2023-07-06 10:28:53 +02:00
|
|
|
|
2023-04-17 11:06:27 +02:00
|
|
|
package eu.dnetlib.pace.tree;
|
|
|
|
|
2023-07-06 10:28:53 +02:00
|
|
|
import java.util.HashMap;
|
|
|
|
import java.util.List;
|
|
|
|
import java.util.Map;
|
|
|
|
import java.util.stream.Collectors;
|
|
|
|
|
2023-04-17 11:06:27 +02:00
|
|
|
import eu.dnetlib.pace.config.Config;
|
|
|
|
import eu.dnetlib.pace.model.Field;
|
|
|
|
import eu.dnetlib.pace.model.FieldList;
|
|
|
|
import eu.dnetlib.pace.model.FieldValueImpl;
|
|
|
|
import eu.dnetlib.pace.model.Person;
|
|
|
|
import eu.dnetlib.pace.tree.support.AbstractComparator;
|
|
|
|
import eu.dnetlib.pace.tree.support.ComparatorClass;
|
|
|
|
|
|
|
|
@ComparatorClass("cosineSimilarity")
|
|
|
|
public class CosineSimilarity extends AbstractComparator {
|
|
|
|
|
2023-07-06 10:28:53 +02:00
|
|
|
Map<String, String> params;
|
2023-04-17 11:06:27 +02:00
|
|
|
|
2023-07-06 10:28:53 +02:00
|
|
|
public CosineSimilarity(Map<String, String> params) {
|
|
|
|
super(params);
|
|
|
|
}
|
2023-04-17 11:06:27 +02:00
|
|
|
|
2023-07-06 10:28:53 +02:00
|
|
|
@Override
|
|
|
|
public double compare(final Field a, final Field b, final Config conf) {
|
2023-04-17 11:06:27 +02:00
|
|
|
|
2023-07-06 10:28:53 +02:00
|
|
|
if (a.isEmpty() || b.isEmpty())
|
|
|
|
return -1;
|
2023-04-17 11:06:27 +02:00
|
|
|
|
2023-07-06 10:28:53 +02:00
|
|
|
double[] aVector = ((FieldValueImpl) a).doubleArrayValue();
|
|
|
|
double[] bVector = ((FieldValueImpl) b).doubleArrayValue();
|
2023-04-17 11:06:27 +02:00
|
|
|
|
2023-07-06 10:28:53 +02:00
|
|
|
return cosineSimilarity(aVector, bVector);
|
|
|
|
}
|
2023-04-17 11:06:27 +02:00
|
|
|
|
2023-07-06 10:28:53 +02:00
|
|
|
double cosineSimilarity(double[] a, double[] b) {
|
|
|
|
double dotProduct = 0;
|
|
|
|
double normASum = 0;
|
|
|
|
double normBSum = 0;
|
2023-04-17 11:06:27 +02:00
|
|
|
|
2023-07-06 10:28:53 +02:00
|
|
|
for (int i = 0; i < a.length; i++) {
|
|
|
|
dotProduct += a[i] * b[i];
|
|
|
|
normASum += a[i] * a[i];
|
|
|
|
normBSum += b[i] * b[i];
|
|
|
|
}
|
2023-04-17 11:06:27 +02:00
|
|
|
|
2023-07-06 10:28:53 +02:00
|
|
|
double eucledianDist = Math.sqrt(normASum) * Math.sqrt(normBSum);
|
|
|
|
return dotProduct / eucledianDist;
|
|
|
|
}
|
2023-04-17 11:06:27 +02:00
|
|
|
|
|
|
|
}
|