added test to verify the relation pre-processing
This commit is contained in:
parent
8d59fdf34e
commit
7817338e05
|
@ -115,7 +115,7 @@ public class PrepareRelationsJob {
|
||||||
isSparkSessionManaged,
|
isSparkSessionManaged,
|
||||||
spark -> {
|
spark -> {
|
||||||
removeOutputDir(spark, outputPath);
|
removeOutputDir(spark, outputPath);
|
||||||
prepareRelationsDataset(
|
prepareRelationsRDD(
|
||||||
spark, inputRelationsPath, outputPath, relationFilter, maxRelations, relPartitions);
|
spark, inputRelationsPath, outputPath, relationFilter, maxRelations, relPartitions);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
@ -148,21 +148,8 @@ public class PrepareRelationsJob {
|
||||||
.map(Tuple2::_2)
|
.map(Tuple2::_2)
|
||||||
.rdd();
|
.rdd();
|
||||||
|
|
||||||
// group by TARGET and apply limit
|
|
||||||
RDD<Relation> byTarget = readPathRelationRDD(spark, inputRelationsPath)
|
|
||||||
.filter(rel -> rel.getDataInfo().getDeletedbyinference() == false)
|
|
||||||
.filter(rel -> relationFilter.contains(rel.getRelClass()) == false)
|
|
||||||
.mapToPair(r -> new Tuple2<>(SortableRelationKey.create(r, r.getTarget()), r))
|
|
||||||
.repartitionAndSortWithinPartitions(new RelationPartitioner(relPartitions))
|
|
||||||
.groupBy(Tuple2::_1)
|
|
||||||
.map(Tuple2::_2)
|
|
||||||
.map(t -> Iterables.limit(t, maxRelations))
|
|
||||||
.flatMap(Iterable::iterator)
|
|
||||||
.map(Tuple2::_2)
|
|
||||||
.rdd();
|
|
||||||
|
|
||||||
spark
|
spark
|
||||||
.createDataset(bySource.union(byTarget), Encoders.bean(Relation.class))
|
.createDataset(bySource, Encoders.bean(Relation.class))
|
||||||
.repartition(relPartitions)
|
.repartition(relPartitions)
|
||||||
.write()
|
.write()
|
||||||
.mode(SaveMode.Overwrite)
|
.mode(SaveMode.Overwrite)
|
||||||
|
@ -172,41 +159,7 @@ public class PrepareRelationsJob {
|
||||||
private static void prepareRelationsDataset(
|
private static void prepareRelationsDataset(
|
||||||
SparkSession spark, String inputRelationsPath, String outputPath, Set<String> relationFilter, int maxRelations,
|
SparkSession spark, String inputRelationsPath, String outputPath, Set<String> relationFilter, int maxRelations,
|
||||||
int relPartitions) {
|
int relPartitions) {
|
||||||
|
spark
|
||||||
Dataset<Relation> bySource = pruneRelations(
|
|
||||||
spark, inputRelationsPath, relationFilter, maxRelations, relPartitions,
|
|
||||||
(Function<Relation, String>) r -> r.getSource());
|
|
||||||
Dataset<Relation> byTarget = pruneRelations(
|
|
||||||
spark, inputRelationsPath, relationFilter, maxRelations, relPartitions,
|
|
||||||
(Function<Relation, String>) r -> r.getTarget());
|
|
||||||
|
|
||||||
bySource
|
|
||||||
.union(byTarget)
|
|
||||||
.repartition(relPartitions)
|
|
||||||
.write()
|
|
||||||
.mode(SaveMode.Overwrite)
|
|
||||||
.parquet(outputPath);
|
|
||||||
}
|
|
||||||
|
|
||||||
private static Dataset<Relation> pruneRelations(SparkSession spark, String inputRelationsPath,
|
|
||||||
Set<String> relationFilter, int maxRelations, int relPartitions,
|
|
||||||
Function<Relation, String> idFn) {
|
|
||||||
return readRelations(spark, inputRelationsPath, relationFilter, relPartitions)
|
|
||||||
.groupByKey(
|
|
||||||
(MapFunction<Relation, String>) r -> idFn.call(r),
|
|
||||||
Encoders.STRING())
|
|
||||||
.agg(new RelationAggregator(maxRelations).toColumn())
|
|
||||||
.flatMap(
|
|
||||||
(FlatMapFunction<Tuple2<String, RelationList>, Relation>) t -> t
|
|
||||||
._2()
|
|
||||||
.getRelations()
|
|
||||||
.iterator(),
|
|
||||||
Encoders.bean(Relation.class));
|
|
||||||
}
|
|
||||||
|
|
||||||
private static Dataset<Relation> readRelations(SparkSession spark, String inputRelationsPath,
|
|
||||||
Set<String> relationFilter, int relPartitions) {
|
|
||||||
return spark
|
|
||||||
.read()
|
.read()
|
||||||
.textFile(inputRelationsPath)
|
.textFile(inputRelationsPath)
|
||||||
.repartition(relPartitions)
|
.repartition(relPartitions)
|
||||||
|
@ -214,7 +167,20 @@ public class PrepareRelationsJob {
|
||||||
(MapFunction<String, Relation>) s -> OBJECT_MAPPER.readValue(s, Relation.class),
|
(MapFunction<String, Relation>) s -> OBJECT_MAPPER.readValue(s, Relation.class),
|
||||||
Encoders.kryo(Relation.class))
|
Encoders.kryo(Relation.class))
|
||||||
.filter((FilterFunction<Relation>) rel -> rel.getDataInfo().getDeletedbyinference() == false)
|
.filter((FilterFunction<Relation>) rel -> rel.getDataInfo().getDeletedbyinference() == false)
|
||||||
.filter((FilterFunction<Relation>) rel -> relationFilter.contains(rel.getRelClass()) == false);
|
.filter((FilterFunction<Relation>) rel -> relationFilter.contains(rel.getRelClass()) == false)
|
||||||
|
.groupByKey(
|
||||||
|
(MapFunction<Relation, String>) Relation::getSource,
|
||||||
|
Encoders.STRING())
|
||||||
|
.agg(new RelationAggregator(maxRelations).toColumn())
|
||||||
|
.flatMap(
|
||||||
|
(FlatMapFunction<Tuple2<String, RelationList>, Relation>) t -> Iterables
|
||||||
|
.limit(t._2().getRelations(), maxRelations)
|
||||||
|
.iterator(),
|
||||||
|
Encoders.bean(Relation.class))
|
||||||
|
.repartition(relPartitions)
|
||||||
|
.write()
|
||||||
|
.mode(SaveMode.Overwrite)
|
||||||
|
.parquet(outputPath);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static class RelationAggregator
|
public static class RelationAggregator
|
||||||
|
|
|
@ -133,6 +133,7 @@
|
||||||
</spark-opts>
|
</spark-opts>
|
||||||
<arg>--inputRelationsPath</arg><arg>${inputGraphRootPath}/relation</arg>
|
<arg>--inputRelationsPath</arg><arg>${inputGraphRootPath}/relation</arg>
|
||||||
<arg>--outputPath</arg><arg>${workingDir}/relation</arg>
|
<arg>--outputPath</arg><arg>${workingDir}/relation</arg>
|
||||||
|
<arg>--maxRelations</arg><arg>${maxRelations}</arg>
|
||||||
<arg>--relPartitions</arg><arg>5000</arg>
|
<arg>--relPartitions</arg><arg>5000</arg>
|
||||||
</spark>
|
</spark>
|
||||||
<ok to="fork_join_related_entities"/>
|
<ok to="fork_join_related_entities"/>
|
||||||
|
|
|
@ -0,0 +1,93 @@
|
||||||
|
|
||||||
|
package eu.dnetlib.dhp.oa.provision;
|
||||||
|
|
||||||
|
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||||
|
import eu.dnetlib.dhp.oa.provision.model.ProvisionModelSupport;
|
||||||
|
import eu.dnetlib.dhp.schema.oaf.Relation;
|
||||||
|
import org.apache.commons.io.FileUtils;
|
||||||
|
import org.apache.spark.SparkConf;
|
||||||
|
import org.apache.spark.api.java.function.FilterFunction;
|
||||||
|
import org.apache.spark.sql.Dataset;
|
||||||
|
import org.apache.spark.sql.Encoders;
|
||||||
|
import org.apache.spark.sql.Row;
|
||||||
|
import org.apache.spark.sql.SparkSession;
|
||||||
|
import org.junit.jupiter.api.AfterAll;
|
||||||
|
import org.junit.jupiter.api.Assertions;
|
||||||
|
import org.junit.jupiter.api.BeforeAll;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
import org.junit.jupiter.api.io.TempDir;
|
||||||
|
import org.slf4j.Logger;
|
||||||
|
import org.slf4j.LoggerFactory;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.nio.file.Files;
|
||||||
|
import java.nio.file.Path;
|
||||||
|
|
||||||
|
public class PrepareRelationsJobTest {
|
||||||
|
|
||||||
|
private static final Logger log = LoggerFactory.getLogger(PrepareRelationsJobTest.class);
|
||||||
|
|
||||||
|
public static final String SUBRELTYPE = "subRelType";
|
||||||
|
public static final String OUTCOME = "outcome";
|
||||||
|
public static final String SUPPLEMENT = "supplement";
|
||||||
|
|
||||||
|
private static SparkSession spark;
|
||||||
|
|
||||||
|
private static Path workingDir;
|
||||||
|
|
||||||
|
@BeforeAll
|
||||||
|
public static void setUp() throws IOException {
|
||||||
|
workingDir = Files.createTempDirectory(PrepareRelationsJobTest.class.getSimpleName());
|
||||||
|
log.info("using work dir {}", workingDir);
|
||||||
|
|
||||||
|
SparkConf conf = new SparkConf();
|
||||||
|
|
||||||
|
conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer");
|
||||||
|
conf.registerKryoClasses(ProvisionModelSupport.getModelClasses());
|
||||||
|
|
||||||
|
spark = SparkSession
|
||||||
|
.builder()
|
||||||
|
.appName(PrepareRelationsJobTest.class.getSimpleName())
|
||||||
|
.master("local[*]")
|
||||||
|
.config(conf)
|
||||||
|
.getOrCreate();
|
||||||
|
}
|
||||||
|
|
||||||
|
@AfterAll
|
||||||
|
public static void afterAll() throws IOException {
|
||||||
|
FileUtils.deleteDirectory(workingDir.toFile());
|
||||||
|
spark.stop();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testRunPrepareRelationsJob(@TempDir Path testPath) throws Exception {
|
||||||
|
|
||||||
|
final int maxRelations = 10;
|
||||||
|
PrepareRelationsJob
|
||||||
|
.main(
|
||||||
|
new String[] {
|
||||||
|
"-isSparkSessionManaged", Boolean.FALSE.toString(),
|
||||||
|
"-inputRelationsPath", getClass().getResource("relations.gz").getPath(),
|
||||||
|
"-outputPath", testPath.toString(),
|
||||||
|
"-relPartitions", "10",
|
||||||
|
"-relationFilter", "asd",
|
||||||
|
"-maxRelations", String.valueOf(maxRelations)
|
||||||
|
});
|
||||||
|
|
||||||
|
Dataset<Relation> out = spark.read()
|
||||||
|
.parquet(testPath.toString())
|
||||||
|
.as(Encoders.bean(Relation.class))
|
||||||
|
.cache();
|
||||||
|
|
||||||
|
Assertions.assertEquals(10, out.count());
|
||||||
|
|
||||||
|
Dataset<Row> freq = out.toDF().cube(SUBRELTYPE).count().filter((FilterFunction<Row>) value -> !value.isNullAt(0));
|
||||||
|
long outcome = freq.filter(freq.col(SUBRELTYPE).equalTo(OUTCOME)).collectAsList().get(0).getAs("count");
|
||||||
|
long supplement = freq.filter(freq.col(SUBRELTYPE).equalTo(SUPPLEMENT)).collectAsList().get(0).getAs("count");
|
||||||
|
|
||||||
|
Assertions.assertTrue(outcome > supplement);
|
||||||
|
Assertions.assertEquals(7, outcome);
|
||||||
|
Assertions.assertEquals(3, supplement);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
Binary file not shown.
|
@ -0,0 +1,11 @@
|
||||||
|
# Set root logger level to DEBUG and its only appender to A1.
|
||||||
|
log4j.rootLogger=INFO, A1
|
||||||
|
|
||||||
|
# A1 is set to be a ConsoleAppender.
|
||||||
|
log4j.appender.A1=org.apache.log4j.ConsoleAppender
|
||||||
|
|
||||||
|
# A1 uses PatternLayout.
|
||||||
|
log4j.logger.org = ERROR
|
||||||
|
log4j.logger.eu.dnetlib = DEBUG
|
||||||
|
log4j.appender.A1.layout=org.apache.log4j.PatternLayout
|
||||||
|
log4j.appender.A1.layout.ConversionPattern=%-4r [%t] %-5p %c %x - %m%n
|
Loading…
Reference in New Issue