dnet-dedup/dnet-pace-core/src/main/java/eu/dnetlib/pace/util/PaceResolver.java

84 lines
4.3 KiB
Java

package eu.dnetlib.pace.util;
import eu.dnetlib.pace.clustering.ClusteringClass;
import eu.dnetlib.pace.clustering.ClusteringFunction;
import eu.dnetlib.pace.condition.ConditionAlgo;
import eu.dnetlib.pace.condition.ConditionClass;
import eu.dnetlib.pace.distance.DistanceAlgo;
import eu.dnetlib.pace.distance.DistanceClass;
import eu.dnetlib.pace.model.FieldDef;
import eu.dnetlib.pace.tree.support.Comparator;
import eu.dnetlib.pace.tree.support.ComparatorClass;
import org.reflections.Reflections;
import java.io.Serializable;
import java.lang.reflect.InvocationTargetException;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
public class PaceResolver implements Serializable {
public static final Reflections CLUSTERING_RESOLVER = new Reflections("eu.dnetlib.pace.clustering");
public static final Reflections CONDITION_RESOLVER = new Reflections("eu.dnetlib.pace.condition");
public static final Reflections DISTANCE_RESOLVER = new Reflections("eu.dnetlib.pace.compare.algo");
public static final Reflections COMPARATOR_RESOLVER = new Reflections("eu.dnetlib.pace.tree");
private final Map<String, Class<ClusteringFunction>> clusteringFunctions;
private final Map<String, Class<ConditionAlgo>> conditionAlgos;
private final Map<String, Class<DistanceAlgo>> distanceAlgos;
private final Map<String, Class<Comparator>> comparators;
public PaceResolver() {
this.clusteringFunctions = CLUSTERING_RESOLVER.getTypesAnnotatedWith(ClusteringClass.class).stream()
.filter(ClusteringFunction.class::isAssignableFrom)
.collect(Collectors.toMap(cl -> cl.getAnnotation(ClusteringClass.class).value(), cl -> (Class<ClusteringFunction>)cl));
this.conditionAlgos = CONDITION_RESOLVER.getTypesAnnotatedWith(ConditionClass.class).stream()
.filter(ConditionAlgo.class::isAssignableFrom)
.collect(Collectors.toMap(cl -> cl.getAnnotation(ConditionClass.class).value(), cl -> (Class<ConditionAlgo>)cl));
this.distanceAlgos = DISTANCE_RESOLVER.getTypesAnnotatedWith(DistanceClass.class).stream()
.filter(DistanceAlgo.class::isAssignableFrom)
.collect(Collectors.toMap(cl -> cl.getAnnotation(DistanceClass.class).value(), cl -> (Class<DistanceAlgo>)cl));
this.comparators = COMPARATOR_RESOLVER.getTypesAnnotatedWith(ComparatorClass.class).stream()
.filter(Comparator.class::isAssignableFrom)
.collect(Collectors.toMap(cl -> cl.getAnnotation(ComparatorClass.class).value(), cl -> (Class<Comparator>)cl));
}
public ClusteringFunction getClusteringFunction(String name, Map<String, Integer> params) throws PaceException {
try {
return clusteringFunctions.get(name).getDeclaredConstructor(Map.class).newInstance(params);
} catch (InstantiationException | IllegalAccessException | InvocationTargetException | NoSuchMethodException e) {
throw new PaceException(name + " not found ", e);
}
}
public DistanceAlgo getDistanceAlgo(String name, Map<String, Number> params) throws PaceException {
try {
return distanceAlgos.get(name).getDeclaredConstructor(Map.class).newInstance(params);
} catch (InstantiationException | IllegalAccessException | InvocationTargetException | NoSuchMethodException e) {
throw new PaceException(name + " not found ", e);
}
}
public ConditionAlgo getConditionAlgo(String name, List<FieldDef> fields) throws PaceException {
try {
return conditionAlgos.get(name).getDeclaredConstructor(String.class, List.class).newInstance(name, fields);
} catch (InstantiationException | IllegalAccessException | InvocationTargetException | NoSuchMethodException e) {
throw new PaceException(name + " not found ", e);
}
}
public Comparator getComparator(String name, Map<String, Number> params) throws PaceException {
try {
return comparators.get(name).getDeclaredConstructor(Map.class).newInstance(params);
} catch (InstantiationException | IllegalAccessException | InvocationTargetException | NoSuchMethodException | NullPointerException e) {
throw new PaceException(name + " not found ", e);
}
}
}