package eu.dnetlib.deeplearning.support; import eu.dnetlib.featureextraction.Utilities; import eu.dnetlib.support.Author; import eu.dnetlib.support.ConnectedComponent; import eu.dnetlib.support.Relation; import org.apache.spark.api.java.JavaRDD; import org.codehaus.jackson.map.ObjectMapper; import org.jetbrains.annotations.NotNull; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.MultiDataSet; import org.nd4j.linalg.factory.Nd4j; import java.io.IOException; import java.util.*; import java.util.stream.Collectors; import java.util.stream.Stream; public class DataSetProcessor { public static JavaRDD entityGroupToMultiDataset(JavaRDD groupEntity, String idJPath, String featureJPath, String groundTruthJPath) { return groupEntity.map(g -> { Map featuresMap = new HashMap<>(); List groundTruth = new ArrayList<>(); Set entities = g.getDocs(); for(String json:entities) { featuresMap.put( Utilities.getJPathString(idJPath, json), Utilities.getJPathArray(featureJPath, json) ); groundTruth.add(Utilities.getJPathString(groundTruthJPath, json)); } Set relations = g.getSimrels(); return getMultiDataSet(featuresMap, relations, groundTruth); }); } public static MultiDataSet getMultiDataSet(Map featuresMap, Set relations, List groundTruth) { List identifiers = new ArrayList<>(featuresMap.keySet()); int numNodes = identifiers.size(); //initialize arrays INDArray adjacency = Nd4j.zeros(numNodes, numNodes); INDArray features = Nd4j.zeros(numNodes, featuresMap.get(identifiers.get(0)).length); //feature size taken from the first element (it's equal for every element) INDArray degree = Nd4j.zeros(numNodes, numNodes); //create adjacency for(Relation r: relations) { adjacency.put(identifiers.indexOf(r.getSource()), identifiers.indexOf(r.getTarget()), 1); adjacency.put(identifiers.indexOf(r.getTarget()), identifiers.indexOf(r.getSource()), 1); } adjacency.addi(Nd4j.eye(numNodes)); //create degree and features List degreeSupport = relations.stream().flatMap(r -> Stream.of(r.getSource(), r.getTarget())).collect(Collectors.toList()); for(int i=0; i< identifiers.size(); i++) { degree.put(i, i, Collections.frequency(degreeSupport, identifiers.get(i))); features.putRow(i, Nd4j.create(featuresMap.get(identifiers.get(i)))); } //infer label INDArray label = Nd4j.zeros(1, 2); if (groundTruth.stream().distinct().count()==1) { //correct (same elements) label.put(0, 0, 1.0); } else { //wrong (different elements) label.put(0, 1, 1.0); } return new MultiDataSet( new INDArray[]{ features, adjacency, degree }, new INDArray[]{ label } ); } }