package eu.dnetlib.dhp.oa.graph.raw; import com.fasterxml.jackson.databind.ObjectMapper; import eu.dnetlib.dhp.application.ArgumentApplicationParser; import eu.dnetlib.dhp.common.HdfsSupport; import eu.dnetlib.dhp.schema.common.ModelSupport; import eu.dnetlib.dhp.schema.oaf.*; import org.apache.commons.io.IOUtils; import org.apache.commons.lang3.StringUtils; import org.apache.hadoop.io.compress.GzipCodec; import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.FilterFunction; import org.apache.spark.api.java.function.MapFunction; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Encoders; import org.apache.spark.sql.SaveMode; import org.apache.spark.sql.SparkSession; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import scala.Tuple2; import scala.reflect.ClassTag; import scala.reflect.ClassTag$; import java.util.Objects; import java.util.Optional; import java.util.function.Function; import static eu.dnetlib.dhp.common.SparkSessionSupport.runWithSparkSession; import static eu.dnetlib.dhp.schema.common.ModelSupport.isSubClass; public class MergeClaimsApplication { private static final Logger log = LoggerFactory.getLogger(MergeClaimsApplication.class); private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); public static void main(final String[] args) throws Exception { final ArgumentApplicationParser parser = new ArgumentApplicationParser( IOUtils.toString(MigrateMongoMdstoresApplication.class .getResourceAsStream("/eu/dnetlib/dhp/oa/graph/merge_claims_parameters.json"))); parser.parseArgument(args); Boolean isSparkSessionManaged = Optional .ofNullable(parser.get("isSparkSessionManaged")) .map(Boolean::valueOf) .orElse(Boolean.TRUE); log.info("isSparkSessionManaged: {}", isSparkSessionManaged); final String rawGraphPath = parser.get("rawGraphPath"); log.info("rawGraphPath: {}", rawGraphPath); final String claimsGraphPath = parser.get("claimsGraphPath"); log.info("claimsGraphPath: {}", claimsGraphPath); final String outputRawGaphPath = parser.get("outputRawGaphPath"); log.info("outputRawGaphPath: {}", outputRawGaphPath); String graphTableClassName = parser.get("graphTableClassName"); log.info("graphTableClassName: {}", graphTableClassName); Class clazz = (Class) Class.forName(graphTableClassName); SparkConf conf = new SparkConf(); conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer"); conf.registerKryoClasses(ModelSupport.getOafModelClasses()); runWithSparkSession(conf, isSparkSessionManaged, spark -> { String type = clazz.getSimpleName().toLowerCase(); String rawPath = rawGraphPath + "/" + type; String claimPath = claimsGraphPath + "/" + type; String outPath = outputRawGaphPath + "/" + type; removeOutputDir(spark, outPath); mergeByType(spark, rawPath, claimPath, outPath, clazz); }); } private static void mergeByType(SparkSession spark, String rawPath, String claimPath, String outPath, Class clazz) { Dataset> raw = readFromPath(spark, rawPath, clazz) .map((MapFunction>) value -> new Tuple2<>(idFn().apply(value), value), Encoders.tuple(Encoders.STRING(), Encoders.kryo(clazz))); final JavaSparkContext jsc = JavaSparkContext.fromSparkContext(spark.sparkContext()); Dataset> claim = jsc.broadcast(readFromPath(spark, claimPath, clazz)) .getValue() .map((MapFunction>) value -> new Tuple2<>(idFn().apply(value), value), Encoders.tuple(Encoders.STRING(), Encoders.kryo(clazz))); /* Dataset> claim = readFromPath(spark, claimPath, clazz) .map((MapFunction>) value -> new Tuple2<>(idFn().apply(value), value), Encoders.tuple(Encoders.STRING(), Encoders.kryo(clazz))); */ raw.joinWith(claim, raw.col("_1").equalTo(claim.col("_1")), "full_outer") .map((MapFunction, Tuple2>, T>) value -> { Optional> opRaw = Optional.ofNullable(value._1()); Optional> opClaim = Optional.ofNullable(value._2()); return opRaw.isPresent() ? opRaw.get()._2() : opClaim.isPresent() ? opClaim.get()._2() : null; }, Encoders.bean(clazz)) .filter(Objects::nonNull) .write() .mode(SaveMode.Overwrite) .parquet(outPath); } private static Dataset readFromPath(SparkSession spark, String path, Class clazz) { return spark.read() .load(path) .as(Encoders.bean(clazz)) .filter((FilterFunction) value -> Objects.nonNull(idFn().apply(value))); } private static void removeOutputDir(SparkSession spark, String path) { HdfsSupport.remove(path, spark.sparkContext().hadoopConfiguration()); } private static Function idFn() { return x -> { if (isSubClass(x, Relation.class)) { return idFnForRelation(x); } return idFnForOafEntity(x); }; } private static String idFnForRelation(T t) { Relation r = (Relation) t; return Optional.ofNullable(r.getSource()) .map(source -> Optional.ofNullable(r.getTarget()) .map(target -> Optional.ofNullable(r.getRelType()) .map(relType -> Optional.ofNullable(r.getSubRelType()) .map(subRelType -> Optional.ofNullable(r.getRelClass()) .map(relClass -> String.join(source, target, relType, subRelType, relClass)) .orElse(String.join(source, target, relType, subRelType)) ) .orElse(String.join(source, target, relType)) ) .orElse(String.join(source, target)) ) .orElse(source) ) .orElse(null); } private static String idFnForOafEntity(T t) { return ((OafEntity) t).getId(); } }