Use UDAF and Aggregation class for testing

This commit is contained in:
Giambattista Bloisi 2023-07-07 12:35:30 +02:00
parent df19548c56
commit dcc08cc512
2 changed files with 193 additions and 21 deletions

View File

@ -6,10 +6,12 @@ import eu.dnetlib.pace.tree.support.TreeProcessor
import eu.dnetlib.pace.util.MapDocumentUtil.truncateValue
import eu.dnetlib.pace.util.{BlockProcessor, MapDocumentUtil, SparkReporter}
import org.apache.spark.SparkContext
import org.apache.spark.sql.{Column, Dataset, Row, functions}
import org.apache.spark.rdd.RDD.rddToPairRDDFunctions
import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder}
import org.apache.spark.sql.{Column, Dataset, Encoder, Encoders, Row, functions}
import org.apache.spark.sql.catalyst.expressions.{GenericRowWithSchema, Literal}
import org.apache.spark.sql.expressions.{UserDefinedFunction, Window}
import org.apache.spark.sql.types.{DataTypes, Metadata, StructField, StructType}
import org.apache.spark.sql.expressions.{Aggregator, MutableAggregationBuffer, UserDefinedAggregateFunction, UserDefinedFunction, Window}
import org.apache.spark.sql.types.{ArrayType, DataType, DataTypes, Metadata, StructField, StructType}
import java.util
import java.util.function.Predicate
@ -18,6 +20,8 @@ import scala.collection.JavaConverters._
import scala.collection.mutable
import org.apache.spark.sql.functions.{col, lit, udf}
import java.util.stream.Collectors
case class SparkDedupConfig(conf: DedupConfig, numPartitions: Int) extends Serializable {
private val URL_REGEX: Pattern = Pattern.compile("^\\s*(http|https|ftp)\\://.*")
@ -74,11 +78,12 @@ case class SparkDedupConfig(conf: DedupConfig, numPartitions: Int) extends Seria
val keys = conf.clusterings().asScala.map(_.getName + "_clustered").mkString(",")
val fields = rowDataType.fieldNames.mkString(",")
// // Using SQL because GROUPING SETS are not available through Scala/Java DSL// // Using SQL because GROUPING SETS are not available through Scala/Java DSL
// Using SQL because GROUPING SETS are not available through Scala/Java DSL
df_with_keys.sqlContext.sql(
("SELECT coalesce(" + keys + ") as key, slice(sort_array(collect_set(struct(" + fields + "))), 1, " + conf.getWf.getQueueMaxSize + ") as block FROM " + tempTable + " WHERE coalesce(" + keys + ") IS NOT NULL GROUP BY GROUPING SETS (" + keys + ") ")
("SELECT coalesce(" + keys + ") as key, collect_sort_slice(" + fields + ") as block FROM " + tempTable + " WHERE coalesce(" + keys + ") IS NOT NULL GROUP BY GROUPING SETS (" + keys + ") HAVING size(block) > 1")
)
}
val generateClustersWithDFAPI: (Dataset[Row] => Dataset[Row]) = df => {
@ -91,7 +96,7 @@ case class SparkDedupConfig(conf: DedupConfig, numPartitions: Int) extends Seria
} else {
res
}
}).checkpoint()
})
var relBlocks: Dataset[Row] = null
@ -108,8 +113,110 @@ case class SparkDedupConfig(conf: DedupConfig, numPartitions: Int) extends Seria
}
val ds: Dataset[Row] = df_with_filters.withColumn("key", functions.explode(clusterValuesUDF(cd).apply(functions.array(columns.asScala: _*))))
.groupBy(new Column("key"))
.agg(functions.slice(functions.sort_array(functions.collect_set(functions.struct(rowDataType.fieldNames.map(col): _*))), 1, conf.getWf.getQueueMaxSize).as("block"))
.select((Seq(rowDataType.fieldNames: _*) ++ Seq("key")).map(col): _*)
.groupByKey(r => r.getAs[String]("key"))(Encoders.STRING)
.agg(collectSortSliceAggregator.toColumn)
.toDF("key", "block")
.select(col("block.block").as("block"))
/*.groupBy("key")
.agg(collectSortSliceUDAF(rowDataType.fieldNames.map(col): _*).as("block"))*/
.filter(functions.size(new Column("block")).geq(new Literal(2, DataTypes.IntegerType)))
if (relBlocks == null) relBlocks = ds
else relBlocks = relBlocks.union(ds)
}
relBlocks
}
val generateClustersWithDFAPIMerged: (Dataset[Row] => Dataset[Row]) = df => {
val df_with_filters = conf.getPace.getModel.asScala.foldLeft(df)((res, fdef) => {
if (conf.blacklists.containsKey(fdef.getName)) {
res.withColumn(
fdef.getName + "_filtered",
filterColumnUDF(fdef).apply(new Column(fdef.getName))
)
} else {
res
}
})
import scala.collection.JavaConversions._
val keys = conf.clusterings().foldLeft(null : Column)((res, cd) => {
val columns: util.List[Column] = new util.ArrayList[Column](cd.getFields().size)
for (fName <- cd.getFields()) {
if (conf.blacklists.containsKey(fName))
columns.add(new Column(fName + "_filtered"))
else
columns.add(new Column(fName))
}
if (res != null)
functions.array_union(res, clusterValuesUDF(cd).apply(functions.array(columns.asScala: _*)))
else
clusterValuesUDF(cd).apply(functions.array(columns.asScala: _*))
})
val ds: Dataset[Row] = df_with_filters.withColumn("key", functions.explode(keys))
.select((Seq(rowDataType.fieldNames: _*) ++ Seq("key")).map(col): _*)
.groupByKey(r => r.getAs[String]("key"))(Encoders.STRING)
.agg(collectSortSliceAggregator.toColumn)
.toDF("key", "block")
.select(col("block.block").as("block"))
/*.groupBy("key")
.agg(collectSortSliceUDAF(rowDataType.fieldNames.map(col): _*).as("block"))*/
.filter(functions.size(new Column("block")).geq(new Literal(2, DataTypes.IntegerType)))
ds
}
val generateClustersWithRDDReduction: (Dataset[Row] => Dataset[Row]) = df => {
val df_with_filters = conf.getPace.getModel.asScala.foldLeft(df)((res, fdef) => {
if (conf.blacklists.containsKey(fdef.getName)) {
res.withColumn(
fdef.getName + "_filtered",
filterColumnUDF(fdef).apply(new Column(fdef.getName))
)
} else {
res
}
})
var relBlocks: Dataset[Row] = null
import scala.collection.JavaConversions._
for (cd <- conf.clusterings()) {
val columns: util.List[Column] = new util.ArrayList[Column](cd.getFields().size)
for (fName <- cd.getFields()) {
if (conf.blacklists.containsKey(fName))
columns.add(new Column(fName + "_filtered"))
else
columns.add(new Column(fName))
}
val ds: Dataset[Row] = df.sparkSession.createDataFrame(df_with_filters.withColumn("key", functions.explode(clusterValuesUDF(cd).apply(functions.array(columns.asScala: _*))))
.select(col("key"), functions.array(functions.struct(rowDataType.fieldNames.map(col): _*).as("value")))
.rdd.keyBy(_.getString(0))
.reduceByKey((a, b) => {
val b1 = a.getSeq[Row](1)
val b2 = b.getSeq[Row](1)
if (b1.size + b2.size > conf.getWf.getQueueMaxSize)
Row(a.get(0), b1.union(b2).sortBy(_.getString(orderingFieldPosition)).slice(0, conf.getWf.getQueueMaxSize))
else
Row(a.get(0), b1.union(b2))
})
.map(_._2)
.filter(k => k.getSeq(1).size > 1),
new StructType().add(StructField("key", DataTypes.StringType)).add(StructField("block", ArrayType(rowDataType)))
)
if (relBlocks == null) relBlocks = ds
else relBlocks = relBlocks.union(ds)
}
@ -167,6 +274,8 @@ case class SparkDedupConfig(conf: DedupConfig, numPartitions: Int) extends Seria
treeProcessor.compare(a, b)
}).apply(functions.struct(rowDataType.fieldNames.map(s => col("l.".concat(s))): _*), functions.struct(rowDataType.fieldNames.map(s => col("r.".concat(s))): _*)))
.filter(col("match").equalTo(true))
.select(col("l.identifier").as("from"), col("r.identifier").as("to"))
// dsWithMatch.show(false)
@ -176,9 +285,9 @@ case class SparkDedupConfig(conf: DedupConfig, numPartitions: Int) extends Seria
relBlocks = relBlocks.union(dsWithMatch)
}
val res = relBlocks.filter(col("match").equalTo(true))
.select(col("l.identifier").as("from"), col("r.identifier").as("to"))
//.repartition()
val res = relBlocks
//.select(col("l.identifier").as("from"), col("r.identifier").as("to"))
.repartition()
.distinct()
// res.show(false)
@ -301,7 +410,7 @@ case class SparkDedupConfig(conf: DedupConfig, numPartitions: Int) extends Seria
def clusterValuesUDF(cd: ClusteringDef) = {
udf[mutable.WrappedArray[String], mutable.WrappedArray[Object]](values => {
values.flatMap(f => cd.clusteringFunction().apply(conf, Seq(f.toString).asJava).asScala)
values.flatMap(f => cd.clusteringFunction().apply(conf, Seq(f.toString).asJava).asScala).map(cd.getName.concat(_))
})
}
@ -311,17 +420,78 @@ case class SparkDedupConfig(conf: DedupConfig, numPartitions: Int) extends Seria
udf[Array[Tuple2[String, String]], mutable.WrappedArray[Row]](block => {
val reporter = new SparkReporter(accumulators)
//now done by spark
// val mapDocuments = block.asJava.stream
// .sorted(new RowDataOrderingComparator(rowDataType.fieldIndex(conf.getWf.getOrderField)))
// .limit(conf.getWf.getQueueMaxSize)
// .collect(Collectors.toList[Row]())
val mapDocuments = block.asJava.stream
.sorted(new RowDataOrderingComparator(orderingFieldPosition))
.limit(conf.getWf.getQueueMaxSize)
.collect(Collectors.toList[Row]())
new BlockProcessor(conf, identityFieldPosition, orderingFieldPosition).processSortedRows(block.asJava, reporter)
new BlockProcessor(conf, identityFieldPosition, orderingFieldPosition).processSortedRows(mapDocuments, reporter)
reporter.getRelations.asScala.toArray
})
}
val collectSortSliceAggregator : Aggregator[Row,Seq[Row], Row] = new Aggregator[Row, Seq[Row], Row] () {
override def zero: Seq[Row] = Seq[Row]()
override def reduce(buffer: Seq[Row], input: Row): Seq[Row] = {
merge(buffer, Seq(input))
}
override def merge(buffer: Seq[Row], toMerge: Seq[Row]): Seq[Row] = {
val newBlock = buffer ++ toMerge
if (newBlock.size > conf.getWf.getQueueMaxSize)
newBlock.sortBy(_.getString(orderingFieldPosition)).slice(0, conf.getWf.getQueueMaxSize)
else
newBlock
}
override def finish(reduction: Seq[Row]): Row = {
Row(reduction.toArray)
}
override def bufferEncoder: Encoder[Seq[Row]] = Encoders.kryo[Seq[Row]]
override def outputEncoder: Encoder[Row] = RowEncoder.apply(new StructType().add("block", DataTypes.createArrayType(rowDataType), nullable = true))
}
val collectSortSliceUDAF : UserDefinedAggregateFunction = new UserDefinedAggregateFunction {
override def inputSchema: StructType = rowDataType
override def bufferSchema: StructType = {
new StructType().add("block", DataTypes.createArrayType(rowDataType), nullable = true)
}
override def dataType: DataType = DataTypes.createArrayType(rowDataType)
override def deterministic: Boolean = true
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = Seq[Row]()
}
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
val newBlock = buffer.getSeq[Row](0) ++ Seq(input)
if (newBlock.size > conf.getWf.getQueueMaxSize)
buffer(0) = newBlock.sortBy(_.getString(orderingFieldPosition)).slice(0, conf.getWf.getQueueMaxSize)
else
buffer(0) = newBlock
}
override def merge(buffer: MutableAggregationBuffer, row: Row): Unit = {
val newBlock = buffer.getSeq[Row](0) ++ row.getSeq[Row](0)
if (newBlock.size > conf.getWf.getQueueMaxSize)
buffer(0) = newBlock.sortBy(_.getString(orderingFieldPosition)).slice(0, conf.getWf.getQueueMaxSize)
else
buffer(0) = newBlock
}
override def evaluate(buffer: Row): Any = {
buffer.getSeq[Row](0)
}
}
}

View File

@ -86,15 +86,17 @@ public class SparkCreateSimRels extends AbstractSparkAction {
SparkDedupConfig sparkConfig = new SparkDedupConfig(dedupConf, numPartitions);
spark.udf().register("collect_sort_slice", sparkConfig.collectSortSliceUDAF());
Dataset<?> simRels = spark
.read()
.textFile(DedupUtility.createEntityPath(graphBasePath, subEntity))
.transform(sparkConfig.modelExtractor()) // Extract fields from input json column according to model
// definition
.transform(sparkConfig.generateAndProcessClustersWithJoins()) // generate <key,block> pairs according to
.transform(sparkConfig.generateClustersWithDFAPIMerged()) // generate <key,block> pairs according to
// filters, clusters, and model
// definition
// .transform(sparkConfig.processClusters()) // process blocks and emits <from,to> pairs of found
.transform(sparkConfig.processClusters()) // process blocks and emits <from,to> pairs of found
// similarities
.map(
(MapFunction<Row, Relation>) t -> DedupUtility