dhp-graph-dump/dump/src/main/java/eu/dnetlib/dhp/skgif/ExtendResult.java

135 lines
4.2 KiB
Java

package eu.dnetlib.dhp.skgif;
import static eu.dnetlib.dhp.common.SparkSessionSupport.runWithSparkSession;
import java.io.Serializable;
import java.util.*;
import org.apache.commons.io.IOUtils;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.function.MapFunction;
import org.apache.spark.api.java.function.MapGroupsFunction;
import org.apache.spark.sql.*;
import org.apache.spark.sql.types.StructType;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import com.fasterxml.jackson.databind.ObjectMapper;
import eu.dnetlib.dhp.application.ArgumentApplicationParser;
import eu.dnetlib.dhp.oa.graph.dump.Utils;
import eu.dnetlib.dhp.skgif.model.RelationType;
import eu.dnetlib.dhp.skgif.model.Relations;
import eu.dnetlib.dhp.skgif.model.ResearchProduct;
import scala.Tuple2;
/**
* @author miriam.baglioni
* @Date 05/09/23
*/
public class ExtendResult implements Serializable {
private static final Logger log = LoggerFactory.getLogger(ExtendResult.class);
public static void main(String[] args) throws Exception {
String jsonConfiguration = IOUtils
.toString(
PrepareResultRelation.class
.getResourceAsStream(
"/eu/dnetlib/dhp/oa/graph/dump/extend_result_parameters.json"));
final ArgumentApplicationParser parser = new ArgumentApplicationParser(jsonConfiguration);
parser.parseArgument(args);
Boolean isSparkSessionManaged = Optional
.ofNullable(parser.get("isSparkSessionManaged"))
.map(Boolean::valueOf)
.orElse(Boolean.TRUE);
log.info("isSparkSessionManaged: {}", isSparkSessionManaged);
final String inputPath = parser.get("sourcePath");
log.info("inputPath: {}", inputPath);
final String outputPath = parser.get("outputPath");
log.info("outputPath: {}", outputPath);
SparkConf conf = new SparkConf();
runWithSparkSession(
conf,
isSparkSessionManaged,
spark -> {
Utils.removeOutputDir(spark, outputPath);
extendResult(spark, inputPath, outputPath);
});
}
private static void extendResult(SparkSession spark, String inputPath, String outputPath) {
ObjectMapper mapper = new ObjectMapper();
Dataset<ResearchProduct> result = spark
.read()
.json(inputPath + "/result")
.as(Encoders.bean(ResearchProduct.class));
final StructType structureSchema = new StructType()
.fromDDL("`resultId` STRING, `target` STRING, `resultClass` STRING");
Dataset<Row> relations = spark
.read()
.schema(structureSchema)
.json(inputPath + "/preparedRelations");
result
.joinWith(
relations, result
.col("localIdentifier")
.equalTo(relations.col("resultId")),
"left")
.groupByKey(
(MapFunction<Tuple2<ResearchProduct, Row>, String>) t2 -> t2._1().getLocalIdentifier(),
Encoders.STRING())
.mapGroups((MapGroupsFunction<String, Tuple2<ResearchProduct, Row>, ResearchProduct>) (key, it) -> {
Tuple2<ResearchProduct, Row> first = it.next();
ResearchProduct rp = first._1();
addRels(rp, first._2());
it.forEachRemaining(t2 -> addRels(rp, t2._2()));
return rp;
}, Encoders.bean(ResearchProduct.class))
.map((MapFunction<ResearchProduct, String>) r -> mapper.writeValueAsString(r), Encoders.STRING())
.write()
.mode(SaveMode.Overwrite)
.option("compression", "gzip")
.text(outputPath);
}
private static void addRels(ResearchProduct rp, Row row) {
String relClass = row.getAs("relClass");
Map<String, List<String>> relations = new HashMap<>();
if (relClass.equals(RelationType.OUTCOME.label)) {
if (!Optional.ofNullable(rp.getFunding()).isPresent()) {
rp.setFunding(new ArrayList<>());
}
rp.getFunding().add(row.getAs("target"));
} else if (relClass.equals(RelationType.AFFILIATION)) {
if (!Optional.ofNullable(rp.getRelevantOrganizations()).isPresent())
rp.setRelevantOrganizations(new ArrayList<>());
rp.getRelevantOrganizations().add(row.getAs("target"));
} else {
if (!relations.containsKey(relClass)) {
relations.put(relClass, new ArrayList<>());
}
relations.get(relClass).add(row.getAs("target"));
}
if (relations.size() > 0) {
rp.setRelatedProducts(new ArrayList<>());
for (String key : relations.keySet()) {
Relations rel = new Relations();
rel.setRelationType(key);
rel.setProductList(relations.get(key));
}
}
}
}