
298 lines
12 KiB

package org.gcube.common.keycloak.model;
import java.util.Arrays;
import java.util.Base64;
import java.util.List;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import com.auth0.jwt.JWT;
import com.auth0.jwt.algorithms.Algorithm;
import com.auth0.jwt.exceptions.TokenExpiredException;
import com.auth0.jwt.interfaces.JWTVerifier;
* @author <a href="">Mauro Mugnaini</a>
public class ModelUtils {
protected static final Logger logger = LoggerFactory.getLogger(ModelUtils.class);
private static final String ACCOUNT_AUDIENCE_RESOURCE = "account";
private static final ObjectMapper mapper = new ObjectMapper();
static {
public static String toJSONString(Object object) {
return toJSONString(object, false);
public static String toJSONString(Object object, boolean prettyPrint) {
ObjectWriter writer = prettyPrint ? mapper.writerWithDefaultPrettyPrinter() : mapper.writer();
try {
return writer.writeValueAsString(object);
} catch (JsonProcessingException e) {
logger.error("Cannot pretty print object", e);
return null;
* Creates a {@link RSAPublicKey} instance from its string PEM representation
* @param publicKeyPem the public key PEM string
* @return the RSA public key
* @throws Exception if it's not possible to create the RSA public key from the PEM string
public static RSAPublicKey createRSAPublicKey(String publicKeyPem) throws Exception {
return (RSAPublicKey) createPublicKey(publicKeyPem, "RSA");
* Creates a {@link PublicKey} instance from its string PEM representation
* @param publicKeyPem the public key PEM string
* @param algorithm the key type (e.g. RSA)
* @return the public key
* @throws Exception if it's not possible to create the public key from the PEM string
public static PublicKey createPublicKey(String publicKeyPem, String algorithm) throws Exception {
try {
String publicKey = publicKeyPem.replaceFirst("-----BEGIN (.*)-----\n", "");
publicKey = publicKey.replaceFirst("-----END (.*)-----", "");
publicKey = publicKey.replaceAll("\r\n", "");
publicKey = publicKey.replaceAll("\n", "");
byte[] encoded = Base64.getDecoder().decode(publicKey);
KeyFactory kf = KeyFactory.getInstance(algorithm);
return kf.generatePublic(new X509EncodedKeySpec(encoded));
} catch (Exception e) {
throw new RuntimeException("Cannot create public key from PEM string", e);
* Verifies the token validity
* @param token the base64 JWT token string
* @param rsaPublicKey the realm's RSA public key on server
* @return <code>true</code> if the token is valid, <code>false</code> otherwise
* @throws RuntimeException if an error occurs constructing the verifier
public static boolean isValid(String token, RSAPublicKey rsaPublicKey) throws Exception {
return isValid(token, rsaPublicKey, true);
* Verifies the token validity
* @param token the base64 JWT token string
* @param rsaPublicKey the realm's RSA public key on server
* @param checkExpiration if <code>false</code> token expiration check is disabled
* @return <code>true</code> if the token is valid, <code>false</code> otherwise
* @throws RuntimeException if an error occurs constructing the verifier
public static boolean isValid(String token, RSAPublicKey rsaPublicKey, boolean checkExpiration) throws Exception {
try {
return isValid(token, Algorithm.RSA256(rsaPublicKey, null), checkExpiration);
} catch (Exception e) {
throw new RuntimeException("Cannot construct the JWT verifier", e);
* Verifies the token validity
* @param token the base64 JWT token string
* @param publicKey the realm's public key on server
* @param keyAlgorithm the public key algorithm
* @return <code>true</code> if the token is valid, <code>false</code> otherwise
* @throws RuntimeException if an error occurs constructing the verifier
public static boolean isValid(String token, PublicKey publicKey, String keyAlgorithm) throws Exception {
return isValid(token, publicKey, keyAlgorithm, true);
* Verifies the token validity
* @param token the base64 JWT token string
* @param publicKey the realm's public key on server
* @param keyAlgorithm the public key algorithm
* @param checkExpiration if <code>false</code> token expiration check is disabled
* @return <code>true</code> if the token is valid, <code>false</code> otherwise
* @throws RuntimeException if an error occurs constructing the verifier
public static boolean isValid(String token, PublicKey publicKey, String keyAlgorithm, boolean checkExpiration) throws Exception {
try {
Algorithm algorithm = null;
switch (keyAlgorithm) {
case "RS256":
algorithm = Algorithm.RSA256((RSAPublicKey) publicKey, null);
case "RS384":
algorithm = Algorithm.RSA384((RSAPublicKey) publicKey, null);
case "RS512":
algorithm = Algorithm.RSA512((RSAPublicKey) publicKey, null);
throw new RuntimeException("Unsupported key algorithm: " + algorithm);
return isValid(token, algorithm, checkExpiration);
} catch (Exception e) {
throw new RuntimeException("Cannot construct the JWT verifier", e);
* Verifies the token validity
* @param token the base64 JWT token string
* @param algorithm the algorithm to use for verification
* @param checkExpiration if <code>false</code> token expiration check is disabled
* @return <code>true</code> if the token is valid, <code>false</code> otherwise
public static boolean isValid(String token, Algorithm algorithm, boolean checkExpiration) throws Exception {
JWTVerifier verifier = JWT.require(algorithm).build();;
try {
return true;
} catch (TokenExpiredException e) {
// This is OK because expiration check is after the signature validation in the implementation
if (logger.isDebugEnabled()) {
logger.debug("JWT is expired: {}", e.getMessage());
return !checkExpiration;
} catch (Exception e) {
if (logger.isDebugEnabled()) {
logger.debug("JWT is not verified: {}", e.getMessage());
return false;
public static String getAccessTokenPayloadJSONStringFrom(TokenResponse tokenResponse) throws Exception {
return getAccessTokenPayloadJSONStringFrom(tokenResponse, true);
public static String getAccessTokenPayloadJSONStringFrom(TokenResponse tokenResponse, boolean prettyPrint)
throws Exception {
return toJSONString(getAccessTokenFrom(tokenResponse, Object.class), prettyPrint);
public static AccessToken getAccessTokenFrom(TokenResponse tokenResponse) throws Exception {
return getAccessTokenFrom(tokenResponse, AccessToken.class);
public static AccessToken getAccessTokenFrom(String authorizationHeaderOrBase64EncodedJWT) throws Exception {
return getAccessTokenFrom(authorizationHeaderOrBase64EncodedJWT.matches("[b|B]earer ")
? authorizationHeaderOrBase64EncodedJWT.substring("bearer ".length())
: authorizationHeaderOrBase64EncodedJWT, AccessToken.class);
private static <T> T getAccessTokenFrom(TokenResponse tokenResponse, Class<T> clazz) throws Exception {
return getAccessTokenFrom(tokenResponse.getAccessToken(), clazz);
private static <T> T getAccessTokenFrom(String accessToken, Class<T> clazz) throws Exception {
return mapper.readValue(getDecodedPayload(accessToken), clazz);
public static String getRefreshTokenPayloadStringFrom(TokenResponse tokenResponse) throws Exception {
return getRefreshTokenPayloadStringFrom(tokenResponse, true);
public static String getRefreshTokenPayloadStringFrom(TokenResponse tokenResponse, boolean prettyPrint)
throws Exception {
return toJSONString(getRefreshTokenFrom(tokenResponse, Object.class), prettyPrint);
public static RefreshToken getRefreshTokenFrom(TokenResponse tokenResponse) throws Exception {
return getRefreshTokenFrom(tokenResponse.getRefreshToken());
public static RefreshToken getRefreshTokenFrom(String base64EncodedJWT) throws Exception {
return mapper.readValue(getDecodedPayload(base64EncodedJWT), RefreshToken.class);
private static <T> T getRefreshTokenFrom(TokenResponse tokenResponse, Class<T> clazz) throws Exception {
return mapper.readValue(getDecodedPayload(tokenResponse.getRefreshToken()), clazz);
protected static byte[] getBase64Decoded(String string) {
return Base64.getDecoder().decode(string);
protected static String splitAndGet(String encodedJWT, int index) {
String[] split = encodedJWT.split("\\.");
if (split.length == 3) {
return split[index];
} else {
return null;
public static byte[] getDecodedHeader(String value) {
return getBase64Decoded(getEncodedHeader(value));
public static String getEncodedHeader(String encodedJWT) {
return splitAndGet(encodedJWT, 0);
public static byte[] getDecodedPayload(String value) {
return getBase64Decoded(getEncodedPayload(value));
public static String getEncodedPayload(String encodedJWT) {
return splitAndGet(encodedJWT, 1);
public static byte[] getDecodedSignature(String value) {
return getBase64Decoded(getEncodedSignature(value));
public static String getEncodedSignature(String encodedJWT) {
return splitAndGet(encodedJWT, 2);
public static String getClientIdFromToken(AccessToken accessToken) {
String clientId;
logger.debug("Client id not provided, using authorized party field (azp)");
clientId = accessToken.getIssuedFor();
if (clientId == null) {
logger.warn("Issued for field (azp) not present, getting first of the audience field (aud)");
clientId = getFirstAudienceNoAccount(accessToken);
return clientId;
private static String getFirstAudienceNoAccount(AccessToken accessToken) {
// Trying to get it from the token's audience ('aud' field), getting the first except the 'account'
List<String> tokenAud = Arrays.asList(accessToken.getAudience());
if (tokenAud.size() > 0) {
return tokenAud.iterator().next();
} else {
// Setting it to empty string to avoid NPE in encoding
return "";