48 lines
1.4 KiB
Java
48 lines
1.4 KiB
Java
|
package eu.dnetlib.deeplearning;
|
||
|
|
||
|
import com.beust.jcommander.internal.Sets;
|
||
|
import com.google.common.collect.Lists;
|
||
|
import eu.dnetlib.deeplearning.support.DataSetProcessor;
|
||
|
import eu.dnetlib.support.Relation;
|
||
|
import org.junit.jupiter.api.BeforeAll;
|
||
|
import org.junit.jupiter.api.Test;
|
||
|
import org.nd4j.linalg.dataset.MultiDataSet;
|
||
|
|
||
|
import java.util.*;
|
||
|
import java.util.stream.Collectors;
|
||
|
|
||
|
public class DataSetProcessorTest {
|
||
|
|
||
|
static Map<String, double[]> features;
|
||
|
static Set<Relation> relations;
|
||
|
static List<String> groundTruth;
|
||
|
|
||
|
@BeforeAll
|
||
|
public static void init(){
|
||
|
//initialize example features
|
||
|
features = new HashMap<>();
|
||
|
features.put("0", new double[]{0.0,0.0});
|
||
|
features.put("1", new double[]{1.0,1.0});
|
||
|
features.put("2", new double[]{2.0,2.0});
|
||
|
|
||
|
//initialize example relations
|
||
|
relations = new HashSet<>(Lists.newArrayList(
|
||
|
new Relation("0", "1", "simrel"),
|
||
|
new Relation("1", "2", "simrel")
|
||
|
));
|
||
|
|
||
|
//initialize example ground truth
|
||
|
groundTruth = Lists.newArrayList("class1", "class1", "class2");
|
||
|
|
||
|
}
|
||
|
|
||
|
@Test
|
||
|
public void getMultiDataSetTest() throws Exception {
|
||
|
MultiDataSet multiDataSet = DataSetProcessor.getMultiDataSet(features, relations, groundTruth);
|
||
|
System.out.println("multiDataSet = " + multiDataSet);
|
||
|
|
||
|
multiDataSet.asList();
|
||
|
}
|
||
|
|
||
|
}
|