89 lines
3.3 KiB
Java
89 lines
3.3 KiB
Java
package eu.dnetlib.deeplearning.support;
|
|
|
|
import eu.dnetlib.featureextraction.Utilities;
|
|
import eu.dnetlib.support.Author;
|
|
import eu.dnetlib.support.ConnectedComponent;
|
|
import eu.dnetlib.support.Relation;
|
|
import org.apache.spark.api.java.JavaRDD;
|
|
import org.codehaus.jackson.map.ObjectMapper;
|
|
import org.jetbrains.annotations.NotNull;
|
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
import org.nd4j.linalg.dataset.MultiDataSet;
|
|
import org.nd4j.linalg.factory.Nd4j;
|
|
|
|
import java.io.IOException;
|
|
import java.util.*;
|
|
import java.util.stream.Collectors;
|
|
import java.util.stream.Stream;
|
|
|
|
public class DataSetProcessor {
|
|
|
|
public static JavaRDD<MultiDataSet> entityGroupToMultiDataset(JavaRDD<ConnectedComponent> groupEntity, String idJPath, String featureJPath, String groundTruthJPath) {
|
|
|
|
return groupEntity.map(g -> {
|
|
Map<String, double[]> featuresMap = new HashMap<>();
|
|
List<String> groundTruth = new ArrayList<>();
|
|
Set<String> entities = g.getDocs();
|
|
for(String json:entities) {
|
|
featuresMap.put(
|
|
Utilities.getJPathString(idJPath, json),
|
|
Utilities.getJPathArray(featureJPath, json)
|
|
);
|
|
groundTruth.add(Utilities.getJPathString(groundTruthJPath, json));
|
|
}
|
|
|
|
Set<Relation> relations = g.getSimrels();
|
|
|
|
return getMultiDataSet(featuresMap, relations, groundTruth);
|
|
});
|
|
}
|
|
|
|
public static MultiDataSet getMultiDataSet(Map<String, double[]> featuresMap, Set<Relation> relations, List<String> groundTruth) {
|
|
|
|
List<String> identifiers = new ArrayList<>(featuresMap.keySet());
|
|
|
|
int numNodes = identifiers.size();
|
|
|
|
//initialize arrays
|
|
INDArray adjacency = Nd4j.zeros(numNodes, numNodes);
|
|
INDArray features = Nd4j.zeros(numNodes, featuresMap.get(identifiers.get(0)).length); //feature size taken from the first element (it's equal for every element)
|
|
INDArray degree = Nd4j.zeros(numNodes, numNodes);
|
|
|
|
//create adjacency
|
|
for(Relation r: relations) {
|
|
adjacency.put(identifiers.indexOf(r.getSource()), identifiers.indexOf(r.getTarget()), 1);
|
|
adjacency.put(identifiers.indexOf(r.getTarget()), identifiers.indexOf(r.getSource()), 1);
|
|
}
|
|
adjacency.addi(Nd4j.eye(numNodes));
|
|
|
|
//create degree and features
|
|
List<String> degreeSupport = relations.stream().flatMap(r -> Stream.of(r.getSource(), r.getTarget())).collect(Collectors.toList());
|
|
for(int i=0; i< identifiers.size(); i++) {
|
|
degree.put(i, i, Collections.frequency(degreeSupport, identifiers.get(i)));
|
|
features.putRow(i, Nd4j.create(featuresMap.get(identifiers.get(i))));
|
|
}
|
|
|
|
//infer label
|
|
INDArray label = Nd4j.zeros(1, 2);
|
|
if (groundTruth.stream().distinct().count()==1) {
|
|
//correct (same elements)
|
|
label.put(0, 0, 1.0);
|
|
}
|
|
else {
|
|
//wrong (different elements)
|
|
label.put(0, 1, 1.0);
|
|
}
|
|
|
|
return new MultiDataSet(
|
|
new INDArray[]{
|
|
features,
|
|
adjacency,
|
|
degree
|
|
},
|
|
new INDArray[]{
|
|
label
|
|
}
|
|
);
|
|
}
|
|
}
|