package eu.dnetlib.deeplearning.layers; import org.deeplearning4j.nn.conf.graph.GraphVertex; import org.deeplearning4j.nn.conf.layers.samediff.SDVertexParams; import org.deeplearning4j.nn.conf.layers.samediff.SameDiffLambdaVertex; import org.deeplearning4j.nn.conf.layers.samediff.SameDiffVertex; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.INDArray; import java.util.Map; public class GraphConvolutionVertex extends SameDiffLambdaVertex { @Override public SDVariable defineVertex(SameDiff sameDiff, VertexInputs inputs) { SDVariable features = inputs.getInput(0); SDVariable adjacency = inputs.getInput(1); SDVariable degree = inputs.getInput(2).pow(0.5); //result: DegreeMatrix^-0.5 x Adjacent x DegreeMatrix^-0.5 x Features return degree.mmul(adjacency).mmul(degree).mmul(features); } }