argos/dmp-backend/web/src/main/java/eu/eudat/logic/security/validators/configurableProvider/Saml2SSOUtils.java

688 lines
32 KiB
Java

package eu.eudat.logic.security.validators.configurableProvider;
import eu.eudat.logic.security.customproviders.ConfigurableProvider.entities.saml2.Saml2ConfigurableProvider;
import jakarta.xml.soap.*;
import net.shibboleth.utilities.java.support.component.ComponentInitializationException;
import net.shibboleth.utilities.java.support.resolver.CriteriaSet;
import net.shibboleth.utilities.java.support.xml.BasicParserPool;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import org.apache.http.HttpResponse;
import org.apache.http.client.HttpClient;
import org.apache.http.client.ResponseHandler;
import org.apache.http.client.methods.HttpPost;
import org.apache.http.conn.ssl.SSLConnectionSocketFactory;
import org.apache.http.conn.ssl.TrustSelfSignedStrategy;
import org.apache.http.entity.ContentType;
import org.apache.http.entity.StringEntity;
import org.apache.http.impl.client.BasicResponseHandler;
import org.apache.http.impl.client.CloseableHttpClient;
import org.apache.http.impl.client.HttpClientBuilder;
import org.apache.http.impl.client.HttpClients;
import org.apache.http.ssl.SSLContextBuilder;
import org.apache.xml.security.c14n.Canonicalizer;
import org.apache.xml.security.signature.XMLSignature;
import org.opensaml.core.config.ConfigurationService;
import org.opensaml.core.config.InitializationException;
import org.opensaml.core.config.InitializationService;
import org.opensaml.core.criterion.EntityIdCriterion;
import org.opensaml.core.xml.XMLObject;
import org.opensaml.core.xml.XMLObjectBuilder;
import org.opensaml.core.xml.config.XMLObjectProviderRegistry;
import org.opensaml.core.xml.io.*;
import org.opensaml.core.xml.schema.*;
import org.opensaml.saml.common.SAMLObject;
import org.opensaml.saml.common.SAMLVersion;
import org.opensaml.saml.common.xml.SAMLConstants;
import org.opensaml.saml.criterion.EntityRoleCriterion;
import org.opensaml.saml.criterion.ProtocolCriterion;
import org.opensaml.saml.metadata.resolver.impl.HTTPMetadataResolver;
import org.opensaml.saml.metadata.resolver.impl.PredicateRoleDescriptorResolver;
import org.opensaml.saml.saml2.core.*;
import org.opensaml.saml.saml2.encryption.Decrypter;
import org.opensaml.saml.saml2.metadata.IDPSSODescriptor;
import org.opensaml.saml.security.impl.MetadataCredentialResolver;
import org.opensaml.security.credential.Credential;
import org.opensaml.security.credential.CredentialSupport;
import org.opensaml.security.credential.UsageType;
import org.opensaml.security.criteria.UsageCriterion;
import org.opensaml.security.x509.BasicX509Credential;
import org.opensaml.security.x509.X509Credential;
import org.opensaml.security.x509.impl.KeyStoreX509CredentialAdapter;
import org.opensaml.soap.soap11.Body;
import org.opensaml.soap.soap11.Envelope;
import org.opensaml.xml.util.Base64;
import org.opensaml.xmlsec.config.impl.DefaultSecurityConfigurationBootstrap;
import org.opensaml.xmlsec.encryption.EncryptedKey;
import org.opensaml.xmlsec.keyinfo.KeyInfoCredentialResolver;
import org.opensaml.xmlsec.keyinfo.impl.StaticKeyInfoCredentialResolver;
import org.opensaml.xmlsec.signature.KeyInfo;
import org.opensaml.xmlsec.signature.Signature;
import org.opensaml.xmlsec.signature.X509Data;
import org.opensaml.xmlsec.signature.support.SignatureValidator;
import org.opensaml.xmlsec.signature.support.Signer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.w3c.dom.Document;
import org.w3c.dom.Element;
import org.w3c.dom.bootstrap.DOMImplementationRegistry;
import org.w3c.dom.ls.DOMImplementationLS;
import org.w3c.dom.ls.LSOutput;
import org.w3c.dom.ls.LSSerializer;
import org.xml.sax.SAXException;
import javax.crypto.SecretKey;
import javax.xml.namespace.QName;
import javax.xml.parsers.DocumentBuilder;
import javax.xml.parsers.DocumentBuilderFactory;
import javax.xml.parsers.ParserConfigurationException;
import javax.xml.transform.TransformerException;
import javax.xml.transform.TransformerFactory;
import javax.xml.transform.dom.DOMSource;
import javax.xml.transform.stream.StreamResult;
import java.io.*;
import java.net.UnknownHostException;
import java.nio.charset.StandardCharsets;
import java.security.*;
import java.security.cert.CertificateEncodingException;
import java.security.cert.X509Certificate;
import java.time.Instant;
import java.util.*;
import java.util.stream.Collectors;
import java.util.zip.Inflater;
import java.util.zip.InflaterInputStream;
public class Saml2SSOUtils {
private static final Logger logger = LoggerFactory.getLogger(Saml2SSOUtils.class);
private static boolean isBootStrapped = false;
private static BasicParserPool parserPool;
private static XMLObjectProviderRegistry registry;
private Saml2SSOUtils() {
}
private static void doBootstrap() throws Exception {
if (!isBootStrapped) {
try {
boostrap();
isBootStrapped = true;
} catch (Exception e) {
throw new Exception("Error in bootstrapping the OpenSAML2 library", e);
}
}
}
private static void boostrap(){
parserPool = new BasicParserPool();
parserPool.setMaxPoolSize(100);
parserPool.setCoalescing(true);
parserPool.setIgnoreComments(true);
parserPool.setIgnoreElementContentWhitespace(true);
parserPool.setNamespaceAware(true);
parserPool.setExpandEntityReferences(false);
parserPool.setXincludeAware(false);
final Map<String, Boolean> features = new HashMap<String, Boolean>();
features.put("http://xml.org/sax/features/external-general-entities", Boolean.FALSE);
features.put("http://xml.org/sax/features/external-parameter-entities", Boolean.FALSE);
features.put("http://apache.org/xml/features/disallow-doctype-decl", Boolean.TRUE);
features.put("http://apache.org/xml/features/validation/schema/normalized-value", Boolean.FALSE);
features.put("http://javax.xml.XMLConstants/feature/secure-processing", Boolean.TRUE);
parserPool.setBuilderFeatures(features);
parserPool.setBuilderAttributes(new HashMap<String, Object>());
try {
parserPool.initialize();
} catch (ComponentInitializationException e) {
logger.error(e.getMessage(), e);
}
registry = new XMLObjectProviderRegistry();
ConfigurationService.register(XMLObjectProviderRegistry.class, registry);
registry.setParserPool(parserPool);
try {
InitializationService.initialize();
} catch (InitializationException e) {
logger.error(e.getMessage(), e);
}
}
private static XMLObject buildXMLObject(QName objectQName) throws Exception {
doBootstrap();
XMLObjectBuilder builder = registry.getBuilderFactory().getBuilder(objectQName);
if (builder == null) {
throw new Exception("Unable to retrieve builder for object QName " + objectQName);
}
return builder.buildObject(objectQName.getNamespaceURI(), objectQName.getLocalPart(), objectQName.getPrefix());
}
public static String getAttributeName(Attribute attribute, Saml2ConfigurableProvider.SAML2UsingFormat usingFormat){
String friendlyName = attribute.getFriendlyName();
String name = attribute.getName();
if(usingFormat.getName().equals(Saml2ConfigurableProvider.SAML2UsingFormat.FRIENDLY_NAME.getName())){
return (friendlyName != null) ? friendlyName : name;
}
else{
return (name != null) ? name : friendlyName;
}
}
public static Object getAttributeType(XMLObject attribute, Saml2ConfigurableProvider.SAML2AttributeType attributeType){
if(attributeType.getType().equals(Saml2ConfigurableProvider.SAML2AttributeType.XSSTRING.getType())){
return ((XSString)attribute).getValue();
}
else if(attributeType.getType().equals(Saml2ConfigurableProvider.SAML2AttributeType.XSINTEGER.getType())){
return ((XSInteger)attribute).getValue();
}
else if(attributeType.getType().equals(Saml2ConfigurableProvider.SAML2AttributeType.XSDATETIME.getType())){
return ((XSDateTime)attribute).getValue();
}
else if(attributeType.getType().equals(Saml2ConfigurableProvider.SAML2AttributeType.XSBOOLEAN.getType())){
return ((XSBoolean)attribute).getValue();
}
else if(attributeType.getType().equals(Saml2ConfigurableProvider.SAML2AttributeType.XSBASE64BINARY.getType())){
return ((XSBase64Binary)attribute).getValue();
}
else if(attributeType.getType().equals(Saml2ConfigurableProvider.SAML2AttributeType.XSURI.getType())){
return ((XSURI)attribute).getURI();
}
else if(attributeType.getType().equals(Saml2ConfigurableProvider.SAML2AttributeType.XSQNAME.getType())){
return ((XSQName)attribute).getValue();
}
else if(attributeType.getType().equals(Saml2ConfigurableProvider.SAML2AttributeType.XSANY.getType())){
return ((XSAny)attribute).getTextContent();
}
else {
return null;
}
}
private static String marshall(XMLObject xmlObject) throws Exception {
try {
MarshallerFactory marshallerFactory = registry.getMarshallerFactory();
Marshaller marshaller = marshallerFactory.getMarshaller(xmlObject);
Element element = marshaller.marshall(xmlObject);
ByteArrayOutputStream byteArrayOutputStrm = new ByteArrayOutputStream();
DOMImplementationRegistry registry = DOMImplementationRegistry.newInstance();
DOMImplementationLS impl = (DOMImplementationLS) registry.getDOMImplementation("LS");
LSSerializer writer = impl.createLSSerializer();
LSOutput output = impl.createLSOutput();
output.setByteStream(byteArrayOutputStrm);
writer.write(element, output);
return byteArrayOutputStrm.toString();
} catch (Exception e) {
throw new Exception("Error Serializing the SAML Response", e);
}
}
private static XMLObject unmarshall(String saml2SSOString) throws Exception {
DocumentBuilderFactory documentBuilderFactory = DocumentBuilderFactory.newInstance();
//documentBuilderFactory.setExpandEntityReferences(false);
documentBuilderFactory.setNamespaceAware(true);
try {
DocumentBuilder docBuilder = documentBuilderFactory.newDocumentBuilder();
ByteArrayInputStream is = new ByteArrayInputStream(saml2SSOString.getBytes(StandardCharsets.UTF_8));
Document document = docBuilder.parse(is);
Element element = document.getDocumentElement();
UnmarshallerFactory unmarshallerFactory = registry.getUnmarshallerFactory();
Unmarshaller unmarshaller = unmarshallerFactory.getUnmarshaller(element);
return unmarshaller.unmarshall(element);
} catch (ParserConfigurationException | UnmarshallingException | SAXException | IOException e) {
throw new Exception("Error in unmarshalling SAML2SSO Request from the encoded String", e);
}
}
public static Assertion processArtifactResponse(String artifactString, Saml2ConfigurableProvider saml2Provider) throws Exception {
doBootstrap();
if (artifactString != null){
ArtifactResolve artifactResolve = generateArtifactResolveReq(artifactString, saml2Provider);
ArtifactResponse artifactResponse = sendArtifactResolveRequest(artifactResolve, saml2Provider.getIdpArtifactUrl());
Response saml2Response = (Response)artifactResponse.getMessage();
return processSSOResponse(saml2Response, saml2Provider);
}
else {
throw new Exception("Invalid SAML2 Artifact. SAML2 Artifact can not be null.");
}
}
private static ArtifactResolve generateArtifactResolveReq(String samlArtReceived, Saml2ConfigurableProvider saml2Provider) throws Exception {
ArtifactResolve artifactResolve = createArtifactResolveObject(samlArtReceived, saml2Provider.getSpEntityId());
if (saml2Provider.isSignatureRequired()) {
signArtifactResolveReq(artifactResolve, saml2Provider);
}
return artifactResolve;
}
private static ArtifactResolve createArtifactResolveObject(String samlArtReceived, String spEntityId) throws Exception {
ArtifactResolve artifactResolve = (ArtifactResolve)buildXMLObject(ArtifactResolve.DEFAULT_ELEMENT_NAME);
artifactResolve.setVersion(SAMLVersion.VERSION_20);
artifactResolve.setID(UUID.randomUUID().toString());
artifactResolve.setIssueInstant(Instant.now());
Artifact artifact = (Artifact)buildXMLObject(Artifact.DEFAULT_ELEMENT_NAME);
artifact.setValue(samlArtReceived);
Issuer issuer = (Issuer)buildXMLObject(Issuer.DEFAULT_ELEMENT_NAME);
issuer.setValue(spEntityId);
artifactResolve.setIssuer(issuer);
artifactResolve.setArtifact(artifact);
return artifactResolve;
}
private static void signArtifactResolveReq(ArtifactResolve artifactResolve, Saml2ConfigurableProvider saml2Provider) throws Exception {
try {
KeyStore ks = KeyStore.getInstance("JKS");
String archivePassword = saml2Provider.getSignatureKeyStorePassword();
char[] pwdArray = (archivePassword != null) ? archivePassword.toCharArray() : "changeit".toCharArray();
ks.load(new FileInputStream(saml2Provider.getSignaturePath()), pwdArray);
X509Credential cred = new KeyStoreX509CredentialAdapter(ks, saml2Provider.getSignatureKeyAlias(), saml2Provider.getSignatureKeyPassword().toCharArray());
Signature signature = setSignatureRaw(XMLSignature.ALGO_ID_SIGNATURE_RSA, cred);
artifactResolve.setSignature(signature);
List<Signature> signatureList = new ArrayList<>();
signatureList.add(signature);
MarshallerFactory marshallerFactory = registry.getMarshallerFactory();
Marshaller marshaller = marshallerFactory.getMarshaller(artifactResolve);
marshaller.marshall(artifactResolve);
org.apache.xml.security.Init.init();
Signer.signObjects(signatureList);
} catch (Exception e) {
throw new Exception("Error while signing the SAML Request message", e);
}
}
private static Signature setSignatureRaw(String signatureAlgorithm, X509Credential cred) throws Exception {
Signature signature = (Signature)buildXMLObject(Signature.DEFAULT_ELEMENT_NAME);
signature.setSigningCredential(cred);
signature.setSignatureAlgorithm(signatureAlgorithm);
signature.setCanonicalizationAlgorithm(Canonicalizer.ALGO_ID_C14N_EXCL_OMIT_COMMENTS);
try {
KeyInfo keyInfo = (KeyInfo)buildXMLObject(KeyInfo.DEFAULT_ELEMENT_NAME);
X509Data data = (X509Data)buildXMLObject(X509Data.DEFAULT_ELEMENT_NAME);
org.opensaml.xmlsec.signature.X509Certificate cert =
(org.opensaml.xmlsec.signature.X509Certificate) buildXMLObject(
org.opensaml.xmlsec.signature.X509Certificate.DEFAULT_ELEMENT_NAME);
String value = org.apache.commons.codec.binary.Base64.encodeBase64String(cred.getEntityCertificate().getEncoded());
cert.setValue(value);
data.getX509Certificates().add(cert);
keyInfo.getX509Datas().add(data);
signature.setKeyInfo(keyInfo);
return signature;
} catch (CertificateEncodingException e) {
throw new Exception("Error getting certificate", e);
}
}
private static ArtifactResponse sendArtifactResolveRequest(ArtifactResolve artifactResolve, String idpArtifactUrl) throws Exception {
Envelope envelope = buildSOAPMessage(artifactResolve);
String envelopeElement;
try {
envelopeElement = marshall(envelope);
} catch (Exception e) {
throw new Exception("Encountered error marshalling SOAP message with artifact " + "resolve, into its DOM representation", e);
}
String artifactResponseString = sendSOAP(envelopeElement, idpArtifactUrl);
ArtifactResponse artifactResponse = extractArtifactResponse(artifactResponseString);
validateArtifactResponse(artifactResolve, artifactResponse);
return artifactResponse;
}
private static Envelope buildSOAPMessage(SAMLObject samlMessage) throws Exception {
Envelope envelope = (Envelope)buildXMLObject(Envelope.DEFAULT_ELEMENT_NAME);
Body body = (Body)buildXMLObject(Body.DEFAULT_ELEMENT_NAME);
body.getUnknownXMLObjects().add(samlMessage);
envelope.setBody(body);
return envelope;
}
private static String sendSOAP(String message, String idpArtifactUrl) throws Exception {
if (message == null) {
throw new Exception("Cannot send null SOAP message.");
}
if (idpArtifactUrl == null) {
throw new Exception("Cannot send SOAP message to null URL.");
}
StringBuilder soapResponse = new StringBuilder();
try {
HttpPost httpPost = new HttpPost(idpArtifactUrl);
setRequestProperties(idpArtifactUrl, message, httpPost);
HttpClient httpClient = getHttpClient();
HttpResponse httpResponse = httpClient.execute(httpPost);
int responseCode = httpResponse.getStatusLine().getStatusCode();
if (responseCode != 200) {
throw new Exception("Problem in communicating with: " + idpArtifactUrl + ". Received response: " + responseCode);
} else {
soapResponse.append(getResponseBody(httpResponse));
}
} catch (UnknownHostException e) {
throw new Exception("Unknown targeted host: " + idpArtifactUrl, e);
} catch (IOException e) {
throw new Exception("Could not open connection with host: " + idpArtifactUrl, e);
}
return soapResponse.toString();
}
private static void setRequestProperties(String idpArtifactUrl, String message, HttpPost httpPost) {
httpPost.addHeader("Content-Type", "text/xml; charset=utf-8");
httpPost.addHeader("Accept", "text/xml; charset=utf-8");
String sbSOAPAction = "\"" + idpArtifactUrl + "\"";
httpPost.addHeader("SOAPAction", sbSOAPAction);
httpPost.addHeader("Pragma", "no-cache");
httpPost.addHeader("Cache-Control", "no-cache, no-store");
httpPost.setEntity(new StringEntity(message, ContentType.create("text/xml", StandardCharsets.UTF_8)));
}
private static HttpClient getHttpClient() throws Exception {
CloseableHttpClient httpClient = null;
SSLContextBuilder builder = new SSLContextBuilder();
try {
builder.loadTrustMaterial(null, new TrustSelfSignedStrategy());
SSLConnectionSocketFactory sslsf = new SSLConnectionSocketFactory(
builder.build());
httpClient = HttpClients.custom().setSSLSocketFactory(
sslsf).build();
} catch (NoSuchAlgorithmException | KeyStoreException e) {
throw new Exception("Error while building trust store.", e);
} catch (KeyManagementException e) {
throw new Exception("Error while building socket factory.", e);
}
return httpClient;
}
private static String getResponseBody(HttpResponse response) throws Exception {
ResponseHandler<String> responseHandler = new BasicResponseHandler();
String responseBody;
try {
responseBody = responseHandler.handleResponse(response);
} catch (IOException e) {
throw new Exception("Error when retrieving the HTTP response body.", e);
}
return responseBody;
}
private static ArtifactResponse extractArtifactResponse(String artifactResponseString) throws Exception {
ArtifactResponse artifactResponse = null;
InputStream stream = new ByteArrayInputStream(artifactResponseString.getBytes(StandardCharsets.UTF_8));
try {
MessageFactory messageFactory = MessageFactory.newInstance();
SOAPMessage soapMessage = messageFactory.createMessage(new MimeHeaders(), stream);
SOAPBody soapBody = soapMessage.getSOAPBody();
Iterator<Node> iterator = soapBody.getChildElements();
while (iterator.hasNext()) {
SOAPBodyElement artifactResponseElement = (SOAPBodyElement) iterator.next();
if (StringUtils.equals(SAMLConstants.SAML20P_NS, artifactResponseElement.getNamespaceURI()) &&
StringUtils.equals(ArtifactResponse.DEFAULT_ELEMENT_LOCAL_NAME,
artifactResponseElement.getLocalName())) {
DOMSource source = new DOMSource(artifactResponseElement);
StringWriter stringResult = new StringWriter();
TransformerFactory.newInstance().newTransformer().transform(
source, new StreamResult(stringResult));
artifactResponse = (ArtifactResponse) unmarshall(stringResult.toString());
} else {
throw new Exception("Received invalid artifact response with nameSpaceURI: " +
artifactResponseElement.getNamespaceURI() + " and localName: " +
artifactResponseElement.getLocalName());
}
}
} catch (SOAPException | IOException | TransformerException e) {
throw new Exception("Didn't receive valid artifact response.", e);
} catch (Exception e) {
throw new Exception("Encountered error unmarshalling response into SAML2 object", e);
}
return artifactResponse;
}
private static void validateArtifactResponse(ArtifactResolve artifactResolve, ArtifactResponse artifactResponse) throws Exception {
if (artifactResponse == null) {
throw new Exception("Received artifact response message was null.");
}
String artifactResolveId = artifactResolve.getID();
String artifactResponseInResponseTo = artifactResponse.getInResponseTo();
if (!artifactResolveId.equals(artifactResponseInResponseTo)) {
throw new Exception("Artifact resolve ID: " + artifactResolveId + " is not equal to " +
"artifact response InResponseTo : " + artifactResponseInResponseTo);
}
String artifactResponseStatus = artifactResponse.getStatus().getStatusCode().getValue();
if (!StatusCode.SUCCESS.equals(artifactResponseStatus)) {
throw new Exception("Unsuccessful artifact response with status: " +
artifactResponseStatus);
}
SAMLObject message = artifactResponse.getMessage();
if (message == null) {
throw new Exception("No SAML response embedded into the artifact response.");
}
}
public static Assertion processResponse(String saml2SSOResponse, Saml2ConfigurableProvider saml2Provider) throws Exception {
doBootstrap();
if (saml2SSOResponse != null) {
byte[] decodedResponse = Base64.decode(saml2SSOResponse);
String response;
if(!saml2Provider.getBinding().equals("Post")){
ByteArrayInputStream bytesIn = new ByteArrayInputStream(decodedResponse);
InflaterInputStream inflater = new InflaterInputStream(bytesIn, new Inflater(true));
response = new BufferedReader(new InputStreamReader(inflater, StandardCharsets.UTF_8))
.lines().collect(Collectors.joining("\n"));
}
else{
response = new String(decodedResponse);
}
Response saml2Response = (Response) Saml2SSOUtils.unmarshall(response);
return processSSOResponse(saml2Response, saml2Provider);
} else {
throw new Exception("Invalid SAML2 Response. SAML2 Response can not be null.");
}
}
private static Assertion processSSOResponse(Response saml2Response, Saml2ConfigurableProvider saml2Provider) throws Exception {
Assertion assertion = null;
if (saml2Provider.isAssertionEncrypted()) {
List<EncryptedAssertion> encryptedAssertions = saml2Response.getEncryptedAssertions();
EncryptedAssertion encryptedAssertion;
if (!CollectionUtils.isEmpty(encryptedAssertions)) {
encryptedAssertion = encryptedAssertions.get(0);
try {
assertion = getDecryptedAssertion(encryptedAssertion, saml2Provider);
} catch (Exception e) {
throw new Exception("Unable to decrypt the SAML2 Assertion");
}
}
} else {
List<Assertion> assertions = saml2Response.getAssertions();
if (assertions != null && !assertions.isEmpty()) {
assertion = assertions.get(0);
}
}
if (assertion == null) {
throw new Exception("SAML2 Assertion not found in the Response");
}
String idPEntityIdValue = assertion.getIssuer().getValue();
if (idPEntityIdValue == null || idPEntityIdValue.isEmpty()) {
throw new Exception("SAML2 Response does not contain an Issuer value");
} else if (!idPEntityIdValue.equals(saml2Provider.getIdpEntityId())) {
throw new Exception("SAML2 Response Issuer verification failed");
}
String subject = null;
if (assertion.getSubject() != null && assertion.getSubject().getNameID() != null) {
subject = assertion.getSubject().getNameID().getValue();
}
if (subject == null) {
throw new Exception("SAML2 Response does not contain the name of the subject");
}
validateAudienceRestriction(assertion, saml2Provider.getSpEntityId());
final HTTPMetadataResolver metadataResolver = new HTTPMetadataResolver(HttpClientBuilder.create().build(), saml2Provider.getIdpMetadataUrl());
metadataResolver.setId(metadataResolver.getClass().getCanonicalName());
metadataResolver.setParserPool(parserPool);
metadataResolver.initialize();
final MetadataCredentialResolver metadataCredentialResolver = new MetadataCredentialResolver();
final PredicateRoleDescriptorResolver roleResolver = new PredicateRoleDescriptorResolver(metadataResolver);
final KeyInfoCredentialResolver keyResolver = DefaultSecurityConfigurationBootstrap.buildBasicInlineKeyInfoCredentialResolver();
metadataCredentialResolver.setKeyInfoCredentialResolver(keyResolver);
metadataCredentialResolver.setRoleDescriptorResolver(roleResolver);
metadataCredentialResolver.initialize();
roleResolver.initialize();
CriteriaSet criteriaSet = new CriteriaSet();
criteriaSet.add(new UsageCriterion(UsageType.SIGNING));
criteriaSet.add(new EntityRoleCriterion(IDPSSODescriptor.DEFAULT_ELEMENT_NAME));
criteriaSet.add(new ProtocolCriterion(SAMLConstants.SAML20P_NS));
criteriaSet.add(new EntityIdCriterion(saml2Provider.getIdpEntityId()));
Credential credential = metadataCredentialResolver.resolveSingle(criteriaSet);
validateSignature(saml2Response, assertion, saml2Provider.isResponseSigned(), saml2Provider.isAssertionSigned(), credential);
return assertion;
}
private static Assertion getDecryptedAssertion(EncryptedAssertion encryptedAssertion, Saml2ConfigurableProvider saml2Provider) throws Exception {
try {
KeyStore ks = (saml2Provider.getKeyFormat().getType().equals("JKS")) ? KeyStore.getInstance("JKS") : KeyStore.getInstance("PKCS12");
String archivePassword = saml2Provider.getArchivePassword();
char[] pwdArray = (archivePassword != null) ? archivePassword.toCharArray() : "changeit".toCharArray();
ks.load(new FileInputStream(saml2Provider.getCredentialPath()), pwdArray);
X509Certificate cert = (X509Certificate)ks.getCertificate(saml2Provider.getKeyAlias());
PrivateKey pk = (PrivateKey) ks.getKey(saml2Provider.getKeyAlias(), saml2Provider.getKeyPassword().toCharArray());
KeyInfoCredentialResolver keyResolver = new StaticKeyInfoCredentialResolver(
new BasicX509Credential(cert, pk));
EncryptedKey key = encryptedAssertion.getEncryptedData().getKeyInfo().getEncryptedKeys().get(0);
Decrypter decrypter = new Decrypter(null, keyResolver, null);
SecretKey dkey = (SecretKey) decrypter.decryptKey(key, encryptedAssertion.getEncryptedData().getEncryptionMethod().getAlgorithm());
Credential shared = CredentialSupport.getSimpleCredential(dkey);
decrypter = new Decrypter(new StaticKeyInfoCredentialResolver(shared), null, null);
decrypter.setRootInNewDocument(true);
return decrypter.decrypt(encryptedAssertion);
} catch (Exception e) {
throw new Exception("Decrypted assertion error", e);
}
}
private static void validateAudienceRestriction(Assertion assertion, String requiredSPEntityId) throws Exception {
if (assertion != null) {
Conditions conditions = assertion.getConditions();
if (conditions != null) {
List<AudienceRestriction> audienceRestrictions = conditions.getAudienceRestrictions();
if (audienceRestrictions != null && !audienceRestrictions.isEmpty()) {
boolean audienceFound = false;
for (AudienceRestriction audienceRestriction : audienceRestrictions) {
if (audienceRestriction.getAudiences() != null && !audienceRestriction.getAudiences().isEmpty()
) {
for (Audience audience : audienceRestriction.getAudiences()) {
if (requiredSPEntityId.equals(audience.getURI())) {
audienceFound = true;
break;
}
}
}
if (audienceFound) {
break;
}
}
if (!audienceFound) {
throw new Exception("SAML2 Assertion Audience Restriction validation failed");
}
} else {
throw new Exception("SAML2 Response doesn't contain AudienceRestrictions");
}
} else {
throw new Exception("SAML2 Response doesn't contain Conditions");
}
}
}
private static void validateSignature(Response response, Assertion assertion, Boolean isResponseSigned, Boolean isAssertionSigned, Credential credential) throws Exception {
if (isResponseSigned) {
if (response.getSignature() == null) {
throw new Exception("SAML2 Response signing is enabled, but signature element not found in SAML2 Response element");
} else {
try {
SignatureValidator.validate(response.getSignature(), credential);
} catch (Exception e) {
throw new Exception("Signature validation failed for SAML2 Response");
}
}
}
if (isAssertionSigned) {
if (assertion.getSignature() == null) {
throw new Exception("SAML2 Assertion signing is enabled, but signature element not found in SAML2 Assertion element");
} else {
try {
SignatureValidator.validate(assertion.getSignature(), credential);
} catch (Exception e) {
throw new Exception("Signature validation failed for SAML2 Assertion");
}
}
}
}
}