diff --git a/dhp-workflows/dhp-graph-provision/src/main/java/eu/dnetlib/dhp/oa/provision/PrepareRelationsJob.java b/dhp-workflows/dhp-graph-provision/src/main/java/eu/dnetlib/dhp/oa/provision/PrepareRelationsJob.java index bf9806787..601cf6449 100644 --- a/dhp-workflows/dhp-graph-provision/src/main/java/eu/dnetlib/dhp/oa/provision/PrepareRelationsJob.java +++ b/dhp-workflows/dhp-graph-provision/src/main/java/eu/dnetlib/dhp/oa/provision/PrepareRelationsJob.java @@ -115,7 +115,7 @@ public class PrepareRelationsJob { isSparkSessionManaged, spark -> { removeOutputDir(spark, outputPath); - prepareRelationsDataset( + prepareRelationsRDD( spark, inputRelationsPath, outputPath, relationFilter, maxRelations, relPartitions); }); } @@ -148,21 +148,8 @@ public class PrepareRelationsJob { .map(Tuple2::_2) .rdd(); - // group by TARGET and apply limit - RDD byTarget = readPathRelationRDD(spark, inputRelationsPath) - .filter(rel -> rel.getDataInfo().getDeletedbyinference() == false) - .filter(rel -> relationFilter.contains(rel.getRelClass()) == false) - .mapToPair(r -> new Tuple2<>(SortableRelationKey.create(r, r.getTarget()), r)) - .repartitionAndSortWithinPartitions(new RelationPartitioner(relPartitions)) - .groupBy(Tuple2::_1) - .map(Tuple2::_2) - .map(t -> Iterables.limit(t, maxRelations)) - .flatMap(Iterable::iterator) - .map(Tuple2::_2) - .rdd(); - spark - .createDataset(bySource.union(byTarget), Encoders.bean(Relation.class)) + .createDataset(bySource, Encoders.bean(Relation.class)) .repartition(relPartitions) .write() .mode(SaveMode.Overwrite) @@ -172,41 +159,7 @@ public class PrepareRelationsJob { private static void prepareRelationsDataset( SparkSession spark, String inputRelationsPath, String outputPath, Set relationFilter, int maxRelations, int relPartitions) { - - Dataset bySource = pruneRelations( - spark, inputRelationsPath, relationFilter, maxRelations, relPartitions, - (Function) r -> r.getSource()); - Dataset byTarget = pruneRelations( - spark, inputRelationsPath, relationFilter, maxRelations, relPartitions, - (Function) r -> r.getTarget()); - - bySource - .union(byTarget) - .repartition(relPartitions) - .write() - .mode(SaveMode.Overwrite) - .parquet(outputPath); - } - - private static Dataset pruneRelations(SparkSession spark, String inputRelationsPath, - Set relationFilter, int maxRelations, int relPartitions, - Function idFn) { - return readRelations(spark, inputRelationsPath, relationFilter, relPartitions) - .groupByKey( - (MapFunction) r -> idFn.call(r), - Encoders.STRING()) - .agg(new RelationAggregator(maxRelations).toColumn()) - .flatMap( - (FlatMapFunction, Relation>) t -> t - ._2() - .getRelations() - .iterator(), - Encoders.bean(Relation.class)); - } - - private static Dataset readRelations(SparkSession spark, String inputRelationsPath, - Set relationFilter, int relPartitions) { - return spark + spark .read() .textFile(inputRelationsPath) .repartition(relPartitions) @@ -214,7 +167,20 @@ public class PrepareRelationsJob { (MapFunction) s -> OBJECT_MAPPER.readValue(s, Relation.class), Encoders.kryo(Relation.class)) .filter((FilterFunction) rel -> rel.getDataInfo().getDeletedbyinference() == false) - .filter((FilterFunction) rel -> relationFilter.contains(rel.getRelClass()) == false); + .filter((FilterFunction) rel -> relationFilter.contains(rel.getRelClass()) == false) + .groupByKey( + (MapFunction) Relation::getSource, + Encoders.STRING()) + .agg(new RelationAggregator(maxRelations).toColumn()) + .flatMap( + (FlatMapFunction, Relation>) t -> Iterables + .limit(t._2().getRelations(), maxRelations) + .iterator(), + Encoders.bean(Relation.class)) + .repartition(relPartitions) + .write() + .mode(SaveMode.Overwrite) + .parquet(outputPath); } public static class RelationAggregator diff --git a/dhp-workflows/dhp-graph-provision/src/main/resources/eu/dnetlib/dhp/oa/provision/oozie_app/workflow.xml b/dhp-workflows/dhp-graph-provision/src/main/resources/eu/dnetlib/dhp/oa/provision/oozie_app/workflow.xml index 0d5121cf1..697a00a09 100644 --- a/dhp-workflows/dhp-graph-provision/src/main/resources/eu/dnetlib/dhp/oa/provision/oozie_app/workflow.xml +++ b/dhp-workflows/dhp-graph-provision/src/main/resources/eu/dnetlib/dhp/oa/provision/oozie_app/workflow.xml @@ -133,6 +133,7 @@ --inputRelationsPath${inputGraphRootPath}/relation --outputPath${workingDir}/relation + --maxRelations${maxRelations} --relPartitions5000 diff --git a/dhp-workflows/dhp-graph-provision/src/test/java/eu/dnetlib/dhp/oa/provision/PrepareRelationsJobTest.java b/dhp-workflows/dhp-graph-provision/src/test/java/eu/dnetlib/dhp/oa/provision/PrepareRelationsJobTest.java new file mode 100644 index 000000000..c16bbc6fb --- /dev/null +++ b/dhp-workflows/dhp-graph-provision/src/test/java/eu/dnetlib/dhp/oa/provision/PrepareRelationsJobTest.java @@ -0,0 +1,93 @@ + +package eu.dnetlib.dhp.oa.provision; + +import com.fasterxml.jackson.databind.ObjectMapper; +import eu.dnetlib.dhp.oa.provision.model.ProvisionModelSupport; +import eu.dnetlib.dhp.schema.oaf.Relation; +import org.apache.commons.io.FileUtils; +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.function.FilterFunction; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; + +public class PrepareRelationsJobTest { + + private static final Logger log = LoggerFactory.getLogger(PrepareRelationsJobTest.class); + + public static final String SUBRELTYPE = "subRelType"; + public static final String OUTCOME = "outcome"; + public static final String SUPPLEMENT = "supplement"; + + private static SparkSession spark; + + private static Path workingDir; + + @BeforeAll + public static void setUp() throws IOException { + workingDir = Files.createTempDirectory(PrepareRelationsJobTest.class.getSimpleName()); + log.info("using work dir {}", workingDir); + + SparkConf conf = new SparkConf(); + + conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer"); + conf.registerKryoClasses(ProvisionModelSupport.getModelClasses()); + + spark = SparkSession + .builder() + .appName(PrepareRelationsJobTest.class.getSimpleName()) + .master("local[*]") + .config(conf) + .getOrCreate(); + } + + @AfterAll + public static void afterAll() throws IOException { + FileUtils.deleteDirectory(workingDir.toFile()); + spark.stop(); + } + + @Test + public void testRunPrepareRelationsJob(@TempDir Path testPath) throws Exception { + + final int maxRelations = 10; + PrepareRelationsJob + .main( + new String[] { + "-isSparkSessionManaged", Boolean.FALSE.toString(), + "-inputRelationsPath", getClass().getResource("relations.gz").getPath(), + "-outputPath", testPath.toString(), + "-relPartitions", "10", + "-relationFilter", "asd", + "-maxRelations", String.valueOf(maxRelations) + }); + + Dataset out = spark.read() + .parquet(testPath.toString()) + .as(Encoders.bean(Relation.class)) + .cache(); + + Assertions.assertEquals(10, out.count()); + + Dataset freq = out.toDF().cube(SUBRELTYPE).count().filter((FilterFunction) value -> !value.isNullAt(0)); + long outcome = freq.filter(freq.col(SUBRELTYPE).equalTo(OUTCOME)).collectAsList().get(0).getAs("count"); + long supplement = freq.filter(freq.col(SUBRELTYPE).equalTo(SUPPLEMENT)).collectAsList().get(0).getAs("count"); + + Assertions.assertTrue(outcome > supplement); + Assertions.assertEquals(7, outcome); + Assertions.assertEquals(3, supplement); + } + +} diff --git a/dhp-workflows/dhp-graph-provision/src/test/resources/eu/dnetlib/dhp/oa/provision/relations.gz b/dhp-workflows/dhp-graph-provision/src/test/resources/eu/dnetlib/dhp/oa/provision/relations.gz new file mode 100644 index 000000000..13bc01c8c Binary files /dev/null and b/dhp-workflows/dhp-graph-provision/src/test/resources/eu/dnetlib/dhp/oa/provision/relations.gz differ diff --git a/dhp-workflows/dhp-graph-provision/src/test/resources/log4j.properties b/dhp-workflows/dhp-graph-provision/src/test/resources/log4j.properties new file mode 100644 index 000000000..20f56e38d --- /dev/null +++ b/dhp-workflows/dhp-graph-provision/src/test/resources/log4j.properties @@ -0,0 +1,11 @@ +# Set root logger level to DEBUG and its only appender to A1. +log4j.rootLogger=INFO, A1 + +# A1 is set to be a ConsoleAppender. +log4j.appender.A1=org.apache.log4j.ConsoleAppender + +# A1 uses PatternLayout. +log4j.logger.org = ERROR +log4j.logger.eu.dnetlib = DEBUG +log4j.appender.A1.layout=org.apache.log4j.PatternLayout +log4j.appender.A1.layout.ConversionPattern=%-4r [%t] %-5p %c %x - %m%n \ No newline at end of file