dnet-and/dnet-and-test/src/main/java/eu/dnetlib/jobs/featureextraction/SparkPublicationFeatureExtr...

108 lines
4.6 KiB
Java

package eu.dnetlib.jobs.featureextraction;
import eu.dnetlib.dhp.schema.oaf.Publication;
import eu.dnetlib.dhp.schema.oaf.StructuredProperty;
import eu.dnetlib.featureextraction.ScalaFeatureTransformer;
import eu.dnetlib.jobs.AbstractSparkJob;
import eu.dnetlib.support.ArgumentApplicationParser;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.*;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.codehaus.jackson.map.DeserializationConfig;
import org.codehaus.jackson.map.ObjectMapper;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.util.Optional;
import java.util.stream.Collectors;
public class SparkPublicationFeatureExtractor extends AbstractSparkJob {
private static final Logger log = LoggerFactory.getLogger(SparkPublicationFeatureExtractor.class);
public SparkPublicationFeatureExtractor(ArgumentApplicationParser parser, SparkSession spark) {
super(parser, spark);
}
public static void main(String[] args) throws Exception {
ArgumentApplicationParser parser = new ArgumentApplicationParser(
readResource("/jobs/parameters/publicationFeatureExtractor_parameters.json", SparkPublicationFeatureExtractor.class)
);
parser.parseArgument(args);
SparkConf conf = new SparkConf();
new SparkAuthorExtractor(
parser,
getSparkSession(conf)
).run();
}
@Override
public void run() throws IOException {
// read oozie parameters
final String publicationsPath = parser.get("publicationsPath");
final String workingPath = parser.get("workingPath");
final String featuresPath = parser.get("featuresPath");
final String bertModel = parser.get("bertModelPath");
final String bertSentenceModel = parser.get("bertSentenceModel");
final String wordEmbeddingModel = parser.get("wordEmbeddingModel");
final int numPartitions = Optional
.ofNullable(parser.get("numPartitions"))
.map(Integer::valueOf)
.orElse(NUM_PARTITIONS);
log.info("publicationsPath: '{}'", publicationsPath);
log.info("workingPath: '{}'", workingPath);
log.info("numPartitions: '{}'", numPartitions);
JavaSparkContext context = JavaSparkContext.fromSparkContext(spark.sparkContext());
JavaRDD<Publication> publications = context
.textFile(publicationsPath)
.map(x -> new ObjectMapper()
.configure(DeserializationConfig.Feature.FAIL_ON_UNKNOWN_PROPERTIES, false)
.readValue(x, Publication.class));
StructType inputSchema = new StructType(new StructField[]{
new StructField("id", DataTypes.StringType, false, Metadata.empty()),
new StructField("title", DataTypes.StringType, false, Metadata.empty()),
new StructField("abstract", DataTypes.StringType, false, Metadata.empty()),
new StructField("subjects", DataTypes.StringType, false, Metadata.empty())
});
//prepare Rows
Dataset<Row> inputData = spark.createDataFrame(
publications.map(p -> RowFactory.create(
p.getId(),
p.getTitle().get(0).getValue(),
p.getDescription().size()>0? p.getDescription().get(0).getValue(): "",
p.getSubject().stream().map(StructuredProperty::getValue).collect(Collectors.joining(" ")))),
inputSchema);
log.info("Generating word embeddings");
Dataset<Row> wordEmbeddingsData = ScalaFeatureTransformer.wordEmbeddings(inputData, "subjects", wordEmbeddingModel);
log.info("Generating bert embeddings");
Dataset<Row> bertEmbeddingsData = ScalaFeatureTransformer.bertEmbeddings(wordEmbeddingsData, "title", bertModel);
log.info("Generating bert sentence embeddings");
Dataset<Row> bertSentenceEmbeddingsData = ScalaFeatureTransformer.bertSentenceEmbeddings(bertEmbeddingsData, "abstract", bertSentenceModel);
Dataset<Row> features = bertSentenceEmbeddingsData.select("id", ScalaFeatureTransformer.WORD_EMBEDDINGS_COL(), ScalaFeatureTransformer.BERT_EMBEDDINGS_COL(), ScalaFeatureTransformer.BERT_SENTENCE_EMBEDDINGS_COL());
features
.write()
.mode(SaveMode.Overwrite)
.save(featuresPath);
}
}