dnet-and/dnet-feature-extraction/src/main/java/eu/dnetlib/deeplearning/layers/GraphConvolutionVertex.java

25 lines
934 B
Java

package eu.dnetlib.deeplearning.layers;
import org.deeplearning4j.nn.conf.graph.GraphVertex;
import org.deeplearning4j.nn.conf.layers.samediff.SDVertexParams;
import org.deeplearning4j.nn.conf.layers.samediff.SameDiffLambdaVertex;
import org.deeplearning4j.nn.conf.layers.samediff.SameDiffVertex;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.ndarray.INDArray;
import java.util.Map;
public class GraphConvolutionVertex extends SameDiffLambdaVertex {
@Override
public SDVariable defineVertex(SameDiff sameDiff, VertexInputs inputs) {
SDVariable features = inputs.getInput(0);
SDVariable adjacency = inputs.getInput(1);
SDVariable degree = inputs.getInput(2).pow(0.5);
//result: DegreeMatrix^-0.5 x Adjacent x DegreeMatrix^-0.5 x Features
return degree.mmul(adjacency).mmul(degree).mmul(features);
}
}