added test to verify the relation pre-processing

This commit is contained in:
Claudio Atzori 2020-06-26 17:58:33 +02:00
parent 8d59fdf34e
commit 7817338e05
5 changed files with 122 additions and 51 deletions

View File

@ -115,7 +115,7 @@ public class PrepareRelationsJob {
isSparkSessionManaged,
spark -> {
removeOutputDir(spark, outputPath);
prepareRelationsDataset(
prepareRelationsRDD(
spark, inputRelationsPath, outputPath, relationFilter, maxRelations, relPartitions);
});
}
@ -148,21 +148,8 @@ public class PrepareRelationsJob {
.map(Tuple2::_2)
.rdd();
// group by TARGET and apply limit
RDD<Relation> byTarget = readPathRelationRDD(spark, inputRelationsPath)
.filter(rel -> rel.getDataInfo().getDeletedbyinference() == false)
.filter(rel -> relationFilter.contains(rel.getRelClass()) == false)
.mapToPair(r -> new Tuple2<>(SortableRelationKey.create(r, r.getTarget()), r))
.repartitionAndSortWithinPartitions(new RelationPartitioner(relPartitions))
.groupBy(Tuple2::_1)
.map(Tuple2::_2)
.map(t -> Iterables.limit(t, maxRelations))
.flatMap(Iterable::iterator)
.map(Tuple2::_2)
.rdd();
spark
.createDataset(bySource.union(byTarget), Encoders.bean(Relation.class))
.createDataset(bySource, Encoders.bean(Relation.class))
.repartition(relPartitions)
.write()
.mode(SaveMode.Overwrite)
@ -172,41 +159,7 @@ public class PrepareRelationsJob {
private static void prepareRelationsDataset(
SparkSession spark, String inputRelationsPath, String outputPath, Set<String> relationFilter, int maxRelations,
int relPartitions) {
Dataset<Relation> bySource = pruneRelations(
spark, inputRelationsPath, relationFilter, maxRelations, relPartitions,
(Function<Relation, String>) r -> r.getSource());
Dataset<Relation> byTarget = pruneRelations(
spark, inputRelationsPath, relationFilter, maxRelations, relPartitions,
(Function<Relation, String>) r -> r.getTarget());
bySource
.union(byTarget)
.repartition(relPartitions)
.write()
.mode(SaveMode.Overwrite)
.parquet(outputPath);
}
private static Dataset<Relation> pruneRelations(SparkSession spark, String inputRelationsPath,
Set<String> relationFilter, int maxRelations, int relPartitions,
Function<Relation, String> idFn) {
return readRelations(spark, inputRelationsPath, relationFilter, relPartitions)
.groupByKey(
(MapFunction<Relation, String>) r -> idFn.call(r),
Encoders.STRING())
.agg(new RelationAggregator(maxRelations).toColumn())
.flatMap(
(FlatMapFunction<Tuple2<String, RelationList>, Relation>) t -> t
._2()
.getRelations()
.iterator(),
Encoders.bean(Relation.class));
}
private static Dataset<Relation> readRelations(SparkSession spark, String inputRelationsPath,
Set<String> relationFilter, int relPartitions) {
return spark
spark
.read()
.textFile(inputRelationsPath)
.repartition(relPartitions)
@ -214,7 +167,20 @@ public class PrepareRelationsJob {
(MapFunction<String, Relation>) s -> OBJECT_MAPPER.readValue(s, Relation.class),
Encoders.kryo(Relation.class))
.filter((FilterFunction<Relation>) rel -> rel.getDataInfo().getDeletedbyinference() == false)
.filter((FilterFunction<Relation>) rel -> relationFilter.contains(rel.getRelClass()) == false);
.filter((FilterFunction<Relation>) rel -> relationFilter.contains(rel.getRelClass()) == false)
.groupByKey(
(MapFunction<Relation, String>) Relation::getSource,
Encoders.STRING())
.agg(new RelationAggregator(maxRelations).toColumn())
.flatMap(
(FlatMapFunction<Tuple2<String, RelationList>, Relation>) t -> Iterables
.limit(t._2().getRelations(), maxRelations)
.iterator(),
Encoders.bean(Relation.class))
.repartition(relPartitions)
.write()
.mode(SaveMode.Overwrite)
.parquet(outputPath);
}
public static class RelationAggregator

View File

@ -133,6 +133,7 @@
</spark-opts>
<arg>--inputRelationsPath</arg><arg>${inputGraphRootPath}/relation</arg>
<arg>--outputPath</arg><arg>${workingDir}/relation</arg>
<arg>--maxRelations</arg><arg>${maxRelations}</arg>
<arg>--relPartitions</arg><arg>5000</arg>
</spark>
<ok to="fork_join_related_entities"/>

View File

@ -0,0 +1,93 @@
package eu.dnetlib.dhp.oa.provision;
import com.fasterxml.jackson.databind.ObjectMapper;
import eu.dnetlib.dhp.oa.provision.model.ProvisionModelSupport;
import eu.dnetlib.dhp.schema.oaf.Relation;
import org.apache.commons.io.FileUtils;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.function.FilterFunction;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
public class PrepareRelationsJobTest {
private static final Logger log = LoggerFactory.getLogger(PrepareRelationsJobTest.class);
public static final String SUBRELTYPE = "subRelType";
public static final String OUTCOME = "outcome";
public static final String SUPPLEMENT = "supplement";
private static SparkSession spark;
private static Path workingDir;
@BeforeAll
public static void setUp() throws IOException {
workingDir = Files.createTempDirectory(PrepareRelationsJobTest.class.getSimpleName());
log.info("using work dir {}", workingDir);
SparkConf conf = new SparkConf();
conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer");
conf.registerKryoClasses(ProvisionModelSupport.getModelClasses());
spark = SparkSession
.builder()
.appName(PrepareRelationsJobTest.class.getSimpleName())
.master("local[*]")
.config(conf)
.getOrCreate();
}
@AfterAll
public static void afterAll() throws IOException {
FileUtils.deleteDirectory(workingDir.toFile());
spark.stop();
}
@Test
public void testRunPrepareRelationsJob(@TempDir Path testPath) throws Exception {
final int maxRelations = 10;
PrepareRelationsJob
.main(
new String[] {
"-isSparkSessionManaged", Boolean.FALSE.toString(),
"-inputRelationsPath", getClass().getResource("relations.gz").getPath(),
"-outputPath", testPath.toString(),
"-relPartitions", "10",
"-relationFilter", "asd",
"-maxRelations", String.valueOf(maxRelations)
});
Dataset<Relation> out = spark.read()
.parquet(testPath.toString())
.as(Encoders.bean(Relation.class))
.cache();
Assertions.assertEquals(10, out.count());
Dataset<Row> freq = out.toDF().cube(SUBRELTYPE).count().filter((FilterFunction<Row>) value -> !value.isNullAt(0));
long outcome = freq.filter(freq.col(SUBRELTYPE).equalTo(OUTCOME)).collectAsList().get(0).getAs("count");
long supplement = freq.filter(freq.col(SUBRELTYPE).equalTo(SUPPLEMENT)).collectAsList().get(0).getAs("count");
Assertions.assertTrue(outcome > supplement);
Assertions.assertEquals(7, outcome);
Assertions.assertEquals(3, supplement);
}
}

View File

@ -0,0 +1,11 @@
# Set root logger level to DEBUG and its only appender to A1.
log4j.rootLogger=INFO, A1
# A1 is set to be a ConsoleAppender.
log4j.appender.A1=org.apache.log4j.ConsoleAppender
# A1 uses PatternLayout.
log4j.logger.org = ERROR
log4j.logger.eu.dnetlib = DEBUG
log4j.appender.A1.layout=org.apache.log4j.PatternLayout
log4j.appender.A1.layout.ConversionPattern=%-4r [%t] %-5p %c %x - %m%n