package org.gcube.common.keycloak.model; import java.util.Arrays; import java.util.Base64; import java.util.List; import org.gcube.com.fasterxml.jackson.annotation.JsonInclude.Include; import org.gcube.com.fasterxml.jackson.core.JsonProcessingException; import org.gcube.com.fasterxml.jackson.databind.ObjectMapper; import org.gcube.com.fasterxml.jackson.databind.ObjectWriter; import org.slf4j.Logger; import org.slf4j.LoggerFactory; public class ModelUtils { protected static final Logger logger = LoggerFactory.getLogger(ModelUtils.class); private static final String ACCOUNT_AUDIENCE_RESOURCE = "account"; private static final ObjectMapper mapper = new ObjectMapper(); static { mapper.setSerializationInclusion(Include.NON_NULL); } public static String toJSONString(Object object) { return toJSONString(object, false); } public static String toJSONString(Object object, boolean prettyPrint) { ObjectWriter writer = prettyPrint ? mapper.writerWithDefaultPrettyPrinter() : mapper.writer(); try { return writer.writeValueAsString(object); } catch (JsonProcessingException e) { logger.error("Cannot pretty print object", e); return null; } } public static String getAccessTokenPayloadJSONStringFrom(TokenResponse tokenResponse) throws Exception { return getAccessTokenPayloadJSONStringFrom(tokenResponse, true); } public static String getAccessTokenPayloadJSONStringFrom(TokenResponse tokenResponse, boolean prettyPrint) throws Exception { return toJSONString(getAccessTokenFrom(tokenResponse, Object.class), prettyPrint); } public static AccessToken getAccessTokenFrom(TokenResponse tokenResponse) throws Exception { return getAccessTokenFrom(tokenResponse, AccessToken.class); } public static AccessToken getAccessTokenFrom(String authorizationHeaderOrBase64EncodedJWT) throws Exception { return getAccessTokenFrom(authorizationHeaderOrBase64EncodedJWT.matches("[b|B]earer ") ? authorizationHeaderOrBase64EncodedJWT.substring("bearer ".length()) : authorizationHeaderOrBase64EncodedJWT, AccessToken.class); } private static T getAccessTokenFrom(TokenResponse tokenResponse, Class clazz) throws Exception { return getAccessTokenFrom(tokenResponse.getAccessToken(), clazz); } private static T getAccessTokenFrom(String accessToken, Class clazz) throws Exception { return mapper.readValue(getDecodedPayload(accessToken), clazz); } public static String getRefreshTokenPayloadStringFrom(TokenResponse tokenResponse) throws Exception { return getRefreshTokenPayloadStringFrom(tokenResponse, true); } public static String getRefreshTokenPayloadStringFrom(TokenResponse tokenResponse, boolean prettyPrint) throws Exception { return toJSONString(getRefreshTokenFrom(tokenResponse, Object.class), prettyPrint); } public static RefreshToken getRefreshTokenFrom(TokenResponse tokenResponse) throws Exception { return getRefreshTokenFrom(tokenResponse.getRefreshToken()); } public static RefreshToken getRefreshTokenFrom(String base64EncodedJWT) throws Exception { return mapper.readValue(getDecodedPayload(base64EncodedJWT), RefreshToken.class); } private static T getRefreshTokenFrom(TokenResponse tokenResponse, Class clazz) throws Exception { return mapper.readValue(getDecodedPayload(tokenResponse.getRefreshToken()), clazz); } protected static byte[] getBase64Decoded(String string) { return Base64.getDecoder().decode(string); } protected static String splitAndGet(String encodedJWT, int index) { String[] split = encodedJWT.split("\\."); if (split.length == 3) { return split[index]; } else { return null; } } public static byte[] getDecodedHeader(String value) { return getBase64Decoded(getEncodedHeader(value)); } public static String getEncodedHeader(String encodedJWT) { return splitAndGet(encodedJWT, 0); } public static byte[] getDecodedPayload(String value) { return getBase64Decoded(getEncodedPayload(value)); } public static String getEncodedPayload(String encodedJWT) { return splitAndGet(encodedJWT, 1); } public static byte[] getDecodedSignature(String value) { return getBase64Decoded(getEncodedSignature(value)); } public static String getEncodedSignature(String encodedJWT) { return splitAndGet(encodedJWT, 2); } public static String getClientIdFromToken(AccessToken accessToken) { String clientId; logger.debug("Client id not provided, using authorized party field (azp)"); clientId = accessToken.getIssuedFor(); if (clientId == null) { logger.warn("Issued for field (azp) not present, getting first of the audience field (aud)"); clientId = getFirstAudienceNoAccount(accessToken); } return clientId; } private static String getFirstAudienceNoAccount(AccessToken accessToken) { // Trying to get it from the token's audience ('aud' field), getting the first except the 'account' List tokenAud = Arrays.asList(accessToken.getAudience()); tokenAud.remove(ACCOUNT_AUDIENCE_RESOURCE); if (tokenAud.size() > 0) { return tokenAud.iterator().next(); } else { // Setting it to empty string to avoid NPE in encoding return ""; } } }