package eu.dnetlib.deeplearning.layers; import org.deeplearning4j.nn.conf.layers.samediff.SameDiffLambdaLayer; import org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer; import org.nd4j.autodiff.samediff.SDIndex; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import java.util.Map; public class GraphGlobalAddPool extends SameDiffLambdaLayer { int size; public GraphGlobalAddPool(int size) { this.size = size; } @Override public SDVariable defineLayer(SameDiff sameDiff, SDVariable layerInput) { return layerInput.mean(0).reshape(1, size); //reshape because output layer expects 2-dimensional arrays } }