added test to verify the relation pre-processing
This commit is contained in:
parent
8d59fdf34e
commit
7817338e05
|
@ -115,7 +115,7 @@ public class PrepareRelationsJob {
|
|||
isSparkSessionManaged,
|
||||
spark -> {
|
||||
removeOutputDir(spark, outputPath);
|
||||
prepareRelationsDataset(
|
||||
prepareRelationsRDD(
|
||||
spark, inputRelationsPath, outputPath, relationFilter, maxRelations, relPartitions);
|
||||
});
|
||||
}
|
||||
|
@ -148,21 +148,8 @@ public class PrepareRelationsJob {
|
|||
.map(Tuple2::_2)
|
||||
.rdd();
|
||||
|
||||
// group by TARGET and apply limit
|
||||
RDD<Relation> byTarget = readPathRelationRDD(spark, inputRelationsPath)
|
||||
.filter(rel -> rel.getDataInfo().getDeletedbyinference() == false)
|
||||
.filter(rel -> relationFilter.contains(rel.getRelClass()) == false)
|
||||
.mapToPair(r -> new Tuple2<>(SortableRelationKey.create(r, r.getTarget()), r))
|
||||
.repartitionAndSortWithinPartitions(new RelationPartitioner(relPartitions))
|
||||
.groupBy(Tuple2::_1)
|
||||
.map(Tuple2::_2)
|
||||
.map(t -> Iterables.limit(t, maxRelations))
|
||||
.flatMap(Iterable::iterator)
|
||||
.map(Tuple2::_2)
|
||||
.rdd();
|
||||
|
||||
spark
|
||||
.createDataset(bySource.union(byTarget), Encoders.bean(Relation.class))
|
||||
.createDataset(bySource, Encoders.bean(Relation.class))
|
||||
.repartition(relPartitions)
|
||||
.write()
|
||||
.mode(SaveMode.Overwrite)
|
||||
|
@ -172,41 +159,7 @@ public class PrepareRelationsJob {
|
|||
private static void prepareRelationsDataset(
|
||||
SparkSession spark, String inputRelationsPath, String outputPath, Set<String> relationFilter, int maxRelations,
|
||||
int relPartitions) {
|
||||
|
||||
Dataset<Relation> bySource = pruneRelations(
|
||||
spark, inputRelationsPath, relationFilter, maxRelations, relPartitions,
|
||||
(Function<Relation, String>) r -> r.getSource());
|
||||
Dataset<Relation> byTarget = pruneRelations(
|
||||
spark, inputRelationsPath, relationFilter, maxRelations, relPartitions,
|
||||
(Function<Relation, String>) r -> r.getTarget());
|
||||
|
||||
bySource
|
||||
.union(byTarget)
|
||||
.repartition(relPartitions)
|
||||
.write()
|
||||
.mode(SaveMode.Overwrite)
|
||||
.parquet(outputPath);
|
||||
}
|
||||
|
||||
private static Dataset<Relation> pruneRelations(SparkSession spark, String inputRelationsPath,
|
||||
Set<String> relationFilter, int maxRelations, int relPartitions,
|
||||
Function<Relation, String> idFn) {
|
||||
return readRelations(spark, inputRelationsPath, relationFilter, relPartitions)
|
||||
.groupByKey(
|
||||
(MapFunction<Relation, String>) r -> idFn.call(r),
|
||||
Encoders.STRING())
|
||||
.agg(new RelationAggregator(maxRelations).toColumn())
|
||||
.flatMap(
|
||||
(FlatMapFunction<Tuple2<String, RelationList>, Relation>) t -> t
|
||||
._2()
|
||||
.getRelations()
|
||||
.iterator(),
|
||||
Encoders.bean(Relation.class));
|
||||
}
|
||||
|
||||
private static Dataset<Relation> readRelations(SparkSession spark, String inputRelationsPath,
|
||||
Set<String> relationFilter, int relPartitions) {
|
||||
return spark
|
||||
spark
|
||||
.read()
|
||||
.textFile(inputRelationsPath)
|
||||
.repartition(relPartitions)
|
||||
|
@ -214,7 +167,20 @@ public class PrepareRelationsJob {
|
|||
(MapFunction<String, Relation>) s -> OBJECT_MAPPER.readValue(s, Relation.class),
|
||||
Encoders.kryo(Relation.class))
|
||||
.filter((FilterFunction<Relation>) rel -> rel.getDataInfo().getDeletedbyinference() == false)
|
||||
.filter((FilterFunction<Relation>) rel -> relationFilter.contains(rel.getRelClass()) == false);
|
||||
.filter((FilterFunction<Relation>) rel -> relationFilter.contains(rel.getRelClass()) == false)
|
||||
.groupByKey(
|
||||
(MapFunction<Relation, String>) Relation::getSource,
|
||||
Encoders.STRING())
|
||||
.agg(new RelationAggregator(maxRelations).toColumn())
|
||||
.flatMap(
|
||||
(FlatMapFunction<Tuple2<String, RelationList>, Relation>) t -> Iterables
|
||||
.limit(t._2().getRelations(), maxRelations)
|
||||
.iterator(),
|
||||
Encoders.bean(Relation.class))
|
||||
.repartition(relPartitions)
|
||||
.write()
|
||||
.mode(SaveMode.Overwrite)
|
||||
.parquet(outputPath);
|
||||
}
|
||||
|
||||
public static class RelationAggregator
|
||||
|
|
|
@ -133,6 +133,7 @@
|
|||
</spark-opts>
|
||||
<arg>--inputRelationsPath</arg><arg>${inputGraphRootPath}/relation</arg>
|
||||
<arg>--outputPath</arg><arg>${workingDir}/relation</arg>
|
||||
<arg>--maxRelations</arg><arg>${maxRelations}</arg>
|
||||
<arg>--relPartitions</arg><arg>5000</arg>
|
||||
</spark>
|
||||
<ok to="fork_join_related_entities"/>
|
||||
|
|
|
@ -0,0 +1,93 @@
|
|||
|
||||
package eu.dnetlib.dhp.oa.provision;
|
||||
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import eu.dnetlib.dhp.oa.provision.model.ProvisionModelSupport;
|
||||
import eu.dnetlib.dhp.schema.oaf.Relation;
|
||||
import org.apache.commons.io.FileUtils;
|
||||
import org.apache.spark.SparkConf;
|
||||
import org.apache.spark.api.java.function.FilterFunction;
|
||||
import org.apache.spark.sql.Dataset;
|
||||
import org.apache.spark.sql.Encoders;
|
||||
import org.apache.spark.sql.Row;
|
||||
import org.apache.spark.sql.SparkSession;
|
||||
import org.junit.jupiter.api.AfterAll;
|
||||
import org.junit.jupiter.api.Assertions;
|
||||
import org.junit.jupiter.api.BeforeAll;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.io.TempDir;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Path;
|
||||
|
||||
public class PrepareRelationsJobTest {
|
||||
|
||||
private static final Logger log = LoggerFactory.getLogger(PrepareRelationsJobTest.class);
|
||||
|
||||
public static final String SUBRELTYPE = "subRelType";
|
||||
public static final String OUTCOME = "outcome";
|
||||
public static final String SUPPLEMENT = "supplement";
|
||||
|
||||
private static SparkSession spark;
|
||||
|
||||
private static Path workingDir;
|
||||
|
||||
@BeforeAll
|
||||
public static void setUp() throws IOException {
|
||||
workingDir = Files.createTempDirectory(PrepareRelationsJobTest.class.getSimpleName());
|
||||
log.info("using work dir {}", workingDir);
|
||||
|
||||
SparkConf conf = new SparkConf();
|
||||
|
||||
conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer");
|
||||
conf.registerKryoClasses(ProvisionModelSupport.getModelClasses());
|
||||
|
||||
spark = SparkSession
|
||||
.builder()
|
||||
.appName(PrepareRelationsJobTest.class.getSimpleName())
|
||||
.master("local[*]")
|
||||
.config(conf)
|
||||
.getOrCreate();
|
||||
}
|
||||
|
||||
@AfterAll
|
||||
public static void afterAll() throws IOException {
|
||||
FileUtils.deleteDirectory(workingDir.toFile());
|
||||
spark.stop();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testRunPrepareRelationsJob(@TempDir Path testPath) throws Exception {
|
||||
|
||||
final int maxRelations = 10;
|
||||
PrepareRelationsJob
|
||||
.main(
|
||||
new String[] {
|
||||
"-isSparkSessionManaged", Boolean.FALSE.toString(),
|
||||
"-inputRelationsPath", getClass().getResource("relations.gz").getPath(),
|
||||
"-outputPath", testPath.toString(),
|
||||
"-relPartitions", "10",
|
||||
"-relationFilter", "asd",
|
||||
"-maxRelations", String.valueOf(maxRelations)
|
||||
});
|
||||
|
||||
Dataset<Relation> out = spark.read()
|
||||
.parquet(testPath.toString())
|
||||
.as(Encoders.bean(Relation.class))
|
||||
.cache();
|
||||
|
||||
Assertions.assertEquals(10, out.count());
|
||||
|
||||
Dataset<Row> freq = out.toDF().cube(SUBRELTYPE).count().filter((FilterFunction<Row>) value -> !value.isNullAt(0));
|
||||
long outcome = freq.filter(freq.col(SUBRELTYPE).equalTo(OUTCOME)).collectAsList().get(0).getAs("count");
|
||||
long supplement = freq.filter(freq.col(SUBRELTYPE).equalTo(SUPPLEMENT)).collectAsList().get(0).getAs("count");
|
||||
|
||||
Assertions.assertTrue(outcome > supplement);
|
||||
Assertions.assertEquals(7, outcome);
|
||||
Assertions.assertEquals(3, supplement);
|
||||
}
|
||||
|
||||
}
|
Binary file not shown.
|
@ -0,0 +1,11 @@
|
|||
# Set root logger level to DEBUG and its only appender to A1.
|
||||
log4j.rootLogger=INFO, A1
|
||||
|
||||
# A1 is set to be a ConsoleAppender.
|
||||
log4j.appender.A1=org.apache.log4j.ConsoleAppender
|
||||
|
||||
# A1 uses PatternLayout.
|
||||
log4j.logger.org = ERROR
|
||||
log4j.logger.eu.dnetlib = DEBUG
|
||||
log4j.appender.A1.layout=org.apache.log4j.PatternLayout
|
||||
log4j.appender.A1.layout.ConversionPattern=%-4r [%t] %-5p %c %x - %m%n
|
Loading…
Reference in New Issue