package eu.dnetlib.jobs.featureextraction; import eu.dnetlib.dhp.schema.oaf.Publication; import eu.dnetlib.jobs.AbstractSparkJob; import eu.dnetlib.jobs.SparkTokenizer; import eu.dnetlib.support.ArgumentApplicationParser; import eu.dnetlib.support.Author; import eu.dnetlib.support.AuthorsFactory; import org.apache.hadoop.io.compress.GzipCodec; import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.ml.linalg.DenseVector; import org.apache.spark.sql.SparkSession; import org.codehaus.jackson.map.DeserializationConfig; import org.codehaus.jackson.map.ObjectMapper; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import scala.Tuple2; import java.io.IOException; import java.util.HashMap; import java.util.Map; import java.util.Optional; public class SparkAuthorExtractor extends AbstractSparkJob { private static final Logger log = LoggerFactory.getLogger(SparkAuthorExtractor.class); public SparkAuthorExtractor(ArgumentApplicationParser parser, SparkSession spark) { super(parser, spark); } public static void main(String[] args) throws Exception { ArgumentApplicationParser parser = new ArgumentApplicationParser( readResource("/jobs/parameters/authorExtractor_parameters.json", SparkTokenizer.class) ); parser.parseArgument(args); SparkConf conf = new SparkConf(); new SparkAuthorExtractor( parser, getSparkSession(conf) ).run(); } @Override public void run() throws IOException { // read oozie parameters final String topicsPath = parser.get("topicsPath"); final String featuresPath = parser.get("featuresPath"); final String publicationsPath = parser.get("publicationsPath"); final String workingPath = parser.get("workingPath"); final String outputPath = parser.get("outputPath"); final int numPartitions = Optional .ofNullable(parser.get("numPartitions")) .map(Integer::valueOf) .orElse(NUM_PARTITIONS); log.info("publicationsPath: '{}'", publicationsPath); log.info("topicsPath: '{}'", topicsPath); log.info("featuresPath: '{}'", featuresPath); log.info("workingPath: '{}'", workingPath); log.info("outputPath: '{}'", outputPath); log.info("numPartitions: '{}'", numPartitions); //join publications with topics JavaSparkContext context = JavaSparkContext.fromSparkContext(spark.sparkContext()); JavaRDD publications = context .textFile(publicationsPath) .map(x -> new ObjectMapper() .configure(DeserializationConfig.Feature.FAIL_ON_UNKNOWN_PROPERTIES, false) .readValue(x, Publication.class)); JavaPairRDD topics = spark.read().load(topicsPath).toJavaRDD() .mapToPair(t -> new Tuple2<>(t.getString(0), ((DenseVector) t.get(1)).toArray())); //merge topics with other embeddings JavaPairRDD> publicationEmbeddings = spark.read().load(featuresPath).toJavaRDD().mapToPair(t -> { Map embeddings = new HashMap<>(); embeddings.put("word_embeddings", ((DenseVector) t.get(1)).toArray()); embeddings.put("bert_embeddings", ((DenseVector) t.get(2)).toArray()); embeddings.put("bert_sentence_embeddings", ((DenseVector) t.get(3)).toArray()); return new Tuple2<>(t.getString(0), embeddings); }) .join(topics).mapToPair(e -> { e._2()._1().put("lda_topics", e._2()._2()); return new Tuple2<>(e._1(), e._2()._1()); }); JavaRDD authors = AuthorsFactory.extractAuthorsFromPublications(publications, publicationEmbeddings); authors .map(a -> new ObjectMapper().writeValueAsString(a)) .saveAsTextFile(outputPath, GzipCodec.class); } }