From dcc08cc512f54af958fe49f7ac0c2baef32bcbc3 Mon Sep 17 00:00:00 2001 From: Giambattista Bloisi Date: Fri, 7 Jul 2023 12:35:30 +0200 Subject: [PATCH] Use UDAF and Aggregation class for testing --- .../dnetlib/pace/model/SparkDedupConfig.scala | 208 ++++++++++++++++-- .../dhp/oa/dedup/SparkCreateSimRels.java | 6 +- 2 files changed, 193 insertions(+), 21 deletions(-) diff --git a/dhp-pace-core/src/main/java/eu/dnetlib/pace/model/SparkDedupConfig.scala b/dhp-pace-core/src/main/java/eu/dnetlib/pace/model/SparkDedupConfig.scala index 286b256ef..9a89de57f 100644 --- a/dhp-pace-core/src/main/java/eu/dnetlib/pace/model/SparkDedupConfig.scala +++ b/dhp-pace-core/src/main/java/eu/dnetlib/pace/model/SparkDedupConfig.scala @@ -6,10 +6,12 @@ import eu.dnetlib.pace.tree.support.TreeProcessor import eu.dnetlib.pace.util.MapDocumentUtil.truncateValue import eu.dnetlib.pace.util.{BlockProcessor, MapDocumentUtil, SparkReporter} import org.apache.spark.SparkContext -import org.apache.spark.sql.{Column, Dataset, Row, functions} +import org.apache.spark.rdd.RDD.rddToPairRDDFunctions +import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} +import org.apache.spark.sql.{Column, Dataset, Encoder, Encoders, Row, functions} import org.apache.spark.sql.catalyst.expressions.{GenericRowWithSchema, Literal} -import org.apache.spark.sql.expressions.{UserDefinedFunction, Window} -import org.apache.spark.sql.types.{DataTypes, Metadata, StructField, StructType} +import org.apache.spark.sql.expressions.{Aggregator, MutableAggregationBuffer, UserDefinedAggregateFunction, UserDefinedFunction, Window} +import org.apache.spark.sql.types.{ArrayType, DataType, DataTypes, Metadata, StructField, StructType} import java.util import java.util.function.Predicate @@ -18,6 +20,8 @@ import scala.collection.JavaConverters._ import scala.collection.mutable import org.apache.spark.sql.functions.{col, lit, udf} +import java.util.stream.Collectors + case class SparkDedupConfig(conf: DedupConfig, numPartitions: Int) extends Serializable { private val URL_REGEX: Pattern = Pattern.compile("^\\s*(http|https|ftp)\\://.*") @@ -74,11 +78,12 @@ case class SparkDedupConfig(conf: DedupConfig, numPartitions: Int) extends Seria val keys = conf.clusterings().asScala.map(_.getName + "_clustered").mkString(",") val fields = rowDataType.fieldNames.mkString(",") - // // Using SQL because GROUPING SETS are not available through Scala/Java DSL// // Using SQL because GROUPING SETS are not available through Scala/Java DSL + // Using SQL because GROUPING SETS are not available through Scala/Java DSL df_with_keys.sqlContext.sql( - ("SELECT coalesce(" + keys + ") as key, slice(sort_array(collect_set(struct(" + fields + "))), 1, " + conf.getWf.getQueueMaxSize + ") as block FROM " + tempTable + " WHERE coalesce(" + keys + ") IS NOT NULL GROUP BY GROUPING SETS (" + keys + ") ") + ("SELECT coalesce(" + keys + ") as key, collect_sort_slice(" + fields + ") as block FROM " + tempTable + " WHERE coalesce(" + keys + ") IS NOT NULL GROUP BY GROUPING SETS (" + keys + ") HAVING size(block) > 1") ) + } val generateClustersWithDFAPI: (Dataset[Row] => Dataset[Row]) = df => { @@ -91,7 +96,7 @@ case class SparkDedupConfig(conf: DedupConfig, numPartitions: Int) extends Seria } else { res } - }).checkpoint() + }) var relBlocks: Dataset[Row] = null @@ -108,8 +113,110 @@ case class SparkDedupConfig(conf: DedupConfig, numPartitions: Int) extends Seria } val ds: Dataset[Row] = df_with_filters.withColumn("key", functions.explode(clusterValuesUDF(cd).apply(functions.array(columns.asScala: _*)))) - .groupBy(new Column("key")) - .agg(functions.slice(functions.sort_array(functions.collect_set(functions.struct(rowDataType.fieldNames.map(col): _*))), 1, conf.getWf.getQueueMaxSize).as("block")) + .select((Seq(rowDataType.fieldNames: _*) ++ Seq("key")).map(col): _*) + .groupByKey(r => r.getAs[String]("key"))(Encoders.STRING) + .agg(collectSortSliceAggregator.toColumn) + .toDF("key", "block") + .select(col("block.block").as("block")) + + /*.groupBy("key") + .agg(collectSortSliceUDAF(rowDataType.fieldNames.map(col): _*).as("block"))*/ + .filter(functions.size(new Column("block")).geq(new Literal(2, DataTypes.IntegerType))) + + if (relBlocks == null) relBlocks = ds + else relBlocks = relBlocks.union(ds) + } + + relBlocks + } + + val generateClustersWithDFAPIMerged: (Dataset[Row] => Dataset[Row]) = df => { + val df_with_filters = conf.getPace.getModel.asScala.foldLeft(df)((res, fdef) => { + if (conf.blacklists.containsKey(fdef.getName)) { + res.withColumn( + fdef.getName + "_filtered", + filterColumnUDF(fdef).apply(new Column(fdef.getName)) + ) + } else { + res + } + }) + + import scala.collection.JavaConversions._ + + val keys = conf.clusterings().foldLeft(null : Column)((res, cd) => { + val columns: util.List[Column] = new util.ArrayList[Column](cd.getFields().size) + + for (fName <- cd.getFields()) { + if (conf.blacklists.containsKey(fName)) + columns.add(new Column(fName + "_filtered")) + else + columns.add(new Column(fName)) + } + + if (res != null) + functions.array_union(res, clusterValuesUDF(cd).apply(functions.array(columns.asScala: _*))) + else + clusterValuesUDF(cd).apply(functions.array(columns.asScala: _*)) + }) + + val ds: Dataset[Row] = df_with_filters.withColumn("key", functions.explode(keys)) + .select((Seq(rowDataType.fieldNames: _*) ++ Seq("key")).map(col): _*) + .groupByKey(r => r.getAs[String]("key"))(Encoders.STRING) + .agg(collectSortSliceAggregator.toColumn) + .toDF("key", "block") + .select(col("block.block").as("block")) + + /*.groupBy("key") + .agg(collectSortSliceUDAF(rowDataType.fieldNames.map(col): _*).as("block"))*/ + .filter(functions.size(new Column("block")).geq(new Literal(2, DataTypes.IntegerType))) + + ds + } + + val generateClustersWithRDDReduction: (Dataset[Row] => Dataset[Row]) = df => { + val df_with_filters = conf.getPace.getModel.asScala.foldLeft(df)((res, fdef) => { + if (conf.blacklists.containsKey(fdef.getName)) { + res.withColumn( + fdef.getName + "_filtered", + filterColumnUDF(fdef).apply(new Column(fdef.getName)) + ) + } else { + res + } + }) + + var relBlocks: Dataset[Row] = null + + import scala.collection.JavaConversions._ + + for (cd <- conf.clusterings()) { + val columns: util.List[Column] = new util.ArrayList[Column](cd.getFields().size) + + for (fName <- cd.getFields()) { + if (conf.blacklists.containsKey(fName)) + columns.add(new Column(fName + "_filtered")) + else + columns.add(new Column(fName)) + } + + val ds: Dataset[Row] = df.sparkSession.createDataFrame(df_with_filters.withColumn("key", functions.explode(clusterValuesUDF(cd).apply(functions.array(columns.asScala: _*)))) + .select(col("key"), functions.array(functions.struct(rowDataType.fieldNames.map(col): _*).as("value"))) + .rdd.keyBy(_.getString(0)) + .reduceByKey((a, b) => { + val b1 = a.getSeq[Row](1) + val b2 = b.getSeq[Row](1) + + if (b1.size + b2.size > conf.getWf.getQueueMaxSize) + Row(a.get(0), b1.union(b2).sortBy(_.getString(orderingFieldPosition)).slice(0, conf.getWf.getQueueMaxSize)) + else + Row(a.get(0), b1.union(b2)) + }) + .map(_._2) + .filter(k => k.getSeq(1).size > 1), + new StructType().add(StructField("key", DataTypes.StringType)).add(StructField("block", ArrayType(rowDataType))) + ) + if (relBlocks == null) relBlocks = ds else relBlocks = relBlocks.union(ds) } @@ -167,6 +274,8 @@ case class SparkDedupConfig(conf: DedupConfig, numPartitions: Int) extends Seria treeProcessor.compare(a, b) }).apply(functions.struct(rowDataType.fieldNames.map(s => col("l.".concat(s))): _*), functions.struct(rowDataType.fieldNames.map(s => col("r.".concat(s))): _*))) + .filter(col("match").equalTo(true)) + .select(col("l.identifier").as("from"), col("r.identifier").as("to")) // dsWithMatch.show(false) @@ -176,9 +285,9 @@ case class SparkDedupConfig(conf: DedupConfig, numPartitions: Int) extends Seria relBlocks = relBlocks.union(dsWithMatch) } - val res = relBlocks.filter(col("match").equalTo(true)) - .select(col("l.identifier").as("from"), col("r.identifier").as("to")) - //.repartition() + val res = relBlocks + //.select(col("l.identifier").as("from"), col("r.identifier").as("to")) + .repartition() .distinct() // res.show(false) @@ -301,7 +410,7 @@ case class SparkDedupConfig(conf: DedupConfig, numPartitions: Int) extends Seria def clusterValuesUDF(cd: ClusteringDef) = { udf[mutable.WrappedArray[String], mutable.WrappedArray[Object]](values => { - values.flatMap(f => cd.clusteringFunction().apply(conf, Seq(f.toString).asJava).asScala) + values.flatMap(f => cd.clusteringFunction().apply(conf, Seq(f.toString).asJava).asScala).map(cd.getName.concat(_)) }) } @@ -311,17 +420,78 @@ case class SparkDedupConfig(conf: DedupConfig, numPartitions: Int) extends Seria udf[Array[Tuple2[String, String]], mutable.WrappedArray[Row]](block => { val reporter = new SparkReporter(accumulators) - //now done by spark - // val mapDocuments = block.asJava.stream - // .sorted(new RowDataOrderingComparator(rowDataType.fieldIndex(conf.getWf.getOrderField))) - // .limit(conf.getWf.getQueueMaxSize) - // .collect(Collectors.toList[Row]()) + val mapDocuments = block.asJava.stream + .sorted(new RowDataOrderingComparator(orderingFieldPosition)) + .limit(conf.getWf.getQueueMaxSize) + .collect(Collectors.toList[Row]()) - - new BlockProcessor(conf, identityFieldPosition, orderingFieldPosition).processSortedRows(block.asJava, reporter) + new BlockProcessor(conf, identityFieldPosition, orderingFieldPosition).processSortedRows(mapDocuments, reporter) reporter.getRelations.asScala.toArray }) } + val collectSortSliceAggregator : Aggregator[Row,Seq[Row], Row] = new Aggregator[Row, Seq[Row], Row] () { + override def zero: Seq[Row] = Seq[Row]() + + + override def reduce(buffer: Seq[Row], input: Row): Seq[Row] = { + merge(buffer, Seq(input)) + } + + override def merge(buffer: Seq[Row], toMerge: Seq[Row]): Seq[Row] = { + val newBlock = buffer ++ toMerge + + if (newBlock.size > conf.getWf.getQueueMaxSize) + newBlock.sortBy(_.getString(orderingFieldPosition)).slice(0, conf.getWf.getQueueMaxSize) + else + newBlock + } + + override def finish(reduction: Seq[Row]): Row = { + Row(reduction.toArray) + } + + override def bufferEncoder: Encoder[Seq[Row]] = Encoders.kryo[Seq[Row]] + + override def outputEncoder: Encoder[Row] = RowEncoder.apply(new StructType().add("block", DataTypes.createArrayType(rowDataType), nullable = true)) + } + + val collectSortSliceUDAF : UserDefinedAggregateFunction = new UserDefinedAggregateFunction { + override def inputSchema: StructType = rowDataType + + override def bufferSchema: StructType = { + new StructType().add("block", DataTypes.createArrayType(rowDataType), nullable = true) + } + + override def dataType: DataType = DataTypes.createArrayType(rowDataType) + + override def deterministic: Boolean = true + + override def initialize(buffer: MutableAggregationBuffer): Unit = { + buffer(0) = Seq[Row]() + } + + override def update(buffer: MutableAggregationBuffer, input: Row): Unit = { + val newBlock = buffer.getSeq[Row](0) ++ Seq(input) + + if (newBlock.size > conf.getWf.getQueueMaxSize) + buffer(0) = newBlock.sortBy(_.getString(orderingFieldPosition)).slice(0, conf.getWf.getQueueMaxSize) + else + buffer(0) = newBlock + } + + override def merge(buffer: MutableAggregationBuffer, row: Row): Unit = { + val newBlock = buffer.getSeq[Row](0) ++ row.getSeq[Row](0) + + if (newBlock.size > conf.getWf.getQueueMaxSize) + buffer(0) = newBlock.sortBy(_.getString(orderingFieldPosition)).slice(0, conf.getWf.getQueueMaxSize) + else + buffer(0) = newBlock + } + + override def evaluate(buffer: Row): Any = { + buffer.getSeq[Row](0) + } + } } diff --git a/dhp-workflows/dhp-dedup-openaire/src/main/java/eu/dnetlib/dhp/oa/dedup/SparkCreateSimRels.java b/dhp-workflows/dhp-dedup-openaire/src/main/java/eu/dnetlib/dhp/oa/dedup/SparkCreateSimRels.java index 7c4ab5265..cd914e2df 100644 --- a/dhp-workflows/dhp-dedup-openaire/src/main/java/eu/dnetlib/dhp/oa/dedup/SparkCreateSimRels.java +++ b/dhp-workflows/dhp-dedup-openaire/src/main/java/eu/dnetlib/dhp/oa/dedup/SparkCreateSimRels.java @@ -86,15 +86,17 @@ public class SparkCreateSimRels extends AbstractSparkAction { SparkDedupConfig sparkConfig = new SparkDedupConfig(dedupConf, numPartitions); + spark.udf().register("collect_sort_slice", sparkConfig.collectSortSliceUDAF()); + Dataset simRels = spark .read() .textFile(DedupUtility.createEntityPath(graphBasePath, subEntity)) .transform(sparkConfig.modelExtractor()) // Extract fields from input json column according to model // definition - .transform(sparkConfig.generateAndProcessClustersWithJoins()) // generate pairs according to + .transform(sparkConfig.generateClustersWithDFAPIMerged()) // generate pairs according to // filters, clusters, and model // definition - // .transform(sparkConfig.processClusters()) // process blocks and emits pairs of found + .transform(sparkConfig.processClusters()) // process blocks and emits pairs of found // similarities .map( (MapFunction) t -> DedupUtility