108 lines
4.6 KiB
Java
108 lines
4.6 KiB
Java
package eu.dnetlib.jobs.featureextraction;
|
|
|
|
import eu.dnetlib.dhp.schema.oaf.Publication;
|
|
import eu.dnetlib.dhp.schema.oaf.StructuredProperty;
|
|
import eu.dnetlib.featureextraction.ScalaFeatureTransformer;
|
|
import eu.dnetlib.jobs.AbstractSparkJob;
|
|
import eu.dnetlib.support.ArgumentApplicationParser;
|
|
import org.apache.spark.SparkConf;
|
|
import org.apache.spark.api.java.JavaRDD;
|
|
import org.apache.spark.api.java.JavaSparkContext;
|
|
import org.apache.spark.sql.*;
|
|
import org.apache.spark.sql.types.DataTypes;
|
|
import org.apache.spark.sql.types.Metadata;
|
|
import org.apache.spark.sql.types.StructField;
|
|
import org.apache.spark.sql.types.StructType;
|
|
import org.codehaus.jackson.map.DeserializationConfig;
|
|
import org.codehaus.jackson.map.ObjectMapper;
|
|
import org.slf4j.Logger;
|
|
import org.slf4j.LoggerFactory;
|
|
|
|
import java.io.IOException;
|
|
import java.util.Optional;
|
|
import java.util.stream.Collectors;
|
|
|
|
public class SparkPublicationFeatureExtractor extends AbstractSparkJob {
|
|
private static final Logger log = LoggerFactory.getLogger(SparkPublicationFeatureExtractor.class);
|
|
|
|
public SparkPublicationFeatureExtractor(ArgumentApplicationParser parser, SparkSession spark) {
|
|
super(parser, spark);
|
|
}
|
|
|
|
public static void main(String[] args) throws Exception {
|
|
|
|
ArgumentApplicationParser parser = new ArgumentApplicationParser(
|
|
readResource("/jobs/parameters/publicationFeatureExtractor_parameters.json", SparkPublicationFeatureExtractor.class)
|
|
);
|
|
|
|
parser.parseArgument(args);
|
|
|
|
SparkConf conf = new SparkConf();
|
|
|
|
new SparkAuthorExtractor(
|
|
parser,
|
|
getSparkSession(conf)
|
|
).run();
|
|
}
|
|
|
|
@Override
|
|
public void run() throws IOException {
|
|
// read oozie parameters
|
|
final String publicationsPath = parser.get("publicationsPath");
|
|
final String workingPath = parser.get("workingPath");
|
|
final String featuresPath = parser.get("featuresPath");
|
|
final String bertModel = parser.get("bertModelPath");
|
|
final String bertSentenceModel = parser.get("bertSentenceModel");
|
|
final String wordEmbeddingModel = parser.get("wordEmbeddingModel");
|
|
final int numPartitions = Optional
|
|
.ofNullable(parser.get("numPartitions"))
|
|
.map(Integer::valueOf)
|
|
.orElse(NUM_PARTITIONS);
|
|
|
|
log.info("publicationsPath: '{}'", publicationsPath);
|
|
log.info("workingPath: '{}'", workingPath);
|
|
log.info("numPartitions: '{}'", numPartitions);
|
|
|
|
JavaSparkContext context = JavaSparkContext.fromSparkContext(spark.sparkContext());
|
|
|
|
JavaRDD<Publication> publications = context
|
|
.textFile(publicationsPath)
|
|
.map(x -> new ObjectMapper()
|
|
.configure(DeserializationConfig.Feature.FAIL_ON_UNKNOWN_PROPERTIES, false)
|
|
.readValue(x, Publication.class));
|
|
|
|
StructType inputSchema = new StructType(new StructField[]{
|
|
new StructField("id", DataTypes.StringType, false, Metadata.empty()),
|
|
new StructField("title", DataTypes.StringType, false, Metadata.empty()),
|
|
new StructField("abstract", DataTypes.StringType, false, Metadata.empty()),
|
|
new StructField("subjects", DataTypes.StringType, false, Metadata.empty())
|
|
});
|
|
|
|
//prepare Rows
|
|
Dataset<Row> inputData = spark.createDataFrame(
|
|
publications.map(p -> RowFactory.create(
|
|
p.getId(),
|
|
p.getTitle().get(0).getValue(),
|
|
p.getDescription().size()>0? p.getDescription().get(0).getValue(): "",
|
|
p.getSubject().stream().map(StructuredProperty::getValue).collect(Collectors.joining(" ")))),
|
|
inputSchema);
|
|
|
|
log.info("Generating word embeddings");
|
|
Dataset<Row> wordEmbeddingsData = ScalaFeatureTransformer.wordEmbeddings(inputData, "subjects", wordEmbeddingModel);
|
|
|
|
log.info("Generating bert embeddings");
|
|
Dataset<Row> bertEmbeddingsData = ScalaFeatureTransformer.bertEmbeddings(wordEmbeddingsData, "title", bertModel);
|
|
|
|
log.info("Generating bert sentence embeddings");
|
|
Dataset<Row> bertSentenceEmbeddingsData = ScalaFeatureTransformer.bertSentenceEmbeddings(bertEmbeddingsData, "abstract", bertSentenceModel);
|
|
|
|
Dataset<Row> features = bertSentenceEmbeddingsData.select("id", ScalaFeatureTransformer.WORD_EMBEDDINGS_COL(), ScalaFeatureTransformer.BERT_EMBEDDINGS_COL(), ScalaFeatureTransformer.BERT_SENTENCE_EMBEDDINGS_COL());
|
|
|
|
features
|
|
.write()
|
|
.mode(SaveMode.Overwrite)
|
|
.save(featuresPath);
|
|
|
|
}
|
|
}
|