98 lines
4.5 KiB
Java
98 lines
4.5 KiB
Java
package eu.dnetlib.deeplearning.support;
|
|
|
|
import eu.dnetlib.deeplearning.layers.GraphConvolutionVertex;
|
|
import eu.dnetlib.deeplearning.layers.GraphGlobalAddPool;
|
|
import org.bytedeco.opencv.opencv_dnn.PoolingLayer;
|
|
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
|
|
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
|
import org.deeplearning4j.nn.conf.graph.MergeVertex;
|
|
import org.deeplearning4j.nn.conf.layers.*;
|
|
import org.deeplearning4j.nn.weights.WeightInit;
|
|
import org.nd4j.autodiff.samediff.SDVariable;
|
|
import org.nd4j.linalg.activations.Activation;
|
|
import org.nd4j.linalg.api.buffer.DataType;
|
|
import org.nd4j.linalg.learning.config.Adam;
|
|
import org.nd4j.linalg.learning.config.Nesterovs;
|
|
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
|
|
|
public class NetworkConfigurations {
|
|
|
|
//parameteres default values
|
|
protected static final int SEED = 12345;
|
|
protected static final double LEARNING_RATE = 1e-3;
|
|
protected static final String ADJACENCY_MATRIX = "adjacency";
|
|
protected static final String FEATURES_MATRIX = "features";
|
|
protected static final String DEGREE_MATRIX = "degrees";
|
|
|
|
public static MultiLayerConfiguration getLinearDataClassifier(int numInputs, int numHiddenNodes, int numOutputs) {
|
|
return new NeuralNetConfiguration.Builder()
|
|
.seed(SEED)
|
|
.weightInit(WeightInit.XAVIER)
|
|
.updater(new Nesterovs(LEARNING_RATE, 0.9))
|
|
.list()
|
|
.layer(new DenseLayer.Builder().nIn(numInputs).nOut(numHiddenNodes)
|
|
.activation(Activation.RELU)
|
|
.build())
|
|
.layer(new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
|
|
.activation(Activation.SOFTMAX)
|
|
.nIn(numHiddenNodes).nOut(numOutputs).build())
|
|
.build();
|
|
}
|
|
|
|
public static ComputationGraphConfiguration getSimpleGCN(int numLayers, int numInputs, int numHiddenNodes, int numClasses) {
|
|
|
|
ComputationGraphConfiguration.GraphBuilder baseConfig = new NeuralNetConfiguration.Builder()
|
|
.seed(SEED)
|
|
.updater(new Adam(LEARNING_RATE))
|
|
.weightInit(WeightInit.XAVIER)
|
|
.graphBuilder()
|
|
.addInputs(FEATURES_MATRIX, ADJACENCY_MATRIX, DEGREE_MATRIX)
|
|
//first convolution layer
|
|
.addVertex("layer1",
|
|
new GraphConvolutionVertex(),
|
|
FEATURES_MATRIX, ADJACENCY_MATRIX, DEGREE_MATRIX)
|
|
.layer("conv1",
|
|
new DenseLayer.Builder().nIn(numInputs).nOut(numHiddenNodes)
|
|
.activation(Activation.RELU)
|
|
.build(),
|
|
"layer1")
|
|
.layer("batch1",
|
|
new BatchNormalization.Builder().nOut(numHiddenNodes).build(),
|
|
"conv1");
|
|
|
|
//ad as many layers as requested
|
|
for(int i=2; i<=numLayers; i++) {
|
|
baseConfig = baseConfig.addVertex("layer" + i,
|
|
new GraphConvolutionVertex(),
|
|
"batch" + (i-1), ADJACENCY_MATRIX, DEGREE_MATRIX)
|
|
.layer("conv" + i,
|
|
new DenseLayer.Builder().nIn(numHiddenNodes).nOut(numHiddenNodes)
|
|
.activation(Activation.RELU)
|
|
.build(),
|
|
"layer" + i)
|
|
.layer("batch" + i,
|
|
new BatchNormalization.Builder().nOut(numHiddenNodes).build(),
|
|
"conv" + i);
|
|
}
|
|
|
|
baseConfig = baseConfig
|
|
.layer("pool",
|
|
new GraphGlobalAddPool(numHiddenNodes),
|
|
"batch" + numLayers)
|
|
.layer("fc1",
|
|
new DenseLayer.Builder().nIn(numHiddenNodes).nOut(numHiddenNodes)
|
|
.activation(Activation.RELU)
|
|
.weightInit(WeightInit.XAVIER)
|
|
.build(),
|
|
"pool")
|
|
.layer("out",
|
|
new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
|
|
.activation(Activation.SOFTMAX)
|
|
.nIn(numHiddenNodes).nOut(numClasses).build(),
|
|
"fc1");
|
|
|
|
return baseConfig.setOutputs("out").build();
|
|
}
|
|
}
|