package eu.dnetlib.dhp.actionmanager.promote; import static eu.dnetlib.dhp.schema.common.ModelSupport.isSubClass; import eu.dnetlib.dhp.common.FunctionalInterfaceSupport.SerializableSupplier; import eu.dnetlib.dhp.schema.oaf.Oaf; import java.util.Objects; import java.util.Optional; import java.util.function.BiFunction; import java.util.function.Function; import org.apache.spark.api.java.function.FilterFunction; import org.apache.spark.api.java.function.MapFunction; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Encoder; import org.apache.spark.sql.Encoders; import org.apache.spark.sql.TypedColumn; import org.apache.spark.sql.expressions.Aggregator; import scala.Tuple2; /** Promote action payload functions. */ public class PromoteActionPayloadFunctions { private PromoteActionPayloadFunctions() {} /** * Joins dataset representing graph table with dataset representing action payload using * supplied functions. * * @param rowDS Dataset representing graph table * @param actionPayloadDS Dataset representing action payload * @param rowIdFn Function used to get the id of graph table row * @param actionPayloadIdFn Function used to get id of action payload instance * @param mergeAndGetFn Function used to merge graph table row and action payload instance * @param rowClazz Class of graph table * @param actionPayloadClazz Class of action payload * @param Type of graph table row * @param Type of action payload instance * @return Dataset of merged graph table rows and action payload instances */ public static Dataset joinGraphTableWithActionPayloadAndMerge( Dataset rowDS, Dataset actionPayloadDS, SerializableSupplier> rowIdFn, SerializableSupplier> actionPayloadIdFn, SerializableSupplier> mergeAndGetFn, Class rowClazz, Class actionPayloadClazz) { if (!isSubClass(rowClazz, actionPayloadClazz)) { throw new RuntimeException( "action payload type must be the same or be a super type of table row type"); } Dataset> rowWithIdDS = mapToTupleWithId(rowDS, rowIdFn, rowClazz); Dataset> actionPayloadWithIdDS = mapToTupleWithId(actionPayloadDS, actionPayloadIdFn, actionPayloadClazz); return rowWithIdDS .joinWith( actionPayloadWithIdDS, rowWithIdDS.col("_1").equalTo(actionPayloadWithIdDS.col("_1")), "full_outer") .map( (MapFunction, Tuple2>, G>) value -> { Optional rowOpt = Optional.ofNullable(value._1()).map(Tuple2::_2); Optional actionPayloadOpt = Optional.ofNullable(value._2()).map(Tuple2::_2); return rowOpt.map( row -> actionPayloadOpt .map( actionPayload -> mergeAndGetFn .get() .apply( row, actionPayload)) .orElse(row)) .orElseGet( () -> actionPayloadOpt .filter( actionPayload -> actionPayload .getClass() .equals( rowClazz)) .map(rowClazz::cast) .orElse(null)); }, Encoders.kryo(rowClazz)) .filter((FilterFunction) Objects::nonNull); } private static Dataset> mapToTupleWithId( Dataset ds, SerializableSupplier> idFn, Class clazz) { return ds.map( (MapFunction>) value -> new Tuple2<>(idFn.get().apply(value), value), Encoders.tuple(Encoders.STRING(), Encoders.kryo(clazz))); } /** * Groups graph table by id and aggregates using supplied functions. * * @param rowDS Dataset representing graph table * @param rowIdFn Function used to get the id of graph table row * @param mergeAndGetFn Function used to merge graph table rows * @param zeroFn Function to create a zero/empty instance of graph table row * @param isNotZeroFn Function to check if graph table row is not zero/empty * @param rowClazz Class of graph table * @param Type of graph table row * @return Dataset of aggregated graph table rows */ public static Dataset groupGraphTableByIdAndMerge( Dataset rowDS, SerializableSupplier> rowIdFn, SerializableSupplier> mergeAndGetFn, SerializableSupplier zeroFn, SerializableSupplier> isNotZeroFn, Class rowClazz) { TypedColumn aggregator = new TableAggregator<>(zeroFn, mergeAndGetFn, isNotZeroFn, rowClazz).toColumn(); return rowDS.groupByKey( (MapFunction) x -> rowIdFn.get().apply(x), Encoders.STRING()) .agg(aggregator) .map((MapFunction, G>) Tuple2::_2, Encoders.kryo(rowClazz)); } /** * Aggregator to be used for aggregating graph table rows during grouping. * * @param Type of graph table row */ public static class TableAggregator extends Aggregator { private SerializableSupplier zeroFn; private SerializableSupplier> mergeAndGetFn; private SerializableSupplier> isNotZeroFn; private Class rowClazz; public TableAggregator( SerializableSupplier zeroFn, SerializableSupplier> mergeAndGetFn, SerializableSupplier> isNotZeroFn, Class rowClazz) { this.zeroFn = zeroFn; this.mergeAndGetFn = mergeAndGetFn; this.isNotZeroFn = isNotZeroFn; this.rowClazz = rowClazz; } @Override public G zero() { return zeroFn.get(); } @Override public G reduce(G b, G a) { return zeroSafeMergeAndGet(b, a); } @Override public G merge(G b1, G b2) { return zeroSafeMergeAndGet(b1, b2); } private G zeroSafeMergeAndGet(G left, G right) { Function isNotZero = isNotZeroFn.get(); if (isNotZero.apply(left) && isNotZero.apply(right)) { return mergeAndGetFn.get().apply(left, right); } else if (isNotZero.apply(left) && !isNotZero.apply(right)) { return left; } else if (!isNotZero.apply(left) && isNotZero.apply(right)) { return right; } throw new RuntimeException( "internal aggregation error: left and right objects are zero"); } @Override public G finish(G reduction) { return reduction; } @Override public Encoder bufferEncoder() { return Encoders.kryo(rowClazz); } @Override public Encoder outputEncoder() { return Encoders.kryo(rowClazz); } } }