25 lines
934 B
Java
25 lines
934 B
Java
package eu.dnetlib.deeplearning.layers;
|
|
|
|
import org.deeplearning4j.nn.conf.graph.GraphVertex;
|
|
import org.deeplearning4j.nn.conf.layers.samediff.SDVertexParams;
|
|
import org.deeplearning4j.nn.conf.layers.samediff.SameDiffLambdaVertex;
|
|
import org.deeplearning4j.nn.conf.layers.samediff.SameDiffVertex;
|
|
import org.nd4j.autodiff.samediff.SDVariable;
|
|
import org.nd4j.autodiff.samediff.SameDiff;
|
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
|
|
import java.util.Map;
|
|
|
|
public class GraphConvolutionVertex extends SameDiffLambdaVertex {
|
|
|
|
@Override
|
|
public SDVariable defineVertex(SameDiff sameDiff, VertexInputs inputs) {
|
|
SDVariable features = inputs.getInput(0);
|
|
SDVariable adjacency = inputs.getInput(1);
|
|
SDVariable degree = inputs.getInput(2).pow(0.5);
|
|
|
|
//result: DegreeMatrix^-0.5 x Adjacent x DegreeMatrix^-0.5 x Features
|
|
return degree.mmul(adjacency).mmul(degree).mmul(features);
|
|
}
|
|
}
|