implementation of GCN network: first commit
This commit is contained in:
parent
88451e3832
commit
4c9d33171d
|
@ -36,7 +36,7 @@ public abstract class AbstractSparkJob implements Serializable {
|
||||||
this.spark = spark;
|
this.spark = spark;
|
||||||
}
|
}
|
||||||
|
|
||||||
abstract void run() throws IOException;
|
protected abstract void run() throws IOException;
|
||||||
|
|
||||||
protected static SparkSession getSparkSession(SparkConf conf) {
|
protected static SparkSession getSparkSession(SparkConf conf) {
|
||||||
return SparkSession.builder().config(conf).getOrCreate();
|
return SparkSession.builder().config(conf).getOrCreate();
|
||||||
|
|
|
@ -45,7 +45,7 @@ public class SparkLDAAnalysis extends AbstractSparkJob {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
void run() throws IOException {
|
protected void run() throws IOException {
|
||||||
// read oozie parameters
|
// read oozie parameters
|
||||||
final String authorsPath = parser.get("authorsPath");
|
final String authorsPath = parser.get("authorsPath");
|
||||||
final String workingPath = parser.get("workingPath");
|
final String workingPath = parser.get("workingPath");
|
||||||
|
|
|
@ -0,0 +1,75 @@
|
||||||
|
package eu.dnetlib.jobs.deeplearning;
|
||||||
|
|
||||||
|
import eu.dnetlib.deeplearning.support.DataSetProcessor;
|
||||||
|
import eu.dnetlib.jobs.AbstractSparkJob;
|
||||||
|
import eu.dnetlib.jobs.SparkLDATuning;
|
||||||
|
import eu.dnetlib.support.ArgumentApplicationParser;
|
||||||
|
import eu.dnetlib.support.ConnectedComponent;
|
||||||
|
import org.apache.spark.SparkConf;
|
||||||
|
import org.apache.spark.api.java.JavaRDD;
|
||||||
|
import org.apache.spark.api.java.JavaSparkContext;
|
||||||
|
import org.apache.spark.sql.SparkSession;
|
||||||
|
import org.codehaus.jackson.map.ObjectMapper;
|
||||||
|
import org.deeplearning4j.spark.data.BatchAndExportDataSetsFunction;
|
||||||
|
import org.deeplearning4j.spark.data.BatchAndExportMultiDataSetsFunction;
|
||||||
|
import org.deeplearning4j.spark.datavec.iterator.IteratorUtils;
|
||||||
|
import org.nd4j.linalg.dataset.MultiDataSet;
|
||||||
|
import org.slf4j.Logger;
|
||||||
|
import org.slf4j.LoggerFactory;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Optional;
|
||||||
|
|
||||||
|
public class SparkCreateGroupDataSet extends AbstractSparkJob {
|
||||||
|
|
||||||
|
private static final Logger log = LoggerFactory.getLogger(SparkCreateGroupDataSet.class);
|
||||||
|
|
||||||
|
public SparkCreateGroupDataSet(ArgumentApplicationParser parser, SparkSession spark) {
|
||||||
|
super(parser, spark);
|
||||||
|
}
|
||||||
|
public static void main(String[] args) throws Exception {
|
||||||
|
ArgumentApplicationParser parser = new ArgumentApplicationParser(
|
||||||
|
readResource("/jobs/parameters/createGroupDataset_parameters.json", SparkLDATuning.class)
|
||||||
|
);
|
||||||
|
|
||||||
|
parser.parseArgument(args);
|
||||||
|
|
||||||
|
SparkConf conf = new SparkConf();
|
||||||
|
|
||||||
|
new SparkCreateGroupDataSet(
|
||||||
|
parser,
|
||||||
|
getSparkSession(conf)
|
||||||
|
).run();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void run() throws IOException {
|
||||||
|
// read oozie parameters
|
||||||
|
final String groupsPath = parser.get("groupsPath");
|
||||||
|
final String workingPath = parser.get("workingPath");
|
||||||
|
final String groundTruthJPath = parser.get("groundTruthJPath");
|
||||||
|
final String idJPath = parser.get("idJPath");
|
||||||
|
final String featuresJPath = parser.get("featuresJPath");
|
||||||
|
final int numPartitions = Optional
|
||||||
|
.ofNullable(parser.get("numPartitions"))
|
||||||
|
.map(Integer::valueOf)
|
||||||
|
.orElse(NUM_PARTITIONS);
|
||||||
|
|
||||||
|
log.info("groupsPath: '{}'", groupsPath);
|
||||||
|
log.info("workingPath: '{}'", workingPath);
|
||||||
|
log.info("groundTruthJPath: '{}'", groundTruthJPath);
|
||||||
|
log.info("idJPath: '{}'", idJPath);
|
||||||
|
log.info("featuresJPath: '{}'", featuresJPath);
|
||||||
|
log.info("numPartitions: '{}'", numPartitions);
|
||||||
|
|
||||||
|
JavaSparkContext context = JavaSparkContext.fromSparkContext(spark.sparkContext());
|
||||||
|
|
||||||
|
JavaRDD<ConnectedComponent> groups = context.textFile(groupsPath).map(g -> new ObjectMapper().readValue(g, ConnectedComponent.class));
|
||||||
|
|
||||||
|
JavaRDD<MultiDataSet> dataset = DataSetProcessor.entityGroupToMultiDataset(groups, idJPath, featuresJPath, groundTruthJPath);
|
||||||
|
|
||||||
|
dataset.saveAsObjectFile(workingPath + "/groupDataset");
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,99 @@
|
||||||
|
package eu.dnetlib.jobs.deeplearning;
|
||||||
|
|
||||||
|
import eu.dnetlib.deeplearning.support.DataSetProcessor;
|
||||||
|
import eu.dnetlib.deeplearning.support.NetworkConfigurations;
|
||||||
|
import eu.dnetlib.jobs.AbstractSparkJob;
|
||||||
|
import eu.dnetlib.jobs.SparkLDATuning;
|
||||||
|
import eu.dnetlib.support.ArgumentApplicationParser;
|
||||||
|
import eu.dnetlib.support.ConnectedComponent;
|
||||||
|
import org.apache.spark.SparkConf;
|
||||||
|
import org.apache.spark.api.java.JavaRDD;
|
||||||
|
import org.apache.spark.api.java.JavaSparkContext;
|
||||||
|
import org.apache.spark.sql.SparkSession;
|
||||||
|
import org.codehaus.jackson.map.ObjectMapper;
|
||||||
|
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
|
||||||
|
import org.deeplearning4j.nn.graph.ComputationGraph;
|
||||||
|
import org.deeplearning4j.optimize.listeners.PerformanceListener;
|
||||||
|
import org.deeplearning4j.optimize.solvers.accumulation.encoding.threshold.AdaptiveThresholdAlgorithm;
|
||||||
|
import org.deeplearning4j.spark.api.RDDTrainingApproach;
|
||||||
|
import org.deeplearning4j.spark.api.TrainingMaster;
|
||||||
|
import org.deeplearning4j.spark.impl.graph.SparkComputationGraph;
|
||||||
|
import org.deeplearning4j.spark.parameterserver.training.SharedTrainingMaster;
|
||||||
|
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
||||||
|
import org.nd4j.parameterserver.distributed.conf.VoidConfiguration;
|
||||||
|
import org.slf4j.Logger;
|
||||||
|
import org.slf4j.LoggerFactory;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.Optional;
|
||||||
|
|
||||||
|
public class SparkGraphClassificationTraining extends AbstractSparkJob {
|
||||||
|
|
||||||
|
private static final Logger log = LoggerFactory.getLogger(SparkGraphClassificationTraining.class);
|
||||||
|
|
||||||
|
public SparkGraphClassificationTraining(ArgumentApplicationParser parser, SparkSession spark) {
|
||||||
|
super(parser, spark);
|
||||||
|
}
|
||||||
|
public static void main(String[] args) throws Exception {
|
||||||
|
ArgumentApplicationParser parser = new ArgumentApplicationParser(
|
||||||
|
readResource("/jobs/parameters/graphClassificationTraining_parameters.json", SparkLDATuning.class)
|
||||||
|
);
|
||||||
|
|
||||||
|
parser.parseArgument(args);
|
||||||
|
|
||||||
|
SparkConf conf = new SparkConf();
|
||||||
|
|
||||||
|
new SparkGraphClassificationTraining(
|
||||||
|
parser,
|
||||||
|
getSparkSession(conf)
|
||||||
|
).run();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void run() throws IOException {
|
||||||
|
// read oozie parameters
|
||||||
|
final String workingPath = parser.get("workingPath");
|
||||||
|
final int numPartitions = Optional
|
||||||
|
.ofNullable(parser.get("numPartitions"))
|
||||||
|
.map(Integer::valueOf)
|
||||||
|
.orElse(NUM_PARTITIONS);
|
||||||
|
log.info("workingPath: '{}'", workingPath);
|
||||||
|
log.info("numPartitions: '{}'", numPartitions);
|
||||||
|
|
||||||
|
JavaSparkContext context = JavaSparkContext.fromSparkContext(spark.sparkContext());
|
||||||
|
|
||||||
|
VoidConfiguration conf = VoidConfiguration.builder()
|
||||||
|
.unicastPort(40123)
|
||||||
|
// .networkMask("255.255.148.0/22")
|
||||||
|
.controllerAddress("127.0.0.1")
|
||||||
|
.build();
|
||||||
|
|
||||||
|
TrainingMaster trainingMaster = new SharedTrainingMaster.Builder(conf,1)
|
||||||
|
.rngSeed(12345)
|
||||||
|
.collectTrainingStats(false)
|
||||||
|
.thresholdAlgorithm(new AdaptiveThresholdAlgorithm(1e-3))
|
||||||
|
.batchSizePerWorker(32)
|
||||||
|
.workersPerNode(4)
|
||||||
|
.rddTrainingApproach(RDDTrainingApproach.Direct)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
JavaRDD<MultiDataSet> trainData = context.objectFile(workingPath + "/groupDataset");
|
||||||
|
|
||||||
|
SparkComputationGraph sparkComputationGraph = new SparkComputationGraph(
|
||||||
|
context,
|
||||||
|
NetworkConfigurations.getSimpleGCN(3, 2, 5, 2),
|
||||||
|
trainingMaster);
|
||||||
|
sparkComputationGraph.setListeners(new PerformanceListener(10, true));
|
||||||
|
|
||||||
|
//execute training
|
||||||
|
for (int i = 0; i < 20; i ++) {
|
||||||
|
sparkComputationGraph.fitMultiDataSet(trainData);
|
||||||
|
}
|
||||||
|
|
||||||
|
ComputationGraph network = sparkComputationGraph.getNetwork();
|
||||||
|
|
||||||
|
System.out.println("network = " + network.getConfiguration().toJson());
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,38 @@
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"paramName": "i",
|
||||||
|
"paramLongName": "groupsPath",
|
||||||
|
"paramDescription": "the input data: groups to be transformed into dataset",
|
||||||
|
"paramRequired": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"paramName": "f",
|
||||||
|
"paramLongName": "featuresJPath",
|
||||||
|
"paramDescription": "the jpath of the features field",
|
||||||
|
"paramRequired": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"paramName": "w",
|
||||||
|
"paramLongName": "workingPath",
|
||||||
|
"paramDescription": "path of the working directory",
|
||||||
|
"paramRequired": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"paramName": "np",
|
||||||
|
"paramLongName": "numPartitions",
|
||||||
|
"paramDescription": "number of partitions for the similarity relations intermediate phases",
|
||||||
|
"paramRequired": false
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"paramName": "id",
|
||||||
|
"paramLongName": "idJPath",
|
||||||
|
"paramDescription": "the jpath of the id field",
|
||||||
|
"paramRequired": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"paramName": "gt",
|
||||||
|
"paramLongName": "groundTruthJPath",
|
||||||
|
"paramDescription": "the jpath of the field to be used as ground truth",
|
||||||
|
"paramRequired": true
|
||||||
|
}
|
||||||
|
]
|
|
@ -0,0 +1,14 @@
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"paramName": "w",
|
||||||
|
"paramLongName": "workingPath",
|
||||||
|
"paramDescription": "path of the working directory",
|
||||||
|
"paramRequired": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"paramName": "np",
|
||||||
|
"paramLongName": "numPartitions",
|
||||||
|
"paramDescription": "number of partitions for the similarity relations intermediate phases",
|
||||||
|
"paramRequired": false
|
||||||
|
}
|
||||||
|
]
|
|
@ -0,0 +1,102 @@
|
||||||
|
package eu.dnetlib.jobs.deeplearning;
|
||||||
|
|
||||||
|
import eu.dnetlib.jobs.AbstractSparkJob;
|
||||||
|
import eu.dnetlib.support.ArgumentApplicationParser;
|
||||||
|
import org.apache.commons.io.FileUtils;
|
||||||
|
import org.apache.commons.io.IOUtils;
|
||||||
|
import org.apache.spark.api.java.JavaSparkContext;
|
||||||
|
import org.apache.spark.sql.SparkSession;
|
||||||
|
import org.junit.jupiter.api.*;
|
||||||
|
|
||||||
|
import java.io.File;
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.net.URISyntaxException;
|
||||||
|
import java.nio.file.Paths;
|
||||||
|
|
||||||
|
@TestMethodOrder(MethodOrderer.OrderAnnotation.class)
|
||||||
|
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
|
||||||
|
public class GNNTrainingTest {
|
||||||
|
|
||||||
|
static SparkSession spark;
|
||||||
|
static JavaSparkContext context;
|
||||||
|
final static String workingPath = "/tmp/working_dir";
|
||||||
|
|
||||||
|
final static String numPartitions = "20";
|
||||||
|
final String inputDataPath = Paths
|
||||||
|
.get(getClass().getResource("/eu/dnetlib/jobs/examples/authors.groups.example.json").toURI())
|
||||||
|
.toFile()
|
||||||
|
.getAbsolutePath();
|
||||||
|
final static String groundTruthJPath = "$.orcid";
|
||||||
|
final static String idJPath = "$.id";
|
||||||
|
final static String featuresJPath = "$.topics";
|
||||||
|
|
||||||
|
public GNNTrainingTest() throws URISyntaxException {}
|
||||||
|
|
||||||
|
public static void cleanup() throws IOException {
|
||||||
|
//remove directories and clean workspace
|
||||||
|
FileUtils.deleteDirectory(new File(workingPath));
|
||||||
|
}
|
||||||
|
|
||||||
|
@BeforeAll
|
||||||
|
public void setup() throws IOException {
|
||||||
|
cleanup();
|
||||||
|
|
||||||
|
spark = SparkSession
|
||||||
|
.builder()
|
||||||
|
.appName("Testing")
|
||||||
|
.master("local[*]")
|
||||||
|
.getOrCreate();
|
||||||
|
|
||||||
|
context = JavaSparkContext.fromSparkContext(spark.sparkContext());
|
||||||
|
}
|
||||||
|
|
||||||
|
@AfterAll
|
||||||
|
public static void finalCleanUp() throws IOException {
|
||||||
|
cleanup();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
@Order(1)
|
||||||
|
public void createGroupDataSetTest() throws Exception {
|
||||||
|
ArgumentApplicationParser parser = new ArgumentApplicationParser(readResource("/jobs/parameters/createGroupDataset_parameters.json", SparkCreateGroupDataSet.class));
|
||||||
|
|
||||||
|
parser.parseArgument(
|
||||||
|
new String[] {
|
||||||
|
"-i", inputDataPath,
|
||||||
|
"-gt", groundTruthJPath,
|
||||||
|
"-id", idJPath,
|
||||||
|
"-f", featuresJPath,
|
||||||
|
"-w", workingPath,
|
||||||
|
"-np", numPartitions
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
new SparkCreateGroupDataSet(
|
||||||
|
parser,
|
||||||
|
spark
|
||||||
|
).run();
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
@Order(2)
|
||||||
|
public void graphClassificationTrainingTest() throws Exception{
|
||||||
|
ArgumentApplicationParser parser = new ArgumentApplicationParser(readResource("/jobs/parameters/graphClassificationTraining_parameters.json", SparkGraphClassificationTraining.class));
|
||||||
|
|
||||||
|
parser.parseArgument(
|
||||||
|
new String[] {
|
||||||
|
"-w", workingPath,
|
||||||
|
"-np", numPartitions
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
new SparkGraphClassificationTraining(
|
||||||
|
parser,
|
||||||
|
spark
|
||||||
|
).run();
|
||||||
|
}
|
||||||
|
|
||||||
|
public static String readResource(String path, Class<? extends AbstractSparkJob> clazz) throws IOException {
|
||||||
|
return IOUtils.toString(clazz.getResourceAsStream(path));
|
||||||
|
}
|
||||||
|
}
|
File diff suppressed because one or more lines are too long
|
@ -51,6 +51,45 @@
|
||||||
<artifactId>junit-jupiter</artifactId>
|
<artifactId>junit-jupiter</artifactId>
|
||||||
<scope>test</scope>
|
<scope>test</scope>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
|
||||||
|
<!--DEEPLEARNING4J -->
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.nd4j</groupId>
|
||||||
|
<artifactId>${nd4j.backend}</artifactId>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.deeplearning4j</groupId>
|
||||||
|
<artifactId>deeplearning4j-core</artifactId>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.deeplearning4j</groupId>
|
||||||
|
<artifactId>deeplearning4j-datasets</artifactId>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.deeplearning4j</groupId>
|
||||||
|
<artifactId>dl4j-spark-parameterserver_2.11</artifactId>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.deeplearning4j</groupId>
|
||||||
|
<artifactId>dl4j-spark_2.11</artifactId>
|
||||||
|
</dependency>
|
||||||
|
|
||||||
|
<!--PLOT-->
|
||||||
|
<dependency>
|
||||||
|
<groupId>jfree</groupId>
|
||||||
|
<artifactId>jfreechart</artifactId>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.jfree</groupId>
|
||||||
|
<artifactId>jcommon</artifactId>
|
||||||
|
</dependency>
|
||||||
|
|
||||||
|
<!--DNET DEDUP-->
|
||||||
|
<dependency>
|
||||||
|
<groupId>eu.dnetlib</groupId>
|
||||||
|
<artifactId>dnet-dedup-test</artifactId>
|
||||||
|
</dependency>
|
||||||
|
|
||||||
</dependencies>
|
</dependencies>
|
||||||
|
|
||||||
</project>
|
</project>
|
|
@ -0,0 +1,130 @@
|
||||||
|
//package eu.dnetlib.deeplearning;
|
||||||
|
//
|
||||||
|
///* *****************************************************************************
|
||||||
|
// *
|
||||||
|
// *
|
||||||
|
// *
|
||||||
|
// * This program and the accompanying materials are made available under the
|
||||||
|
// * terms of the Apache License, Version 2.0 which is available at
|
||||||
|
// * https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
// * See the NOTICE file distributed with this work for additional
|
||||||
|
// * information regarding copyright ownership.
|
||||||
|
// *
|
||||||
|
// * Unless required by applicable law or agreed to in writing, software
|
||||||
|
// * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
// * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
// * License for the specific language governing permissions and limitations
|
||||||
|
// * under the License.
|
||||||
|
// *
|
||||||
|
// * SPDX-License-Identifier: Apache-2.0
|
||||||
|
// ******************************************************************************/
|
||||||
|
//
|
||||||
|
//import org.datavec.api.records.reader.RecordReader;
|
||||||
|
//import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
|
||||||
|
//import org.datavec.api.split.FileSplit;
|
||||||
|
//import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
|
||||||
|
//import org.deeplearning4j.examples.utils.DownloaderUtility;
|
||||||
|
//import org.deeplearning4j.examples.utils.PlotUtil;
|
||||||
|
//import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||||
|
//import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
|
//import org.deeplearning4j.nn.conf.layers.DenseLayer;
|
||||||
|
//import org.deeplearning4j.nn.conf.layers.OutputLayer;
|
||||||
|
//import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||||
|
//import org.deeplearning4j.nn.weights.WeightInit;
|
||||||
|
//import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
|
||||||
|
//import org.nd4j.evaluation.classification.Evaluation;
|
||||||
|
//import org.nd4j.linalg.activations.Activation;
|
||||||
|
//import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
//import org.nd4j.linalg.dataset.DataSet;
|
||||||
|
//import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||||
|
//import org.nd4j.linalg.learning.config.Nesterovs;
|
||||||
|
//import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;
|
||||||
|
//
|
||||||
|
//import java.io.File;
|
||||||
|
//import java.util.concurrent.TimeUnit;
|
||||||
|
//
|
||||||
|
//public class GroupClassifier {
|
||||||
|
//
|
||||||
|
// public static boolean visualize = true;
|
||||||
|
// public static String dataLocalPath;
|
||||||
|
//
|
||||||
|
// public static void main(String[] args) throws Exception {
|
||||||
|
// int seed = 123;
|
||||||
|
// double learningRate = 0.01;
|
||||||
|
// int batchSize = 50;
|
||||||
|
// int nEpochs = 30;
|
||||||
|
//
|
||||||
|
// int numInputs = 2;
|
||||||
|
// int numOutputs = 2;
|
||||||
|
// int numHiddenNodes = 20;
|
||||||
|
//
|
||||||
|
// dataLocalPath = DownloaderUtility.CLASSIFICATIONDATA.Download();
|
||||||
|
// //Load the training data:
|
||||||
|
// RecordReader rr = new CSVRecordReader();
|
||||||
|
// rr.initialize(new FileSplit(new File(dataLocalPath, "linear_data_train.csv")));
|
||||||
|
// DataSetIterator trainIter = new RecordReaderDataSetIterator(rr, batchSize, 0, 2);
|
||||||
|
//
|
||||||
|
// //Load the test/evaluation data:
|
||||||
|
// RecordReader rrTest = new CSVRecordReader();
|
||||||
|
// rrTest.initialize(new FileSplit(new File(dataLocalPath, "linear_data_eval.csv")));
|
||||||
|
// DataSetIterator testIter = new RecordReaderDataSetIterator(rrTest, batchSize, 0, 2);
|
||||||
|
//
|
||||||
|
// MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
||||||
|
// .seed(seed)
|
||||||
|
// .weightInit(WeightInit.XAVIER)
|
||||||
|
// .updater(new Nesterovs(learningRate, 0.9))
|
||||||
|
// .list()
|
||||||
|
// .layer(new DenseLayer.Builder().nIn(numInputs).nOut(numHiddenNodes)
|
||||||
|
// .activation(Activation.RELU)
|
||||||
|
// .build())
|
||||||
|
// .layer(new OutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD)
|
||||||
|
// .activation(Activation.SOFTMAX)
|
||||||
|
// .nIn(numHiddenNodes).nOut(numOutputs).build())
|
||||||
|
// .build();
|
||||||
|
//
|
||||||
|
//
|
||||||
|
// MultiLayerNetwork model = new MultiLayerNetwork(conf);
|
||||||
|
// model.init();
|
||||||
|
// model.setListeners(new ScoreIterationListener(10)); //Print score every 10 parameter updates
|
||||||
|
//
|
||||||
|
// model.fit(trainIter, nEpochs);
|
||||||
|
//
|
||||||
|
// System.out.println("Evaluate model....");
|
||||||
|
// Evaluation eval = new Evaluation(numOutputs);
|
||||||
|
// while (testIter.hasNext()) {
|
||||||
|
// DataSet t = testIter.next();
|
||||||
|
// INDArray features = t.getFeatures();
|
||||||
|
// INDArray labels = t.getLabels();
|
||||||
|
// INDArray predicted = model.output(features, false);
|
||||||
|
// eval.eval(labels, predicted);
|
||||||
|
// }
|
||||||
|
// //An alternate way to do the above loop
|
||||||
|
// //Evaluation evalResults = model.evaluate(testIter);
|
||||||
|
//
|
||||||
|
// //Print the evaluation statistics
|
||||||
|
// System.out.println(eval.stats());
|
||||||
|
//
|
||||||
|
// System.out.println("\n****************Example finished********************");
|
||||||
|
// //Training is complete. Code that follows is for plotting the data & predictions only
|
||||||
|
// generateVisuals(model, trainIter, testIter);
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// public static void generateVisuals(MultiLayerNetwork model, DataSetIterator trainIter, DataSetIterator testIter) throws Exception {
|
||||||
|
// if (visualize) {
|
||||||
|
// double xMin = 0;
|
||||||
|
// double xMax = 1.0;
|
||||||
|
// double yMin = -0.2;
|
||||||
|
// double yMax = 0.8;
|
||||||
|
// int nPointsPerAxis = 100;
|
||||||
|
//
|
||||||
|
// //Generate x,y points that span the whole range of features
|
||||||
|
// INDArray allXYPoints = PlotUtil.generatePointsOnGraph(xMin, xMax, yMin, yMax, nPointsPerAxis);
|
||||||
|
// //Get train data and plot with predictions
|
||||||
|
// PlotUtil.plotTrainingData(model, trainIter, allXYPoints, nPointsPerAxis);
|
||||||
|
// TimeUnit.SECONDS.sleep(3);
|
||||||
|
// //Get test data, run the test data through the network to generate predictions, and plot those predictions:
|
||||||
|
// PlotUtil.plotTestData(model, testIter, allXYPoints, nPointsPerAxis);
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
//}
|
||||||
|
//
|
|
@ -0,0 +1,24 @@
|
||||||
|
package eu.dnetlib.deeplearning.layers;
|
||||||
|
|
||||||
|
import org.deeplearning4j.nn.conf.graph.GraphVertex;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.samediff.SDVertexParams;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.samediff.SameDiffLambdaVertex;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.samediff.SameDiffVertex;
|
||||||
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
public class GraphConvolutionVertex extends SameDiffLambdaVertex {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public SDVariable defineVertex(SameDiff sameDiff, VertexInputs inputs) {
|
||||||
|
SDVariable features = inputs.getInput(0);
|
||||||
|
SDVariable adjacency = inputs.getInput(1);
|
||||||
|
SDVariable degree = inputs.getInput(2).pow(0.5);
|
||||||
|
|
||||||
|
//result: DegreeMatrix^-0.5 x Adjacent x DegreeMatrix^-0.5 x Features
|
||||||
|
return degree.mmul(adjacency).mmul(degree).mmul(features);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,21 @@
|
||||||
|
package eu.dnetlib.deeplearning.layers;
|
||||||
|
|
||||||
|
import org.deeplearning4j.nn.conf.layers.samediff.SameDiffLambdaLayer;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer;
|
||||||
|
import org.nd4j.autodiff.samediff.SDIndex;
|
||||||
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
public class GraphGlobalAddPool extends SameDiffLambdaLayer {
|
||||||
|
|
||||||
|
int size;
|
||||||
|
public GraphGlobalAddPool(int size) {
|
||||||
|
this.size = size;
|
||||||
|
}
|
||||||
|
@Override
|
||||||
|
public SDVariable defineLayer(SameDiff sameDiff, SDVariable layerInput) {
|
||||||
|
return layerInput.mean(0).reshape(1, size); //reshape because output layer expects 2-dimensional arrays
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,88 @@
|
||||||
|
package eu.dnetlib.deeplearning.support;
|
||||||
|
|
||||||
|
import eu.dnetlib.featureextraction.Utilities;
|
||||||
|
import eu.dnetlib.support.Author;
|
||||||
|
import eu.dnetlib.support.ConnectedComponent;
|
||||||
|
import eu.dnetlib.support.Relation;
|
||||||
|
import org.apache.spark.api.java.JavaRDD;
|
||||||
|
import org.codehaus.jackson.map.ObjectMapper;
|
||||||
|
import org.jetbrains.annotations.NotNull;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.dataset.MultiDataSet;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.*;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
import java.util.stream.Stream;
|
||||||
|
|
||||||
|
public class DataSetProcessor {
|
||||||
|
|
||||||
|
public static JavaRDD<MultiDataSet> entityGroupToMultiDataset(JavaRDD<ConnectedComponent> groupEntity, String idJPath, String featureJPath, String groundTruthJPath) {
|
||||||
|
|
||||||
|
return groupEntity.map(g -> {
|
||||||
|
Map<String, double[]> featuresMap = new HashMap<>();
|
||||||
|
List<String> groundTruth = new ArrayList<>();
|
||||||
|
Set<String> entities = g.getDocs();
|
||||||
|
for(String json:entities) {
|
||||||
|
featuresMap.put(
|
||||||
|
Utilities.getJPathString(idJPath, json),
|
||||||
|
Utilities.getJPathArray(featureJPath, json)
|
||||||
|
);
|
||||||
|
groundTruth.add(Utilities.getJPathString(groundTruthJPath, json));
|
||||||
|
}
|
||||||
|
|
||||||
|
Set<Relation> relations = g.getSimrels();
|
||||||
|
|
||||||
|
return getMultiDataSet(featuresMap, relations, groundTruth);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
public static MultiDataSet getMultiDataSet(Map<String, double[]> featuresMap, Set<Relation> relations, List<String> groundTruth) {
|
||||||
|
|
||||||
|
List<String> identifiers = new ArrayList<>(featuresMap.keySet());
|
||||||
|
|
||||||
|
int numNodes = identifiers.size();
|
||||||
|
|
||||||
|
//initialize arrays
|
||||||
|
INDArray adjacency = Nd4j.zeros(numNodes, numNodes);
|
||||||
|
INDArray features = Nd4j.zeros(numNodes, featuresMap.get(identifiers.get(0)).length); //feature size taken from the first element (it's equal for every element)
|
||||||
|
INDArray degree = Nd4j.zeros(numNodes, numNodes);
|
||||||
|
|
||||||
|
//create adjacency
|
||||||
|
for(Relation r: relations) {
|
||||||
|
adjacency.put(identifiers.indexOf(r.getSource()), identifiers.indexOf(r.getTarget()), 1);
|
||||||
|
adjacency.put(identifiers.indexOf(r.getTarget()), identifiers.indexOf(r.getSource()), 1);
|
||||||
|
}
|
||||||
|
adjacency.addi(Nd4j.eye(numNodes));
|
||||||
|
|
||||||
|
//create degree and features
|
||||||
|
List<String> degreeSupport = relations.stream().flatMap(r -> Stream.of(r.getSource(), r.getTarget())).collect(Collectors.toList());
|
||||||
|
for(int i=0; i< identifiers.size(); i++) {
|
||||||
|
degree.put(i, i, Collections.frequency(degreeSupport, identifiers.get(i)));
|
||||||
|
features.putRow(i, Nd4j.create(featuresMap.get(identifiers.get(i))));
|
||||||
|
}
|
||||||
|
|
||||||
|
//infer label
|
||||||
|
INDArray label = Nd4j.zeros(1, 2);
|
||||||
|
if (groundTruth.stream().distinct().count()==1) {
|
||||||
|
//correct (same elements)
|
||||||
|
label.put(0, 0, 1.0);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
//wrong (different elements)
|
||||||
|
label.put(0, 1, 1.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
return new MultiDataSet(
|
||||||
|
new INDArray[]{
|
||||||
|
features,
|
||||||
|
adjacency,
|
||||||
|
degree
|
||||||
|
},
|
||||||
|
new INDArray[]{
|
||||||
|
label
|
||||||
|
}
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,11 @@
|
||||||
|
package eu.dnetlib.deeplearning.support;
|
||||||
|
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.dataset.MultiDataSet;
|
||||||
|
|
||||||
|
import java.io.*;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
public class GroupMultiDataSet extends MultiDataSet {
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,97 @@
|
||||||
|
package eu.dnetlib.deeplearning.support;
|
||||||
|
|
||||||
|
import eu.dnetlib.deeplearning.layers.GraphConvolutionVertex;
|
||||||
|
import eu.dnetlib.deeplearning.layers.GraphGlobalAddPool;
|
||||||
|
import org.bytedeco.opencv.opencv_dnn.PoolingLayer;
|
||||||
|
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
|
||||||
|
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||||
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
|
import org.deeplearning4j.nn.conf.graph.MergeVertex;
|
||||||
|
import org.deeplearning4j.nn.conf.layers.*;
|
||||||
|
import org.deeplearning4j.nn.weights.WeightInit;
|
||||||
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
|
import org.nd4j.linalg.activations.Activation;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.learning.config.Adam;
|
||||||
|
import org.nd4j.linalg.learning.config.Nesterovs;
|
||||||
|
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||||
|
|
||||||
|
public class NetworkConfigurations {
|
||||||
|
|
||||||
|
//parameteres default values
|
||||||
|
protected static final int SEED = 12345;
|
||||||
|
protected static final double LEARNING_RATE = 1e-3;
|
||||||
|
protected static final String ADJACENCY_MATRIX = "adjacency";
|
||||||
|
protected static final String FEATURES_MATRIX = "features";
|
||||||
|
protected static final String DEGREE_MATRIX = "degrees";
|
||||||
|
|
||||||
|
public static MultiLayerConfiguration getLinearDataClassifier(int numInputs, int numHiddenNodes, int numOutputs) {
|
||||||
|
return new NeuralNetConfiguration.Builder()
|
||||||
|
.seed(SEED)
|
||||||
|
.weightInit(WeightInit.XAVIER)
|
||||||
|
.updater(new Nesterovs(LEARNING_RATE, 0.9))
|
||||||
|
.list()
|
||||||
|
.layer(new DenseLayer.Builder().nIn(numInputs).nOut(numHiddenNodes)
|
||||||
|
.activation(Activation.RELU)
|
||||||
|
.build())
|
||||||
|
.layer(new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
|
||||||
|
.activation(Activation.SOFTMAX)
|
||||||
|
.nIn(numHiddenNodes).nOut(numOutputs).build())
|
||||||
|
.build();
|
||||||
|
}
|
||||||
|
|
||||||
|
public static ComputationGraphConfiguration getSimpleGCN(int numLayers, int numInputs, int numHiddenNodes, int numClasses) {
|
||||||
|
|
||||||
|
ComputationGraphConfiguration.GraphBuilder baseConfig = new NeuralNetConfiguration.Builder()
|
||||||
|
.seed(SEED)
|
||||||
|
.updater(new Adam(LEARNING_RATE))
|
||||||
|
.weightInit(WeightInit.XAVIER)
|
||||||
|
.graphBuilder()
|
||||||
|
.addInputs(FEATURES_MATRIX, ADJACENCY_MATRIX, DEGREE_MATRIX)
|
||||||
|
//first convolution layer
|
||||||
|
.addVertex("layer1",
|
||||||
|
new GraphConvolutionVertex(),
|
||||||
|
FEATURES_MATRIX, ADJACENCY_MATRIX, DEGREE_MATRIX)
|
||||||
|
.layer("conv1",
|
||||||
|
new DenseLayer.Builder().nIn(numInputs).nOut(numHiddenNodes)
|
||||||
|
.activation(Activation.RELU)
|
||||||
|
.build(),
|
||||||
|
"layer1")
|
||||||
|
.layer("batch1",
|
||||||
|
new BatchNormalization.Builder().nOut(numHiddenNodes).build(),
|
||||||
|
"conv1");
|
||||||
|
|
||||||
|
//ad as many layers as requested
|
||||||
|
for(int i=2; i<=numLayers; i++) {
|
||||||
|
baseConfig = baseConfig.addVertex("layer" + i,
|
||||||
|
new GraphConvolutionVertex(),
|
||||||
|
"batch" + (i-1), ADJACENCY_MATRIX, DEGREE_MATRIX)
|
||||||
|
.layer("conv" + i,
|
||||||
|
new DenseLayer.Builder().nIn(numHiddenNodes).nOut(numHiddenNodes)
|
||||||
|
.activation(Activation.RELU)
|
||||||
|
.build(),
|
||||||
|
"layer" + i)
|
||||||
|
.layer("batch" + i,
|
||||||
|
new BatchNormalization.Builder().nOut(numHiddenNodes).build(),
|
||||||
|
"conv" + i);
|
||||||
|
}
|
||||||
|
|
||||||
|
baseConfig = baseConfig
|
||||||
|
.layer("pool",
|
||||||
|
new GraphGlobalAddPool(numHiddenNodes),
|
||||||
|
"batch" + numLayers)
|
||||||
|
.layer("fc1",
|
||||||
|
new DenseLayer.Builder().nIn(numHiddenNodes).nOut(numHiddenNodes)
|
||||||
|
.activation(Activation.RELU)
|
||||||
|
.weightInit(WeightInit.XAVIER)
|
||||||
|
.build(),
|
||||||
|
"pool")
|
||||||
|
.layer("out",
|
||||||
|
new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
|
||||||
|
.activation(Activation.SOFTMAX)
|
||||||
|
.nIn(numHiddenNodes).nOut(numClasses).build(),
|
||||||
|
"fc1");
|
||||||
|
|
||||||
|
return baseConfig.setOutputs("out").build();
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,253 @@
|
||||||
|
//package eu.dnetlib.deeplearning.support;
|
||||||
|
//
|
||||||
|
//import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||||||
|
//import org.jfree.chart.ChartPanel;
|
||||||
|
//import org.jfree.chart.ChartUtilities;
|
||||||
|
//import org.jfree.chart.JFreeChart;
|
||||||
|
//import org.jfree.chart.axis.AxisLocation;
|
||||||
|
//import org.jfree.chart.axis.NumberAxis;
|
||||||
|
//import org.jfree.chart.block.BlockBorder;
|
||||||
|
//import org.jfree.chart.plot.DatasetRenderingOrder;
|
||||||
|
//import org.jfree.chart.plot.XYPlot;
|
||||||
|
//import org.jfree.chart.renderer.GrayPaintScale;
|
||||||
|
//import org.jfree.chart.renderer.PaintScale;
|
||||||
|
//import org.jfree.chart.renderer.xy.XYBlockRenderer;
|
||||||
|
//import org.jfree.chart.renderer.xy.XYLineAndShapeRenderer;
|
||||||
|
//import org.jfree.chart.title.PaintScaleLegend;
|
||||||
|
//import org.jfree.data.xy.*;
|
||||||
|
//import org.jfree.ui.RectangleEdge;
|
||||||
|
//import org.jfree.ui.RectangleInsets;
|
||||||
|
//import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
//import org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax;
|
||||||
|
//import org.nd4j.linalg.dataset.DataSet;
|
||||||
|
//import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
||||||
|
//import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
//
|
||||||
|
//import javax.swing.*;
|
||||||
|
//import java.awt.*;
|
||||||
|
//import java.util.ArrayList;
|
||||||
|
//import java.util.List;
|
||||||
|
//
|
||||||
|
///**
|
||||||
|
// * Simple plotting methods for the MLPClassifier quickstartexamples
|
||||||
|
// *
|
||||||
|
// * @author Alex Black
|
||||||
|
// */
|
||||||
|
//public class PlotUtils {
|
||||||
|
//
|
||||||
|
// /**
|
||||||
|
// * Plot the training data. Assume 2d input, classification output
|
||||||
|
// *
|
||||||
|
// * @param model Model to use to get predictions
|
||||||
|
// * @param trainIter DataSet Iterator
|
||||||
|
// * @param backgroundIn sets of x,y points in input space, plotted in the background
|
||||||
|
// * @param nDivisions Number of points (per axis, for the backgroundIn/backgroundOut arrays)
|
||||||
|
// */
|
||||||
|
// public static void plotTrainingData(MultiLayerNetwork model, DataSetIterator trainIter, INDArray backgroundIn, int nDivisions) {
|
||||||
|
// double[] mins = backgroundIn.min(0).data().asDouble();
|
||||||
|
// double[] maxs = backgroundIn.max(0).data().asDouble();
|
||||||
|
//
|
||||||
|
// DataSet ds = allBatches(trainIter);
|
||||||
|
// INDArray backgroundOut = model.output(backgroundIn);
|
||||||
|
//
|
||||||
|
// XYZDataset backgroundData = createBackgroundData(backgroundIn, backgroundOut);
|
||||||
|
// JPanel panel = new ChartPanel(createChart(backgroundData, mins, maxs, nDivisions, createDataSetTrain(ds.getFeatures(), ds.getLabels())));
|
||||||
|
//
|
||||||
|
// JFrame f = new JFrame();
|
||||||
|
// f.add(panel);
|
||||||
|
// f.setDefaultCloseOperation(WindowConstants.EXIT_ON_CLOSE);
|
||||||
|
// f.pack();
|
||||||
|
// f.setTitle("Training Data");
|
||||||
|
//
|
||||||
|
// f.setVisible(true);
|
||||||
|
// f.setLocation(0, 0);
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// /**
|
||||||
|
// * Plot the training data. Assume 2d input, classification output
|
||||||
|
// *
|
||||||
|
// * @param model Model to use to get predictions
|
||||||
|
// * @param testIter Test Iterator
|
||||||
|
// * @param backgroundIn sets of x,y points in input space, plotted in the background
|
||||||
|
// * @param nDivisions Number of points (per axis, for the backgroundIn/backgroundOut arrays)
|
||||||
|
// */
|
||||||
|
// public static void plotTestData(MultiLayerNetwork model, DataSetIterator testIter, INDArray backgroundIn, int nDivisions) {
|
||||||
|
//
|
||||||
|
// double[] mins = backgroundIn.min(0).data().asDouble();
|
||||||
|
// double[] maxs = backgroundIn.max(0).data().asDouble();
|
||||||
|
//
|
||||||
|
// INDArray backgroundOut = model.output(backgroundIn);
|
||||||
|
// XYZDataset backgroundData = createBackgroundData(backgroundIn, backgroundOut);
|
||||||
|
// DataSet ds = allBatches(testIter);
|
||||||
|
// INDArray predicted = model.output(ds.getFeatures());
|
||||||
|
// JPanel panel = new ChartPanel(createChart(backgroundData, mins, maxs, nDivisions, createDataSetTest(ds.getFeatures(), ds.getLabels(), predicted)));
|
||||||
|
//
|
||||||
|
// JFrame f = new JFrame();
|
||||||
|
// f.add(panel);
|
||||||
|
// f.setDefaultCloseOperation(WindowConstants.EXIT_ON_CLOSE);
|
||||||
|
// f.pack();
|
||||||
|
// f.setTitle("Test Data");
|
||||||
|
//
|
||||||
|
// f.setVisible(true);
|
||||||
|
// f.setLocationRelativeTo(null);
|
||||||
|
// //f.setLocation(100,100);
|
||||||
|
//
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
//
|
||||||
|
// /**
|
||||||
|
// * Create data for the background data set
|
||||||
|
// */
|
||||||
|
// private static XYZDataset createBackgroundData(INDArray backgroundIn, INDArray backgroundOut) {
|
||||||
|
// int nRows = backgroundIn.rows();
|
||||||
|
// double[] xValues = new double[nRows];
|
||||||
|
// double[] yValues = new double[nRows];
|
||||||
|
// double[] zValues = new double[nRows];
|
||||||
|
// for (int i = 0; i < nRows; i++) {
|
||||||
|
// xValues[i] = backgroundIn.getDouble(i, 0);
|
||||||
|
// yValues[i] = backgroundIn.getDouble(i, 1);
|
||||||
|
// zValues[i] = backgroundOut.getDouble(i, 0);
|
||||||
|
//
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// DefaultXYZDataset dataset = new DefaultXYZDataset();
|
||||||
|
// dataset.addSeries("Series 1",
|
||||||
|
// new double[][]{xValues, yValues, zValues});
|
||||||
|
// return dataset;
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// //Training data
|
||||||
|
// private static XYDataset createDataSetTrain(INDArray features, INDArray labels) {
|
||||||
|
// int nRows = features.rows();
|
||||||
|
//
|
||||||
|
// int nClasses = 2; // Binary classification using one output call end sigmoid.
|
||||||
|
//
|
||||||
|
// XYSeries[] series = new XYSeries[nClasses];
|
||||||
|
// for (int i = 0; i < series.length; i++) series[i] = new XYSeries("Class " + i);
|
||||||
|
// INDArray argMax = Nd4j.getExecutioner().exec(new ArgMax(new INDArray[]{labels},false,new int[]{1}))[0];
|
||||||
|
// for (int i = 0; i < nRows; i++) {
|
||||||
|
// int classIdx = (int) argMax.getDouble(i);
|
||||||
|
// series[classIdx].add(features.getDouble(i, 0), features.getDouble(i, 1));
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// XYSeriesCollection c = new XYSeriesCollection();
|
||||||
|
// for (XYSeries s : series) c.addSeries(s);
|
||||||
|
// return c;
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// //Test data
|
||||||
|
// private static XYDataset createDataSetTest(INDArray features, INDArray labels, INDArray predicted) {
|
||||||
|
// int nRows = features.rows();
|
||||||
|
//
|
||||||
|
// int nClasses = 2; // Binary classification using one output call end sigmoid.
|
||||||
|
//
|
||||||
|
// XYSeries[] series = new XYSeries[nClasses * nClasses];
|
||||||
|
// int[] series_index = new int[]{0, 3, 2, 1}; //little hack to make the charts look consistent.
|
||||||
|
// for (int i = 0; i < nClasses * nClasses; i++) {
|
||||||
|
// int trueClass = i / nClasses;
|
||||||
|
// int predClass = i % nClasses;
|
||||||
|
// String label = "actual=" + trueClass + ", pred=" + predClass;
|
||||||
|
// series[series_index[i]] = new XYSeries(label);
|
||||||
|
// }
|
||||||
|
// INDArray actualIdx = labels.argMax(1);
|
||||||
|
// INDArray predictedIdx = predicted.argMax(1);
|
||||||
|
// for (int i = 0; i < nRows; i++) {
|
||||||
|
// int classIdx = actualIdx.getInt(i);
|
||||||
|
// int predIdx = predictedIdx.getInt(i);
|
||||||
|
// int idx = series_index[classIdx * nClasses + predIdx];
|
||||||
|
// series[idx].add(features.getDouble(i, 0), features.getDouble(i, 1));
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// XYSeriesCollection c = new XYSeriesCollection();
|
||||||
|
// for (XYSeries s : series) c.addSeries(s);
|
||||||
|
// return c;
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// private static JFreeChart createChart(XYZDataset dataset, double[] mins, double[] maxs, int nPoints, XYDataset xyData) {
|
||||||
|
// NumberAxis xAxis = new NumberAxis("X");
|
||||||
|
// xAxis.setRange(mins[0], maxs[0]);
|
||||||
|
//
|
||||||
|
//
|
||||||
|
// NumberAxis yAxis = new NumberAxis("Y");
|
||||||
|
// yAxis.setRange(mins[1], maxs[1]);
|
||||||
|
//
|
||||||
|
// XYBlockRenderer renderer = new XYBlockRenderer();
|
||||||
|
// renderer.setBlockWidth((maxs[0] - mins[0]) / (nPoints - 1));
|
||||||
|
// renderer.setBlockHeight((maxs[1] - mins[1]) / (nPoints - 1));
|
||||||
|
// PaintScale scale = new GrayPaintScale(0, 1.0);
|
||||||
|
// renderer.setPaintScale(scale);
|
||||||
|
// XYPlot plot = new XYPlot(dataset, xAxis, yAxis, renderer);
|
||||||
|
// plot.setBackgroundPaint(Color.lightGray);
|
||||||
|
// plot.setDomainGridlinesVisible(false);
|
||||||
|
// plot.setRangeGridlinesVisible(false);
|
||||||
|
// plot.setAxisOffset(new RectangleInsets(5, 5, 5, 5));
|
||||||
|
// JFreeChart chart = new JFreeChart("", plot);
|
||||||
|
// chart.getXYPlot().getRenderer().setSeriesVisibleInLegend(0, false);
|
||||||
|
//
|
||||||
|
//
|
||||||
|
// NumberAxis scaleAxis = new NumberAxis("Probability (class 1)");
|
||||||
|
// scaleAxis.setAxisLinePaint(Color.white);
|
||||||
|
// scaleAxis.setTickMarkPaint(Color.white);
|
||||||
|
// scaleAxis.setTickLabelFont(new Font("Dialog", Font.PLAIN, 7));
|
||||||
|
// PaintScaleLegend legend = new PaintScaleLegend(new GrayPaintScale(),
|
||||||
|
// scaleAxis);
|
||||||
|
// legend.setStripOutlineVisible(false);
|
||||||
|
// legend.setSubdivisionCount(20);
|
||||||
|
// legend.setAxisLocation(AxisLocation.BOTTOM_OR_LEFT);
|
||||||
|
// legend.setAxisOffset(5.0);
|
||||||
|
// legend.setMargin(new RectangleInsets(5, 5, 5, 5));
|
||||||
|
// legend.setFrame(new BlockBorder(Color.red));
|
||||||
|
// legend.setPadding(new RectangleInsets(10, 10, 10, 10));
|
||||||
|
// legend.setStripWidth(10);
|
||||||
|
// legend.setPosition(RectangleEdge.LEFT);
|
||||||
|
// chart.addSubtitle(legend);
|
||||||
|
//
|
||||||
|
// ChartUtilities.applyCurrentTheme(chart);
|
||||||
|
//
|
||||||
|
// plot.setDataset(1, xyData);
|
||||||
|
// XYLineAndShapeRenderer renderer2 = new XYLineAndShapeRenderer();
|
||||||
|
// renderer2.setBaseLinesVisible(false);
|
||||||
|
// plot.setRenderer(1, renderer2);
|
||||||
|
//
|
||||||
|
// plot.setDatasetRenderingOrder(DatasetRenderingOrder.FORWARD);
|
||||||
|
//
|
||||||
|
// return chart;
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// public static INDArray generatePointsOnGraph(double xMin, double xMax, double yMin, double yMax, int nPointsPerAxis) {
|
||||||
|
// //generate all the x,y points
|
||||||
|
// double[][] evalPoints = new double[nPointsPerAxis * nPointsPerAxis][2];
|
||||||
|
// int count = 0;
|
||||||
|
// for (int i = 0; i < nPointsPerAxis; i++) {
|
||||||
|
// for (int j = 0; j < nPointsPerAxis; j++) {
|
||||||
|
// double x = i * (xMax - xMin) / (nPointsPerAxis - 1) + xMin;
|
||||||
|
// double y = j * (yMax - yMin) / (nPointsPerAxis - 1) + yMin;
|
||||||
|
//
|
||||||
|
// evalPoints[count][0] = x;
|
||||||
|
// evalPoints[count][1] = y;
|
||||||
|
//
|
||||||
|
// count++;
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// return Nd4j.create(evalPoints);
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// /**
|
||||||
|
// * This is to collect all the data and return it as one minibatch. Obviously only for use here with small datasets
|
||||||
|
// * @param iter
|
||||||
|
// * @return
|
||||||
|
// */
|
||||||
|
// private static DataSet allBatches(DataSetIterator iter) {
|
||||||
|
//
|
||||||
|
// List<DataSet> fullSet = new ArrayList<>();
|
||||||
|
// iter.reset();
|
||||||
|
// while (iter.hasNext()) {
|
||||||
|
// List<DataSet> miniBatchList = iter.next().asList();
|
||||||
|
// fullSet.addAll(miniBatchList);
|
||||||
|
// }
|
||||||
|
// iter.reset();
|
||||||
|
// return new ListDataSetIterator<>(fullSet,fullSet.size()).next();
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
//}
|
|
@ -23,6 +23,7 @@ import java.io.BufferedReader;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.io.InputStreamReader;
|
import java.io.InputStreamReader;
|
||||||
import java.io.Serializable;
|
import java.io.Serializable;
|
||||||
|
import java.math.BigDecimal;
|
||||||
import java.text.Normalizer;
|
import java.text.Normalizer;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.regex.Matcher;
|
import java.util.regex.Matcher;
|
||||||
|
@ -71,6 +72,30 @@ public class Utilities implements Serializable {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public static double[] getJPathArray(final String jsonPath, final String inputJson) {
|
||||||
|
try {
|
||||||
|
Object o = JsonPath.read(inputJson, jsonPath);
|
||||||
|
if (o instanceof double[])
|
||||||
|
return (double[]) o;
|
||||||
|
if (o instanceof JSONArray) {
|
||||||
|
Object[] objects = ((JSONArray) o).toArray();
|
||||||
|
double[] array = new double[objects.length];
|
||||||
|
for (int i = 0; i < objects.length; i++) {
|
||||||
|
if (objects[i] instanceof BigDecimal)
|
||||||
|
array[i] = ((BigDecimal)objects[i]).doubleValue();
|
||||||
|
else
|
||||||
|
array[i] = (double) objects[i];
|
||||||
|
}
|
||||||
|
return array;
|
||||||
|
}
|
||||||
|
return new double[0];
|
||||||
|
}
|
||||||
|
catch (Exception e) {
|
||||||
|
e.printStackTrace();
|
||||||
|
return new double[0];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// public static String normalize(final String s) {
|
// public static String normalize(final String s) {
|
||||||
// return Normalizer.normalize(s, Normalizer.Form.NFD)
|
// return Normalizer.normalize(s, Normalizer.Form.NFD)
|
||||||
// .replaceAll("[^\\w\\s-]", "") // Remove all non-word, non-space or non-dash characters
|
// .replaceAll("[^\\w\\s-]", "") // Remove all non-word, non-space or non-dash characters
|
||||||
|
|
|
@ -27,6 +27,6 @@ public class UtilityTest {
|
||||||
Author a = new Author("De Bonis, Michele", "Æ", "De Bonis", new ArrayList<CoAuthor>(), new double[]{0.0, 1.0}, "author::id", "pub::id", "orcid");
|
Author a = new Author("De Bonis, Michele", "Æ", "De Bonis", new ArrayList<CoAuthor>(), new double[]{0.0, 1.0}, "author::id", "pub::id", "orcid");
|
||||||
System.out.println("a = " + a.isAccurate());
|
System.out.println("a = " + a.isAccurate());
|
||||||
System.out.println(AuthorsFactory.getLNFI(a));
|
System.out.println(AuthorsFactory.getLNFI(a));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,47 @@
|
||||||
|
package eu.dnetlib.deeplearning;
|
||||||
|
|
||||||
|
import com.beust.jcommander.internal.Sets;
|
||||||
|
import com.google.common.collect.Lists;
|
||||||
|
import eu.dnetlib.deeplearning.support.DataSetProcessor;
|
||||||
|
import eu.dnetlib.support.Relation;
|
||||||
|
import org.junit.jupiter.api.BeforeAll;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
import org.nd4j.linalg.dataset.MultiDataSet;
|
||||||
|
|
||||||
|
import java.util.*;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
|
public class DataSetProcessorTest {
|
||||||
|
|
||||||
|
static Map<String, double[]> features;
|
||||||
|
static Set<Relation> relations;
|
||||||
|
static List<String> groundTruth;
|
||||||
|
|
||||||
|
@BeforeAll
|
||||||
|
public static void init(){
|
||||||
|
//initialize example features
|
||||||
|
features = new HashMap<>();
|
||||||
|
features.put("0", new double[]{0.0,0.0});
|
||||||
|
features.put("1", new double[]{1.0,1.0});
|
||||||
|
features.put("2", new double[]{2.0,2.0});
|
||||||
|
|
||||||
|
//initialize example relations
|
||||||
|
relations = new HashSet<>(Lists.newArrayList(
|
||||||
|
new Relation("0", "1", "simrel"),
|
||||||
|
new Relation("1", "2", "simrel")
|
||||||
|
));
|
||||||
|
|
||||||
|
//initialize example ground truth
|
||||||
|
groundTruth = Lists.newArrayList("class1", "class1", "class2");
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void getMultiDataSetTest() throws Exception {
|
||||||
|
MultiDataSet multiDataSet = DataSetProcessor.getMultiDataSet(features, relations, groundTruth);
|
||||||
|
System.out.println("multiDataSet = " + multiDataSet);
|
||||||
|
|
||||||
|
multiDataSet.asList();
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,33 @@
|
||||||
|
package eu.dnetlib.deeplearning;
|
||||||
|
|
||||||
|
import eu.dnetlib.deeplearning.support.NetworkConfigurations;
|
||||||
|
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
|
||||||
|
import org.deeplearning4j.nn.graph.ComputationGraph;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
|
public class NetworkConfigurationTests {
|
||||||
|
|
||||||
|
public final static int N = 3; //number of nodes
|
||||||
|
public final static int K = 7; //number of features
|
||||||
|
|
||||||
|
public static INDArray[] exampleGraph = new INDArray[]{
|
||||||
|
Nd4j.zeros(N, K), //features
|
||||||
|
Nd4j.ones(N, N), //adjacency
|
||||||
|
Nd4j.ones(N, N) //degree
|
||||||
|
};
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void simpleGCNTest() {
|
||||||
|
|
||||||
|
ComputationGraphConfiguration simpleGCNConf = NetworkConfigurations.getSimpleGCN(3, K, 5, 2);
|
||||||
|
ComputationGraph simpleGCN = new ComputationGraph(simpleGCNConf);
|
||||||
|
simpleGCN.init();
|
||||||
|
|
||||||
|
INDArray[] output = simpleGCN.output(exampleGraph);
|
||||||
|
System.out.println("output = " + output[0]);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
90
pom.xml
90
pom.xml
|
@ -247,7 +247,7 @@
|
||||||
<google.guava.version>15.0</google.guava.version>
|
<google.guava.version>15.0</google.guava.version>
|
||||||
|
|
||||||
<spark.version>2.2.0</spark.version>
|
<spark.version>2.2.0</spark.version>
|
||||||
<sparknlp.version>2.5.5</sparknlp.version>
|
<scala.binary.version>2.11</scala.binary.version>
|
||||||
<jackson.version>2.6.5</jackson.version>
|
<jackson.version>2.6.5</jackson.version>
|
||||||
<mockito-core.version>3.3.3</mockito-core.version>
|
<mockito-core.version>3.3.3</mockito-core.version>
|
||||||
|
|
||||||
|
@ -282,6 +282,10 @@
|
||||||
<junit-jupiter.version>5.6.1</junit-jupiter.version>
|
<junit-jupiter.version>5.6.1</junit-jupiter.version>
|
||||||
<maven.dependency.eu.dnetlib.dhp.dhp-build-assembly-resources.jar.path>../dhp-build/dhp-build-assembly-resources/target/dhp-build-assembly-resources-${project.version}.jar</maven.dependency.eu.dnetlib.dhp.dhp-build-assembly-resources.jar.path>
|
<maven.dependency.eu.dnetlib.dhp.dhp-build-assembly-resources.jar.path>../dhp-build/dhp-build-assembly-resources/target/dhp-build-assembly-resources-${project.version}.jar</maven.dependency.eu.dnetlib.dhp.dhp-build-assembly-resources.jar.path>
|
||||||
|
|
||||||
|
<!--deeplearning4j-->
|
||||||
|
<dl4j-master.version>1.0.0-beta7</dl4j-master.version>
|
||||||
|
<nd4j.backend>nd4j-native</nd4j.backend>
|
||||||
|
|
||||||
</properties>
|
</properties>
|
||||||
|
|
||||||
<dependencyManagement>
|
<dependencyManagement>
|
||||||
|
@ -446,6 +450,90 @@
|
||||||
<version>2.10.29</version>
|
<version>2.10.29</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
|
||||||
|
<!--DEEPLEARNING4J-->
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.nd4j</groupId>
|
||||||
|
<artifactId>${nd4j.backend}</artifactId>
|
||||||
|
<version>${dl4j-master.version}</version>
|
||||||
|
</dependency>
|
||||||
|
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.datavec</groupId>
|
||||||
|
<artifactId>datavec-api</artifactId>
|
||||||
|
<version>${dl4j-master.version}</version>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.datavec</groupId>
|
||||||
|
<artifactId>datavec-data-image</artifactId>
|
||||||
|
<version>${dl4j-master.version}</version>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.datavec</groupId>
|
||||||
|
<artifactId>datavec-local</artifactId>
|
||||||
|
<version>${dl4j-master.version}</version>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.deeplearning4j</groupId>
|
||||||
|
<artifactId>deeplearning4j-datasets</artifactId>
|
||||||
|
<version>${dl4j-master.version}</version>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.deeplearning4j</groupId>
|
||||||
|
<artifactId>deeplearning4j-core</artifactId>
|
||||||
|
<version>${dl4j-master.version}</version>
|
||||||
|
</dependency>
|
||||||
|
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.deeplearning4j</groupId>
|
||||||
|
<artifactId>resources</artifactId>
|
||||||
|
<version>${dl4j-master.version}</version>
|
||||||
|
</dependency>
|
||||||
|
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.deeplearning4j</groupId>
|
||||||
|
<artifactId>deeplearning4j-ui</artifactId>
|
||||||
|
<version>${dl4j-master.version}</version>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.deeplearning4j</groupId>
|
||||||
|
<artifactId>deeplearning4j-zoo</artifactId>
|
||||||
|
<version>${dl4j-master.version}</version>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.deeplearning4j</groupId>
|
||||||
|
<artifactId>dl4j-spark-parameterserver_2.11</artifactId>
|
||||||
|
<version>${dl4j-master.version}</version>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.deeplearning4j</groupId>
|
||||||
|
<artifactId>dl4j-spark_2.11</artifactId>
|
||||||
|
<version>${dl4j-master.version}</version>
|
||||||
|
</dependency>
|
||||||
|
|
||||||
|
<!--PLOT-->
|
||||||
|
<dependency>
|
||||||
|
<groupId>jfree</groupId>
|
||||||
|
<artifactId>jfreechart</artifactId>
|
||||||
|
<version>1.0.13</version>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.jfree</groupId>
|
||||||
|
<artifactId>jcommon</artifactId>
|
||||||
|
<version>1.0.23</version>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.deeplearning4j</groupId>
|
||||||
|
<artifactId>deeplearning4j-datasets</artifactId>
|
||||||
|
<version>${dl4j-master.version}</version>
|
||||||
|
</dependency>
|
||||||
|
|
||||||
|
<!--DNET DEDUP-->
|
||||||
|
<dependency>
|
||||||
|
<groupId>eu.dnetlib</groupId>
|
||||||
|
<artifactId>dnet-dedup-test</artifactId>
|
||||||
|
<version>4.1.13-SNAPSHOT</version>
|
||||||
|
</dependency>
|
||||||
|
|
||||||
</dependencies>
|
</dependencies>
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue