34 lines
1.0 KiB
Java
34 lines
1.0 KiB
Java
package eu.dnetlib.deeplearning;
|
|
|
|
import eu.dnetlib.deeplearning.support.NetworkConfigurations;
|
|
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
|
|
import org.deeplearning4j.nn.graph.ComputationGraph;
|
|
import org.junit.jupiter.api.Test;
|
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
import org.nd4j.linalg.factory.Nd4j;
|
|
|
|
public class NetworkConfigurationTests {
|
|
|
|
public final static int N = 3; //number of nodes
|
|
public final static int K = 7; //number of features
|
|
|
|
public static INDArray[] exampleGraph = new INDArray[]{
|
|
Nd4j.zeros(N, K), //features
|
|
Nd4j.ones(N, N), //adjacency
|
|
Nd4j.ones(N, N) //degree
|
|
};
|
|
|
|
@Test
|
|
public void simpleGCNTest() {
|
|
|
|
ComputationGraphConfiguration simpleGCNConf = NetworkConfigurations.getSimpleGCN(3, K, 5, 2);
|
|
ComputationGraph simpleGCN = new ComputationGraph(simpleGCNConf);
|
|
simpleGCN.init();
|
|
|
|
INDArray[] output = simpleGCN.output(exampleGraph);
|
|
System.out.println("output = " + output[0]);
|
|
|
|
}
|
|
|
|
}
|