104 lines
4.1 KiB
Java
104 lines
4.1 KiB
Java
package eu.dnetlib.jobs.featureextraction;
|
|
|
|
import eu.dnetlib.dhp.schema.oaf.Publication;
|
|
import eu.dnetlib.jobs.AbstractSparkJob;
|
|
import eu.dnetlib.jobs.SparkTokenizer;
|
|
import eu.dnetlib.support.ArgumentApplicationParser;
|
|
import eu.dnetlib.support.Author;
|
|
import eu.dnetlib.support.AuthorsFactory;
|
|
import org.apache.hadoop.io.compress.GzipCodec;
|
|
import org.apache.spark.SparkConf;
|
|
import org.apache.spark.api.java.JavaPairRDD;
|
|
import org.apache.spark.api.java.JavaRDD;
|
|
import org.apache.spark.api.java.JavaSparkContext;
|
|
import org.apache.spark.ml.linalg.DenseVector;
|
|
import org.apache.spark.sql.SparkSession;
|
|
import org.codehaus.jackson.map.DeserializationConfig;
|
|
import org.codehaus.jackson.map.ObjectMapper;
|
|
import org.slf4j.Logger;
|
|
import org.slf4j.LoggerFactory;
|
|
import scala.Tuple2;
|
|
|
|
import java.io.IOException;
|
|
import java.util.HashMap;
|
|
import java.util.Map;
|
|
import java.util.Optional;
|
|
|
|
public class SparkAuthorExtractor extends AbstractSparkJob {
|
|
private static final Logger log = LoggerFactory.getLogger(SparkAuthorExtractor.class);
|
|
|
|
public SparkAuthorExtractor(ArgumentApplicationParser parser, SparkSession spark) {
|
|
super(parser, spark);
|
|
}
|
|
|
|
public static void main(String[] args) throws Exception {
|
|
|
|
ArgumentApplicationParser parser = new ArgumentApplicationParser(
|
|
readResource("/jobs/parameters/authorExtractor_parameters.json", SparkTokenizer.class)
|
|
);
|
|
|
|
parser.parseArgument(args);
|
|
|
|
SparkConf conf = new SparkConf();
|
|
|
|
new SparkAuthorExtractor(
|
|
parser,
|
|
getSparkSession(conf)
|
|
).run();
|
|
}
|
|
|
|
@Override
|
|
public void run() throws IOException {
|
|
// read oozie parameters
|
|
final String topicsPath = parser.get("topicsPath");
|
|
final String featuresPath = parser.get("featuresPath");
|
|
final String publicationsPath = parser.get("publicationsPath");
|
|
final String workingPath = parser.get("workingPath");
|
|
final String outputPath = parser.get("outputPath");
|
|
final int numPartitions = Optional
|
|
.ofNullable(parser.get("numPartitions"))
|
|
.map(Integer::valueOf)
|
|
.orElse(NUM_PARTITIONS);
|
|
|
|
log.info("publicationsPath: '{}'", publicationsPath);
|
|
log.info("topicsPath: '{}'", topicsPath);
|
|
log.info("featuresPath: '{}'", featuresPath);
|
|
log.info("workingPath: '{}'", workingPath);
|
|
log.info("outputPath: '{}'", outputPath);
|
|
log.info("numPartitions: '{}'", numPartitions);
|
|
|
|
//join publications with topics
|
|
JavaSparkContext context = JavaSparkContext.fromSparkContext(spark.sparkContext());
|
|
|
|
JavaRDD<Publication> publications = context
|
|
.textFile(publicationsPath)
|
|
.map(x -> new ObjectMapper()
|
|
.configure(DeserializationConfig.Feature.FAIL_ON_UNKNOWN_PROPERTIES, false)
|
|
.readValue(x, Publication.class));
|
|
|
|
JavaPairRDD<String, double[]> topics = spark.read().load(topicsPath).toJavaRDD()
|
|
.mapToPair(t -> new Tuple2<>(t.getString(0), ((DenseVector) t.get(1)).toArray()));
|
|
|
|
//merge topics with other embeddings
|
|
JavaPairRDD<String, Map<String, double[]>> publicationEmbeddings = spark.read().load(featuresPath).toJavaRDD().mapToPair(t -> {
|
|
Map<String, double[]> embeddings = new HashMap<>();
|
|
embeddings.put("word_embeddings", ((DenseVector) t.get(1)).toArray());
|
|
embeddings.put("bert_embeddings", ((DenseVector) t.get(2)).toArray());
|
|
embeddings.put("bert_sentence_embeddings", ((DenseVector) t.get(3)).toArray());
|
|
return new Tuple2<>(t.getString(0), embeddings);
|
|
})
|
|
.join(topics).mapToPair(e -> {
|
|
e._2()._1().put("lda_topics", e._2()._2());
|
|
return new Tuple2<>(e._1(), e._2()._1());
|
|
});
|
|
|
|
JavaRDD<Author> authors = AuthorsFactory.extractAuthorsFromPublications(publications, publicationEmbeddings);
|
|
|
|
authors
|
|
.map(a -> new ObjectMapper().writeValueAsString(a))
|
|
.saveAsTextFile(outputPath, GzipCodec.class);
|
|
|
|
}
|
|
|
|
}
|