package eu.dnetlib.deeplearning; import eu.dnetlib.deeplearning.support.NetworkConfigurations; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; import org.deeplearning4j.nn.graph.ComputationGraph; import org.junit.jupiter.api.Test; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; public class NetworkConfigurationTests { public final static int N = 3; //number of nodes public final static int K = 7; //number of features public static INDArray[] exampleGraph = new INDArray[]{ Nd4j.zeros(N, K), //features Nd4j.ones(N, N), //adjacency Nd4j.ones(N, N) //degree }; @Test public void simpleGCNTest() { ComputationGraphConfiguration simpleGCNConf = NetworkConfigurations.getSimpleGCN(3, K, 5, 2); ComputationGraph simpleGCN = new ComputationGraph(simpleGCNConf); simpleGCN.init(); INDArray[] output = simpleGCN.output(exampleGraph); System.out.println("output = " + output[0]); } }