Return BadRequest when unknown parameters are given

This commit is contained in:
Serafeim Chatzopoulos 2024-06-04 18:08:41 +03:00
parent af48b4e4b5
commit 1089f68d91
5 changed files with 129 additions and 20 deletions

View File

@ -4,6 +4,8 @@ import eu.dnetlib.dhp.oa.model.graph.Datasource;
import eu.openaire.api.dto.request.DataSourceRequest;
import eu.openaire.api.dto.response.SearchResponse;
import eu.openaire.api.errors.ErrorResponse;
import eu.openaire.api.errors.RequestValidator;
import eu.openaire.api.errors.exceptions.BadRequestException;
import eu.openaire.api.services.DataSourceService;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.Parameter;
@ -15,10 +17,9 @@ import io.swagger.v3.oas.annotations.tags.Tag;
import jakarta.validation.Valid;
import lombok.RequiredArgsConstructor;
import org.springdoc.core.annotations.ParameterObject;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.PathVariable;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.validation.BindingResult;
import org.springframework.web.bind.WebDataBinder;
import org.springframework.web.bind.annotation.*;
@RestController
@RequestMapping("/v2/dataSources")
@ -29,6 +30,13 @@ public class DataSourceController
private final DataSourceService dataSourceService;
private final RequestValidator requestValidator;
@InitBinder
protected void initBinder(WebDataBinder binder) {
binder.setValidator(requestValidator);
}
@Operation(
summary = "Retrieve an data source by id",
description = "Get a data source object by specifying its id."
@ -53,7 +61,12 @@ public class DataSourceController
@ApiResponse(responseCode = "500", content = { @Content(schema = @Schema(implementation = ErrorResponse.class), mediaType = "application/json") })
})
@GetMapping(value = "")
public Object search(@Valid @ParameterObject final DataSourceRequest request) {
public Object search(@Valid @ParameterObject final DataSourceRequest request, BindingResult validationResult) {
if (validationResult.hasErrors()) {
throw new BadRequestException(RequestValidator.getErrorMessage(validationResult.getAllErrors()));
}
return dataSourceService.search(request);
}
}

View File

@ -4,6 +4,8 @@ import eu.dnetlib.dhp.oa.model.graph.Organization;
import eu.openaire.api.dto.request.OrganizationRequest;
import eu.openaire.api.dto.response.SearchResponse;
import eu.openaire.api.errors.ErrorResponse;
import eu.openaire.api.errors.RequestValidator;
import eu.openaire.api.errors.exceptions.BadRequestException;
import eu.openaire.api.services.OrganizationService;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.Parameter;
@ -15,10 +17,9 @@ import io.swagger.v3.oas.annotations.tags.Tag;
import jakarta.validation.Valid;
import lombok.RequiredArgsConstructor;
import org.springdoc.core.annotations.ParameterObject;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.PathVariable;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.validation.BindingResult;
import org.springframework.web.bind.WebDataBinder;
import org.springframework.web.bind.annotation.*;
@RestController
@RequestMapping("/v2/organizations")
@ -29,6 +30,13 @@ public class OrganizationController
private final OrganizationService organizationService;
private final RequestValidator requestValidator;
@InitBinder
protected void initBinder(WebDataBinder binder) {
binder.setValidator(requestValidator);
}
@Operation(
summary = "Retrieve an organization by id",
description = "Get a organization object by specifying its id."
@ -53,7 +61,12 @@ public class OrganizationController
@ApiResponse(responseCode = "500", content = { @Content(schema = @Schema(implementation = ErrorResponse.class), mediaType = "application/json") })
})
@GetMapping(value = "")
public Object search(@Valid @ParameterObject final OrganizationRequest request) {
public Object search(@Valid @ParameterObject final OrganizationRequest request, BindingResult validationResult) {
if (validationResult.hasErrors()) {
throw new BadRequestException(RequestValidator.getErrorMessage(validationResult.getAllErrors()));
}
return organizationService.search(request);
}
}

View File

@ -4,6 +4,8 @@ import eu.dnetlib.dhp.oa.model.graph.Project;
import eu.openaire.api.dto.request.ProjectRequest;
import eu.openaire.api.dto.response.SearchResponse;
import eu.openaire.api.errors.ErrorResponse;
import eu.openaire.api.errors.RequestValidator;
import eu.openaire.api.errors.exceptions.BadRequestException;
import eu.openaire.api.services.ProjectService;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.Parameter;
@ -15,10 +17,9 @@ import io.swagger.v3.oas.annotations.tags.Tag;
import jakarta.validation.Valid;
import lombok.RequiredArgsConstructor;
import org.springdoc.core.annotations.ParameterObject;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.PathVariable;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.validation.BindingResult;
import org.springframework.web.bind.WebDataBinder;
import org.springframework.web.bind.annotation.*;
@RestController
@RequestMapping("/v2/projects")
@ -29,6 +30,13 @@ public class ProjectController
private final ProjectService projectService;
private final RequestValidator requestValidator;
@InitBinder
protected void initBinder(WebDataBinder binder) {
binder.setValidator(requestValidator);
}
@Operation(
summary = "Retrieve a project by id",
description = "Get a project object by specifying its id."
@ -53,7 +61,12 @@ public class ProjectController
@ApiResponse(responseCode = "500", content = { @Content(schema = @Schema(implementation = ErrorResponse.class), mediaType = "application/json") })
})
@GetMapping(value = "")
public Object search(@Valid @ParameterObject final ProjectRequest request) {
public Object search(@Valid @ParameterObject final ProjectRequest request, BindingResult validationResult) {
if (validationResult.hasErrors()) {
throw new BadRequestException(RequestValidator.getErrorMessage(validationResult.getAllErrors()));
}
return projectService.search(request);
}
}

View File

@ -4,6 +4,8 @@ import eu.dnetlib.dhp.oa.model.graph.GraphResult;
import eu.openaire.api.dto.request.ResearchProductsRequest;
import eu.openaire.api.dto.response.SearchResponse;
import eu.openaire.api.errors.ErrorResponse;
import eu.openaire.api.errors.RequestValidator;
import eu.openaire.api.errors.exceptions.BadRequestException;
import eu.openaire.api.services.ResearchProductService;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.Parameter;
@ -15,10 +17,9 @@ import io.swagger.v3.oas.annotations.tags.Tag;
import jakarta.validation.Valid;
import lombok.RequiredArgsConstructor;
import org.springdoc.core.annotations.ParameterObject;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.PathVariable;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.validation.BindingResult;
import org.springframework.web.bind.WebDataBinder;
import org.springframework.web.bind.annotation.*;
@RestController
@RequestMapping("/v2/researchProducts")
@ -29,6 +30,13 @@ public class ResearchProductsController
private final ResearchProductService researchProductService;
private final RequestValidator requestValidator;
@InitBinder
protected void initBinder(WebDataBinder binder) {
binder.setValidator(requestValidator);
}
@Operation(
summary = "Retrieve a research product by id",
description = "Get a research product object by specifying its id."
@ -53,7 +61,12 @@ public class ResearchProductsController
@ApiResponse(responseCode = "500", content = { @Content(schema = @Schema(implementation = ErrorResponse.class), mediaType = "application/json") })
})
@GetMapping(value = "")
public Object search(@Valid @ParameterObject final ResearchProductsRequest request) {
public Object search(@Valid @ParameterObject final ResearchProductsRequest request, BindingResult validationResult) {
if (validationResult.hasErrors()) {
throw new BadRequestException(RequestValidator.getErrorMessage(validationResult.getAllErrors()));
}
return researchProductService.search(request);
}
}

View File

@ -0,0 +1,57 @@
package eu.openaire.api.errors;
import eu.openaire.api.dto.request.DataSourceRequest;
import eu.openaire.api.dto.request.OrganizationRequest;
import eu.openaire.api.dto.request.ProjectRequest;
import eu.openaire.api.dto.request.ResearchProductsRequest;
import jakarta.servlet.http.HttpServletRequest;
import org.springframework.context.support.DefaultMessageSourceResolvable;
import org.springframework.stereotype.Component;
import org.springframework.validation.Errors;
import org.springframework.validation.ObjectError;
import org.springframework.validation.Validator;
import java.lang.reflect.Field;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
@Component
public class RequestValidator implements Validator {
private final HttpServletRequest request;
public RequestValidator(HttpServletRequest request) {
this.request = request;
}
@Override
public boolean supports(Class<?> clazz) {
// Add any classes you want this validator to support
return ResearchProductsRequest.class.isAssignableFrom(clazz)
|| OrganizationRequest.class.isAssignableFrom(clazz)
|| DataSourceRequest.class.isAssignableFrom(clazz)
|| ProjectRequest.class.isAssignableFrom(clazz);
}
@Override
public void validate(Object target, Errors errors) {
Set<String> validParams = Stream.of(target.getClass().getDeclaredFields())
.map(Field::getName)
.collect(Collectors.toSet());
// reject any parameters that are not in the allowed request object
request.getParameterMap().forEach((key, value) -> {
if (!validParams.contains(key)) {
errors.reject(key, "Unknown parameter: " + key + "; valid parameters are: " + validParams);
}
});
}
public static String getErrorMessage(List<ObjectError> allErrors) {
return allErrors.stream()
.map(DefaultMessageSourceResolvable::getDefaultMessage)
.collect(Collectors.joining("\n"));
}
}