diff --git a/dhp-build/dhp-build-properties-maven-plugin/test.properties b/dhp-build/dhp-build-properties-maven-plugin/test.properties
index 0573300..a584f65 100644
--- a/dhp-build/dhp-build-properties-maven-plugin/test.properties
+++ b/dhp-build/dhp-build-properties-maven-plugin/test.properties
@@ -1,2 +1,2 @@
-# Thu Apr 13 16:22:22 CEST 2023
+# Thu Apr 27 21:12:07 CEST 2023
projectPropertyKey=projectPropertyValue
diff --git a/dnet-and-test/job-override.properties b/dnet-and-test/job-override.properties
index 99d28c3..09cb572 100644
--- a/dnet-and-test/job-override.properties
+++ b/dnet-and-test/job-override.properties
@@ -11,11 +11,31 @@
#outputModelPath = /user/michele.debonis/lda_experiments/lda_dewey2.model
#LDA INFERENCE
+#numPartitions = 1000
+#inputFieldJPath = $.description[0].value
+#vocabularyPath = /user/michele.debonis/lda_experiments/dewey_vocabulary
+#entitiesPath = /tmp/publications_with_pid_pubmed
+#workingPath = /user/michele.debonis/lda_experiments/lda_inference_working_dir
+#ldaInferencePath = /user/michele.debonis/lda_experiments/publications_pubmed_topics
+#ldaModelPath = /user/michele.debonis/lda_experiments/lda_dewey.model
+#authorsPath = /user/michele.debonis/lda_experiments/authors_pubmed
+
+#GNN TRAINING
+#groupsPath = /user/michele.debonis/authors_dedup/gt_dedup/groupentities
+#workingPath = /user/michele.debonis/gnn_experiments
+#numPartitions = 1000
+#numEpochs = 100
+#groundTruthJPath = $.orcid
+#idJPath = $.id
+#featuresJPath = $.topics
+
+#FEATURE EXTRACTION
+publicationsPath = /tmp/publications_with_pid_pubmed
+workingPath = /user/michele.debonis/feature_extraction
numPartitions = 1000
-inputFieldJPath = $.description[0].value
-vocabularyPath = /user/michele.debonis/lda_experiments/dewey_vocabulary
-entitiesPath = /tmp/publications_with_pid_pubmed
-workingPath = /user/michele.debonis/lda_experiments/lda_inference_working_dir
-ldaInferencePath = /user/michele.debonis/lda_experiments/publications_pubmed_topics
-ldaModelPath = /user/michele.debonis/lda_experiments/lda_dewey.model
-authorsPath = /user/michele.debonis/lda_experiments/authors_pubmed
\ No newline at end of file
+featuresPath = /user/michele.debonis/feature_extraction/publications_pubmed_features
+topicsPath = /user/michele.debonis/lda_experiments/publications_pubmed_topics
+outputPath = /user/michele.debonis/feature_extraction/authors_pubmed
+wordEmbeddingsModel = /user/michele.debonis/nlp_models/glove_100d_en_2.4.0_2.4_1579690104032
+bertSentenceModel = /user/michele.debonis/nlp_models/sent_small_bert_L6_512_en_2.6.0_2.4_1598350624049
+bertModel = /user/michele.debonis/nlp_models/small_bert_L2_128_en_2.6.0_2.4_1598344320681
\ No newline at end of file
diff --git a/dnet-and-test/pom.xml b/dnet-and-test/pom.xml
index ddc3e22..7104cba 100644
--- a/dnet-and-test/pom.xml
+++ b/dnet-and-test/pom.xml
@@ -131,23 +131,11 @@
json-path
-
eu.dnetlib.dhp
dhp-schemas
-
-
-
-
-
-
-
-
-
-
-
diff --git a/dnet-and-test/src/main/java/eu/dnetlib/jobs/AbstractSparkJob.java b/dnet-and-test/src/main/java/eu/dnetlib/jobs/AbstractSparkJob.java
index ff21a8e..a1b9737 100644
--- a/dnet-and-test/src/main/java/eu/dnetlib/jobs/AbstractSparkJob.java
+++ b/dnet-and-test/src/main/java/eu/dnetlib/jobs/AbstractSparkJob.java
@@ -36,7 +36,7 @@ public abstract class AbstractSparkJob implements Serializable {
this.spark = spark;
}
- protected abstract void run() throws IOException;
+ protected abstract void run() throws IOException, InterruptedException;
protected static SparkSession getSparkSession(SparkConf conf) {
return SparkSession.builder().config(conf).getOrCreate();
diff --git a/dnet-and-test/src/main/java/eu/dnetlib/jobs/SparkCountVectorizer.java b/dnet-and-test/src/main/java/eu/dnetlib/jobs/SparkCountVectorizer.java
index ef1912c..d14b76a 100644
--- a/dnet-and-test/src/main/java/eu/dnetlib/jobs/SparkCountVectorizer.java
+++ b/dnet-and-test/src/main/java/eu/dnetlib/jobs/SparkCountVectorizer.java
@@ -2,7 +2,6 @@ package eu.dnetlib.jobs;
import eu.dnetlib.featureextraction.FeatureTransformer;
import eu.dnetlib.support.ArgumentApplicationParser;
-import org.apache.hadoop.fs.shell.Count;
import org.apache.spark.SparkConf;
import org.apache.spark.ml.feature.CountVectorizerModel;
import org.apache.spark.sql.Dataset;
diff --git a/dnet-and-test/src/main/java/eu/dnetlib/jobs/SparkCreateVocabulary.java b/dnet-and-test/src/main/java/eu/dnetlib/jobs/SparkCreateVocabulary.java
index 4b5cdf8..45b3011 100644
--- a/dnet-and-test/src/main/java/eu/dnetlib/jobs/SparkCreateVocabulary.java
+++ b/dnet-and-test/src/main/java/eu/dnetlib/jobs/SparkCreateVocabulary.java
@@ -9,8 +9,6 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
-import java.net.URISyntaxException;
-import java.nio.file.Paths;
import java.util.Optional;
public class SparkCreateVocabulary extends AbstractSparkJob{
diff --git a/dnet-and-test/src/main/java/eu/dnetlib/jobs/deeplearning/SparkCreateGroupDataSet.java b/dnet-and-test/src/main/java/eu/dnetlib/jobs/deeplearning/SparkCreateGroupDataSet.java
deleted file mode 100644
index e02f5ae..0000000
--- a/dnet-and-test/src/main/java/eu/dnetlib/jobs/deeplearning/SparkCreateGroupDataSet.java
+++ /dev/null
@@ -1,75 +0,0 @@
-package eu.dnetlib.jobs.deeplearning;
-
-import eu.dnetlib.deeplearning.support.DataSetProcessor;
-import eu.dnetlib.jobs.AbstractSparkJob;
-import eu.dnetlib.jobs.SparkLDATuning;
-import eu.dnetlib.support.ArgumentApplicationParser;
-import eu.dnetlib.support.ConnectedComponent;
-import org.apache.spark.SparkConf;
-import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.sql.SparkSession;
-import org.codehaus.jackson.map.ObjectMapper;
-import org.deeplearning4j.spark.data.BatchAndExportDataSetsFunction;
-import org.deeplearning4j.spark.data.BatchAndExportMultiDataSetsFunction;
-import org.deeplearning4j.spark.datavec.iterator.IteratorUtils;
-import org.nd4j.linalg.dataset.MultiDataSet;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-
-import java.io.IOException;
-import java.util.List;
-import java.util.Optional;
-
-public class SparkCreateGroupDataSet extends AbstractSparkJob {
-
- private static final Logger log = LoggerFactory.getLogger(SparkCreateGroupDataSet.class);
-
- public SparkCreateGroupDataSet(ArgumentApplicationParser parser, SparkSession spark) {
- super(parser, spark);
- }
- public static void main(String[] args) throws Exception {
- ArgumentApplicationParser parser = new ArgumentApplicationParser(
- readResource("/jobs/parameters/createGroupDataset_parameters.json", SparkLDATuning.class)
- );
-
- parser.parseArgument(args);
-
- SparkConf conf = new SparkConf();
-
- new SparkCreateGroupDataSet(
- parser,
- getSparkSession(conf)
- ).run();
- }
-
- @Override
- public void run() throws IOException {
- // read oozie parameters
- final String groupsPath = parser.get("groupsPath");
- final String workingPath = parser.get("workingPath");
- final String groundTruthJPath = parser.get("groundTruthJPath");
- final String idJPath = parser.get("idJPath");
- final String featuresJPath = parser.get("featuresJPath");
- final int numPartitions = Optional
- .ofNullable(parser.get("numPartitions"))
- .map(Integer::valueOf)
- .orElse(NUM_PARTITIONS);
-
- log.info("groupsPath: '{}'", groupsPath);
- log.info("workingPath: '{}'", workingPath);
- log.info("groundTruthJPath: '{}'", groundTruthJPath);
- log.info("idJPath: '{}'", idJPath);
- log.info("featuresJPath: '{}'", featuresJPath);
- log.info("numPartitions: '{}'", numPartitions);
-
- JavaSparkContext context = JavaSparkContext.fromSparkContext(spark.sparkContext());
-
- JavaRDD groups = context.textFile(groupsPath).map(g -> new ObjectMapper().readValue(g, ConnectedComponent.class));
-
- JavaRDD dataset = DataSetProcessor.entityGroupToMultiDataset(groups, idJPath, featuresJPath, groundTruthJPath);
-
- dataset.saveAsObjectFile(workingPath + "/groupDataset");
- }
-
-}
diff --git a/dnet-and-test/src/main/java/eu/dnetlib/jobs/deeplearning/SparkGraphClassificationTraining.java b/dnet-and-test/src/main/java/eu/dnetlib/jobs/deeplearning/SparkGraphClassificationTraining.java
deleted file mode 100644
index e0772ae..0000000
--- a/dnet-and-test/src/main/java/eu/dnetlib/jobs/deeplearning/SparkGraphClassificationTraining.java
+++ /dev/null
@@ -1,99 +0,0 @@
-package eu.dnetlib.jobs.deeplearning;
-
-import eu.dnetlib.deeplearning.support.DataSetProcessor;
-import eu.dnetlib.deeplearning.support.NetworkConfigurations;
-import eu.dnetlib.jobs.AbstractSparkJob;
-import eu.dnetlib.jobs.SparkLDATuning;
-import eu.dnetlib.support.ArgumentApplicationParser;
-import eu.dnetlib.support.ConnectedComponent;
-import org.apache.spark.SparkConf;
-import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.sql.SparkSession;
-import org.codehaus.jackson.map.ObjectMapper;
-import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
-import org.deeplearning4j.nn.graph.ComputationGraph;
-import org.deeplearning4j.optimize.listeners.PerformanceListener;
-import org.deeplearning4j.optimize.solvers.accumulation.encoding.threshold.AdaptiveThresholdAlgorithm;
-import org.deeplearning4j.spark.api.RDDTrainingApproach;
-import org.deeplearning4j.spark.api.TrainingMaster;
-import org.deeplearning4j.spark.impl.graph.SparkComputationGraph;
-import org.deeplearning4j.spark.parameterserver.training.SharedTrainingMaster;
-import org.nd4j.linalg.dataset.api.MultiDataSet;
-import org.nd4j.parameterserver.distributed.conf.VoidConfiguration;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-
-import java.io.IOException;
-import java.util.Optional;
-
-public class SparkGraphClassificationTraining extends AbstractSparkJob {
-
- private static final Logger log = LoggerFactory.getLogger(SparkGraphClassificationTraining.class);
-
- public SparkGraphClassificationTraining(ArgumentApplicationParser parser, SparkSession spark) {
- super(parser, spark);
- }
- public static void main(String[] args) throws Exception {
- ArgumentApplicationParser parser = new ArgumentApplicationParser(
- readResource("/jobs/parameters/graphClassificationTraining_parameters.json", SparkLDATuning.class)
- );
-
- parser.parseArgument(args);
-
- SparkConf conf = new SparkConf();
-
- new SparkGraphClassificationTraining(
- parser,
- getSparkSession(conf)
- ).run();
- }
-
- @Override
- public void run() throws IOException {
- // read oozie parameters
- final String workingPath = parser.get("workingPath");
- final int numPartitions = Optional
- .ofNullable(parser.get("numPartitions"))
- .map(Integer::valueOf)
- .orElse(NUM_PARTITIONS);
- log.info("workingPath: '{}'", workingPath);
- log.info("numPartitions: '{}'", numPartitions);
-
- JavaSparkContext context = JavaSparkContext.fromSparkContext(spark.sparkContext());
-
- VoidConfiguration conf = VoidConfiguration.builder()
- .unicastPort(40123)
-// .networkMask("255.255.148.0/22")
- .controllerAddress("127.0.0.1")
- .build();
-
- TrainingMaster trainingMaster = new SharedTrainingMaster.Builder(conf,1)
- .rngSeed(12345)
- .collectTrainingStats(false)
- .thresholdAlgorithm(new AdaptiveThresholdAlgorithm(1e-3))
- .batchSizePerWorker(32)
- .workersPerNode(4)
- .rddTrainingApproach(RDDTrainingApproach.Direct)
- .build();
-
- JavaRDD trainData = context.objectFile(workingPath + "/groupDataset");
-
- SparkComputationGraph sparkComputationGraph = new SparkComputationGraph(
- context,
- NetworkConfigurations.getSimpleGCN(3, 2, 5, 2),
- trainingMaster);
- sparkComputationGraph.setListeners(new PerformanceListener(10, true));
-
- //execute training
- for (int i = 0; i < 20; i ++) {
- sparkComputationGraph.fitMultiDataSet(trainData);
- }
-
- ComputationGraph network = sparkComputationGraph.getNetwork();
-
- System.out.println("network = " + network.getConfiguration().toJson());
-
-
- }
-}
diff --git a/dnet-and-test/src/main/java/eu/dnetlib/jobs/SparkAuthorExtractor.java b/dnet-and-test/src/main/java/eu/dnetlib/jobs/featureextraction/SparkAuthorExtractor.java
similarity index 50%
rename from dnet-and-test/src/main/java/eu/dnetlib/jobs/SparkAuthorExtractor.java
rename to dnet-and-test/src/main/java/eu/dnetlib/jobs/featureextraction/SparkAuthorExtractor.java
index 5de5fc0..e014cea 100644
--- a/dnet-and-test/src/main/java/eu/dnetlib/jobs/SparkAuthorExtractor.java
+++ b/dnet-and-test/src/main/java/eu/dnetlib/jobs/featureextraction/SparkAuthorExtractor.java
@@ -1,5 +1,8 @@
-package eu.dnetlib.jobs;
+package eu.dnetlib.jobs.featureextraction;
+import eu.dnetlib.dhp.schema.oaf.Publication;
+import eu.dnetlib.jobs.AbstractSparkJob;
+import eu.dnetlib.jobs.SparkTokenizer;
import eu.dnetlib.support.ArgumentApplicationParser;
import eu.dnetlib.support.Author;
import eu.dnetlib.support.AuthorsFactory;
@@ -10,15 +13,18 @@ import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.ml.linalg.DenseVector;
import org.apache.spark.sql.SparkSession;
+import org.codehaus.jackson.map.DeserializationConfig;
import org.codehaus.jackson.map.ObjectMapper;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.Tuple2;
import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
import java.util.Optional;
-public class SparkAuthorExtractor extends AbstractSparkJob{
+public class SparkAuthorExtractor extends AbstractSparkJob {
private static final Logger log = LoggerFactory.getLogger(SparkAuthorExtractor.class);
public SparkAuthorExtractor(ArgumentApplicationParser parser, SparkSession spark) {
@@ -45,7 +51,8 @@ public class SparkAuthorExtractor extends AbstractSparkJob{
public void run() throws IOException {
// read oozie parameters
final String topicsPath = parser.get("topicsPath");
- final String entitiesPath = parser.get("entitiesPath");
+ final String featuresPath = parser.get("featuresPath");
+ final String publicationsPath = parser.get("publicationsPath");
final String workingPath = parser.get("workingPath");
final String outputPath = parser.get("outputPath");
final int numPartitions = Optional
@@ -53,25 +60,44 @@ public class SparkAuthorExtractor extends AbstractSparkJob{
.map(Integer::valueOf)
.orElse(NUM_PARTITIONS);
- log.info("entitiesPath: '{}'", entitiesPath);
- log.info("topicsPath: '{}'", topicsPath);
- log.info("workingPath: '{}'", workingPath);
- log.info("outputPath: '{}'", outputPath);
- log.info("numPartitions: '{}'", numPartitions);
+ log.info("publicationsPath: '{}'", publicationsPath);
+ log.info("topicsPath: '{}'", topicsPath);
+ log.info("featuresPath: '{}'", featuresPath);
+ log.info("workingPath: '{}'", workingPath);
+ log.info("outputPath: '{}'", outputPath);
+ log.info("numPartitions: '{}'", numPartitions);
//join publications with topics
JavaSparkContext context = JavaSparkContext.fromSparkContext(spark.sparkContext());
- JavaRDD entities = context.textFile(entitiesPath);
+ JavaRDD publications = context
+ .textFile(publicationsPath)
+ .map(x -> new ObjectMapper()
+ .configure(DeserializationConfig.Feature.FAIL_ON_UNKNOWN_PROPERTIES, false)
+ .readValue(x, Publication.class));
- JavaPairRDD topics = spark.read().load(topicsPath).toJavaRDD()
- .mapToPair(t -> new Tuple2<>(t.getString(0), (DenseVector) t.get(1)));
+ JavaPairRDD topics = spark.read().load(topicsPath).toJavaRDD()
+ .mapToPair(t -> new Tuple2<>(t.getString(0), ((DenseVector) t.get(1)).toArray()));
- JavaRDD authors = AuthorsFactory.extractAuthorsFromPublications(entities, topics);
+ //merge topics with other embeddings
+ JavaPairRDD> publicationEmbeddings = spark.read().load(featuresPath).toJavaRDD().mapToPair(t -> {
+ Map embeddings = new HashMap<>();
+ embeddings.put("word_embeddings", ((DenseVector) t.get(1)).toArray());
+ embeddings.put("bert_embeddings", ((DenseVector) t.get(2)).toArray());
+ embeddings.put("bert_sentence_embeddings", ((DenseVector) t.get(3)).toArray());
+ return new Tuple2<>(t.getString(0), embeddings);
+ })
+ .join(topics).mapToPair(e -> {
+ e._2()._1().put("lda_topics", e._2()._2());
+ return new Tuple2<>(e._1(), e._2()._1());
+ });
+
+ JavaRDD authors = AuthorsFactory.extractAuthorsFromPublications(publications, publicationEmbeddings);
authors
.map(a -> new ObjectMapper().writeValueAsString(a))
.saveAsTextFile(outputPath, GzipCodec.class);
}
+
}
diff --git a/dnet-and-test/src/main/java/eu/dnetlib/jobs/featureextraction/SparkPublicationFeatureExtractor.java b/dnet-and-test/src/main/java/eu/dnetlib/jobs/featureextraction/SparkPublicationFeatureExtractor.java
new file mode 100644
index 0000000..ab22c58
--- /dev/null
+++ b/dnet-and-test/src/main/java/eu/dnetlib/jobs/featureextraction/SparkPublicationFeatureExtractor.java
@@ -0,0 +1,107 @@
+package eu.dnetlib.jobs.featureextraction;
+
+import eu.dnetlib.dhp.schema.oaf.Publication;
+import eu.dnetlib.dhp.schema.oaf.StructuredProperty;
+import eu.dnetlib.featureextraction.ScalaFeatureTransformer;
+import eu.dnetlib.jobs.AbstractSparkJob;
+import eu.dnetlib.support.ArgumentApplicationParser;
+import org.apache.spark.SparkConf;
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.sql.*;
+import org.apache.spark.sql.types.DataTypes;
+import org.apache.spark.sql.types.Metadata;
+import org.apache.spark.sql.types.StructField;
+import org.apache.spark.sql.types.StructType;
+import org.codehaus.jackson.map.DeserializationConfig;
+import org.codehaus.jackson.map.ObjectMapper;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.IOException;
+import java.util.Optional;
+import java.util.stream.Collectors;
+
+public class SparkPublicationFeatureExtractor extends AbstractSparkJob {
+ private static final Logger log = LoggerFactory.getLogger(SparkPublicationFeatureExtractor.class);
+
+ public SparkPublicationFeatureExtractor(ArgumentApplicationParser parser, SparkSession spark) {
+ super(parser, spark);
+ }
+
+ public static void main(String[] args) throws Exception {
+
+ ArgumentApplicationParser parser = new ArgumentApplicationParser(
+ readResource("/jobs/parameters/publicationFeatureExtractor_parameters.json", SparkPublicationFeatureExtractor.class)
+ );
+
+ parser.parseArgument(args);
+
+ SparkConf conf = new SparkConf();
+
+ new SparkAuthorExtractor(
+ parser,
+ getSparkSession(conf)
+ ).run();
+ }
+
+ @Override
+ public void run() throws IOException {
+ // read oozie parameters
+ final String publicationsPath = parser.get("publicationsPath");
+ final String workingPath = parser.get("workingPath");
+ final String featuresPath = parser.get("featuresPath");
+ final String bertModel = parser.get("bertModelPath");
+ final String bertSentenceModel = parser.get("bertSentenceModel");
+ final String wordEmbeddingModel = parser.get("wordEmbeddingModel");
+ final int numPartitions = Optional
+ .ofNullable(parser.get("numPartitions"))
+ .map(Integer::valueOf)
+ .orElse(NUM_PARTITIONS);
+
+ log.info("publicationsPath: '{}'", publicationsPath);
+ log.info("workingPath: '{}'", workingPath);
+ log.info("numPartitions: '{}'", numPartitions);
+
+ JavaSparkContext context = JavaSparkContext.fromSparkContext(spark.sparkContext());
+
+ JavaRDD publications = context
+ .textFile(publicationsPath)
+ .map(x -> new ObjectMapper()
+ .configure(DeserializationConfig.Feature.FAIL_ON_UNKNOWN_PROPERTIES, false)
+ .readValue(x, Publication.class));
+
+ StructType inputSchema = new StructType(new StructField[]{
+ new StructField("id", DataTypes.StringType, false, Metadata.empty()),
+ new StructField("title", DataTypes.StringType, false, Metadata.empty()),
+ new StructField("abstract", DataTypes.StringType, false, Metadata.empty()),
+ new StructField("subjects", DataTypes.StringType, false, Metadata.empty())
+ });
+
+ //prepare Rows
+ Dataset inputData = spark.createDataFrame(
+ publications.map(p -> RowFactory.create(
+ p.getId(),
+ p.getTitle().get(0).getValue(),
+ p.getDescription().size()>0? p.getDescription().get(0).getValue(): "",
+ p.getSubject().stream().map(StructuredProperty::getValue).collect(Collectors.joining(" ")))),
+ inputSchema);
+
+ log.info("Generating word embeddings");
+ Dataset wordEmbeddingsData = ScalaFeatureTransformer.wordEmbeddings(inputData, "subjects", wordEmbeddingModel);
+
+ log.info("Generating bert embeddings");
+ Dataset bertEmbeddingsData = ScalaFeatureTransformer.bertEmbeddings(wordEmbeddingsData, "title", bertModel);
+
+ log.info("Generating bert sentence embeddings");
+ Dataset bertSentenceEmbeddingsData = ScalaFeatureTransformer.bertSentenceEmbeddings(bertEmbeddingsData, "abstract", bertSentenceModel);
+
+ Dataset features = bertSentenceEmbeddingsData.select("id", ScalaFeatureTransformer.WORD_EMBEDDINGS_COL(), ScalaFeatureTransformer.BERT_EMBEDDINGS_COL(), ScalaFeatureTransformer.BERT_SENTENCE_EMBEDDINGS_COL());
+
+ features
+ .write()
+ .mode(SaveMode.Overwrite)
+ .save(featuresPath);
+
+ }
+}
diff --git a/dnet-and-test/src/main/java/eu/dnetlib/jobs/SparkLDAAnalysis.java b/dnet-and-test/src/main/java/eu/dnetlib/jobs/featureextraction/lda/SparkLDAAnalysis.java
similarity index 96%
rename from dnet-and-test/src/main/java/eu/dnetlib/jobs/SparkLDAAnalysis.java
rename to dnet-and-test/src/main/java/eu/dnetlib/jobs/featureextraction/lda/SparkLDAAnalysis.java
index ececa9b..c0ee71f 100644
--- a/dnet-and-test/src/main/java/eu/dnetlib/jobs/SparkLDAAnalysis.java
+++ b/dnet-and-test/src/main/java/eu/dnetlib/jobs/featureextraction/lda/SparkLDAAnalysis.java
@@ -1,7 +1,8 @@
-package eu.dnetlib.jobs;
+package eu.dnetlib.jobs.featureextraction.lda;
import com.clearspring.analytics.util.Lists;
import eu.dnetlib.featureextraction.Utilities;
+import eu.dnetlib.jobs.AbstractSparkJob;
import eu.dnetlib.support.ArgumentApplicationParser;
import eu.dnetlib.support.Author;
import eu.dnetlib.support.AuthorsFactory;
@@ -108,7 +109,7 @@ public class SparkLDAAnalysis extends AbstractSparkJob {
else {
bRes = authors.get(i).getOrcid().equals(authors.get(j).getOrcid());
}
- results.add(new Tuple2<>(bRes, cosineSimilarity(authors.get(i).getTopics(), authors.get(j).getTopics())));
+ results.add(new Tuple2<>(bRes, cosineSimilarity(authors.get(i).getEmbeddings().get("lda_topics"), authors.get(j).getEmbeddings().get("lda_topics"))));
j++;
}
i++;
diff --git a/dnet-and-test/src/main/java/eu/dnetlib/jobs/SparkLDAInference.java b/dnet-and-test/src/main/java/eu/dnetlib/jobs/featureextraction/lda/SparkLDAInference.java
similarity index 93%
rename from dnet-and-test/src/main/java/eu/dnetlib/jobs/SparkLDAInference.java
rename to dnet-and-test/src/main/java/eu/dnetlib/jobs/featureextraction/lda/SparkLDAInference.java
index 83814a1..797c386 100644
--- a/dnet-and-test/src/main/java/eu/dnetlib/jobs/SparkLDAInference.java
+++ b/dnet-and-test/src/main/java/eu/dnetlib/jobs/featureextraction/lda/SparkLDAInference.java
@@ -1,6 +1,7 @@
-package eu.dnetlib.jobs;
+package eu.dnetlib.jobs.featureextraction.lda;
import eu.dnetlib.featureextraction.FeatureTransformer;
+import eu.dnetlib.jobs.AbstractSparkJob;
import eu.dnetlib.support.ArgumentApplicationParser;
import org.apache.spark.SparkConf;
import org.apache.spark.ml.clustering.LDAModel;
@@ -15,7 +16,7 @@ import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.util.*;
-public class SparkLDAInference extends AbstractSparkJob{
+public class SparkLDAInference extends AbstractSparkJob {
private static final Logger log = LoggerFactory.getLogger(SparkLDAInference.class);
diff --git a/dnet-and-test/src/main/java/eu/dnetlib/jobs/SparkLDATuning.java b/dnet-and-test/src/main/java/eu/dnetlib/jobs/featureextraction/lda/SparkLDATuning.java
similarity index 95%
rename from dnet-and-test/src/main/java/eu/dnetlib/jobs/SparkLDATuning.java
rename to dnet-and-test/src/main/java/eu/dnetlib/jobs/featureextraction/lda/SparkLDATuning.java
index b4ac8b9..443c30e 100644
--- a/dnet-and-test/src/main/java/eu/dnetlib/jobs/SparkLDATuning.java
+++ b/dnet-and-test/src/main/java/eu/dnetlib/jobs/featureextraction/lda/SparkLDATuning.java
@@ -1,7 +1,8 @@
-package eu.dnetlib.jobs;
+package eu.dnetlib.jobs.featureextraction.lda;
import eu.dnetlib.featureextraction.FeatureTransformer;
import eu.dnetlib.featureextraction.Utilities;
+import eu.dnetlib.jobs.AbstractSparkJob;
import eu.dnetlib.support.ArgumentApplicationParser;
import org.apache.spark.SparkConf;
import org.apache.spark.ml.clustering.LDAModel;
@@ -15,7 +16,7 @@ import scala.Tuple2;
import java.io.IOException;
import java.util.*;
-public class SparkLDATuning extends AbstractSparkJob{
+public class SparkLDATuning extends AbstractSparkJob {
private static final Logger log = LoggerFactory.getLogger(SparkLDATuning.class);
diff --git a/dnet-and-test/src/main/resources/feature_extraction/oozie_app/config-default.xml b/dnet-and-test/src/main/resources/feature_extraction/oozie_app/config-default.xml
new file mode 100644
index 0000000..2e0ed9a
--- /dev/null
+++ b/dnet-and-test/src/main/resources/feature_extraction/oozie_app/config-default.xml
@@ -0,0 +1,18 @@
+
+
+ jobTracker
+ yarnRM
+
+
+ nameNode
+ hdfs://nameservice1
+
+
+ oozie.use.system.libpath
+ true
+
+
+ oozie.action.sharelib.for.spark
+ spark2
+
+
\ No newline at end of file
diff --git a/dnet-and-test/src/main/resources/feature_extraction/oozie_app/workflow.xml b/dnet-and-test/src/main/resources/feature_extraction/oozie_app/workflow.xml
new file mode 100644
index 0000000..2aa77c0
--- /dev/null
+++ b/dnet-and-test/src/main/resources/feature_extraction/oozie_app/workflow.xml
@@ -0,0 +1,172 @@
+
+
+
+ publicationsPath
+ the input entity path
+
+
+ workingPath
+ path for the working directory
+
+
+ numPartitions
+ number of partitions for the spark files
+
+
+ featuresPath
+ location of the embeddings
+
+
+ topicsPath
+ location of the topics
+
+
+ outputPath
+ location of the output authors
+
+
+ bertModel
+ location of the bert model
+
+
+ bertSentenceModel
+ location of the bert sentence model
+
+
+ wordEmbeddingsModel
+ location of the word embeddings model
+
+
+ sparkDriverMemory
+ memory for driver process
+
+
+ sparkExecutorMemory
+ memory for individual executor
+
+
+ sparkExecutorCores
+ number of cores used by single executor
+
+
+ oozieActionShareLibForSpark2
+ oozie action sharelib for spark 2.*
+
+
+ spark2ExtraListeners
+ com.cloudera.spark.lineage.NavigatorAppListener
+ spark 2.* extra listeners classname
+
+
+ spark2SqlQueryExecutionListeners
+ com.cloudera.spark.lineage.NavigatorQueryListener
+ spark 2.* sql query execution listeners classname
+
+
+ spark2YarnHistoryServerAddress
+ spark 2.* yarn history server address
+
+
+ spark2EventLogDir
+ spark 2.* event log dir location
+
+
+
+
+ ${jobTracker}
+ ${nameNode}
+
+
+ mapreduce.job.queuename
+ ${queueName}
+
+
+ oozie.launcher.mapred.job.queue.name
+ ${oozieLauncherQueueName}
+
+
+ oozie.action.sharelib.for.spark
+ ${oozieActionShareLibForSpark2}
+
+
+
+
+
+
+
+ Action failed, error message[${wf:errorMessage(wf:lastErrorNode())}]
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ yarn
+ cluster
+ Publication Feature Extraction
+ eu.dnetlib.jobs.featureextraction.SparkPublicationFeatureExtractor
+ dnet-and-test-${projectVersion}.jar
+
+ --num-executors=32
+ --executor-memory=${sparkExecutorMemory}
+ --executor-cores=${sparkExecutorCores}
+ --driver-memory=${sparkDriverMemory}
+ --conf spark.extraListeners=${spark2ExtraListeners}
+ --conf spark.sql.queryExecutionListeners=${spark2SqlQueryExecutionListeners}
+ --conf spark.yarn.historyServer.address=${spark2YarnHistoryServerAddress}
+ --conf spark.eventLog.dir=${nameNode}${spark2EventLogDir}
+ --conf spark.sql.shuffle.partitions=3840
+ --conf spark.dynamicAllocation.enabled=false
+
+ --publicationsPath${publicationsPath}
+ --workingPath${workingPath}
+ --numPartitions${numPartitions}
+ --featuresPath${featuresPath}
+ --wordEmbeddingsModel${wordEmbeddingsModel}
+ --bertModel${bertModel}
+ --bertSentenceModel${bertSentenceModel}
+
+
+
+
+
+
+
+ yarn
+ cluster
+ Author Extraction
+ eu.dnetlib.jobs.featureextraction.SparkAuthorExtractor
+ dnet-and-test-${projectVersion}.jar
+
+ --num-executors=32
+ --executor-memory=${sparkExecutorMemory}
+ --executor-cores=${sparkExecutorCores}
+ --driver-memory=${sparkDriverMemory}
+ --conf spark.extraListeners=${spark2ExtraListeners}
+ --conf spark.sql.queryExecutionListeners=${spark2SqlQueryExecutionListeners}
+ --conf spark.yarn.historyServer.address=${spark2YarnHistoryServerAddress}
+ --conf spark.eventLog.dir=${nameNode}${spark2EventLogDir}
+ --conf spark.sql.shuffle.partitions=3840
+ --conf spark.dynamicAllocation.enabled=true
+
+ --workingPath${workingPath}
+ --numPartitions${numPartitions}
+ --publicationsPath${publicationsPath}
+ --topicsPath${topicsPath}
+ --featuresPath${featuresPath}
+ --outputPath${outputPath}
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/dnet-and-test/src/main/resources/jobs/parameters/authorExtractor_parameters.json b/dnet-and-test/src/main/resources/jobs/parameters/authorExtractor_parameters.json
index 4905b64..bee9c78 100644
--- a/dnet-and-test/src/main/resources/jobs/parameters/authorExtractor_parameters.json
+++ b/dnet-and-test/src/main/resources/jobs/parameters/authorExtractor_parameters.json
@@ -6,11 +6,23 @@
"paramRequired": true
},
{
- "paramName": "e",
- "paramLongName": "entitiesPath",
+ "paramName": "p",
+ "paramLongName": "publicationsPath",
"paramDescription": "location of the input entities",
"paramRequired": true
},
+ {
+ "paramName": "t",
+ "paramLongName": "topicsPath",
+ "paramDescription": "location of the lda topics",
+ "paramRequired": true
+ },
+ {
+ "paramName": "f",
+ "paramLongName": "featuresPath",
+ "paramDescription": "location of the features",
+ "paramRequired": true
+ },
{
"paramName": "np",
"paramLongName": "numPartitions",
@@ -22,11 +34,5 @@
"paramLongName": "outputPath",
"paramDescription": "location of the output author extracted",
"paramRequired": false
- },
- {
- "paramName": "t",
- "paramLongName": "topicsPath",
- "paramDescription": "location of the lda topics",
- "paramRequired": false
}
]
\ No newline at end of file
diff --git a/dnet-and-test/src/main/resources/jobs/parameters/graphClassificationTraining_parameters.json b/dnet-and-test/src/main/resources/jobs/parameters/graphClassificationTraining_parameters.json
deleted file mode 100644
index 8e6b4d2..0000000
--- a/dnet-and-test/src/main/resources/jobs/parameters/graphClassificationTraining_parameters.json
+++ /dev/null
@@ -1,14 +0,0 @@
-[
- {
- "paramName": "w",
- "paramLongName": "workingPath",
- "paramDescription": "path of the working directory",
- "paramRequired": true
- },
- {
- "paramName": "np",
- "paramLongName": "numPartitions",
- "paramDescription": "number of partitions for the similarity relations intermediate phases",
- "paramRequired": false
- }
-]
\ No newline at end of file
diff --git a/dnet-and-test/src/main/resources/jobs/parameters/publicationFeatureExtractor_parameters.json b/dnet-and-test/src/main/resources/jobs/parameters/publicationFeatureExtractor_parameters.json
new file mode 100644
index 0000000..db5af73
--- /dev/null
+++ b/dnet-and-test/src/main/resources/jobs/parameters/publicationFeatureExtractor_parameters.json
@@ -0,0 +1,44 @@
+[
+ {
+ "paramName": "w",
+ "paramLongName": "workingPath",
+ "paramDescription": "path of the working directory",
+ "paramRequired": true
+ },
+ {
+ "paramName": "np",
+ "paramLongName": "numPartitions",
+ "paramDescription": "number of partitions for the similarity relations intermediate phases",
+ "paramRequired": false
+ },
+ {
+ "paramName": "p",
+ "paramLongName": "publicationsPath",
+ "paramDescription": "location of the publications",
+ "paramRequired": true
+ },
+ {
+ "paramName": "f",
+ "paramLongName": "featuresPath",
+ "paramDescription": "location of the features",
+ "paramRequired": true
+ },
+ {
+ "paramName": "we",
+ "paramLongName": "wordEmbeddingsModel",
+ "paramDescription": "path of the word embeddings model",
+ "paramRequired": true
+ },
+ {
+ "paramName": "bm",
+ "paramLongName": "bertModel",
+ "paramDescription": "path of the bert model",
+ "paramRequired": true
+ },
+ {
+ "paramName": "bs",
+ "paramLongName": "bertSentenceModel",
+ "paramDescription": "path of the bert sentence model",
+ "paramRequired": true
+ }
+]
\ No newline at end of file
diff --git a/dnet-and-test/src/main/resources/lda_inference/oozie_app/workflow.xml b/dnet-and-test/src/main/resources/lda_inference/oozie_app/workflow.xml
index 4537200..e16e896 100644
--- a/dnet-and-test/src/main/resources/lda_inference/oozie_app/workflow.xml
+++ b/dnet-and-test/src/main/resources/lda_inference/oozie_app/workflow.xml
@@ -155,7 +155,7 @@
yarn
cluster
LDA Inference
- eu.dnetlib.jobs.SparkLDAInference
+ eu.dnetlib.jobs.featureextraction.lda.SparkLDAInference
dnet-and-test-${projectVersion}.jar
--executor-memory=${sparkExecutorMemory}
@@ -172,63 +172,8 @@
--ldaModelPath${ldaModelPath}
--numPartitions${numPartitions}
-
-
-
-
-
-
- yarn
- cluster
- LDA Inference
- eu.dnetlib.jobs.SparkAuthorExtractor
- dnet-and-test-${projectVersion}.jar
-
- --executor-memory=${sparkExecutorMemory}
- --executor-cores=${sparkExecutorCores}
- --driver-memory=${sparkDriverMemory}
- --conf spark.extraListeners=${spark2ExtraListeners}
- --conf spark.sql.queryExecutionListeners=${spark2SqlQueryExecutionListeners}
- --conf spark.yarn.historyServer.address=${spark2YarnHistoryServerAddress}
- --conf spark.eventLog.dir=${nameNode}${spark2EventLogDir}
- --conf spark.sql.shuffle.partitions=3840
-
- --entitiesPath${entitiesPath}
- --workingPath${workingPath}
- --outputPath${authorsPath}
- --numPartitions${numPartitions}
- --topicsPath${ldaInferencePath}
-
-
-
-
-
-
-
- yarn
- cluster
- LDA Threshold Analysis
- eu.dnetlib.jobs.SparkLDAAnalysis
- dnet-and-test-${projectVersion}.jar
-
- --num-executors=32
- --executor-memory=${sparkExecutorMemory}
- --executor-cores=${sparkExecutorCores}
- --driver-memory=${sparkDriverMemory}
- --conf spark.extraListeners=${spark2ExtraListeners}
- --conf spark.sql.queryExecutionListeners=${spark2SqlQueryExecutionListeners}
- --conf spark.yarn.historyServer.address=${spark2YarnHistoryServerAddress}
- --conf spark.eventLog.dir=${nameNode}${spark2EventLogDir}
- --conf spark.sql.shuffle.partitions=3840
- --conf spark.dynamicAllocation.enabled=false
-
- --authorsPath${authorsPath}
- --workingPath${workingPath}
- --numPartitions${numPartitions}
-
-
\ No newline at end of file
diff --git a/dnet-and-test/src/main/resources/lda_tuning/oozie_app/workflow.xml b/dnet-and-test/src/main/resources/lda_tuning/oozie_app/workflow.xml
index add9e78..e109b7b 100644
--- a/dnet-and-test/src/main/resources/lda_tuning/oozie_app/workflow.xml
+++ b/dnet-and-test/src/main/resources/lda_tuning/oozie_app/workflow.xml
@@ -195,7 +195,7 @@
yarn
cluster
LDA Tuning
- eu.dnetlib.jobs.SparkLDATuning
+ eu.dnetlib.jobs.featureextraction.lda.SparkLDATuning
dnet-and-test-${projectVersion}.jar
--executor-memory=${sparkExecutorMemory}
diff --git a/dnet-and-test/src/test/java/eu/dnetlib/jobs/deeplearning/GNNTrainingTest.java b/dnet-and-test/src/test/java/eu/dnetlib/jobs/featureextraction/FeatureExtractionJobTest.java
similarity index 63%
rename from dnet-and-test/src/test/java/eu/dnetlib/jobs/deeplearning/GNNTrainingTest.java
rename to dnet-and-test/src/test/java/eu/dnetlib/jobs/featureextraction/FeatureExtractionJobTest.java
index 762e175..2a57b5e 100644
--- a/dnet-and-test/src/test/java/eu/dnetlib/jobs/deeplearning/GNNTrainingTest.java
+++ b/dnet-and-test/src/test/java/eu/dnetlib/jobs/featureextraction/FeatureExtractionJobTest.java
@@ -1,6 +1,7 @@
-package eu.dnetlib.jobs.deeplearning;
+package eu.dnetlib.jobs.featureextraction;
import eu.dnetlib.jobs.AbstractSparkJob;
+import eu.dnetlib.jobs.SparkTokenizer;
import eu.dnetlib.support.ArgumentApplicationParser;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.IOUtils;
@@ -15,22 +16,22 @@ import java.nio.file.Paths;
@TestMethodOrder(MethodOrderer.OrderAnnotation.class)
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
-public class GNNTrainingTest {
-
+public class FeatureExtractionJobTest {
static SparkSession spark;
static JavaSparkContext context;
final static String workingPath = "/tmp/working_dir";
- final static String numPartitions = "20";
final String inputDataPath = Paths
- .get(getClass().getResource("/eu/dnetlib/jobs/examples/authors.groups.example.json").toURI())
+ .get(getClass().getResource("/eu/dnetlib/jobs/examples/publications.subset.json").toURI())
.toFile()
.getAbsolutePath();
- final static String groundTruthJPath = "$.orcid";
- final static String idJPath = "$.id";
- final static String featuresJPath = "$.topics";
- public GNNTrainingTest() throws URISyntaxException {}
+ final String ldaTopicsPath = Paths
+ .get(getClass().getResource("/eu/dnetlib/jobs/examples/publications_lda_topics_subset").toURI())
+ .toFile()
+ .getAbsolutePath();
+
+ public FeatureExtractionJobTest() throws URISyntaxException {}
public static void cleanup() throws IOException {
//remove directories and clean workspace
@@ -57,43 +58,43 @@ public class GNNTrainingTest {
@Test
@Order(1)
- public void createGroupDataSetTest() throws Exception {
- ArgumentApplicationParser parser = new ArgumentApplicationParser(readResource("/jobs/parameters/createGroupDataset_parameters.json", SparkCreateGroupDataSet.class));
+ public void publicationFeatureExtractionTest() throws Exception {
+ ArgumentApplicationParser parser = new ArgumentApplicationParser(readResource("/jobs/parameters/publicationFeatureExtractor_parameters.json", SparkTokenizer.class));
parser.parseArgument(
new String[] {
- "-i", inputDataPath,
- "-gt", groundTruthJPath,
- "-id", idJPath,
- "-f", featuresJPath,
+ "-p", inputDataPath,
"-w", workingPath,
- "-np", numPartitions
+ "-np", "20"
}
);
- new SparkCreateGroupDataSet(
+ new SparkPublicationFeatureExtractor(
parser,
spark
).run();
-
}
@Test
@Order(2)
- public void graphClassificationTrainingTest() throws Exception{
- ArgumentApplicationParser parser = new ArgumentApplicationParser(readResource("/jobs/parameters/graphClassificationTraining_parameters.json", SparkGraphClassificationTraining.class));
+ public void authorExtractionTest() throws Exception {
+ ArgumentApplicationParser parser = new ArgumentApplicationParser(readResource("/jobs/parameters/authorExtractor_parameters.json", SparkAuthorExtractor.class));
parser.parseArgument(
- new String[] {
+ new String[]{
+ "-p", inputDataPath,
"-w", workingPath,
- "-np", numPartitions
- }
- );
+ "-np", "20",
+ "-t", ldaTopicsPath,
+ "-f", workingPath + "/publication_features",
+ "-o", workingPath + "/authors"
+ });
- new SparkGraphClassificationTraining(
+ new SparkAuthorExtractor(
parser,
spark
).run();
+
}
public static String readResource(String path, Class extends AbstractSparkJob> clazz) throws IOException {
diff --git a/dnet-and-test/src/test/java/eu/dnetlib/jobs/LDAAnalysisTest.java b/dnet-and-test/src/test/java/eu/dnetlib/jobs/featureextraction/lda/LDAAnalysisTest.java
similarity index 68%
rename from dnet-and-test/src/test/java/eu/dnetlib/jobs/LDAAnalysisTest.java
rename to dnet-and-test/src/test/java/eu/dnetlib/jobs/featureextraction/lda/LDAAnalysisTest.java
index 8c59008..80cbf90 100644
--- a/dnet-and-test/src/test/java/eu/dnetlib/jobs/LDAAnalysisTest.java
+++ b/dnet-and-test/src/test/java/eu/dnetlib/jobs/featureextraction/lda/LDAAnalysisTest.java
@@ -1,5 +1,9 @@
-package eu.dnetlib.jobs;
+package eu.dnetlib.jobs.featureextraction.lda;
+import eu.dnetlib.jobs.AbstractSparkJob;
+import eu.dnetlib.jobs.SparkCountVectorizer;
+import eu.dnetlib.jobs.SparkCreateVocabulary;
+import eu.dnetlib.jobs.SparkTokenizer;
import eu.dnetlib.support.ArgumentApplicationParser;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.IOUtils;
@@ -157,74 +161,12 @@ public class LDAAnalysisTest {
parser,
spark
).run();
- }
-
- @Test
- @Order(6)
- public void authorExtractorTest() throws Exception {
- ArgumentApplicationParser parser = new ArgumentApplicationParser(readResource("/jobs/parameters/authorExtractor_parameters.json", SparkLDAInference.class));
-
- parser.parseArgument(
- new String[]{
- "-e", inputDataPath,
- "-o", authorsPath,
- "-t", topicsPath,
- "-w", workingPath,
- "-np", numPartitions
- });
-
- new SparkAuthorExtractor(
- parser,
- spark
- ).run();
- }
-
- @Test
- @Order(7)
- public void ldaAnalysis() throws Exception {
- ArgumentApplicationParser parser = new ArgumentApplicationParser(readResource("/jobs/parameters/ldaAnalysis_parameters.json", SparkLDAAnalysis.class));
-
- parser.parseArgument(
- new String[]{
- "-i", authorsPath,
- "-w", workingPath,
- "-np", numPartitions
- });
-
- new SparkLDAAnalysis(
- parser,
- spark
- ).run();
-
- Thread.sleep(1000000000);
+ Thread.sleep(100000);
}
public static String readResource(String path, Class extends AbstractSparkJob> clazz) throws IOException {
return IOUtils.toString(clazz.getResourceAsStream(path));
}
-
-// @Test
-// public void createVocabulary() {
-//
-// StructType inputSchema = new StructType(new StructField[]{
-// new StructField("id", DataTypes.StringType, false, Metadata.empty()),
-// new StructField("sentence", DataTypes.StringType, false, Metadata.empty())
-// });
-//
-// JavaSparkContext sc = JavaSparkContext.fromSparkContext(spark.sparkContext());
-// JavaRDD rows = sc.textFile("/Users/miconis/Desktop/dewey").map(s -> s.substring(4)).map(s -> Utilities.normalize(s).replaceAll(" ", " ")).filter(s -> !s.contains("unassigned")).map(s -> RowFactory.create("id", s));
-//
-// Dataset dataFrame = spark.createDataFrame(rows, inputSchema);
-//
-// dataFrame = FeatureTransformer.tokenizeData(dataFrame);
-//
-// JavaRDD map = dataFrame.toJavaRDD().map(r -> r.getList(1)).flatMap(l -> l.iterator()).map(s -> s.toString()).distinct();
-//
-// map.coalesce(1).saveAsTextFile("/tmp/vocab_raw");
-// System.out.println("map = " + map.count());
-// System.out.println("dataFrame = " + map.first());
-// }
-
}
diff --git a/dnet-and-test/src/test/resources/eu/dnetlib/jobs/examples/publications_lda_topics_subset/._SUCCESS.crc b/dnet-and-test/src/test/resources/eu/dnetlib/jobs/examples/publications_lda_topics_subset/._SUCCESS.crc
new file mode 100644
index 0000000..3b7b044
Binary files /dev/null and b/dnet-and-test/src/test/resources/eu/dnetlib/jobs/examples/publications_lda_topics_subset/._SUCCESS.crc differ
diff --git a/dnet-and-test/src/test/resources/eu/dnetlib/jobs/examples/publications_lda_topics_subset/.part-00000-c3abd217-3f3a-4ab9-a993-0cfde5f36081-c000.snappy.parquet.crc b/dnet-and-test/src/test/resources/eu/dnetlib/jobs/examples/publications_lda_topics_subset/.part-00000-c3abd217-3f3a-4ab9-a993-0cfde5f36081-c000.snappy.parquet.crc
new file mode 100644
index 0000000..e546b9f
Binary files /dev/null and b/dnet-and-test/src/test/resources/eu/dnetlib/jobs/examples/publications_lda_topics_subset/.part-00000-c3abd217-3f3a-4ab9-a993-0cfde5f36081-c000.snappy.parquet.crc differ
diff --git a/dnet-and-test/src/test/resources/eu/dnetlib/jobs/examples/publications_lda_topics_subset/.part-00001-c3abd217-3f3a-4ab9-a993-0cfde5f36081-c000.snappy.parquet.crc b/dnet-and-test/src/test/resources/eu/dnetlib/jobs/examples/publications_lda_topics_subset/.part-00001-c3abd217-3f3a-4ab9-a993-0cfde5f36081-c000.snappy.parquet.crc
new file mode 100644
index 0000000..01d9d71
Binary files /dev/null and b/dnet-and-test/src/test/resources/eu/dnetlib/jobs/examples/publications_lda_topics_subset/.part-00001-c3abd217-3f3a-4ab9-a993-0cfde5f36081-c000.snappy.parquet.crc differ
diff --git a/dnet-and-test/src/test/resources/eu/dnetlib/jobs/examples/publications_lda_topics_subset/.part-00002-c3abd217-3f3a-4ab9-a993-0cfde5f36081-c000.snappy.parquet.crc b/dnet-and-test/src/test/resources/eu/dnetlib/jobs/examples/publications_lda_topics_subset/.part-00002-c3abd217-3f3a-4ab9-a993-0cfde5f36081-c000.snappy.parquet.crc
new file mode 100644
index 0000000..21019f7
Binary files /dev/null and b/dnet-and-test/src/test/resources/eu/dnetlib/jobs/examples/publications_lda_topics_subset/.part-00002-c3abd217-3f3a-4ab9-a993-0cfde5f36081-c000.snappy.parquet.crc differ
diff --git a/dnet-and-test/src/test/resources/eu/dnetlib/jobs/examples/publications_lda_topics_subset/.part-00003-c3abd217-3f3a-4ab9-a993-0cfde5f36081-c000.snappy.parquet.crc b/dnet-and-test/src/test/resources/eu/dnetlib/jobs/examples/publications_lda_topics_subset/.part-00003-c3abd217-3f3a-4ab9-a993-0cfde5f36081-c000.snappy.parquet.crc
new file mode 100644
index 0000000..ab35987
Binary files /dev/null and b/dnet-and-test/src/test/resources/eu/dnetlib/jobs/examples/publications_lda_topics_subset/.part-00003-c3abd217-3f3a-4ab9-a993-0cfde5f36081-c000.snappy.parquet.crc differ
diff --git a/dnet-and-test/src/test/resources/eu/dnetlib/jobs/examples/publications_lda_topics_subset/.part-00004-c3abd217-3f3a-4ab9-a993-0cfde5f36081-c000.snappy.parquet.crc b/dnet-and-test/src/test/resources/eu/dnetlib/jobs/examples/publications_lda_topics_subset/.part-00004-c3abd217-3f3a-4ab9-a993-0cfde5f36081-c000.snappy.parquet.crc
new file mode 100644
index 0000000..09aeb18
Binary files /dev/null and b/dnet-and-test/src/test/resources/eu/dnetlib/jobs/examples/publications_lda_topics_subset/.part-00004-c3abd217-3f3a-4ab9-a993-0cfde5f36081-c000.snappy.parquet.crc differ
diff --git a/dnet-and-test/src/test/resources/eu/dnetlib/jobs/examples/publications_lda_topics_subset/.part-00005-c3abd217-3f3a-4ab9-a993-0cfde5f36081-c000.snappy.parquet.crc b/dnet-and-test/src/test/resources/eu/dnetlib/jobs/examples/publications_lda_topics_subset/.part-00005-c3abd217-3f3a-4ab9-a993-0cfde5f36081-c000.snappy.parquet.crc
new file mode 100644
index 0000000..f4d51d0
Binary files /dev/null and b/dnet-and-test/src/test/resources/eu/dnetlib/jobs/examples/publications_lda_topics_subset/.part-00005-c3abd217-3f3a-4ab9-a993-0cfde5f36081-c000.snappy.parquet.crc differ
diff --git a/dnet-and-test/src/test/resources/eu/dnetlib/jobs/examples/publications_lda_topics_subset/.part-00006-c3abd217-3f3a-4ab9-a993-0cfde5f36081-c000.snappy.parquet.crc b/dnet-and-test/src/test/resources/eu/dnetlib/jobs/examples/publications_lda_topics_subset/.part-00006-c3abd217-3f3a-4ab9-a993-0cfde5f36081-c000.snappy.parquet.crc
new file mode 100644
index 0000000..5383af4
Binary files /dev/null and b/dnet-and-test/src/test/resources/eu/dnetlib/jobs/examples/publications_lda_topics_subset/.part-00006-c3abd217-3f3a-4ab9-a993-0cfde5f36081-c000.snappy.parquet.crc differ
diff --git a/dnet-and-test/src/test/resources/eu/dnetlib/jobs/examples/publications_lda_topics_subset/.part-00007-c3abd217-3f3a-4ab9-a993-0cfde5f36081-c000.snappy.parquet.crc b/dnet-and-test/src/test/resources/eu/dnetlib/jobs/examples/publications_lda_topics_subset/.part-00007-c3abd217-3f3a-4ab9-a993-0cfde5f36081-c000.snappy.parquet.crc
new file mode 100644
index 0000000..72a90c4
Binary files /dev/null and b/dnet-and-test/src/test/resources/eu/dnetlib/jobs/examples/publications_lda_topics_subset/.part-00007-c3abd217-3f3a-4ab9-a993-0cfde5f36081-c000.snappy.parquet.crc differ
diff --git a/dnet-and-test/src/test/resources/eu/dnetlib/jobs/examples/publications_lda_topics_subset/.part-00008-c3abd217-3f3a-4ab9-a993-0cfde5f36081-c000.snappy.parquet.crc b/dnet-and-test/src/test/resources/eu/dnetlib/jobs/examples/publications_lda_topics_subset/.part-00008-c3abd217-3f3a-4ab9-a993-0cfde5f36081-c000.snappy.parquet.crc
new file mode 100644
index 0000000..cd57565
Binary files /dev/null and b/dnet-and-test/src/test/resources/eu/dnetlib/jobs/examples/publications_lda_topics_subset/.part-00008-c3abd217-3f3a-4ab9-a993-0cfde5f36081-c000.snappy.parquet.crc differ
diff --git a/dnet-and-test/src/test/resources/eu/dnetlib/jobs/examples/publications_lda_topics_subset/.part-00009-c3abd217-3f3a-4ab9-a993-0cfde5f36081-c000.snappy.parquet.crc b/dnet-and-test/src/test/resources/eu/dnetlib/jobs/examples/publications_lda_topics_subset/.part-00009-c3abd217-3f3a-4ab9-a993-0cfde5f36081-c000.snappy.parquet.crc
new file mode 100644
index 0000000..7733d1a
Binary files /dev/null and b/dnet-and-test/src/test/resources/eu/dnetlib/jobs/examples/publications_lda_topics_subset/.part-00009-c3abd217-3f3a-4ab9-a993-0cfde5f36081-c000.snappy.parquet.crc differ
diff --git a/dnet-and-test/src/test/resources/eu/dnetlib/jobs/examples/publications_lda_topics_subset/_SUCCESS b/dnet-and-test/src/test/resources/eu/dnetlib/jobs/examples/publications_lda_topics_subset/_SUCCESS
new file mode 100644
index 0000000..e69de29
diff --git a/dnet-and-test/src/test/resources/eu/dnetlib/jobs/examples/publications_lda_topics_subset/part-00000-c3abd217-3f3a-4ab9-a993-0cfde5f36081-c000.snappy.parquet b/dnet-and-test/src/test/resources/eu/dnetlib/jobs/examples/publications_lda_topics_subset/part-00000-c3abd217-3f3a-4ab9-a993-0cfde5f36081-c000.snappy.parquet
new file mode 100644
index 0000000..c0e22dc
Binary files /dev/null and b/dnet-and-test/src/test/resources/eu/dnetlib/jobs/examples/publications_lda_topics_subset/part-00000-c3abd217-3f3a-4ab9-a993-0cfde5f36081-c000.snappy.parquet differ
diff --git a/dnet-and-test/src/test/resources/eu/dnetlib/jobs/examples/publications_lda_topics_subset/part-00001-c3abd217-3f3a-4ab9-a993-0cfde5f36081-c000.snappy.parquet b/dnet-and-test/src/test/resources/eu/dnetlib/jobs/examples/publications_lda_topics_subset/part-00001-c3abd217-3f3a-4ab9-a993-0cfde5f36081-c000.snappy.parquet
new file mode 100644
index 0000000..1235280
Binary files /dev/null and b/dnet-and-test/src/test/resources/eu/dnetlib/jobs/examples/publications_lda_topics_subset/part-00001-c3abd217-3f3a-4ab9-a993-0cfde5f36081-c000.snappy.parquet differ
diff --git a/dnet-and-test/src/test/resources/eu/dnetlib/jobs/examples/publications_lda_topics_subset/part-00002-c3abd217-3f3a-4ab9-a993-0cfde5f36081-c000.snappy.parquet b/dnet-and-test/src/test/resources/eu/dnetlib/jobs/examples/publications_lda_topics_subset/part-00002-c3abd217-3f3a-4ab9-a993-0cfde5f36081-c000.snappy.parquet
new file mode 100644
index 0000000..9bc1662
Binary files /dev/null and b/dnet-and-test/src/test/resources/eu/dnetlib/jobs/examples/publications_lda_topics_subset/part-00002-c3abd217-3f3a-4ab9-a993-0cfde5f36081-c000.snappy.parquet differ
diff --git a/dnet-and-test/src/test/resources/eu/dnetlib/jobs/examples/publications_lda_topics_subset/part-00003-c3abd217-3f3a-4ab9-a993-0cfde5f36081-c000.snappy.parquet b/dnet-and-test/src/test/resources/eu/dnetlib/jobs/examples/publications_lda_topics_subset/part-00003-c3abd217-3f3a-4ab9-a993-0cfde5f36081-c000.snappy.parquet
new file mode 100644
index 0000000..2bca203
Binary files /dev/null and b/dnet-and-test/src/test/resources/eu/dnetlib/jobs/examples/publications_lda_topics_subset/part-00003-c3abd217-3f3a-4ab9-a993-0cfde5f36081-c000.snappy.parquet differ
diff --git a/dnet-and-test/src/test/resources/eu/dnetlib/jobs/examples/publications_lda_topics_subset/part-00004-c3abd217-3f3a-4ab9-a993-0cfde5f36081-c000.snappy.parquet b/dnet-and-test/src/test/resources/eu/dnetlib/jobs/examples/publications_lda_topics_subset/part-00004-c3abd217-3f3a-4ab9-a993-0cfde5f36081-c000.snappy.parquet
new file mode 100644
index 0000000..5b2b06d
Binary files /dev/null and b/dnet-and-test/src/test/resources/eu/dnetlib/jobs/examples/publications_lda_topics_subset/part-00004-c3abd217-3f3a-4ab9-a993-0cfde5f36081-c000.snappy.parquet differ
diff --git a/dnet-and-test/src/test/resources/eu/dnetlib/jobs/examples/publications_lda_topics_subset/part-00005-c3abd217-3f3a-4ab9-a993-0cfde5f36081-c000.snappy.parquet b/dnet-and-test/src/test/resources/eu/dnetlib/jobs/examples/publications_lda_topics_subset/part-00005-c3abd217-3f3a-4ab9-a993-0cfde5f36081-c000.snappy.parquet
new file mode 100644
index 0000000..4d8d74d
Binary files /dev/null and b/dnet-and-test/src/test/resources/eu/dnetlib/jobs/examples/publications_lda_topics_subset/part-00005-c3abd217-3f3a-4ab9-a993-0cfde5f36081-c000.snappy.parquet differ
diff --git a/dnet-and-test/src/test/resources/eu/dnetlib/jobs/examples/publications_lda_topics_subset/part-00006-c3abd217-3f3a-4ab9-a993-0cfde5f36081-c000.snappy.parquet b/dnet-and-test/src/test/resources/eu/dnetlib/jobs/examples/publications_lda_topics_subset/part-00006-c3abd217-3f3a-4ab9-a993-0cfde5f36081-c000.snappy.parquet
new file mode 100644
index 0000000..d1626ce
Binary files /dev/null and b/dnet-and-test/src/test/resources/eu/dnetlib/jobs/examples/publications_lda_topics_subset/part-00006-c3abd217-3f3a-4ab9-a993-0cfde5f36081-c000.snappy.parquet differ
diff --git a/dnet-and-test/src/test/resources/eu/dnetlib/jobs/examples/publications_lda_topics_subset/part-00007-c3abd217-3f3a-4ab9-a993-0cfde5f36081-c000.snappy.parquet b/dnet-and-test/src/test/resources/eu/dnetlib/jobs/examples/publications_lda_topics_subset/part-00007-c3abd217-3f3a-4ab9-a993-0cfde5f36081-c000.snappy.parquet
new file mode 100644
index 0000000..8bd1a75
Binary files /dev/null and b/dnet-and-test/src/test/resources/eu/dnetlib/jobs/examples/publications_lda_topics_subset/part-00007-c3abd217-3f3a-4ab9-a993-0cfde5f36081-c000.snappy.parquet differ
diff --git a/dnet-and-test/src/test/resources/eu/dnetlib/jobs/examples/publications_lda_topics_subset/part-00008-c3abd217-3f3a-4ab9-a993-0cfde5f36081-c000.snappy.parquet b/dnet-and-test/src/test/resources/eu/dnetlib/jobs/examples/publications_lda_topics_subset/part-00008-c3abd217-3f3a-4ab9-a993-0cfde5f36081-c000.snappy.parquet
new file mode 100644
index 0000000..b70a199
Binary files /dev/null and b/dnet-and-test/src/test/resources/eu/dnetlib/jobs/examples/publications_lda_topics_subset/part-00008-c3abd217-3f3a-4ab9-a993-0cfde5f36081-c000.snappy.parquet differ
diff --git a/dnet-and-test/src/test/resources/eu/dnetlib/jobs/examples/publications_lda_topics_subset/part-00009-c3abd217-3f3a-4ab9-a993-0cfde5f36081-c000.snappy.parquet b/dnet-and-test/src/test/resources/eu/dnetlib/jobs/examples/publications_lda_topics_subset/part-00009-c3abd217-3f3a-4ab9-a993-0cfde5f36081-c000.snappy.parquet
new file mode 100644
index 0000000..00c3e19
Binary files /dev/null and b/dnet-and-test/src/test/resources/eu/dnetlib/jobs/examples/publications_lda_topics_subset/part-00009-c3abd217-3f3a-4ab9-a993-0cfde5f36081-c000.snappy.parquet differ
diff --git a/dnet-feature-extraction/pom.xml b/dnet-feature-extraction/pom.xml
index 6a169d8..7e65e43 100644
--- a/dnet-feature-extraction/pom.xml
+++ b/dnet-feature-extraction/pom.xml
@@ -13,6 +13,36 @@
dnet-feature-extraction
jar
+
+
+
+ net.alchim31.maven
+ scala-maven-plugin
+ 4.0.1
+
+
+ scala-compile-first
+ initialize
+
+ add-source
+ compile
+
+
+
+ scala-test-compile
+ process-test-resources
+
+ testCompile
+
+
+
+
+ ${scala.version}
+
+
+
+
+
org.apache.spark
@@ -53,42 +83,36 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
- org.nd4j
- ${nd4j.backend}
-
-
- org.deeplearning4j
- deeplearning4j-core
-
-
- org.deeplearning4j
- deeplearning4j-datasets
-
-
- org.deeplearning4j
- dl4j-spark-parameterserver_2.11
-
-
- org.deeplearning4j
- dl4j-spark_2.11
+ com.johnsnowlabs.nlp
+ spark-nlp_${scala.binary.version}
-
-
- jfree
- jfreechart
-
-
- org.jfree
- jcommon
-
-
-
-
- eu.dnetlib
- dnet-dedup-test
-
+
+
+
+
diff --git a/dnet-feature-extraction/src/main/java/eu/dnetlib/deeplearning/GroupClassifier.java b/dnet-feature-extraction/src/main/java/eu/dnetlib/deeplearning/GroupClassifier.java
deleted file mode 100644
index 09435ff..0000000
--- a/dnet-feature-extraction/src/main/java/eu/dnetlib/deeplearning/GroupClassifier.java
+++ /dev/null
@@ -1,130 +0,0 @@
-//package eu.dnetlib.deeplearning;
-//
-///* *****************************************************************************
-// *
-// *
-// *
-// * This program and the accompanying materials are made available under the
-// * terms of the Apache License, Version 2.0 which is available at
-// * https://www.apache.org/licenses/LICENSE-2.0.
-// * See the NOTICE file distributed with this work for additional
-// * information regarding copyright ownership.
-// *
-// * Unless required by applicable law or agreed to in writing, software
-// * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
-// * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
-// * License for the specific language governing permissions and limitations
-// * under the License.
-// *
-// * SPDX-License-Identifier: Apache-2.0
-// ******************************************************************************/
-//
-//import org.datavec.api.records.reader.RecordReader;
-//import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
-//import org.datavec.api.split.FileSplit;
-//import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
-//import org.deeplearning4j.examples.utils.DownloaderUtility;
-//import org.deeplearning4j.examples.utils.PlotUtil;
-//import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
-//import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
-//import org.deeplearning4j.nn.conf.layers.DenseLayer;
-//import org.deeplearning4j.nn.conf.layers.OutputLayer;
-//import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
-//import org.deeplearning4j.nn.weights.WeightInit;
-//import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
-//import org.nd4j.evaluation.classification.Evaluation;
-//import org.nd4j.linalg.activations.Activation;
-//import org.nd4j.linalg.api.ndarray.INDArray;
-//import org.nd4j.linalg.dataset.DataSet;
-//import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
-//import org.nd4j.linalg.learning.config.Nesterovs;
-//import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;
-//
-//import java.io.File;
-//import java.util.concurrent.TimeUnit;
-//
-//public class GroupClassifier {
-//
-// public static boolean visualize = true;
-// public static String dataLocalPath;
-//
-// public static void main(String[] args) throws Exception {
-// int seed = 123;
-// double learningRate = 0.01;
-// int batchSize = 50;
-// int nEpochs = 30;
-//
-// int numInputs = 2;
-// int numOutputs = 2;
-// int numHiddenNodes = 20;
-//
-// dataLocalPath = DownloaderUtility.CLASSIFICATIONDATA.Download();
-// //Load the training data:
-// RecordReader rr = new CSVRecordReader();
-// rr.initialize(new FileSplit(new File(dataLocalPath, "linear_data_train.csv")));
-// DataSetIterator trainIter = new RecordReaderDataSetIterator(rr, batchSize, 0, 2);
-//
-// //Load the test/evaluation data:
-// RecordReader rrTest = new CSVRecordReader();
-// rrTest.initialize(new FileSplit(new File(dataLocalPath, "linear_data_eval.csv")));
-// DataSetIterator testIter = new RecordReaderDataSetIterator(rrTest, batchSize, 0, 2);
-//
-// MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
-// .seed(seed)
-// .weightInit(WeightInit.XAVIER)
-// .updater(new Nesterovs(learningRate, 0.9))
-// .list()
-// .layer(new DenseLayer.Builder().nIn(numInputs).nOut(numHiddenNodes)
-// .activation(Activation.RELU)
-// .build())
-// .layer(new OutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD)
-// .activation(Activation.SOFTMAX)
-// .nIn(numHiddenNodes).nOut(numOutputs).build())
-// .build();
-//
-//
-// MultiLayerNetwork model = new MultiLayerNetwork(conf);
-// model.init();
-// model.setListeners(new ScoreIterationListener(10)); //Print score every 10 parameter updates
-//
-// model.fit(trainIter, nEpochs);
-//
-// System.out.println("Evaluate model....");
-// Evaluation eval = new Evaluation(numOutputs);
-// while (testIter.hasNext()) {
-// DataSet t = testIter.next();
-// INDArray features = t.getFeatures();
-// INDArray labels = t.getLabels();
-// INDArray predicted = model.output(features, false);
-// eval.eval(labels, predicted);
-// }
-// //An alternate way to do the above loop
-// //Evaluation evalResults = model.evaluate(testIter);
-//
-// //Print the evaluation statistics
-// System.out.println(eval.stats());
-//
-// System.out.println("\n****************Example finished********************");
-// //Training is complete. Code that follows is for plotting the data & predictions only
-// generateVisuals(model, trainIter, testIter);
-// }
-//
-// public static void generateVisuals(MultiLayerNetwork model, DataSetIterator trainIter, DataSetIterator testIter) throws Exception {
-// if (visualize) {
-// double xMin = 0;
-// double xMax = 1.0;
-// double yMin = -0.2;
-// double yMax = 0.8;
-// int nPointsPerAxis = 100;
-//
-// //Generate x,y points that span the whole range of features
-// INDArray allXYPoints = PlotUtil.generatePointsOnGraph(xMin, xMax, yMin, yMax, nPointsPerAxis);
-// //Get train data and plot with predictions
-// PlotUtil.plotTrainingData(model, trainIter, allXYPoints, nPointsPerAxis);
-// TimeUnit.SECONDS.sleep(3);
-// //Get test data, run the test data through the network to generate predictions, and plot those predictions:
-// PlotUtil.plotTestData(model, testIter, allXYPoints, nPointsPerAxis);
-// }
-// }
-//}
-//
diff --git a/dnet-feature-extraction/src/main/java/eu/dnetlib/deeplearning/layers/GraphConvolutionVertex.java b/dnet-feature-extraction/src/main/java/eu/dnetlib/deeplearning/layers/GraphConvolutionVertex.java
deleted file mode 100644
index 6a16711..0000000
--- a/dnet-feature-extraction/src/main/java/eu/dnetlib/deeplearning/layers/GraphConvolutionVertex.java
+++ /dev/null
@@ -1,24 +0,0 @@
-package eu.dnetlib.deeplearning.layers;
-
-import org.deeplearning4j.nn.conf.graph.GraphVertex;
-import org.deeplearning4j.nn.conf.layers.samediff.SDVertexParams;
-import org.deeplearning4j.nn.conf.layers.samediff.SameDiffLambdaVertex;
-import org.deeplearning4j.nn.conf.layers.samediff.SameDiffVertex;
-import org.nd4j.autodiff.samediff.SDVariable;
-import org.nd4j.autodiff.samediff.SameDiff;
-import org.nd4j.linalg.api.ndarray.INDArray;
-
-import java.util.Map;
-
-public class GraphConvolutionVertex extends SameDiffLambdaVertex {
-
- @Override
- public SDVariable defineVertex(SameDiff sameDiff, VertexInputs inputs) {
- SDVariable features = inputs.getInput(0);
- SDVariable adjacency = inputs.getInput(1);
- SDVariable degree = inputs.getInput(2).pow(0.5);
-
- //result: DegreeMatrix^-0.5 x Adjacent x DegreeMatrix^-0.5 x Features
- return degree.mmul(adjacency).mmul(degree).mmul(features);
- }
-}
diff --git a/dnet-feature-extraction/src/main/java/eu/dnetlib/deeplearning/layers/GraphGlobalAddPool.java b/dnet-feature-extraction/src/main/java/eu/dnetlib/deeplearning/layers/GraphGlobalAddPool.java
deleted file mode 100644
index 74f2f3f..0000000
--- a/dnet-feature-extraction/src/main/java/eu/dnetlib/deeplearning/layers/GraphGlobalAddPool.java
+++ /dev/null
@@ -1,21 +0,0 @@
-package eu.dnetlib.deeplearning.layers;
-
-import org.deeplearning4j.nn.conf.layers.samediff.SameDiffLambdaLayer;
-import org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer;
-import org.nd4j.autodiff.samediff.SDIndex;
-import org.nd4j.autodiff.samediff.SDVariable;
-import org.nd4j.autodiff.samediff.SameDiff;
-
-import java.util.Map;
-
-public class GraphGlobalAddPool extends SameDiffLambdaLayer {
-
- int size;
- public GraphGlobalAddPool(int size) {
- this.size = size;
- }
- @Override
- public SDVariable defineLayer(SameDiff sameDiff, SDVariable layerInput) {
- return layerInput.mean(0).reshape(1, size); //reshape because output layer expects 2-dimensional arrays
- }
-}
diff --git a/dnet-feature-extraction/src/main/java/eu/dnetlib/deeplearning/support/DataSetProcessor.java b/dnet-feature-extraction/src/main/java/eu/dnetlib/deeplearning/support/DataSetProcessor.java
deleted file mode 100644
index cfaf9d2..0000000
--- a/dnet-feature-extraction/src/main/java/eu/dnetlib/deeplearning/support/DataSetProcessor.java
+++ /dev/null
@@ -1,88 +0,0 @@
-package eu.dnetlib.deeplearning.support;
-
-import eu.dnetlib.featureextraction.Utilities;
-import eu.dnetlib.support.Author;
-import eu.dnetlib.support.ConnectedComponent;
-import eu.dnetlib.support.Relation;
-import org.apache.spark.api.java.JavaRDD;
-import org.codehaus.jackson.map.ObjectMapper;
-import org.jetbrains.annotations.NotNull;
-import org.nd4j.linalg.api.ndarray.INDArray;
-import org.nd4j.linalg.dataset.MultiDataSet;
-import org.nd4j.linalg.factory.Nd4j;
-
-import java.io.IOException;
-import java.util.*;
-import java.util.stream.Collectors;
-import java.util.stream.Stream;
-
-public class DataSetProcessor {
-
- public static JavaRDD entityGroupToMultiDataset(JavaRDD groupEntity, String idJPath, String featureJPath, String groundTruthJPath) {
-
- return groupEntity.map(g -> {
- Map featuresMap = new HashMap<>();
- List groundTruth = new ArrayList<>();
- Set entities = g.getDocs();
- for(String json:entities) {
- featuresMap.put(
- Utilities.getJPathString(idJPath, json),
- Utilities.getJPathArray(featureJPath, json)
- );
- groundTruth.add(Utilities.getJPathString(groundTruthJPath, json));
- }
-
- Set relations = g.getSimrels();
-
- return getMultiDataSet(featuresMap, relations, groundTruth);
- });
- }
-
- public static MultiDataSet getMultiDataSet(Map featuresMap, Set relations, List groundTruth) {
-
- List identifiers = new ArrayList<>(featuresMap.keySet());
-
- int numNodes = identifiers.size();
-
- //initialize arrays
- INDArray adjacency = Nd4j.zeros(numNodes, numNodes);
- INDArray features = Nd4j.zeros(numNodes, featuresMap.get(identifiers.get(0)).length); //feature size taken from the first element (it's equal for every element)
- INDArray degree = Nd4j.zeros(numNodes, numNodes);
-
- //create adjacency
- for(Relation r: relations) {
- adjacency.put(identifiers.indexOf(r.getSource()), identifiers.indexOf(r.getTarget()), 1);
- adjacency.put(identifiers.indexOf(r.getTarget()), identifiers.indexOf(r.getSource()), 1);
- }
- adjacency.addi(Nd4j.eye(numNodes));
-
- //create degree and features
- List degreeSupport = relations.stream().flatMap(r -> Stream.of(r.getSource(), r.getTarget())).collect(Collectors.toList());
- for(int i=0; i< identifiers.size(); i++) {
- degree.put(i, i, Collections.frequency(degreeSupport, identifiers.get(i)));
- features.putRow(i, Nd4j.create(featuresMap.get(identifiers.get(i))));
- }
-
- //infer label
- INDArray label = Nd4j.zeros(1, 2);
- if (groundTruth.stream().distinct().count()==1) {
- //correct (same elements)
- label.put(0, 0, 1.0);
- }
- else {
- //wrong (different elements)
- label.put(0, 1, 1.0);
- }
-
- return new MultiDataSet(
- new INDArray[]{
- features,
- adjacency,
- degree
- },
- new INDArray[]{
- label
- }
- );
- }
-}
diff --git a/dnet-feature-extraction/src/main/java/eu/dnetlib/deeplearning/support/GroupMultiDataSet.java b/dnet-feature-extraction/src/main/java/eu/dnetlib/deeplearning/support/GroupMultiDataSet.java
deleted file mode 100644
index b6734dd..0000000
--- a/dnet-feature-extraction/src/main/java/eu/dnetlib/deeplearning/support/GroupMultiDataSet.java
+++ /dev/null
@@ -1,11 +0,0 @@
-package eu.dnetlib.deeplearning.support;
-
-import org.nd4j.linalg.api.ndarray.INDArray;
-import org.nd4j.linalg.dataset.MultiDataSet;
-
-import java.io.*;
-import java.util.List;
-
-public class GroupMultiDataSet extends MultiDataSet {
-
-}
diff --git a/dnet-feature-extraction/src/main/java/eu/dnetlib/deeplearning/support/NetworkConfigurations.java b/dnet-feature-extraction/src/main/java/eu/dnetlib/deeplearning/support/NetworkConfigurations.java
deleted file mode 100644
index 52db977..0000000
--- a/dnet-feature-extraction/src/main/java/eu/dnetlib/deeplearning/support/NetworkConfigurations.java
+++ /dev/null
@@ -1,97 +0,0 @@
-package eu.dnetlib.deeplearning.support;
-
-import eu.dnetlib.deeplearning.layers.GraphConvolutionVertex;
-import eu.dnetlib.deeplearning.layers.GraphGlobalAddPool;
-import org.bytedeco.opencv.opencv_dnn.PoolingLayer;
-import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
-import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
-import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
-import org.deeplearning4j.nn.conf.graph.MergeVertex;
-import org.deeplearning4j.nn.conf.layers.*;
-import org.deeplearning4j.nn.weights.WeightInit;
-import org.nd4j.autodiff.samediff.SDVariable;
-import org.nd4j.linalg.activations.Activation;
-import org.nd4j.linalg.api.buffer.DataType;
-import org.nd4j.linalg.learning.config.Adam;
-import org.nd4j.linalg.learning.config.Nesterovs;
-import org.nd4j.linalg.lossfunctions.LossFunctions;
-
-public class NetworkConfigurations {
-
- //parameteres default values
- protected static final int SEED = 12345;
- protected static final double LEARNING_RATE = 1e-3;
- protected static final String ADJACENCY_MATRIX = "adjacency";
- protected static final String FEATURES_MATRIX = "features";
- protected static final String DEGREE_MATRIX = "degrees";
-
- public static MultiLayerConfiguration getLinearDataClassifier(int numInputs, int numHiddenNodes, int numOutputs) {
- return new NeuralNetConfiguration.Builder()
- .seed(SEED)
- .weightInit(WeightInit.XAVIER)
- .updater(new Nesterovs(LEARNING_RATE, 0.9))
- .list()
- .layer(new DenseLayer.Builder().nIn(numInputs).nOut(numHiddenNodes)
- .activation(Activation.RELU)
- .build())
- .layer(new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
- .activation(Activation.SOFTMAX)
- .nIn(numHiddenNodes).nOut(numOutputs).build())
- .build();
- }
-
- public static ComputationGraphConfiguration getSimpleGCN(int numLayers, int numInputs, int numHiddenNodes, int numClasses) {
-
- ComputationGraphConfiguration.GraphBuilder baseConfig = new NeuralNetConfiguration.Builder()
- .seed(SEED)
- .updater(new Adam(LEARNING_RATE))
- .weightInit(WeightInit.XAVIER)
- .graphBuilder()
- .addInputs(FEATURES_MATRIX, ADJACENCY_MATRIX, DEGREE_MATRIX)
- //first convolution layer
- .addVertex("layer1",
- new GraphConvolutionVertex(),
- FEATURES_MATRIX, ADJACENCY_MATRIX, DEGREE_MATRIX)
- .layer("conv1",
- new DenseLayer.Builder().nIn(numInputs).nOut(numHiddenNodes)
- .activation(Activation.RELU)
- .build(),
- "layer1")
- .layer("batch1",
- new BatchNormalization.Builder().nOut(numHiddenNodes).build(),
- "conv1");
-
- //ad as many layers as requested
- for(int i=2; i<=numLayers; i++) {
- baseConfig = baseConfig.addVertex("layer" + i,
- new GraphConvolutionVertex(),
- "batch" + (i-1), ADJACENCY_MATRIX, DEGREE_MATRIX)
- .layer("conv" + i,
- new DenseLayer.Builder().nIn(numHiddenNodes).nOut(numHiddenNodes)
- .activation(Activation.RELU)
- .build(),
- "layer" + i)
- .layer("batch" + i,
- new BatchNormalization.Builder().nOut(numHiddenNodes).build(),
- "conv" + i);
- }
-
- baseConfig = baseConfig
- .layer("pool",
- new GraphGlobalAddPool(numHiddenNodes),
- "batch" + numLayers)
- .layer("fc1",
- new DenseLayer.Builder().nIn(numHiddenNodes).nOut(numHiddenNodes)
- .activation(Activation.RELU)
- .weightInit(WeightInit.XAVIER)
- .build(),
- "pool")
- .layer("out",
- new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
- .activation(Activation.SOFTMAX)
- .nIn(numHiddenNodes).nOut(numClasses).build(),
- "fc1");
-
- return baseConfig.setOutputs("out").build();
- }
-}
diff --git a/dnet-feature-extraction/src/main/java/eu/dnetlib/deeplearning/support/PlotUtils.java b/dnet-feature-extraction/src/main/java/eu/dnetlib/deeplearning/support/PlotUtils.java
deleted file mode 100644
index 3653b1a..0000000
--- a/dnet-feature-extraction/src/main/java/eu/dnetlib/deeplearning/support/PlotUtils.java
+++ /dev/null
@@ -1,253 +0,0 @@
-//package eu.dnetlib.deeplearning.support;
-//
-//import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
-//import org.jfree.chart.ChartPanel;
-//import org.jfree.chart.ChartUtilities;
-//import org.jfree.chart.JFreeChart;
-//import org.jfree.chart.axis.AxisLocation;
-//import org.jfree.chart.axis.NumberAxis;
-//import org.jfree.chart.block.BlockBorder;
-//import org.jfree.chart.plot.DatasetRenderingOrder;
-//import org.jfree.chart.plot.XYPlot;
-//import org.jfree.chart.renderer.GrayPaintScale;
-//import org.jfree.chart.renderer.PaintScale;
-//import org.jfree.chart.renderer.xy.XYBlockRenderer;
-//import org.jfree.chart.renderer.xy.XYLineAndShapeRenderer;
-//import org.jfree.chart.title.PaintScaleLegend;
-//import org.jfree.data.xy.*;
-//import org.jfree.ui.RectangleEdge;
-//import org.jfree.ui.RectangleInsets;
-//import org.nd4j.linalg.api.ndarray.INDArray;
-//import org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax;
-//import org.nd4j.linalg.dataset.DataSet;
-//import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
-//import org.nd4j.linalg.factory.Nd4j;
-//
-//import javax.swing.*;
-//import java.awt.*;
-//import java.util.ArrayList;
-//import java.util.List;
-//
-///**
-// * Simple plotting methods for the MLPClassifier quickstartexamples
-// *
-// * @author Alex Black
-// */
-//public class PlotUtils {
-//
-// /**
-// * Plot the training data. Assume 2d input, classification output
-// *
-// * @param model Model to use to get predictions
-// * @param trainIter DataSet Iterator
-// * @param backgroundIn sets of x,y points in input space, plotted in the background
-// * @param nDivisions Number of points (per axis, for the backgroundIn/backgroundOut arrays)
-// */
-// public static void plotTrainingData(MultiLayerNetwork model, DataSetIterator trainIter, INDArray backgroundIn, int nDivisions) {
-// double[] mins = backgroundIn.min(0).data().asDouble();
-// double[] maxs = backgroundIn.max(0).data().asDouble();
-//
-// DataSet ds = allBatches(trainIter);
-// INDArray backgroundOut = model.output(backgroundIn);
-//
-// XYZDataset backgroundData = createBackgroundData(backgroundIn, backgroundOut);
-// JPanel panel = new ChartPanel(createChart(backgroundData, mins, maxs, nDivisions, createDataSetTrain(ds.getFeatures(), ds.getLabels())));
-//
-// JFrame f = new JFrame();
-// f.add(panel);
-// f.setDefaultCloseOperation(WindowConstants.EXIT_ON_CLOSE);
-// f.pack();
-// f.setTitle("Training Data");
-//
-// f.setVisible(true);
-// f.setLocation(0, 0);
-// }
-//
-// /**
-// * Plot the training data. Assume 2d input, classification output
-// *
-// * @param model Model to use to get predictions
-// * @param testIter Test Iterator
-// * @param backgroundIn sets of x,y points in input space, plotted in the background
-// * @param nDivisions Number of points (per axis, for the backgroundIn/backgroundOut arrays)
-// */
-// public static void plotTestData(MultiLayerNetwork model, DataSetIterator testIter, INDArray backgroundIn, int nDivisions) {
-//
-// double[] mins = backgroundIn.min(0).data().asDouble();
-// double[] maxs = backgroundIn.max(0).data().asDouble();
-//
-// INDArray backgroundOut = model.output(backgroundIn);
-// XYZDataset backgroundData = createBackgroundData(backgroundIn, backgroundOut);
-// DataSet ds = allBatches(testIter);
-// INDArray predicted = model.output(ds.getFeatures());
-// JPanel panel = new ChartPanel(createChart(backgroundData, mins, maxs, nDivisions, createDataSetTest(ds.getFeatures(), ds.getLabels(), predicted)));
-//
-// JFrame f = new JFrame();
-// f.add(panel);
-// f.setDefaultCloseOperation(WindowConstants.EXIT_ON_CLOSE);
-// f.pack();
-// f.setTitle("Test Data");
-//
-// f.setVisible(true);
-// f.setLocationRelativeTo(null);
-// //f.setLocation(100,100);
-//
-// }
-//
-//
-// /**
-// * Create data for the background data set
-// */
-// private static XYZDataset createBackgroundData(INDArray backgroundIn, INDArray backgroundOut) {
-// int nRows = backgroundIn.rows();
-// double[] xValues = new double[nRows];
-// double[] yValues = new double[nRows];
-// double[] zValues = new double[nRows];
-// for (int i = 0; i < nRows; i++) {
-// xValues[i] = backgroundIn.getDouble(i, 0);
-// yValues[i] = backgroundIn.getDouble(i, 1);
-// zValues[i] = backgroundOut.getDouble(i, 0);
-//
-// }
-//
-// DefaultXYZDataset dataset = new DefaultXYZDataset();
-// dataset.addSeries("Series 1",
-// new double[][]{xValues, yValues, zValues});
-// return dataset;
-// }
-//
-// //Training data
-// private static XYDataset createDataSetTrain(INDArray features, INDArray labels) {
-// int nRows = features.rows();
-//
-// int nClasses = 2; // Binary classification using one output call end sigmoid.
-//
-// XYSeries[] series = new XYSeries[nClasses];
-// for (int i = 0; i < series.length; i++) series[i] = new XYSeries("Class " + i);
-// INDArray argMax = Nd4j.getExecutioner().exec(new ArgMax(new INDArray[]{labels},false,new int[]{1}))[0];
-// for (int i = 0; i < nRows; i++) {
-// int classIdx = (int) argMax.getDouble(i);
-// series[classIdx].add(features.getDouble(i, 0), features.getDouble(i, 1));
-// }
-//
-// XYSeriesCollection c = new XYSeriesCollection();
-// for (XYSeries s : series) c.addSeries(s);
-// return c;
-// }
-//
-// //Test data
-// private static XYDataset createDataSetTest(INDArray features, INDArray labels, INDArray predicted) {
-// int nRows = features.rows();
-//
-// int nClasses = 2; // Binary classification using one output call end sigmoid.
-//
-// XYSeries[] series = new XYSeries[nClasses * nClasses];
-// int[] series_index = new int[]{0, 3, 2, 1}; //little hack to make the charts look consistent.
-// for (int i = 0; i < nClasses * nClasses; i++) {
-// int trueClass = i / nClasses;
-// int predClass = i % nClasses;
-// String label = "actual=" + trueClass + ", pred=" + predClass;
-// series[series_index[i]] = new XYSeries(label);
-// }
-// INDArray actualIdx = labels.argMax(1);
-// INDArray predictedIdx = predicted.argMax(1);
-// for (int i = 0; i < nRows; i++) {
-// int classIdx = actualIdx.getInt(i);
-// int predIdx = predictedIdx.getInt(i);
-// int idx = series_index[classIdx * nClasses + predIdx];
-// series[idx].add(features.getDouble(i, 0), features.getDouble(i, 1));
-// }
-//
-// XYSeriesCollection c = new XYSeriesCollection();
-// for (XYSeries s : series) c.addSeries(s);
-// return c;
-// }
-//
-// private static JFreeChart createChart(XYZDataset dataset, double[] mins, double[] maxs, int nPoints, XYDataset xyData) {
-// NumberAxis xAxis = new NumberAxis("X");
-// xAxis.setRange(mins[0], maxs[0]);
-//
-//
-// NumberAxis yAxis = new NumberAxis("Y");
-// yAxis.setRange(mins[1], maxs[1]);
-//
-// XYBlockRenderer renderer = new XYBlockRenderer();
-// renderer.setBlockWidth((maxs[0] - mins[0]) / (nPoints - 1));
-// renderer.setBlockHeight((maxs[1] - mins[1]) / (nPoints - 1));
-// PaintScale scale = new GrayPaintScale(0, 1.0);
-// renderer.setPaintScale(scale);
-// XYPlot plot = new XYPlot(dataset, xAxis, yAxis, renderer);
-// plot.setBackgroundPaint(Color.lightGray);
-// plot.setDomainGridlinesVisible(false);
-// plot.setRangeGridlinesVisible(false);
-// plot.setAxisOffset(new RectangleInsets(5, 5, 5, 5));
-// JFreeChart chart = new JFreeChart("", plot);
-// chart.getXYPlot().getRenderer().setSeriesVisibleInLegend(0, false);
-//
-//
-// NumberAxis scaleAxis = new NumberAxis("Probability (class 1)");
-// scaleAxis.setAxisLinePaint(Color.white);
-// scaleAxis.setTickMarkPaint(Color.white);
-// scaleAxis.setTickLabelFont(new Font("Dialog", Font.PLAIN, 7));
-// PaintScaleLegend legend = new PaintScaleLegend(new GrayPaintScale(),
-// scaleAxis);
-// legend.setStripOutlineVisible(false);
-// legend.setSubdivisionCount(20);
-// legend.setAxisLocation(AxisLocation.BOTTOM_OR_LEFT);
-// legend.setAxisOffset(5.0);
-// legend.setMargin(new RectangleInsets(5, 5, 5, 5));
-// legend.setFrame(new BlockBorder(Color.red));
-// legend.setPadding(new RectangleInsets(10, 10, 10, 10));
-// legend.setStripWidth(10);
-// legend.setPosition(RectangleEdge.LEFT);
-// chart.addSubtitle(legend);
-//
-// ChartUtilities.applyCurrentTheme(chart);
-//
-// plot.setDataset(1, xyData);
-// XYLineAndShapeRenderer renderer2 = new XYLineAndShapeRenderer();
-// renderer2.setBaseLinesVisible(false);
-// plot.setRenderer(1, renderer2);
-//
-// plot.setDatasetRenderingOrder(DatasetRenderingOrder.FORWARD);
-//
-// return chart;
-// }
-//
-// public static INDArray generatePointsOnGraph(double xMin, double xMax, double yMin, double yMax, int nPointsPerAxis) {
-// //generate all the x,y points
-// double[][] evalPoints = new double[nPointsPerAxis * nPointsPerAxis][2];
-// int count = 0;
-// for (int i = 0; i < nPointsPerAxis; i++) {
-// for (int j = 0; j < nPointsPerAxis; j++) {
-// double x = i * (xMax - xMin) / (nPointsPerAxis - 1) + xMin;
-// double y = j * (yMax - yMin) / (nPointsPerAxis - 1) + yMin;
-//
-// evalPoints[count][0] = x;
-// evalPoints[count][1] = y;
-//
-// count++;
-// }
-// }
-//
-// return Nd4j.create(evalPoints);
-// }
-//
-// /**
-// * This is to collect all the data and return it as one minibatch. Obviously only for use here with small datasets
-// * @param iter
-// * @return
-// */
-// private static DataSet allBatches(DataSetIterator iter) {
-//
-// List fullSet = new ArrayList<>();
-// iter.reset();
-// while (iter.hasNext()) {
-// List miniBatchList = iter.next().asList();
-// fullSet.addAll(miniBatchList);
-// }
-// iter.reset();
-// return new ListDataSetIterator<>(fullSet,fullSet.size()).next();
-// }
-//
-//}
diff --git a/dnet-feature-extraction/src/main/java/eu/dnetlib/example/Example.scala b/dnet-feature-extraction/src/main/java/eu/dnetlib/example/Example.scala
new file mode 100644
index 0000000..38fe366
--- /dev/null
+++ b/dnet-feature-extraction/src/main/java/eu/dnetlib/example/Example.scala
@@ -0,0 +1,76 @@
+//package eu.dnetlib.example
+//
+//import com.intel.analytics.bigdl.dllib.NNContext
+//import com.intel.analytics.bigdl.dllib.keras.Model
+//import com.intel.analytics.bigdl.dllib.keras.models.Models
+//import com.intel.analytics.bigdl.dllib.keras.optimizers.Adam
+//import com.intel.analytics.bigdl.dllib.nn.ClassNLLCriterion
+//import com.intel.analytics.bigdl.dllib.utils.Shape
+//import com.intel.analytics.bigdl.dllib.keras.layers._
+//import com.intel.analytics.bigdl.numeric.NumericFloat
+//import org.apache.spark.ml.feature.VectorAssembler
+//import org.apache.spark._
+//import org.apache.spark.sql.{SQLContext, SparkSession}
+//import org.apache.spark.sql.functions._
+//import org.apache.spark.sql.types.DoubleType
+//object Example {
+//
+//
+// def main(args: Array[String]): Unit = {
+//
+// val conf = new SparkConf().setMaster("local[2]").setAppName("dllib_demo")
+// val sc = NNContext.initNNContext(conf)
+//
+//// val spark = new SQLContext(sc) //deprecated
+// val spark = SparkSession
+// .builder()
+// .config(sc.getConf)
+// .getOrCreate()
+//
+// val path = "/Users/miconis/Desktop/example_dataset.csv"
+// val df = spark.read.options(Map("inferSchema"->"true","delimiter"->",")).csv(path)
+// .toDF("num_times_pregrant", "plasma_glucose", "blood_pressure", "skin_fold_thickness", "2-hour_insulin", "body_mass_index", "diabetes_pedigree_function", "age", "class")
+//
+// val assembler = new VectorAssembler()
+// .setInputCols(Array("num_times_pregrant", "plasma_glucose", "blood_pressure", "skin_fold_thickness", "2-hour_insulin", "body_mass_index", "diabetes_pedigree_function", "age"))
+// .setOutputCol("features")
+// val assembleredDF = assembler.transform(df)
+// val df2 = assembleredDF.withColumn("label", col("class").cast(DoubleType) + lit(1))
+//
+// val Array(trainDF, valDF) = df2.randomSplit(Array(0.8, 0.2))
+//
+// val x1 = Input(Shape(8))
+// val merge = Merge.merge(inputs = List(x1, x1), mode = "dot")
+// val dense1 = Dense(12, activation="relu").inputs(x1)
+// val dense2 = Dense(8, activation="relu").inputs(dense1)
+// val dense3 = Dense(2, activation="relu").inputs(dense2)
+// val dmodel = Model(x1, dense3)
+//
+// dmodel.compile(optimizer = new Adam(), loss = ClassNLLCriterion())
+//
+//
+// //training
+// dmodel.fit(x = trainDF, batchSize = 4, nbEpoch = 2, featureCols = Array("features"), labelCols = Array("label"), valX = valDF)
+//
+//
+//// //save model
+//// val modelPath = "/tmp/demo/keras.model"
+//// dmodel.saveModel(modelPath)
+////
+////
+//// //load model
+//// val loadModel = Models.loadModel(modelPath)
+// val loadModel = dmodel
+//
+// //inference
+// val preDF2 = loadModel.predict(valDF, featureCols = Array("features"), predictionCol = "predict")
+//
+// preDF2.show(false)
+//
+// //evaluation
+// val ret = dmodel.evaluate(trainDF, batchSize = 4, featureCols = Array("features"), labelCols = Array("label"))
+//
+// ret.foreach(println)
+//
+// }
+//}
diff --git a/dnet-feature-extraction/src/main/java/eu/dnetlib/featureextraction/FeatureTransformer.java b/dnet-feature-extraction/src/main/java/eu/dnetlib/featureextraction/FeatureTransformer.java
index e55a99a..1f5c8a2 100644
--- a/dnet-feature-extraction/src/main/java/eu/dnetlib/featureextraction/FeatureTransformer.java
+++ b/dnet-feature-extraction/src/main/java/eu/dnetlib/featureextraction/FeatureTransformer.java
@@ -1,5 +1,11 @@
package eu.dnetlib.featureextraction;
+import com.google.common.collect.Lists;
+import com.johnsnowlabs.nlp.*;
+import com.johnsnowlabs.nlp.annotators.sbd.pragmatic.SentenceDetector;
+import com.johnsnowlabs.nlp.embeddings.BertSentenceEmbeddings;
+import org.apache.spark.ml.Pipeline;
+import org.apache.spark.ml.PipelineStage;
import org.apache.spark.ml.clustering.LDA;
import org.apache.spark.ml.clustering.LDAModel;
import org.apache.spark.ml.feature.CountVectorizer;
@@ -8,16 +14,20 @@ import org.apache.spark.ml.feature.StopWordsRemover;
import org.apache.spark.ml.feature.Tokenizer;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
+import org.apache.spark.sql.RowFactory;
+import org.apache.spark.sql.SparkSession;
+import org.apache.spark.sql.types.DataTypes;
+import org.apache.spark.sql.types.StructField;
+import org.apache.spark.sql.types.StructType;
import scala.Tuple2;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.Serializable;
-import java.util.HashMap;
-import java.util.HashSet;
-import java.util.Map;
-import java.util.Set;
+import java.net.URISyntaxException;
+import java.nio.file.Paths;
+import java.util.*;
public class FeatureTransformer implements Serializable {
@@ -161,4 +171,5 @@ public class FeatureTransformer implements Serializable {
public static Dataset ldaInference(Dataset inputDS, LDAModel ldaModel) {
return ldaModel.transform(inputDS).select(ID_COL, LDA_INFERENCE_OUTPUT_COL);
}
+
}
diff --git a/dnet-feature-extraction/src/main/java/eu/dnetlib/featureextraction/ScalaFeatureTransformer.scala b/dnet-feature-extraction/src/main/java/eu/dnetlib/featureextraction/ScalaFeatureTransformer.scala
new file mode 100644
index 0000000..1bcf09b
--- /dev/null
+++ b/dnet-feature-extraction/src/main/java/eu/dnetlib/featureextraction/ScalaFeatureTransformer.scala
@@ -0,0 +1,157 @@
+package eu.dnetlib.featureextraction
+
+import com.johnsnowlabs.nlp.EmbeddingsFinisher
+import com.johnsnowlabs.nlp.annotator.SentenceDetector
+import com.johnsnowlabs.nlp.annotators.Tokenizer
+import com.johnsnowlabs.nlp.base.DocumentAssembler
+import com.johnsnowlabs.nlp.embeddings.{BertEmbeddings, BertSentenceEmbeddings, WordEmbeddingsModel}
+import org.apache.spark.ml.Pipeline
+import org.apache.spark.sql.functions.{array, col, explode}
+import org.apache.spark.sql.{Dataset, Row}
+
+import java.nio.file.Paths
+
+object ScalaFeatureTransformer {
+
+ val DOCUMENT_COL = "document"
+ val SENTENCE_COL = "sentence"
+ val BERT_SENTENCE_EMBEDDINGS_COL = "bert_sentence"
+ val BERT_EMBEDDINGS_COL = "bert"
+ val TOKENIZER_COL = "tokens"
+ val WORD_EMBEDDINGS_COL = "word"
+
+ //models path
+ private val bertSentenceModelPath = Paths.get(getClass.getResource("/eu/dnetlib/featureextraction/support/sent_small_bert_L6_512_en_2.6.0_2.4_1598350624049").toURI).toFile.getAbsolutePath
+ private val bertModelPath = Paths.get(getClass.getResource("/eu/dnetlib/featureextraction/support/small_bert_L2_128_en_2.6.0_2.4_1598344320681").toURI).toFile.getAbsolutePath
+ private val wordModelPath = Paths.get(getClass.getResource("/eu/dnetlib/featureextraction/support/glove_100d_en_2.4.0_2.4_1579690104032").toURI).toFile.getAbsolutePath
+
+ /**
+ * Extract the SentenceBERT embeddings for the given field.
+ *
+ * @param inputData: the input data
+ * @param inputField: the input field
+ * @return the dataset with the embeddings
+ * */
+ def bertSentenceEmbeddings(inputData: Dataset[Row], inputField: String, modelPath: String): Dataset[Row] = {
+
+ val documentAssembler = new DocumentAssembler()
+ .setInputCol(inputField)
+ .setOutputCol(DOCUMENT_COL)
+
+ val sentence = new SentenceDetector()
+ .setInputCols(DOCUMENT_COL)
+ .setOutputCol(SENTENCE_COL)
+
+ val bertSentenceEmbeddings = BertSentenceEmbeddings
+ .load(modelPath)
+ .setInputCols(SENTENCE_COL)
+ .setOutputCol("raw_" + BERT_SENTENCE_EMBEDDINGS_COL)
+ .setCaseSensitive(false)
+
+ val bertSentenceEmbeddingsFinisher = new EmbeddingsFinisher()
+ .setInputCols("raw_" + BERT_SENTENCE_EMBEDDINGS_COL)
+ .setOutputCols(BERT_SENTENCE_EMBEDDINGS_COL)
+ .setOutputAsVector(true)
+ .setCleanAnnotations(false)
+
+ val pipeline = new Pipeline()
+ .setStages(Array(
+ documentAssembler,
+ sentence,
+ bertSentenceEmbeddings,
+ bertSentenceEmbeddingsFinisher
+ ))
+
+ val result = pipeline.fit(inputData).transform(inputData).withColumn(BERT_SENTENCE_EMBEDDINGS_COL, explode(col(BERT_SENTENCE_EMBEDDINGS_COL)))
+
+ result
+ }
+
+ /**
+ * Extract the BERT embeddings for the given field.
+ *
+ * @param inputData : the input data
+ * @param inputField : the input field
+ * @return the dataset with the embeddings
+ * */
+ def bertEmbeddings(inputData: Dataset[Row], inputField: String, modelPath: String): Dataset[Row] = {
+
+ val documentAssembler = new DocumentAssembler()
+ .setInputCol(inputField)
+ .setOutputCol(DOCUMENT_COL)
+
+ val tokenizer = new Tokenizer()
+ .setInputCols(DOCUMENT_COL)
+ .setOutputCol(TOKENIZER_COL)
+
+ val bertEmbeddings = BertEmbeddings
+ .load(modelPath)
+ .setInputCols(TOKENIZER_COL, DOCUMENT_COL)
+ .setOutputCol("raw_" + BERT_EMBEDDINGS_COL)
+ .setCaseSensitive(false)
+
+ val bertEmbeddingsFinisher = new EmbeddingsFinisher()
+ .setInputCols("raw_" + BERT_EMBEDDINGS_COL)
+ .setOutputCols(BERT_EMBEDDINGS_COL)
+ .setOutputAsVector(true)
+ .setCleanAnnotations(false)
+
+ val pipeline = new Pipeline()
+ .setStages(Array(
+ documentAssembler,
+ tokenizer,
+ bertEmbeddings,
+ bertEmbeddingsFinisher
+ ))
+
+ val result = pipeline.fit(inputData).transform(inputData).withColumn(BERT_EMBEDDINGS_COL, explode(col(BERT_EMBEDDINGS_COL)))
+
+ result
+ }
+
+ /**
+ * Extract the Word2Vec embeddings for the given field.
+ *
+ * @param inputData : the input data
+ * @param inputField : the input field
+ * @return the dataset with the embeddings
+ * */
+ def wordEmbeddings(inputData: Dataset[Row], inputField: String, modelPath: String): Dataset[Row] = {
+
+ val documentAssembler = new DocumentAssembler()
+ .setInputCol(inputField)
+ .setOutputCol(DOCUMENT_COL)
+
+ val tokenizer = new Tokenizer()
+ .setInputCols(DOCUMENT_COL)
+ .setOutputCol(TOKENIZER_COL)
+
+ val wordEmbeddings = WordEmbeddingsModel
+ .load(modelPath)
+ .setInputCols(DOCUMENT_COL, TOKENIZER_COL)
+ .setOutputCol("raw_" + WORD_EMBEDDINGS_COL)
+
+ val wordEmbeddingsFinisher = new EmbeddingsFinisher()
+ .setInputCols("raw_" + WORD_EMBEDDINGS_COL)
+ .setOutputCols(WORD_EMBEDDINGS_COL)
+ .setOutputAsVector(true)
+ .setCleanAnnotations(false)
+
+ val pipeline = new Pipeline()
+ .setStages(Array(
+ documentAssembler,
+ tokenizer,
+ wordEmbeddings,
+ wordEmbeddingsFinisher
+ ))
+
+ val result = pipeline.fit(inputData).transform(inputData).withColumn(WORD_EMBEDDINGS_COL, explode(col(WORD_EMBEDDINGS_COL)))
+
+ result
+ }
+
+ //bert on the title
+ //bert sentence: on the abstract
+ //word2vec: on the subjects
+
+}
diff --git a/dnet-feature-extraction/src/main/java/eu/dnetlib/support/Author.java b/dnet-feature-extraction/src/main/java/eu/dnetlib/support/Author.java
index 9d65a7d..aee3b1e 100644
--- a/dnet-feature-extraction/src/main/java/eu/dnetlib/support/Author.java
+++ b/dnet-feature-extraction/src/main/java/eu/dnetlib/support/Author.java
@@ -5,6 +5,7 @@ import org.codehaus.jackson.annotate.JsonIgnore;
import java.io.Serializable;
import java.util.List;
+import java.util.Map;
public class Author implements Serializable {
@@ -12,24 +13,32 @@ public class Author implements Serializable {
public String firstname;
public String lastname;
public List coAuthors;
- public double[] topics;
public String orcid;
public String id;
+ public Map embeddings;
- public String pubId;
-
- public Author() {
+ public Map getEmbeddings() {
+ return embeddings;
}
- public Author(String fullname, String firstname, String lastname, List coAuthors, double[] topics, String id, String pubId, String orcid) {
+ public Author(String fullname, String firstname, String lastname, List coAuthors, String orcid, String id, Map embeddings, String pubId) {
this.fullname = fullname;
this.firstname = firstname;
this.lastname = lastname;
this.coAuthors = coAuthors;
- this.topics = topics;
- this.id = id;
- this.pubId = pubId;
this.orcid = orcid;
+ this.id = id;
+ this.embeddings = embeddings;
+ this.pubId = pubId;
+ }
+
+ public void setEmbeddings(Map embeddings) {
+ this.embeddings = embeddings;
+ }
+
+ public String pubId;
+
+ public Author() {
}
public String getFullname() {
@@ -64,14 +73,6 @@ public class Author implements Serializable {
this.coAuthors = coAuthors;
}
- public double[] getTopics() {
- return topics;
- }
-
- public void setTopics(double[] topics) {
- this.topics = topics;
- }
-
public String getId() {
return id;
}
diff --git a/dnet-feature-extraction/src/main/java/eu/dnetlib/support/AuthorsFactory.java b/dnet-feature-extraction/src/main/java/eu/dnetlib/support/AuthorsFactory.java
index 78e6e66..6540f4e 100644
--- a/dnet-feature-extraction/src/main/java/eu/dnetlib/support/AuthorsFactory.java
+++ b/dnet-feature-extraction/src/main/java/eu/dnetlib/support/AuthorsFactory.java
@@ -14,34 +14,32 @@ import javax.rmi.CORBA.Util;
import java.math.BigInteger;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.Iterator;
-import java.util.List;
+import java.util.*;
import java.util.stream.Collectors;
public class AuthorsFactory {
- public static JavaRDD extractAuthorsFromPublications(JavaRDD entities, JavaPairRDD topics) {
+ public static JavaRDD extractAuthorsFromPublications(JavaRDD publications, JavaPairRDD> topics) {
- JavaPairRDD publicationWithTopics = entities.map(x -> new ObjectMapper().configure(DeserializationConfig.Feature.FAIL_ON_UNKNOWN_PROPERTIES, false).readValue(x, Publication.class))
+ //read topics
+ JavaPairRDD> publicationWithEmbeddings = publications
.mapToPair(p -> new Tuple2<>(p.getId(), p))
.join(topics)
.mapToPair(Tuple2::_2);
- return publicationWithTopics.flatMap(p -> createAuthors(p));
+ return publicationWithEmbeddings.flatMap(AuthorsFactory::createAuthors);
}
- public static Iterator createAuthors(Tuple2 publicationWithTopic){
- List baseCoAuthors = publicationWithTopic._1()
+ public static Iterator createAuthors(Tuple2> publicationWithEmbeddings){
+ List baseCoAuthors = publicationWithEmbeddings._1()
.getAuthor()
.stream()
.map(a -> new CoAuthor(a.getFullname(), a.getName()!=null?a.getName():"", a.getSurname()!=null?a.getSurname():"", a.getPid().size()>0? a.getPid().get(0).getValue():""))
.collect(Collectors.toList());
List authors = new ArrayList<>();
- for(eu.dnetlib.dhp.schema.oaf.Author a : publicationWithTopic._1().getAuthor()) {
+ for(eu.dnetlib.dhp.schema.oaf.Author a : publicationWithEmbeddings._1().getAuthor()) {
//prepare orcid
String orcid = a.getPid().size()>0? a.getPid().get(0).getValue() : "";
@@ -50,9 +48,19 @@ public class AuthorsFactory {
coAuthors.remove(new CoAuthor(a.getFullname(), a.getName() != null ? a.getName() : "", a.getSurname() != null ? a.getSurname() : "", a.getPid().size() > 0 ? a.getPid().get(0).getValue() : ""));
//prepare raw author id
- String id = "author::" + getMd5(a.getFullname().concat(publicationWithTopic._1().getId()));
+ String id = "author::" + getMd5(a.getFullname().concat(publicationWithEmbeddings._1().getId()));
- authors.add(new Author(a.getFullname(), a.getName(), a.getSurname(), coAuthors, publicationWithTopic._2().toArray(), id, publicationWithTopic._1().getId(), orcid));
+ //prepare embeddings
+ authors.add(new Author(
+ a.getFullname(),
+ a.getName(),
+ a.getSurname(),
+ coAuthors,
+ orcid,
+ id,
+ publicationWithEmbeddings._2(),
+ publicationWithEmbeddings._1().getId())
+ );
}
return authors.iterator();
diff --git a/dnet-feature-extraction/src/main/java/eu/dnetlib/support/ConnectedComponent.java b/dnet-feature-extraction/src/main/java/eu/dnetlib/support/ConnectedComponent.java
new file mode 100644
index 0000000..4276be6
--- /dev/null
+++ b/dnet-feature-extraction/src/main/java/eu/dnetlib/support/ConnectedComponent.java
@@ -0,0 +1,70 @@
+package eu.dnetlib.support;
+
+import com.google.common.collect.Sets;
+import org.codehaus.jackson.map.ObjectMapper;
+
+import java.io.IOException;
+import java.io.Serializable;
+import java.util.HashSet;
+import java.util.Set;
+
+public class ConnectedComponent implements Serializable {
+
+ private HashSet docs;
+ private String ccId;
+ private HashSet simrels;
+
+ public ConnectedComponent() {
+ }
+
+ public ConnectedComponent(String ccId, Set docs, Set simrels) {
+ this.docs = new HashSet<>(docs);
+ this.ccId = ccId;
+ this.simrels = new HashSet<>(simrels);
+ }
+
+ public ConnectedComponent(Set docs) {
+ this.docs = new HashSet<>(docs);
+ //initialization of id and relations missing
+ }
+
+ public ConnectedComponent(String ccId, Iterable docs, Iterable simrels) {
+ this.ccId = ccId;
+ this.docs = Sets.newHashSet(docs);
+ this.simrels = Sets.newHashSet(simrels);
+ }
+
+ @Override
+ public String toString() {
+ ObjectMapper mapper = new ObjectMapper();
+ try {
+ return mapper.writeValueAsString(this);
+ } catch (IOException e) {
+ throw new RuntimeException("Failed to create Json: ", e);
+ }
+ }
+
+ public Set getDocs() {
+ return docs;
+ }
+
+ public void setDocs(HashSet docs) {
+ this.docs = docs;
+ }
+
+ public String getCcId() {
+ return ccId;
+ }
+
+ public void setCcId(String ccId) {
+ this.ccId = ccId;
+ }
+
+ public void setSimrels(HashSet simrels) {
+ this.simrels = simrels;
+ }
+
+ public HashSet getSimrels() {
+ return simrels;
+ }
+}
diff --git a/dnet-feature-extraction/src/main/java/eu/dnetlib/support/Relation.java b/dnet-feature-extraction/src/main/java/eu/dnetlib/support/Relation.java
new file mode 100644
index 0000000..f741e9d
--- /dev/null
+++ b/dnet-feature-extraction/src/main/java/eu/dnetlib/support/Relation.java
@@ -0,0 +1,52 @@
+package eu.dnetlib.support;
+
+import java.io.Serializable;
+
+public class Relation implements Serializable {
+
+ String source;
+ String target;
+ String type;
+
+ public Relation() {
+ }
+
+ public Relation(String source, String target, String type) {
+ this.source = source;
+ this.target = target;
+ this.type = type;
+ }
+
+ public String getSource() {
+ return source;
+ }
+
+ public void setSource(String source) {
+ this.source = source;
+ }
+
+ public String getTarget() {
+ return target;
+ }
+
+ public void setTarget(String target) {
+ this.target = target;
+ }
+
+ public String getType() {
+ return type;
+ }
+
+ public void setType(String type) {
+ this.type = type;
+ }
+
+ @Override
+ public String toString() {
+ return "Relation{" +
+ "source='" + source + '\'' +
+ ", target='" + target + '\'' +
+ ", type='" + type + '\'' +
+ '}';
+ }
+}
diff --git a/dnet-feature-extraction/src/test/java/UtilityTest.java b/dnet-feature-extraction/src/test/java/UtilityTest.java
index 386a41b..6b59e29 100644
--- a/dnet-feature-extraction/src/test/java/UtilityTest.java
+++ b/dnet-feature-extraction/src/test/java/UtilityTest.java
@@ -7,6 +7,7 @@ import org.junit.jupiter.api.Test;
import java.lang.annotation.Target;
import java.util.ArrayList;
+import java.util.HashMap;
public class UtilityTest {
@@ -24,7 +25,7 @@ public class UtilityTest {
@Test
public void lnfiTest() throws Exception {
- Author a = new Author("De Bonis, Michele", "Æ", "De Bonis", new ArrayList(), new double[]{0.0, 1.0}, "author::id", "pub::id", "orcid");
+ Author a = new Author("De Bonis, Michele", "Æ", "De Bonis", new ArrayList(), "orcid", "author::id", new HashMap(), "pub::id");
System.out.println("a = " + a.isAccurate());
System.out.println(AuthorsFactory.getLNFI(a));
}
diff --git a/dnet-feature-extraction/src/test/java/eu/dnetlib/deeplearning/DataSetProcessorTest.java b/dnet-feature-extraction/src/test/java/eu/dnetlib/deeplearning/DataSetProcessorTest.java
deleted file mode 100644
index 0b0f2c7..0000000
--- a/dnet-feature-extraction/src/test/java/eu/dnetlib/deeplearning/DataSetProcessorTest.java
+++ /dev/null
@@ -1,47 +0,0 @@
-package eu.dnetlib.deeplearning;
-
-import com.beust.jcommander.internal.Sets;
-import com.google.common.collect.Lists;
-import eu.dnetlib.deeplearning.support.DataSetProcessor;
-import eu.dnetlib.support.Relation;
-import org.junit.jupiter.api.BeforeAll;
-import org.junit.jupiter.api.Test;
-import org.nd4j.linalg.dataset.MultiDataSet;
-
-import java.util.*;
-import java.util.stream.Collectors;
-
-public class DataSetProcessorTest {
-
- static Map features;
- static Set relations;
- static List groundTruth;
-
- @BeforeAll
- public static void init(){
- //initialize example features
- features = new HashMap<>();
- features.put("0", new double[]{0.0,0.0});
- features.put("1", new double[]{1.0,1.0});
- features.put("2", new double[]{2.0,2.0});
-
- //initialize example relations
- relations = new HashSet<>(Lists.newArrayList(
- new Relation("0", "1", "simrel"),
- new Relation("1", "2", "simrel")
- ));
-
- //initialize example ground truth
- groundTruth = Lists.newArrayList("class1", "class1", "class2");
-
- }
-
- @Test
- public void getMultiDataSetTest() throws Exception {
- MultiDataSet multiDataSet = DataSetProcessor.getMultiDataSet(features, relations, groundTruth);
- System.out.println("multiDataSet = " + multiDataSet);
-
- multiDataSet.asList();
- }
-
-}
diff --git a/dnet-feature-extraction/src/test/java/eu/dnetlib/deeplearning/NetworkConfigurationTests.java b/dnet-feature-extraction/src/test/java/eu/dnetlib/deeplearning/NetworkConfigurationTests.java
deleted file mode 100644
index ad791cb..0000000
--- a/dnet-feature-extraction/src/test/java/eu/dnetlib/deeplearning/NetworkConfigurationTests.java
+++ /dev/null
@@ -1,33 +0,0 @@
-package eu.dnetlib.deeplearning;
-
-import eu.dnetlib.deeplearning.support.NetworkConfigurations;
-import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
-import org.deeplearning4j.nn.graph.ComputationGraph;
-import org.junit.jupiter.api.Test;
-import org.nd4j.linalg.api.ndarray.INDArray;
-import org.nd4j.linalg.factory.Nd4j;
-
-public class NetworkConfigurationTests {
-
- public final static int N = 3; //number of nodes
- public final static int K = 7; //number of features
-
- public static INDArray[] exampleGraph = new INDArray[]{
- Nd4j.zeros(N, K), //features
- Nd4j.ones(N, N), //adjacency
- Nd4j.ones(N, N) //degree
- };
-
- @Test
- public void simpleGCNTest() {
-
- ComputationGraphConfiguration simpleGCNConf = NetworkConfigurations.getSimpleGCN(3, K, 5, 2);
- ComputationGraph simpleGCN = new ComputationGraph(simpleGCNConf);
- simpleGCN.init();
-
- INDArray[] output = simpleGCN.output(exampleGraph);
- System.out.println("output = " + output[0]);
-
- }
-
-}
diff --git a/dnet-feature-extraction/src/test/java/eu/dnetlib/deeplearning/featureextraction/FeatureTransformerTest.java b/dnet-feature-extraction/src/test/java/eu/dnetlib/deeplearning/featureextraction/FeatureTransformerTest.java
new file mode 100644
index 0000000..6a1534e
--- /dev/null
+++ b/dnet-feature-extraction/src/test/java/eu/dnetlib/deeplearning/featureextraction/FeatureTransformerTest.java
@@ -0,0 +1,53 @@
+package eu.dnetlib.deeplearning.featureextraction;
+
+import eu.dnetlib.featureextraction.ScalaFeatureTransformer;
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.ml.linalg.DenseVector;
+import org.apache.spark.ml.linalg.DenseVector$;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.RowFactory;
+import org.apache.spark.sql.SparkSession;
+import org.apache.spark.sql.types.DataTypes;
+import org.apache.spark.sql.types.Metadata;
+import org.apache.spark.sql.types.StructField;
+import org.apache.spark.sql.types.StructType;
+import org.junit.jupiter.api.BeforeAll;
+import org.junit.jupiter.api.Test;
+import scala.collection.JavaConversions;
+import scala.collection.mutable.WrappedArray;
+
+import javax.xml.crypto.Data;
+import java.io.IOException;
+import java.util.Arrays;
+
+public class FeatureTransformerTest {
+
+ static SparkSession spark;
+ static JavaSparkContext context;
+ static Dataset inputData;
+ static StructType inputSchema = new StructType(new StructField[]{
+ new StructField("title", DataTypes.StringType, false, Metadata.empty()),
+ new StructField("abstract", DataTypes.StringType, false, Metadata.empty())
+ });
+
+ @BeforeAll
+ public static void setup() throws IOException {
+
+ spark = SparkSession
+ .builder()
+ .appName("Testing")
+ .master("local[*]")
+ .getOrCreate();
+
+ context = JavaSparkContext.fromSparkContext(spark.sparkContext());
+
+ inputData = spark.createDataFrame(Arrays.asList(
+ RowFactory.create("article title 1", "article description 1"),
+ RowFactory.create("article title 2", "article description 2")
+ ), inputSchema);
+
+ }
+
+}
diff --git a/pom.xml b/pom.xml
index 82b42eb..61f468f 100644
--- a/pom.xml
+++ b/pom.xml
@@ -94,6 +94,8 @@
false
+
+
target
@@ -305,31 +307,31 @@
com.fasterxml.jackson.core
jackson-databind
${jackson.version}
- provided
+
com.fasterxml.jackson.dataformat
jackson-dataformat-xml
${jackson.version}
- provided
+
com.fasterxml.jackson.module
jackson-module-jsonSchema
${jackson.version}
- provided
+
com.fasterxml.jackson.core
jackson-core
${jackson.version}
- provided
+
com.fasterxml.jackson.core
jackson-annotations
${jackson.version}
- provided
+
@@ -388,25 +390,25 @@
org.apache.spark
spark-core_2.11
${spark.version}
- provided
+
org.apache.spark
spark-graphx_2.11
${spark.version}
- provided
+
org.apache.spark
spark-sql_2.11
${spark.version}
- provided
+
org.apache.spark
spark-mllib_2.11
${spark.version}
- provided
+
org.junit.jupiter
@@ -451,92 +453,79 @@
-
- org.nd4j
- ${nd4j.backend}
- ${dl4j-master.version}
-
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
- org.datavec
- datavec-api
- ${dl4j-master.version}
-
-
- org.datavec
- datavec-data-image
- ${dl4j-master.version}
-
-
- org.datavec
- datavec-local
- ${dl4j-master.version}
-
-
- org.deeplearning4j
- deeplearning4j-datasets
- ${dl4j-master.version}
-
-
- org.deeplearning4j
- deeplearning4j-core
- ${dl4j-master.version}
+ com.johnsnowlabs.nlp
+ spark-nlp_${scala.binary.version}
+ 2.7.5
-
- org.deeplearning4j
- resources
- ${dl4j-master.version}
-
-
-
- org.deeplearning4j
- deeplearning4j-ui
- ${dl4j-master.version}
-
-
- org.deeplearning4j
- deeplearning4j-zoo
- ${dl4j-master.version}
-
-
- org.deeplearning4j
- dl4j-spark-parameterserver_2.11
- ${dl4j-master.version}
-
-
- org.deeplearning4j
- dl4j-spark_2.11
- ${dl4j-master.version}
-
-
-
-
- jfree
- jfreechart
- 1.0.13
-
-
- org.jfree
- jcommon
- 1.0.23
-
-
- org.deeplearning4j
- deeplearning4j-datasets
- ${dl4j-master.version}
-
-
-
-
- eu.dnetlib
- dnet-dedup-test
- 4.1.13-SNAPSHOT
-
+
+
+
+
+
-