implementation of GCN network: first commit
This commit is contained in:
parent
88451e3832
commit
4c9d33171d
|
@ -36,7 +36,7 @@ public abstract class AbstractSparkJob implements Serializable {
|
|||
this.spark = spark;
|
||||
}
|
||||
|
||||
abstract void run() throws IOException;
|
||||
protected abstract void run() throws IOException;
|
||||
|
||||
protected static SparkSession getSparkSession(SparkConf conf) {
|
||||
return SparkSession.builder().config(conf).getOrCreate();
|
||||
|
|
|
@ -45,7 +45,7 @@ public class SparkLDAAnalysis extends AbstractSparkJob {
|
|||
}
|
||||
|
||||
@Override
|
||||
void run() throws IOException {
|
||||
protected void run() throws IOException {
|
||||
// read oozie parameters
|
||||
final String authorsPath = parser.get("authorsPath");
|
||||
final String workingPath = parser.get("workingPath");
|
||||
|
|
|
@ -0,0 +1,75 @@
|
|||
package eu.dnetlib.jobs.deeplearning;
|
||||
|
||||
import eu.dnetlib.deeplearning.support.DataSetProcessor;
|
||||
import eu.dnetlib.jobs.AbstractSparkJob;
|
||||
import eu.dnetlib.jobs.SparkLDATuning;
|
||||
import eu.dnetlib.support.ArgumentApplicationParser;
|
||||
import eu.dnetlib.support.ConnectedComponent;
|
||||
import org.apache.spark.SparkConf;
|
||||
import org.apache.spark.api.java.JavaRDD;
|
||||
import org.apache.spark.api.java.JavaSparkContext;
|
||||
import org.apache.spark.sql.SparkSession;
|
||||
import org.codehaus.jackson.map.ObjectMapper;
|
||||
import org.deeplearning4j.spark.data.BatchAndExportDataSetsFunction;
|
||||
import org.deeplearning4j.spark.data.BatchAndExportMultiDataSetsFunction;
|
||||
import org.deeplearning4j.spark.datavec.iterator.IteratorUtils;
|
||||
import org.nd4j.linalg.dataset.MultiDataSet;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
|
||||
public class SparkCreateGroupDataSet extends AbstractSparkJob {
|
||||
|
||||
private static final Logger log = LoggerFactory.getLogger(SparkCreateGroupDataSet.class);
|
||||
|
||||
public SparkCreateGroupDataSet(ArgumentApplicationParser parser, SparkSession spark) {
|
||||
super(parser, spark);
|
||||
}
|
||||
public static void main(String[] args) throws Exception {
|
||||
ArgumentApplicationParser parser = new ArgumentApplicationParser(
|
||||
readResource("/jobs/parameters/createGroupDataset_parameters.json", SparkLDATuning.class)
|
||||
);
|
||||
|
||||
parser.parseArgument(args);
|
||||
|
||||
SparkConf conf = new SparkConf();
|
||||
|
||||
new SparkCreateGroupDataSet(
|
||||
parser,
|
||||
getSparkSession(conf)
|
||||
).run();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void run() throws IOException {
|
||||
// read oozie parameters
|
||||
final String groupsPath = parser.get("groupsPath");
|
||||
final String workingPath = parser.get("workingPath");
|
||||
final String groundTruthJPath = parser.get("groundTruthJPath");
|
||||
final String idJPath = parser.get("idJPath");
|
||||
final String featuresJPath = parser.get("featuresJPath");
|
||||
final int numPartitions = Optional
|
||||
.ofNullable(parser.get("numPartitions"))
|
||||
.map(Integer::valueOf)
|
||||
.orElse(NUM_PARTITIONS);
|
||||
|
||||
log.info("groupsPath: '{}'", groupsPath);
|
||||
log.info("workingPath: '{}'", workingPath);
|
||||
log.info("groundTruthJPath: '{}'", groundTruthJPath);
|
||||
log.info("idJPath: '{}'", idJPath);
|
||||
log.info("featuresJPath: '{}'", featuresJPath);
|
||||
log.info("numPartitions: '{}'", numPartitions);
|
||||
|
||||
JavaSparkContext context = JavaSparkContext.fromSparkContext(spark.sparkContext());
|
||||
|
||||
JavaRDD<ConnectedComponent> groups = context.textFile(groupsPath).map(g -> new ObjectMapper().readValue(g, ConnectedComponent.class));
|
||||
|
||||
JavaRDD<MultiDataSet> dataset = DataSetProcessor.entityGroupToMultiDataset(groups, idJPath, featuresJPath, groundTruthJPath);
|
||||
|
||||
dataset.saveAsObjectFile(workingPath + "/groupDataset");
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,99 @@
|
|||
package eu.dnetlib.jobs.deeplearning;
|
||||
|
||||
import eu.dnetlib.deeplearning.support.DataSetProcessor;
|
||||
import eu.dnetlib.deeplearning.support.NetworkConfigurations;
|
||||
import eu.dnetlib.jobs.AbstractSparkJob;
|
||||
import eu.dnetlib.jobs.SparkLDATuning;
|
||||
import eu.dnetlib.support.ArgumentApplicationParser;
|
||||
import eu.dnetlib.support.ConnectedComponent;
|
||||
import org.apache.spark.SparkConf;
|
||||
import org.apache.spark.api.java.JavaRDD;
|
||||
import org.apache.spark.api.java.JavaSparkContext;
|
||||
import org.apache.spark.sql.SparkSession;
|
||||
import org.codehaus.jackson.map.ObjectMapper;
|
||||
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
|
||||
import org.deeplearning4j.nn.graph.ComputationGraph;
|
||||
import org.deeplearning4j.optimize.listeners.PerformanceListener;
|
||||
import org.deeplearning4j.optimize.solvers.accumulation.encoding.threshold.AdaptiveThresholdAlgorithm;
|
||||
import org.deeplearning4j.spark.api.RDDTrainingApproach;
|
||||
import org.deeplearning4j.spark.api.TrainingMaster;
|
||||
import org.deeplearning4j.spark.impl.graph.SparkComputationGraph;
|
||||
import org.deeplearning4j.spark.parameterserver.training.SharedTrainingMaster;
|
||||
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
||||
import org.nd4j.parameterserver.distributed.conf.VoidConfiguration;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Optional;
|
||||
|
||||
public class SparkGraphClassificationTraining extends AbstractSparkJob {
|
||||
|
||||
private static final Logger log = LoggerFactory.getLogger(SparkGraphClassificationTraining.class);
|
||||
|
||||
public SparkGraphClassificationTraining(ArgumentApplicationParser parser, SparkSession spark) {
|
||||
super(parser, spark);
|
||||
}
|
||||
public static void main(String[] args) throws Exception {
|
||||
ArgumentApplicationParser parser = new ArgumentApplicationParser(
|
||||
readResource("/jobs/parameters/graphClassificationTraining_parameters.json", SparkLDATuning.class)
|
||||
);
|
||||
|
||||
parser.parseArgument(args);
|
||||
|
||||
SparkConf conf = new SparkConf();
|
||||
|
||||
new SparkGraphClassificationTraining(
|
||||
parser,
|
||||
getSparkSession(conf)
|
||||
).run();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void run() throws IOException {
|
||||
// read oozie parameters
|
||||
final String workingPath = parser.get("workingPath");
|
||||
final int numPartitions = Optional
|
||||
.ofNullable(parser.get("numPartitions"))
|
||||
.map(Integer::valueOf)
|
||||
.orElse(NUM_PARTITIONS);
|
||||
log.info("workingPath: '{}'", workingPath);
|
||||
log.info("numPartitions: '{}'", numPartitions);
|
||||
|
||||
JavaSparkContext context = JavaSparkContext.fromSparkContext(spark.sparkContext());
|
||||
|
||||
VoidConfiguration conf = VoidConfiguration.builder()
|
||||
.unicastPort(40123)
|
||||
// .networkMask("255.255.148.0/22")
|
||||
.controllerAddress("127.0.0.1")
|
||||
.build();
|
||||
|
||||
TrainingMaster trainingMaster = new SharedTrainingMaster.Builder(conf,1)
|
||||
.rngSeed(12345)
|
||||
.collectTrainingStats(false)
|
||||
.thresholdAlgorithm(new AdaptiveThresholdAlgorithm(1e-3))
|
||||
.batchSizePerWorker(32)
|
||||
.workersPerNode(4)
|
||||
.rddTrainingApproach(RDDTrainingApproach.Direct)
|
||||
.build();
|
||||
|
||||
JavaRDD<MultiDataSet> trainData = context.objectFile(workingPath + "/groupDataset");
|
||||
|
||||
SparkComputationGraph sparkComputationGraph = new SparkComputationGraph(
|
||||
context,
|
||||
NetworkConfigurations.getSimpleGCN(3, 2, 5, 2),
|
||||
trainingMaster);
|
||||
sparkComputationGraph.setListeners(new PerformanceListener(10, true));
|
||||
|
||||
//execute training
|
||||
for (int i = 0; i < 20; i ++) {
|
||||
sparkComputationGraph.fitMultiDataSet(trainData);
|
||||
}
|
||||
|
||||
ComputationGraph network = sparkComputationGraph.getNetwork();
|
||||
|
||||
System.out.println("network = " + network.getConfiguration().toJson());
|
||||
|
||||
|
||||
}
|
||||
}
|
|
@ -0,0 +1,38 @@
|
|||
[
|
||||
{
|
||||
"paramName": "i",
|
||||
"paramLongName": "groupsPath",
|
||||
"paramDescription": "the input data: groups to be transformed into dataset",
|
||||
"paramRequired": true
|
||||
},
|
||||
{
|
||||
"paramName": "f",
|
||||
"paramLongName": "featuresJPath",
|
||||
"paramDescription": "the jpath of the features field",
|
||||
"paramRequired": true
|
||||
},
|
||||
{
|
||||
"paramName": "w",
|
||||
"paramLongName": "workingPath",
|
||||
"paramDescription": "path of the working directory",
|
||||
"paramRequired": true
|
||||
},
|
||||
{
|
||||
"paramName": "np",
|
||||
"paramLongName": "numPartitions",
|
||||
"paramDescription": "number of partitions for the similarity relations intermediate phases",
|
||||
"paramRequired": false
|
||||
},
|
||||
{
|
||||
"paramName": "id",
|
||||
"paramLongName": "idJPath",
|
||||
"paramDescription": "the jpath of the id field",
|
||||
"paramRequired": true
|
||||
},
|
||||
{
|
||||
"paramName": "gt",
|
||||
"paramLongName": "groundTruthJPath",
|
||||
"paramDescription": "the jpath of the field to be used as ground truth",
|
||||
"paramRequired": true
|
||||
}
|
||||
]
|
|
@ -0,0 +1,14 @@
|
|||
[
|
||||
{
|
||||
"paramName": "w",
|
||||
"paramLongName": "workingPath",
|
||||
"paramDescription": "path of the working directory",
|
||||
"paramRequired": true
|
||||
},
|
||||
{
|
||||
"paramName": "np",
|
||||
"paramLongName": "numPartitions",
|
||||
"paramDescription": "number of partitions for the similarity relations intermediate phases",
|
||||
"paramRequired": false
|
||||
}
|
||||
]
|
|
@ -0,0 +1,102 @@
|
|||
package eu.dnetlib.jobs.deeplearning;
|
||||
|
||||
import eu.dnetlib.jobs.AbstractSparkJob;
|
||||
import eu.dnetlib.support.ArgumentApplicationParser;
|
||||
import org.apache.commons.io.FileUtils;
|
||||
import org.apache.commons.io.IOUtils;
|
||||
import org.apache.spark.api.java.JavaSparkContext;
|
||||
import org.apache.spark.sql.SparkSession;
|
||||
import org.junit.jupiter.api.*;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.net.URISyntaxException;
|
||||
import java.nio.file.Paths;
|
||||
|
||||
@TestMethodOrder(MethodOrderer.OrderAnnotation.class)
|
||||
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
|
||||
public class GNNTrainingTest {
|
||||
|
||||
static SparkSession spark;
|
||||
static JavaSparkContext context;
|
||||
final static String workingPath = "/tmp/working_dir";
|
||||
|
||||
final static String numPartitions = "20";
|
||||
final String inputDataPath = Paths
|
||||
.get(getClass().getResource("/eu/dnetlib/jobs/examples/authors.groups.example.json").toURI())
|
||||
.toFile()
|
||||
.getAbsolutePath();
|
||||
final static String groundTruthJPath = "$.orcid";
|
||||
final static String idJPath = "$.id";
|
||||
final static String featuresJPath = "$.topics";
|
||||
|
||||
public GNNTrainingTest() throws URISyntaxException {}
|
||||
|
||||
public static void cleanup() throws IOException {
|
||||
//remove directories and clean workspace
|
||||
FileUtils.deleteDirectory(new File(workingPath));
|
||||
}
|
||||
|
||||
@BeforeAll
|
||||
public void setup() throws IOException {
|
||||
cleanup();
|
||||
|
||||
spark = SparkSession
|
||||
.builder()
|
||||
.appName("Testing")
|
||||
.master("local[*]")
|
||||
.getOrCreate();
|
||||
|
||||
context = JavaSparkContext.fromSparkContext(spark.sparkContext());
|
||||
}
|
||||
|
||||
@AfterAll
|
||||
public static void finalCleanUp() throws IOException {
|
||||
cleanup();
|
||||
}
|
||||
|
||||
@Test
|
||||
@Order(1)
|
||||
public void createGroupDataSetTest() throws Exception {
|
||||
ArgumentApplicationParser parser = new ArgumentApplicationParser(readResource("/jobs/parameters/createGroupDataset_parameters.json", SparkCreateGroupDataSet.class));
|
||||
|
||||
parser.parseArgument(
|
||||
new String[] {
|
||||
"-i", inputDataPath,
|
||||
"-gt", groundTruthJPath,
|
||||
"-id", idJPath,
|
||||
"-f", featuresJPath,
|
||||
"-w", workingPath,
|
||||
"-np", numPartitions
|
||||
}
|
||||
);
|
||||
|
||||
new SparkCreateGroupDataSet(
|
||||
parser,
|
||||
spark
|
||||
).run();
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
@Order(2)
|
||||
public void graphClassificationTrainingTest() throws Exception{
|
||||
ArgumentApplicationParser parser = new ArgumentApplicationParser(readResource("/jobs/parameters/graphClassificationTraining_parameters.json", SparkGraphClassificationTraining.class));
|
||||
|
||||
parser.parseArgument(
|
||||
new String[] {
|
||||
"-w", workingPath,
|
||||
"-np", numPartitions
|
||||
}
|
||||
);
|
||||
|
||||
new SparkGraphClassificationTraining(
|
||||
parser,
|
||||
spark
|
||||
).run();
|
||||
}
|
||||
|
||||
public static String readResource(String path, Class<? extends AbstractSparkJob> clazz) throws IOException {
|
||||
return IOUtils.toString(clazz.getResourceAsStream(path));
|
||||
}
|
||||
}
|
File diff suppressed because one or more lines are too long
|
@ -51,6 +51,45 @@
|
|||
<artifactId>junit-jupiter</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
<!--DEEPLEARNING4J -->
|
||||
<dependency>
|
||||
<groupId>org.nd4j</groupId>
|
||||
<artifactId>${nd4j.backend}</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.deeplearning4j</groupId>
|
||||
<artifactId>deeplearning4j-core</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.deeplearning4j</groupId>
|
||||
<artifactId>deeplearning4j-datasets</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.deeplearning4j</groupId>
|
||||
<artifactId>dl4j-spark-parameterserver_2.11</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.deeplearning4j</groupId>
|
||||
<artifactId>dl4j-spark_2.11</artifactId>
|
||||
</dependency>
|
||||
|
||||
<!--PLOT-->
|
||||
<dependency>
|
||||
<groupId>jfree</groupId>
|
||||
<artifactId>jfreechart</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.jfree</groupId>
|
||||
<artifactId>jcommon</artifactId>
|
||||
</dependency>
|
||||
|
||||
<!--DNET DEDUP-->
|
||||
<dependency>
|
||||
<groupId>eu.dnetlib</groupId>
|
||||
<artifactId>dnet-dedup-test</artifactId>
|
||||
</dependency>
|
||||
|
||||
</dependencies>
|
||||
|
||||
</project>
|
|
@ -0,0 +1,130 @@
|
|||
//package eu.dnetlib.deeplearning;
|
||||
//
|
||||
///* *****************************************************************************
|
||||
// *
|
||||
// *
|
||||
// *
|
||||
// * This program and the accompanying materials are made available under the
|
||||
// * terms of the Apache License, Version 2.0 which is available at
|
||||
// * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
// * See the NOTICE file distributed with this work for additional
|
||||
// * information regarding copyright ownership.
|
||||
// *
|
||||
// * Unless required by applicable law or agreed to in writing, software
|
||||
// * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
// * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
// * License for the specific language governing permissions and limitations
|
||||
// * under the License.
|
||||
// *
|
||||
// * SPDX-License-Identifier: Apache-2.0
|
||||
// ******************************************************************************/
|
||||
//
|
||||
//import org.datavec.api.records.reader.RecordReader;
|
||||
//import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
|
||||
//import org.datavec.api.split.FileSplit;
|
||||
//import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
|
||||
//import org.deeplearning4j.examples.utils.DownloaderUtility;
|
||||
//import org.deeplearning4j.examples.utils.PlotUtil;
|
||||
//import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||
//import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||
//import org.deeplearning4j.nn.conf.layers.DenseLayer;
|
||||
//import org.deeplearning4j.nn.conf.layers.OutputLayer;
|
||||
//import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||
//import org.deeplearning4j.nn.weights.WeightInit;
|
||||
//import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
|
||||
//import org.nd4j.evaluation.classification.Evaluation;
|
||||
//import org.nd4j.linalg.activations.Activation;
|
||||
//import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
//import org.nd4j.linalg.dataset.DataSet;
|
||||
//import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||
//import org.nd4j.linalg.learning.config.Nesterovs;
|
||||
//import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;
|
||||
//
|
||||
//import java.io.File;
|
||||
//import java.util.concurrent.TimeUnit;
|
||||
//
|
||||
//public class GroupClassifier {
|
||||
//
|
||||
// public static boolean visualize = true;
|
||||
// public static String dataLocalPath;
|
||||
//
|
||||
// public static void main(String[] args) throws Exception {
|
||||
// int seed = 123;
|
||||
// double learningRate = 0.01;
|
||||
// int batchSize = 50;
|
||||
// int nEpochs = 30;
|
||||
//
|
||||
// int numInputs = 2;
|
||||
// int numOutputs = 2;
|
||||
// int numHiddenNodes = 20;
|
||||
//
|
||||
// dataLocalPath = DownloaderUtility.CLASSIFICATIONDATA.Download();
|
||||
// //Load the training data:
|
||||
// RecordReader rr = new CSVRecordReader();
|
||||
// rr.initialize(new FileSplit(new File(dataLocalPath, "linear_data_train.csv")));
|
||||
// DataSetIterator trainIter = new RecordReaderDataSetIterator(rr, batchSize, 0, 2);
|
||||
//
|
||||
// //Load the test/evaluation data:
|
||||
// RecordReader rrTest = new CSVRecordReader();
|
||||
// rrTest.initialize(new FileSplit(new File(dataLocalPath, "linear_data_eval.csv")));
|
||||
// DataSetIterator testIter = new RecordReaderDataSetIterator(rrTest, batchSize, 0, 2);
|
||||
//
|
||||
// MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
||||
// .seed(seed)
|
||||
// .weightInit(WeightInit.XAVIER)
|
||||
// .updater(new Nesterovs(learningRate, 0.9))
|
||||
// .list()
|
||||
// .layer(new DenseLayer.Builder().nIn(numInputs).nOut(numHiddenNodes)
|
||||
// .activation(Activation.RELU)
|
||||
// .build())
|
||||
// .layer(new OutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD)
|
||||
// .activation(Activation.SOFTMAX)
|
||||
// .nIn(numHiddenNodes).nOut(numOutputs).build())
|
||||
// .build();
|
||||
//
|
||||
//
|
||||
// MultiLayerNetwork model = new MultiLayerNetwork(conf);
|
||||
// model.init();
|
||||
// model.setListeners(new ScoreIterationListener(10)); //Print score every 10 parameter updates
|
||||
//
|
||||
// model.fit(trainIter, nEpochs);
|
||||
//
|
||||
// System.out.println("Evaluate model....");
|
||||
// Evaluation eval = new Evaluation(numOutputs);
|
||||
// while (testIter.hasNext()) {
|
||||
// DataSet t = testIter.next();
|
||||
// INDArray features = t.getFeatures();
|
||||
// INDArray labels = t.getLabels();
|
||||
// INDArray predicted = model.output(features, false);
|
||||
// eval.eval(labels, predicted);
|
||||
// }
|
||||
// //An alternate way to do the above loop
|
||||
// //Evaluation evalResults = model.evaluate(testIter);
|
||||
//
|
||||
// //Print the evaluation statistics
|
||||
// System.out.println(eval.stats());
|
||||
//
|
||||
// System.out.println("\n****************Example finished********************");
|
||||
// //Training is complete. Code that follows is for plotting the data & predictions only
|
||||
// generateVisuals(model, trainIter, testIter);
|
||||
// }
|
||||
//
|
||||
// public static void generateVisuals(MultiLayerNetwork model, DataSetIterator trainIter, DataSetIterator testIter) throws Exception {
|
||||
// if (visualize) {
|
||||
// double xMin = 0;
|
||||
// double xMax = 1.0;
|
||||
// double yMin = -0.2;
|
||||
// double yMax = 0.8;
|
||||
// int nPointsPerAxis = 100;
|
||||
//
|
||||
// //Generate x,y points that span the whole range of features
|
||||
// INDArray allXYPoints = PlotUtil.generatePointsOnGraph(xMin, xMax, yMin, yMax, nPointsPerAxis);
|
||||
// //Get train data and plot with predictions
|
||||
// PlotUtil.plotTrainingData(model, trainIter, allXYPoints, nPointsPerAxis);
|
||||
// TimeUnit.SECONDS.sleep(3);
|
||||
// //Get test data, run the test data through the network to generate predictions, and plot those predictions:
|
||||
// PlotUtil.plotTestData(model, testIter, allXYPoints, nPointsPerAxis);
|
||||
// }
|
||||
// }
|
||||
//}
|
||||
//
|
|
@ -0,0 +1,24 @@
|
|||
package eu.dnetlib.deeplearning.layers;
|
||||
|
||||
import org.deeplearning4j.nn.conf.graph.GraphVertex;
|
||||
import org.deeplearning4j.nn.conf.layers.samediff.SDVertexParams;
|
||||
import org.deeplearning4j.nn.conf.layers.samediff.SameDiffLambdaVertex;
|
||||
import org.deeplearning4j.nn.conf.layers.samediff.SameDiffVertex;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
public class GraphConvolutionVertex extends SameDiffLambdaVertex {
|
||||
|
||||
@Override
|
||||
public SDVariable defineVertex(SameDiff sameDiff, VertexInputs inputs) {
|
||||
SDVariable features = inputs.getInput(0);
|
||||
SDVariable adjacency = inputs.getInput(1);
|
||||
SDVariable degree = inputs.getInput(2).pow(0.5);
|
||||
|
||||
//result: DegreeMatrix^-0.5 x Adjacent x DegreeMatrix^-0.5 x Features
|
||||
return degree.mmul(adjacency).mmul(degree).mmul(features);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,21 @@
|
|||
package eu.dnetlib.deeplearning.layers;
|
||||
|
||||
import org.deeplearning4j.nn.conf.layers.samediff.SameDiffLambdaLayer;
|
||||
import org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer;
|
||||
import org.nd4j.autodiff.samediff.SDIndex;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
public class GraphGlobalAddPool extends SameDiffLambdaLayer {
|
||||
|
||||
int size;
|
||||
public GraphGlobalAddPool(int size) {
|
||||
this.size = size;
|
||||
}
|
||||
@Override
|
||||
public SDVariable defineLayer(SameDiff sameDiff, SDVariable layerInput) {
|
||||
return layerInput.mean(0).reshape(1, size); //reshape because output layer expects 2-dimensional arrays
|
||||
}
|
||||
}
|
|
@ -0,0 +1,88 @@
|
|||
package eu.dnetlib.deeplearning.support;
|
||||
|
||||
import eu.dnetlib.featureextraction.Utilities;
|
||||
import eu.dnetlib.support.Author;
|
||||
import eu.dnetlib.support.ConnectedComponent;
|
||||
import eu.dnetlib.support.Relation;
|
||||
import org.apache.spark.api.java.JavaRDD;
|
||||
import org.codehaus.jackson.map.ObjectMapper;
|
||||
import org.jetbrains.annotations.NotNull;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.dataset.MultiDataSet;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.*;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
public class DataSetProcessor {
|
||||
|
||||
public static JavaRDD<MultiDataSet> entityGroupToMultiDataset(JavaRDD<ConnectedComponent> groupEntity, String idJPath, String featureJPath, String groundTruthJPath) {
|
||||
|
||||
return groupEntity.map(g -> {
|
||||
Map<String, double[]> featuresMap = new HashMap<>();
|
||||
List<String> groundTruth = new ArrayList<>();
|
||||
Set<String> entities = g.getDocs();
|
||||
for(String json:entities) {
|
||||
featuresMap.put(
|
||||
Utilities.getJPathString(idJPath, json),
|
||||
Utilities.getJPathArray(featureJPath, json)
|
||||
);
|
||||
groundTruth.add(Utilities.getJPathString(groundTruthJPath, json));
|
||||
}
|
||||
|
||||
Set<Relation> relations = g.getSimrels();
|
||||
|
||||
return getMultiDataSet(featuresMap, relations, groundTruth);
|
||||
});
|
||||
}
|
||||
|
||||
public static MultiDataSet getMultiDataSet(Map<String, double[]> featuresMap, Set<Relation> relations, List<String> groundTruth) {
|
||||
|
||||
List<String> identifiers = new ArrayList<>(featuresMap.keySet());
|
||||
|
||||
int numNodes = identifiers.size();
|
||||
|
||||
//initialize arrays
|
||||
INDArray adjacency = Nd4j.zeros(numNodes, numNodes);
|
||||
INDArray features = Nd4j.zeros(numNodes, featuresMap.get(identifiers.get(0)).length); //feature size taken from the first element (it's equal for every element)
|
||||
INDArray degree = Nd4j.zeros(numNodes, numNodes);
|
||||
|
||||
//create adjacency
|
||||
for(Relation r: relations) {
|
||||
adjacency.put(identifiers.indexOf(r.getSource()), identifiers.indexOf(r.getTarget()), 1);
|
||||
adjacency.put(identifiers.indexOf(r.getTarget()), identifiers.indexOf(r.getSource()), 1);
|
||||
}
|
||||
adjacency.addi(Nd4j.eye(numNodes));
|
||||
|
||||
//create degree and features
|
||||
List<String> degreeSupport = relations.stream().flatMap(r -> Stream.of(r.getSource(), r.getTarget())).collect(Collectors.toList());
|
||||
for(int i=0; i< identifiers.size(); i++) {
|
||||
degree.put(i, i, Collections.frequency(degreeSupport, identifiers.get(i)));
|
||||
features.putRow(i, Nd4j.create(featuresMap.get(identifiers.get(i))));
|
||||
}
|
||||
|
||||
//infer label
|
||||
INDArray label = Nd4j.zeros(1, 2);
|
||||
if (groundTruth.stream().distinct().count()==1) {
|
||||
//correct (same elements)
|
||||
label.put(0, 0, 1.0);
|
||||
}
|
||||
else {
|
||||
//wrong (different elements)
|
||||
label.put(0, 1, 1.0);
|
||||
}
|
||||
|
||||
return new MultiDataSet(
|
||||
new INDArray[]{
|
||||
features,
|
||||
adjacency,
|
||||
degree
|
||||
},
|
||||
new INDArray[]{
|
||||
label
|
||||
}
|
||||
);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,11 @@
|
|||
package eu.dnetlib.deeplearning.support;
|
||||
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.dataset.MultiDataSet;
|
||||
|
||||
import java.io.*;
|
||||
import java.util.List;
|
||||
|
||||
public class GroupMultiDataSet extends MultiDataSet {
|
||||
|
||||
}
|
|
@ -0,0 +1,97 @@
|
|||
package eu.dnetlib.deeplearning.support;
|
||||
|
||||
import eu.dnetlib.deeplearning.layers.GraphConvolutionVertex;
|
||||
import eu.dnetlib.deeplearning.layers.GraphGlobalAddPool;
|
||||
import org.bytedeco.opencv.opencv_dnn.PoolingLayer;
|
||||
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
|
||||
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||
import org.deeplearning4j.nn.conf.graph.MergeVertex;
|
||||
import org.deeplearning4j.nn.conf.layers.*;
|
||||
import org.deeplearning4j.nn.weights.WeightInit;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.linalg.activations.Activation;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.learning.config.Adam;
|
||||
import org.nd4j.linalg.learning.config.Nesterovs;
|
||||
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||
|
||||
public class NetworkConfigurations {
|
||||
|
||||
//parameteres default values
|
||||
protected static final int SEED = 12345;
|
||||
protected static final double LEARNING_RATE = 1e-3;
|
||||
protected static final String ADJACENCY_MATRIX = "adjacency";
|
||||
protected static final String FEATURES_MATRIX = "features";
|
||||
protected static final String DEGREE_MATRIX = "degrees";
|
||||
|
||||
public static MultiLayerConfiguration getLinearDataClassifier(int numInputs, int numHiddenNodes, int numOutputs) {
|
||||
return new NeuralNetConfiguration.Builder()
|
||||
.seed(SEED)
|
||||
.weightInit(WeightInit.XAVIER)
|
||||
.updater(new Nesterovs(LEARNING_RATE, 0.9))
|
||||
.list()
|
||||
.layer(new DenseLayer.Builder().nIn(numInputs).nOut(numHiddenNodes)
|
||||
.activation(Activation.RELU)
|
||||
.build())
|
||||
.layer(new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
|
||||
.activation(Activation.SOFTMAX)
|
||||
.nIn(numHiddenNodes).nOut(numOutputs).build())
|
||||
.build();
|
||||
}
|
||||
|
||||
public static ComputationGraphConfiguration getSimpleGCN(int numLayers, int numInputs, int numHiddenNodes, int numClasses) {
|
||||
|
||||
ComputationGraphConfiguration.GraphBuilder baseConfig = new NeuralNetConfiguration.Builder()
|
||||
.seed(SEED)
|
||||
.updater(new Adam(LEARNING_RATE))
|
||||
.weightInit(WeightInit.XAVIER)
|
||||
.graphBuilder()
|
||||
.addInputs(FEATURES_MATRIX, ADJACENCY_MATRIX, DEGREE_MATRIX)
|
||||
//first convolution layer
|
||||
.addVertex("layer1",
|
||||
new GraphConvolutionVertex(),
|
||||
FEATURES_MATRIX, ADJACENCY_MATRIX, DEGREE_MATRIX)
|
||||
.layer("conv1",
|
||||
new DenseLayer.Builder().nIn(numInputs).nOut(numHiddenNodes)
|
||||
.activation(Activation.RELU)
|
||||
.build(),
|
||||
"layer1")
|
||||
.layer("batch1",
|
||||
new BatchNormalization.Builder().nOut(numHiddenNodes).build(),
|
||||
"conv1");
|
||||
|
||||
//ad as many layers as requested
|
||||
for(int i=2; i<=numLayers; i++) {
|
||||
baseConfig = baseConfig.addVertex("layer" + i,
|
||||
new GraphConvolutionVertex(),
|
||||
"batch" + (i-1), ADJACENCY_MATRIX, DEGREE_MATRIX)
|
||||
.layer("conv" + i,
|
||||
new DenseLayer.Builder().nIn(numHiddenNodes).nOut(numHiddenNodes)
|
||||
.activation(Activation.RELU)
|
||||
.build(),
|
||||
"layer" + i)
|
||||
.layer("batch" + i,
|
||||
new BatchNormalization.Builder().nOut(numHiddenNodes).build(),
|
||||
"conv" + i);
|
||||
}
|
||||
|
||||
baseConfig = baseConfig
|
||||
.layer("pool",
|
||||
new GraphGlobalAddPool(numHiddenNodes),
|
||||
"batch" + numLayers)
|
||||
.layer("fc1",
|
||||
new DenseLayer.Builder().nIn(numHiddenNodes).nOut(numHiddenNodes)
|
||||
.activation(Activation.RELU)
|
||||
.weightInit(WeightInit.XAVIER)
|
||||
.build(),
|
||||
"pool")
|
||||
.layer("out",
|
||||
new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
|
||||
.activation(Activation.SOFTMAX)
|
||||
.nIn(numHiddenNodes).nOut(numClasses).build(),
|
||||
"fc1");
|
||||
|
||||
return baseConfig.setOutputs("out").build();
|
||||
}
|
||||
}
|
|
@ -0,0 +1,253 @@
|
|||
//package eu.dnetlib.deeplearning.support;
|
||||
//
|
||||
//import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||
//import org.jfree.chart.ChartPanel;
|
||||
//import org.jfree.chart.ChartUtilities;
|
||||
//import org.jfree.chart.JFreeChart;
|
||||
//import org.jfree.chart.axis.AxisLocation;
|
||||
//import org.jfree.chart.axis.NumberAxis;
|
||||
//import org.jfree.chart.block.BlockBorder;
|
||||
//import org.jfree.chart.plot.DatasetRenderingOrder;
|
||||
//import org.jfree.chart.plot.XYPlot;
|
||||
//import org.jfree.chart.renderer.GrayPaintScale;
|
||||
//import org.jfree.chart.renderer.PaintScale;
|
||||
//import org.jfree.chart.renderer.xy.XYBlockRenderer;
|
||||
//import org.jfree.chart.renderer.xy.XYLineAndShapeRenderer;
|
||||
//import org.jfree.chart.title.PaintScaleLegend;
|
||||
//import org.jfree.data.xy.*;
|
||||
//import org.jfree.ui.RectangleEdge;
|
||||
//import org.jfree.ui.RectangleInsets;
|
||||
//import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
//import org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax;
|
||||
//import org.nd4j.linalg.dataset.DataSet;
|
||||
//import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||
//import org.nd4j.linalg.factory.Nd4j;
|
||||
//
|
||||
//import javax.swing.*;
|
||||
//import java.awt.*;
|
||||
//import java.util.ArrayList;
|
||||
//import java.util.List;
|
||||
//
|
||||
///**
|
||||
// * Simple plotting methods for the MLPClassifier quickstartexamples
|
||||
// *
|
||||
// * @author Alex Black
|
||||
// */
|
||||
//public class PlotUtils {
|
||||
//
|
||||
// /**
|
||||
// * Plot the training data. Assume 2d input, classification output
|
||||
// *
|
||||
// * @param model Model to use to get predictions
|
||||
// * @param trainIter DataSet Iterator
|
||||
// * @param backgroundIn sets of x,y points in input space, plotted in the background
|
||||
// * @param nDivisions Number of points (per axis, for the backgroundIn/backgroundOut arrays)
|
||||
// */
|
||||
// public static void plotTrainingData(MultiLayerNetwork model, DataSetIterator trainIter, INDArray backgroundIn, int nDivisions) {
|
||||
// double[] mins = backgroundIn.min(0).data().asDouble();
|
||||
// double[] maxs = backgroundIn.max(0).data().asDouble();
|
||||
//
|
||||
// DataSet ds = allBatches(trainIter);
|
||||
// INDArray backgroundOut = model.output(backgroundIn);
|
||||
//
|
||||
// XYZDataset backgroundData = createBackgroundData(backgroundIn, backgroundOut);
|
||||
// JPanel panel = new ChartPanel(createChart(backgroundData, mins, maxs, nDivisions, createDataSetTrain(ds.getFeatures(), ds.getLabels())));
|
||||
//
|
||||
// JFrame f = new JFrame();
|
||||
// f.add(panel);
|
||||
// f.setDefaultCloseOperation(WindowConstants.EXIT_ON_CLOSE);
|
||||
// f.pack();
|
||||
// f.setTitle("Training Data");
|
||||
//
|
||||
// f.setVisible(true);
|
||||
// f.setLocation(0, 0);
|
||||
// }
|
||||
//
|
||||
// /**
|
||||
// * Plot the training data. Assume 2d input, classification output
|
||||
// *
|
||||
// * @param model Model to use to get predictions
|
||||
// * @param testIter Test Iterator
|
||||
// * @param backgroundIn sets of x,y points in input space, plotted in the background
|
||||
// * @param nDivisions Number of points (per axis, for the backgroundIn/backgroundOut arrays)
|
||||
// */
|
||||
// public static void plotTestData(MultiLayerNetwork model, DataSetIterator testIter, INDArray backgroundIn, int nDivisions) {
|
||||
//
|
||||
// double[] mins = backgroundIn.min(0).data().asDouble();
|
||||
// double[] maxs = backgroundIn.max(0).data().asDouble();
|
||||
//
|
||||
// INDArray backgroundOut = model.output(backgroundIn);
|
||||
// XYZDataset backgroundData = createBackgroundData(backgroundIn, backgroundOut);
|
||||
// DataSet ds = allBatches(testIter);
|
||||
// INDArray predicted = model.output(ds.getFeatures());
|
||||
// JPanel panel = new ChartPanel(createChart(backgroundData, mins, maxs, nDivisions, createDataSetTest(ds.getFeatures(), ds.getLabels(), predicted)));
|
||||
//
|
||||
// JFrame f = new JFrame();
|
||||
// f.add(panel);
|
||||
// f.setDefaultCloseOperation(WindowConstants.EXIT_ON_CLOSE);
|
||||
// f.pack();
|
||||
// f.setTitle("Test Data");
|
||||
//
|
||||
// f.setVisible(true);
|
||||
// f.setLocationRelativeTo(null);
|
||||
// //f.setLocation(100,100);
|
||||
//
|
||||
// }
|
||||
//
|
||||
//
|
||||
// /**
|
||||
// * Create data for the background data set
|
||||
// */
|
||||
// private static XYZDataset createBackgroundData(INDArray backgroundIn, INDArray backgroundOut) {
|
||||
// int nRows = backgroundIn.rows();
|
||||
// double[] xValues = new double[nRows];
|
||||
// double[] yValues = new double[nRows];
|
||||
// double[] zValues = new double[nRows];
|
||||
// for (int i = 0; i < nRows; i++) {
|
||||
// xValues[i] = backgroundIn.getDouble(i, 0);
|
||||
// yValues[i] = backgroundIn.getDouble(i, 1);
|
||||
// zValues[i] = backgroundOut.getDouble(i, 0);
|
||||
//
|
||||
// }
|
||||
//
|
||||
// DefaultXYZDataset dataset = new DefaultXYZDataset();
|
||||
// dataset.addSeries("Series 1",
|
||||
// new double[][]{xValues, yValues, zValues});
|
||||
// return dataset;
|
||||
// }
|
||||
//
|
||||
// //Training data
|
||||
// private static XYDataset createDataSetTrain(INDArray features, INDArray labels) {
|
||||
// int nRows = features.rows();
|
||||
//
|
||||
// int nClasses = 2; // Binary classification using one output call end sigmoid.
|
||||
//
|
||||
// XYSeries[] series = new XYSeries[nClasses];
|
||||
// for (int i = 0; i < series.length; i++) series[i] = new XYSeries("Class " + i);
|
||||
// INDArray argMax = Nd4j.getExecutioner().exec(new ArgMax(new INDArray[]{labels},false,new int[]{1}))[0];
|
||||
// for (int i = 0; i < nRows; i++) {
|
||||
// int classIdx = (int) argMax.getDouble(i);
|
||||
// series[classIdx].add(features.getDouble(i, 0), features.getDouble(i, 1));
|
||||
// }
|
||||
//
|
||||
// XYSeriesCollection c = new XYSeriesCollection();
|
||||
// for (XYSeries s : series) c.addSeries(s);
|
||||
// return c;
|
||||
// }
|
||||
//
|
||||
// //Test data
|
||||
// private static XYDataset createDataSetTest(INDArray features, INDArray labels, INDArray predicted) {
|
||||
// int nRows = features.rows();
|
||||
//
|
||||
// int nClasses = 2; // Binary classification using one output call end sigmoid.
|
||||
//
|
||||
// XYSeries[] series = new XYSeries[nClasses * nClasses];
|
||||
// int[] series_index = new int[]{0, 3, 2, 1}; //little hack to make the charts look consistent.
|
||||
// for (int i = 0; i < nClasses * nClasses; i++) {
|
||||
// int trueClass = i / nClasses;
|
||||
// int predClass = i % nClasses;
|
||||
// String label = "actual=" + trueClass + ", pred=" + predClass;
|
||||
// series[series_index[i]] = new XYSeries(label);
|
||||
// }
|
||||
// INDArray actualIdx = labels.argMax(1);
|
||||
// INDArray predictedIdx = predicted.argMax(1);
|
||||
// for (int i = 0; i < nRows; i++) {
|
||||
// int classIdx = actualIdx.getInt(i);
|
||||
// int predIdx = predictedIdx.getInt(i);
|
||||
// int idx = series_index[classIdx * nClasses + predIdx];
|
||||
// series[idx].add(features.getDouble(i, 0), features.getDouble(i, 1));
|
||||
// }
|
||||
//
|
||||
// XYSeriesCollection c = new XYSeriesCollection();
|
||||
// for (XYSeries s : series) c.addSeries(s);
|
||||
// return c;
|
||||
// }
|
||||
//
|
||||
// private static JFreeChart createChart(XYZDataset dataset, double[] mins, double[] maxs, int nPoints, XYDataset xyData) {
|
||||
// NumberAxis xAxis = new NumberAxis("X");
|
||||
// xAxis.setRange(mins[0], maxs[0]);
|
||||
//
|
||||
//
|
||||
// NumberAxis yAxis = new NumberAxis("Y");
|
||||
// yAxis.setRange(mins[1], maxs[1]);
|
||||
//
|
||||
// XYBlockRenderer renderer = new XYBlockRenderer();
|
||||
// renderer.setBlockWidth((maxs[0] - mins[0]) / (nPoints - 1));
|
||||
// renderer.setBlockHeight((maxs[1] - mins[1]) / (nPoints - 1));
|
||||
// PaintScale scale = new GrayPaintScale(0, 1.0);
|
||||
// renderer.setPaintScale(scale);
|
||||
// XYPlot plot = new XYPlot(dataset, xAxis, yAxis, renderer);
|
||||
// plot.setBackgroundPaint(Color.lightGray);
|
||||
// plot.setDomainGridlinesVisible(false);
|
||||
// plot.setRangeGridlinesVisible(false);
|
||||
// plot.setAxisOffset(new RectangleInsets(5, 5, 5, 5));
|
||||
// JFreeChart chart = new JFreeChart("", plot);
|
||||
// chart.getXYPlot().getRenderer().setSeriesVisibleInLegend(0, false);
|
||||
//
|
||||
//
|
||||
// NumberAxis scaleAxis = new NumberAxis("Probability (class 1)");
|
||||
// scaleAxis.setAxisLinePaint(Color.white);
|
||||
// scaleAxis.setTickMarkPaint(Color.white);
|
||||
// scaleAxis.setTickLabelFont(new Font("Dialog", Font.PLAIN, 7));
|
||||
// PaintScaleLegend legend = new PaintScaleLegend(new GrayPaintScale(),
|
||||
// scaleAxis);
|
||||
// legend.setStripOutlineVisible(false);
|
||||
// legend.setSubdivisionCount(20);
|
||||
// legend.setAxisLocation(AxisLocation.BOTTOM_OR_LEFT);
|
||||
// legend.setAxisOffset(5.0);
|
||||
// legend.setMargin(new RectangleInsets(5, 5, 5, 5));
|
||||
// legend.setFrame(new BlockBorder(Color.red));
|
||||
// legend.setPadding(new RectangleInsets(10, 10, 10, 10));
|
||||
// legend.setStripWidth(10);
|
||||
// legend.setPosition(RectangleEdge.LEFT);
|
||||
// chart.addSubtitle(legend);
|
||||
//
|
||||
// ChartUtilities.applyCurrentTheme(chart);
|
||||
//
|
||||
// plot.setDataset(1, xyData);
|
||||
// XYLineAndShapeRenderer renderer2 = new XYLineAndShapeRenderer();
|
||||
// renderer2.setBaseLinesVisible(false);
|
||||
// plot.setRenderer(1, renderer2);
|
||||
//
|
||||
// plot.setDatasetRenderingOrder(DatasetRenderingOrder.FORWARD);
|
||||
//
|
||||
// return chart;
|
||||
// }
|
||||
//
|
||||
// public static INDArray generatePointsOnGraph(double xMin, double xMax, double yMin, double yMax, int nPointsPerAxis) {
|
||||
// //generate all the x,y points
|
||||
// double[][] evalPoints = new double[nPointsPerAxis * nPointsPerAxis][2];
|
||||
// int count = 0;
|
||||
// for (int i = 0; i < nPointsPerAxis; i++) {
|
||||
// for (int j = 0; j < nPointsPerAxis; j++) {
|
||||
// double x = i * (xMax - xMin) / (nPointsPerAxis - 1) + xMin;
|
||||
// double y = j * (yMax - yMin) / (nPointsPerAxis - 1) + yMin;
|
||||
//
|
||||
// evalPoints[count][0] = x;
|
||||
// evalPoints[count][1] = y;
|
||||
//
|
||||
// count++;
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// return Nd4j.create(evalPoints);
|
||||
// }
|
||||
//
|
||||
// /**
|
||||
// * This is to collect all the data and return it as one minibatch. Obviously only for use here with small datasets
|
||||
// * @param iter
|
||||
// * @return
|
||||
// */
|
||||
// private static DataSet allBatches(DataSetIterator iter) {
|
||||
//
|
||||
// List<DataSet> fullSet = new ArrayList<>();
|
||||
// iter.reset();
|
||||
// while (iter.hasNext()) {
|
||||
// List<DataSet> miniBatchList = iter.next().asList();
|
||||
// fullSet.addAll(miniBatchList);
|
||||
// }
|
||||
// iter.reset();
|
||||
// return new ListDataSetIterator<>(fullSet,fullSet.size()).next();
|
||||
// }
|
||||
//
|
||||
//}
|
|
@ -23,6 +23,7 @@ import java.io.BufferedReader;
|
|||
import java.io.IOException;
|
||||
import java.io.InputStreamReader;
|
||||
import java.io.Serializable;
|
||||
import java.math.BigDecimal;
|
||||
import java.text.Normalizer;
|
||||
import java.util.List;
|
||||
import java.util.regex.Matcher;
|
||||
|
@ -71,6 +72,30 @@ public class Utilities implements Serializable {
|
|||
}
|
||||
}
|
||||
|
||||
public static double[] getJPathArray(final String jsonPath, final String inputJson) {
|
||||
try {
|
||||
Object o = JsonPath.read(inputJson, jsonPath);
|
||||
if (o instanceof double[])
|
||||
return (double[]) o;
|
||||
if (o instanceof JSONArray) {
|
||||
Object[] objects = ((JSONArray) o).toArray();
|
||||
double[] array = new double[objects.length];
|
||||
for (int i = 0; i < objects.length; i++) {
|
||||
if (objects[i] instanceof BigDecimal)
|
||||
array[i] = ((BigDecimal)objects[i]).doubleValue();
|
||||
else
|
||||
array[i] = (double) objects[i];
|
||||
}
|
||||
return array;
|
||||
}
|
||||
return new double[0];
|
||||
}
|
||||
catch (Exception e) {
|
||||
e.printStackTrace();
|
||||
return new double[0];
|
||||
}
|
||||
}
|
||||
|
||||
// public static String normalize(final String s) {
|
||||
// return Normalizer.normalize(s, Normalizer.Form.NFD)
|
||||
// .replaceAll("[^\\w\\s-]", "") // Remove all non-word, non-space or non-dash characters
|
||||
|
|
|
@ -27,6 +27,6 @@ public class UtilityTest {
|
|||
Author a = new Author("De Bonis, Michele", "Æ", "De Bonis", new ArrayList<CoAuthor>(), new double[]{0.0, 1.0}, "author::id", "pub::id", "orcid");
|
||||
System.out.println("a = " + a.isAccurate());
|
||||
System.out.println(AuthorsFactory.getLNFI(a));
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -0,0 +1,47 @@
|
|||
package eu.dnetlib.deeplearning;
|
||||
|
||||
import com.beust.jcommander.internal.Sets;
|
||||
import com.google.common.collect.Lists;
|
||||
import eu.dnetlib.deeplearning.support.DataSetProcessor;
|
||||
import eu.dnetlib.support.Relation;
|
||||
import org.junit.jupiter.api.BeforeAll;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.linalg.dataset.MultiDataSet;
|
||||
|
||||
import java.util.*;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
public class DataSetProcessorTest {
|
||||
|
||||
static Map<String, double[]> features;
|
||||
static Set<Relation> relations;
|
||||
static List<String> groundTruth;
|
||||
|
||||
@BeforeAll
|
||||
public static void init(){
|
||||
//initialize example features
|
||||
features = new HashMap<>();
|
||||
features.put("0", new double[]{0.0,0.0});
|
||||
features.put("1", new double[]{1.0,1.0});
|
||||
features.put("2", new double[]{2.0,2.0});
|
||||
|
||||
//initialize example relations
|
||||
relations = new HashSet<>(Lists.newArrayList(
|
||||
new Relation("0", "1", "simrel"),
|
||||
new Relation("1", "2", "simrel")
|
||||
));
|
||||
|
||||
//initialize example ground truth
|
||||
groundTruth = Lists.newArrayList("class1", "class1", "class2");
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
public void getMultiDataSetTest() throws Exception {
|
||||
MultiDataSet multiDataSet = DataSetProcessor.getMultiDataSet(features, relations, groundTruth);
|
||||
System.out.println("multiDataSet = " + multiDataSet);
|
||||
|
||||
multiDataSet.asList();
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,33 @@
|
|||
package eu.dnetlib.deeplearning;
|
||||
|
||||
import eu.dnetlib.deeplearning.support.NetworkConfigurations;
|
||||
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
|
||||
import org.deeplearning4j.nn.graph.ComputationGraph;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
public class NetworkConfigurationTests {
|
||||
|
||||
public final static int N = 3; //number of nodes
|
||||
public final static int K = 7; //number of features
|
||||
|
||||
public static INDArray[] exampleGraph = new INDArray[]{
|
||||
Nd4j.zeros(N, K), //features
|
||||
Nd4j.ones(N, N), //adjacency
|
||||
Nd4j.ones(N, N) //degree
|
||||
};
|
||||
|
||||
@Test
|
||||
public void simpleGCNTest() {
|
||||
|
||||
ComputationGraphConfiguration simpleGCNConf = NetworkConfigurations.getSimpleGCN(3, K, 5, 2);
|
||||
ComputationGraph simpleGCN = new ComputationGraph(simpleGCNConf);
|
||||
simpleGCN.init();
|
||||
|
||||
INDArray[] output = simpleGCN.output(exampleGraph);
|
||||
System.out.println("output = " + output[0]);
|
||||
|
||||
}
|
||||
|
||||
}
|
90
pom.xml
90
pom.xml
|
@ -247,7 +247,7 @@
|
|||
<google.guava.version>15.0</google.guava.version>
|
||||
|
||||
<spark.version>2.2.0</spark.version>
|
||||
<sparknlp.version>2.5.5</sparknlp.version>
|
||||
<scala.binary.version>2.11</scala.binary.version>
|
||||
<jackson.version>2.6.5</jackson.version>
|
||||
<mockito-core.version>3.3.3</mockito-core.version>
|
||||
|
||||
|
@ -282,6 +282,10 @@
|
|||
<junit-jupiter.version>5.6.1</junit-jupiter.version>
|
||||
<maven.dependency.eu.dnetlib.dhp.dhp-build-assembly-resources.jar.path>../dhp-build/dhp-build-assembly-resources/target/dhp-build-assembly-resources-${project.version}.jar</maven.dependency.eu.dnetlib.dhp.dhp-build-assembly-resources.jar.path>
|
||||
|
||||
<!--deeplearning4j-->
|
||||
<dl4j-master.version>1.0.0-beta7</dl4j-master.version>
|
||||
<nd4j.backend>nd4j-native</nd4j.backend>
|
||||
|
||||
</properties>
|
||||
|
||||
<dependencyManagement>
|
||||
|
@ -446,6 +450,90 @@
|
|||
<version>2.10.29</version>
|
||||
</dependency>
|
||||
|
||||
<!--DEEPLEARNING4J-->
|
||||
<dependency>
|
||||
<groupId>org.nd4j</groupId>
|
||||
<artifactId>${nd4j.backend}</artifactId>
|
||||
<version>${dl4j-master.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.datavec</groupId>
|
||||
<artifactId>datavec-api</artifactId>
|
||||
<version>${dl4j-master.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.datavec</groupId>
|
||||
<artifactId>datavec-data-image</artifactId>
|
||||
<version>${dl4j-master.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.datavec</groupId>
|
||||
<artifactId>datavec-local</artifactId>
|
||||
<version>${dl4j-master.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.deeplearning4j</groupId>
|
||||
<artifactId>deeplearning4j-datasets</artifactId>
|
||||
<version>${dl4j-master.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.deeplearning4j</groupId>
|
||||
<artifactId>deeplearning4j-core</artifactId>
|
||||
<version>${dl4j-master.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.deeplearning4j</groupId>
|
||||
<artifactId>resources</artifactId>
|
||||
<version>${dl4j-master.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.deeplearning4j</groupId>
|
||||
<artifactId>deeplearning4j-ui</artifactId>
|
||||
<version>${dl4j-master.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.deeplearning4j</groupId>
|
||||
<artifactId>deeplearning4j-zoo</artifactId>
|
||||
<version>${dl4j-master.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.deeplearning4j</groupId>
|
||||
<artifactId>dl4j-spark-parameterserver_2.11</artifactId>
|
||||
<version>${dl4j-master.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.deeplearning4j</groupId>
|
||||
<artifactId>dl4j-spark_2.11</artifactId>
|
||||
<version>${dl4j-master.version}</version>
|
||||
</dependency>
|
||||
|
||||
<!--PLOT-->
|
||||
<dependency>
|
||||
<groupId>jfree</groupId>
|
||||
<artifactId>jfreechart</artifactId>
|
||||
<version>1.0.13</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.jfree</groupId>
|
||||
<artifactId>jcommon</artifactId>
|
||||
<version>1.0.23</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.deeplearning4j</groupId>
|
||||
<artifactId>deeplearning4j-datasets</artifactId>
|
||||
<version>${dl4j-master.version}</version>
|
||||
</dependency>
|
||||
|
||||
<!--DNET DEDUP-->
|
||||
<dependency>
|
||||
<groupId>eu.dnetlib</groupId>
|
||||
<artifactId>dnet-dedup-test</artifactId>
|
||||
<version>4.1.13-SNAPSHOT</version>
|
||||
</dependency>
|
||||
|
||||
</dependencies>
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue