package eu.dnetlib.jobs.featureextraction; import eu.dnetlib.dhp.schema.oaf.Publication; import eu.dnetlib.dhp.schema.oaf.StructuredProperty; import eu.dnetlib.featureextraction.ScalaFeatureTransformer; import eu.dnetlib.jobs.AbstractSparkJob; import eu.dnetlib.support.ArgumentApplicationParser; import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.sql.*; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.Metadata; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; import org.codehaus.jackson.map.DeserializationConfig; import org.codehaus.jackson.map.ObjectMapper; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.IOException; import java.util.Optional; import java.util.stream.Collectors; public class SparkPublicationFeatureExtractor extends AbstractSparkJob { private static final Logger log = LoggerFactory.getLogger(SparkPublicationFeatureExtractor.class); public SparkPublicationFeatureExtractor(ArgumentApplicationParser parser, SparkSession spark) { super(parser, spark); } public static void main(String[] args) throws Exception { ArgumentApplicationParser parser = new ArgumentApplicationParser( readResource("/jobs/parameters/publicationFeatureExtractor_parameters.json", SparkPublicationFeatureExtractor.class) ); parser.parseArgument(args); SparkConf conf = new SparkConf(); new SparkAuthorExtractor( parser, getSparkSession(conf) ).run(); } @Override public void run() throws IOException { // read oozie parameters final String publicationsPath = parser.get("publicationsPath"); final String workingPath = parser.get("workingPath"); final String featuresPath = parser.get("featuresPath"); final String bertModel = parser.get("bertModelPath"); final String bertSentenceModel = parser.get("bertSentenceModel"); final String wordEmbeddingModel = parser.get("wordEmbeddingModel"); final int numPartitions = Optional .ofNullable(parser.get("numPartitions")) .map(Integer::valueOf) .orElse(NUM_PARTITIONS); log.info("publicationsPath: '{}'", publicationsPath); log.info("workingPath: '{}'", workingPath); log.info("numPartitions: '{}'", numPartitions); JavaSparkContext context = JavaSparkContext.fromSparkContext(spark.sparkContext()); JavaRDD publications = context .textFile(publicationsPath) .map(x -> new ObjectMapper() .configure(DeserializationConfig.Feature.FAIL_ON_UNKNOWN_PROPERTIES, false) .readValue(x, Publication.class)); StructType inputSchema = new StructType(new StructField[]{ new StructField("id", DataTypes.StringType, false, Metadata.empty()), new StructField("title", DataTypes.StringType, false, Metadata.empty()), new StructField("abstract", DataTypes.StringType, false, Metadata.empty()), new StructField("subjects", DataTypes.StringType, false, Metadata.empty()) }); //prepare Rows Dataset inputData = spark.createDataFrame( publications.map(p -> RowFactory.create( p.getId(), p.getTitle().get(0).getValue(), p.getDescription().size()>0? p.getDescription().get(0).getValue(): "", p.getSubject().stream().map(StructuredProperty::getValue).collect(Collectors.joining(" ")))), inputSchema); log.info("Generating word embeddings"); Dataset wordEmbeddingsData = ScalaFeatureTransformer.wordEmbeddings(inputData, "subjects", wordEmbeddingModel); log.info("Generating bert embeddings"); Dataset bertEmbeddingsData = ScalaFeatureTransformer.bertEmbeddings(wordEmbeddingsData, "title", bertModel); log.info("Generating bert sentence embeddings"); Dataset bertSentenceEmbeddingsData = ScalaFeatureTransformer.bertSentenceEmbeddings(bertEmbeddingsData, "abstract", bertSentenceModel); Dataset features = bertSentenceEmbeddingsData.select("id", ScalaFeatureTransformer.WORD_EMBEDDINGS_COL(), ScalaFeatureTransformer.BERT_EMBEDDINGS_COL(), ScalaFeatureTransformer.BERT_SENTENCE_EMBEDDINGS_COL()); features .write() .mode(SaveMode.Overwrite) .save(featuresPath); } }