enrichment steps #38

Merged
claudio.atzori merged 334 commits from miriam.baglioni/dnet-hadoop:master into enrichment_wfs 2020-08-11 16:40:26 +02:00
13 changed files with 355 additions and 114 deletions
Showing only changes of commit a6ea432435 - Show all commits

View File

@ -19,7 +19,6 @@ import org.apache.spark.sql.expressions.Aggregator;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.collect.Lists;
@ -28,8 +27,6 @@ import eu.dnetlib.dhp.common.HdfsSupport;
import eu.dnetlib.dhp.oa.provision.model.JoinedEntity;
import eu.dnetlib.dhp.oa.provision.model.ProvisionModelSupport;
import eu.dnetlib.dhp.oa.provision.model.RelatedEntityWrapper;
import eu.dnetlib.dhp.oa.provision.model.TypedRow;
import eu.dnetlib.dhp.schema.common.EntityType;
import eu.dnetlib.dhp.schema.common.ModelSupport;
import eu.dnetlib.dhp.schema.oaf.*;
import scala.Tuple2;
@ -305,20 +302,6 @@ public class CreateRelatedEntitiesJob_phase2 {
private static FilterFunction<JoinedEntity> filterEmptyEntityFn() {
return (FilterFunction<JoinedEntity>) v -> Objects.nonNull(v.getEntity());
/*
* return (FilterFunction<JoinedEntity>) v -> Optional .ofNullable(v.getEntity()) .map(e ->
* StringUtils.isNotBlank(e.getId())) .orElse(false);
*/
}
private static TypedRow getTypedRow(String type, OafEntity entity)
throws JsonProcessingException {
TypedRow t = new TypedRow();
t.setType(type);
t.setDeleted(entity.getDataInfo().getDeletedbyinference());
t.setId(entity.getId());
t.setOaf(OBJECT_MAPPER.writeValueAsString(entity));
return t;
}
private static void removeOutputDir(SparkSession spark, String path) {

View File

@ -3,28 +3,33 @@ package eu.dnetlib.dhp.oa.provision;
import static eu.dnetlib.dhp.common.SparkSessionSupport.runWithSparkSession;
import java.util.HashSet;
import java.util.Optional;
import java.util.Set;
import java.util.*;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import org.apache.commons.io.IOUtils;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.FilterFunction;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.MapFunction;
import org.apache.spark.rdd.RDD;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.SaveMode;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.*;
import org.apache.spark.sql.expressions.Aggregator;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.base.Splitter;
import com.google.common.collect.Iterables;
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
import eu.dnetlib.dhp.application.ArgumentApplicationParser;
import eu.dnetlib.dhp.common.HdfsSupport;
import eu.dnetlib.dhp.oa.provision.model.ProvisionModelSupport;
import eu.dnetlib.dhp.oa.provision.model.SortableRelationKey;
import eu.dnetlib.dhp.oa.provision.utils.RelationPartitioner;
import eu.dnetlib.dhp.schema.oaf.Relation;
@ -102,6 +107,8 @@ public class PrepareRelationsJob {
log.info("maxRelations: {}", maxRelations);
SparkConf conf = new SparkConf();
conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer");
conf.registerKryoClasses(ProvisionModelSupport.getModelClasses());
runWithSparkSession(
conf,
@ -125,9 +132,8 @@ public class PrepareRelationsJob {
* @param maxRelations maximum number of allowed outgoing edges
* @param relPartitions number of partitions for the output RDD
*/
private static void prepareRelationsRDD(
SparkSession spark, String inputRelationsPath, String outputPath, Set<String> relationFilter, int maxRelations,
int relPartitions) {
private static void prepareRelationsRDD(SparkSession spark, String inputRelationsPath, String outputPath,
Set<String> relationFilter, int maxRelations, int relPartitions) {
// group by SOURCE and apply limit
RDD<Relation> bySource = readPathRelationRDD(spark, inputRelationsPath)
@ -142,27 +148,95 @@ public class PrepareRelationsJob {
.map(Tuple2::_2)
.rdd();
// group by TARGET and apply limit
RDD<Relation> byTarget = readPathRelationRDD(spark, inputRelationsPath)
.filter(rel -> rel.getDataInfo().getDeletedbyinference() == false)
.filter(rel -> relationFilter.contains(rel.getRelClass()) == false)
.mapToPair(r -> new Tuple2<>(SortableRelationKey.create(r, r.getTarget()), r))
.repartitionAndSortWithinPartitions(new RelationPartitioner(relPartitions))
.groupBy(Tuple2::_1)
.map(Tuple2::_2)
.map(t -> Iterables.limit(t, maxRelations))
.flatMap(Iterable::iterator)
.map(Tuple2::_2)
.rdd();
spark
.createDataset(bySource.union(byTarget), Encoders.bean(Relation.class))
.createDataset(bySource, Encoders.bean(Relation.class))
.repartition(relPartitions)
.write()
.mode(SaveMode.Overwrite)
.parquet(outputPath);
}
private static void prepareRelationsDataset(
SparkSession spark, String inputRelationsPath, String outputPath, Set<String> relationFilter, int maxRelations,
int relPartitions) {
spark
.read()
.textFile(inputRelationsPath)
.repartition(relPartitions)
.map(
(MapFunction<String, Relation>) s -> OBJECT_MAPPER.readValue(s, Relation.class),
Encoders.kryo(Relation.class))
.filter((FilterFunction<Relation>) rel -> rel.getDataInfo().getDeletedbyinference() == false)
.filter((FilterFunction<Relation>) rel -> relationFilter.contains(rel.getRelClass()) == false)
.groupByKey(
(MapFunction<Relation, String>) Relation::getSource,
Encoders.STRING())
.agg(new RelationAggregator(maxRelations).toColumn())
.flatMap(
(FlatMapFunction<Tuple2<String, RelationList>, Relation>) t -> Iterables
.limit(t._2().getRelations(), maxRelations)
.iterator(),
Encoders.bean(Relation.class))
.repartition(relPartitions)
.write()
.mode(SaveMode.Overwrite)
.parquet(outputPath);
}
public static class RelationAggregator
extends Aggregator<Relation, RelationList, RelationList> {
private int maxRelations;
public RelationAggregator(int maxRelations) {
this.maxRelations = maxRelations;
}
@Override
public RelationList zero() {
return new RelationList();
}
@Override
public RelationList reduce(RelationList b, Relation a) {
b.getRelations().add(a);
return getSortableRelationList(b);
}
@Override
public RelationList merge(RelationList b1, RelationList b2) {
b1.getRelations().addAll(b2.getRelations());
return getSortableRelationList(b1);
}
@Override
public RelationList finish(RelationList r) {
return getSortableRelationList(r);
}
private RelationList getSortableRelationList(RelationList b1) {
RelationList sr = new RelationList();
sr
.setRelations(
b1
.getRelations()
.stream()
.limit(maxRelations)
.collect(Collectors.toCollection(() -> new PriorityQueue<>(new RelationComparator()))));
return sr;
}
@Override
public Encoder<RelationList> bufferEncoder() {
return Encoders.kryo(RelationList.class);
}
@Override
public Encoder<RelationList> outputEncoder() {
return Encoders.kryo(RelationList.class);
}
}
/**
* Reads a JavaRDD of eu.dnetlib.dhp.oa.provision.model.SortableRelation objects from a newline delimited json text
* file,

View File

@ -0,0 +1,43 @@
package eu.dnetlib.dhp.oa.provision;
import java.util.Comparator;
import java.util.Map;
import java.util.Optional;
import com.google.common.collect.ComparisonChain;
import com.google.common.collect.Maps;
import eu.dnetlib.dhp.schema.oaf.Relation;
public class RelationComparator implements Comparator<Relation> {
private static final Map<String, Integer> weights = Maps.newHashMap();
static {
weights.put("outcome", 0);
weights.put("supplement", 1);
weights.put("review", 2);
weights.put("citation", 3);
weights.put("affiliation", 4);
weights.put("relationship", 5);
weights.put("publicationDataset", 6);
weights.put("similarity", 7);
weights.put("provision", 8);
weights.put("participation", 9);
weights.put("dedup", 10);
}
private Integer getWeight(Relation o) {
return Optional.ofNullable(weights.get(o.getSubRelType())).orElse(Integer.MAX_VALUE);
}
@Override
public int compare(Relation o1, Relation o2) {
return ComparisonChain
.start()
.compare(getWeight(o1), getWeight(o2))
.result();
}
}

View File

@ -0,0 +1,25 @@
package eu.dnetlib.dhp.oa.provision;
import java.io.Serializable;
import java.util.PriorityQueue;
import java.util.Queue;
import eu.dnetlib.dhp.schema.oaf.Relation;
public class RelationList implements Serializable {
private Queue<Relation> relations;
public RelationList() {
this.relations = new PriorityQueue<>(new RelationComparator());
}
public Queue<Relation> getRelations() {
return relations;
}
public void setRelations(Queue<Relation> relations) {
this.relations = relations;
}
}

View File

@ -0,0 +1,80 @@
package eu.dnetlib.dhp.oa.provision;
import java.io.Serializable;
import java.util.Map;
import java.util.Optional;
import com.fasterxml.jackson.annotation.JsonIgnore;
import com.google.common.collect.ComparisonChain;
import com.google.common.collect.Maps;
import eu.dnetlib.dhp.schema.oaf.Relation;
public class SortableRelation extends Relation implements Comparable<SortableRelation>, Serializable {
private static final Map<String, Integer> weights = Maps.newHashMap();
static {
weights.put("outcome", 0);
weights.put("supplement", 1);
weights.put("review", 2);
weights.put("citation", 3);
weights.put("affiliation", 4);
weights.put("relationship", 5);
weights.put("publicationDataset", 6);
weights.put("similarity", 7);
weights.put("provision", 8);
weights.put("participation", 9);
weights.put("dedup", 10);
}
private static final long serialVersionUID = 34753984579L;
private String groupingKey;
public static SortableRelation create(Relation r, String groupingKey) {
SortableRelation sr = new SortableRelation();
sr.setGroupingKey(groupingKey);
sr.setSource(r.getSource());
sr.setTarget(r.getTarget());
sr.setRelType(r.getRelType());
sr.setSubRelType(r.getSubRelType());
sr.setRelClass(r.getRelClass());
sr.setDataInfo(r.getDataInfo());
sr.setCollectedfrom(r.getCollectedfrom());
sr.setLastupdatetimestamp(r.getLastupdatetimestamp());
sr.setProperties(r.getProperties());
sr.setValidated(r.getValidated());
sr.setValidationDate(r.getValidationDate());
return sr;
}
@JsonIgnore
public Relation asRelation() {
return this;
}
@Override
public int compareTo(SortableRelation o) {
return ComparisonChain
.start()
.compare(getGroupingKey(), o.getGroupingKey())
.compare(getWeight(this), getWeight(o))
.result();
}
private Integer getWeight(SortableRelation o) {
return Optional.ofNullable(weights.get(o.getSubRelType())).orElse(Integer.MAX_VALUE);
}
public String getGroupingKey() {
return groupingKey;
}
public void setGroupingKey(String groupingKey) {
this.groupingKey = groupingKey;
}
}

View File

@ -5,6 +5,8 @@ import java.util.List;
import com.google.common.collect.Lists;
import eu.dnetlib.dhp.oa.provision.RelationList;
import eu.dnetlib.dhp.oa.provision.SortableRelation;
import eu.dnetlib.dhp.schema.common.ModelSupport;
public class ProvisionModelSupport {
@ -15,11 +17,12 @@ public class ProvisionModelSupport {
.addAll(
Lists
.newArrayList(
TypedRow.class,
RelatedEntityWrapper.class,
JoinedEntity.class,
RelatedEntity.class,
SortableRelationKey.class));
SortableRelationKey.class,
SortableRelation.class,
RelationList.class));
return modelClasses.toArray(new Class[] {});
}
}

View File

@ -16,10 +16,6 @@ public class RelatedEntityWrapper implements Serializable {
}
public RelatedEntityWrapper(Relation relation, RelatedEntity target) {
this(null, relation, target);
}
public RelatedEntityWrapper(TypedRow entity, Relation relation, RelatedEntity target) {
this.relation = relation;
this.target = target;
}

View File

@ -1,64 +0,0 @@
package eu.dnetlib.dhp.oa.provision.model;
import java.io.Serializable;
import com.google.common.base.Objects;
public class TypedRow implements Serializable {
private String id;
private Boolean deleted;
private String type;
private String oaf;
public String getId() {
return id;
}
public void setId(String id) {
this.id = id;
}
public Boolean getDeleted() {
return deleted;
}
public void setDeleted(Boolean deleted) {
this.deleted = deleted;
}
public String getType() {
return type;
}
public void setType(String type) {
this.type = type;
}
public String getOaf() {
return oaf;
}
public void setOaf(String oaf) {
this.oaf = oaf;
}
@Override
public boolean equals(Object o) {
if (this == o)
return true;
if (o == null || getClass() != o.getClass())
return false;
TypedRow typedRow2 = (TypedRow) o;
return Objects.equal(id, typedRow2.id);
}
@Override
public int hashCode() {
return Objects.hashCode(id);
}
}

View File

@ -121,10 +121,6 @@ public class XmlRecordFactory implements Serializable {
}
}
private static OafEntity toOafEntity(TypedRow typedRow) {
return parseOaf(typedRow.getOaf(), typedRow.getType());
}
private static OafEntity parseOaf(final String json, final String type) {
try {
switch (EntityType.valueOf(type)) {

View File

@ -133,6 +133,7 @@
</spark-opts>
<arg>--inputRelationsPath</arg><arg>${inputGraphRootPath}/relation</arg>
<arg>--outputPath</arg><arg>${workingDir}/relation</arg>
<arg>--maxRelations</arg><arg>${maxRelations}</arg>
<arg>--relPartitions</arg><arg>5000</arg>
</spark>
<ok to="fork_join_related_entities"/>

View File

@ -0,0 +1,93 @@
package eu.dnetlib.dhp.oa.provision;
import com.fasterxml.jackson.databind.ObjectMapper;
import eu.dnetlib.dhp.oa.provision.model.ProvisionModelSupport;
import eu.dnetlib.dhp.schema.oaf.Relation;
import org.apache.commons.io.FileUtils;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.function.FilterFunction;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
public class PrepareRelationsJobTest {
private static final Logger log = LoggerFactory.getLogger(PrepareRelationsJobTest.class);
public static final String SUBRELTYPE = "subRelType";
public static final String OUTCOME = "outcome";
public static final String SUPPLEMENT = "supplement";
private static SparkSession spark;
private static Path workingDir;
@BeforeAll
public static void setUp() throws IOException {
workingDir = Files.createTempDirectory(PrepareRelationsJobTest.class.getSimpleName());
log.info("using work dir {}", workingDir);
SparkConf conf = new SparkConf();
conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer");
conf.registerKryoClasses(ProvisionModelSupport.getModelClasses());
spark = SparkSession
.builder()
.appName(PrepareRelationsJobTest.class.getSimpleName())
.master("local[*]")
.config(conf)
.getOrCreate();
}
@AfterAll
public static void afterAll() throws IOException {
FileUtils.deleteDirectory(workingDir.toFile());
spark.stop();
}
@Test
public void testRunPrepareRelationsJob(@TempDir Path testPath) throws Exception {
final int maxRelations = 10;
PrepareRelationsJob
.main(
new String[] {
"-isSparkSessionManaged", Boolean.FALSE.toString(),
"-inputRelationsPath", getClass().getResource("relations.gz").getPath(),
"-outputPath", testPath.toString(),
"-relPartitions", "10",
"-relationFilter", "asd",
"-maxRelations", String.valueOf(maxRelations)
});
Dataset<Relation> out = spark.read()
.parquet(testPath.toString())
.as(Encoders.bean(Relation.class))
.cache();
Assertions.assertEquals(10, out.count());
Dataset<Row> freq = out.toDF().cube(SUBRELTYPE).count().filter((FilterFunction<Row>) value -> !value.isNullAt(0));
long outcome = freq.filter(freq.col(SUBRELTYPE).equalTo(OUTCOME)).collectAsList().get(0).getAs("count");
long supplement = freq.filter(freq.col(SUBRELTYPE).equalTo(SUPPLEMENT)).collectAsList().get(0).getAs("count");
Assertions.assertTrue(outcome > supplement);
Assertions.assertEquals(7, outcome);
Assertions.assertEquals(3, supplement);
}
}

View File

@ -0,0 +1,11 @@
# Set root logger level to DEBUG and its only appender to A1.
log4j.rootLogger=INFO, A1
# A1 is set to be a ConsoleAppender.
log4j.appender.A1=org.apache.log4j.ConsoleAppender
# A1 uses PatternLayout.
log4j.logger.org = ERROR
log4j.logger.eu.dnetlib = DEBUG
log4j.appender.A1.layout=org.apache.log4j.PatternLayout
log4j.appender.A1.layout.ConversionPattern=%-4r [%t] %-5p %c %x - %m%n