diff --git a/dhp-workflows/dhp-actionmanager/src/main/java/eu/dnetlib/dhp/actionmanager/PromoteActionSetFromHDFSFunctions.java b/dhp-workflows/dhp-actionmanager/src/main/java/eu/dnetlib/dhp/actionmanager/PromoteActionSetFromHDFSFunctions.java index c1f2e4c11..95b231a70 100644 --- a/dhp-workflows/dhp-actionmanager/src/main/java/eu/dnetlib/dhp/actionmanager/PromoteActionSetFromHDFSFunctions.java +++ b/dhp-workflows/dhp-actionmanager/src/main/java/eu/dnetlib/dhp/actionmanager/PromoteActionSetFromHDFSFunctions.java @@ -1,45 +1,134 @@ package eu.dnetlib.dhp.actionmanager; +import eu.dnetlib.dhp.schema.oaf.Oaf; import eu.dnetlib.dhp.schema.oaf.OafEntity; +import eu.dnetlib.dhp.schema.oaf.Relation; +import org.apache.spark.api.java.function.FilterFunction; import org.apache.spark.api.java.function.MapFunction; import org.apache.spark.api.java.function.ReduceFunction; -import org.apache.spark.sql.Column; import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoder; import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.TypedColumn; +import org.apache.spark.sql.expressions.Aggregator; import scala.Tuple2; +import java.util.Objects; import java.util.Optional; import java.util.function.BiFunction; +import java.util.function.Function; public class PromoteActionSetFromHDFSFunctions { - public static Dataset groupEntitiesByIdAndMerge(Dataset entityDS, - Class clazz) { - return entityDS - .groupByKey((MapFunction) OafEntity::getId, Encoders.STRING()) - .reduceGroups((ReduceFunction) (x1, x2) -> { - x1.mergeFrom(x2); - return x1; - }) - .map((MapFunction, T>) pair -> pair._2, Encoders.bean(clazz)); + public static Dataset joinOafEntityWithActionPayloadAndMerge(Dataset oafDS, + Dataset actionPayloadDS, + SerializableSupplier> oafIdFn, + SerializableSupplier, T>> actionPayloadToOafFn, + SerializableSupplier> mergeAndGetFn, + Class clazz) { + Dataset> oafWithIdDS = oafDS + .map((MapFunction>) value -> new Tuple2<>(oafIdFn.get().apply(value), value), + Encoders.tuple(Encoders.STRING(), Encoders.kryo(clazz))); + + Dataset> actionPayloadWithIdDS = actionPayloadDS + .map((MapFunction) value -> actionPayloadToOafFn.get().apply(value, clazz), Encoders.kryo(clazz)) + .filter((FilterFunction) Objects::nonNull) + .map((MapFunction>) value -> new Tuple2<>(oafIdFn.get().apply(value), value), + Encoders.tuple(Encoders.STRING(), Encoders.kryo(clazz))); + + return oafWithIdDS + .joinWith(actionPayloadWithIdDS, oafWithIdDS.col("_1").equalTo(actionPayloadWithIdDS.col("_1")), "left_outer") + .map((MapFunction, Tuple2>, T>) value -> { + T left = value._1()._2(); + return Optional + .ofNullable(value._2()) + .map(Tuple2::_2) + .map(x -> mergeAndGetFn.get().apply(left, x)) + .orElse(left); + }, Encoders.kryo(clazz)); } - public static Dataset joinEntitiesWithActionPayloadAndMerge(Dataset entityDS, - Dataset actionPayloadDS, - BiFunction, Dataset, Column> entityToActionPayloadJoinExpr, - BiFunction, T> actionPayloadToEntityFn, - Class clazz) { - return entityDS - .joinWith(actionPayloadDS, entityToActionPayloadJoinExpr.apply(entityDS, actionPayloadDS), "left_outer") - .map((MapFunction, T>) pair -> Optional - .ofNullable(pair._2()) - .map(x -> { - T entity = actionPayloadToEntityFn.apply(x, clazz); - pair._1().mergeFrom(entity); - return pair._1(); - }) - .orElse(pair._1()), Encoders.bean(clazz)); + public static Dataset groupOafByIdAndMerge(Dataset oafDS, + SerializableSupplier> oafIdFn, + SerializableSupplier> mergeAndGetFn, + Class clazz) { + return oafDS + .groupByKey((MapFunction) x -> oafIdFn.get().apply(x), Encoders.STRING()) + .reduceGroups((ReduceFunction) (v1, v2) -> mergeAndGetFn.get().apply(v1, v2)) + .map((MapFunction, T>) Tuple2::_2, Encoders.kryo(clazz)); } + public static Dataset groupOafByIdAndMergeUsingAggregator(Dataset oafDS, + SerializableSupplier zeroFn, + SerializableSupplier> idFn, + Class clazz) { + TypedColumn aggregator = new OafAggregator<>(zeroFn, clazz).toColumn(); + return oafDS + .groupByKey((MapFunction) x -> idFn.get().apply(x), Encoders.STRING()) + .agg(aggregator) + .map((MapFunction, T>) Tuple2::_2, Encoders.kryo(clazz)); + } -} + public static class OafAggregator extends Aggregator { + private SerializableSupplier zero; + private Class clazz; + + public OafAggregator(SerializableSupplier zero, Class clazz) { + this.zero = zero; + this.clazz = clazz; + } + + @Override + public T zero() { + return zero.get(); + } + + @Override + public T reduce(T b, T a) { + return mergeFrom(b, a); + } + + @Override + public T merge(T b1, T b2) { + return mergeFrom(b1, b2); + } + + private T mergeFrom(T left, T right) { + if (isNonNull(left)) { + if (left instanceof Relation) { + ((Relation) left).mergeFrom((Relation) right); + return left; + } + ((OafEntity) left).mergeFrom((OafEntity) right); + return left; + } + + if (right instanceof Relation) { + ((Relation) right).mergeFrom((Relation) left); + return right; + } + ((OafEntity) right).mergeFrom((OafEntity) left); + return right; + } + + private Boolean isNonNull(T a) { + return Objects.nonNull(a.getLastupdatetimestamp()); + } + + @Override + public T finish(T reduction) { + return reduction; + } + + @Override + public Encoder bufferEncoder() { + return Encoders.kryo(clazz); + } + + @Override + public Encoder outputEncoder() { + return Encoders.kryo(clazz); + } + + } +} \ No newline at end of file