diff --git a/dhp-workflows/dhp-graph-provision/src/main/java/eu/dnetlib/dhp/oa/provision/PrepareRelationsJob.java b/dhp-workflows/dhp-graph-provision/src/main/java/eu/dnetlib/dhp/oa/provision/PrepareRelationsJob.java index 6b34899c80..4ae822df7f 100644 --- a/dhp-workflows/dhp-graph-provision/src/main/java/eu/dnetlib/dhp/oa/provision/PrepareRelationsJob.java +++ b/dhp-workflows/dhp-graph-provision/src/main/java/eu/dnetlib/dhp/oa/provision/PrepareRelationsJob.java @@ -11,6 +11,8 @@ import org.apache.commons.io.IOUtils; import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.api.java.function.PairFunction; import org.apache.spark.rdd.RDD; import org.apache.spark.sql.Encoders; import org.apache.spark.sql.SaveMode; @@ -30,6 +32,7 @@ import eu.dnetlib.dhp.common.HdfsSupport; import eu.dnetlib.dhp.oa.provision.model.SortableRelationKey; import eu.dnetlib.dhp.oa.provision.utils.RelationPartitioner; import eu.dnetlib.dhp.schema.oaf.Relation; +import scala.Tuple2; /** * Joins the graph nodes by resolving the links of distance = 1 to create an adjacency list of linked objects. The @@ -63,21 +66,6 @@ public class PrepareRelationsJob { public static final int DEFAULT_NUM_PARTITIONS = 3000; - private static final Map weights = Maps.newHashMap(); - - static { - weights.put("outcome", 0); - weights.put("supplement", 1); - weights.put("affiliation", 2); - weights.put("relationship", 3); - weights.put("publicationDataset", 4); - weights.put("similarity", 5); - - weights.put("provision", 6); - weights.put("participation", 7); - weights.put("dedup", 8); - } - public static void main(String[] args) throws Exception { String jsonConfiguration = IOUtils .toString( @@ -146,21 +134,26 @@ public class PrepareRelationsJob { int relPartitions) { RDD cappedRels = readPathRelationRDD(spark, inputRelationsPath) - .repartition(relPartitions) - .filter(rel -> !rel.getDataInfo().getDeletedbyinference()) - .filter(rel -> !relationFilter.contains(rel.getRelClass())) + .filter(rel -> rel.getDataInfo().getDeletedbyinference() == false) + .filter(rel -> relationFilter.contains(rel.getRelClass()) == false) + // group by SOURCE and apply limit - .groupBy(r -> SortableRelationKey.create(r, r.getSource())) - .repartitionAndSortWithinPartitions( - new RelationPartitioner(relPartitions), - (SerializableComparator) (o1, o2) -> compare(o1, o2)) - .flatMap(t -> Iterables.limit(t._2(), maxRelations).iterator()) + .mapToPair(r -> new Tuple2<>(SortableRelationKey.create(r, r.getSource()), r)) + .repartitionAndSortWithinPartitions(new RelationPartitioner(relPartitions)) + .groupBy(Tuple2::_1) + .map(Tuple2::_2) + .map(t -> Iterables.limit(t, maxRelations)) + .flatMap(Iterable::iterator) + .map(Tuple2::_2) + // group by TARGET and apply limit - .groupBy(r -> SortableRelationKey.create(r, r.getTarget())) - .repartitionAndSortWithinPartitions( - new RelationPartitioner(relPartitions), - (SerializableComparator) (o1, o2) -> compare(o1, o2)) - .flatMap(t -> Iterables.limit(t._2(), maxRelations).iterator()) + .mapToPair(r -> new Tuple2<>(SortableRelationKey.create(r, r.getTarget()), r)) + .repartitionAndSortWithinPartitions(new RelationPartitioner(relPartitions)) + .groupBy(Tuple2::_1) + .map(Tuple2::_2) + .map(t -> Iterables.limit(t, maxRelations)) + .flatMap(Iterable::iterator) + .map(Tuple2::_2) .rdd(); spark @@ -170,24 +163,6 @@ public class PrepareRelationsJob { .parquet(outputPath); } - private static int compare(SortableRelationKey o1, SortableRelationKey o2) { - final Integer w1 = Optional.ofNullable(weights.get(o1.getSubRelType())).orElse(Integer.MAX_VALUE); - final Integer w2 = Optional.ofNullable(weights.get(o2.getSubRelType())).orElse(Integer.MAX_VALUE); - return ComparisonChain - .start() - .compare(w1, w2) - .compare(o1.getSource(), o2.getSource()) - .compare(o1.getTarget(), o2.getTarget()) - .result(); - } - - @FunctionalInterface - public interface SerializableComparator extends Comparator, Serializable { - - @Override - int compare(T o1, T o2); - } - /** * Reads a JavaRDD of eu.dnetlib.dhp.oa.provision.model.SortableRelation objects from a newline delimited json text * file, diff --git a/dhp-workflows/dhp-graph-provision/src/main/java/eu/dnetlib/dhp/oa/provision/model/SortableRelationKey.java b/dhp-workflows/dhp-graph-provision/src/main/java/eu/dnetlib/dhp/oa/provision/model/SortableRelationKey.java index ab6518809b..e96c4ca5cf 100644 --- a/dhp-workflows/dhp-graph-provision/src/main/java/eu/dnetlib/dhp/oa/provision/model/SortableRelationKey.java +++ b/dhp-workflows/dhp-graph-provision/src/main/java/eu/dnetlib/dhp/oa/provision/model/SortableRelationKey.java @@ -11,25 +11,34 @@ import com.google.common.collect.Maps; import eu.dnetlib.dhp.schema.oaf.Relation; -public class SortableRelationKey implements Serializable { +public class SortableRelationKey implements Comparable, Serializable { + + private static final Map weights = Maps.newHashMap(); + + static { + weights.put("outcome", 0); + weights.put("supplement", 1); + weights.put("review", 2); + weights.put("citation", 3); + weights.put("affiliation", 4); + weights.put("relationship", 5); + weights.put("publicationDataset", 6); + weights.put("similarity", 7); + + weights.put("provision", 8); + weights.put("participation", 9); + weights.put("dedup", 10); + } + + private static final long serialVersionUID = 3232323; private String groupingKey; - private String source; - - private String target; - private String subRelType; - public String getSource() { - return source; - } - public static SortableRelationKey create(Relation r, String groupingKey) { SortableRelationKey sr = new SortableRelationKey(); sr.setGroupingKey(groupingKey); - sr.setSource(r.getSource()); - sr.setTarget(r.getTarget()); sr.setSubRelType(r.getSubRelType()); return sr; } @@ -49,16 +58,16 @@ public class SortableRelationKey implements Serializable { return Objects.hashCode(getGroupingKey()); } - public void setSource(String source) { - this.source = source; + @Override + public int compareTo(SortableRelationKey o) { + return ComparisonChain + .start() + .compare(getWeight(this), getWeight(o)) + .result() * -1; } - public String getTarget() { - return target; - } - - public void setTarget(String target) { - this.target = target; + private Integer getWeight(SortableRelationKey o) { + return Optional.ofNullable(weights.get(o.getSubRelType())).orElse(Integer.MAX_VALUE); } public String getSubRelType() { @@ -76,4 +85,5 @@ public class SortableRelationKey implements Serializable { public void setGroupingKey(String groupingKey) { this.groupingKey = groupingKey; } + } diff --git a/dhp-workflows/dhp-graph-provision/src/main/java/eu/dnetlib/dhp/oa/provision/utils/RelationPartitioner.java b/dhp-workflows/dhp-graph-provision/src/main/java/eu/dnetlib/dhp/oa/provision/utils/RelationPartitioner.java index bdece36ab0..7bd8b92171 100644 --- a/dhp-workflows/dhp-graph-provision/src/main/java/eu/dnetlib/dhp/oa/provision/utils/RelationPartitioner.java +++ b/dhp-workflows/dhp-graph-provision/src/main/java/eu/dnetlib/dhp/oa/provision/utils/RelationPartitioner.java @@ -12,6 +12,8 @@ import eu.dnetlib.dhp.oa.provision.model.SortableRelationKey; */ public class RelationPartitioner extends Partitioner { + private static final long serialVersionUID = 343434456L; + private final int numPartitions; public RelationPartitioner(int numPartitions) { @@ -29,4 +31,14 @@ public class RelationPartitioner extends Partitioner { return Utils.nonNegativeMod(partitionKey.getGroupingKey().hashCode(), numPartitions()); } + @Override + public boolean equals(Object obj) { + if (obj instanceof RelationPartitioner) { + RelationPartitioner p = (RelationPartitioner) obj; + if (p.numPartitions() == numPartitions()) + return true; + } + return false; + } + }