package eu.dnetlib.deeplearning; import com.beust.jcommander.internal.Sets; import com.google.common.collect.Lists; import eu.dnetlib.deeplearning.support.DataSetProcessor; import eu.dnetlib.support.Relation; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; import org.nd4j.linalg.dataset.MultiDataSet; import java.util.*; import java.util.stream.Collectors; public class DataSetProcessorTest { static Map features; static Set relations; static List groundTruth; @BeforeAll public static void init(){ //initialize example features features = new HashMap<>(); features.put("0", new double[]{0.0,0.0}); features.put("1", new double[]{1.0,1.0}); features.put("2", new double[]{2.0,2.0}); //initialize example relations relations = new HashSet<>(Lists.newArrayList( new Relation("0", "1", "simrel"), new Relation("1", "2", "simrel") )); //initialize example ground truth groundTruth = Lists.newArrayList("class1", "class1", "class2"); } @Test public void getMultiDataSetTest() throws Exception { MultiDataSet multiDataSet = DataSetProcessor.getMultiDataSet(features, relations, groundTruth); System.out.println("multiDataSet = " + multiDataSet); multiDataSet.asList(); } }