diff --git a/src/main/java/eu/dnetlib/uoaauthorizationlibrary/security/AuthorizationService.java b/src/main/java/eu/dnetlib/uoaauthorizationlibrary/security/AuthorizationService.java index 2c23b41..7eec0f1 100644 --- a/src/main/java/eu/dnetlib/uoaauthorizationlibrary/security/AuthorizationService.java +++ b/src/main/java/eu/dnetlib/uoaauthorizationlibrary/security/AuthorizationService.java @@ -1,6 +1,7 @@ package eu.dnetlib.uoaauthorizationlibrary.security; import org.apache.log4j.Logger; +import org.springframework.security.core.Authentication; import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.stereotype.Component; @@ -69,7 +70,7 @@ public class AuthorizationService { } public List getRoles() { - OpenAIREAuthentication authentication = (OpenAIREAuthentication) SecurityContextHolder.getContext().getAuthentication(); + OpenAIREAuthentication authentication = getAuthentication(); if (authentication != null && authentication.isAuthenticated()) { return authentication.getAuthorities().stream().map(GrantedAuthority::getAuthority).collect(Collectors.toList()); } @@ -77,7 +78,7 @@ public class AuthorizationService { } public String getAaiId() { - OpenAIREAuthentication authentication = (OpenAIREAuthentication) SecurityContextHolder.getContext().getAuthentication(); + OpenAIREAuthentication authentication = getAuthentication(); if (authentication != null && authentication.isAuthenticated()) { return authentication.getUser().getSub(); } @@ -85,10 +86,19 @@ public class AuthorizationService { } public String getEmail() { - OpenAIREAuthentication authentication = (OpenAIREAuthentication) SecurityContextHolder.getContext().getAuthentication(); + OpenAIREAuthentication authentication = getAuthentication(); if (authentication != null && authentication.isAuthenticated()) { return authentication.getUser().getEmail(); } return null; } + + private OpenAIREAuthentication getAuthentication() { + Authentication authentication = SecurityContextHolder.getContext().getAuthentication(); + if(authentication instanceof OpenAIREAuthentication) { + return (OpenAIREAuthentication) authentication; + } else { + return null; + } + } } diff --git a/src/main/java/eu/dnetlib/uoaauthorizationlibrary/utils/AuthorizationUtils.java b/src/main/java/eu/dnetlib/uoaauthorizationlibrary/utils/AuthorizationUtils.java index 45b72d0..8add38c 100644 --- a/src/main/java/eu/dnetlib/uoaauthorizationlibrary/utils/AuthorizationUtils.java +++ b/src/main/java/eu/dnetlib/uoaauthorizationlibrary/utils/AuthorizationUtils.java @@ -10,12 +10,15 @@ import org.springframework.web.client.RestTemplate; import javax.servlet.http.Cookie; import javax.servlet.http.HttpServletRequest; +import java.util.Arrays; import java.util.Collections; @Component public class AuthorizationUtils { private final Logger log = Logger.getLogger(this.getClass()); private final SecurityConfig securityConfig; + private final static String TOKEN = "AccessToken"; + private final static String SESSION = "OpenAIRESession"; @Autowired AuthorizationUtils(SecurityConfig securityConfig) { @@ -27,10 +30,9 @@ public class AuthorizationUtils { return null; } for (Cookie c : request.getCookies()) { - if (c.getName().equals("AccessToken")) { + if (c.getName().equals(TOKEN)) { return c.getValue(); } - } return null; } @@ -39,11 +41,23 @@ public class AuthorizationUtils { String url = securityConfig.getUserInfoUrl() + (securityConfig.isDeprecated()?getToken(request):""); RestTemplate restTemplate = new RestTemplate(); try { - ResponseEntity response = restTemplate.exchange(url, HttpMethod.GET, createHeaders(request), UserInfo.class); - return response.getBody(); - } catch (RestClientException e) { - log.error(e.getMessage()); + if(hasCookie(request)) { + ResponseEntity response = restTemplate.exchange(url, HttpMethod.GET, createHeaders(request), UserInfo.class); + return response.getBody(); + } return null; + } catch (RestClientException e) { + log.error(url + ":" + e.getMessage()); + return null; + } + } + + private boolean hasCookie(HttpServletRequest request) { + Cookie[] cookies = request.getCookies(); + if(securityConfig.isDeprecated()) { + return Arrays.stream(cookies).anyMatch(cookie -> cookie.getName().equalsIgnoreCase(TOKEN)); + } else { + return Arrays.stream(cookies).anyMatch(cookie -> cookie.getName().equalsIgnoreCase(SESSION)); } }