diff --git a/dhp-workflows/dhp-graph-provision/src/main/java/eu/dnetlib/dhp/oa/provision/PrepareRelationsJob.java b/dhp-workflows/dhp-graph-provision/src/main/java/eu/dnetlib/dhp/oa/provision/PrepareRelationsJob.java index fdf397ad7..c2eb8c408 100644 --- a/dhp-workflows/dhp-graph-provision/src/main/java/eu/dnetlib/dhp/oa/provision/PrepareRelationsJob.java +++ b/dhp-workflows/dhp-graph-provision/src/main/java/eu/dnetlib/dhp/oa/provision/PrepareRelationsJob.java @@ -1,43 +1,31 @@ package eu.dnetlib.dhp.oa.provision; -import static eu.dnetlib.dhp.common.SparkSessionSupport.runWithSparkSession; - -import java.util.HashSet; -import java.util.Optional; -import java.util.PriorityQueue; -import java.util.Set; -import java.util.stream.Collectors; - -import org.apache.commons.io.IOUtils; -import org.apache.commons.lang3.StringUtils; -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.api.java.function.FilterFunction; -import org.apache.spark.api.java.function.FlatMapFunction; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.api.java.function.MapFunction; -import org.apache.spark.sql.Encoder; -import org.apache.spark.sql.Encoders; -import org.apache.spark.sql.SaveMode; -import org.apache.spark.sql.SparkSession; -import org.apache.spark.sql.expressions.Aggregator; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.base.Joiner; import com.google.common.base.Splitter; -import com.google.common.collect.Iterables; import com.google.common.collect.Sets; - import eu.dnetlib.dhp.application.ArgumentApplicationParser; import eu.dnetlib.dhp.common.HdfsSupport; import eu.dnetlib.dhp.oa.provision.model.ProvisionModelSupport; -import eu.dnetlib.dhp.oa.provision.model.SortableRelationKey; -import eu.dnetlib.dhp.oa.provision.utils.RelationPartitioner; import eu.dnetlib.dhp.schema.oaf.Relation; -import scala.Tuple2; +import org.apache.commons.io.IOUtils; +import org.apache.spark.SparkConf; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.SaveMode; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.expressions.Window; +import org.apache.spark.sql.expressions.WindowSpec; +import org.apache.spark.sql.functions; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.HashSet; +import java.util.Optional; +import java.util.Set; + +import static eu.dnetlib.dhp.common.SparkSessionSupport.runWithSparkSession; +import static org.apache.spark.sql.functions.col; /** * PrepareRelationsJob prunes the relationships: only consider relationships that are not virtually deleted @@ -130,130 +118,28 @@ public class PrepareRelationsJob { private static void prepareRelationsRDD(SparkSession spark, String inputRelationsPath, String outputPath, Set relationFilter, int sourceMaxRelations, int targetMaxRelations, int relPartitions) { - JavaRDD rels = readPathRelationRDD(spark, inputRelationsPath) - .filter(rel -> !(rel.getSource().startsWith("unresolved") || rel.getTarget().startsWith("unresolved"))) - .filter(rel -> !rel.getDataInfo().getDeletedbyinference()) - .filter(rel -> !relationFilter.contains(StringUtils.lowerCase(rel.getRelClass()))); + WindowSpec source_w = Window + .partitionBy("source", "subRelType") + .orderBy(col("target").desc_nulls_last()); - JavaRDD pruned = pruneRels( - pruneRels( - rels, - sourceMaxRelations, relPartitions, (Function) Relation::getSource), - targetMaxRelations, relPartitions, (Function) Relation::getTarget); - spark - .createDataset(pruned.rdd(), Encoders.bean(Relation.class)) - .repartition(relPartitions) - .write() - .mode(SaveMode.Overwrite) - .parquet(outputPath); - } + WindowSpec target_w = Window + .partitionBy("target", "subRelType") + .orderBy(col("source").desc_nulls_last()); - private static JavaRDD pruneRels(JavaRDD rels, int maxRelations, - int relPartitions, Function idFn) { - return rels - .mapToPair(r -> new Tuple2<>(SortableRelationKey.create(r, idFn.call(r)), r)) - .repartitionAndSortWithinPartitions(new RelationPartitioner(relPartitions)) - .groupBy(Tuple2::_1) - .map(Tuple2::_2) - .map(t -> Iterables.limit(t, maxRelations)) - .flatMap(Iterable::iterator) - .map(Tuple2::_2); - } - - // experimental - private static void prepareRelationsDataset( - SparkSession spark, String inputRelationsPath, String outputPath, Set relationFilter, int maxRelations, - int relPartitions) { - spark - .read() - .textFile(inputRelationsPath) - .repartition(relPartitions) - .map( - (MapFunction) s -> OBJECT_MAPPER.readValue(s, Relation.class), - Encoders.kryo(Relation.class)) - .filter((FilterFunction) rel -> !rel.getDataInfo().getDeletedbyinference()) - .filter((FilterFunction) rel -> !relationFilter.contains(rel.getRelClass())) - .groupByKey( - (MapFunction) Relation::getSource, - Encoders.STRING()) - .agg(new RelationAggregator(maxRelations).toColumn()) - .flatMap( - (FlatMapFunction, Relation>) t -> Iterables - .limit(t._2().getRelations(), maxRelations) - .iterator(), - Encoders.bean(Relation.class)) - .repartition(relPartitions) - .write() - .mode(SaveMode.Overwrite) - .parquet(outputPath); - } - - public static class RelationAggregator - extends Aggregator { - - private final int maxRelations; - - public RelationAggregator(int maxRelations) { - this.maxRelations = maxRelations; - } - - @Override - public RelationList zero() { - return new RelationList(); - } - - @Override - public RelationList reduce(RelationList b, Relation a) { - b.getRelations().add(a); - return getSortableRelationList(b); - } - - @Override - public RelationList merge(RelationList b1, RelationList b2) { - b1.getRelations().addAll(b2.getRelations()); - return getSortableRelationList(b1); - } - - @Override - public RelationList finish(RelationList r) { - return getSortableRelationList(r); - } - - private RelationList getSortableRelationList(RelationList b1) { - RelationList sr = new RelationList(); - sr - .setRelations( - b1 - .getRelations() - .stream() - .limit(maxRelations) - .collect(Collectors.toCollection(() -> new PriorityQueue<>(new RelationComparator())))); - return sr; - } - - @Override - public Encoder bufferEncoder() { - return Encoders.kryo(RelationList.class); - } - - @Override - public Encoder outputEncoder() { - return Encoders.kryo(RelationList.class); - } - } - - /** - * Reads a JavaRDD of eu.dnetlib.dhp.oa.provision.model.SortableRelation objects from a newline delimited json text - * file, - * - * @param spark - * @param inputPath - * @return the JavaRDD containing all the relationships - */ - private static JavaRDD readPathRelationRDD( - SparkSession spark, final String inputPath) { - JavaSparkContext sc = JavaSparkContext.fromSparkContext(spark.sparkContext()); - return sc.textFile(inputPath).map(s -> OBJECT_MAPPER.readValue(s, Relation.class)); + spark.read().schema(Encoders.bean(Relation.class).schema()).json(inputRelationsPath) + .where("source NOT LIKE 'unresolved%' AND target NOT LIKE 'unresolved%'") + .where("datainfo.deletedbyinference != true") + .where(relationFilter.isEmpty() ? "" : "lower(relClass) NOT IN ("+ Joiner.on(',').join(relationFilter) +")") + .withColumn("source_w_pos", functions.row_number().over(source_w)) + .where("source_w_pos < " + sourceMaxRelations ) + .drop("source_w_pos") + .withColumn("target_w_pos", functions.row_number().over(target_w)) + .where("target_w_pos < " + targetMaxRelations) + .drop( "target_w_pos") + .coalesce(relPartitions) + .write() + .mode(SaveMode.Overwrite) + .parquet(outputPath); } private static void removeOutputDir(SparkSession spark, String path) {