diff --git a/dhp-workflows/dhp-actionmanager/src/main/java/eu/dnetlib/dhp/actionmanager/PromoteActionSetFromHDFSFunctions.java b/dhp-workflows/dhp-actionmanager/src/main/java/eu/dnetlib/dhp/actionmanager/PromoteActionSetFromHDFSFunctions.java index 95b231a702..994989e45a 100644 --- a/dhp-workflows/dhp-actionmanager/src/main/java/eu/dnetlib/dhp/actionmanager/PromoteActionSetFromHDFSFunctions.java +++ b/dhp-workflows/dhp-actionmanager/src/main/java/eu/dnetlib/dhp/actionmanager/PromoteActionSetFromHDFSFunctions.java @@ -3,7 +3,6 @@ package eu.dnetlib.dhp.actionmanager; import eu.dnetlib.dhp.schema.oaf.Oaf; import eu.dnetlib.dhp.schema.oaf.OafEntity; import eu.dnetlib.dhp.schema.oaf.Relation; -import org.apache.spark.api.java.function.FilterFunction; import org.apache.spark.api.java.function.MapFunction; import org.apache.spark.api.java.function.ReduceFunction; import org.apache.spark.sql.Dataset; @@ -21,20 +20,12 @@ import java.util.function.Function; public class PromoteActionSetFromHDFSFunctions { public static Dataset joinOafEntityWithActionPayloadAndMerge(Dataset oafDS, - Dataset actionPayloadDS, + Dataset actionPayloadOafDS, SerializableSupplier> oafIdFn, - SerializableSupplier, T>> actionPayloadToOafFn, SerializableSupplier> mergeAndGetFn, Class clazz) { - Dataset> oafWithIdDS = oafDS - .map((MapFunction>) value -> new Tuple2<>(oafIdFn.get().apply(value), value), - Encoders.tuple(Encoders.STRING(), Encoders.kryo(clazz))); - - Dataset> actionPayloadWithIdDS = actionPayloadDS - .map((MapFunction) value -> actionPayloadToOafFn.get().apply(value, clazz), Encoders.kryo(clazz)) - .filter((FilterFunction) Objects::nonNull) - .map((MapFunction>) value -> new Tuple2<>(oafIdFn.get().apply(value), value), - Encoders.tuple(Encoders.STRING(), Encoders.kryo(clazz))); + Dataset> oafWithIdDS = mapToTupleWithId(oafDS, oafIdFn, clazz); + Dataset> actionPayloadWithIdDS = mapToTupleWithId(actionPayloadOafDS, oafIdFn, clazz); return oafWithIdDS .joinWith(actionPayloadWithIdDS, oafWithIdDS.col("_1").equalTo(actionPayloadWithIdDS.col("_1")), "left_outer") @@ -48,6 +39,14 @@ public class PromoteActionSetFromHDFSFunctions { }, Encoders.kryo(clazz)); } + private static Dataset> mapToTupleWithId(Dataset ds, + SerializableSupplier> idFn, + Class clazz) { + return ds + .map((MapFunction>) value -> new Tuple2<>(idFn.get().apply(value), value), + Encoders.tuple(Encoders.STRING(), Encoders.kryo(clazz))); + } + public static Dataset groupOafByIdAndMerge(Dataset oafDS, SerializableSupplier> oafIdFn, SerializableSupplier> mergeAndGetFn, diff --git a/dhp-workflows/dhp-actionmanager/src/test/java/eu/dnetlib/dhp/actionmanager/PromoteActionSetFromHDFSFunctionsTest.java b/dhp-workflows/dhp-actionmanager/src/test/java/eu/dnetlib/dhp/actionmanager/PromoteActionSetFromHDFSFunctionsTest.java index a4db01f7ad..a05dac9107 100644 --- a/dhp-workflows/dhp-actionmanager/src/test/java/eu/dnetlib/dhp/actionmanager/PromoteActionSetFromHDFSFunctionsTest.java +++ b/dhp-workflows/dhp-actionmanager/src/test/java/eu/dnetlib/dhp/actionmanager/PromoteActionSetFromHDFSFunctionsTest.java @@ -1,7 +1,5 @@ package eu.dnetlib.dhp.actionmanager; -import com.fasterxml.jackson.databind.JsonNode; -import com.fasterxml.jackson.databind.ObjectMapper; import eu.dnetlib.dhp.schema.oaf.Oaf; import org.apache.spark.SparkConf; import org.apache.spark.sql.Dataset; @@ -11,7 +9,6 @@ import org.junit.AfterClass; import org.junit.BeforeClass; import org.junit.Test; -import java.io.IOException; import java.util.Arrays; import java.util.List; import java.util.function.BiFunction; @@ -50,38 +47,24 @@ public class PromoteActionSetFromHDFSFunctionsTest { ); Dataset oafDS = spark.createDataset(oafData, Encoders.bean(OafImpl.class)); - List actionPayloadData = Arrays.asList( - createActionPayload(id1), - createActionPayload(id2), createActionPayload(id2), - createActionPayload(id3), createActionPayload(id3), createActionPayload(id3) + List actionPayloadData = Arrays.asList( + createOafImpl(id1), + createOafImpl(id2), createOafImpl(id2), + createOafImpl(id3), createOafImpl(id3), createOafImpl(id3) ); - Dataset actionPayloadDS = spark.createDataset(actionPayloadData, Encoders.STRING()); + Dataset actionPayloadDS = spark.createDataset(actionPayloadData, Encoders.bean(OafImpl.class)); SerializableSupplier> oafIdFn = () -> OafImpl::getId; - SerializableSupplier, OafImpl>> actionPayloadToOafFn = () -> (s, clazz) -> { - try { - JsonNode jsonNode = new ObjectMapper().readTree(s); - String id = jsonNode.at("/id").asText(); - return createOafImpl(id); - } catch (IOException e) { - throw new RuntimeException(e); - } - }; - SerializableSupplier> mergeAndGetFn = () -> (x, y) -> { - x.mergeFrom(y); - return x; - }; + SerializableSupplier> mergeAndGetFn = () -> OafImpl::mergeAngGet; // when List results = PromoteActionSetFromHDFSFunctions .joinOafEntityWithActionPayloadAndMerge(oafDS, actionPayloadDS, oafIdFn, - actionPayloadToOafFn, mergeAndGetFn, OafImpl.class) .collectAsList(); -// System.out.println(results.stream().map(x -> String.format("%s:%d", x.getId(), x.merged)).collect(Collectors.joining(","))); // then assertEquals(7, results.size()); @@ -95,6 +78,8 @@ public class PromoteActionSetFromHDFSFunctionsTest { case "id4": assertEquals(1, result.merged); break; + default: + throw new RuntimeException(); } }); } @@ -112,10 +97,7 @@ public class PromoteActionSetFromHDFSFunctionsTest { ); Dataset oafDS = spark.createDataset(oafData, Encoders.bean(OafImpl.class)); SerializableSupplier> idFn = () -> OafImpl::getId; - SerializableSupplier> mergeAndGetFn = () -> (x, y) -> { - x.mergeFrom(y); - return x; - }; + SerializableSupplier> mergeAndGetFn = () -> OafImpl::mergeAngGet; // when List results = PromoteActionSetFromHDFSFunctions @@ -124,7 +106,6 @@ public class PromoteActionSetFromHDFSFunctionsTest { mergeAndGetFn, OafImpl.class) .collectAsList(); -// System.out.println(results.stream().map(x -> String.format("%s:%d", x.getId(), x.merged)).collect(Collectors.joining(","))); // then assertEquals(3, results.size()); @@ -139,6 +120,8 @@ public class PromoteActionSetFromHDFSFunctionsTest { case "id3": assertEquals(3, result.merged); break; + default: + throw new RuntimeException(); } }); } @@ -147,8 +130,9 @@ public class PromoteActionSetFromHDFSFunctionsTest { private String id; private int merged = 1; - public void mergeFrom(Oaf e) { - merged += ((OafImpl) e).merged; + public OafImpl mergeAngGet(OafImpl e) { + merged += e.merged; + return this; } public String getId() { @@ -174,7 +158,4 @@ public class PromoteActionSetFromHDFSFunctionsTest { return x; } - private static String createActionPayload(String id) { - return String.format("{\"id\":\"%s\"}", id); - } }