implementation of GCN network: first commit

This commit is contained in:
Michele De Bonis 2023-04-18 15:24:34 +02:00
parent 88451e3832
commit 4c9d33171d
21 changed files with 1642 additions and 4 deletions

View File

@ -36,7 +36,7 @@ public abstract class AbstractSparkJob implements Serializable {
this.spark = spark;
}
abstract void run() throws IOException;
protected abstract void run() throws IOException;
protected static SparkSession getSparkSession(SparkConf conf) {
return SparkSession.builder().config(conf).getOrCreate();

View File

@ -45,7 +45,7 @@ public class SparkLDAAnalysis extends AbstractSparkJob {
}
@Override
void run() throws IOException {
protected void run() throws IOException {
// read oozie parameters
final String authorsPath = parser.get("authorsPath");
final String workingPath = parser.get("workingPath");

View File

@ -0,0 +1,75 @@
package eu.dnetlib.jobs.deeplearning;
import eu.dnetlib.deeplearning.support.DataSetProcessor;
import eu.dnetlib.jobs.AbstractSparkJob;
import eu.dnetlib.jobs.SparkLDATuning;
import eu.dnetlib.support.ArgumentApplicationParser;
import eu.dnetlib.support.ConnectedComponent;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.SparkSession;
import org.codehaus.jackson.map.ObjectMapper;
import org.deeplearning4j.spark.data.BatchAndExportDataSetsFunction;
import org.deeplearning4j.spark.data.BatchAndExportMultiDataSetsFunction;
import org.deeplearning4j.spark.datavec.iterator.IteratorUtils;
import org.nd4j.linalg.dataset.MultiDataSet;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.util.List;
import java.util.Optional;
public class SparkCreateGroupDataSet extends AbstractSparkJob {
private static final Logger log = LoggerFactory.getLogger(SparkCreateGroupDataSet.class);
public SparkCreateGroupDataSet(ArgumentApplicationParser parser, SparkSession spark) {
super(parser, spark);
}
public static void main(String[] args) throws Exception {
ArgumentApplicationParser parser = new ArgumentApplicationParser(
readResource("/jobs/parameters/createGroupDataset_parameters.json", SparkLDATuning.class)
);
parser.parseArgument(args);
SparkConf conf = new SparkConf();
new SparkCreateGroupDataSet(
parser,
getSparkSession(conf)
).run();
}
@Override
public void run() throws IOException {
// read oozie parameters
final String groupsPath = parser.get("groupsPath");
final String workingPath = parser.get("workingPath");
final String groundTruthJPath = parser.get("groundTruthJPath");
final String idJPath = parser.get("idJPath");
final String featuresJPath = parser.get("featuresJPath");
final int numPartitions = Optional
.ofNullable(parser.get("numPartitions"))
.map(Integer::valueOf)
.orElse(NUM_PARTITIONS);
log.info("groupsPath: '{}'", groupsPath);
log.info("workingPath: '{}'", workingPath);
log.info("groundTruthJPath: '{}'", groundTruthJPath);
log.info("idJPath: '{}'", idJPath);
log.info("featuresJPath: '{}'", featuresJPath);
log.info("numPartitions: '{}'", numPartitions);
JavaSparkContext context = JavaSparkContext.fromSparkContext(spark.sparkContext());
JavaRDD<ConnectedComponent> groups = context.textFile(groupsPath).map(g -> new ObjectMapper().readValue(g, ConnectedComponent.class));
JavaRDD<MultiDataSet> dataset = DataSetProcessor.entityGroupToMultiDataset(groups, idJPath, featuresJPath, groundTruthJPath);
dataset.saveAsObjectFile(workingPath + "/groupDataset");
}
}

View File

@ -0,0 +1,99 @@
package eu.dnetlib.jobs.deeplearning;
import eu.dnetlib.deeplearning.support.DataSetProcessor;
import eu.dnetlib.deeplearning.support.NetworkConfigurations;
import eu.dnetlib.jobs.AbstractSparkJob;
import eu.dnetlib.jobs.SparkLDATuning;
import eu.dnetlib.support.ArgumentApplicationParser;
import eu.dnetlib.support.ConnectedComponent;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.SparkSession;
import org.codehaus.jackson.map.ObjectMapper;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.optimize.listeners.PerformanceListener;
import org.deeplearning4j.optimize.solvers.accumulation.encoding.threshold.AdaptiveThresholdAlgorithm;
import org.deeplearning4j.spark.api.RDDTrainingApproach;
import org.deeplearning4j.spark.api.TrainingMaster;
import org.deeplearning4j.spark.impl.graph.SparkComputationGraph;
import org.deeplearning4j.spark.parameterserver.training.SharedTrainingMaster;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.parameterserver.distributed.conf.VoidConfiguration;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.util.Optional;
public class SparkGraphClassificationTraining extends AbstractSparkJob {
private static final Logger log = LoggerFactory.getLogger(SparkGraphClassificationTraining.class);
public SparkGraphClassificationTraining(ArgumentApplicationParser parser, SparkSession spark) {
super(parser, spark);
}
public static void main(String[] args) throws Exception {
ArgumentApplicationParser parser = new ArgumentApplicationParser(
readResource("/jobs/parameters/graphClassificationTraining_parameters.json", SparkLDATuning.class)
);
parser.parseArgument(args);
SparkConf conf = new SparkConf();
new SparkGraphClassificationTraining(
parser,
getSparkSession(conf)
).run();
}
@Override
public void run() throws IOException {
// read oozie parameters
final String workingPath = parser.get("workingPath");
final int numPartitions = Optional
.ofNullable(parser.get("numPartitions"))
.map(Integer::valueOf)
.orElse(NUM_PARTITIONS);
log.info("workingPath: '{}'", workingPath);
log.info("numPartitions: '{}'", numPartitions);
JavaSparkContext context = JavaSparkContext.fromSparkContext(spark.sparkContext());
VoidConfiguration conf = VoidConfiguration.builder()
.unicastPort(40123)
// .networkMask("255.255.148.0/22")
.controllerAddress("127.0.0.1")
.build();
TrainingMaster trainingMaster = new SharedTrainingMaster.Builder(conf,1)
.rngSeed(12345)
.collectTrainingStats(false)
.thresholdAlgorithm(new AdaptiveThresholdAlgorithm(1e-3))
.batchSizePerWorker(32)
.workersPerNode(4)
.rddTrainingApproach(RDDTrainingApproach.Direct)
.build();
JavaRDD<MultiDataSet> trainData = context.objectFile(workingPath + "/groupDataset");
SparkComputationGraph sparkComputationGraph = new SparkComputationGraph(
context,
NetworkConfigurations.getSimpleGCN(3, 2, 5, 2),
trainingMaster);
sparkComputationGraph.setListeners(new PerformanceListener(10, true));
//execute training
for (int i = 0; i < 20; i ++) {
sparkComputationGraph.fitMultiDataSet(trainData);
}
ComputationGraph network = sparkComputationGraph.getNetwork();
System.out.println("network = " + network.getConfiguration().toJson());
}
}

View File

@ -0,0 +1,38 @@
[
{
"paramName": "i",
"paramLongName": "groupsPath",
"paramDescription": "the input data: groups to be transformed into dataset",
"paramRequired": true
},
{
"paramName": "f",
"paramLongName": "featuresJPath",
"paramDescription": "the jpath of the features field",
"paramRequired": true
},
{
"paramName": "w",
"paramLongName": "workingPath",
"paramDescription": "path of the working directory",
"paramRequired": true
},
{
"paramName": "np",
"paramLongName": "numPartitions",
"paramDescription": "number of partitions for the similarity relations intermediate phases",
"paramRequired": false
},
{
"paramName": "id",
"paramLongName": "idJPath",
"paramDescription": "the jpath of the id field",
"paramRequired": true
},
{
"paramName": "gt",
"paramLongName": "groundTruthJPath",
"paramDescription": "the jpath of the field to be used as ground truth",
"paramRequired": true
}
]

View File

@ -0,0 +1,14 @@
[
{
"paramName": "w",
"paramLongName": "workingPath",
"paramDescription": "path of the working directory",
"paramRequired": true
},
{
"paramName": "np",
"paramLongName": "numPartitions",
"paramDescription": "number of partitions for the similarity relations intermediate phases",
"paramRequired": false
}
]

View File

@ -0,0 +1,102 @@
package eu.dnetlib.jobs.deeplearning;
import eu.dnetlib.jobs.AbstractSparkJob;
import eu.dnetlib.support.ArgumentApplicationParser;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.IOUtils;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.SparkSession;
import org.junit.jupiter.api.*;
import java.io.File;
import java.io.IOException;
import java.net.URISyntaxException;
import java.nio.file.Paths;
@TestMethodOrder(MethodOrderer.OrderAnnotation.class)
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
public class GNNTrainingTest {
static SparkSession spark;
static JavaSparkContext context;
final static String workingPath = "/tmp/working_dir";
final static String numPartitions = "20";
final String inputDataPath = Paths
.get(getClass().getResource("/eu/dnetlib/jobs/examples/authors.groups.example.json").toURI())
.toFile()
.getAbsolutePath();
final static String groundTruthJPath = "$.orcid";
final static String idJPath = "$.id";
final static String featuresJPath = "$.topics";
public GNNTrainingTest() throws URISyntaxException {}
public static void cleanup() throws IOException {
//remove directories and clean workspace
FileUtils.deleteDirectory(new File(workingPath));
}
@BeforeAll
public void setup() throws IOException {
cleanup();
spark = SparkSession
.builder()
.appName("Testing")
.master("local[*]")
.getOrCreate();
context = JavaSparkContext.fromSparkContext(spark.sparkContext());
}
@AfterAll
public static void finalCleanUp() throws IOException {
cleanup();
}
@Test
@Order(1)
public void createGroupDataSetTest() throws Exception {
ArgumentApplicationParser parser = new ArgumentApplicationParser(readResource("/jobs/parameters/createGroupDataset_parameters.json", SparkCreateGroupDataSet.class));
parser.parseArgument(
new String[] {
"-i", inputDataPath,
"-gt", groundTruthJPath,
"-id", idJPath,
"-f", featuresJPath,
"-w", workingPath,
"-np", numPartitions
}
);
new SparkCreateGroupDataSet(
parser,
spark
).run();
}
@Test
@Order(2)
public void graphClassificationTrainingTest() throws Exception{
ArgumentApplicationParser parser = new ArgumentApplicationParser(readResource("/jobs/parameters/graphClassificationTraining_parameters.json", SparkGraphClassificationTraining.class));
parser.parseArgument(
new String[] {
"-w", workingPath,
"-np", numPartitions
}
);
new SparkGraphClassificationTraining(
parser,
spark
).run();
}
public static String readResource(String path, Class<? extends AbstractSparkJob> clazz) throws IOException {
return IOUtils.toString(clazz.getResourceAsStream(path));
}
}

File diff suppressed because one or more lines are too long

View File

@ -51,6 +51,45 @@
<artifactId>junit-jupiter</artifactId>
<scope>test</scope>
</dependency>
<!--DEEPLEARNING4J -->
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>${nd4j.backend}</artifactId>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-core</artifactId>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-datasets</artifactId>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>dl4j-spark-parameterserver_2.11</artifactId>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>dl4j-spark_2.11</artifactId>
</dependency>
<!--PLOT-->
<dependency>
<groupId>jfree</groupId>
<artifactId>jfreechart</artifactId>
</dependency>
<dependency>
<groupId>org.jfree</groupId>
<artifactId>jcommon</artifactId>
</dependency>
<!--DNET DEDUP-->
<dependency>
<groupId>eu.dnetlib</groupId>
<artifactId>dnet-dedup-test</artifactId>
</dependency>
</dependencies>
</project>

View File

@ -0,0 +1,130 @@
//package eu.dnetlib.deeplearning;
//
///* *****************************************************************************
// *
// *
// *
// * This program and the accompanying materials are made available under the
// * terms of the Apache License, Version 2.0 which is available at
// * https://www.apache.org/licenses/LICENSE-2.0.
// * See the NOTICE file distributed with this work for additional
// * information regarding copyright ownership.
// *
// * Unless required by applicable law or agreed to in writing, software
// * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// * License for the specific language governing permissions and limitations
// * under the License.
// *
// * SPDX-License-Identifier: Apache-2.0
// ******************************************************************************/
//
//import org.datavec.api.records.reader.RecordReader;
//import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
//import org.datavec.api.split.FileSplit;
//import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
//import org.deeplearning4j.examples.utils.DownloaderUtility;
//import org.deeplearning4j.examples.utils.PlotUtil;
//import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
//import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
//import org.deeplearning4j.nn.conf.layers.DenseLayer;
//import org.deeplearning4j.nn.conf.layers.OutputLayer;
//import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
//import org.deeplearning4j.nn.weights.WeightInit;
//import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
//import org.nd4j.evaluation.classification.Evaluation;
//import org.nd4j.linalg.activations.Activation;
//import org.nd4j.linalg.api.ndarray.INDArray;
//import org.nd4j.linalg.dataset.DataSet;
//import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
//import org.nd4j.linalg.learning.config.Nesterovs;
//import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;
//
//import java.io.File;
//import java.util.concurrent.TimeUnit;
//
//public class GroupClassifier {
//
// public static boolean visualize = true;
// public static String dataLocalPath;
//
// public static void main(String[] args) throws Exception {
// int seed = 123;
// double learningRate = 0.01;
// int batchSize = 50;
// int nEpochs = 30;
//
// int numInputs = 2;
// int numOutputs = 2;
// int numHiddenNodes = 20;
//
// dataLocalPath = DownloaderUtility.CLASSIFICATIONDATA.Download();
// //Load the training data:
// RecordReader rr = new CSVRecordReader();
// rr.initialize(new FileSplit(new File(dataLocalPath, "linear_data_train.csv")));
// DataSetIterator trainIter = new RecordReaderDataSetIterator(rr, batchSize, 0, 2);
//
// //Load the test/evaluation data:
// RecordReader rrTest = new CSVRecordReader();
// rrTest.initialize(new FileSplit(new File(dataLocalPath, "linear_data_eval.csv")));
// DataSetIterator testIter = new RecordReaderDataSetIterator(rrTest, batchSize, 0, 2);
//
// MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
// .seed(seed)
// .weightInit(WeightInit.XAVIER)
// .updater(new Nesterovs(learningRate, 0.9))
// .list()
// .layer(new DenseLayer.Builder().nIn(numInputs).nOut(numHiddenNodes)
// .activation(Activation.RELU)
// .build())
// .layer(new OutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD)
// .activation(Activation.SOFTMAX)
// .nIn(numHiddenNodes).nOut(numOutputs).build())
// .build();
//
//
// MultiLayerNetwork model = new MultiLayerNetwork(conf);
// model.init();
// model.setListeners(new ScoreIterationListener(10)); //Print score every 10 parameter updates
//
// model.fit(trainIter, nEpochs);
//
// System.out.println("Evaluate model....");
// Evaluation eval = new Evaluation(numOutputs);
// while (testIter.hasNext()) {
// DataSet t = testIter.next();
// INDArray features = t.getFeatures();
// INDArray labels = t.getLabels();
// INDArray predicted = model.output(features, false);
// eval.eval(labels, predicted);
// }
// //An alternate way to do the above loop
// //Evaluation evalResults = model.evaluate(testIter);
//
// //Print the evaluation statistics
// System.out.println(eval.stats());
//
// System.out.println("\n****************Example finished********************");
// //Training is complete. Code that follows is for plotting the data & predictions only
// generateVisuals(model, trainIter, testIter);
// }
//
// public static void generateVisuals(MultiLayerNetwork model, DataSetIterator trainIter, DataSetIterator testIter) throws Exception {
// if (visualize) {
// double xMin = 0;
// double xMax = 1.0;
// double yMin = -0.2;
// double yMax = 0.8;
// int nPointsPerAxis = 100;
//
// //Generate x,y points that span the whole range of features
// INDArray allXYPoints = PlotUtil.generatePointsOnGraph(xMin, xMax, yMin, yMax, nPointsPerAxis);
// //Get train data and plot with predictions
// PlotUtil.plotTrainingData(model, trainIter, allXYPoints, nPointsPerAxis);
// TimeUnit.SECONDS.sleep(3);
// //Get test data, run the test data through the network to generate predictions, and plot those predictions:
// PlotUtil.plotTestData(model, testIter, allXYPoints, nPointsPerAxis);
// }
// }
//}
//

View File

@ -0,0 +1,24 @@
package eu.dnetlib.deeplearning.layers;
import org.deeplearning4j.nn.conf.graph.GraphVertex;
import org.deeplearning4j.nn.conf.layers.samediff.SDVertexParams;
import org.deeplearning4j.nn.conf.layers.samediff.SameDiffLambdaVertex;
import org.deeplearning4j.nn.conf.layers.samediff.SameDiffVertex;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.ndarray.INDArray;
import java.util.Map;
public class GraphConvolutionVertex extends SameDiffLambdaVertex {
@Override
public SDVariable defineVertex(SameDiff sameDiff, VertexInputs inputs) {
SDVariable features = inputs.getInput(0);
SDVariable adjacency = inputs.getInput(1);
SDVariable degree = inputs.getInput(2).pow(0.5);
//result: DegreeMatrix^-0.5 x Adjacent x DegreeMatrix^-0.5 x Features
return degree.mmul(adjacency).mmul(degree).mmul(features);
}
}

View File

@ -0,0 +1,21 @@
package eu.dnetlib.deeplearning.layers;
import org.deeplearning4j.nn.conf.layers.samediff.SameDiffLambdaLayer;
import org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer;
import org.nd4j.autodiff.samediff.SDIndex;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import java.util.Map;
public class GraphGlobalAddPool extends SameDiffLambdaLayer {
int size;
public GraphGlobalAddPool(int size) {
this.size = size;
}
@Override
public SDVariable defineLayer(SameDiff sameDiff, SDVariable layerInput) {
return layerInput.mean(0).reshape(1, size); //reshape because output layer expects 2-dimensional arrays
}
}

View File

@ -0,0 +1,88 @@
package eu.dnetlib.deeplearning.support;
import eu.dnetlib.featureextraction.Utilities;
import eu.dnetlib.support.Author;
import eu.dnetlib.support.ConnectedComponent;
import eu.dnetlib.support.Relation;
import org.apache.spark.api.java.JavaRDD;
import org.codehaus.jackson.map.ObjectMapper;
import org.jetbrains.annotations.NotNull;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.MultiDataSet;
import org.nd4j.linalg.factory.Nd4j;
import java.io.IOException;
import java.util.*;
import java.util.stream.Collectors;
import java.util.stream.Stream;
public class DataSetProcessor {
public static JavaRDD<MultiDataSet> entityGroupToMultiDataset(JavaRDD<ConnectedComponent> groupEntity, String idJPath, String featureJPath, String groundTruthJPath) {
return groupEntity.map(g -> {
Map<String, double[]> featuresMap = new HashMap<>();
List<String> groundTruth = new ArrayList<>();
Set<String> entities = g.getDocs();
for(String json:entities) {
featuresMap.put(
Utilities.getJPathString(idJPath, json),
Utilities.getJPathArray(featureJPath, json)
);
groundTruth.add(Utilities.getJPathString(groundTruthJPath, json));
}
Set<Relation> relations = g.getSimrels();
return getMultiDataSet(featuresMap, relations, groundTruth);
});
}
public static MultiDataSet getMultiDataSet(Map<String, double[]> featuresMap, Set<Relation> relations, List<String> groundTruth) {
List<String> identifiers = new ArrayList<>(featuresMap.keySet());
int numNodes = identifiers.size();
//initialize arrays
INDArray adjacency = Nd4j.zeros(numNodes, numNodes);
INDArray features = Nd4j.zeros(numNodes, featuresMap.get(identifiers.get(0)).length); //feature size taken from the first element (it's equal for every element)
INDArray degree = Nd4j.zeros(numNodes, numNodes);
//create adjacency
for(Relation r: relations) {
adjacency.put(identifiers.indexOf(r.getSource()), identifiers.indexOf(r.getTarget()), 1);
adjacency.put(identifiers.indexOf(r.getTarget()), identifiers.indexOf(r.getSource()), 1);
}
adjacency.addi(Nd4j.eye(numNodes));
//create degree and features
List<String> degreeSupport = relations.stream().flatMap(r -> Stream.of(r.getSource(), r.getTarget())).collect(Collectors.toList());
for(int i=0; i< identifiers.size(); i++) {
degree.put(i, i, Collections.frequency(degreeSupport, identifiers.get(i)));
features.putRow(i, Nd4j.create(featuresMap.get(identifiers.get(i))));
}
//infer label
INDArray label = Nd4j.zeros(1, 2);
if (groundTruth.stream().distinct().count()==1) {
//correct (same elements)
label.put(0, 0, 1.0);
}
else {
//wrong (different elements)
label.put(0, 1, 1.0);
}
return new MultiDataSet(
new INDArray[]{
features,
adjacency,
degree
},
new INDArray[]{
label
}
);
}
}

View File

@ -0,0 +1,11 @@
package eu.dnetlib.deeplearning.support;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.MultiDataSet;
import java.io.*;
import java.util.List;
public class GroupMultiDataSet extends MultiDataSet {
}

View File

@ -0,0 +1,97 @@
package eu.dnetlib.deeplearning.support;
import eu.dnetlib.deeplearning.layers.GraphConvolutionVertex;
import eu.dnetlib.deeplearning.layers.GraphGlobalAddPool;
import org.bytedeco.opencv.opencv_dnn.PoolingLayer;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.graph.MergeVertex;
import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.weights.WeightInit;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.learning.config.Nesterovs;
import org.nd4j.linalg.lossfunctions.LossFunctions;
public class NetworkConfigurations {
//parameteres default values
protected static final int SEED = 12345;
protected static final double LEARNING_RATE = 1e-3;
protected static final String ADJACENCY_MATRIX = "adjacency";
protected static final String FEATURES_MATRIX = "features";
protected static final String DEGREE_MATRIX = "degrees";
public static MultiLayerConfiguration getLinearDataClassifier(int numInputs, int numHiddenNodes, int numOutputs) {
return new NeuralNetConfiguration.Builder()
.seed(SEED)
.weightInit(WeightInit.XAVIER)
.updater(new Nesterovs(LEARNING_RATE, 0.9))
.list()
.layer(new DenseLayer.Builder().nIn(numInputs).nOut(numHiddenNodes)
.activation(Activation.RELU)
.build())
.layer(new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
.activation(Activation.SOFTMAX)
.nIn(numHiddenNodes).nOut(numOutputs).build())
.build();
}
public static ComputationGraphConfiguration getSimpleGCN(int numLayers, int numInputs, int numHiddenNodes, int numClasses) {
ComputationGraphConfiguration.GraphBuilder baseConfig = new NeuralNetConfiguration.Builder()
.seed(SEED)
.updater(new Adam(LEARNING_RATE))
.weightInit(WeightInit.XAVIER)
.graphBuilder()
.addInputs(FEATURES_MATRIX, ADJACENCY_MATRIX, DEGREE_MATRIX)
//first convolution layer
.addVertex("layer1",
new GraphConvolutionVertex(),
FEATURES_MATRIX, ADJACENCY_MATRIX, DEGREE_MATRIX)
.layer("conv1",
new DenseLayer.Builder().nIn(numInputs).nOut(numHiddenNodes)
.activation(Activation.RELU)
.build(),
"layer1")
.layer("batch1",
new BatchNormalization.Builder().nOut(numHiddenNodes).build(),
"conv1");
//ad as many layers as requested
for(int i=2; i<=numLayers; i++) {
baseConfig = baseConfig.addVertex("layer" + i,
new GraphConvolutionVertex(),
"batch" + (i-1), ADJACENCY_MATRIX, DEGREE_MATRIX)
.layer("conv" + i,
new DenseLayer.Builder().nIn(numHiddenNodes).nOut(numHiddenNodes)
.activation(Activation.RELU)
.build(),
"layer" + i)
.layer("batch" + i,
new BatchNormalization.Builder().nOut(numHiddenNodes).build(),
"conv" + i);
}
baseConfig = baseConfig
.layer("pool",
new GraphGlobalAddPool(numHiddenNodes),
"batch" + numLayers)
.layer("fc1",
new DenseLayer.Builder().nIn(numHiddenNodes).nOut(numHiddenNodes)
.activation(Activation.RELU)
.weightInit(WeightInit.XAVIER)
.build(),
"pool")
.layer("out",
new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
.activation(Activation.SOFTMAX)
.nIn(numHiddenNodes).nOut(numClasses).build(),
"fc1");
return baseConfig.setOutputs("out").build();
}
}

View File

@ -0,0 +1,253 @@
//package eu.dnetlib.deeplearning.support;
//
//import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
//import org.jfree.chart.ChartPanel;
//import org.jfree.chart.ChartUtilities;
//import org.jfree.chart.JFreeChart;
//import org.jfree.chart.axis.AxisLocation;
//import org.jfree.chart.axis.NumberAxis;
//import org.jfree.chart.block.BlockBorder;
//import org.jfree.chart.plot.DatasetRenderingOrder;
//import org.jfree.chart.plot.XYPlot;
//import org.jfree.chart.renderer.GrayPaintScale;
//import org.jfree.chart.renderer.PaintScale;
//import org.jfree.chart.renderer.xy.XYBlockRenderer;
//import org.jfree.chart.renderer.xy.XYLineAndShapeRenderer;
//import org.jfree.chart.title.PaintScaleLegend;
//import org.jfree.data.xy.*;
//import org.jfree.ui.RectangleEdge;
//import org.jfree.ui.RectangleInsets;
//import org.nd4j.linalg.api.ndarray.INDArray;
//import org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax;
//import org.nd4j.linalg.dataset.DataSet;
//import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
//import org.nd4j.linalg.factory.Nd4j;
//
//import javax.swing.*;
//import java.awt.*;
//import java.util.ArrayList;
//import java.util.List;
//
///**
// * Simple plotting methods for the MLPClassifier quickstartexamples
// *
// * @author Alex Black
// */
//public class PlotUtils {
//
// /**
// * Plot the training data. Assume 2d input, classification output
// *
// * @param model Model to use to get predictions
// * @param trainIter DataSet Iterator
// * @param backgroundIn sets of x,y points in input space, plotted in the background
// * @param nDivisions Number of points (per axis, for the backgroundIn/backgroundOut arrays)
// */
// public static void plotTrainingData(MultiLayerNetwork model, DataSetIterator trainIter, INDArray backgroundIn, int nDivisions) {
// double[] mins = backgroundIn.min(0).data().asDouble();
// double[] maxs = backgroundIn.max(0).data().asDouble();
//
// DataSet ds = allBatches(trainIter);
// INDArray backgroundOut = model.output(backgroundIn);
//
// XYZDataset backgroundData = createBackgroundData(backgroundIn, backgroundOut);
// JPanel panel = new ChartPanel(createChart(backgroundData, mins, maxs, nDivisions, createDataSetTrain(ds.getFeatures(), ds.getLabels())));
//
// JFrame f = new JFrame();
// f.add(panel);
// f.setDefaultCloseOperation(WindowConstants.EXIT_ON_CLOSE);
// f.pack();
// f.setTitle("Training Data");
//
// f.setVisible(true);
// f.setLocation(0, 0);
// }
//
// /**
// * Plot the training data. Assume 2d input, classification output
// *
// * @param model Model to use to get predictions
// * @param testIter Test Iterator
// * @param backgroundIn sets of x,y points in input space, plotted in the background
// * @param nDivisions Number of points (per axis, for the backgroundIn/backgroundOut arrays)
// */
// public static void plotTestData(MultiLayerNetwork model, DataSetIterator testIter, INDArray backgroundIn, int nDivisions) {
//
// double[] mins = backgroundIn.min(0).data().asDouble();
// double[] maxs = backgroundIn.max(0).data().asDouble();
//
// INDArray backgroundOut = model.output(backgroundIn);
// XYZDataset backgroundData = createBackgroundData(backgroundIn, backgroundOut);
// DataSet ds = allBatches(testIter);
// INDArray predicted = model.output(ds.getFeatures());
// JPanel panel = new ChartPanel(createChart(backgroundData, mins, maxs, nDivisions, createDataSetTest(ds.getFeatures(), ds.getLabels(), predicted)));
//
// JFrame f = new JFrame();
// f.add(panel);
// f.setDefaultCloseOperation(WindowConstants.EXIT_ON_CLOSE);
// f.pack();
// f.setTitle("Test Data");
//
// f.setVisible(true);
// f.setLocationRelativeTo(null);
// //f.setLocation(100,100);
//
// }
//
//
// /**
// * Create data for the background data set
// */
// private static XYZDataset createBackgroundData(INDArray backgroundIn, INDArray backgroundOut) {
// int nRows = backgroundIn.rows();
// double[] xValues = new double[nRows];
// double[] yValues = new double[nRows];
// double[] zValues = new double[nRows];
// for (int i = 0; i < nRows; i++) {
// xValues[i] = backgroundIn.getDouble(i, 0);
// yValues[i] = backgroundIn.getDouble(i, 1);
// zValues[i] = backgroundOut.getDouble(i, 0);
//
// }
//
// DefaultXYZDataset dataset = new DefaultXYZDataset();
// dataset.addSeries("Series 1",
// new double[][]{xValues, yValues, zValues});
// return dataset;
// }
//
// //Training data
// private static XYDataset createDataSetTrain(INDArray features, INDArray labels) {
// int nRows = features.rows();
//
// int nClasses = 2; // Binary classification using one output call end sigmoid.
//
// XYSeries[] series = new XYSeries[nClasses];
// for (int i = 0; i < series.length; i++) series[i] = new XYSeries("Class " + i);
// INDArray argMax = Nd4j.getExecutioner().exec(new ArgMax(new INDArray[]{labels},false,new int[]{1}))[0];
// for (int i = 0; i < nRows; i++) {
// int classIdx = (int) argMax.getDouble(i);
// series[classIdx].add(features.getDouble(i, 0), features.getDouble(i, 1));
// }
//
// XYSeriesCollection c = new XYSeriesCollection();
// for (XYSeries s : series) c.addSeries(s);
// return c;
// }
//
// //Test data
// private static XYDataset createDataSetTest(INDArray features, INDArray labels, INDArray predicted) {
// int nRows = features.rows();
//
// int nClasses = 2; // Binary classification using one output call end sigmoid.
//
// XYSeries[] series = new XYSeries[nClasses * nClasses];
// int[] series_index = new int[]{0, 3, 2, 1}; //little hack to make the charts look consistent.
// for (int i = 0; i < nClasses * nClasses; i++) {
// int trueClass = i / nClasses;
// int predClass = i % nClasses;
// String label = "actual=" + trueClass + ", pred=" + predClass;
// series[series_index[i]] = new XYSeries(label);
// }
// INDArray actualIdx = labels.argMax(1);
// INDArray predictedIdx = predicted.argMax(1);
// for (int i = 0; i < nRows; i++) {
// int classIdx = actualIdx.getInt(i);
// int predIdx = predictedIdx.getInt(i);
// int idx = series_index[classIdx * nClasses + predIdx];
// series[idx].add(features.getDouble(i, 0), features.getDouble(i, 1));
// }
//
// XYSeriesCollection c = new XYSeriesCollection();
// for (XYSeries s : series) c.addSeries(s);
// return c;
// }
//
// private static JFreeChart createChart(XYZDataset dataset, double[] mins, double[] maxs, int nPoints, XYDataset xyData) {
// NumberAxis xAxis = new NumberAxis("X");
// xAxis.setRange(mins[0], maxs[0]);
//
//
// NumberAxis yAxis = new NumberAxis("Y");
// yAxis.setRange(mins[1], maxs[1]);
//
// XYBlockRenderer renderer = new XYBlockRenderer();
// renderer.setBlockWidth((maxs[0] - mins[0]) / (nPoints - 1));
// renderer.setBlockHeight((maxs[1] - mins[1]) / (nPoints - 1));
// PaintScale scale = new GrayPaintScale(0, 1.0);
// renderer.setPaintScale(scale);
// XYPlot plot = new XYPlot(dataset, xAxis, yAxis, renderer);
// plot.setBackgroundPaint(Color.lightGray);
// plot.setDomainGridlinesVisible(false);
// plot.setRangeGridlinesVisible(false);
// plot.setAxisOffset(new RectangleInsets(5, 5, 5, 5));
// JFreeChart chart = new JFreeChart("", plot);
// chart.getXYPlot().getRenderer().setSeriesVisibleInLegend(0, false);
//
//
// NumberAxis scaleAxis = new NumberAxis("Probability (class 1)");
// scaleAxis.setAxisLinePaint(Color.white);
// scaleAxis.setTickMarkPaint(Color.white);
// scaleAxis.setTickLabelFont(new Font("Dialog", Font.PLAIN, 7));
// PaintScaleLegend legend = new PaintScaleLegend(new GrayPaintScale(),
// scaleAxis);
// legend.setStripOutlineVisible(false);
// legend.setSubdivisionCount(20);
// legend.setAxisLocation(AxisLocation.BOTTOM_OR_LEFT);
// legend.setAxisOffset(5.0);
// legend.setMargin(new RectangleInsets(5, 5, 5, 5));
// legend.setFrame(new BlockBorder(Color.red));
// legend.setPadding(new RectangleInsets(10, 10, 10, 10));
// legend.setStripWidth(10);
// legend.setPosition(RectangleEdge.LEFT);
// chart.addSubtitle(legend);
//
// ChartUtilities.applyCurrentTheme(chart);
//
// plot.setDataset(1, xyData);
// XYLineAndShapeRenderer renderer2 = new XYLineAndShapeRenderer();
// renderer2.setBaseLinesVisible(false);
// plot.setRenderer(1, renderer2);
//
// plot.setDatasetRenderingOrder(DatasetRenderingOrder.FORWARD);
//
// return chart;
// }
//
// public static INDArray generatePointsOnGraph(double xMin, double xMax, double yMin, double yMax, int nPointsPerAxis) {
// //generate all the x,y points
// double[][] evalPoints = new double[nPointsPerAxis * nPointsPerAxis][2];
// int count = 0;
// for (int i = 0; i < nPointsPerAxis; i++) {
// for (int j = 0; j < nPointsPerAxis; j++) {
// double x = i * (xMax - xMin) / (nPointsPerAxis - 1) + xMin;
// double y = j * (yMax - yMin) / (nPointsPerAxis - 1) + yMin;
//
// evalPoints[count][0] = x;
// evalPoints[count][1] = y;
//
// count++;
// }
// }
//
// return Nd4j.create(evalPoints);
// }
//
// /**
// * This is to collect all the data and return it as one minibatch. Obviously only for use here with small datasets
// * @param iter
// * @return
// */
// private static DataSet allBatches(DataSetIterator iter) {
//
// List<DataSet> fullSet = new ArrayList<>();
// iter.reset();
// while (iter.hasNext()) {
// List<DataSet> miniBatchList = iter.next().asList();
// fullSet.addAll(miniBatchList);
// }
// iter.reset();
// return new ListDataSetIterator<>(fullSet,fullSet.size()).next();
// }
//
//}

View File

@ -23,6 +23,7 @@ import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.Serializable;
import java.math.BigDecimal;
import java.text.Normalizer;
import java.util.List;
import java.util.regex.Matcher;
@ -71,6 +72,30 @@ public class Utilities implements Serializable {
}
}
public static double[] getJPathArray(final String jsonPath, final String inputJson) {
try {
Object o = JsonPath.read(inputJson, jsonPath);
if (o instanceof double[])
return (double[]) o;
if (o instanceof JSONArray) {
Object[] objects = ((JSONArray) o).toArray();
double[] array = new double[objects.length];
for (int i = 0; i < objects.length; i++) {
if (objects[i] instanceof BigDecimal)
array[i] = ((BigDecimal)objects[i]).doubleValue();
else
array[i] = (double) objects[i];
}
return array;
}
return new double[0];
}
catch (Exception e) {
e.printStackTrace();
return new double[0];
}
}
// public static String normalize(final String s) {
// return Normalizer.normalize(s, Normalizer.Form.NFD)
// .replaceAll("[^\\w\\s-]", "") // Remove all non-word, non-space or non-dash characters

View File

@ -27,6 +27,6 @@ public class UtilityTest {
Author a = new Author("De Bonis, Michele", "Æ", "De Bonis", new ArrayList<CoAuthor>(), new double[]{0.0, 1.0}, "author::id", "pub::id", "orcid");
System.out.println("a = " + a.isAccurate());
System.out.println(AuthorsFactory.getLNFI(a));
}
}

View File

@ -0,0 +1,47 @@
package eu.dnetlib.deeplearning;
import com.beust.jcommander.internal.Sets;
import com.google.common.collect.Lists;
import eu.dnetlib.deeplearning.support.DataSetProcessor;
import eu.dnetlib.support.Relation;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.nd4j.linalg.dataset.MultiDataSet;
import java.util.*;
import java.util.stream.Collectors;
public class DataSetProcessorTest {
static Map<String, double[]> features;
static Set<Relation> relations;
static List<String> groundTruth;
@BeforeAll
public static void init(){
//initialize example features
features = new HashMap<>();
features.put("0", new double[]{0.0,0.0});
features.put("1", new double[]{1.0,1.0});
features.put("2", new double[]{2.0,2.0});
//initialize example relations
relations = new HashSet<>(Lists.newArrayList(
new Relation("0", "1", "simrel"),
new Relation("1", "2", "simrel")
));
//initialize example ground truth
groundTruth = Lists.newArrayList("class1", "class1", "class2");
}
@Test
public void getMultiDataSetTest() throws Exception {
MultiDataSet multiDataSet = DataSetProcessor.getMultiDataSet(features, relations, groundTruth);
System.out.println("multiDataSet = " + multiDataSet);
multiDataSet.asList();
}
}

View File

@ -0,0 +1,33 @@
package eu.dnetlib.deeplearning;
import eu.dnetlib.deeplearning.support.NetworkConfigurations;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.junit.jupiter.api.Test;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
public class NetworkConfigurationTests {
public final static int N = 3; //number of nodes
public final static int K = 7; //number of features
public static INDArray[] exampleGraph = new INDArray[]{
Nd4j.zeros(N, K), //features
Nd4j.ones(N, N), //adjacency
Nd4j.ones(N, N) //degree
};
@Test
public void simpleGCNTest() {
ComputationGraphConfiguration simpleGCNConf = NetworkConfigurations.getSimpleGCN(3, K, 5, 2);
ComputationGraph simpleGCN = new ComputationGraph(simpleGCNConf);
simpleGCN.init();
INDArray[] output = simpleGCN.output(exampleGraph);
System.out.println("output = " + output[0]);
}
}

90
pom.xml
View File

@ -247,7 +247,7 @@
<google.guava.version>15.0</google.guava.version>
<spark.version>2.2.0</spark.version>
<sparknlp.version>2.5.5</sparknlp.version>
<scala.binary.version>2.11</scala.binary.version>
<jackson.version>2.6.5</jackson.version>
<mockito-core.version>3.3.3</mockito-core.version>
@ -282,6 +282,10 @@
<junit-jupiter.version>5.6.1</junit-jupiter.version>
<maven.dependency.eu.dnetlib.dhp.dhp-build-assembly-resources.jar.path>../dhp-build/dhp-build-assembly-resources/target/dhp-build-assembly-resources-${project.version}.jar</maven.dependency.eu.dnetlib.dhp.dhp-build-assembly-resources.jar.path>
<!--deeplearning4j-->
<dl4j-master.version>1.0.0-beta7</dl4j-master.version>
<nd4j.backend>nd4j-native</nd4j.backend>
</properties>
<dependencyManagement>
@ -446,6 +450,90 @@
<version>2.10.29</version>
</dependency>
<!--DEEPLEARNING4J-->
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>${nd4j.backend}</artifactId>
<version>${dl4j-master.version}</version>
</dependency>
<dependency>
<groupId>org.datavec</groupId>
<artifactId>datavec-api</artifactId>
<version>${dl4j-master.version}</version>
</dependency>
<dependency>
<groupId>org.datavec</groupId>
<artifactId>datavec-data-image</artifactId>
<version>${dl4j-master.version}</version>
</dependency>
<dependency>
<groupId>org.datavec</groupId>
<artifactId>datavec-local</artifactId>
<version>${dl4j-master.version}</version>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-datasets</artifactId>
<version>${dl4j-master.version}</version>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-core</artifactId>
<version>${dl4j-master.version}</version>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>resources</artifactId>
<version>${dl4j-master.version}</version>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-ui</artifactId>
<version>${dl4j-master.version}</version>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-zoo</artifactId>
<version>${dl4j-master.version}</version>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>dl4j-spark-parameterserver_2.11</artifactId>
<version>${dl4j-master.version}</version>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>dl4j-spark_2.11</artifactId>
<version>${dl4j-master.version}</version>
</dependency>
<!--PLOT-->
<dependency>
<groupId>jfree</groupId>
<artifactId>jfreechart</artifactId>
<version>1.0.13</version>
</dependency>
<dependency>
<groupId>org.jfree</groupId>
<artifactId>jcommon</artifactId>
<version>1.0.23</version>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-datasets</artifactId>
<version>${dl4j-master.version}</version>
</dependency>
<!--DNET DEDUP-->
<dependency>
<groupId>eu.dnetlib</groupId>
<artifactId>dnet-dedup-test</artifactId>
<version>4.1.13-SNAPSHOT</version>
</dependency>
</dependencies>