refactoring and implementation of scala feature extractor
This commit is contained in:
parent
4c9d33171d
commit
12fb11c8c1
|
@ -1,2 +1,2 @@
|
||||||
# Thu Apr 13 16:22:22 CEST 2023
|
# Thu Apr 27 21:12:07 CEST 2023
|
||||||
projectPropertyKey=projectPropertyValue
|
projectPropertyKey=projectPropertyValue
|
||||||
|
|
|
@ -11,11 +11,31 @@
|
||||||
#outputModelPath = /user/michele.debonis/lda_experiments/lda_dewey2.model
|
#outputModelPath = /user/michele.debonis/lda_experiments/lda_dewey2.model
|
||||||
|
|
||||||
#LDA INFERENCE
|
#LDA INFERENCE
|
||||||
|
#numPartitions = 1000
|
||||||
|
#inputFieldJPath = $.description[0].value
|
||||||
|
#vocabularyPath = /user/michele.debonis/lda_experiments/dewey_vocabulary
|
||||||
|
#entitiesPath = /tmp/publications_with_pid_pubmed
|
||||||
|
#workingPath = /user/michele.debonis/lda_experiments/lda_inference_working_dir
|
||||||
|
#ldaInferencePath = /user/michele.debonis/lda_experiments/publications_pubmed_topics
|
||||||
|
#ldaModelPath = /user/michele.debonis/lda_experiments/lda_dewey.model
|
||||||
|
#authorsPath = /user/michele.debonis/lda_experiments/authors_pubmed
|
||||||
|
|
||||||
|
#GNN TRAINING
|
||||||
|
#groupsPath = /user/michele.debonis/authors_dedup/gt_dedup/groupentities
|
||||||
|
#workingPath = /user/michele.debonis/gnn_experiments
|
||||||
|
#numPartitions = 1000
|
||||||
|
#numEpochs = 100
|
||||||
|
#groundTruthJPath = $.orcid
|
||||||
|
#idJPath = $.id
|
||||||
|
#featuresJPath = $.topics
|
||||||
|
|
||||||
|
#FEATURE EXTRACTION
|
||||||
|
publicationsPath = /tmp/publications_with_pid_pubmed
|
||||||
|
workingPath = /user/michele.debonis/feature_extraction
|
||||||
numPartitions = 1000
|
numPartitions = 1000
|
||||||
inputFieldJPath = $.description[0].value
|
featuresPath = /user/michele.debonis/feature_extraction/publications_pubmed_features
|
||||||
vocabularyPath = /user/michele.debonis/lda_experiments/dewey_vocabulary
|
topicsPath = /user/michele.debonis/lda_experiments/publications_pubmed_topics
|
||||||
entitiesPath = /tmp/publications_with_pid_pubmed
|
outputPath = /user/michele.debonis/feature_extraction/authors_pubmed
|
||||||
workingPath = /user/michele.debonis/lda_experiments/lda_inference_working_dir
|
wordEmbeddingsModel = /user/michele.debonis/nlp_models/glove_100d_en_2.4.0_2.4_1579690104032
|
||||||
ldaInferencePath = /user/michele.debonis/lda_experiments/publications_pubmed_topics
|
bertSentenceModel = /user/michele.debonis/nlp_models/sent_small_bert_L6_512_en_2.6.0_2.4_1598350624049
|
||||||
ldaModelPath = /user/michele.debonis/lda_experiments/lda_dewey.model
|
bertModel = /user/michele.debonis/nlp_models/small_bert_L2_128_en_2.6.0_2.4_1598344320681
|
||||||
authorsPath = /user/michele.debonis/lda_experiments/authors_pubmed
|
|
|
@ -131,23 +131,11 @@
|
||||||
<artifactId>json-path</artifactId>
|
<artifactId>json-path</artifactId>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
|
||||||
|
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>eu.dnetlib.dhp</groupId>
|
<groupId>eu.dnetlib.dhp</groupId>
|
||||||
<artifactId>dhp-schemas</artifactId>
|
<artifactId>dhp-schemas</artifactId>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
|
||||||
<!-- <dependency>-->
|
|
||||||
<!-- <groupId>org.mockito</groupId>-->
|
|
||||||
<!-- <artifactId>mockito-core</artifactId>-->
|
|
||||||
<!-- <scope>test</scope>-->
|
|
||||||
<!-- </dependency>-->
|
|
||||||
<!-- <dependency>-->
|
|
||||||
<!-- <groupId>org.mockito</groupId>-->
|
|
||||||
<!-- <artifactId>mockito-junit-jupiter</artifactId>-->
|
|
||||||
<!-- <scope>test</scope>-->
|
|
||||||
<!-- </dependency>-->
|
|
||||||
|
|
||||||
</dependencies>
|
</dependencies>
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
|
|
|
@ -36,7 +36,7 @@ public abstract class AbstractSparkJob implements Serializable {
|
||||||
this.spark = spark;
|
this.spark = spark;
|
||||||
}
|
}
|
||||||
|
|
||||||
protected abstract void run() throws IOException;
|
protected abstract void run() throws IOException, InterruptedException;
|
||||||
|
|
||||||
protected static SparkSession getSparkSession(SparkConf conf) {
|
protected static SparkSession getSparkSession(SparkConf conf) {
|
||||||
return SparkSession.builder().config(conf).getOrCreate();
|
return SparkSession.builder().config(conf).getOrCreate();
|
||||||
|
|
|
@ -2,7 +2,6 @@ package eu.dnetlib.jobs;
|
||||||
|
|
||||||
import eu.dnetlib.featureextraction.FeatureTransformer;
|
import eu.dnetlib.featureextraction.FeatureTransformer;
|
||||||
import eu.dnetlib.support.ArgumentApplicationParser;
|
import eu.dnetlib.support.ArgumentApplicationParser;
|
||||||
import org.apache.hadoop.fs.shell.Count;
|
|
||||||
import org.apache.spark.SparkConf;
|
import org.apache.spark.SparkConf;
|
||||||
import org.apache.spark.ml.feature.CountVectorizerModel;
|
import org.apache.spark.ml.feature.CountVectorizerModel;
|
||||||
import org.apache.spark.sql.Dataset;
|
import org.apache.spark.sql.Dataset;
|
||||||
|
|
|
@ -9,8 +9,6 @@ import org.slf4j.Logger;
|
||||||
import org.slf4j.LoggerFactory;
|
import org.slf4j.LoggerFactory;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.net.URISyntaxException;
|
|
||||||
import java.nio.file.Paths;
|
|
||||||
import java.util.Optional;
|
import java.util.Optional;
|
||||||
|
|
||||||
public class SparkCreateVocabulary extends AbstractSparkJob{
|
public class SparkCreateVocabulary extends AbstractSparkJob{
|
||||||
|
|
|
@ -1,75 +0,0 @@
|
||||||
package eu.dnetlib.jobs.deeplearning;
|
|
||||||
|
|
||||||
import eu.dnetlib.deeplearning.support.DataSetProcessor;
|
|
||||||
import eu.dnetlib.jobs.AbstractSparkJob;
|
|
||||||
import eu.dnetlib.jobs.SparkLDATuning;
|
|
||||||
import eu.dnetlib.support.ArgumentApplicationParser;
|
|
||||||
import eu.dnetlib.support.ConnectedComponent;
|
|
||||||
import org.apache.spark.SparkConf;
|
|
||||||
import org.apache.spark.api.java.JavaRDD;
|
|
||||||
import org.apache.spark.api.java.JavaSparkContext;
|
|
||||||
import org.apache.spark.sql.SparkSession;
|
|
||||||
import org.codehaus.jackson.map.ObjectMapper;
|
|
||||||
import org.deeplearning4j.spark.data.BatchAndExportDataSetsFunction;
|
|
||||||
import org.deeplearning4j.spark.data.BatchAndExportMultiDataSetsFunction;
|
|
||||||
import org.deeplearning4j.spark.datavec.iterator.IteratorUtils;
|
|
||||||
import org.nd4j.linalg.dataset.MultiDataSet;
|
|
||||||
import org.slf4j.Logger;
|
|
||||||
import org.slf4j.LoggerFactory;
|
|
||||||
|
|
||||||
import java.io.IOException;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Optional;
|
|
||||||
|
|
||||||
public class SparkCreateGroupDataSet extends AbstractSparkJob {
|
|
||||||
|
|
||||||
private static final Logger log = LoggerFactory.getLogger(SparkCreateGroupDataSet.class);
|
|
||||||
|
|
||||||
public SparkCreateGroupDataSet(ArgumentApplicationParser parser, SparkSession spark) {
|
|
||||||
super(parser, spark);
|
|
||||||
}
|
|
||||||
public static void main(String[] args) throws Exception {
|
|
||||||
ArgumentApplicationParser parser = new ArgumentApplicationParser(
|
|
||||||
readResource("/jobs/parameters/createGroupDataset_parameters.json", SparkLDATuning.class)
|
|
||||||
);
|
|
||||||
|
|
||||||
parser.parseArgument(args);
|
|
||||||
|
|
||||||
SparkConf conf = new SparkConf();
|
|
||||||
|
|
||||||
new SparkCreateGroupDataSet(
|
|
||||||
parser,
|
|
||||||
getSparkSession(conf)
|
|
||||||
).run();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void run() throws IOException {
|
|
||||||
// read oozie parameters
|
|
||||||
final String groupsPath = parser.get("groupsPath");
|
|
||||||
final String workingPath = parser.get("workingPath");
|
|
||||||
final String groundTruthJPath = parser.get("groundTruthJPath");
|
|
||||||
final String idJPath = parser.get("idJPath");
|
|
||||||
final String featuresJPath = parser.get("featuresJPath");
|
|
||||||
final int numPartitions = Optional
|
|
||||||
.ofNullable(parser.get("numPartitions"))
|
|
||||||
.map(Integer::valueOf)
|
|
||||||
.orElse(NUM_PARTITIONS);
|
|
||||||
|
|
||||||
log.info("groupsPath: '{}'", groupsPath);
|
|
||||||
log.info("workingPath: '{}'", workingPath);
|
|
||||||
log.info("groundTruthJPath: '{}'", groundTruthJPath);
|
|
||||||
log.info("idJPath: '{}'", idJPath);
|
|
||||||
log.info("featuresJPath: '{}'", featuresJPath);
|
|
||||||
log.info("numPartitions: '{}'", numPartitions);
|
|
||||||
|
|
||||||
JavaSparkContext context = JavaSparkContext.fromSparkContext(spark.sparkContext());
|
|
||||||
|
|
||||||
JavaRDD<ConnectedComponent> groups = context.textFile(groupsPath).map(g -> new ObjectMapper().readValue(g, ConnectedComponent.class));
|
|
||||||
|
|
||||||
JavaRDD<MultiDataSet> dataset = DataSetProcessor.entityGroupToMultiDataset(groups, idJPath, featuresJPath, groundTruthJPath);
|
|
||||||
|
|
||||||
dataset.saveAsObjectFile(workingPath + "/groupDataset");
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,99 +0,0 @@
|
||||||
package eu.dnetlib.jobs.deeplearning;
|
|
||||||
|
|
||||||
import eu.dnetlib.deeplearning.support.DataSetProcessor;
|
|
||||||
import eu.dnetlib.deeplearning.support.NetworkConfigurations;
|
|
||||||
import eu.dnetlib.jobs.AbstractSparkJob;
|
|
||||||
import eu.dnetlib.jobs.SparkLDATuning;
|
|
||||||
import eu.dnetlib.support.ArgumentApplicationParser;
|
|
||||||
import eu.dnetlib.support.ConnectedComponent;
|
|
||||||
import org.apache.spark.SparkConf;
|
|
||||||
import org.apache.spark.api.java.JavaRDD;
|
|
||||||
import org.apache.spark.api.java.JavaSparkContext;
|
|
||||||
import org.apache.spark.sql.SparkSession;
|
|
||||||
import org.codehaus.jackson.map.ObjectMapper;
|
|
||||||
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
|
|
||||||
import org.deeplearning4j.nn.graph.ComputationGraph;
|
|
||||||
import org.deeplearning4j.optimize.listeners.PerformanceListener;
|
|
||||||
import org.deeplearning4j.optimize.solvers.accumulation.encoding.threshold.AdaptiveThresholdAlgorithm;
|
|
||||||
import org.deeplearning4j.spark.api.RDDTrainingApproach;
|
|
||||||
import org.deeplearning4j.spark.api.TrainingMaster;
|
|
||||||
import org.deeplearning4j.spark.impl.graph.SparkComputationGraph;
|
|
||||||
import org.deeplearning4j.spark.parameterserver.training.SharedTrainingMaster;
|
|
||||||
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
|
||||||
import org.nd4j.parameterserver.distributed.conf.VoidConfiguration;
|
|
||||||
import org.slf4j.Logger;
|
|
||||||
import org.slf4j.LoggerFactory;
|
|
||||||
|
|
||||||
import java.io.IOException;
|
|
||||||
import java.util.Optional;
|
|
||||||
|
|
||||||
public class SparkGraphClassificationTraining extends AbstractSparkJob {
|
|
||||||
|
|
||||||
private static final Logger log = LoggerFactory.getLogger(SparkGraphClassificationTraining.class);
|
|
||||||
|
|
||||||
public SparkGraphClassificationTraining(ArgumentApplicationParser parser, SparkSession spark) {
|
|
||||||
super(parser, spark);
|
|
||||||
}
|
|
||||||
public static void main(String[] args) throws Exception {
|
|
||||||
ArgumentApplicationParser parser = new ArgumentApplicationParser(
|
|
||||||
readResource("/jobs/parameters/graphClassificationTraining_parameters.json", SparkLDATuning.class)
|
|
||||||
);
|
|
||||||
|
|
||||||
parser.parseArgument(args);
|
|
||||||
|
|
||||||
SparkConf conf = new SparkConf();
|
|
||||||
|
|
||||||
new SparkGraphClassificationTraining(
|
|
||||||
parser,
|
|
||||||
getSparkSession(conf)
|
|
||||||
).run();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void run() throws IOException {
|
|
||||||
// read oozie parameters
|
|
||||||
final String workingPath = parser.get("workingPath");
|
|
||||||
final int numPartitions = Optional
|
|
||||||
.ofNullable(parser.get("numPartitions"))
|
|
||||||
.map(Integer::valueOf)
|
|
||||||
.orElse(NUM_PARTITIONS);
|
|
||||||
log.info("workingPath: '{}'", workingPath);
|
|
||||||
log.info("numPartitions: '{}'", numPartitions);
|
|
||||||
|
|
||||||
JavaSparkContext context = JavaSparkContext.fromSparkContext(spark.sparkContext());
|
|
||||||
|
|
||||||
VoidConfiguration conf = VoidConfiguration.builder()
|
|
||||||
.unicastPort(40123)
|
|
||||||
// .networkMask("255.255.148.0/22")
|
|
||||||
.controllerAddress("127.0.0.1")
|
|
||||||
.build();
|
|
||||||
|
|
||||||
TrainingMaster trainingMaster = new SharedTrainingMaster.Builder(conf,1)
|
|
||||||
.rngSeed(12345)
|
|
||||||
.collectTrainingStats(false)
|
|
||||||
.thresholdAlgorithm(new AdaptiveThresholdAlgorithm(1e-3))
|
|
||||||
.batchSizePerWorker(32)
|
|
||||||
.workersPerNode(4)
|
|
||||||
.rddTrainingApproach(RDDTrainingApproach.Direct)
|
|
||||||
.build();
|
|
||||||
|
|
||||||
JavaRDD<MultiDataSet> trainData = context.objectFile(workingPath + "/groupDataset");
|
|
||||||
|
|
||||||
SparkComputationGraph sparkComputationGraph = new SparkComputationGraph(
|
|
||||||
context,
|
|
||||||
NetworkConfigurations.getSimpleGCN(3, 2, 5, 2),
|
|
||||||
trainingMaster);
|
|
||||||
sparkComputationGraph.setListeners(new PerformanceListener(10, true));
|
|
||||||
|
|
||||||
//execute training
|
|
||||||
for (int i = 0; i < 20; i ++) {
|
|
||||||
sparkComputationGraph.fitMultiDataSet(trainData);
|
|
||||||
}
|
|
||||||
|
|
||||||
ComputationGraph network = sparkComputationGraph.getNetwork();
|
|
||||||
|
|
||||||
System.out.println("network = " + network.getConfiguration().toJson());
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,5 +1,8 @@
|
||||||
package eu.dnetlib.jobs;
|
package eu.dnetlib.jobs.featureextraction;
|
||||||
|
|
||||||
|
import eu.dnetlib.dhp.schema.oaf.Publication;
|
||||||
|
import eu.dnetlib.jobs.AbstractSparkJob;
|
||||||
|
import eu.dnetlib.jobs.SparkTokenizer;
|
||||||
import eu.dnetlib.support.ArgumentApplicationParser;
|
import eu.dnetlib.support.ArgumentApplicationParser;
|
||||||
import eu.dnetlib.support.Author;
|
import eu.dnetlib.support.Author;
|
||||||
import eu.dnetlib.support.AuthorsFactory;
|
import eu.dnetlib.support.AuthorsFactory;
|
||||||
|
@ -10,15 +13,18 @@ import org.apache.spark.api.java.JavaRDD;
|
||||||
import org.apache.spark.api.java.JavaSparkContext;
|
import org.apache.spark.api.java.JavaSparkContext;
|
||||||
import org.apache.spark.ml.linalg.DenseVector;
|
import org.apache.spark.ml.linalg.DenseVector;
|
||||||
import org.apache.spark.sql.SparkSession;
|
import org.apache.spark.sql.SparkSession;
|
||||||
|
import org.codehaus.jackson.map.DeserializationConfig;
|
||||||
import org.codehaus.jackson.map.ObjectMapper;
|
import org.codehaus.jackson.map.ObjectMapper;
|
||||||
import org.slf4j.Logger;
|
import org.slf4j.Logger;
|
||||||
import org.slf4j.LoggerFactory;
|
import org.slf4j.LoggerFactory;
|
||||||
import scala.Tuple2;
|
import scala.Tuple2;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.Map;
|
||||||
import java.util.Optional;
|
import java.util.Optional;
|
||||||
|
|
||||||
public class SparkAuthorExtractor extends AbstractSparkJob{
|
public class SparkAuthorExtractor extends AbstractSparkJob {
|
||||||
private static final Logger log = LoggerFactory.getLogger(SparkAuthorExtractor.class);
|
private static final Logger log = LoggerFactory.getLogger(SparkAuthorExtractor.class);
|
||||||
|
|
||||||
public SparkAuthorExtractor(ArgumentApplicationParser parser, SparkSession spark) {
|
public SparkAuthorExtractor(ArgumentApplicationParser parser, SparkSession spark) {
|
||||||
|
@ -45,7 +51,8 @@ public class SparkAuthorExtractor extends AbstractSparkJob{
|
||||||
public void run() throws IOException {
|
public void run() throws IOException {
|
||||||
// read oozie parameters
|
// read oozie parameters
|
||||||
final String topicsPath = parser.get("topicsPath");
|
final String topicsPath = parser.get("topicsPath");
|
||||||
final String entitiesPath = parser.get("entitiesPath");
|
final String featuresPath = parser.get("featuresPath");
|
||||||
|
final String publicationsPath = parser.get("publicationsPath");
|
||||||
final String workingPath = parser.get("workingPath");
|
final String workingPath = parser.get("workingPath");
|
||||||
final String outputPath = parser.get("outputPath");
|
final String outputPath = parser.get("outputPath");
|
||||||
final int numPartitions = Optional
|
final int numPartitions = Optional
|
||||||
|
@ -53,8 +60,9 @@ public class SparkAuthorExtractor extends AbstractSparkJob{
|
||||||
.map(Integer::valueOf)
|
.map(Integer::valueOf)
|
||||||
.orElse(NUM_PARTITIONS);
|
.orElse(NUM_PARTITIONS);
|
||||||
|
|
||||||
log.info("entitiesPath: '{}'", entitiesPath);
|
log.info("publicationsPath: '{}'", publicationsPath);
|
||||||
log.info("topicsPath: '{}'", topicsPath);
|
log.info("topicsPath: '{}'", topicsPath);
|
||||||
|
log.info("featuresPath: '{}'", featuresPath);
|
||||||
log.info("workingPath: '{}'", workingPath);
|
log.info("workingPath: '{}'", workingPath);
|
||||||
log.info("outputPath: '{}'", outputPath);
|
log.info("outputPath: '{}'", outputPath);
|
||||||
log.info("numPartitions: '{}'", numPartitions);
|
log.info("numPartitions: '{}'", numPartitions);
|
||||||
|
@ -62,16 +70,34 @@ public class SparkAuthorExtractor extends AbstractSparkJob{
|
||||||
//join publications with topics
|
//join publications with topics
|
||||||
JavaSparkContext context = JavaSparkContext.fromSparkContext(spark.sparkContext());
|
JavaSparkContext context = JavaSparkContext.fromSparkContext(spark.sparkContext());
|
||||||
|
|
||||||
JavaRDD<String> entities = context.textFile(entitiesPath);
|
JavaRDD<Publication> publications = context
|
||||||
|
.textFile(publicationsPath)
|
||||||
|
.map(x -> new ObjectMapper()
|
||||||
|
.configure(DeserializationConfig.Feature.FAIL_ON_UNKNOWN_PROPERTIES, false)
|
||||||
|
.readValue(x, Publication.class));
|
||||||
|
|
||||||
JavaPairRDD<String, DenseVector> topics = spark.read().load(topicsPath).toJavaRDD()
|
JavaPairRDD<String, double[]> topics = spark.read().load(topicsPath).toJavaRDD()
|
||||||
.mapToPair(t -> new Tuple2<>(t.getString(0), (DenseVector) t.get(1)));
|
.mapToPair(t -> new Tuple2<>(t.getString(0), ((DenseVector) t.get(1)).toArray()));
|
||||||
|
|
||||||
JavaRDD<Author> authors = AuthorsFactory.extractAuthorsFromPublications(entities, topics);
|
//merge topics with other embeddings
|
||||||
|
JavaPairRDD<String, Map<String, double[]>> publicationEmbeddings = spark.read().load(featuresPath).toJavaRDD().mapToPair(t -> {
|
||||||
|
Map<String, double[]> embeddings = new HashMap<>();
|
||||||
|
embeddings.put("word_embeddings", ((DenseVector) t.get(1)).toArray());
|
||||||
|
embeddings.put("bert_embeddings", ((DenseVector) t.get(2)).toArray());
|
||||||
|
embeddings.put("bert_sentence_embeddings", ((DenseVector) t.get(3)).toArray());
|
||||||
|
return new Tuple2<>(t.getString(0), embeddings);
|
||||||
|
})
|
||||||
|
.join(topics).mapToPair(e -> {
|
||||||
|
e._2()._1().put("lda_topics", e._2()._2());
|
||||||
|
return new Tuple2<>(e._1(), e._2()._1());
|
||||||
|
});
|
||||||
|
|
||||||
|
JavaRDD<Author> authors = AuthorsFactory.extractAuthorsFromPublications(publications, publicationEmbeddings);
|
||||||
|
|
||||||
authors
|
authors
|
||||||
.map(a -> new ObjectMapper().writeValueAsString(a))
|
.map(a -> new ObjectMapper().writeValueAsString(a))
|
||||||
.saveAsTextFile(outputPath, GzipCodec.class);
|
.saveAsTextFile(outputPath, GzipCodec.class);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
|
@ -0,0 +1,107 @@
|
||||||
|
package eu.dnetlib.jobs.featureextraction;
|
||||||
|
|
||||||
|
import eu.dnetlib.dhp.schema.oaf.Publication;
|
||||||
|
import eu.dnetlib.dhp.schema.oaf.StructuredProperty;
|
||||||
|
import eu.dnetlib.featureextraction.ScalaFeatureTransformer;
|
||||||
|
import eu.dnetlib.jobs.AbstractSparkJob;
|
||||||
|
import eu.dnetlib.support.ArgumentApplicationParser;
|
||||||
|
import org.apache.spark.SparkConf;
|
||||||
|
import org.apache.spark.api.java.JavaRDD;
|
||||||
|
import org.apache.spark.api.java.JavaSparkContext;
|
||||||
|
import org.apache.spark.sql.*;
|
||||||
|
import org.apache.spark.sql.types.DataTypes;
|
||||||
|
import org.apache.spark.sql.types.Metadata;
|
||||||
|
import org.apache.spark.sql.types.StructField;
|
||||||
|
import org.apache.spark.sql.types.StructType;
|
||||||
|
import org.codehaus.jackson.map.DeserializationConfig;
|
||||||
|
import org.codehaus.jackson.map.ObjectMapper;
|
||||||
|
import org.slf4j.Logger;
|
||||||
|
import org.slf4j.LoggerFactory;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.Optional;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
|
public class SparkPublicationFeatureExtractor extends AbstractSparkJob {
|
||||||
|
private static final Logger log = LoggerFactory.getLogger(SparkPublicationFeatureExtractor.class);
|
||||||
|
|
||||||
|
public SparkPublicationFeatureExtractor(ArgumentApplicationParser parser, SparkSession spark) {
|
||||||
|
super(parser, spark);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static void main(String[] args) throws Exception {
|
||||||
|
|
||||||
|
ArgumentApplicationParser parser = new ArgumentApplicationParser(
|
||||||
|
readResource("/jobs/parameters/publicationFeatureExtractor_parameters.json", SparkPublicationFeatureExtractor.class)
|
||||||
|
);
|
||||||
|
|
||||||
|
parser.parseArgument(args);
|
||||||
|
|
||||||
|
SparkConf conf = new SparkConf();
|
||||||
|
|
||||||
|
new SparkAuthorExtractor(
|
||||||
|
parser,
|
||||||
|
getSparkSession(conf)
|
||||||
|
).run();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void run() throws IOException {
|
||||||
|
// read oozie parameters
|
||||||
|
final String publicationsPath = parser.get("publicationsPath");
|
||||||
|
final String workingPath = parser.get("workingPath");
|
||||||
|
final String featuresPath = parser.get("featuresPath");
|
||||||
|
final String bertModel = parser.get("bertModelPath");
|
||||||
|
final String bertSentenceModel = parser.get("bertSentenceModel");
|
||||||
|
final String wordEmbeddingModel = parser.get("wordEmbeddingModel");
|
||||||
|
final int numPartitions = Optional
|
||||||
|
.ofNullable(parser.get("numPartitions"))
|
||||||
|
.map(Integer::valueOf)
|
||||||
|
.orElse(NUM_PARTITIONS);
|
||||||
|
|
||||||
|
log.info("publicationsPath: '{}'", publicationsPath);
|
||||||
|
log.info("workingPath: '{}'", workingPath);
|
||||||
|
log.info("numPartitions: '{}'", numPartitions);
|
||||||
|
|
||||||
|
JavaSparkContext context = JavaSparkContext.fromSparkContext(spark.sparkContext());
|
||||||
|
|
||||||
|
JavaRDD<Publication> publications = context
|
||||||
|
.textFile(publicationsPath)
|
||||||
|
.map(x -> new ObjectMapper()
|
||||||
|
.configure(DeserializationConfig.Feature.FAIL_ON_UNKNOWN_PROPERTIES, false)
|
||||||
|
.readValue(x, Publication.class));
|
||||||
|
|
||||||
|
StructType inputSchema = new StructType(new StructField[]{
|
||||||
|
new StructField("id", DataTypes.StringType, false, Metadata.empty()),
|
||||||
|
new StructField("title", DataTypes.StringType, false, Metadata.empty()),
|
||||||
|
new StructField("abstract", DataTypes.StringType, false, Metadata.empty()),
|
||||||
|
new StructField("subjects", DataTypes.StringType, false, Metadata.empty())
|
||||||
|
});
|
||||||
|
|
||||||
|
//prepare Rows
|
||||||
|
Dataset<Row> inputData = spark.createDataFrame(
|
||||||
|
publications.map(p -> RowFactory.create(
|
||||||
|
p.getId(),
|
||||||
|
p.getTitle().get(0).getValue(),
|
||||||
|
p.getDescription().size()>0? p.getDescription().get(0).getValue(): "",
|
||||||
|
p.getSubject().stream().map(StructuredProperty::getValue).collect(Collectors.joining(" ")))),
|
||||||
|
inputSchema);
|
||||||
|
|
||||||
|
log.info("Generating word embeddings");
|
||||||
|
Dataset<Row> wordEmbeddingsData = ScalaFeatureTransformer.wordEmbeddings(inputData, "subjects", wordEmbeddingModel);
|
||||||
|
|
||||||
|
log.info("Generating bert embeddings");
|
||||||
|
Dataset<Row> bertEmbeddingsData = ScalaFeatureTransformer.bertEmbeddings(wordEmbeddingsData, "title", bertModel);
|
||||||
|
|
||||||
|
log.info("Generating bert sentence embeddings");
|
||||||
|
Dataset<Row> bertSentenceEmbeddingsData = ScalaFeatureTransformer.bertSentenceEmbeddings(bertEmbeddingsData, "abstract", bertSentenceModel);
|
||||||
|
|
||||||
|
Dataset<Row> features = bertSentenceEmbeddingsData.select("id", ScalaFeatureTransformer.WORD_EMBEDDINGS_COL(), ScalaFeatureTransformer.BERT_EMBEDDINGS_COL(), ScalaFeatureTransformer.BERT_SENTENCE_EMBEDDINGS_COL());
|
||||||
|
|
||||||
|
features
|
||||||
|
.write()
|
||||||
|
.mode(SaveMode.Overwrite)
|
||||||
|
.save(featuresPath);
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,7 +1,8 @@
|
||||||
package eu.dnetlib.jobs;
|
package eu.dnetlib.jobs.featureextraction.lda;
|
||||||
|
|
||||||
import com.clearspring.analytics.util.Lists;
|
import com.clearspring.analytics.util.Lists;
|
||||||
import eu.dnetlib.featureextraction.Utilities;
|
import eu.dnetlib.featureextraction.Utilities;
|
||||||
|
import eu.dnetlib.jobs.AbstractSparkJob;
|
||||||
import eu.dnetlib.support.ArgumentApplicationParser;
|
import eu.dnetlib.support.ArgumentApplicationParser;
|
||||||
import eu.dnetlib.support.Author;
|
import eu.dnetlib.support.Author;
|
||||||
import eu.dnetlib.support.AuthorsFactory;
|
import eu.dnetlib.support.AuthorsFactory;
|
||||||
|
@ -108,7 +109,7 @@ public class SparkLDAAnalysis extends AbstractSparkJob {
|
||||||
else {
|
else {
|
||||||
bRes = authors.get(i).getOrcid().equals(authors.get(j).getOrcid());
|
bRes = authors.get(i).getOrcid().equals(authors.get(j).getOrcid());
|
||||||
}
|
}
|
||||||
results.add(new Tuple2<>(bRes, cosineSimilarity(authors.get(i).getTopics(), authors.get(j).getTopics())));
|
results.add(new Tuple2<>(bRes, cosineSimilarity(authors.get(i).getEmbeddings().get("lda_topics"), authors.get(j).getEmbeddings().get("lda_topics"))));
|
||||||
j++;
|
j++;
|
||||||
}
|
}
|
||||||
i++;
|
i++;
|
|
@ -1,6 +1,7 @@
|
||||||
package eu.dnetlib.jobs;
|
package eu.dnetlib.jobs.featureextraction.lda;
|
||||||
|
|
||||||
import eu.dnetlib.featureextraction.FeatureTransformer;
|
import eu.dnetlib.featureextraction.FeatureTransformer;
|
||||||
|
import eu.dnetlib.jobs.AbstractSparkJob;
|
||||||
import eu.dnetlib.support.ArgumentApplicationParser;
|
import eu.dnetlib.support.ArgumentApplicationParser;
|
||||||
import org.apache.spark.SparkConf;
|
import org.apache.spark.SparkConf;
|
||||||
import org.apache.spark.ml.clustering.LDAModel;
|
import org.apache.spark.ml.clustering.LDAModel;
|
||||||
|
@ -15,7 +16,7 @@ import org.slf4j.LoggerFactory;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
|
||||||
public class SparkLDAInference extends AbstractSparkJob{
|
public class SparkLDAInference extends AbstractSparkJob {
|
||||||
|
|
||||||
private static final Logger log = LoggerFactory.getLogger(SparkLDAInference.class);
|
private static final Logger log = LoggerFactory.getLogger(SparkLDAInference.class);
|
||||||
|
|
|
@ -1,7 +1,8 @@
|
||||||
package eu.dnetlib.jobs;
|
package eu.dnetlib.jobs.featureextraction.lda;
|
||||||
|
|
||||||
import eu.dnetlib.featureextraction.FeatureTransformer;
|
import eu.dnetlib.featureextraction.FeatureTransformer;
|
||||||
import eu.dnetlib.featureextraction.Utilities;
|
import eu.dnetlib.featureextraction.Utilities;
|
||||||
|
import eu.dnetlib.jobs.AbstractSparkJob;
|
||||||
import eu.dnetlib.support.ArgumentApplicationParser;
|
import eu.dnetlib.support.ArgumentApplicationParser;
|
||||||
import org.apache.spark.SparkConf;
|
import org.apache.spark.SparkConf;
|
||||||
import org.apache.spark.ml.clustering.LDAModel;
|
import org.apache.spark.ml.clustering.LDAModel;
|
||||||
|
@ -15,7 +16,7 @@ import scala.Tuple2;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
|
||||||
public class SparkLDATuning extends AbstractSparkJob{
|
public class SparkLDATuning extends AbstractSparkJob {
|
||||||
|
|
||||||
private static final Logger log = LoggerFactory.getLogger(SparkLDATuning.class);
|
private static final Logger log = LoggerFactory.getLogger(SparkLDATuning.class);
|
||||||
|
|
|
@ -0,0 +1,18 @@
|
||||||
|
<configuration>
|
||||||
|
<property>
|
||||||
|
<name>jobTracker</name>
|
||||||
|
<value>yarnRM</value>
|
||||||
|
</property>
|
||||||
|
<property>
|
||||||
|
<name>nameNode</name>
|
||||||
|
<value>hdfs://nameservice1</value>
|
||||||
|
</property>
|
||||||
|
<property>
|
||||||
|
<name>oozie.use.system.libpath</name>
|
||||||
|
<value>true</value>
|
||||||
|
</property>
|
||||||
|
<property>
|
||||||
|
<name>oozie.action.sharelib.for.spark</name>
|
||||||
|
<value>spark2</value>
|
||||||
|
</property>
|
||||||
|
</configuration>
|
|
@ -0,0 +1,172 @@
|
||||||
|
<workflow-app name="Publication Features Extraction" xmlns="uri:oozie:workflow:0.5">
|
||||||
|
<parameters>
|
||||||
|
<property>
|
||||||
|
<name>publicationsPath</name>
|
||||||
|
<description>the input entity path</description>
|
||||||
|
</property>
|
||||||
|
<property>
|
||||||
|
<name>workingPath</name>
|
||||||
|
<description>path for the working directory</description>
|
||||||
|
</property>
|
||||||
|
<property>
|
||||||
|
<name>numPartitions</name>
|
||||||
|
<description>number of partitions for the spark files</description>
|
||||||
|
</property>
|
||||||
|
<property>
|
||||||
|
<name>featuresPath</name>
|
||||||
|
<description>location of the embeddings</description>
|
||||||
|
</property>
|
||||||
|
<property>
|
||||||
|
<name>topicsPath</name>
|
||||||
|
<description>location of the topics</description>
|
||||||
|
</property>
|
||||||
|
<property>
|
||||||
|
<name>outputPath</name>
|
||||||
|
<description>location of the output authors</description>
|
||||||
|
</property>
|
||||||
|
<property>
|
||||||
|
<name>bertModel</name>
|
||||||
|
<description>location of the bert model</description>
|
||||||
|
</property>
|
||||||
|
<property>
|
||||||
|
<name>bertSentenceModel</name>
|
||||||
|
<description>location of the bert sentence model</description>
|
||||||
|
</property>
|
||||||
|
<property>
|
||||||
|
<name>wordEmbeddingsModel</name>
|
||||||
|
<description>location of the word embeddings model</description>
|
||||||
|
</property>
|
||||||
|
<property>
|
||||||
|
<name>sparkDriverMemory</name>
|
||||||
|
<description>memory for driver process</description>
|
||||||
|
</property>
|
||||||
|
<property>
|
||||||
|
<name>sparkExecutorMemory</name>
|
||||||
|
<description>memory for individual executor</description>
|
||||||
|
</property>
|
||||||
|
<property>
|
||||||
|
<name>sparkExecutorCores</name>
|
||||||
|
<description>number of cores used by single executor</description>
|
||||||
|
</property>
|
||||||
|
<property>
|
||||||
|
<name>oozieActionShareLibForSpark2</name>
|
||||||
|
<description>oozie action sharelib for spark 2.*</description>
|
||||||
|
</property>
|
||||||
|
<property>
|
||||||
|
<name>spark2ExtraListeners</name>
|
||||||
|
<value>com.cloudera.spark.lineage.NavigatorAppListener</value>
|
||||||
|
<description>spark 2.* extra listeners classname</description>
|
||||||
|
</property>
|
||||||
|
<property>
|
||||||
|
<name>spark2SqlQueryExecutionListeners</name>
|
||||||
|
<value>com.cloudera.spark.lineage.NavigatorQueryListener</value>
|
||||||
|
<description>spark 2.* sql query execution listeners classname</description>
|
||||||
|
</property>
|
||||||
|
<property>
|
||||||
|
<name>spark2YarnHistoryServerAddress</name>
|
||||||
|
<description>spark 2.* yarn history server address</description>
|
||||||
|
</property>
|
||||||
|
<property>
|
||||||
|
<name>spark2EventLogDir</name>
|
||||||
|
<description>spark 2.* event log dir location</description>
|
||||||
|
</property>
|
||||||
|
</parameters>
|
||||||
|
|
||||||
|
<global>
|
||||||
|
<job-tracker>${jobTracker}</job-tracker>
|
||||||
|
<name-node>${nameNode}</name-node>
|
||||||
|
<configuration>
|
||||||
|
<property>
|
||||||
|
<name>mapreduce.job.queuename</name>
|
||||||
|
<value>${queueName}</value>
|
||||||
|
</property>
|
||||||
|
<property>
|
||||||
|
<name>oozie.launcher.mapred.job.queue.name</name>
|
||||||
|
<value>${oozieLauncherQueueName}</value>
|
||||||
|
</property>
|
||||||
|
<property>
|
||||||
|
<name>oozie.action.sharelib.for.spark</name>
|
||||||
|
<value>${oozieActionShareLibForSpark2}</value>
|
||||||
|
</property>
|
||||||
|
</configuration>
|
||||||
|
</global>
|
||||||
|
|
||||||
|
<start to="resetWorkingPath"/>
|
||||||
|
|
||||||
|
<kill name="Kill">
|
||||||
|
<message>Action failed, error message[${wf:errorMessage(wf:lastErrorNode())}]</message>
|
||||||
|
</kill>
|
||||||
|
|
||||||
|
<action name="resetWorkingPath">
|
||||||
|
<fs>
|
||||||
|
<delete path="${workingPath}"/>
|
||||||
|
</fs>
|
||||||
|
<ok to="PublicationFeatureExtractor"/>
|
||||||
|
<error to="Kill"/>
|
||||||
|
</action>
|
||||||
|
|
||||||
|
<!--TODO reimplement LDA procedure and put inference here as the first step-->
|
||||||
|
|
||||||
|
<action name="PublicationFeatureExtractor">
|
||||||
|
<spark xmlns="uri:oozie:spark-action:0.2">
|
||||||
|
<master>yarn</master>
|
||||||
|
<mode>cluster</mode>
|
||||||
|
<name>Publication Feature Extraction</name>
|
||||||
|
<class>eu.dnetlib.jobs.featureextraction.SparkPublicationFeatureExtractor</class>
|
||||||
|
<jar>dnet-and-test-${projectVersion}.jar</jar>
|
||||||
|
<spark-opts>
|
||||||
|
--num-executors=32
|
||||||
|
--executor-memory=${sparkExecutorMemory}
|
||||||
|
--executor-cores=${sparkExecutorCores}
|
||||||
|
--driver-memory=${sparkDriverMemory}
|
||||||
|
--conf spark.extraListeners=${spark2ExtraListeners}
|
||||||
|
--conf spark.sql.queryExecutionListeners=${spark2SqlQueryExecutionListeners}
|
||||||
|
--conf spark.yarn.historyServer.address=${spark2YarnHistoryServerAddress}
|
||||||
|
--conf spark.eventLog.dir=${nameNode}${spark2EventLogDir}
|
||||||
|
--conf spark.sql.shuffle.partitions=3840
|
||||||
|
--conf spark.dynamicAllocation.enabled=false
|
||||||
|
</spark-opts>
|
||||||
|
<arg>--publicationsPath</arg><arg>${publicationsPath}</arg>
|
||||||
|
<arg>--workingPath</arg><arg>${workingPath}</arg>
|
||||||
|
<arg>--numPartitions</arg><arg>${numPartitions}</arg>
|
||||||
|
<arg>--featuresPath</arg><arg>${featuresPath}</arg>
|
||||||
|
<arg>--wordEmbeddingsModel</arg><arg>${wordEmbeddingsModel}</arg>
|
||||||
|
<arg>--bertModel</arg><arg>${bertModel}</arg>
|
||||||
|
<arg>--bertSentenceModel</arg><arg>${bertSentenceModel}</arg>
|
||||||
|
</spark>
|
||||||
|
<ok to="AuthorExtractor"/>
|
||||||
|
<error to="Kill"/>
|
||||||
|
</action>
|
||||||
|
|
||||||
|
<action name="AuthorExtractor">
|
||||||
|
<spark xmlns="uri:oozie:spark-action:0.2">
|
||||||
|
<master>yarn</master>
|
||||||
|
<mode>cluster</mode>
|
||||||
|
<name>Author Extraction</name>
|
||||||
|
<class>eu.dnetlib.jobs.featureextraction.SparkAuthorExtractor</class>
|
||||||
|
<jar>dnet-and-test-${projectVersion}.jar</jar>
|
||||||
|
<spark-opts>
|
||||||
|
--num-executors=32
|
||||||
|
--executor-memory=${sparkExecutorMemory}
|
||||||
|
--executor-cores=${sparkExecutorCores}
|
||||||
|
--driver-memory=${sparkDriverMemory}
|
||||||
|
--conf spark.extraListeners=${spark2ExtraListeners}
|
||||||
|
--conf spark.sql.queryExecutionListeners=${spark2SqlQueryExecutionListeners}
|
||||||
|
--conf spark.yarn.historyServer.address=${spark2YarnHistoryServerAddress}
|
||||||
|
--conf spark.eventLog.dir=${nameNode}${spark2EventLogDir}
|
||||||
|
--conf spark.sql.shuffle.partitions=3840
|
||||||
|
--conf spark.dynamicAllocation.enabled=true
|
||||||
|
</spark-opts>
|
||||||
|
<arg>--workingPath</arg><arg>${workingPath}</arg>
|
||||||
|
<arg>--numPartitions</arg><arg>${numPartitions}</arg>
|
||||||
|
<arg>--publicationsPath</arg><arg>${publicationsPath}</arg>
|
||||||
|
<arg>--topicsPath</arg><arg>${topicsPath}</arg>
|
||||||
|
<arg>--featuresPath</arg><arg>${featuresPath}</arg>
|
||||||
|
<arg>--outputPath</arg><arg>${outputPath}</arg>
|
||||||
|
</spark>
|
||||||
|
<ok to="End"/>
|
||||||
|
<error to="Kill"/>
|
||||||
|
</action>
|
||||||
|
|
||||||
|
<end name="End"/>
|
||||||
|
</workflow-app>
|
|
@ -6,11 +6,23 @@
|
||||||
"paramRequired": true
|
"paramRequired": true
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"paramName": "e",
|
"paramName": "p",
|
||||||
"paramLongName": "entitiesPath",
|
"paramLongName": "publicationsPath",
|
||||||
"paramDescription": "location of the input entities",
|
"paramDescription": "location of the input entities",
|
||||||
"paramRequired": true
|
"paramRequired": true
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"paramName": "t",
|
||||||
|
"paramLongName": "topicsPath",
|
||||||
|
"paramDescription": "location of the lda topics",
|
||||||
|
"paramRequired": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"paramName": "f",
|
||||||
|
"paramLongName": "featuresPath",
|
||||||
|
"paramDescription": "location of the features",
|
||||||
|
"paramRequired": true
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"paramName": "np",
|
"paramName": "np",
|
||||||
"paramLongName": "numPartitions",
|
"paramLongName": "numPartitions",
|
||||||
|
@ -22,11 +34,5 @@
|
||||||
"paramLongName": "outputPath",
|
"paramLongName": "outputPath",
|
||||||
"paramDescription": "location of the output author extracted",
|
"paramDescription": "location of the output author extracted",
|
||||||
"paramRequired": false
|
"paramRequired": false
|
||||||
},
|
|
||||||
{
|
|
||||||
"paramName": "t",
|
|
||||||
"paramLongName": "topicsPath",
|
|
||||||
"paramDescription": "location of the lda topics",
|
|
||||||
"paramRequired": false
|
|
||||||
}
|
}
|
||||||
]
|
]
|
|
@ -1,14 +0,0 @@
|
||||||
[
|
|
||||||
{
|
|
||||||
"paramName": "w",
|
|
||||||
"paramLongName": "workingPath",
|
|
||||||
"paramDescription": "path of the working directory",
|
|
||||||
"paramRequired": true
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"paramName": "np",
|
|
||||||
"paramLongName": "numPartitions",
|
|
||||||
"paramDescription": "number of partitions for the similarity relations intermediate phases",
|
|
||||||
"paramRequired": false
|
|
||||||
}
|
|
||||||
]
|
|
|
@ -0,0 +1,44 @@
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"paramName": "w",
|
||||||
|
"paramLongName": "workingPath",
|
||||||
|
"paramDescription": "path of the working directory",
|
||||||
|
"paramRequired": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"paramName": "np",
|
||||||
|
"paramLongName": "numPartitions",
|
||||||
|
"paramDescription": "number of partitions for the similarity relations intermediate phases",
|
||||||
|
"paramRequired": false
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"paramName": "p",
|
||||||
|
"paramLongName": "publicationsPath",
|
||||||
|
"paramDescription": "location of the publications",
|
||||||
|
"paramRequired": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"paramName": "f",
|
||||||
|
"paramLongName": "featuresPath",
|
||||||
|
"paramDescription": "location of the features",
|
||||||
|
"paramRequired": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"paramName": "we",
|
||||||
|
"paramLongName": "wordEmbeddingsModel",
|
||||||
|
"paramDescription": "path of the word embeddings model",
|
||||||
|
"paramRequired": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"paramName": "bm",
|
||||||
|
"paramLongName": "bertModel",
|
||||||
|
"paramDescription": "path of the bert model",
|
||||||
|
"paramRequired": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"paramName": "bs",
|
||||||
|
"paramLongName": "bertSentenceModel",
|
||||||
|
"paramDescription": "path of the bert sentence model",
|
||||||
|
"paramRequired": true
|
||||||
|
}
|
||||||
|
]
|
|
@ -155,7 +155,7 @@
|
||||||
<master>yarn</master>
|
<master>yarn</master>
|
||||||
<mode>cluster</mode>
|
<mode>cluster</mode>
|
||||||
<name>LDA Inference</name>
|
<name>LDA Inference</name>
|
||||||
<class>eu.dnetlib.jobs.SparkLDAInference</class>
|
<class>eu.dnetlib.jobs.featureextraction.lda.SparkLDAInference</class>
|
||||||
<jar>dnet-and-test-${projectVersion}.jar</jar>
|
<jar>dnet-and-test-${projectVersion}.jar</jar>
|
||||||
<spark-opts>
|
<spark-opts>
|
||||||
--executor-memory=${sparkExecutorMemory}
|
--executor-memory=${sparkExecutorMemory}
|
||||||
|
@ -172,63 +172,8 @@
|
||||||
<arg>--ldaModelPath</arg><arg>${ldaModelPath}</arg>
|
<arg>--ldaModelPath</arg><arg>${ldaModelPath}</arg>
|
||||||
<arg>--numPartitions</arg><arg>${numPartitions}</arg>
|
<arg>--numPartitions</arg><arg>${numPartitions}</arg>
|
||||||
</spark>
|
</spark>
|
||||||
<ok to="AuthorExtraction"/>
|
|
||||||
<error to="Kill"/>
|
|
||||||
</action>
|
|
||||||
|
|
||||||
<action name="AuthorExtraction">
|
|
||||||
<spark xmlns="uri:oozie:spark-action:0.2">
|
|
||||||
<master>yarn</master>
|
|
||||||
<mode>cluster</mode>
|
|
||||||
<name>LDA Inference</name>
|
|
||||||
<class>eu.dnetlib.jobs.SparkAuthorExtractor</class>
|
|
||||||
<jar>dnet-and-test-${projectVersion}.jar</jar>
|
|
||||||
<spark-opts>
|
|
||||||
--executor-memory=${sparkExecutorMemory}
|
|
||||||
--executor-cores=${sparkExecutorCores}
|
|
||||||
--driver-memory=${sparkDriverMemory}
|
|
||||||
--conf spark.extraListeners=${spark2ExtraListeners}
|
|
||||||
--conf spark.sql.queryExecutionListeners=${spark2SqlQueryExecutionListeners}
|
|
||||||
--conf spark.yarn.historyServer.address=${spark2YarnHistoryServerAddress}
|
|
||||||
--conf spark.eventLog.dir=${nameNode}${spark2EventLogDir}
|
|
||||||
--conf spark.sql.shuffle.partitions=3840
|
|
||||||
</spark-opts>
|
|
||||||
<arg>--entitiesPath</arg><arg>${entitiesPath}</arg>
|
|
||||||
<arg>--workingPath</arg><arg>${workingPath}</arg>
|
|
||||||
<arg>--outputPath</arg><arg>${authorsPath}</arg>
|
|
||||||
<arg>--numPartitions</arg><arg>${numPartitions}</arg>
|
|
||||||
<arg>--topicsPath</arg><arg>${ldaInferencePath}</arg>
|
|
||||||
</spark>
|
|
||||||
<ok to="ThresholdAnalysis"/>
|
|
||||||
<error to="Kill"/>
|
|
||||||
</action>
|
|
||||||
|
|
||||||
<action name="ThresholdAnalysis">
|
|
||||||
<spark xmlns="uri:oozie:spark-action:0.2">
|
|
||||||
<master>yarn</master>
|
|
||||||
<mode>cluster</mode>
|
|
||||||
<name>LDA Threshold Analysis</name>
|
|
||||||
<class>eu.dnetlib.jobs.SparkLDAAnalysis</class>
|
|
||||||
<jar>dnet-and-test-${projectVersion}.jar</jar>
|
|
||||||
<spark-opts>
|
|
||||||
--num-executors=32
|
|
||||||
--executor-memory=${sparkExecutorMemory}
|
|
||||||
--executor-cores=${sparkExecutorCores}
|
|
||||||
--driver-memory=${sparkDriverMemory}
|
|
||||||
--conf spark.extraListeners=${spark2ExtraListeners}
|
|
||||||
--conf spark.sql.queryExecutionListeners=${spark2SqlQueryExecutionListeners}
|
|
||||||
--conf spark.yarn.historyServer.address=${spark2YarnHistoryServerAddress}
|
|
||||||
--conf spark.eventLog.dir=${nameNode}${spark2EventLogDir}
|
|
||||||
--conf spark.sql.shuffle.partitions=3840
|
|
||||||
--conf spark.dynamicAllocation.enabled=false
|
|
||||||
</spark-opts>
|
|
||||||
<arg>--authorsPath</arg><arg>${authorsPath}</arg>
|
|
||||||
<arg>--workingPath</arg><arg>${workingPath}</arg>
|
|
||||||
<arg>--numPartitions</arg><arg>${numPartitions}</arg>
|
|
||||||
</spark>
|
|
||||||
<ok to="End"/>
|
<ok to="End"/>
|
||||||
<error to="Kill"/>
|
<error to="Kill"/>
|
||||||
</action>
|
</action>
|
||||||
|
|
||||||
<end name="End"/>
|
<end name="End"/>
|
||||||
</workflow-app>
|
</workflow-app>
|
|
@ -195,7 +195,7 @@
|
||||||
<master>yarn</master>
|
<master>yarn</master>
|
||||||
<mode>cluster</mode>
|
<mode>cluster</mode>
|
||||||
<name>LDA Tuning</name>
|
<name>LDA Tuning</name>
|
||||||
<class>eu.dnetlib.jobs.SparkLDATuning</class>
|
<class>eu.dnetlib.jobs.featureextraction.lda.SparkLDATuning</class>
|
||||||
<jar>dnet-and-test-${projectVersion}.jar</jar>
|
<jar>dnet-and-test-${projectVersion}.jar</jar>
|
||||||
<spark-opts>
|
<spark-opts>
|
||||||
--executor-memory=${sparkExecutorMemory}
|
--executor-memory=${sparkExecutorMemory}
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package eu.dnetlib.jobs.deeplearning;
|
package eu.dnetlib.jobs.featureextraction;
|
||||||
|
|
||||||
import eu.dnetlib.jobs.AbstractSparkJob;
|
import eu.dnetlib.jobs.AbstractSparkJob;
|
||||||
|
import eu.dnetlib.jobs.SparkTokenizer;
|
||||||
import eu.dnetlib.support.ArgumentApplicationParser;
|
import eu.dnetlib.support.ArgumentApplicationParser;
|
||||||
import org.apache.commons.io.FileUtils;
|
import org.apache.commons.io.FileUtils;
|
||||||
import org.apache.commons.io.IOUtils;
|
import org.apache.commons.io.IOUtils;
|
||||||
|
@ -15,22 +16,22 @@ import java.nio.file.Paths;
|
||||||
|
|
||||||
@TestMethodOrder(MethodOrderer.OrderAnnotation.class)
|
@TestMethodOrder(MethodOrderer.OrderAnnotation.class)
|
||||||
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
|
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
|
||||||
public class GNNTrainingTest {
|
public class FeatureExtractionJobTest {
|
||||||
|
|
||||||
static SparkSession spark;
|
static SparkSession spark;
|
||||||
static JavaSparkContext context;
|
static JavaSparkContext context;
|
||||||
final static String workingPath = "/tmp/working_dir";
|
final static String workingPath = "/tmp/working_dir";
|
||||||
|
|
||||||
final static String numPartitions = "20";
|
|
||||||
final String inputDataPath = Paths
|
final String inputDataPath = Paths
|
||||||
.get(getClass().getResource("/eu/dnetlib/jobs/examples/authors.groups.example.json").toURI())
|
.get(getClass().getResource("/eu/dnetlib/jobs/examples/publications.subset.json").toURI())
|
||||||
.toFile()
|
.toFile()
|
||||||
.getAbsolutePath();
|
.getAbsolutePath();
|
||||||
final static String groundTruthJPath = "$.orcid";
|
|
||||||
final static String idJPath = "$.id";
|
|
||||||
final static String featuresJPath = "$.topics";
|
|
||||||
|
|
||||||
public GNNTrainingTest() throws URISyntaxException {}
|
final String ldaTopicsPath = Paths
|
||||||
|
.get(getClass().getResource("/eu/dnetlib/jobs/examples/publications_lda_topics_subset").toURI())
|
||||||
|
.toFile()
|
||||||
|
.getAbsolutePath();
|
||||||
|
|
||||||
|
public FeatureExtractionJobTest() throws URISyntaxException {}
|
||||||
|
|
||||||
public static void cleanup() throws IOException {
|
public static void cleanup() throws IOException {
|
||||||
//remove directories and clean workspace
|
//remove directories and clean workspace
|
||||||
|
@ -57,43 +58,43 @@ public class GNNTrainingTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@Order(1)
|
@Order(1)
|
||||||
public void createGroupDataSetTest() throws Exception {
|
public void publicationFeatureExtractionTest() throws Exception {
|
||||||
ArgumentApplicationParser parser = new ArgumentApplicationParser(readResource("/jobs/parameters/createGroupDataset_parameters.json", SparkCreateGroupDataSet.class));
|
ArgumentApplicationParser parser = new ArgumentApplicationParser(readResource("/jobs/parameters/publicationFeatureExtractor_parameters.json", SparkTokenizer.class));
|
||||||
|
|
||||||
parser.parseArgument(
|
parser.parseArgument(
|
||||||
new String[] {
|
new String[] {
|
||||||
"-i", inputDataPath,
|
"-p", inputDataPath,
|
||||||
"-gt", groundTruthJPath,
|
|
||||||
"-id", idJPath,
|
|
||||||
"-f", featuresJPath,
|
|
||||||
"-w", workingPath,
|
"-w", workingPath,
|
||||||
"-np", numPartitions
|
"-np", "20"
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
|
|
||||||
new SparkCreateGroupDataSet(
|
new SparkPublicationFeatureExtractor(
|
||||||
parser,
|
parser,
|
||||||
spark
|
spark
|
||||||
).run();
|
).run();
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@Order(2)
|
@Order(2)
|
||||||
public void graphClassificationTrainingTest() throws Exception{
|
public void authorExtractionTest() throws Exception {
|
||||||
ArgumentApplicationParser parser = new ArgumentApplicationParser(readResource("/jobs/parameters/graphClassificationTraining_parameters.json", SparkGraphClassificationTraining.class));
|
ArgumentApplicationParser parser = new ArgumentApplicationParser(readResource("/jobs/parameters/authorExtractor_parameters.json", SparkAuthorExtractor.class));
|
||||||
|
|
||||||
parser.parseArgument(
|
parser.parseArgument(
|
||||||
new String[] {
|
new String[]{
|
||||||
|
"-p", inputDataPath,
|
||||||
"-w", workingPath,
|
"-w", workingPath,
|
||||||
"-np", numPartitions
|
"-np", "20",
|
||||||
}
|
"-t", ldaTopicsPath,
|
||||||
);
|
"-f", workingPath + "/publication_features",
|
||||||
|
"-o", workingPath + "/authors"
|
||||||
|
});
|
||||||
|
|
||||||
new SparkGraphClassificationTraining(
|
new SparkAuthorExtractor(
|
||||||
parser,
|
parser,
|
||||||
spark
|
spark
|
||||||
).run();
|
).run();
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public static String readResource(String path, Class<? extends AbstractSparkJob> clazz) throws IOException {
|
public static String readResource(String path, Class<? extends AbstractSparkJob> clazz) throws IOException {
|
|
@ -1,5 +1,9 @@
|
||||||
package eu.dnetlib.jobs;
|
package eu.dnetlib.jobs.featureextraction.lda;
|
||||||
|
|
||||||
|
import eu.dnetlib.jobs.AbstractSparkJob;
|
||||||
|
import eu.dnetlib.jobs.SparkCountVectorizer;
|
||||||
|
import eu.dnetlib.jobs.SparkCreateVocabulary;
|
||||||
|
import eu.dnetlib.jobs.SparkTokenizer;
|
||||||
import eu.dnetlib.support.ArgumentApplicationParser;
|
import eu.dnetlib.support.ArgumentApplicationParser;
|
||||||
import org.apache.commons.io.FileUtils;
|
import org.apache.commons.io.FileUtils;
|
||||||
import org.apache.commons.io.IOUtils;
|
import org.apache.commons.io.IOUtils;
|
||||||
|
@ -157,74 +161,12 @@ public class LDAAnalysisTest {
|
||||||
parser,
|
parser,
|
||||||
spark
|
spark
|
||||||
).run();
|
).run();
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@Order(6)
|
|
||||||
public void authorExtractorTest() throws Exception {
|
|
||||||
ArgumentApplicationParser parser = new ArgumentApplicationParser(readResource("/jobs/parameters/authorExtractor_parameters.json", SparkLDAInference.class));
|
|
||||||
|
|
||||||
parser.parseArgument(
|
|
||||||
new String[]{
|
|
||||||
"-e", inputDataPath,
|
|
||||||
"-o", authorsPath,
|
|
||||||
"-t", topicsPath,
|
|
||||||
"-w", workingPath,
|
|
||||||
"-np", numPartitions
|
|
||||||
});
|
|
||||||
|
|
||||||
new SparkAuthorExtractor(
|
|
||||||
parser,
|
|
||||||
spark
|
|
||||||
).run();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
@Order(7)
|
|
||||||
public void ldaAnalysis() throws Exception {
|
|
||||||
ArgumentApplicationParser parser = new ArgumentApplicationParser(readResource("/jobs/parameters/ldaAnalysis_parameters.json", SparkLDAAnalysis.class));
|
|
||||||
|
|
||||||
parser.parseArgument(
|
|
||||||
new String[]{
|
|
||||||
"-i", authorsPath,
|
|
||||||
"-w", workingPath,
|
|
||||||
"-np", numPartitions
|
|
||||||
});
|
|
||||||
|
|
||||||
new SparkLDAAnalysis(
|
|
||||||
parser,
|
|
||||||
spark
|
|
||||||
).run();
|
|
||||||
|
|
||||||
Thread.sleep(1000000000);
|
|
||||||
|
|
||||||
|
Thread.sleep(100000);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static String readResource(String path, Class<? extends AbstractSparkJob> clazz) throws IOException {
|
public static String readResource(String path, Class<? extends AbstractSparkJob> clazz) throws IOException {
|
||||||
return IOUtils.toString(clazz.getResourceAsStream(path));
|
return IOUtils.toString(clazz.getResourceAsStream(path));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// @Test
|
|
||||||
// public void createVocabulary() {
|
|
||||||
//
|
|
||||||
// StructType inputSchema = new StructType(new StructField[]{
|
|
||||||
// new StructField("id", DataTypes.StringType, false, Metadata.empty()),
|
|
||||||
// new StructField("sentence", DataTypes.StringType, false, Metadata.empty())
|
|
||||||
// });
|
|
||||||
//
|
|
||||||
// JavaSparkContext sc = JavaSparkContext.fromSparkContext(spark.sparkContext());
|
|
||||||
// JavaRDD<Row> rows = sc.textFile("/Users/miconis/Desktop/dewey").map(s -> s.substring(4)).map(s -> Utilities.normalize(s).replaceAll(" ", " ")).filter(s -> !s.contains("unassigned")).map(s -> RowFactory.create("id", s));
|
|
||||||
//
|
|
||||||
// Dataset<Row> dataFrame = spark.createDataFrame(rows, inputSchema);
|
|
||||||
//
|
|
||||||
// dataFrame = FeatureTransformer.tokenizeData(dataFrame);
|
|
||||||
//
|
|
||||||
// JavaRDD<String> map = dataFrame.toJavaRDD().map(r -> r.getList(1)).flatMap(l -> l.iterator()).map(s -> s.toString()).distinct();
|
|
||||||
//
|
|
||||||
// map.coalesce(1).saveAsTextFile("/tmp/vocab_raw");
|
|
||||||
// System.out.println("map = " + map.count());
|
|
||||||
// System.out.println("dataFrame = " + map.first());
|
|
||||||
// }
|
|
||||||
|
|
||||||
}
|
}
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -13,6 +13,36 @@
|
||||||
<artifactId>dnet-feature-extraction</artifactId>
|
<artifactId>dnet-feature-extraction</artifactId>
|
||||||
<packaging>jar</packaging>
|
<packaging>jar</packaging>
|
||||||
|
|
||||||
|
<build>
|
||||||
|
<plugins>
|
||||||
|
<plugin>
|
||||||
|
<groupId>net.alchim31.maven</groupId>
|
||||||
|
<artifactId>scala-maven-plugin</artifactId>
|
||||||
|
<version>4.0.1</version>
|
||||||
|
<executions>
|
||||||
|
<execution>
|
||||||
|
<id>scala-compile-first</id>
|
||||||
|
<phase>initialize</phase>
|
||||||
|
<goals>
|
||||||
|
<goal>add-source</goal>
|
||||||
|
<goal>compile</goal>
|
||||||
|
</goals>
|
||||||
|
</execution>
|
||||||
|
<execution>
|
||||||
|
<id>scala-test-compile</id>
|
||||||
|
<phase>process-test-resources</phase>
|
||||||
|
<goals>
|
||||||
|
<goal>testCompile</goal>
|
||||||
|
</goals>
|
||||||
|
</execution>
|
||||||
|
</executions>
|
||||||
|
<configuration>
|
||||||
|
<scalaVersion>${scala.version}</scalaVersion>
|
||||||
|
</configuration>
|
||||||
|
</plugin>
|
||||||
|
</plugins>
|
||||||
|
</build>
|
||||||
|
|
||||||
<dependencies>
|
<dependencies>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.apache.spark</groupId>
|
<groupId>org.apache.spark</groupId>
|
||||||
|
@ -53,42 +83,36 @@
|
||||||
</dependency>
|
</dependency>
|
||||||
|
|
||||||
<!--DEEPLEARNING4J -->
|
<!--DEEPLEARNING4J -->
|
||||||
|
<!-- <dependency>-->
|
||||||
|
<!-- <groupId>org.nd4j</groupId>-->
|
||||||
|
<!-- <artifactId>${nd4j.backend}</artifactId>-->
|
||||||
|
<!-- </dependency>-->
|
||||||
|
<!-- <dependency>-->
|
||||||
|
<!-- <groupId>org.deeplearning4j</groupId>-->
|
||||||
|
<!-- <artifactId>deeplearning4j-core</artifactId>-->
|
||||||
|
<!-- </dependency>-->
|
||||||
|
<!-- <dependency>-->
|
||||||
|
<!-- <groupId>org.deeplearning4j</groupId>-->
|
||||||
|
<!-- <artifactId>deeplearning4j-datasets</artifactId>-->
|
||||||
|
<!-- </dependency>-->
|
||||||
|
<!-- <dependency>-->
|
||||||
|
<!-- <groupId>org.deeplearning4j</groupId>-->
|
||||||
|
<!-- <artifactId>dl4j-spark-parameterserver_2.11</artifactId>-->
|
||||||
|
<!-- </dependency>-->
|
||||||
|
<!-- <dependency>-->
|
||||||
|
<!-- <groupId>org.deeplearning4j</groupId>-->
|
||||||
|
<!-- <artifactId>dl4j-spark_2.11</artifactId>-->
|
||||||
|
<!-- </dependency>-->
|
||||||
|
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.nd4j</groupId>
|
<groupId>com.johnsnowlabs.nlp</groupId>
|
||||||
<artifactId>${nd4j.backend}</artifactId>
|
<artifactId>spark-nlp_${scala.binary.version}</artifactId>
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.deeplearning4j</groupId>
|
|
||||||
<artifactId>deeplearning4j-core</artifactId>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.deeplearning4j</groupId>
|
|
||||||
<artifactId>deeplearning4j-datasets</artifactId>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.deeplearning4j</groupId>
|
|
||||||
<artifactId>dl4j-spark-parameterserver_2.11</artifactId>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.deeplearning4j</groupId>
|
|
||||||
<artifactId>dl4j-spark_2.11</artifactId>
|
|
||||||
</dependency>
|
</dependency>
|
||||||
|
|
||||||
<!--PLOT-->
|
<!-- <dependency>-->
|
||||||
<dependency>
|
<!-- <groupId>org.json4s</groupId>-->
|
||||||
<groupId>jfree</groupId>
|
<!-- <artifactId>json4s-jackson_${scala.binary.version}</artifactId>-->
|
||||||
<artifactId>jfreechart</artifactId>
|
<!-- </dependency>-->
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.jfree</groupId>
|
|
||||||
<artifactId>jcommon</artifactId>
|
|
||||||
</dependency>
|
|
||||||
|
|
||||||
<!--DNET DEDUP-->
|
|
||||||
<dependency>
|
|
||||||
<groupId>eu.dnetlib</groupId>
|
|
||||||
<artifactId>dnet-dedup-test</artifactId>
|
|
||||||
</dependency>
|
|
||||||
|
|
||||||
</dependencies>
|
</dependencies>
|
||||||
|
|
||||||
|
|
|
@ -1,130 +0,0 @@
|
||||||
//package eu.dnetlib.deeplearning;
|
|
||||||
//
|
|
||||||
///* *****************************************************************************
|
|
||||||
// *
|
|
||||||
// *
|
|
||||||
// *
|
|
||||||
// * This program and the accompanying materials are made available under the
|
|
||||||
// * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
// * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
// * See the NOTICE file distributed with this work for additional
|
|
||||||
// * information regarding copyright ownership.
|
|
||||||
// *
|
|
||||||
// * Unless required by applicable law or agreed to in writing, software
|
|
||||||
// * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
// * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
// * License for the specific language governing permissions and limitations
|
|
||||||
// * under the License.
|
|
||||||
// *
|
|
||||||
// * SPDX-License-Identifier: Apache-2.0
|
|
||||||
// ******************************************************************************/
|
|
||||||
//
|
|
||||||
//import org.datavec.api.records.reader.RecordReader;
|
|
||||||
//import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
|
|
||||||
//import org.datavec.api.split.FileSplit;
|
|
||||||
//import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
|
|
||||||
//import org.deeplearning4j.examples.utils.DownloaderUtility;
|
|
||||||
//import org.deeplearning4j.examples.utils.PlotUtil;
|
|
||||||
//import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
|
||||||
//import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
|
||||||
//import org.deeplearning4j.nn.conf.layers.DenseLayer;
|
|
||||||
//import org.deeplearning4j.nn.conf.layers.OutputLayer;
|
|
||||||
//import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
|
||||||
//import org.deeplearning4j.nn.weights.WeightInit;
|
|
||||||
//import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
|
|
||||||
//import org.nd4j.evaluation.classification.Evaluation;
|
|
||||||
//import org.nd4j.linalg.activations.Activation;
|
|
||||||
//import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
//import org.nd4j.linalg.dataset.DataSet;
|
|
||||||
//import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
|
||||||
//import org.nd4j.linalg.learning.config.Nesterovs;
|
|
||||||
//import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;
|
|
||||||
//
|
|
||||||
//import java.io.File;
|
|
||||||
//import java.util.concurrent.TimeUnit;
|
|
||||||
//
|
|
||||||
//public class GroupClassifier {
|
|
||||||
//
|
|
||||||
// public static boolean visualize = true;
|
|
||||||
// public static String dataLocalPath;
|
|
||||||
//
|
|
||||||
// public static void main(String[] args) throws Exception {
|
|
||||||
// int seed = 123;
|
|
||||||
// double learningRate = 0.01;
|
|
||||||
// int batchSize = 50;
|
|
||||||
// int nEpochs = 30;
|
|
||||||
//
|
|
||||||
// int numInputs = 2;
|
|
||||||
// int numOutputs = 2;
|
|
||||||
// int numHiddenNodes = 20;
|
|
||||||
//
|
|
||||||
// dataLocalPath = DownloaderUtility.CLASSIFICATIONDATA.Download();
|
|
||||||
// //Load the training data:
|
|
||||||
// RecordReader rr = new CSVRecordReader();
|
|
||||||
// rr.initialize(new FileSplit(new File(dataLocalPath, "linear_data_train.csv")));
|
|
||||||
// DataSetIterator trainIter = new RecordReaderDataSetIterator(rr, batchSize, 0, 2);
|
|
||||||
//
|
|
||||||
// //Load the test/evaluation data:
|
|
||||||
// RecordReader rrTest = new CSVRecordReader();
|
|
||||||
// rrTest.initialize(new FileSplit(new File(dataLocalPath, "linear_data_eval.csv")));
|
|
||||||
// DataSetIterator testIter = new RecordReaderDataSetIterator(rrTest, batchSize, 0, 2);
|
|
||||||
//
|
|
||||||
// MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
|
||||||
// .seed(seed)
|
|
||||||
// .weightInit(WeightInit.XAVIER)
|
|
||||||
// .updater(new Nesterovs(learningRate, 0.9))
|
|
||||||
// .list()
|
|
||||||
// .layer(new DenseLayer.Builder().nIn(numInputs).nOut(numHiddenNodes)
|
|
||||||
// .activation(Activation.RELU)
|
|
||||||
// .build())
|
|
||||||
// .layer(new OutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD)
|
|
||||||
// .activation(Activation.SOFTMAX)
|
|
||||||
// .nIn(numHiddenNodes).nOut(numOutputs).build())
|
|
||||||
// .build();
|
|
||||||
//
|
|
||||||
//
|
|
||||||
// MultiLayerNetwork model = new MultiLayerNetwork(conf);
|
|
||||||
// model.init();
|
|
||||||
// model.setListeners(new ScoreIterationListener(10)); //Print score every 10 parameter updates
|
|
||||||
//
|
|
||||||
// model.fit(trainIter, nEpochs);
|
|
||||||
//
|
|
||||||
// System.out.println("Evaluate model....");
|
|
||||||
// Evaluation eval = new Evaluation(numOutputs);
|
|
||||||
// while (testIter.hasNext()) {
|
|
||||||
// DataSet t = testIter.next();
|
|
||||||
// INDArray features = t.getFeatures();
|
|
||||||
// INDArray labels = t.getLabels();
|
|
||||||
// INDArray predicted = model.output(features, false);
|
|
||||||
// eval.eval(labels, predicted);
|
|
||||||
// }
|
|
||||||
// //An alternate way to do the above loop
|
|
||||||
// //Evaluation evalResults = model.evaluate(testIter);
|
|
||||||
//
|
|
||||||
// //Print the evaluation statistics
|
|
||||||
// System.out.println(eval.stats());
|
|
||||||
//
|
|
||||||
// System.out.println("\n****************Example finished********************");
|
|
||||||
// //Training is complete. Code that follows is for plotting the data & predictions only
|
|
||||||
// generateVisuals(model, trainIter, testIter);
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// public static void generateVisuals(MultiLayerNetwork model, DataSetIterator trainIter, DataSetIterator testIter) throws Exception {
|
|
||||||
// if (visualize) {
|
|
||||||
// double xMin = 0;
|
|
||||||
// double xMax = 1.0;
|
|
||||||
// double yMin = -0.2;
|
|
||||||
// double yMax = 0.8;
|
|
||||||
// int nPointsPerAxis = 100;
|
|
||||||
//
|
|
||||||
// //Generate x,y points that span the whole range of features
|
|
||||||
// INDArray allXYPoints = PlotUtil.generatePointsOnGraph(xMin, xMax, yMin, yMax, nPointsPerAxis);
|
|
||||||
// //Get train data and plot with predictions
|
|
||||||
// PlotUtil.plotTrainingData(model, trainIter, allXYPoints, nPointsPerAxis);
|
|
||||||
// TimeUnit.SECONDS.sleep(3);
|
|
||||||
// //Get test data, run the test data through the network to generate predictions, and plot those predictions:
|
|
||||||
// PlotUtil.plotTestData(model, testIter, allXYPoints, nPointsPerAxis);
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
//}
|
|
||||||
//
|
|
|
@ -1,24 +0,0 @@
|
||||||
package eu.dnetlib.deeplearning.layers;
|
|
||||||
|
|
||||||
import org.deeplearning4j.nn.conf.graph.GraphVertex;
|
|
||||||
import org.deeplearning4j.nn.conf.layers.samediff.SDVertexParams;
|
|
||||||
import org.deeplearning4j.nn.conf.layers.samediff.SameDiffLambdaVertex;
|
|
||||||
import org.deeplearning4j.nn.conf.layers.samediff.SameDiffVertex;
|
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
|
|
||||||
import java.util.Map;
|
|
||||||
|
|
||||||
public class GraphConvolutionVertex extends SameDiffLambdaVertex {
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public SDVariable defineVertex(SameDiff sameDiff, VertexInputs inputs) {
|
|
||||||
SDVariable features = inputs.getInput(0);
|
|
||||||
SDVariable adjacency = inputs.getInput(1);
|
|
||||||
SDVariable degree = inputs.getInput(2).pow(0.5);
|
|
||||||
|
|
||||||
//result: DegreeMatrix^-0.5 x Adjacent x DegreeMatrix^-0.5 x Features
|
|
||||||
return degree.mmul(adjacency).mmul(degree).mmul(features);
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,21 +0,0 @@
|
||||||
package eu.dnetlib.deeplearning.layers;
|
|
||||||
|
|
||||||
import org.deeplearning4j.nn.conf.layers.samediff.SameDiffLambdaLayer;
|
|
||||||
import org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer;
|
|
||||||
import org.nd4j.autodiff.samediff.SDIndex;
|
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
|
||||||
|
|
||||||
import java.util.Map;
|
|
||||||
|
|
||||||
public class GraphGlobalAddPool extends SameDiffLambdaLayer {
|
|
||||||
|
|
||||||
int size;
|
|
||||||
public GraphGlobalAddPool(int size) {
|
|
||||||
this.size = size;
|
|
||||||
}
|
|
||||||
@Override
|
|
||||||
public SDVariable defineLayer(SameDiff sameDiff, SDVariable layerInput) {
|
|
||||||
return layerInput.mean(0).reshape(1, size); //reshape because output layer expects 2-dimensional arrays
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,88 +0,0 @@
|
||||||
package eu.dnetlib.deeplearning.support;
|
|
||||||
|
|
||||||
import eu.dnetlib.featureextraction.Utilities;
|
|
||||||
import eu.dnetlib.support.Author;
|
|
||||||
import eu.dnetlib.support.ConnectedComponent;
|
|
||||||
import eu.dnetlib.support.Relation;
|
|
||||||
import org.apache.spark.api.java.JavaRDD;
|
|
||||||
import org.codehaus.jackson.map.ObjectMapper;
|
|
||||||
import org.jetbrains.annotations.NotNull;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.linalg.dataset.MultiDataSet;
|
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
|
||||||
|
|
||||||
import java.io.IOException;
|
|
||||||
import java.util.*;
|
|
||||||
import java.util.stream.Collectors;
|
|
||||||
import java.util.stream.Stream;
|
|
||||||
|
|
||||||
public class DataSetProcessor {
|
|
||||||
|
|
||||||
public static JavaRDD<MultiDataSet> entityGroupToMultiDataset(JavaRDD<ConnectedComponent> groupEntity, String idJPath, String featureJPath, String groundTruthJPath) {
|
|
||||||
|
|
||||||
return groupEntity.map(g -> {
|
|
||||||
Map<String, double[]> featuresMap = new HashMap<>();
|
|
||||||
List<String> groundTruth = new ArrayList<>();
|
|
||||||
Set<String> entities = g.getDocs();
|
|
||||||
for(String json:entities) {
|
|
||||||
featuresMap.put(
|
|
||||||
Utilities.getJPathString(idJPath, json),
|
|
||||||
Utilities.getJPathArray(featureJPath, json)
|
|
||||||
);
|
|
||||||
groundTruth.add(Utilities.getJPathString(groundTruthJPath, json));
|
|
||||||
}
|
|
||||||
|
|
||||||
Set<Relation> relations = g.getSimrels();
|
|
||||||
|
|
||||||
return getMultiDataSet(featuresMap, relations, groundTruth);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
public static MultiDataSet getMultiDataSet(Map<String, double[]> featuresMap, Set<Relation> relations, List<String> groundTruth) {
|
|
||||||
|
|
||||||
List<String> identifiers = new ArrayList<>(featuresMap.keySet());
|
|
||||||
|
|
||||||
int numNodes = identifiers.size();
|
|
||||||
|
|
||||||
//initialize arrays
|
|
||||||
INDArray adjacency = Nd4j.zeros(numNodes, numNodes);
|
|
||||||
INDArray features = Nd4j.zeros(numNodes, featuresMap.get(identifiers.get(0)).length); //feature size taken from the first element (it's equal for every element)
|
|
||||||
INDArray degree = Nd4j.zeros(numNodes, numNodes);
|
|
||||||
|
|
||||||
//create adjacency
|
|
||||||
for(Relation r: relations) {
|
|
||||||
adjacency.put(identifiers.indexOf(r.getSource()), identifiers.indexOf(r.getTarget()), 1);
|
|
||||||
adjacency.put(identifiers.indexOf(r.getTarget()), identifiers.indexOf(r.getSource()), 1);
|
|
||||||
}
|
|
||||||
adjacency.addi(Nd4j.eye(numNodes));
|
|
||||||
|
|
||||||
//create degree and features
|
|
||||||
List<String> degreeSupport = relations.stream().flatMap(r -> Stream.of(r.getSource(), r.getTarget())).collect(Collectors.toList());
|
|
||||||
for(int i=0; i< identifiers.size(); i++) {
|
|
||||||
degree.put(i, i, Collections.frequency(degreeSupport, identifiers.get(i)));
|
|
||||||
features.putRow(i, Nd4j.create(featuresMap.get(identifiers.get(i))));
|
|
||||||
}
|
|
||||||
|
|
||||||
//infer label
|
|
||||||
INDArray label = Nd4j.zeros(1, 2);
|
|
||||||
if (groundTruth.stream().distinct().count()==1) {
|
|
||||||
//correct (same elements)
|
|
||||||
label.put(0, 0, 1.0);
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
//wrong (different elements)
|
|
||||||
label.put(0, 1, 1.0);
|
|
||||||
}
|
|
||||||
|
|
||||||
return new MultiDataSet(
|
|
||||||
new INDArray[]{
|
|
||||||
features,
|
|
||||||
adjacency,
|
|
||||||
degree
|
|
||||||
},
|
|
||||||
new INDArray[]{
|
|
||||||
label
|
|
||||||
}
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,11 +0,0 @@
|
||||||
package eu.dnetlib.deeplearning.support;
|
|
||||||
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.linalg.dataset.MultiDataSet;
|
|
||||||
|
|
||||||
import java.io.*;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
public class GroupMultiDataSet extends MultiDataSet {
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,97 +0,0 @@
|
||||||
package eu.dnetlib.deeplearning.support;
|
|
||||||
|
|
||||||
import eu.dnetlib.deeplearning.layers.GraphConvolutionVertex;
|
|
||||||
import eu.dnetlib.deeplearning.layers.GraphGlobalAddPool;
|
|
||||||
import org.bytedeco.opencv.opencv_dnn.PoolingLayer;
|
|
||||||
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
|
|
||||||
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
|
||||||
import org.deeplearning4j.nn.conf.graph.MergeVertex;
|
|
||||||
import org.deeplearning4j.nn.conf.layers.*;
|
|
||||||
import org.deeplearning4j.nn.weights.WeightInit;
|
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
|
||||||
import org.nd4j.linalg.activations.Activation;
|
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
|
||||||
import org.nd4j.linalg.learning.config.Adam;
|
|
||||||
import org.nd4j.linalg.learning.config.Nesterovs;
|
|
||||||
import org.nd4j.linalg.lossfunctions.LossFunctions;
|
|
||||||
|
|
||||||
public class NetworkConfigurations {
|
|
||||||
|
|
||||||
//parameteres default values
|
|
||||||
protected static final int SEED = 12345;
|
|
||||||
protected static final double LEARNING_RATE = 1e-3;
|
|
||||||
protected static final String ADJACENCY_MATRIX = "adjacency";
|
|
||||||
protected static final String FEATURES_MATRIX = "features";
|
|
||||||
protected static final String DEGREE_MATRIX = "degrees";
|
|
||||||
|
|
||||||
public static MultiLayerConfiguration getLinearDataClassifier(int numInputs, int numHiddenNodes, int numOutputs) {
|
|
||||||
return new NeuralNetConfiguration.Builder()
|
|
||||||
.seed(SEED)
|
|
||||||
.weightInit(WeightInit.XAVIER)
|
|
||||||
.updater(new Nesterovs(LEARNING_RATE, 0.9))
|
|
||||||
.list()
|
|
||||||
.layer(new DenseLayer.Builder().nIn(numInputs).nOut(numHiddenNodes)
|
|
||||||
.activation(Activation.RELU)
|
|
||||||
.build())
|
|
||||||
.layer(new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
|
|
||||||
.activation(Activation.SOFTMAX)
|
|
||||||
.nIn(numHiddenNodes).nOut(numOutputs).build())
|
|
||||||
.build();
|
|
||||||
}
|
|
||||||
|
|
||||||
public static ComputationGraphConfiguration getSimpleGCN(int numLayers, int numInputs, int numHiddenNodes, int numClasses) {
|
|
||||||
|
|
||||||
ComputationGraphConfiguration.GraphBuilder baseConfig = new NeuralNetConfiguration.Builder()
|
|
||||||
.seed(SEED)
|
|
||||||
.updater(new Adam(LEARNING_RATE))
|
|
||||||
.weightInit(WeightInit.XAVIER)
|
|
||||||
.graphBuilder()
|
|
||||||
.addInputs(FEATURES_MATRIX, ADJACENCY_MATRIX, DEGREE_MATRIX)
|
|
||||||
//first convolution layer
|
|
||||||
.addVertex("layer1",
|
|
||||||
new GraphConvolutionVertex(),
|
|
||||||
FEATURES_MATRIX, ADJACENCY_MATRIX, DEGREE_MATRIX)
|
|
||||||
.layer("conv1",
|
|
||||||
new DenseLayer.Builder().nIn(numInputs).nOut(numHiddenNodes)
|
|
||||||
.activation(Activation.RELU)
|
|
||||||
.build(),
|
|
||||||
"layer1")
|
|
||||||
.layer("batch1",
|
|
||||||
new BatchNormalization.Builder().nOut(numHiddenNodes).build(),
|
|
||||||
"conv1");
|
|
||||||
|
|
||||||
//ad as many layers as requested
|
|
||||||
for(int i=2; i<=numLayers; i++) {
|
|
||||||
baseConfig = baseConfig.addVertex("layer" + i,
|
|
||||||
new GraphConvolutionVertex(),
|
|
||||||
"batch" + (i-1), ADJACENCY_MATRIX, DEGREE_MATRIX)
|
|
||||||
.layer("conv" + i,
|
|
||||||
new DenseLayer.Builder().nIn(numHiddenNodes).nOut(numHiddenNodes)
|
|
||||||
.activation(Activation.RELU)
|
|
||||||
.build(),
|
|
||||||
"layer" + i)
|
|
||||||
.layer("batch" + i,
|
|
||||||
new BatchNormalization.Builder().nOut(numHiddenNodes).build(),
|
|
||||||
"conv" + i);
|
|
||||||
}
|
|
||||||
|
|
||||||
baseConfig = baseConfig
|
|
||||||
.layer("pool",
|
|
||||||
new GraphGlobalAddPool(numHiddenNodes),
|
|
||||||
"batch" + numLayers)
|
|
||||||
.layer("fc1",
|
|
||||||
new DenseLayer.Builder().nIn(numHiddenNodes).nOut(numHiddenNodes)
|
|
||||||
.activation(Activation.RELU)
|
|
||||||
.weightInit(WeightInit.XAVIER)
|
|
||||||
.build(),
|
|
||||||
"pool")
|
|
||||||
.layer("out",
|
|
||||||
new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
|
|
||||||
.activation(Activation.SOFTMAX)
|
|
||||||
.nIn(numHiddenNodes).nOut(numClasses).build(),
|
|
||||||
"fc1");
|
|
||||||
|
|
||||||
return baseConfig.setOutputs("out").build();
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,253 +0,0 @@
|
||||||
//package eu.dnetlib.deeplearning.support;
|
|
||||||
//
|
|
||||||
//import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
|
||||||
//import org.jfree.chart.ChartPanel;
|
|
||||||
//import org.jfree.chart.ChartUtilities;
|
|
||||||
//import org.jfree.chart.JFreeChart;
|
|
||||||
//import org.jfree.chart.axis.AxisLocation;
|
|
||||||
//import org.jfree.chart.axis.NumberAxis;
|
|
||||||
//import org.jfree.chart.block.BlockBorder;
|
|
||||||
//import org.jfree.chart.plot.DatasetRenderingOrder;
|
|
||||||
//import org.jfree.chart.plot.XYPlot;
|
|
||||||
//import org.jfree.chart.renderer.GrayPaintScale;
|
|
||||||
//import org.jfree.chart.renderer.PaintScale;
|
|
||||||
//import org.jfree.chart.renderer.xy.XYBlockRenderer;
|
|
||||||
//import org.jfree.chart.renderer.xy.XYLineAndShapeRenderer;
|
|
||||||
//import org.jfree.chart.title.PaintScaleLegend;
|
|
||||||
//import org.jfree.data.xy.*;
|
|
||||||
//import org.jfree.ui.RectangleEdge;
|
|
||||||
//import org.jfree.ui.RectangleInsets;
|
|
||||||
//import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
//import org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax;
|
|
||||||
//import org.nd4j.linalg.dataset.DataSet;
|
|
||||||
//import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
|
||||||
//import org.nd4j.linalg.factory.Nd4j;
|
|
||||||
//
|
|
||||||
//import javax.swing.*;
|
|
||||||
//import java.awt.*;
|
|
||||||
//import java.util.ArrayList;
|
|
||||||
//import java.util.List;
|
|
||||||
//
|
|
||||||
///**
|
|
||||||
// * Simple plotting methods for the MLPClassifier quickstartexamples
|
|
||||||
// *
|
|
||||||
// * @author Alex Black
|
|
||||||
// */
|
|
||||||
//public class PlotUtils {
|
|
||||||
//
|
|
||||||
// /**
|
|
||||||
// * Plot the training data. Assume 2d input, classification output
|
|
||||||
// *
|
|
||||||
// * @param model Model to use to get predictions
|
|
||||||
// * @param trainIter DataSet Iterator
|
|
||||||
// * @param backgroundIn sets of x,y points in input space, plotted in the background
|
|
||||||
// * @param nDivisions Number of points (per axis, for the backgroundIn/backgroundOut arrays)
|
|
||||||
// */
|
|
||||||
// public static void plotTrainingData(MultiLayerNetwork model, DataSetIterator trainIter, INDArray backgroundIn, int nDivisions) {
|
|
||||||
// double[] mins = backgroundIn.min(0).data().asDouble();
|
|
||||||
// double[] maxs = backgroundIn.max(0).data().asDouble();
|
|
||||||
//
|
|
||||||
// DataSet ds = allBatches(trainIter);
|
|
||||||
// INDArray backgroundOut = model.output(backgroundIn);
|
|
||||||
//
|
|
||||||
// XYZDataset backgroundData = createBackgroundData(backgroundIn, backgroundOut);
|
|
||||||
// JPanel panel = new ChartPanel(createChart(backgroundData, mins, maxs, nDivisions, createDataSetTrain(ds.getFeatures(), ds.getLabels())));
|
|
||||||
//
|
|
||||||
// JFrame f = new JFrame();
|
|
||||||
// f.add(panel);
|
|
||||||
// f.setDefaultCloseOperation(WindowConstants.EXIT_ON_CLOSE);
|
|
||||||
// f.pack();
|
|
||||||
// f.setTitle("Training Data");
|
|
||||||
//
|
|
||||||
// f.setVisible(true);
|
|
||||||
// f.setLocation(0, 0);
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// /**
|
|
||||||
// * Plot the training data. Assume 2d input, classification output
|
|
||||||
// *
|
|
||||||
// * @param model Model to use to get predictions
|
|
||||||
// * @param testIter Test Iterator
|
|
||||||
// * @param backgroundIn sets of x,y points in input space, plotted in the background
|
|
||||||
// * @param nDivisions Number of points (per axis, for the backgroundIn/backgroundOut arrays)
|
|
||||||
// */
|
|
||||||
// public static void plotTestData(MultiLayerNetwork model, DataSetIterator testIter, INDArray backgroundIn, int nDivisions) {
|
|
||||||
//
|
|
||||||
// double[] mins = backgroundIn.min(0).data().asDouble();
|
|
||||||
// double[] maxs = backgroundIn.max(0).data().asDouble();
|
|
||||||
//
|
|
||||||
// INDArray backgroundOut = model.output(backgroundIn);
|
|
||||||
// XYZDataset backgroundData = createBackgroundData(backgroundIn, backgroundOut);
|
|
||||||
// DataSet ds = allBatches(testIter);
|
|
||||||
// INDArray predicted = model.output(ds.getFeatures());
|
|
||||||
// JPanel panel = new ChartPanel(createChart(backgroundData, mins, maxs, nDivisions, createDataSetTest(ds.getFeatures(), ds.getLabels(), predicted)));
|
|
||||||
//
|
|
||||||
// JFrame f = new JFrame();
|
|
||||||
// f.add(panel);
|
|
||||||
// f.setDefaultCloseOperation(WindowConstants.EXIT_ON_CLOSE);
|
|
||||||
// f.pack();
|
|
||||||
// f.setTitle("Test Data");
|
|
||||||
//
|
|
||||||
// f.setVisible(true);
|
|
||||||
// f.setLocationRelativeTo(null);
|
|
||||||
// //f.setLocation(100,100);
|
|
||||||
//
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
//
|
|
||||||
// /**
|
|
||||||
// * Create data for the background data set
|
|
||||||
// */
|
|
||||||
// private static XYZDataset createBackgroundData(INDArray backgroundIn, INDArray backgroundOut) {
|
|
||||||
// int nRows = backgroundIn.rows();
|
|
||||||
// double[] xValues = new double[nRows];
|
|
||||||
// double[] yValues = new double[nRows];
|
|
||||||
// double[] zValues = new double[nRows];
|
|
||||||
// for (int i = 0; i < nRows; i++) {
|
|
||||||
// xValues[i] = backgroundIn.getDouble(i, 0);
|
|
||||||
// yValues[i] = backgroundIn.getDouble(i, 1);
|
|
||||||
// zValues[i] = backgroundOut.getDouble(i, 0);
|
|
||||||
//
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// DefaultXYZDataset dataset = new DefaultXYZDataset();
|
|
||||||
// dataset.addSeries("Series 1",
|
|
||||||
// new double[][]{xValues, yValues, zValues});
|
|
||||||
// return dataset;
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// //Training data
|
|
||||||
// private static XYDataset createDataSetTrain(INDArray features, INDArray labels) {
|
|
||||||
// int nRows = features.rows();
|
|
||||||
//
|
|
||||||
// int nClasses = 2; // Binary classification using one output call end sigmoid.
|
|
||||||
//
|
|
||||||
// XYSeries[] series = new XYSeries[nClasses];
|
|
||||||
// for (int i = 0; i < series.length; i++) series[i] = new XYSeries("Class " + i);
|
|
||||||
// INDArray argMax = Nd4j.getExecutioner().exec(new ArgMax(new INDArray[]{labels},false,new int[]{1}))[0];
|
|
||||||
// for (int i = 0; i < nRows; i++) {
|
|
||||||
// int classIdx = (int) argMax.getDouble(i);
|
|
||||||
// series[classIdx].add(features.getDouble(i, 0), features.getDouble(i, 1));
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// XYSeriesCollection c = new XYSeriesCollection();
|
|
||||||
// for (XYSeries s : series) c.addSeries(s);
|
|
||||||
// return c;
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// //Test data
|
|
||||||
// private static XYDataset createDataSetTest(INDArray features, INDArray labels, INDArray predicted) {
|
|
||||||
// int nRows = features.rows();
|
|
||||||
//
|
|
||||||
// int nClasses = 2; // Binary classification using one output call end sigmoid.
|
|
||||||
//
|
|
||||||
// XYSeries[] series = new XYSeries[nClasses * nClasses];
|
|
||||||
// int[] series_index = new int[]{0, 3, 2, 1}; //little hack to make the charts look consistent.
|
|
||||||
// for (int i = 0; i < nClasses * nClasses; i++) {
|
|
||||||
// int trueClass = i / nClasses;
|
|
||||||
// int predClass = i % nClasses;
|
|
||||||
// String label = "actual=" + trueClass + ", pred=" + predClass;
|
|
||||||
// series[series_index[i]] = new XYSeries(label);
|
|
||||||
// }
|
|
||||||
// INDArray actualIdx = labels.argMax(1);
|
|
||||||
// INDArray predictedIdx = predicted.argMax(1);
|
|
||||||
// for (int i = 0; i < nRows; i++) {
|
|
||||||
// int classIdx = actualIdx.getInt(i);
|
|
||||||
// int predIdx = predictedIdx.getInt(i);
|
|
||||||
// int idx = series_index[classIdx * nClasses + predIdx];
|
|
||||||
// series[idx].add(features.getDouble(i, 0), features.getDouble(i, 1));
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// XYSeriesCollection c = new XYSeriesCollection();
|
|
||||||
// for (XYSeries s : series) c.addSeries(s);
|
|
||||||
// return c;
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// private static JFreeChart createChart(XYZDataset dataset, double[] mins, double[] maxs, int nPoints, XYDataset xyData) {
|
|
||||||
// NumberAxis xAxis = new NumberAxis("X");
|
|
||||||
// xAxis.setRange(mins[0], maxs[0]);
|
|
||||||
//
|
|
||||||
//
|
|
||||||
// NumberAxis yAxis = new NumberAxis("Y");
|
|
||||||
// yAxis.setRange(mins[1], maxs[1]);
|
|
||||||
//
|
|
||||||
// XYBlockRenderer renderer = new XYBlockRenderer();
|
|
||||||
// renderer.setBlockWidth((maxs[0] - mins[0]) / (nPoints - 1));
|
|
||||||
// renderer.setBlockHeight((maxs[1] - mins[1]) / (nPoints - 1));
|
|
||||||
// PaintScale scale = new GrayPaintScale(0, 1.0);
|
|
||||||
// renderer.setPaintScale(scale);
|
|
||||||
// XYPlot plot = new XYPlot(dataset, xAxis, yAxis, renderer);
|
|
||||||
// plot.setBackgroundPaint(Color.lightGray);
|
|
||||||
// plot.setDomainGridlinesVisible(false);
|
|
||||||
// plot.setRangeGridlinesVisible(false);
|
|
||||||
// plot.setAxisOffset(new RectangleInsets(5, 5, 5, 5));
|
|
||||||
// JFreeChart chart = new JFreeChart("", plot);
|
|
||||||
// chart.getXYPlot().getRenderer().setSeriesVisibleInLegend(0, false);
|
|
||||||
//
|
|
||||||
//
|
|
||||||
// NumberAxis scaleAxis = new NumberAxis("Probability (class 1)");
|
|
||||||
// scaleAxis.setAxisLinePaint(Color.white);
|
|
||||||
// scaleAxis.setTickMarkPaint(Color.white);
|
|
||||||
// scaleAxis.setTickLabelFont(new Font("Dialog", Font.PLAIN, 7));
|
|
||||||
// PaintScaleLegend legend = new PaintScaleLegend(new GrayPaintScale(),
|
|
||||||
// scaleAxis);
|
|
||||||
// legend.setStripOutlineVisible(false);
|
|
||||||
// legend.setSubdivisionCount(20);
|
|
||||||
// legend.setAxisLocation(AxisLocation.BOTTOM_OR_LEFT);
|
|
||||||
// legend.setAxisOffset(5.0);
|
|
||||||
// legend.setMargin(new RectangleInsets(5, 5, 5, 5));
|
|
||||||
// legend.setFrame(new BlockBorder(Color.red));
|
|
||||||
// legend.setPadding(new RectangleInsets(10, 10, 10, 10));
|
|
||||||
// legend.setStripWidth(10);
|
|
||||||
// legend.setPosition(RectangleEdge.LEFT);
|
|
||||||
// chart.addSubtitle(legend);
|
|
||||||
//
|
|
||||||
// ChartUtilities.applyCurrentTheme(chart);
|
|
||||||
//
|
|
||||||
// plot.setDataset(1, xyData);
|
|
||||||
// XYLineAndShapeRenderer renderer2 = new XYLineAndShapeRenderer();
|
|
||||||
// renderer2.setBaseLinesVisible(false);
|
|
||||||
// plot.setRenderer(1, renderer2);
|
|
||||||
//
|
|
||||||
// plot.setDatasetRenderingOrder(DatasetRenderingOrder.FORWARD);
|
|
||||||
//
|
|
||||||
// return chart;
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// public static INDArray generatePointsOnGraph(double xMin, double xMax, double yMin, double yMax, int nPointsPerAxis) {
|
|
||||||
// //generate all the x,y points
|
|
||||||
// double[][] evalPoints = new double[nPointsPerAxis * nPointsPerAxis][2];
|
|
||||||
// int count = 0;
|
|
||||||
// for (int i = 0; i < nPointsPerAxis; i++) {
|
|
||||||
// for (int j = 0; j < nPointsPerAxis; j++) {
|
|
||||||
// double x = i * (xMax - xMin) / (nPointsPerAxis - 1) + xMin;
|
|
||||||
// double y = j * (yMax - yMin) / (nPointsPerAxis - 1) + yMin;
|
|
||||||
//
|
|
||||||
// evalPoints[count][0] = x;
|
|
||||||
// evalPoints[count][1] = y;
|
|
||||||
//
|
|
||||||
// count++;
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// return Nd4j.create(evalPoints);
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// /**
|
|
||||||
// * This is to collect all the data and return it as one minibatch. Obviously only for use here with small datasets
|
|
||||||
// * @param iter
|
|
||||||
// * @return
|
|
||||||
// */
|
|
||||||
// private static DataSet allBatches(DataSetIterator iter) {
|
|
||||||
//
|
|
||||||
// List<DataSet> fullSet = new ArrayList<>();
|
|
||||||
// iter.reset();
|
|
||||||
// while (iter.hasNext()) {
|
|
||||||
// List<DataSet> miniBatchList = iter.next().asList();
|
|
||||||
// fullSet.addAll(miniBatchList);
|
|
||||||
// }
|
|
||||||
// iter.reset();
|
|
||||||
// return new ListDataSetIterator<>(fullSet,fullSet.size()).next();
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
//}
|
|
|
@ -0,0 +1,76 @@
|
||||||
|
//package eu.dnetlib.example
|
||||||
|
//
|
||||||
|
//import com.intel.analytics.bigdl.dllib.NNContext
|
||||||
|
//import com.intel.analytics.bigdl.dllib.keras.Model
|
||||||
|
//import com.intel.analytics.bigdl.dllib.keras.models.Models
|
||||||
|
//import com.intel.analytics.bigdl.dllib.keras.optimizers.Adam
|
||||||
|
//import com.intel.analytics.bigdl.dllib.nn.ClassNLLCriterion
|
||||||
|
//import com.intel.analytics.bigdl.dllib.utils.Shape
|
||||||
|
//import com.intel.analytics.bigdl.dllib.keras.layers._
|
||||||
|
//import com.intel.analytics.bigdl.numeric.NumericFloat
|
||||||
|
//import org.apache.spark.ml.feature.VectorAssembler
|
||||||
|
//import org.apache.spark._
|
||||||
|
//import org.apache.spark.sql.{SQLContext, SparkSession}
|
||||||
|
//import org.apache.spark.sql.functions._
|
||||||
|
//import org.apache.spark.sql.types.DoubleType
|
||||||
|
//object Example {
|
||||||
|
//
|
||||||
|
//
|
||||||
|
// def main(args: Array[String]): Unit = {
|
||||||
|
//
|
||||||
|
// val conf = new SparkConf().setMaster("local[2]").setAppName("dllib_demo")
|
||||||
|
// val sc = NNContext.initNNContext(conf)
|
||||||
|
//
|
||||||
|
//// val spark = new SQLContext(sc) //deprecated
|
||||||
|
// val spark = SparkSession
|
||||||
|
// .builder()
|
||||||
|
// .config(sc.getConf)
|
||||||
|
// .getOrCreate()
|
||||||
|
//
|
||||||
|
// val path = "/Users/miconis/Desktop/example_dataset.csv"
|
||||||
|
// val df = spark.read.options(Map("inferSchema"->"true","delimiter"->",")).csv(path)
|
||||||
|
// .toDF("num_times_pregrant", "plasma_glucose", "blood_pressure", "skin_fold_thickness", "2-hour_insulin", "body_mass_index", "diabetes_pedigree_function", "age", "class")
|
||||||
|
//
|
||||||
|
// val assembler = new VectorAssembler()
|
||||||
|
// .setInputCols(Array("num_times_pregrant", "plasma_glucose", "blood_pressure", "skin_fold_thickness", "2-hour_insulin", "body_mass_index", "diabetes_pedigree_function", "age"))
|
||||||
|
// .setOutputCol("features")
|
||||||
|
// val assembleredDF = assembler.transform(df)
|
||||||
|
// val df2 = assembleredDF.withColumn("label", col("class").cast(DoubleType) + lit(1))
|
||||||
|
//
|
||||||
|
// val Array(trainDF, valDF) = df2.randomSplit(Array(0.8, 0.2))
|
||||||
|
//
|
||||||
|
// val x1 = Input(Shape(8))
|
||||||
|
// val merge = Merge.merge(inputs = List(x1, x1), mode = "dot")
|
||||||
|
// val dense1 = Dense(12, activation="relu").inputs(x1)
|
||||||
|
// val dense2 = Dense(8, activation="relu").inputs(dense1)
|
||||||
|
// val dense3 = Dense(2, activation="relu").inputs(dense2)
|
||||||
|
// val dmodel = Model(x1, dense3)
|
||||||
|
//
|
||||||
|
// dmodel.compile(optimizer = new Adam(), loss = ClassNLLCriterion())
|
||||||
|
//
|
||||||
|
//
|
||||||
|
// //training
|
||||||
|
// dmodel.fit(x = trainDF, batchSize = 4, nbEpoch = 2, featureCols = Array("features"), labelCols = Array("label"), valX = valDF)
|
||||||
|
//
|
||||||
|
//
|
||||||
|
//// //save model
|
||||||
|
//// val modelPath = "/tmp/demo/keras.model"
|
||||||
|
//// dmodel.saveModel(modelPath)
|
||||||
|
////
|
||||||
|
////
|
||||||
|
//// //load model
|
||||||
|
//// val loadModel = Models.loadModel(modelPath)
|
||||||
|
// val loadModel = dmodel
|
||||||
|
//
|
||||||
|
// //inference
|
||||||
|
// val preDF2 = loadModel.predict(valDF, featureCols = Array("features"), predictionCol = "predict")
|
||||||
|
//
|
||||||
|
// preDF2.show(false)
|
||||||
|
//
|
||||||
|
// //evaluation
|
||||||
|
// val ret = dmodel.evaluate(trainDF, batchSize = 4, featureCols = Array("features"), labelCols = Array("label"))
|
||||||
|
//
|
||||||
|
// ret.foreach(println)
|
||||||
|
//
|
||||||
|
// }
|
||||||
|
//}
|
|
@ -1,5 +1,11 @@
|
||||||
package eu.dnetlib.featureextraction;
|
package eu.dnetlib.featureextraction;
|
||||||
|
|
||||||
|
import com.google.common.collect.Lists;
|
||||||
|
import com.johnsnowlabs.nlp.*;
|
||||||
|
import com.johnsnowlabs.nlp.annotators.sbd.pragmatic.SentenceDetector;
|
||||||
|
import com.johnsnowlabs.nlp.embeddings.BertSentenceEmbeddings;
|
||||||
|
import org.apache.spark.ml.Pipeline;
|
||||||
|
import org.apache.spark.ml.PipelineStage;
|
||||||
import org.apache.spark.ml.clustering.LDA;
|
import org.apache.spark.ml.clustering.LDA;
|
||||||
import org.apache.spark.ml.clustering.LDAModel;
|
import org.apache.spark.ml.clustering.LDAModel;
|
||||||
import org.apache.spark.ml.feature.CountVectorizer;
|
import org.apache.spark.ml.feature.CountVectorizer;
|
||||||
|
@ -8,16 +14,20 @@ import org.apache.spark.ml.feature.StopWordsRemover;
|
||||||
import org.apache.spark.ml.feature.Tokenizer;
|
import org.apache.spark.ml.feature.Tokenizer;
|
||||||
import org.apache.spark.sql.Dataset;
|
import org.apache.spark.sql.Dataset;
|
||||||
import org.apache.spark.sql.Row;
|
import org.apache.spark.sql.Row;
|
||||||
|
import org.apache.spark.sql.RowFactory;
|
||||||
|
import org.apache.spark.sql.SparkSession;
|
||||||
|
import org.apache.spark.sql.types.DataTypes;
|
||||||
|
import org.apache.spark.sql.types.StructField;
|
||||||
|
import org.apache.spark.sql.types.StructType;
|
||||||
import scala.Tuple2;
|
import scala.Tuple2;
|
||||||
|
|
||||||
import java.io.BufferedReader;
|
import java.io.BufferedReader;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.io.InputStreamReader;
|
import java.io.InputStreamReader;
|
||||||
import java.io.Serializable;
|
import java.io.Serializable;
|
||||||
import java.util.HashMap;
|
import java.net.URISyntaxException;
|
||||||
import java.util.HashSet;
|
import java.nio.file.Paths;
|
||||||
import java.util.Map;
|
import java.util.*;
|
||||||
import java.util.Set;
|
|
||||||
|
|
||||||
public class FeatureTransformer implements Serializable {
|
public class FeatureTransformer implements Serializable {
|
||||||
|
|
||||||
|
@ -161,4 +171,5 @@ public class FeatureTransformer implements Serializable {
|
||||||
public static Dataset<Row> ldaInference(Dataset<Row> inputDS, LDAModel ldaModel) {
|
public static Dataset<Row> ldaInference(Dataset<Row> inputDS, LDAModel ldaModel) {
|
||||||
return ldaModel.transform(inputDS).select(ID_COL, LDA_INFERENCE_OUTPUT_COL);
|
return ldaModel.transform(inputDS).select(ID_COL, LDA_INFERENCE_OUTPUT_COL);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,157 @@
|
||||||
|
package eu.dnetlib.featureextraction
|
||||||
|
|
||||||
|
import com.johnsnowlabs.nlp.EmbeddingsFinisher
|
||||||
|
import com.johnsnowlabs.nlp.annotator.SentenceDetector
|
||||||
|
import com.johnsnowlabs.nlp.annotators.Tokenizer
|
||||||
|
import com.johnsnowlabs.nlp.base.DocumentAssembler
|
||||||
|
import com.johnsnowlabs.nlp.embeddings.{BertEmbeddings, BertSentenceEmbeddings, WordEmbeddingsModel}
|
||||||
|
import org.apache.spark.ml.Pipeline
|
||||||
|
import org.apache.spark.sql.functions.{array, col, explode}
|
||||||
|
import org.apache.spark.sql.{Dataset, Row}
|
||||||
|
|
||||||
|
import java.nio.file.Paths
|
||||||
|
|
||||||
|
object ScalaFeatureTransformer {
|
||||||
|
|
||||||
|
val DOCUMENT_COL = "document"
|
||||||
|
val SENTENCE_COL = "sentence"
|
||||||
|
val BERT_SENTENCE_EMBEDDINGS_COL = "bert_sentence"
|
||||||
|
val BERT_EMBEDDINGS_COL = "bert"
|
||||||
|
val TOKENIZER_COL = "tokens"
|
||||||
|
val WORD_EMBEDDINGS_COL = "word"
|
||||||
|
|
||||||
|
//models path
|
||||||
|
private val bertSentenceModelPath = Paths.get(getClass.getResource("/eu/dnetlib/featureextraction/support/sent_small_bert_L6_512_en_2.6.0_2.4_1598350624049").toURI).toFile.getAbsolutePath
|
||||||
|
private val bertModelPath = Paths.get(getClass.getResource("/eu/dnetlib/featureextraction/support/small_bert_L2_128_en_2.6.0_2.4_1598344320681").toURI).toFile.getAbsolutePath
|
||||||
|
private val wordModelPath = Paths.get(getClass.getResource("/eu/dnetlib/featureextraction/support/glove_100d_en_2.4.0_2.4_1579690104032").toURI).toFile.getAbsolutePath
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Extract the SentenceBERT embeddings for the given field.
|
||||||
|
*
|
||||||
|
* @param inputData: the input data
|
||||||
|
* @param inputField: the input field
|
||||||
|
* @return the dataset with the embeddings
|
||||||
|
* */
|
||||||
|
def bertSentenceEmbeddings(inputData: Dataset[Row], inputField: String, modelPath: String): Dataset[Row] = {
|
||||||
|
|
||||||
|
val documentAssembler = new DocumentAssembler()
|
||||||
|
.setInputCol(inputField)
|
||||||
|
.setOutputCol(DOCUMENT_COL)
|
||||||
|
|
||||||
|
val sentence = new SentenceDetector()
|
||||||
|
.setInputCols(DOCUMENT_COL)
|
||||||
|
.setOutputCol(SENTENCE_COL)
|
||||||
|
|
||||||
|
val bertSentenceEmbeddings = BertSentenceEmbeddings
|
||||||
|
.load(modelPath)
|
||||||
|
.setInputCols(SENTENCE_COL)
|
||||||
|
.setOutputCol("raw_" + BERT_SENTENCE_EMBEDDINGS_COL)
|
||||||
|
.setCaseSensitive(false)
|
||||||
|
|
||||||
|
val bertSentenceEmbeddingsFinisher = new EmbeddingsFinisher()
|
||||||
|
.setInputCols("raw_" + BERT_SENTENCE_EMBEDDINGS_COL)
|
||||||
|
.setOutputCols(BERT_SENTENCE_EMBEDDINGS_COL)
|
||||||
|
.setOutputAsVector(true)
|
||||||
|
.setCleanAnnotations(false)
|
||||||
|
|
||||||
|
val pipeline = new Pipeline()
|
||||||
|
.setStages(Array(
|
||||||
|
documentAssembler,
|
||||||
|
sentence,
|
||||||
|
bertSentenceEmbeddings,
|
||||||
|
bertSentenceEmbeddingsFinisher
|
||||||
|
))
|
||||||
|
|
||||||
|
val result = pipeline.fit(inputData).transform(inputData).withColumn(BERT_SENTENCE_EMBEDDINGS_COL, explode(col(BERT_SENTENCE_EMBEDDINGS_COL)))
|
||||||
|
|
||||||
|
result
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Extract the BERT embeddings for the given field.
|
||||||
|
*
|
||||||
|
* @param inputData : the input data
|
||||||
|
* @param inputField : the input field
|
||||||
|
* @return the dataset with the embeddings
|
||||||
|
* */
|
||||||
|
def bertEmbeddings(inputData: Dataset[Row], inputField: String, modelPath: String): Dataset[Row] = {
|
||||||
|
|
||||||
|
val documentAssembler = new DocumentAssembler()
|
||||||
|
.setInputCol(inputField)
|
||||||
|
.setOutputCol(DOCUMENT_COL)
|
||||||
|
|
||||||
|
val tokenizer = new Tokenizer()
|
||||||
|
.setInputCols(DOCUMENT_COL)
|
||||||
|
.setOutputCol(TOKENIZER_COL)
|
||||||
|
|
||||||
|
val bertEmbeddings = BertEmbeddings
|
||||||
|
.load(modelPath)
|
||||||
|
.setInputCols(TOKENIZER_COL, DOCUMENT_COL)
|
||||||
|
.setOutputCol("raw_" + BERT_EMBEDDINGS_COL)
|
||||||
|
.setCaseSensitive(false)
|
||||||
|
|
||||||
|
val bertEmbeddingsFinisher = new EmbeddingsFinisher()
|
||||||
|
.setInputCols("raw_" + BERT_EMBEDDINGS_COL)
|
||||||
|
.setOutputCols(BERT_EMBEDDINGS_COL)
|
||||||
|
.setOutputAsVector(true)
|
||||||
|
.setCleanAnnotations(false)
|
||||||
|
|
||||||
|
val pipeline = new Pipeline()
|
||||||
|
.setStages(Array(
|
||||||
|
documentAssembler,
|
||||||
|
tokenizer,
|
||||||
|
bertEmbeddings,
|
||||||
|
bertEmbeddingsFinisher
|
||||||
|
))
|
||||||
|
|
||||||
|
val result = pipeline.fit(inputData).transform(inputData).withColumn(BERT_EMBEDDINGS_COL, explode(col(BERT_EMBEDDINGS_COL)))
|
||||||
|
|
||||||
|
result
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Extract the Word2Vec embeddings for the given field.
|
||||||
|
*
|
||||||
|
* @param inputData : the input data
|
||||||
|
* @param inputField : the input field
|
||||||
|
* @return the dataset with the embeddings
|
||||||
|
* */
|
||||||
|
def wordEmbeddings(inputData: Dataset[Row], inputField: String, modelPath: String): Dataset[Row] = {
|
||||||
|
|
||||||
|
val documentAssembler = new DocumentAssembler()
|
||||||
|
.setInputCol(inputField)
|
||||||
|
.setOutputCol(DOCUMENT_COL)
|
||||||
|
|
||||||
|
val tokenizer = new Tokenizer()
|
||||||
|
.setInputCols(DOCUMENT_COL)
|
||||||
|
.setOutputCol(TOKENIZER_COL)
|
||||||
|
|
||||||
|
val wordEmbeddings = WordEmbeddingsModel
|
||||||
|
.load(modelPath)
|
||||||
|
.setInputCols(DOCUMENT_COL, TOKENIZER_COL)
|
||||||
|
.setOutputCol("raw_" + WORD_EMBEDDINGS_COL)
|
||||||
|
|
||||||
|
val wordEmbeddingsFinisher = new EmbeddingsFinisher()
|
||||||
|
.setInputCols("raw_" + WORD_EMBEDDINGS_COL)
|
||||||
|
.setOutputCols(WORD_EMBEDDINGS_COL)
|
||||||
|
.setOutputAsVector(true)
|
||||||
|
.setCleanAnnotations(false)
|
||||||
|
|
||||||
|
val pipeline = new Pipeline()
|
||||||
|
.setStages(Array(
|
||||||
|
documentAssembler,
|
||||||
|
tokenizer,
|
||||||
|
wordEmbeddings,
|
||||||
|
wordEmbeddingsFinisher
|
||||||
|
))
|
||||||
|
|
||||||
|
val result = pipeline.fit(inputData).transform(inputData).withColumn(WORD_EMBEDDINGS_COL, explode(col(WORD_EMBEDDINGS_COL)))
|
||||||
|
|
||||||
|
result
|
||||||
|
}
|
||||||
|
|
||||||
|
//bert on the title
|
||||||
|
//bert sentence: on the abstract
|
||||||
|
//word2vec: on the subjects
|
||||||
|
|
||||||
|
}
|
|
@ -5,6 +5,7 @@ import org.codehaus.jackson.annotate.JsonIgnore;
|
||||||
|
|
||||||
import java.io.Serializable;
|
import java.io.Serializable;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
public class Author implements Serializable {
|
public class Author implements Serializable {
|
||||||
|
|
||||||
|
@ -12,24 +13,32 @@ public class Author implements Serializable {
|
||||||
public String firstname;
|
public String firstname;
|
||||||
public String lastname;
|
public String lastname;
|
||||||
public List<CoAuthor> coAuthors;
|
public List<CoAuthor> coAuthors;
|
||||||
public double[] topics;
|
|
||||||
public String orcid;
|
public String orcid;
|
||||||
public String id;
|
public String id;
|
||||||
|
public Map<String, double[]> embeddings;
|
||||||
|
|
||||||
public String pubId;
|
public Map<String, double[]> getEmbeddings() {
|
||||||
|
return embeddings;
|
||||||
public Author() {
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public Author(String fullname, String firstname, String lastname, List<CoAuthor> coAuthors, double[] topics, String id, String pubId, String orcid) {
|
public Author(String fullname, String firstname, String lastname, List<CoAuthor> coAuthors, String orcid, String id, Map<String, double[]> embeddings, String pubId) {
|
||||||
this.fullname = fullname;
|
this.fullname = fullname;
|
||||||
this.firstname = firstname;
|
this.firstname = firstname;
|
||||||
this.lastname = lastname;
|
this.lastname = lastname;
|
||||||
this.coAuthors = coAuthors;
|
this.coAuthors = coAuthors;
|
||||||
this.topics = topics;
|
|
||||||
this.id = id;
|
|
||||||
this.pubId = pubId;
|
|
||||||
this.orcid = orcid;
|
this.orcid = orcid;
|
||||||
|
this.id = id;
|
||||||
|
this.embeddings = embeddings;
|
||||||
|
this.pubId = pubId;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setEmbeddings(Map<String, double[]> embeddings) {
|
||||||
|
this.embeddings = embeddings;
|
||||||
|
}
|
||||||
|
|
||||||
|
public String pubId;
|
||||||
|
|
||||||
|
public Author() {
|
||||||
}
|
}
|
||||||
|
|
||||||
public String getFullname() {
|
public String getFullname() {
|
||||||
|
@ -64,14 +73,6 @@ public class Author implements Serializable {
|
||||||
this.coAuthors = coAuthors;
|
this.coAuthors = coAuthors;
|
||||||
}
|
}
|
||||||
|
|
||||||
public double[] getTopics() {
|
|
||||||
return topics;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setTopics(double[] topics) {
|
|
||||||
this.topics = topics;
|
|
||||||
}
|
|
||||||
|
|
||||||
public String getId() {
|
public String getId() {
|
||||||
return id;
|
return id;
|
||||||
}
|
}
|
||||||
|
|
|
@ -14,34 +14,32 @@ import javax.rmi.CORBA.Util;
|
||||||
import java.math.BigInteger;
|
import java.math.BigInteger;
|
||||||
import java.security.MessageDigest;
|
import java.security.MessageDigest;
|
||||||
import java.security.NoSuchAlgorithmException;
|
import java.security.NoSuchAlgorithmException;
|
||||||
import java.util.ArrayList;
|
import java.util.*;
|
||||||
import java.util.Arrays;
|
|
||||||
import java.util.Iterator;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
public class AuthorsFactory {
|
public class AuthorsFactory {
|
||||||
|
|
||||||
public static JavaRDD<Author> extractAuthorsFromPublications(JavaRDD<String> entities, JavaPairRDD<String, DenseVector> topics) {
|
public static JavaRDD<Author> extractAuthorsFromPublications(JavaRDD<Publication> publications, JavaPairRDD<String, Map<String, double[]>> topics) {
|
||||||
|
|
||||||
JavaPairRDD<Publication, DenseVector> publicationWithTopics = entities.map(x -> new ObjectMapper().configure(DeserializationConfig.Feature.FAIL_ON_UNKNOWN_PROPERTIES, false).readValue(x, Publication.class))
|
//read topics
|
||||||
|
JavaPairRDD<Publication, Map<String, double[]>> publicationWithEmbeddings = publications
|
||||||
.mapToPair(p -> new Tuple2<>(p.getId(), p))
|
.mapToPair(p -> new Tuple2<>(p.getId(), p))
|
||||||
.join(topics)
|
.join(topics)
|
||||||
.mapToPair(Tuple2::_2);
|
.mapToPair(Tuple2::_2);
|
||||||
|
|
||||||
return publicationWithTopics.flatMap(p -> createAuthors(p));
|
return publicationWithEmbeddings.flatMap(AuthorsFactory::createAuthors);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public static Iterator<Author> createAuthors(Tuple2<Publication, DenseVector> publicationWithTopic){
|
public static Iterator<Author> createAuthors(Tuple2<Publication, Map<String, double[]>> publicationWithEmbeddings){
|
||||||
List<CoAuthor> baseCoAuthors = publicationWithTopic._1()
|
List<CoAuthor> baseCoAuthors = publicationWithEmbeddings._1()
|
||||||
.getAuthor()
|
.getAuthor()
|
||||||
.stream()
|
.stream()
|
||||||
.map(a -> new CoAuthor(a.getFullname(), a.getName()!=null?a.getName():"", a.getSurname()!=null?a.getSurname():"", a.getPid().size()>0? a.getPid().get(0).getValue():""))
|
.map(a -> new CoAuthor(a.getFullname(), a.getName()!=null?a.getName():"", a.getSurname()!=null?a.getSurname():"", a.getPid().size()>0? a.getPid().get(0).getValue():""))
|
||||||
.collect(Collectors.toList());
|
.collect(Collectors.toList());
|
||||||
|
|
||||||
List<Author> authors = new ArrayList<>();
|
List<Author> authors = new ArrayList<>();
|
||||||
for(eu.dnetlib.dhp.schema.oaf.Author a : publicationWithTopic._1().getAuthor()) {
|
for(eu.dnetlib.dhp.schema.oaf.Author a : publicationWithEmbeddings._1().getAuthor()) {
|
||||||
|
|
||||||
//prepare orcid
|
//prepare orcid
|
||||||
String orcid = a.getPid().size()>0? a.getPid().get(0).getValue() : "";
|
String orcid = a.getPid().size()>0? a.getPid().get(0).getValue() : "";
|
||||||
|
@ -50,9 +48,19 @@ public class AuthorsFactory {
|
||||||
coAuthors.remove(new CoAuthor(a.getFullname(), a.getName() != null ? a.getName() : "", a.getSurname() != null ? a.getSurname() : "", a.getPid().size() > 0 ? a.getPid().get(0).getValue() : ""));
|
coAuthors.remove(new CoAuthor(a.getFullname(), a.getName() != null ? a.getName() : "", a.getSurname() != null ? a.getSurname() : "", a.getPid().size() > 0 ? a.getPid().get(0).getValue() : ""));
|
||||||
|
|
||||||
//prepare raw author id
|
//prepare raw author id
|
||||||
String id = "author::" + getMd5(a.getFullname().concat(publicationWithTopic._1().getId()));
|
String id = "author::" + getMd5(a.getFullname().concat(publicationWithEmbeddings._1().getId()));
|
||||||
|
|
||||||
authors.add(new Author(a.getFullname(), a.getName(), a.getSurname(), coAuthors, publicationWithTopic._2().toArray(), id, publicationWithTopic._1().getId(), orcid));
|
//prepare embeddings
|
||||||
|
authors.add(new Author(
|
||||||
|
a.getFullname(),
|
||||||
|
a.getName(),
|
||||||
|
a.getSurname(),
|
||||||
|
coAuthors,
|
||||||
|
orcid,
|
||||||
|
id,
|
||||||
|
publicationWithEmbeddings._2(),
|
||||||
|
publicationWithEmbeddings._1().getId())
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
return authors.iterator();
|
return authors.iterator();
|
||||||
|
|
|
@ -0,0 +1,70 @@
|
||||||
|
package eu.dnetlib.support;
|
||||||
|
|
||||||
|
import com.google.common.collect.Sets;
|
||||||
|
import org.codehaus.jackson.map.ObjectMapper;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.io.Serializable;
|
||||||
|
import java.util.HashSet;
|
||||||
|
import java.util.Set;
|
||||||
|
|
||||||
|
public class ConnectedComponent implements Serializable {
|
||||||
|
|
||||||
|
private HashSet<String> docs;
|
||||||
|
private String ccId;
|
||||||
|
private HashSet<Relation> simrels;
|
||||||
|
|
||||||
|
public ConnectedComponent() {
|
||||||
|
}
|
||||||
|
|
||||||
|
public ConnectedComponent(String ccId, Set<String> docs, Set<Relation> simrels) {
|
||||||
|
this.docs = new HashSet<>(docs);
|
||||||
|
this.ccId = ccId;
|
||||||
|
this.simrels = new HashSet<>(simrels);
|
||||||
|
}
|
||||||
|
|
||||||
|
public ConnectedComponent(Set<String> docs) {
|
||||||
|
this.docs = new HashSet<>(docs);
|
||||||
|
//initialization of id and relations missing
|
||||||
|
}
|
||||||
|
|
||||||
|
public ConnectedComponent(String ccId, Iterable<String> docs, Iterable<Relation> simrels) {
|
||||||
|
this.ccId = ccId;
|
||||||
|
this.docs = Sets.newHashSet(docs);
|
||||||
|
this.simrels = Sets.newHashSet(simrels);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String toString() {
|
||||||
|
ObjectMapper mapper = new ObjectMapper();
|
||||||
|
try {
|
||||||
|
return mapper.writeValueAsString(this);
|
||||||
|
} catch (IOException e) {
|
||||||
|
throw new RuntimeException("Failed to create Json: ", e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public Set<String> getDocs() {
|
||||||
|
return docs;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setDocs(HashSet<String> docs) {
|
||||||
|
this.docs = docs;
|
||||||
|
}
|
||||||
|
|
||||||
|
public String getCcId() {
|
||||||
|
return ccId;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setCcId(String ccId) {
|
||||||
|
this.ccId = ccId;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setSimrels(HashSet<Relation> simrels) {
|
||||||
|
this.simrels = simrels;
|
||||||
|
}
|
||||||
|
|
||||||
|
public HashSet<Relation> getSimrels() {
|
||||||
|
return simrels;
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,52 @@
|
||||||
|
package eu.dnetlib.support;
|
||||||
|
|
||||||
|
import java.io.Serializable;
|
||||||
|
|
||||||
|
public class Relation implements Serializable {
|
||||||
|
|
||||||
|
String source;
|
||||||
|
String target;
|
||||||
|
String type;
|
||||||
|
|
||||||
|
public Relation() {
|
||||||
|
}
|
||||||
|
|
||||||
|
public Relation(String source, String target, String type) {
|
||||||
|
this.source = source;
|
||||||
|
this.target = target;
|
||||||
|
this.type = type;
|
||||||
|
}
|
||||||
|
|
||||||
|
public String getSource() {
|
||||||
|
return source;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setSource(String source) {
|
||||||
|
this.source = source;
|
||||||
|
}
|
||||||
|
|
||||||
|
public String getTarget() {
|
||||||
|
return target;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setTarget(String target) {
|
||||||
|
this.target = target;
|
||||||
|
}
|
||||||
|
|
||||||
|
public String getType() {
|
||||||
|
return type;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setType(String type) {
|
||||||
|
this.type = type;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String toString() {
|
||||||
|
return "Relation{" +
|
||||||
|
"source='" + source + '\'' +
|
||||||
|
", target='" + target + '\'' +
|
||||||
|
", type='" + type + '\'' +
|
||||||
|
'}';
|
||||||
|
}
|
||||||
|
}
|
|
@ -7,6 +7,7 @@ import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
import java.lang.annotation.Target;
|
import java.lang.annotation.Target;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
|
import java.util.HashMap;
|
||||||
|
|
||||||
public class UtilityTest {
|
public class UtilityTest {
|
||||||
|
|
||||||
|
@ -24,7 +25,7 @@ public class UtilityTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void lnfiTest() throws Exception {
|
public void lnfiTest() throws Exception {
|
||||||
Author a = new Author("De Bonis, Michele", "Æ", "De Bonis", new ArrayList<CoAuthor>(), new double[]{0.0, 1.0}, "author::id", "pub::id", "orcid");
|
Author a = new Author("De Bonis, Michele", "Æ", "De Bonis", new ArrayList<CoAuthor>(), "orcid", "author::id", new HashMap<String, double[]>(), "pub::id");
|
||||||
System.out.println("a = " + a.isAccurate());
|
System.out.println("a = " + a.isAccurate());
|
||||||
System.out.println(AuthorsFactory.getLNFI(a));
|
System.out.println(AuthorsFactory.getLNFI(a));
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,47 +0,0 @@
|
||||||
package eu.dnetlib.deeplearning;
|
|
||||||
|
|
||||||
import com.beust.jcommander.internal.Sets;
|
|
||||||
import com.google.common.collect.Lists;
|
|
||||||
import eu.dnetlib.deeplearning.support.DataSetProcessor;
|
|
||||||
import eu.dnetlib.support.Relation;
|
|
||||||
import org.junit.jupiter.api.BeforeAll;
|
|
||||||
import org.junit.jupiter.api.Test;
|
|
||||||
import org.nd4j.linalg.dataset.MultiDataSet;
|
|
||||||
|
|
||||||
import java.util.*;
|
|
||||||
import java.util.stream.Collectors;
|
|
||||||
|
|
||||||
public class DataSetProcessorTest {
|
|
||||||
|
|
||||||
static Map<String, double[]> features;
|
|
||||||
static Set<Relation> relations;
|
|
||||||
static List<String> groundTruth;
|
|
||||||
|
|
||||||
@BeforeAll
|
|
||||||
public static void init(){
|
|
||||||
//initialize example features
|
|
||||||
features = new HashMap<>();
|
|
||||||
features.put("0", new double[]{0.0,0.0});
|
|
||||||
features.put("1", new double[]{1.0,1.0});
|
|
||||||
features.put("2", new double[]{2.0,2.0});
|
|
||||||
|
|
||||||
//initialize example relations
|
|
||||||
relations = new HashSet<>(Lists.newArrayList(
|
|
||||||
new Relation("0", "1", "simrel"),
|
|
||||||
new Relation("1", "2", "simrel")
|
|
||||||
));
|
|
||||||
|
|
||||||
//initialize example ground truth
|
|
||||||
groundTruth = Lists.newArrayList("class1", "class1", "class2");
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void getMultiDataSetTest() throws Exception {
|
|
||||||
MultiDataSet multiDataSet = DataSetProcessor.getMultiDataSet(features, relations, groundTruth);
|
|
||||||
System.out.println("multiDataSet = " + multiDataSet);
|
|
||||||
|
|
||||||
multiDataSet.asList();
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,33 +0,0 @@
|
||||||
package eu.dnetlib.deeplearning;
|
|
||||||
|
|
||||||
import eu.dnetlib.deeplearning.support.NetworkConfigurations;
|
|
||||||
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
|
|
||||||
import org.deeplearning4j.nn.graph.ComputationGraph;
|
|
||||||
import org.junit.jupiter.api.Test;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
|
||||||
|
|
||||||
public class NetworkConfigurationTests {
|
|
||||||
|
|
||||||
public final static int N = 3; //number of nodes
|
|
||||||
public final static int K = 7; //number of features
|
|
||||||
|
|
||||||
public static INDArray[] exampleGraph = new INDArray[]{
|
|
||||||
Nd4j.zeros(N, K), //features
|
|
||||||
Nd4j.ones(N, N), //adjacency
|
|
||||||
Nd4j.ones(N, N) //degree
|
|
||||||
};
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void simpleGCNTest() {
|
|
||||||
|
|
||||||
ComputationGraphConfiguration simpleGCNConf = NetworkConfigurations.getSimpleGCN(3, K, 5, 2);
|
|
||||||
ComputationGraph simpleGCN = new ComputationGraph(simpleGCNConf);
|
|
||||||
simpleGCN.init();
|
|
||||||
|
|
||||||
INDArray[] output = simpleGCN.output(exampleGraph);
|
|
||||||
System.out.println("output = " + output[0]);
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
|
@ -0,0 +1,53 @@
|
||||||
|
package eu.dnetlib.deeplearning.featureextraction;
|
||||||
|
|
||||||
|
import eu.dnetlib.featureextraction.ScalaFeatureTransformer;
|
||||||
|
import org.apache.spark.api.java.JavaRDD;
|
||||||
|
import org.apache.spark.api.java.JavaSparkContext;
|
||||||
|
import org.apache.spark.ml.linalg.DenseVector;
|
||||||
|
import org.apache.spark.ml.linalg.DenseVector$;
|
||||||
|
import org.apache.spark.sql.Dataset;
|
||||||
|
import org.apache.spark.sql.Row;
|
||||||
|
import org.apache.spark.sql.RowFactory;
|
||||||
|
import org.apache.spark.sql.SparkSession;
|
||||||
|
import org.apache.spark.sql.types.DataTypes;
|
||||||
|
import org.apache.spark.sql.types.Metadata;
|
||||||
|
import org.apache.spark.sql.types.StructField;
|
||||||
|
import org.apache.spark.sql.types.StructType;
|
||||||
|
import org.junit.jupiter.api.BeforeAll;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
import scala.collection.JavaConversions;
|
||||||
|
import scala.collection.mutable.WrappedArray;
|
||||||
|
|
||||||
|
import javax.xml.crypto.Data;
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.Arrays;
|
||||||
|
|
||||||
|
public class FeatureTransformerTest {
|
||||||
|
|
||||||
|
static SparkSession spark;
|
||||||
|
static JavaSparkContext context;
|
||||||
|
static Dataset<Row> inputData;
|
||||||
|
static StructType inputSchema = new StructType(new StructField[]{
|
||||||
|
new StructField("title", DataTypes.StringType, false, Metadata.empty()),
|
||||||
|
new StructField("abstract", DataTypes.StringType, false, Metadata.empty())
|
||||||
|
});
|
||||||
|
|
||||||
|
@BeforeAll
|
||||||
|
public static void setup() throws IOException {
|
||||||
|
|
||||||
|
spark = SparkSession
|
||||||
|
.builder()
|
||||||
|
.appName("Testing")
|
||||||
|
.master("local[*]")
|
||||||
|
.getOrCreate();
|
||||||
|
|
||||||
|
context = JavaSparkContext.fromSparkContext(spark.sparkContext());
|
||||||
|
|
||||||
|
inputData = spark.createDataFrame(Arrays.asList(
|
||||||
|
RowFactory.create("article title 1", "article description 1"),
|
||||||
|
RowFactory.create("article title 2", "article description 2")
|
||||||
|
), inputSchema);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
165
pom.xml
165
pom.xml
|
@ -94,6 +94,8 @@
|
||||||
<enabled>false</enabled>
|
<enabled>false</enabled>
|
||||||
</snapshots>
|
</snapshots>
|
||||||
</repository>
|
</repository>
|
||||||
|
|
||||||
|
|
||||||
</repositories>
|
</repositories>
|
||||||
<build>
|
<build>
|
||||||
<directory>target</directory>
|
<directory>target</directory>
|
||||||
|
@ -305,31 +307,31 @@
|
||||||
<groupId>com.fasterxml.jackson.core</groupId>
|
<groupId>com.fasterxml.jackson.core</groupId>
|
||||||
<artifactId>jackson-databind</artifactId>
|
<artifactId>jackson-databind</artifactId>
|
||||||
<version>${jackson.version}</version>
|
<version>${jackson.version}</version>
|
||||||
<scope>provided</scope>
|
<!-- <scope>provided</scope>-->
|
||||||
</dependency>
|
</dependency>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>com.fasterxml.jackson.dataformat</groupId>
|
<groupId>com.fasterxml.jackson.dataformat</groupId>
|
||||||
<artifactId>jackson-dataformat-xml</artifactId>
|
<artifactId>jackson-dataformat-xml</artifactId>
|
||||||
<version>${jackson.version}</version>
|
<version>${jackson.version}</version>
|
||||||
<scope>provided</scope>
|
<!-- <scope>provided</scope>-->
|
||||||
</dependency>
|
</dependency>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>com.fasterxml.jackson.module</groupId>
|
<groupId>com.fasterxml.jackson.module</groupId>
|
||||||
<artifactId>jackson-module-jsonSchema</artifactId>
|
<artifactId>jackson-module-jsonSchema</artifactId>
|
||||||
<version>${jackson.version}</version>
|
<version>${jackson.version}</version>
|
||||||
<scope>provided</scope>
|
<!-- <scope>provided</scope>-->
|
||||||
</dependency>
|
</dependency>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>com.fasterxml.jackson.core</groupId>
|
<groupId>com.fasterxml.jackson.core</groupId>
|
||||||
<artifactId>jackson-core</artifactId>
|
<artifactId>jackson-core</artifactId>
|
||||||
<version>${jackson.version}</version>
|
<version>${jackson.version}</version>
|
||||||
<scope>provided</scope>
|
<!-- <scope>provided</scope>-->
|
||||||
</dependency>
|
</dependency>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>com.fasterxml.jackson.core</groupId>
|
<groupId>com.fasterxml.jackson.core</groupId>
|
||||||
<artifactId>jackson-annotations</artifactId>
|
<artifactId>jackson-annotations</artifactId>
|
||||||
<version>${jackson.version}</version>
|
<version>${jackson.version}</version>
|
||||||
<scope>provided</scope>
|
<!-- <scope>provided</scope>-->
|
||||||
</dependency>
|
</dependency>
|
||||||
|
|
||||||
<dependency>
|
<dependency>
|
||||||
|
@ -388,25 +390,25 @@
|
||||||
<groupId>org.apache.spark</groupId>
|
<groupId>org.apache.spark</groupId>
|
||||||
<artifactId>spark-core_2.11</artifactId>
|
<artifactId>spark-core_2.11</artifactId>
|
||||||
<version>${spark.version}</version>
|
<version>${spark.version}</version>
|
||||||
<scope>provided</scope>
|
<!-- <scope>provided</scope>-->
|
||||||
</dependency>
|
</dependency>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.apache.spark</groupId>
|
<groupId>org.apache.spark</groupId>
|
||||||
<artifactId>spark-graphx_2.11</artifactId>
|
<artifactId>spark-graphx_2.11</artifactId>
|
||||||
<version>${spark.version}</version>
|
<version>${spark.version}</version>
|
||||||
<scope>provided</scope>
|
<!-- <scope>provided</scope>-->
|
||||||
</dependency>
|
</dependency>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.apache.spark</groupId>
|
<groupId>org.apache.spark</groupId>
|
||||||
<artifactId>spark-sql_2.11</artifactId>
|
<artifactId>spark-sql_2.11</artifactId>
|
||||||
<version>${spark.version}</version>
|
<version>${spark.version}</version>
|
||||||
<scope>provided</scope>
|
<!-- <scope>provided</scope>-->
|
||||||
</dependency>
|
</dependency>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.apache.spark</groupId>
|
<groupId>org.apache.spark</groupId>
|
||||||
<artifactId>spark-mllib_2.11</artifactId>
|
<artifactId>spark-mllib_2.11</artifactId>
|
||||||
<version>${spark.version}</version>
|
<version>${spark.version}</version>
|
||||||
<scope>provided</scope>
|
<!-- <scope>provided</scope>-->
|
||||||
</dependency>
|
</dependency>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.junit.jupiter</groupId>
|
<groupId>org.junit.jupiter</groupId>
|
||||||
|
@ -451,92 +453,79 @@
|
||||||
</dependency>
|
</dependency>
|
||||||
|
|
||||||
<!--DEEPLEARNING4J-->
|
<!--DEEPLEARNING4J-->
|
||||||
<dependency>
|
<!-- <dependency>-->
|
||||||
<groupId>org.nd4j</groupId>
|
<!-- <groupId>org.nd4j</groupId>-->
|
||||||
<artifactId>${nd4j.backend}</artifactId>
|
<!-- <artifactId>${nd4j.backend}</artifactId>-->
|
||||||
<version>${dl4j-master.version}</version>
|
<!-- <version>${dl4j-master.version}</version>-->
|
||||||
</dependency>
|
<!-- </dependency>-->
|
||||||
|
|
||||||
|
<!-- <dependency>-->
|
||||||
|
<!-- <groupId>org.datavec</groupId>-->
|
||||||
|
<!-- <artifactId>datavec-api</artifactId>-->
|
||||||
|
<!-- <version>${dl4j-master.version}</version>-->
|
||||||
|
<!-- </dependency>-->
|
||||||
|
<!-- <dependency>-->
|
||||||
|
<!-- <groupId>org.datavec</groupId>-->
|
||||||
|
<!-- <artifactId>datavec-data-image</artifactId>-->
|
||||||
|
<!-- <version>${dl4j-master.version}</version>-->
|
||||||
|
<!-- </dependency>-->
|
||||||
|
<!-- <dependency>-->
|
||||||
|
<!-- <groupId>org.datavec</groupId>-->
|
||||||
|
<!-- <artifactId>datavec-local</artifactId>-->
|
||||||
|
<!-- <version>${dl4j-master.version}</version>-->
|
||||||
|
<!-- </dependency>-->
|
||||||
|
<!-- <dependency>-->
|
||||||
|
<!-- <groupId>org.deeplearning4j</groupId>-->
|
||||||
|
<!-- <artifactId>deeplearning4j-datasets</artifactId>-->
|
||||||
|
<!-- <version>${dl4j-master.version}</version>-->
|
||||||
|
<!-- </dependency>-->
|
||||||
|
<!-- <dependency>-->
|
||||||
|
<!-- <groupId>org.deeplearning4j</groupId>-->
|
||||||
|
<!-- <artifactId>deeplearning4j-core</artifactId>-->
|
||||||
|
<!-- <version>${dl4j-master.version}</version>-->
|
||||||
|
<!-- </dependency>-->
|
||||||
|
|
||||||
|
<!-- <dependency>-->
|
||||||
|
<!-- <groupId>org.deeplearning4j</groupId>-->
|
||||||
|
<!-- <artifactId>resources</artifactId>-->
|
||||||
|
<!-- <version>${dl4j-master.version}</version>-->
|
||||||
|
<!-- </dependency>-->
|
||||||
|
|
||||||
|
<!-- <dependency>-->
|
||||||
|
<!-- <groupId>org.deeplearning4j</groupId>-->
|
||||||
|
<!-- <artifactId>deeplearning4j-ui</artifactId>-->
|
||||||
|
<!-- <version>${dl4j-master.version}</version>-->
|
||||||
|
<!-- </dependency>-->
|
||||||
|
<!-- <dependency>-->
|
||||||
|
<!-- <groupId>org.deeplearning4j</groupId>-->
|
||||||
|
<!-- <artifactId>deeplearning4j-zoo</artifactId>-->
|
||||||
|
<!-- <version>${dl4j-master.version}</version>-->
|
||||||
|
<!-- </dependency>-->
|
||||||
|
<!-- <dependency>-->
|
||||||
|
<!-- <groupId>org.deeplearning4j</groupId>-->
|
||||||
|
<!-- <artifactId>dl4j-spark-parameterserver_2.11</artifactId>-->
|
||||||
|
<!-- <version>${dl4j-master.version}</version>-->
|
||||||
|
<!-- </dependency>-->
|
||||||
|
<!-- <dependency>-->
|
||||||
|
<!-- <groupId>org.deeplearning4j</groupId>-->
|
||||||
|
<!-- <artifactId>dl4j-spark_2.11</artifactId>-->
|
||||||
|
<!-- <version>${dl4j-master.version}</version>-->
|
||||||
|
<!-- </dependency>-->
|
||||||
|
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.datavec</groupId>
|
<groupId>com.johnsnowlabs.nlp</groupId>
|
||||||
<artifactId>datavec-api</artifactId>
|
<artifactId>spark-nlp_${scala.binary.version}</artifactId>
|
||||||
<version>${dl4j-master.version}</version>
|
<version>2.7.5</version>
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.datavec</groupId>
|
|
||||||
<artifactId>datavec-data-image</artifactId>
|
|
||||||
<version>${dl4j-master.version}</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.datavec</groupId>
|
|
||||||
<artifactId>datavec-local</artifactId>
|
|
||||||
<version>${dl4j-master.version}</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.deeplearning4j</groupId>
|
|
||||||
<artifactId>deeplearning4j-datasets</artifactId>
|
|
||||||
<version>${dl4j-master.version}</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.deeplearning4j</groupId>
|
|
||||||
<artifactId>deeplearning4j-core</artifactId>
|
|
||||||
<version>${dl4j-master.version}</version>
|
|
||||||
</dependency>
|
</dependency>
|
||||||
|
|
||||||
<dependency>
|
<!-- <dependency>-->
|
||||||
<groupId>org.deeplearning4j</groupId>
|
<!-- <groupId>org.json4s</groupId>-->
|
||||||
<artifactId>resources</artifactId>
|
<!-- <artifactId>json4s-jackson_${scala.binary.version}</artifactId>-->
|
||||||
<version>${dl4j-master.version}</version>
|
<!-- <version>3.5.3</version>-->
|
||||||
</dependency>
|
<!-- </dependency>-->
|
||||||
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.deeplearning4j</groupId>
|
|
||||||
<artifactId>deeplearning4j-ui</artifactId>
|
|
||||||
<version>${dl4j-master.version}</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.deeplearning4j</groupId>
|
|
||||||
<artifactId>deeplearning4j-zoo</artifactId>
|
|
||||||
<version>${dl4j-master.version}</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.deeplearning4j</groupId>
|
|
||||||
<artifactId>dl4j-spark-parameterserver_2.11</artifactId>
|
|
||||||
<version>${dl4j-master.version}</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.deeplearning4j</groupId>
|
|
||||||
<artifactId>dl4j-spark_2.11</artifactId>
|
|
||||||
<version>${dl4j-master.version}</version>
|
|
||||||
</dependency>
|
|
||||||
|
|
||||||
<!--PLOT-->
|
|
||||||
<dependency>
|
|
||||||
<groupId>jfree</groupId>
|
|
||||||
<artifactId>jfreechart</artifactId>
|
|
||||||
<version>1.0.13</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.jfree</groupId>
|
|
||||||
<artifactId>jcommon</artifactId>
|
|
||||||
<version>1.0.23</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.deeplearning4j</groupId>
|
|
||||||
<artifactId>deeplearning4j-datasets</artifactId>
|
|
||||||
<version>${dl4j-master.version}</version>
|
|
||||||
</dependency>
|
|
||||||
|
|
||||||
<!--DNET DEDUP-->
|
|
||||||
<dependency>
|
|
||||||
<groupId>eu.dnetlib</groupId>
|
|
||||||
<artifactId>dnet-dedup-test</artifactId>
|
|
||||||
<version>4.1.13-SNAPSHOT</version>
|
|
||||||
</dependency>
|
|
||||||
|
|
||||||
</dependencies>
|
</dependencies>
|
||||||
|
|
||||||
|
|
||||||
</dependencyManagement>
|
</dependencyManagement>
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
|
|
Loading…
Reference in New Issue