From 1089f68d91a5ca034277a6cb960fd6fb2fe1d2ef Mon Sep 17 00:00:00 2001 From: Serafeim Chatzopoulos Date: Tue, 4 Jun 2024 18:08:41 +0300 Subject: [PATCH] Return BadRequest when unknown parameters are given --- .../api/controllers/DataSourceController.java | 23 ++++++-- .../controllers/OrganizationController.java | 23 ++++++-- .../api/controllers/ProjectController.java | 23 ++++++-- .../ResearchProductsController.java | 23 ++++++-- .../openaire/api/errors/RequestValidator.java | 57 +++++++++++++++++++ 5 files changed, 129 insertions(+), 20 deletions(-) create mode 100644 src/main/java/eu/openaire/api/errors/RequestValidator.java diff --git a/src/main/java/eu/openaire/api/controllers/DataSourceController.java b/src/main/java/eu/openaire/api/controllers/DataSourceController.java index 5fce9bb..2ac4b1a 100644 --- a/src/main/java/eu/openaire/api/controllers/DataSourceController.java +++ b/src/main/java/eu/openaire/api/controllers/DataSourceController.java @@ -4,6 +4,8 @@ import eu.dnetlib.dhp.oa.model.graph.Datasource; import eu.openaire.api.dto.request.DataSourceRequest; import eu.openaire.api.dto.response.SearchResponse; import eu.openaire.api.errors.ErrorResponse; +import eu.openaire.api.errors.RequestValidator; +import eu.openaire.api.errors.exceptions.BadRequestException; import eu.openaire.api.services.DataSourceService; import io.swagger.v3.oas.annotations.Operation; import io.swagger.v3.oas.annotations.Parameter; @@ -15,10 +17,9 @@ import io.swagger.v3.oas.annotations.tags.Tag; import jakarta.validation.Valid; import lombok.RequiredArgsConstructor; import org.springdoc.core.annotations.ParameterObject; -import org.springframework.web.bind.annotation.GetMapping; -import org.springframework.web.bind.annotation.PathVariable; -import org.springframework.web.bind.annotation.RequestMapping; -import org.springframework.web.bind.annotation.RestController; +import org.springframework.validation.BindingResult; +import org.springframework.web.bind.WebDataBinder; +import org.springframework.web.bind.annotation.*; @RestController @RequestMapping("/v2/dataSources") @@ -29,6 +30,13 @@ public class DataSourceController private final DataSourceService dataSourceService; + private final RequestValidator requestValidator; + + @InitBinder + protected void initBinder(WebDataBinder binder) { + binder.setValidator(requestValidator); + } + @Operation( summary = "Retrieve an data source by id", description = "Get a data source object by specifying its id." @@ -53,7 +61,12 @@ public class DataSourceController @ApiResponse(responseCode = "500", content = { @Content(schema = @Schema(implementation = ErrorResponse.class), mediaType = "application/json") }) }) @GetMapping(value = "") - public Object search(@Valid @ParameterObject final DataSourceRequest request) { + public Object search(@Valid @ParameterObject final DataSourceRequest request, BindingResult validationResult) { + + if (validationResult.hasErrors()) { + throw new BadRequestException(RequestValidator.getErrorMessage(validationResult.getAllErrors())); + } + return dataSourceService.search(request); } } diff --git a/src/main/java/eu/openaire/api/controllers/OrganizationController.java b/src/main/java/eu/openaire/api/controllers/OrganizationController.java index 62d7f58..8d7e6f5 100644 --- a/src/main/java/eu/openaire/api/controllers/OrganizationController.java +++ b/src/main/java/eu/openaire/api/controllers/OrganizationController.java @@ -4,6 +4,8 @@ import eu.dnetlib.dhp.oa.model.graph.Organization; import eu.openaire.api.dto.request.OrganizationRequest; import eu.openaire.api.dto.response.SearchResponse; import eu.openaire.api.errors.ErrorResponse; +import eu.openaire.api.errors.RequestValidator; +import eu.openaire.api.errors.exceptions.BadRequestException; import eu.openaire.api.services.OrganizationService; import io.swagger.v3.oas.annotations.Operation; import io.swagger.v3.oas.annotations.Parameter; @@ -15,10 +17,9 @@ import io.swagger.v3.oas.annotations.tags.Tag; import jakarta.validation.Valid; import lombok.RequiredArgsConstructor; import org.springdoc.core.annotations.ParameterObject; -import org.springframework.web.bind.annotation.GetMapping; -import org.springframework.web.bind.annotation.PathVariable; -import org.springframework.web.bind.annotation.RequestMapping; -import org.springframework.web.bind.annotation.RestController; +import org.springframework.validation.BindingResult; +import org.springframework.web.bind.WebDataBinder; +import org.springframework.web.bind.annotation.*; @RestController @RequestMapping("/v2/organizations") @@ -29,6 +30,13 @@ public class OrganizationController private final OrganizationService organizationService; + private final RequestValidator requestValidator; + + @InitBinder + protected void initBinder(WebDataBinder binder) { + binder.setValidator(requestValidator); + } + @Operation( summary = "Retrieve an organization by id", description = "Get a organization object by specifying its id." @@ -53,7 +61,12 @@ public class OrganizationController @ApiResponse(responseCode = "500", content = { @Content(schema = @Schema(implementation = ErrorResponse.class), mediaType = "application/json") }) }) @GetMapping(value = "") - public Object search(@Valid @ParameterObject final OrganizationRequest request) { + public Object search(@Valid @ParameterObject final OrganizationRequest request, BindingResult validationResult) { + + if (validationResult.hasErrors()) { + throw new BadRequestException(RequestValidator.getErrorMessage(validationResult.getAllErrors())); + } + return organizationService.search(request); } } diff --git a/src/main/java/eu/openaire/api/controllers/ProjectController.java b/src/main/java/eu/openaire/api/controllers/ProjectController.java index 48f03db..9c56dd4 100644 --- a/src/main/java/eu/openaire/api/controllers/ProjectController.java +++ b/src/main/java/eu/openaire/api/controllers/ProjectController.java @@ -4,6 +4,8 @@ import eu.dnetlib.dhp.oa.model.graph.Project; import eu.openaire.api.dto.request.ProjectRequest; import eu.openaire.api.dto.response.SearchResponse; import eu.openaire.api.errors.ErrorResponse; +import eu.openaire.api.errors.RequestValidator; +import eu.openaire.api.errors.exceptions.BadRequestException; import eu.openaire.api.services.ProjectService; import io.swagger.v3.oas.annotations.Operation; import io.swagger.v3.oas.annotations.Parameter; @@ -15,10 +17,9 @@ import io.swagger.v3.oas.annotations.tags.Tag; import jakarta.validation.Valid; import lombok.RequiredArgsConstructor; import org.springdoc.core.annotations.ParameterObject; -import org.springframework.web.bind.annotation.GetMapping; -import org.springframework.web.bind.annotation.PathVariable; -import org.springframework.web.bind.annotation.RequestMapping; -import org.springframework.web.bind.annotation.RestController; +import org.springframework.validation.BindingResult; +import org.springframework.web.bind.WebDataBinder; +import org.springframework.web.bind.annotation.*; @RestController @RequestMapping("/v2/projects") @@ -29,6 +30,13 @@ public class ProjectController private final ProjectService projectService; + private final RequestValidator requestValidator; + + @InitBinder + protected void initBinder(WebDataBinder binder) { + binder.setValidator(requestValidator); + } + @Operation( summary = "Retrieve a project by id", description = "Get a project object by specifying its id." @@ -53,7 +61,12 @@ public class ProjectController @ApiResponse(responseCode = "500", content = { @Content(schema = @Schema(implementation = ErrorResponse.class), mediaType = "application/json") }) }) @GetMapping(value = "") - public Object search(@Valid @ParameterObject final ProjectRequest request) { + public Object search(@Valid @ParameterObject final ProjectRequest request, BindingResult validationResult) { + + if (validationResult.hasErrors()) { + throw new BadRequestException(RequestValidator.getErrorMessage(validationResult.getAllErrors())); + } + return projectService.search(request); } } diff --git a/src/main/java/eu/openaire/api/controllers/ResearchProductsController.java b/src/main/java/eu/openaire/api/controllers/ResearchProductsController.java index 5d2c665..beec8e0 100644 --- a/src/main/java/eu/openaire/api/controllers/ResearchProductsController.java +++ b/src/main/java/eu/openaire/api/controllers/ResearchProductsController.java @@ -4,6 +4,8 @@ import eu.dnetlib.dhp.oa.model.graph.GraphResult; import eu.openaire.api.dto.request.ResearchProductsRequest; import eu.openaire.api.dto.response.SearchResponse; import eu.openaire.api.errors.ErrorResponse; +import eu.openaire.api.errors.RequestValidator; +import eu.openaire.api.errors.exceptions.BadRequestException; import eu.openaire.api.services.ResearchProductService; import io.swagger.v3.oas.annotations.Operation; import io.swagger.v3.oas.annotations.Parameter; @@ -15,10 +17,9 @@ import io.swagger.v3.oas.annotations.tags.Tag; import jakarta.validation.Valid; import lombok.RequiredArgsConstructor; import org.springdoc.core.annotations.ParameterObject; -import org.springframework.web.bind.annotation.GetMapping; -import org.springframework.web.bind.annotation.PathVariable; -import org.springframework.web.bind.annotation.RequestMapping; -import org.springframework.web.bind.annotation.RestController; +import org.springframework.validation.BindingResult; +import org.springframework.web.bind.WebDataBinder; +import org.springframework.web.bind.annotation.*; @RestController @RequestMapping("/v2/researchProducts") @@ -29,6 +30,13 @@ public class ResearchProductsController private final ResearchProductService researchProductService; + private final RequestValidator requestValidator; + + @InitBinder + protected void initBinder(WebDataBinder binder) { + binder.setValidator(requestValidator); + } + @Operation( summary = "Retrieve a research product by id", description = "Get a research product object by specifying its id." @@ -53,7 +61,12 @@ public class ResearchProductsController @ApiResponse(responseCode = "500", content = { @Content(schema = @Schema(implementation = ErrorResponse.class), mediaType = "application/json") }) }) @GetMapping(value = "") - public Object search(@Valid @ParameterObject final ResearchProductsRequest request) { + public Object search(@Valid @ParameterObject final ResearchProductsRequest request, BindingResult validationResult) { + + if (validationResult.hasErrors()) { + throw new BadRequestException(RequestValidator.getErrorMessage(validationResult.getAllErrors())); + } + return researchProductService.search(request); } } diff --git a/src/main/java/eu/openaire/api/errors/RequestValidator.java b/src/main/java/eu/openaire/api/errors/RequestValidator.java new file mode 100644 index 0000000..f45a7ba --- /dev/null +++ b/src/main/java/eu/openaire/api/errors/RequestValidator.java @@ -0,0 +1,57 @@ +package eu.openaire.api.errors; + +import eu.openaire.api.dto.request.DataSourceRequest; +import eu.openaire.api.dto.request.OrganizationRequest; +import eu.openaire.api.dto.request.ProjectRequest; +import eu.openaire.api.dto.request.ResearchProductsRequest; +import jakarta.servlet.http.HttpServletRequest; +import org.springframework.context.support.DefaultMessageSourceResolvable; +import org.springframework.stereotype.Component; +import org.springframework.validation.Errors; +import org.springframework.validation.ObjectError; +import org.springframework.validation.Validator; + +import java.lang.reflect.Field; +import java.util.List; +import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +@Component +public class RequestValidator implements Validator { + + private final HttpServletRequest request; + + public RequestValidator(HttpServletRequest request) { + this.request = request; + } + + @Override + public boolean supports(Class clazz) { + // Add any classes you want this validator to support + return ResearchProductsRequest.class.isAssignableFrom(clazz) + || OrganizationRequest.class.isAssignableFrom(clazz) + || DataSourceRequest.class.isAssignableFrom(clazz) + || ProjectRequest.class.isAssignableFrom(clazz); + } + + @Override + public void validate(Object target, Errors errors) { + Set validParams = Stream.of(target.getClass().getDeclaredFields()) + .map(Field::getName) + .collect(Collectors.toSet()); + + // reject any parameters that are not in the allowed request object + request.getParameterMap().forEach((key, value) -> { + if (!validParams.contains(key)) { + errors.reject(key, "Unknown parameter: " + key + "; valid parameters are: " + validParams); + } + }); + } + + public static String getErrorMessage(List allErrors) { + return allErrors.stream() + .map(DefaultMessageSourceResolvable::getDefaultMessage) + .collect(Collectors.joining("\n")); + } +}