dnet-and/dnet-feature-extraction/src/main/java/eu/dnetlib/deeplearning/support/NetworkConfigurations.java

98 lines
4.5 KiB
Java

package eu.dnetlib.deeplearning.support;
import eu.dnetlib.deeplearning.layers.GraphConvolutionVertex;
import eu.dnetlib.deeplearning.layers.GraphGlobalAddPool;
import org.bytedeco.opencv.opencv_dnn.PoolingLayer;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.graph.MergeVertex;
import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.weights.WeightInit;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.learning.config.Nesterovs;
import org.nd4j.linalg.lossfunctions.LossFunctions;
public class NetworkConfigurations {
//parameteres default values
protected static final int SEED = 12345;
protected static final double LEARNING_RATE = 1e-3;
protected static final String ADJACENCY_MATRIX = "adjacency";
protected static final String FEATURES_MATRIX = "features";
protected static final String DEGREE_MATRIX = "degrees";
public static MultiLayerConfiguration getLinearDataClassifier(int numInputs, int numHiddenNodes, int numOutputs) {
return new NeuralNetConfiguration.Builder()
.seed(SEED)
.weightInit(WeightInit.XAVIER)
.updater(new Nesterovs(LEARNING_RATE, 0.9))
.list()
.layer(new DenseLayer.Builder().nIn(numInputs).nOut(numHiddenNodes)
.activation(Activation.RELU)
.build())
.layer(new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
.activation(Activation.SOFTMAX)
.nIn(numHiddenNodes).nOut(numOutputs).build())
.build();
}
public static ComputationGraphConfiguration getSimpleGCN(int numLayers, int numInputs, int numHiddenNodes, int numClasses) {
ComputationGraphConfiguration.GraphBuilder baseConfig = new NeuralNetConfiguration.Builder()
.seed(SEED)
.updater(new Adam(LEARNING_RATE))
.weightInit(WeightInit.XAVIER)
.graphBuilder()
.addInputs(FEATURES_MATRIX, ADJACENCY_MATRIX, DEGREE_MATRIX)
//first convolution layer
.addVertex("layer1",
new GraphConvolutionVertex(),
FEATURES_MATRIX, ADJACENCY_MATRIX, DEGREE_MATRIX)
.layer("conv1",
new DenseLayer.Builder().nIn(numInputs).nOut(numHiddenNodes)
.activation(Activation.RELU)
.build(),
"layer1")
.layer("batch1",
new BatchNormalization.Builder().nOut(numHiddenNodes).build(),
"conv1");
//ad as many layers as requested
for(int i=2; i<=numLayers; i++) {
baseConfig = baseConfig.addVertex("layer" + i,
new GraphConvolutionVertex(),
"batch" + (i-1), ADJACENCY_MATRIX, DEGREE_MATRIX)
.layer("conv" + i,
new DenseLayer.Builder().nIn(numHiddenNodes).nOut(numHiddenNodes)
.activation(Activation.RELU)
.build(),
"layer" + i)
.layer("batch" + i,
new BatchNormalization.Builder().nOut(numHiddenNodes).build(),
"conv" + i);
}
baseConfig = baseConfig
.layer("pool",
new GraphGlobalAddPool(numHiddenNodes),
"batch" + numLayers)
.layer("fc1",
new DenseLayer.Builder().nIn(numHiddenNodes).nOut(numHiddenNodes)
.activation(Activation.RELU)
.weightInit(WeightInit.XAVIER)
.build(),
"pool")
.layer("out",
new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
.activation(Activation.SOFTMAX)
.nIn(numHiddenNodes).nOut(numClasses).build(),
"fc1");
return baseConfig.setOutputs("out").build();
}
}