dnet-and/dnet-feature-extraction/src/test/java/eu/dnetlib/deeplearning/NetworkConfigurationTests.java

34 lines
1.0 KiB
Java

package eu.dnetlib.deeplearning;
import eu.dnetlib.deeplearning.support.NetworkConfigurations;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.junit.jupiter.api.Test;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
public class NetworkConfigurationTests {
public final static int N = 3; //number of nodes
public final static int K = 7; //number of features
public static INDArray[] exampleGraph = new INDArray[]{
Nd4j.zeros(N, K), //features
Nd4j.ones(N, N), //adjacency
Nd4j.ones(N, N) //degree
};
@Test
public void simpleGCNTest() {
ComputationGraphConfiguration simpleGCNConf = NetworkConfigurations.getSimpleGCN(3, K, 5, 2);
ComputationGraph simpleGCN = new ComputationGraph(simpleGCNConf);
simpleGCN.init();
INDArray[] output = simpleGCN.output(exampleGraph);
System.out.println("output = " + output[0]);
}
}