dnet-and/dnet-feature-extraction/src/main/java/eu/dnetlib/deeplearning/layers/GraphGlobalAddPool.java

22 lines
697 B
Java

package eu.dnetlib.deeplearning.layers;
import org.deeplearning4j.nn.conf.layers.samediff.SameDiffLambdaLayer;
import org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer;
import org.nd4j.autodiff.samediff.SDIndex;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import java.util.Map;
public class GraphGlobalAddPool extends SameDiffLambdaLayer {
int size;
public GraphGlobalAddPool(int size) {
this.size = size;
}
@Override
public SDVariable defineLayer(SameDiff sameDiff, SDVariable layerInput) {
return layerInput.mean(0).reshape(1, size); //reshape because output layer expects 2-dimensional arrays
}
}