22 lines
697 B
Java
22 lines
697 B
Java
package eu.dnetlib.deeplearning.layers;
|
|
|
|
import org.deeplearning4j.nn.conf.layers.samediff.SameDiffLambdaLayer;
|
|
import org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer;
|
|
import org.nd4j.autodiff.samediff.SDIndex;
|
|
import org.nd4j.autodiff.samediff.SDVariable;
|
|
import org.nd4j.autodiff.samediff.SameDiff;
|
|
|
|
import java.util.Map;
|
|
|
|
public class GraphGlobalAddPool extends SameDiffLambdaLayer {
|
|
|
|
int size;
|
|
public GraphGlobalAddPool(int size) {
|
|
this.size = size;
|
|
}
|
|
@Override
|
|
public SDVariable defineLayer(SameDiff sameDiff, SDVariable layerInput) {
|
|
return layerInput.mean(0).reshape(1, size); //reshape because output layer expects 2-dimensional arrays
|
|
}
|
|
}
|