From 821be1f8b66ac13bc79105cb17bc86ca5fbf4f85 Mon Sep 17 00:00:00 2001 From: Claudio Atzori Date: Thu, 28 May 2020 13:53:13 +0200 Subject: [PATCH] experimental implementation of custom aggregation using kryo encoders --- .../oa/provision/AdjacencyListBuilderJob.java | 94 +++++++++++++++++-- 1 file changed, 85 insertions(+), 9 deletions(-) diff --git a/dhp-workflows/dhp-graph-provision/src/main/java/eu/dnetlib/dhp/oa/provision/AdjacencyListBuilderJob.java b/dhp-workflows/dhp-graph-provision/src/main/java/eu/dnetlib/dhp/oa/provision/AdjacencyListBuilderJob.java index 9f221ae45c..63b90be7c5 100644 --- a/dhp-workflows/dhp-graph-provision/src/main/java/eu/dnetlib/dhp/oa/provision/AdjacencyListBuilderJob.java +++ b/dhp-workflows/dhp-graph-provision/src/main/java/eu/dnetlib/dhp/oa/provision/AdjacencyListBuilderJob.java @@ -4,20 +4,20 @@ package eu.dnetlib.dhp.oa.provision; import static eu.dnetlib.dhp.common.SparkSessionSupport.runWithSparkSession; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import java.util.Optional; +import java.util.stream.Collectors; import org.apache.commons.io.IOUtils; import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.function.Function2; import org.apache.spark.api.java.function.MapFunction; import org.apache.spark.api.java.function.MapGroupsFunction; import org.apache.spark.api.java.function.PairFunction; import org.apache.spark.rdd.RDD; -import org.apache.spark.sql.Encoders; -import org.apache.spark.sql.SaveMode; -import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.*; +import org.apache.spark.sql.expressions.Aggregator; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -25,10 +25,11 @@ import com.google.common.collect.Lists; import eu.dnetlib.dhp.application.ArgumentApplicationParser; import eu.dnetlib.dhp.common.HdfsSupport; -import eu.dnetlib.dhp.oa.provision.model.EntityRelEntity; -import eu.dnetlib.dhp.oa.provision.model.JoinedEntity; -import eu.dnetlib.dhp.oa.provision.model.Tuple2; +import eu.dnetlib.dhp.oa.provision.model.*; import eu.dnetlib.dhp.schema.common.ModelSupport; +import eu.dnetlib.dhp.schema.oaf.Oaf; +import scala.Function1; +import scala.Function2; /** * Joins the graph nodes by resolving the links of distance = 1 to create an adjacency list of linked objects. The @@ -82,17 +83,92 @@ public class AdjacencyListBuilderJob { SparkConf conf = new SparkConf(); conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer"); - conf.registerKryoClasses(ModelSupport.getOafModelClasses()); + List> modelClasses = Arrays.asList(ModelSupport.getOafModelClasses()); + modelClasses + .addAll( + Lists + .newArrayList( + TypedRow.class, + EntityRelEntity.class, + JoinedEntity.class, + RelatedEntity.class, + Tuple2.class, + SortableRelation.class)); + conf.registerKryoClasses(modelClasses.toArray(new Class[] {})); runWithSparkSession( conf, isSparkSessionManaged, spark -> { removeOutputDir(spark, outputPath); - createAdjacencyListsRDD(spark, inputPath, outputPath); + createAdjacencyListsKryo(spark, inputPath, outputPath); }); } + private static void createAdjacencyListsKryo( + SparkSession spark, String inputPath, String outputPath) { + + TypedColumn aggregator = new AdjacencyListAggregator().toColumn(); + log.info("Reading joined entities from: {}", inputPath); + spark + .read() + .load(inputPath) + .as(Encoders.kryo(EntityRelEntity.class)) + .groupByKey( + (MapFunction) value -> value.getEntity().getId(), + Encoders.STRING()) + .agg(aggregator) + .write() + .mode(SaveMode.Overwrite) + .parquet(outputPath); + } + + public static class AdjacencyListAggregator extends Aggregator { + + @Override + public JoinedEntity zero() { + return new JoinedEntity(); + } + + @Override + public JoinedEntity reduce(JoinedEntity j, EntityRelEntity e) { + j.setEntity(e.getEntity()); + if (j.getLinks().size() <= MAX_LINKS) { + j.getLinks().add(new Tuple2(e.getRelation(), e.getTarget())); + } + return j; + } + + @Override + public JoinedEntity merge(JoinedEntity j1, JoinedEntity j2) { + j1.getLinks().addAll(j2.getLinks()); + return j1; + } + + @Override + public JoinedEntity finish(JoinedEntity j) { + if (j.getLinks().size() > MAX_LINKS) { + ArrayList links = j + .getLinks() + .stream() + .limit(MAX_LINKS) + .collect(Collectors.toCollection(ArrayList::new)); + j.setLinks(links); + } + return j; + } + + @Override + public Encoder bufferEncoder() { + return Encoders.kryo(JoinedEntity.class); + } + + @Override + public Encoder outputEncoder() { + return Encoders.kryo(JoinedEntity.class); + } + } + private static void createAdjacencyLists( SparkSession spark, String inputPath, String outputPath) {