201 lines
7.6 KiB
Java
201 lines
7.6 KiB
Java
package org.gcube.common.keycloak.model;
|
|
|
|
import java.security.KeyFactory;
|
|
import java.security.interfaces.RSAPublicKey;
|
|
import java.security.spec.X509EncodedKeySpec;
|
|
import java.util.Arrays;
|
|
import java.util.Base64;
|
|
import java.util.List;
|
|
|
|
import org.gcube.com.fasterxml.jackson.annotation.JsonInclude.Include;
|
|
import org.gcube.com.fasterxml.jackson.core.JsonProcessingException;
|
|
import org.gcube.com.fasterxml.jackson.databind.ObjectMapper;
|
|
import org.gcube.com.fasterxml.jackson.databind.ObjectWriter;
|
|
import org.slf4j.Logger;
|
|
import org.slf4j.LoggerFactory;
|
|
|
|
import com.auth0.jwt.JWT;
|
|
import com.auth0.jwt.algorithms.Algorithm;
|
|
import com.auth0.jwt.exceptions.JWTVerificationException;
|
|
import com.auth0.jwt.interfaces.JWTVerifier;
|
|
|
|
/**
|
|
* @author <a href="mailto:mauro.mugnaini@nubisware.com">Mauro Mugnaini</a>
|
|
*/
|
|
public class ModelUtils {
|
|
|
|
protected static final Logger logger = LoggerFactory.getLogger(ModelUtils.class);
|
|
|
|
private static final String ACCOUNT_AUDIENCE_RESOURCE = "account";
|
|
|
|
private static final ObjectMapper mapper = new ObjectMapper();
|
|
|
|
static {
|
|
mapper.setSerializationInclusion(Include.NON_NULL);
|
|
}
|
|
|
|
public static String toJSONString(Object object) {
|
|
return toJSONString(object, false);
|
|
}
|
|
|
|
public static String toJSONString(Object object, boolean prettyPrint) {
|
|
ObjectWriter writer = prettyPrint ? mapper.writerWithDefaultPrettyPrinter() : mapper.writer();
|
|
try {
|
|
return writer.writeValueAsString(object);
|
|
} catch (JsonProcessingException e) {
|
|
logger.error("Cannot pretty print object", e);
|
|
return null;
|
|
}
|
|
}
|
|
|
|
public static RSAPublicKey createRSAPublicKey(String publicKeyPem) {
|
|
try {
|
|
String publicKey = publicKeyPem.replaceFirst("-----BEGIN .+-----\n", "");
|
|
publicKey = publicKey.replaceFirst("-----END .+-----", "");
|
|
|
|
byte[] encoded = Base64.getDecoder().decode(publicKey);
|
|
KeyFactory kf = KeyFactory.getInstance("RSA");
|
|
return (RSAPublicKey) kf.generatePublic(new X509EncodedKeySpec(encoded));
|
|
} catch (Exception e) {
|
|
throw new RuntimeException("Cant' create RSA public key from PEM string", e);
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Verifies the token's digital signature
|
|
*
|
|
* @param token the base64 JWT token string
|
|
* @param publicKey the realm's public key on server
|
|
* @return <code>true</code> if the signature is verified, <code>false</code> otherwise
|
|
* @throws RuntimeException if an error occurs constructing the digital signature verifier
|
|
*/
|
|
public static boolean isSignatureValid(String token, RSAPublicKey publicKey) throws RuntimeException {
|
|
JWTVerifier verifier = null;
|
|
try {
|
|
Algorithm algorithm = Algorithm.RSA256(publicKey, null);
|
|
verifier = JWT.require(algorithm).build();
|
|
} catch (Exception e) {
|
|
throw new RuntimeException("Cannot construct the JWT digital signature verifier", e);
|
|
}
|
|
try {
|
|
verifier.verify(token);
|
|
return true;
|
|
} catch (JWTVerificationException e) {
|
|
if (logger.isDebugEnabled()) {
|
|
logger.debug("JWT digital signature is not verified", e);
|
|
}
|
|
return false;
|
|
}
|
|
}
|
|
|
|
public static String getAccessTokenPayloadJSONStringFrom(TokenResponse tokenResponse) throws Exception {
|
|
return getAccessTokenPayloadJSONStringFrom(tokenResponse, true);
|
|
}
|
|
|
|
public static String getAccessTokenPayloadJSONStringFrom(TokenResponse tokenResponse, boolean prettyPrint)
|
|
throws Exception {
|
|
|
|
return toJSONString(getAccessTokenFrom(tokenResponse, Object.class), prettyPrint);
|
|
}
|
|
|
|
public static AccessToken getAccessTokenFrom(TokenResponse tokenResponse) throws Exception {
|
|
return getAccessTokenFrom(tokenResponse, AccessToken.class);
|
|
}
|
|
|
|
public static AccessToken getAccessTokenFrom(String authorizationHeaderOrBase64EncodedJWT) throws Exception {
|
|
return getAccessTokenFrom(authorizationHeaderOrBase64EncodedJWT.matches("[b|B]earer ")
|
|
? authorizationHeaderOrBase64EncodedJWT.substring("bearer ".length())
|
|
: authorizationHeaderOrBase64EncodedJWT, AccessToken.class);
|
|
}
|
|
|
|
private static <T> T getAccessTokenFrom(TokenResponse tokenResponse, Class<T> clazz) throws Exception {
|
|
return getAccessTokenFrom(tokenResponse.getAccessToken(), clazz);
|
|
}
|
|
|
|
private static <T> T getAccessTokenFrom(String accessToken, Class<T> clazz) throws Exception {
|
|
return mapper.readValue(getDecodedPayload(accessToken), clazz);
|
|
}
|
|
|
|
public static String getRefreshTokenPayloadStringFrom(TokenResponse tokenResponse) throws Exception {
|
|
return getRefreshTokenPayloadStringFrom(tokenResponse, true);
|
|
}
|
|
|
|
public static String getRefreshTokenPayloadStringFrom(TokenResponse tokenResponse, boolean prettyPrint)
|
|
throws Exception {
|
|
|
|
return toJSONString(getRefreshTokenFrom(tokenResponse, Object.class), prettyPrint);
|
|
}
|
|
|
|
public static RefreshToken getRefreshTokenFrom(TokenResponse tokenResponse) throws Exception {
|
|
return getRefreshTokenFrom(tokenResponse.getRefreshToken());
|
|
}
|
|
|
|
public static RefreshToken getRefreshTokenFrom(String base64EncodedJWT) throws Exception {
|
|
return mapper.readValue(getDecodedPayload(base64EncodedJWT), RefreshToken.class);
|
|
}
|
|
|
|
private static <T> T getRefreshTokenFrom(TokenResponse tokenResponse, Class<T> clazz) throws Exception {
|
|
return mapper.readValue(getDecodedPayload(tokenResponse.getRefreshToken()), clazz);
|
|
}
|
|
|
|
protected static byte[] getBase64Decoded(String string) {
|
|
return Base64.getDecoder().decode(string);
|
|
}
|
|
|
|
protected static String splitAndGet(String encodedJWT, int index) {
|
|
String[] split = encodedJWT.split("\\.");
|
|
if (split.length == 3) {
|
|
return split[index];
|
|
} else {
|
|
return null;
|
|
}
|
|
}
|
|
|
|
public static byte[] getDecodedHeader(String value) {
|
|
return getBase64Decoded(getEncodedHeader(value));
|
|
}
|
|
|
|
public static String getEncodedHeader(String encodedJWT) {
|
|
return splitAndGet(encodedJWT, 0);
|
|
}
|
|
|
|
public static byte[] getDecodedPayload(String value) {
|
|
return getBase64Decoded(getEncodedPayload(value));
|
|
}
|
|
|
|
public static String getEncodedPayload(String encodedJWT) {
|
|
return splitAndGet(encodedJWT, 1);
|
|
}
|
|
|
|
public static byte[] getDecodedSignature(String value) {
|
|
return getBase64Decoded(getEncodedSignature(value));
|
|
}
|
|
|
|
public static String getEncodedSignature(String encodedJWT) {
|
|
return splitAndGet(encodedJWT, 2);
|
|
}
|
|
|
|
public static String getClientIdFromToken(AccessToken accessToken) {
|
|
String clientId;
|
|
logger.debug("Client id not provided, using authorized party field (azp)");
|
|
clientId = accessToken.getIssuedFor();
|
|
if (clientId == null) {
|
|
logger.warn("Issued for field (azp) not present, getting first of the audience field (aud)");
|
|
clientId = getFirstAudienceNoAccount(accessToken);
|
|
}
|
|
return clientId;
|
|
}
|
|
|
|
private static String getFirstAudienceNoAccount(AccessToken accessToken) {
|
|
// Trying to get it from the token's audience ('aud' field), getting the first except the 'account'
|
|
List<String> tokenAud = Arrays.asList(accessToken.getAudience());
|
|
tokenAud.remove(ACCOUNT_AUDIENCE_RESOURCE);
|
|
if (tokenAud.size() > 0) {
|
|
return tokenAud.iterator().next();
|
|
} else {
|
|
// Setting it to empty string to avoid NPE in encoding
|
|
return "";
|
|
}
|
|
}
|
|
}
|