dnet-and/dnet-feature-extraction/src/main/java/eu/dnetlib/deeplearning/support/DataSetProcessor.java

89 lines
3.3 KiB
Java

package eu.dnetlib.deeplearning.support;
import eu.dnetlib.featureextraction.Utilities;
import eu.dnetlib.support.Author;
import eu.dnetlib.support.ConnectedComponent;
import eu.dnetlib.support.Relation;
import org.apache.spark.api.java.JavaRDD;
import org.codehaus.jackson.map.ObjectMapper;
import org.jetbrains.annotations.NotNull;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.MultiDataSet;
import org.nd4j.linalg.factory.Nd4j;
import java.io.IOException;
import java.util.*;
import java.util.stream.Collectors;
import java.util.stream.Stream;
public class DataSetProcessor {
public static JavaRDD<MultiDataSet> entityGroupToMultiDataset(JavaRDD<ConnectedComponent> groupEntity, String idJPath, String featureJPath, String groundTruthJPath) {
return groupEntity.map(g -> {
Map<String, double[]> featuresMap = new HashMap<>();
List<String> groundTruth = new ArrayList<>();
Set<String> entities = g.getDocs();
for(String json:entities) {
featuresMap.put(
Utilities.getJPathString(idJPath, json),
Utilities.getJPathArray(featureJPath, json)
);
groundTruth.add(Utilities.getJPathString(groundTruthJPath, json));
}
Set<Relation> relations = g.getSimrels();
return getMultiDataSet(featuresMap, relations, groundTruth);
});
}
public static MultiDataSet getMultiDataSet(Map<String, double[]> featuresMap, Set<Relation> relations, List<String> groundTruth) {
List<String> identifiers = new ArrayList<>(featuresMap.keySet());
int numNodes = identifiers.size();
//initialize arrays
INDArray adjacency = Nd4j.zeros(numNodes, numNodes);
INDArray features = Nd4j.zeros(numNodes, featuresMap.get(identifiers.get(0)).length); //feature size taken from the first element (it's equal for every element)
INDArray degree = Nd4j.zeros(numNodes, numNodes);
//create adjacency
for(Relation r: relations) {
adjacency.put(identifiers.indexOf(r.getSource()), identifiers.indexOf(r.getTarget()), 1);
adjacency.put(identifiers.indexOf(r.getTarget()), identifiers.indexOf(r.getSource()), 1);
}
adjacency.addi(Nd4j.eye(numNodes));
//create degree and features
List<String> degreeSupport = relations.stream().flatMap(r -> Stream.of(r.getSource(), r.getTarget())).collect(Collectors.toList());
for(int i=0; i< identifiers.size(); i++) {
degree.put(i, i, Collections.frequency(degreeSupport, identifiers.get(i)));
features.putRow(i, Nd4j.create(featuresMap.get(identifiers.get(i))));
}
//infer label
INDArray label = Nd4j.zeros(1, 2);
if (groundTruth.stream().distinct().count()==1) {
//correct (same elements)
label.put(0, 0, 1.0);
}
else {
//wrong (different elements)
label.put(0, 1, 1.0);
}
return new MultiDataSet(
new INDArray[]{
features,
adjacency,
degree
},
new INDArray[]{
label
}
);
}
}