dnet-and/dnet-and-test/src/main/java/eu/dnetlib/jobs/featureextraction/SparkAuthorExtractor.java

104 lines
4.1 KiB
Java

package eu.dnetlib.jobs.featureextraction;
import eu.dnetlib.dhp.schema.oaf.Publication;
import eu.dnetlib.jobs.AbstractSparkJob;
import eu.dnetlib.jobs.SparkTokenizer;
import eu.dnetlib.support.ArgumentApplicationParser;
import eu.dnetlib.support.Author;
import eu.dnetlib.support.AuthorsFactory;
import org.apache.hadoop.io.compress.GzipCodec;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.ml.linalg.DenseVector;
import org.apache.spark.sql.SparkSession;
import org.codehaus.jackson.map.DeserializationConfig;
import org.codehaus.jackson.map.ObjectMapper;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.Tuple2;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
public class SparkAuthorExtractor extends AbstractSparkJob {
private static final Logger log = LoggerFactory.getLogger(SparkAuthorExtractor.class);
public SparkAuthorExtractor(ArgumentApplicationParser parser, SparkSession spark) {
super(parser, spark);
}
public static void main(String[] args) throws Exception {
ArgumentApplicationParser parser = new ArgumentApplicationParser(
readResource("/jobs/parameters/authorExtractor_parameters.json", SparkTokenizer.class)
);
parser.parseArgument(args);
SparkConf conf = new SparkConf();
new SparkAuthorExtractor(
parser,
getSparkSession(conf)
).run();
}
@Override
public void run() throws IOException {
// read oozie parameters
final String topicsPath = parser.get("topicsPath");
final String featuresPath = parser.get("featuresPath");
final String publicationsPath = parser.get("publicationsPath");
final String workingPath = parser.get("workingPath");
final String outputPath = parser.get("outputPath");
final int numPartitions = Optional
.ofNullable(parser.get("numPartitions"))
.map(Integer::valueOf)
.orElse(NUM_PARTITIONS);
log.info("publicationsPath: '{}'", publicationsPath);
log.info("topicsPath: '{}'", topicsPath);
log.info("featuresPath: '{}'", featuresPath);
log.info("workingPath: '{}'", workingPath);
log.info("outputPath: '{}'", outputPath);
log.info("numPartitions: '{}'", numPartitions);
//join publications with topics
JavaSparkContext context = JavaSparkContext.fromSparkContext(spark.sparkContext());
JavaRDD<Publication> publications = context
.textFile(publicationsPath)
.map(x -> new ObjectMapper()
.configure(DeserializationConfig.Feature.FAIL_ON_UNKNOWN_PROPERTIES, false)
.readValue(x, Publication.class));
JavaPairRDD<String, double[]> topics = spark.read().load(topicsPath).toJavaRDD()
.mapToPair(t -> new Tuple2<>(t.getString(0), ((DenseVector) t.get(1)).toArray()));
//merge topics with other embeddings
JavaPairRDD<String, Map<String, double[]>> publicationEmbeddings = spark.read().load(featuresPath).toJavaRDD().mapToPair(t -> {
Map<String, double[]> embeddings = new HashMap<>();
embeddings.put("word_embeddings", ((DenseVector) t.get(1)).toArray());
embeddings.put("bert_embeddings", ((DenseVector) t.get(2)).toArray());
embeddings.put("bert_sentence_embeddings", ((DenseVector) t.get(3)).toArray());
return new Tuple2<>(t.getString(0), embeddings);
})
.join(topics).mapToPair(e -> {
e._2()._1().put("lda_topics", e._2()._2());
return new Tuple2<>(e._1(), e._2()._1());
});
JavaRDD<Author> authors = AuthorsFactory.extractAuthorsFromPublications(publications, publicationEmbeddings);
authors
.map(a -> new ObjectMapper().writeValueAsString(a))
.saveAsTextFile(outputPath, GzipCodec.class);
}
}