package eu.eudat.logic.security.validators.configurableProvider; import eu.eudat.logic.security.customproviders.ConfigurableProvider.entities.saml2.CertificateInfo; import eu.eudat.logic.security.customproviders.ConfigurableProvider.entities.saml2.Saml2ConfigurableProvider; import eu.eudat.logic.utilities.builders.XmlBuilder; import eu.eudat.models.data.saml2.AuthnRequestModel; import jakarta.xml.soap.*; import net.shibboleth.utilities.java.support.component.ComponentInitializationException; import net.shibboleth.utilities.java.support.resolver.CriteriaSet; import net.shibboleth.utilities.java.support.xml.BasicParserPool; import org.apache.commons.collections4.CollectionUtils; import org.apache.commons.lang3.StringUtils; import org.apache.http.HttpResponse; import org.apache.http.client.HttpClient; import org.apache.http.client.ResponseHandler; import org.apache.http.client.methods.HttpPost; import org.apache.http.conn.ssl.SSLConnectionSocketFactory; import org.apache.http.conn.ssl.TrustSelfSignedStrategy; import org.apache.http.entity.ContentType; import org.apache.http.entity.StringEntity; import org.apache.http.impl.client.BasicResponseHandler; import org.apache.http.impl.client.CloseableHttpClient; import org.apache.http.impl.client.HttpClientBuilder; import org.apache.http.impl.client.HttpClients; import org.apache.http.ssl.SSLContextBuilder; import org.apache.xml.security.c14n.Canonicalizer; import org.apache.xml.security.signature.XMLSignature; import org.opensaml.core.config.ConfigurationService; import org.opensaml.core.config.InitializationException; import org.opensaml.core.config.InitializationService; import org.opensaml.core.criterion.EntityIdCriterion; import org.opensaml.core.xml.XMLObject; import org.opensaml.core.xml.XMLObjectBuilder; import org.opensaml.core.xml.config.XMLObjectProviderRegistry; import org.opensaml.core.xml.io.*; import org.opensaml.core.xml.schema.*; import org.opensaml.saml.common.SAMLObject; import org.opensaml.saml.common.SAMLVersion; import org.opensaml.saml.common.xml.SAMLConstants; import org.opensaml.saml.criterion.EntityRoleCriterion; import org.opensaml.saml.criterion.ProtocolCriterion; import org.opensaml.saml.metadata.resolver.impl.HTTPMetadataResolver; import org.opensaml.saml.metadata.resolver.impl.PredicateRoleDescriptorResolver; import org.opensaml.saml.saml2.core.*; import org.opensaml.saml.saml2.encryption.Decrypter; import org.opensaml.saml.saml2.metadata.*; import org.opensaml.saml.security.impl.MetadataCredentialResolver; import org.opensaml.security.credential.Credential; import org.opensaml.security.credential.CredentialSupport; import org.opensaml.security.credential.UsageType; import org.opensaml.security.criteria.UsageCriterion; import org.opensaml.security.x509.BasicX509Credential; import org.opensaml.security.x509.X509Credential; import org.opensaml.security.x509.impl.KeyStoreX509CredentialAdapter; import org.opensaml.soap.soap11.Body; import org.opensaml.soap.soap11.Envelope; import org.opensaml.xml.util.Base64; import org.opensaml.xmlsec.config.impl.DefaultSecurityConfigurationBootstrap; import org.opensaml.xmlsec.encryption.EncryptedKey; import org.opensaml.xmlsec.keyinfo.KeyInfoCredentialResolver; import org.opensaml.xmlsec.keyinfo.KeyInfoGenerator; import org.opensaml.xmlsec.keyinfo.impl.StaticKeyInfoCredentialResolver; import org.opensaml.xmlsec.keyinfo.impl.X509KeyInfoGeneratorFactory; import org.opensaml.xmlsec.signature.KeyInfo; import org.opensaml.xmlsec.signature.Signature; import org.opensaml.xmlsec.signature.X509Data; import org.opensaml.xmlsec.signature.support.SignatureValidator; import org.opensaml.xmlsec.signature.support.Signer; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.w3c.dom.Document; import org.w3c.dom.Element; import org.w3c.dom.bootstrap.DOMImplementationRegistry; import org.w3c.dom.ls.DOMImplementationLS; import org.w3c.dom.ls.LSOutput; import org.w3c.dom.ls.LSSerializer; import org.xml.sax.SAXException; import javax.crypto.SecretKey; import javax.xml.namespace.QName; import javax.xml.parsers.DocumentBuilder; import javax.xml.parsers.DocumentBuilderFactory; import javax.xml.parsers.ParserConfigurationException; import javax.xml.transform.TransformerException; import javax.xml.transform.TransformerFactory; import javax.xml.transform.dom.DOMSource; import javax.xml.transform.stream.StreamResult; import java.io.*; import java.net.URLEncoder; import java.net.UnknownHostException; import java.nio.charset.StandardCharsets; import java.security.*; import java.security.cert.CertificateEncodingException; import java.security.cert.CertificateException; import java.security.cert.X509Certificate; import java.time.Instant; import java.util.*; import java.util.stream.Collectors; import java.util.zip.Inflater; import java.util.zip.InflaterInputStream; public class Saml2SSOUtils { private static final Logger logger = LoggerFactory.getLogger(Saml2SSOUtils.class); private static boolean isBootStrapped = false; private static BasicParserPool parserPool; private static XMLObjectProviderRegistry registry; private Saml2SSOUtils() { } private static void doBootstrap() throws Exception { if (!isBootStrapped) { try { boostrap(); isBootStrapped = true; } catch (Exception e) { throw new Exception("Error in bootstrapping the OpenSAML2 library", e); } } } private static void boostrap(){ parserPool = new BasicParserPool(); parserPool.setMaxPoolSize(100); parserPool.setCoalescing(true); parserPool.setIgnoreComments(true); parserPool.setIgnoreElementContentWhitespace(true); parserPool.setNamespaceAware(true); parserPool.setExpandEntityReferences(false); parserPool.setXincludeAware(false); final Map features = new HashMap(); features.put("http://xml.org/sax/features/external-general-entities", Boolean.FALSE); features.put("http://xml.org/sax/features/external-parameter-entities", Boolean.FALSE); features.put("http://apache.org/xml/features/disallow-doctype-decl", Boolean.TRUE); features.put("http://apache.org/xml/features/validation/schema/normalized-value", Boolean.FALSE); features.put("http://javax.xml.XMLConstants/feature/secure-processing", Boolean.TRUE); parserPool.setBuilderFeatures(features); parserPool.setBuilderAttributes(new HashMap()); try { parserPool.initialize(); } catch (ComponentInitializationException e) { logger.error(e.getMessage(), e); } registry = new XMLObjectProviderRegistry(); ConfigurationService.register(XMLObjectProviderRegistry.class, registry); registry.setParserPool(parserPool); try { InitializationService.initialize(); } catch (InitializationException e) { logger.error(e.getMessage(), e); } } private static XMLObject buildXMLObject(QName objectQName) throws Exception { doBootstrap(); XMLObjectBuilder builder = registry.getBuilderFactory().getBuilder(objectQName); if (builder == null) { throw new Exception("Unable to retrieve builder for object QName " + objectQName); } return builder.buildObject(objectQName.getNamespaceURI(), objectQName.getLocalPart(), objectQName.getPrefix()); } public static String getAttributeName(Attribute attribute, Saml2ConfigurableProvider.SAML2UsingFormat usingFormat){ String friendlyName = attribute.getFriendlyName(); String name = attribute.getName(); if(usingFormat.getName().equals(Saml2ConfigurableProvider.SAML2UsingFormat.FRIENDLY_NAME.getName())){ return (friendlyName != null) ? friendlyName : name; } else{ return (name != null) ? name : friendlyName; } } public static Object getAttributeType(XMLObject attribute, Saml2ConfigurableProvider.SAML2AttributeType attributeType){ if(attributeType.getType().equals(Saml2ConfigurableProvider.SAML2AttributeType.XSSTRING.getType())){ return ((XSString)attribute).getValue(); } else if(attributeType.getType().equals(Saml2ConfigurableProvider.SAML2AttributeType.XSINTEGER.getType())){ return ((XSInteger)attribute).getValue(); } else if(attributeType.getType().equals(Saml2ConfigurableProvider.SAML2AttributeType.XSDATETIME.getType())){ return ((XSDateTime)attribute).getValue(); } else if(attributeType.getType().equals(Saml2ConfigurableProvider.SAML2AttributeType.XSBOOLEAN.getType())){ return ((XSBoolean)attribute).getValue(); } else if(attributeType.getType().equals(Saml2ConfigurableProvider.SAML2AttributeType.XSBASE64BINARY.getType())){ return ((XSBase64Binary)attribute).getValue(); } else if(attributeType.getType().equals(Saml2ConfigurableProvider.SAML2AttributeType.XSURI.getType())){ return ((XSURI)attribute).getURI(); } else if(attributeType.getType().equals(Saml2ConfigurableProvider.SAML2AttributeType.XSQNAME.getType())){ return ((XSQName)attribute).getValue(); } else if(attributeType.getType().equals(Saml2ConfigurableProvider.SAML2AttributeType.XSANY.getType())){ return ((XSAny)attribute).getTextContent(); } else { return null; } } private static String marshall(XMLObject xmlObject) throws Exception { try { MarshallerFactory marshallerFactory = registry.getMarshallerFactory(); Marshaller marshaller = marshallerFactory.getMarshaller(xmlObject); Element element = marshaller.marshall(xmlObject); ByteArrayOutputStream byteArrayOutputStrm = new ByteArrayOutputStream(); DOMImplementationRegistry registry = DOMImplementationRegistry.newInstance(); DOMImplementationLS impl = (DOMImplementationLS) registry.getDOMImplementation("LS"); LSSerializer writer = impl.createLSSerializer(); LSOutput output = impl.createLSOutput(); output.setByteStream(byteArrayOutputStrm); writer.write(element, output); return byteArrayOutputStrm.toString(); } catch (Exception e) { throw new Exception("Error Serializing the SAML Response", e); } } private static XMLObject unmarshall(String saml2SSOString) throws Exception { DocumentBuilderFactory documentBuilderFactory = DocumentBuilderFactory.newInstance(); //documentBuilderFactory.setExpandEntityReferences(false); documentBuilderFactory.setNamespaceAware(true); try { DocumentBuilder docBuilder = documentBuilderFactory.newDocumentBuilder(); ByteArrayInputStream is = new ByteArrayInputStream(saml2SSOString.getBytes(StandardCharsets.UTF_8)); Document document = docBuilder.parse(is); Element element = document.getDocumentElement(); UnmarshallerFactory unmarshallerFactory = registry.getUnmarshallerFactory(); Unmarshaller unmarshaller = unmarshallerFactory.getUnmarshaller(element); return unmarshaller.unmarshall(element); } catch (ParserConfigurationException | UnmarshallingException | SAXException | IOException e) { throw new Exception("Error in unmarshalling SAML2SSO Request from the encoded String", e); } } public static Assertion processArtifactResponse(String artifactString, Saml2ConfigurableProvider saml2Provider) throws Exception { doBootstrap(); if (artifactString != null){ ArtifactResolve artifactResolve = generateArtifactResolveReq(artifactString, saml2Provider); ArtifactResponse artifactResponse = sendArtifactResolveRequest(artifactResolve, saml2Provider.getIdpArtifactUrl()); Response saml2Response = (Response)artifactResponse.getMessage(); return processSSOResponse(saml2Response, saml2Provider); } else { throw new Exception("Invalid SAML2 Artifact. SAML2 Artifact can not be null."); } } private static ArtifactResolve generateArtifactResolveReq(String samlArtReceived, Saml2ConfigurableProvider saml2Provider) throws Exception { ArtifactResolve artifactResolve = createArtifactResolveObject(samlArtReceived, saml2Provider.getSpEntityId()); if (saml2Provider.isSignatureRequired()) { signArtifactResolveReq(artifactResolve, saml2Provider.getSigningCert()); } return artifactResolve; } private static ArtifactResolve createArtifactResolveObject(String samlArtReceived, String spEntityId) throws Exception { ArtifactResolve artifactResolve = (ArtifactResolve)buildXMLObject(ArtifactResolve.DEFAULT_ELEMENT_NAME); artifactResolve.setVersion(SAMLVersion.VERSION_20); artifactResolve.setID(UUID.randomUUID().toString()); artifactResolve.setIssueInstant(Instant.now()); Artifact artifact = (Artifact)buildXMLObject(Artifact.DEFAULT_ELEMENT_NAME); artifact.setValue(samlArtReceived); Issuer issuer = (Issuer)buildXMLObject(Issuer.DEFAULT_ELEMENT_NAME); issuer.setValue(spEntityId); artifactResolve.setIssuer(issuer); artifactResolve.setArtifact(artifact); return artifactResolve; } private static void signArtifactResolveReq(ArtifactResolve artifactResolve, CertificateInfo singingCertificateInfo) throws Exception { try { KeyStore ks = KeyStore.getInstance("JKS"); String archivePassword = singingCertificateInfo.getKeystorePassword(); char[] pwdArray = (archivePassword != null) ? archivePassword.toCharArray() : "changeit".toCharArray(); ks.load(new FileInputStream(singingCertificateInfo.getKeystorePath()), pwdArray); X509Credential cred = new KeyStoreX509CredentialAdapter(ks, singingCertificateInfo.getAlias(), singingCertificateInfo.getPassword().toCharArray()); Signature signature = setSignatureRaw(XMLSignature.ALGO_ID_SIGNATURE_RSA, cred); artifactResolve.setSignature(signature); List signatureList = new ArrayList<>(); signatureList.add(signature); MarshallerFactory marshallerFactory = registry.getMarshallerFactory(); Marshaller marshaller = marshallerFactory.getMarshaller(artifactResolve); marshaller.marshall(artifactResolve); org.apache.xml.security.Init.init(); Signer.signObjects(signatureList); } catch (Exception e) { throw new Exception("Error while signing the SAML Request message", e); } } private static Signature setSignatureRaw(String signatureAlgorithm, X509Credential cred) throws Exception { Signature signature = (Signature)buildXMLObject(Signature.DEFAULT_ELEMENT_NAME); signature.setSigningCredential(cred); signature.setSignatureAlgorithm(signatureAlgorithm); signature.setCanonicalizationAlgorithm(Canonicalizer.ALGO_ID_C14N_EXCL_OMIT_COMMENTS); try { KeyInfo keyInfo = (KeyInfo)buildXMLObject(KeyInfo.DEFAULT_ELEMENT_NAME); X509Data data = (X509Data)buildXMLObject(X509Data.DEFAULT_ELEMENT_NAME); org.opensaml.xmlsec.signature.X509Certificate cert = (org.opensaml.xmlsec.signature.X509Certificate) buildXMLObject( org.opensaml.xmlsec.signature.X509Certificate.DEFAULT_ELEMENT_NAME); String value = org.apache.commons.codec.binary.Base64.encodeBase64String(cred.getEntityCertificate().getEncoded()); cert.setValue(value); data.getX509Certificates().add(cert); keyInfo.getX509Datas().add(data); signature.setKeyInfo(keyInfo); return signature; } catch (CertificateEncodingException e) { throw new Exception("Error getting certificate", e); } } private static ArtifactResponse sendArtifactResolveRequest(ArtifactResolve artifactResolve, String idpArtifactUrl) throws Exception { Envelope envelope = buildSOAPMessage(artifactResolve); String envelopeElement; try { envelopeElement = marshall(envelope); } catch (Exception e) { throw new Exception("Encountered error marshalling SOAP message with artifact " + "resolve, into its DOM representation", e); } String artifactResponseString = sendSOAP(envelopeElement, idpArtifactUrl); ArtifactResponse artifactResponse = extractArtifactResponse(artifactResponseString); validateArtifactResponse(artifactResolve, artifactResponse); return artifactResponse; } private static Envelope buildSOAPMessage(SAMLObject samlMessage) throws Exception { Envelope envelope = (Envelope)buildXMLObject(Envelope.DEFAULT_ELEMENT_NAME); Body body = (Body)buildXMLObject(Body.DEFAULT_ELEMENT_NAME); body.getUnknownXMLObjects().add(samlMessage); envelope.setBody(body); return envelope; } private static String sendSOAP(String message, String idpArtifactUrl) throws Exception { if (message == null) { throw new Exception("Cannot send null SOAP message."); } if (idpArtifactUrl == null) { throw new Exception("Cannot send SOAP message to null URL."); } StringBuilder soapResponse = new StringBuilder(); try { HttpPost httpPost = new HttpPost(idpArtifactUrl); setRequestProperties(idpArtifactUrl, message, httpPost); HttpClient httpClient = getHttpClient(); HttpResponse httpResponse = httpClient.execute(httpPost); int responseCode = httpResponse.getStatusLine().getStatusCode(); if (responseCode != 200) { throw new Exception("Problem in communicating with: " + idpArtifactUrl + ". Received response: " + responseCode); } else { soapResponse.append(getResponseBody(httpResponse)); } } catch (UnknownHostException e) { throw new Exception("Unknown targeted host: " + idpArtifactUrl, e); } catch (IOException e) { throw new Exception("Could not open connection with host: " + idpArtifactUrl, e); } return soapResponse.toString(); } private static void setRequestProperties(String idpArtifactUrl, String message, HttpPost httpPost) { httpPost.addHeader("Content-Type", "text/xml; charset=utf-8"); httpPost.addHeader("Accept", "text/xml; charset=utf-8"); String sbSOAPAction = "\"" + idpArtifactUrl + "\""; httpPost.addHeader("SOAPAction", sbSOAPAction); httpPost.addHeader("Pragma", "no-cache"); httpPost.addHeader("Cache-Control", "no-cache, no-store"); httpPost.setEntity(new StringEntity(message, ContentType.create("text/xml", StandardCharsets.UTF_8))); } private static HttpClient getHttpClient() throws Exception { CloseableHttpClient httpClient = null; SSLContextBuilder builder = new SSLContextBuilder(); try { builder.loadTrustMaterial(null, new TrustSelfSignedStrategy()); SSLConnectionSocketFactory sslsf = new SSLConnectionSocketFactory( builder.build()); httpClient = HttpClients.custom().setSSLSocketFactory( sslsf).build(); } catch (NoSuchAlgorithmException | KeyStoreException e) { throw new Exception("Error while building trust store.", e); } catch (KeyManagementException e) { throw new Exception("Error while building socket factory.", e); } return httpClient; } private static String getResponseBody(HttpResponse response) throws Exception { ResponseHandler responseHandler = new BasicResponseHandler(); String responseBody; try { responseBody = responseHandler.handleResponse(response); } catch (IOException e) { throw new Exception("Error when retrieving the HTTP response body.", e); } return responseBody; } private static ArtifactResponse extractArtifactResponse(String artifactResponseString) throws Exception { ArtifactResponse artifactResponse = null; InputStream stream = new ByteArrayInputStream(artifactResponseString.getBytes(StandardCharsets.UTF_8)); try { MessageFactory messageFactory = MessageFactory.newInstance(); SOAPMessage soapMessage = messageFactory.createMessage(new MimeHeaders(), stream); SOAPBody soapBody = soapMessage.getSOAPBody(); Iterator iterator = soapBody.getChildElements(); while (iterator.hasNext()) { SOAPBodyElement artifactResponseElement = (SOAPBodyElement) iterator.next(); if (StringUtils.equals(SAMLConstants.SAML20P_NS, artifactResponseElement.getNamespaceURI()) && StringUtils.equals(ArtifactResponse.DEFAULT_ELEMENT_LOCAL_NAME, artifactResponseElement.getLocalName())) { DOMSource source = new DOMSource(artifactResponseElement); StringWriter stringResult = new StringWriter(); TransformerFactory.newInstance().newTransformer().transform( source, new StreamResult(stringResult)); artifactResponse = (ArtifactResponse) unmarshall(stringResult.toString()); } else { throw new Exception("Received invalid artifact response with nameSpaceURI: " + artifactResponseElement.getNamespaceURI() + " and localName: " + artifactResponseElement.getLocalName()); } } } catch (SOAPException | IOException | TransformerException e) { throw new Exception("Didn't receive valid artifact response.", e); } catch (Exception e) { throw new Exception("Encountered error unmarshalling response into SAML2 object", e); } return artifactResponse; } private static void validateArtifactResponse(ArtifactResolve artifactResolve, ArtifactResponse artifactResponse) throws Exception { if (artifactResponse == null) { throw new Exception("Received artifact response message was null."); } String artifactResolveId = artifactResolve.getID(); String artifactResponseInResponseTo = artifactResponse.getInResponseTo(); if (!artifactResolveId.equals(artifactResponseInResponseTo)) { throw new Exception("Artifact resolve ID: " + artifactResolveId + " is not equal to " + "artifact response InResponseTo : " + artifactResponseInResponseTo); } String artifactResponseStatus = artifactResponse.getStatus().getStatusCode().getValue(); if (!StatusCode.SUCCESS.equals(artifactResponseStatus)) { throw new Exception("Unsuccessful artifact response with status: " + artifactResponseStatus); } SAMLObject message = artifactResponse.getMessage(); if (message == null) { throw new Exception("No SAML response embedded into the artifact response."); } } public static Assertion processResponse(String saml2SSOResponse, Saml2ConfigurableProvider saml2Provider) throws Exception { doBootstrap(); if (saml2SSOResponse != null) { byte[] decodedResponse = Base64.decode(saml2SSOResponse); String response; if(!saml2Provider.getBinding().equals("Post")){ ByteArrayInputStream bytesIn = new ByteArrayInputStream(decodedResponse); InflaterInputStream inflater = new InflaterInputStream(bytesIn, new Inflater(true)); response = new BufferedReader(new InputStreamReader(inflater, StandardCharsets.UTF_8)) .lines().collect(Collectors.joining("\n")); } else{ response = new String(decodedResponse); } Response saml2Response = (Response) Saml2SSOUtils.unmarshall(response); return processSSOResponse(saml2Response, saml2Provider); } else { throw new Exception("Invalid SAML2 Response. SAML2 Response can not be null."); } } private static Assertion processSSOResponse(Response saml2Response, Saml2ConfigurableProvider saml2Provider) throws Exception { Assertion assertion = null; if (saml2Provider.isAssertionEncrypted()) { List encryptedAssertions = saml2Response.getEncryptedAssertions(); EncryptedAssertion encryptedAssertion; if (!CollectionUtils.isEmpty(encryptedAssertions)) { encryptedAssertion = encryptedAssertions.get(0); try { assertion = getDecryptedAssertion(encryptedAssertion, saml2Provider.getEncryptionCert()); } catch (Exception e) { throw new Exception("Unable to decrypt the SAML2 Assertion"); } } } else { List assertions = saml2Response.getAssertions(); if (assertions != null && !assertions.isEmpty()) { assertion = assertions.get(0); } } if (assertion == null) { throw new Exception("SAML2 Assertion not found in the Response"); } String idPEntityIdValue = assertion.getIssuer().getValue(); if (idPEntityIdValue == null || idPEntityIdValue.isEmpty()) { throw new Exception("SAML2 Response does not contain an Issuer value"); } else if (!idPEntityIdValue.equals(saml2Provider.getIdpEntityId())) { throw new Exception("SAML2 Response Issuer verification failed"); } String subject = null; if (assertion.getSubject() != null && assertion.getSubject().getNameID() != null) { subject = assertion.getSubject().getNameID().getValue(); } if (subject == null) { throw new Exception("SAML2 Response does not contain the name of the subject"); } validateAudienceRestriction(assertion, saml2Provider.getSpEntityId()); final HTTPMetadataResolver metadataResolver = new HTTPMetadataResolver(HttpClientBuilder.create().build(), saml2Provider.getIdpMetadataUrl()); metadataResolver.setId(metadataResolver.getClass().getCanonicalName()); metadataResolver.setParserPool(parserPool); metadataResolver.initialize(); final MetadataCredentialResolver metadataCredentialResolver = new MetadataCredentialResolver(); final PredicateRoleDescriptorResolver roleResolver = new PredicateRoleDescriptorResolver(metadataResolver); final KeyInfoCredentialResolver keyResolver = DefaultSecurityConfigurationBootstrap.buildBasicInlineKeyInfoCredentialResolver(); metadataCredentialResolver.setKeyInfoCredentialResolver(keyResolver); metadataCredentialResolver.setRoleDescriptorResolver(roleResolver); metadataCredentialResolver.initialize(); roleResolver.initialize(); CriteriaSet criteriaSet = new CriteriaSet(); criteriaSet.add(new UsageCriterion(UsageType.SIGNING)); criteriaSet.add(new EntityRoleCriterion(IDPSSODescriptor.DEFAULT_ELEMENT_NAME)); criteriaSet.add(new ProtocolCriterion(SAMLConstants.SAML20P_NS)); criteriaSet.add(new EntityIdCriterion(saml2Provider.getIdpEntityId())); Credential credential = metadataCredentialResolver.resolveSingle(criteriaSet); validateSignature(saml2Response, assertion, saml2Provider.isResponseSigned(), saml2Provider.isAssertionSigned(), credential); return assertion; } private static Assertion getDecryptedAssertion(EncryptedAssertion encryptedAssertion, CertificateInfo encryptionCertificateInfo) throws Exception { try { KeyStore ks = (encryptionCertificateInfo.getKeyFormat().getType().equals("JKS")) ? KeyStore.getInstance("JKS") : KeyStore.getInstance("PKCS12"); String archivePassword = encryptionCertificateInfo.getKeystorePassword(); char[] pwdArray = (archivePassword != null) ? archivePassword.toCharArray() : "changeit".toCharArray(); ks.load(new FileInputStream(encryptionCertificateInfo.getKeystorePath()), pwdArray); X509Certificate cert = (X509Certificate)ks.getCertificate(encryptionCertificateInfo.getAlias()); PrivateKey pk = (PrivateKey) ks.getKey(encryptionCertificateInfo.getAlias(), encryptionCertificateInfo.getPassword().toCharArray()); KeyInfoCredentialResolver keyResolver = new StaticKeyInfoCredentialResolver( new BasicX509Credential(cert, pk)); EncryptedKey key = encryptedAssertion.getEncryptedData().getKeyInfo().getEncryptedKeys().get(0); Decrypter decrypter = new Decrypter(null, keyResolver, null); SecretKey dkey = (SecretKey) decrypter.decryptKey(key, encryptedAssertion.getEncryptedData().getEncryptionMethod().getAlgorithm()); Credential shared = CredentialSupport.getSimpleCredential(dkey); decrypter = new Decrypter(new StaticKeyInfoCredentialResolver(shared), null, null); decrypter.setRootInNewDocument(true); return decrypter.decrypt(encryptedAssertion); } catch (Exception e) { throw new Exception("Decrypted assertion error", e); } } private static void validateAudienceRestriction(Assertion assertion, String requiredSPEntityId) throws Exception { if (assertion != null) { Conditions conditions = assertion.getConditions(); if (conditions != null) { List audienceRestrictions = conditions.getAudienceRestrictions(); if (audienceRestrictions != null && !audienceRestrictions.isEmpty()) { boolean audienceFound = false; for (AudienceRestriction audienceRestriction : audienceRestrictions) { if (audienceRestriction.getAudiences() != null && !audienceRestriction.getAudiences().isEmpty() ) { for (Audience audience : audienceRestriction.getAudiences()) { if (requiredSPEntityId.equals(audience.getURI())) { audienceFound = true; break; } } } if (audienceFound) { break; } } if (!audienceFound) { throw new Exception("SAML2 Assertion Audience Restriction validation failed"); } } else { throw new Exception("SAML2 Response doesn't contain AudienceRestrictions"); } } else { throw new Exception("SAML2 Response doesn't contain Conditions"); } } } private static void validateSignature(Response response, Assertion assertion, Boolean isResponseSigned, Boolean isAssertionSigned, Credential credential) throws Exception { if (isResponseSigned) { if (response.getSignature() == null) { throw new Exception("SAML2 Response signing is enabled, but signature element not found in SAML2 Response element"); } else { try { SignatureValidator.validate(response.getSignature(), credential); } catch (Exception e) { throw new Exception("Signature validation failed for SAML2 Response"); } } } if (isAssertionSigned) { if (assertion.getSignature() == null) { throw new Exception("SAML2 Assertion signing is enabled, but signature element not found in SAML2 Assertion element"); } else { try { SignatureValidator.validate(assertion.getSignature(), credential); } catch (Exception e) { throw new Exception("Signature validation failed for SAML2 Assertion"); } } } } private static Credential getCredential(CertificateInfo certificateInfo) throws KeyStoreException, IOException, CertificateException, NoSuchAlgorithmException, UnrecoverableKeyException { KeyStore ks = (certificateInfo.getKeyFormat().getType().equals("JKS")) ? KeyStore.getInstance("JKS") : KeyStore.getInstance("PKCS12"); String archivePassword = certificateInfo.getKeystorePassword(); char[] pwdArray = (archivePassword != null) ? archivePassword.toCharArray() : "changeit".toCharArray(); ks.load(new FileInputStream(certificateInfo.getKeystorePath()), pwdArray); X509Certificate cert = (X509Certificate)ks.getCertificate(certificateInfo.getAlias()); PrivateKey pk = (PrivateKey) ks.getKey(certificateInfo.getAlias(), certificateInfo.getPassword().toCharArray()); return new BasicX509Credential(cert, pk); } public static String getMetadata(Saml2ConfigurableProvider provider) throws Exception { EntityDescriptor spEntityDescriptor = (EntityDescriptor) buildXMLObject(EntityDescriptor.DEFAULT_ELEMENT_NAME); spEntityDescriptor.setEntityID(provider.getSpEntityId()); SPSSODescriptor spSSODescriptor = (SPSSODescriptor) buildXMLObject(SPSSODescriptor.DEFAULT_ELEMENT_NAME); spSSODescriptor.setWantAssertionsSigned(provider.isWantAssertionsSigned()); spSSODescriptor.setAuthnRequestsSigned(provider.isAuthnRequestsSigned()); X509KeyInfoGeneratorFactory keyInfoGeneratorFactory = new X509KeyInfoGeneratorFactory(); keyInfoGeneratorFactory.setEmitEntityCertificate(true); KeyInfoGenerator keyInfoGenerator = keyInfoGeneratorFactory.newInstance(); if (provider.isAssertionEncrypted()) { KeyDescriptor encKeyDescriptor = (KeyDescriptor) buildXMLObject(KeyDescriptor.DEFAULT_ELEMENT_NAME); encKeyDescriptor.setUse(UsageType.ENCRYPTION); //Set usage // Generating key info. The element will contain the public key. The key is used to by the IDP to encrypt data try { encKeyDescriptor.setKeyInfo(keyInfoGenerator.generate(getCredential(provider.getEncryptionCert()))); } catch (SecurityException e) { logger.error(e.getMessage(), e); } spSSODescriptor.getKeyDescriptors().add(encKeyDescriptor); } if (provider.isWantAssertionsSigned()) { KeyDescriptor signKeyDescriptor = (KeyDescriptor) buildXMLObject(KeyDescriptor.DEFAULT_ELEMENT_NAME); signKeyDescriptor.setUse(UsageType.SIGNING); //Set usage // Generating key info. The element will contain the public key. The key is used to by the IDP to verify signatures try { signKeyDescriptor.setKeyInfo(keyInfoGenerator.generate(getCredential(provider.getSigningCert()))); } catch (SecurityException e) { logger.error(e.getMessage(), e); } spSSODescriptor.getKeyDescriptors().add(signKeyDescriptor); } NameIDFormat nameIDFormat = (NameIDFormat) buildXMLObject(NameIDFormat.DEFAULT_ELEMENT_NAME); nameIDFormat.setFormat("urn:oasis:names:tc:SAML:2.0:nameid-format:transient"); spSSODescriptor.getNameIDFormats().add(nameIDFormat); AssertionConsumerService assertionConsumerService = (AssertionConsumerService) buildXMLObject(AssertionConsumerService.DEFAULT_ELEMENT_NAME); assertionConsumerService.setIndex(0); switch (provider.getBinding()) { case "Redirect": assertionConsumerService.setBinding(SAMLConstants.SAML2_REDIRECT_BINDING_URI); break; case "Artifact": assertionConsumerService.setBinding(SAMLConstants.SAML2_ARTIFACT_BINDING_URI); break; case "Post": assertionConsumerService.setBinding(SAMLConstants.SAML2_POST_BINDING_URI); break; } assertionConsumerService.setLocation(provider.getAssertionConsumerServiceUrl()); spSSODescriptor.getAssertionConsumerServices().add(assertionConsumerService); spSSODescriptor.addSupportedProtocol(SAMLConstants.SAML20P_NS); spEntityDescriptor.getRoleDescriptors().add(spSSODescriptor); String metadataXML = null; try { DocumentBuilderFactory factory = DocumentBuilderFactory.newInstance(); DocumentBuilder builder = factory.newDocumentBuilder(); Document document = builder.newDocument(); Marshaller out = registry.getMarshallerFactory().getMarshaller(spEntityDescriptor); out.marshall(spEntityDescriptor, document); metadataXML = XmlBuilder.generateXml(document); } catch (MarshallingException | ParserConfigurationException e) { logger.error(e.getMessage(), e); } return metadataXML; } public static AuthnRequestModel getAuthnRequest(Saml2ConfigurableProvider provider) throws Exception { AuthnRequest authnRequest = buildAuthnRequest(provider); String relayState = "spId=" + provider.getSpEntityId() + "&configurableLoginId=" + provider.getConfigurableLoginId(); String authnRequestXml = null; String signatureBase64 = null; try { DocumentBuilderFactory factory = DocumentBuilderFactory.newInstance(); DocumentBuilder builder = factory.newDocumentBuilder(); Document document = builder.newDocument(); Marshaller out = registry.getMarshallerFactory().getMarshaller(authnRequest); out.marshall(authnRequest, document); authnRequestXml = XmlBuilder.generateXml(document); if(provider.isAuthnRequestsSigned()) { signatureBase64 = buildSignature(authnRequestXml, relayState, provider.getSigningCert()); } } catch (MarshallingException | ParserConfigurationException e) { logger.error(e.getMessage(), e); } AuthnRequestModel authnRequestModel = new AuthnRequestModel(); authnRequestModel.setAuthnRequestXml(authnRequestXml); authnRequestModel.setRelayState(relayState); authnRequestModel.setAlgorithm("http://www.w3.org/2000/09/xmldsig#rsa-sha1"); authnRequestModel.setSignature(signatureBase64); return authnRequestModel; } private static String buildSignature(String authnRequest, String relayState, CertificateInfo signingCertInfo) throws Exception{ KeyStore ks = (signingCertInfo.getKeyFormat().getType().equals("JKS")) ? KeyStore.getInstance("JKS") : KeyStore.getInstance("PKCS12"); String archivePassword = signingCertInfo.getKeystorePassword(); char[] pwdArray = (archivePassword != null) ? archivePassword.toCharArray() : "changeit".toCharArray(); ks.load(new FileInputStream(signingCertInfo.getKeystorePath()), pwdArray); PrivateKey pk = (PrivateKey) ks.getKey(signingCertInfo.getAlias(), signingCertInfo.getPassword().toCharArray()); String signAlgorithm = "http://www.w3.org/2000/09/xmldsig#rsa-sha1"; String message = "SAMLRequest=" + URLEncoder.encode(authnRequest, "UTF-8") + "&RelayState=" + URLEncoder.encode(relayState, "UTF-8") + "&SigAlg=" + URLEncoder.encode(signAlgorithm, "UTF-8"); String signature = null; try{ signature = new String(org.apache.commons.codec.binary.Base64.encodeBase64(sign(message, pk)), StandardCharsets.UTF_8); } catch(InvalidKeyException | SignatureException | NoSuchAlgorithmException e){ logger.error(e.getMessage(), e); } return signature; } private static byte[] sign(String message, PrivateKey key) throws InvalidKeyException, SignatureException, NoSuchAlgorithmException { java.security.Signature instance = java.security.Signature.getInstance("SHA1withRSA"); instance.initSign(key); instance.update(message.getBytes()); return instance.sign(); } private static AuthnRequest buildAuthnRequest(Saml2ConfigurableProvider provider) throws Exception { AuthnRequest authnRequest = (AuthnRequest) buildXMLObject(AuthnRequest.DEFAULT_ELEMENT_NAME); authnRequest.setIssueInstant(Instant.now()); authnRequest.setDestination(provider.getIdpUrl()); switch (provider.getBinding()) { case "Redirect": authnRequest.setProtocolBinding(SAMLConstants.SAML2_REDIRECT_BINDING_URI); break; case "Artifact": authnRequest.setProtocolBinding(SAMLConstants.SAML2_ARTIFACT_BINDING_URI); break; case "Post": authnRequest.setProtocolBinding(SAMLConstants.SAML2_POST_BINDING_URI); break; } authnRequest.setAssertionConsumerServiceURL(provider.getAssertionConsumerServiceUrl()); authnRequest.setID('_' + UUID.randomUUID().toString()); authnRequest.setIssuer(buildIssuer(provider.getSpEntityId())); authnRequest.setNameIDPolicy(buildNameIdPolicy()); return authnRequest; } private static NameIDPolicy buildNameIdPolicy() throws Exception { NameIDPolicy nameIDPolicy = (NameIDPolicy) buildXMLObject(NameIDPolicy.DEFAULT_ELEMENT_NAME); nameIDPolicy.setAllowCreate(true); nameIDPolicy.setFormat(NameIDType.TRANSIENT); return nameIDPolicy; } private static Issuer buildIssuer(String spEntityId) throws Exception { Issuer issuer = (Issuer) buildXMLObject(Issuer.DEFAULT_ELEMENT_NAME); issuer.setValue(spEntityId); return issuer; } }