diff --git a/dhp-workflows/dhp-graph-provision/src/main/java/eu/dnetlib/dhp/oa/provision/PrepareRelationsJob.java b/dhp-workflows/dhp-graph-provision/src/main/java/eu/dnetlib/dhp/oa/provision/PrepareRelationsJob.java index d69b75b65..6b34899c8 100644 --- a/dhp-workflows/dhp-graph-provision/src/main/java/eu/dnetlib/dhp/oa/provision/PrepareRelationsJob.java +++ b/dhp-workflows/dhp-graph-provision/src/main/java/eu/dnetlib/dhp/oa/provision/PrepareRelationsJob.java @@ -3,7 +3,9 @@ package eu.dnetlib.dhp.oa.provision; import static eu.dnetlib.dhp.common.SparkSessionSupport.runWithSparkSession; +import java.io.Serializable; import java.util.*; +import java.util.function.Supplier; import org.apache.commons.io.IOUtils; import org.apache.spark.SparkConf; @@ -18,7 +20,9 @@ import org.slf4j.LoggerFactory; import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.base.Splitter; +import com.google.common.collect.ComparisonChain; import com.google.common.collect.Iterables; +import com.google.common.collect.Maps; import com.google.common.collect.Sets; import eu.dnetlib.dhp.application.ArgumentApplicationParser; @@ -59,6 +63,21 @@ public class PrepareRelationsJob { public static final int DEFAULT_NUM_PARTITIONS = 3000; + private static final Map weights = Maps.newHashMap(); + + static { + weights.put("outcome", 0); + weights.put("supplement", 1); + weights.put("affiliation", 2); + weights.put("relationship", 3); + weights.put("publicationDataset", 4); + weights.put("similarity", 5); + + weights.put("provision", 6); + weights.put("participation", 7); + weights.put("dedup", 8); + } + public static void main(String[] args) throws Exception { String jsonConfiguration = IOUtils .toString( @@ -132,11 +151,15 @@ public class PrepareRelationsJob { .filter(rel -> !relationFilter.contains(rel.getRelClass())) // group by SOURCE and apply limit .groupBy(r -> SortableRelationKey.create(r, r.getSource())) - .repartitionAndSortWithinPartitions(new RelationPartitioner(relPartitions)) + .repartitionAndSortWithinPartitions( + new RelationPartitioner(relPartitions), + (SerializableComparator) (o1, o2) -> compare(o1, o2)) .flatMap(t -> Iterables.limit(t._2(), maxRelations).iterator()) // group by TARGET and apply limit .groupBy(r -> SortableRelationKey.create(r, r.getTarget())) - .repartitionAndSortWithinPartitions(new RelationPartitioner(relPartitions)) + .repartitionAndSortWithinPartitions( + new RelationPartitioner(relPartitions), + (SerializableComparator) (o1, o2) -> compare(o1, o2)) .flatMap(t -> Iterables.limit(t._2(), maxRelations).iterator()) .rdd(); @@ -147,6 +170,24 @@ public class PrepareRelationsJob { .parquet(outputPath); } + private static int compare(SortableRelationKey o1, SortableRelationKey o2) { + final Integer w1 = Optional.ofNullable(weights.get(o1.getSubRelType())).orElse(Integer.MAX_VALUE); + final Integer w2 = Optional.ofNullable(weights.get(o2.getSubRelType())).orElse(Integer.MAX_VALUE); + return ComparisonChain + .start() + .compare(w1, w2) + .compare(o1.getSource(), o2.getSource()) + .compare(o1.getTarget(), o2.getTarget()) + .result(); + } + + @FunctionalInterface + public interface SerializableComparator extends Comparator, Serializable { + + @Override + int compare(T o1, T o2); + } + /** * Reads a JavaRDD of eu.dnetlib.dhp.oa.provision.model.SortableRelation objects from a newline delimited json text * file, diff --git a/dhp-workflows/dhp-graph-provision/src/main/java/eu/dnetlib/dhp/oa/provision/model/SortableRelationKey.java b/dhp-workflows/dhp-graph-provision/src/main/java/eu/dnetlib/dhp/oa/provision/model/SortableRelationKey.java index ad61fa044..ab6518809 100644 --- a/dhp-workflows/dhp-graph-provision/src/main/java/eu/dnetlib/dhp/oa/provision/model/SortableRelationKey.java +++ b/dhp-workflows/dhp-graph-provision/src/main/java/eu/dnetlib/dhp/oa/provision/model/SortableRelationKey.java @@ -5,27 +5,13 @@ import java.io.Serializable; import java.util.Map; import java.util.Optional; +import com.google.common.base.Objects; import com.google.common.collect.ComparisonChain; import com.google.common.collect.Maps; import eu.dnetlib.dhp.schema.oaf.Relation; -public class SortableRelationKey implements Comparable, Serializable { - - private static final Map weights = Maps.newHashMap(); - - static { - weights.put("outcome", 0); - weights.put("supplement", 1); - weights.put("affiliation", 2); - weights.put("relationship", 3); - weights.put("publicationDataset", 4); - weights.put("similarity", 5); - - weights.put("provision", 6); - weights.put("participation", 7); - weights.put("dedup", 8); - } +public class SortableRelationKey implements Serializable { private String groupingKey; @@ -49,15 +35,18 @@ public class SortableRelationKey implements Comparable, Ser } @Override - public int compareTo(SortableRelationKey o) { - final Integer wt = Optional.ofNullable(weights.get(getSubRelType())).orElse(Integer.MAX_VALUE); - final Integer wo = Optional.ofNullable(weights.get(o.getSubRelType())).orElse(Integer.MAX_VALUE); - return ComparisonChain - .start() - .compare(wt, wo) - .compare(getSource(), o.getSource()) - .compare(getTarget(), o.getTarget()) - .result(); + public boolean equals(Object o) { + if (this == o) + return true; + if (o == null || getClass() != o.getClass()) + return false; + SortableRelationKey that = (SortableRelationKey) o; + return Objects.equal(getGroupingKey(), that.getGroupingKey()); + } + + @Override + public int hashCode() { + return Objects.hashCode(getGroupingKey()); } public void setSource(String source) {