Use UDAF and Aggregation class for testing
This commit is contained in:
parent
df19548c56
commit
dcc08cc512
|
@ -6,10 +6,12 @@ import eu.dnetlib.pace.tree.support.TreeProcessor
|
|||
import eu.dnetlib.pace.util.MapDocumentUtil.truncateValue
|
||||
import eu.dnetlib.pace.util.{BlockProcessor, MapDocumentUtil, SparkReporter}
|
||||
import org.apache.spark.SparkContext
|
||||
import org.apache.spark.sql.{Column, Dataset, Row, functions}
|
||||
import org.apache.spark.rdd.RDD.rddToPairRDDFunctions
|
||||
import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder}
|
||||
import org.apache.spark.sql.{Column, Dataset, Encoder, Encoders, Row, functions}
|
||||
import org.apache.spark.sql.catalyst.expressions.{GenericRowWithSchema, Literal}
|
||||
import org.apache.spark.sql.expressions.{UserDefinedFunction, Window}
|
||||
import org.apache.spark.sql.types.{DataTypes, Metadata, StructField, StructType}
|
||||
import org.apache.spark.sql.expressions.{Aggregator, MutableAggregationBuffer, UserDefinedAggregateFunction, UserDefinedFunction, Window}
|
||||
import org.apache.spark.sql.types.{ArrayType, DataType, DataTypes, Metadata, StructField, StructType}
|
||||
|
||||
import java.util
|
||||
import java.util.function.Predicate
|
||||
|
@ -18,6 +20,8 @@ import scala.collection.JavaConverters._
|
|||
import scala.collection.mutable
|
||||
import org.apache.spark.sql.functions.{col, lit, udf}
|
||||
|
||||
import java.util.stream.Collectors
|
||||
|
||||
case class SparkDedupConfig(conf: DedupConfig, numPartitions: Int) extends Serializable {
|
||||
|
||||
private val URL_REGEX: Pattern = Pattern.compile("^\\s*(http|https|ftp)\\://.*")
|
||||
|
@ -74,11 +78,12 @@ case class SparkDedupConfig(conf: DedupConfig, numPartitions: Int) extends Seria
|
|||
|
||||
val keys = conf.clusterings().asScala.map(_.getName + "_clustered").mkString(",")
|
||||
val fields = rowDataType.fieldNames.mkString(",")
|
||||
// // Using SQL because GROUPING SETS are not available through Scala/Java DSL// // Using SQL because GROUPING SETS are not available through Scala/Java DSL
|
||||
|
||||
// Using SQL because GROUPING SETS are not available through Scala/Java DSL
|
||||
df_with_keys.sqlContext.sql(
|
||||
("SELECT coalesce(" + keys + ") as key, slice(sort_array(collect_set(struct(" + fields + "))), 1, " + conf.getWf.getQueueMaxSize + ") as block FROM " + tempTable + " WHERE coalesce(" + keys + ") IS NOT NULL GROUP BY GROUPING SETS (" + keys + ") ")
|
||||
("SELECT coalesce(" + keys + ") as key, collect_sort_slice(" + fields + ") as block FROM " + tempTable + " WHERE coalesce(" + keys + ") IS NOT NULL GROUP BY GROUPING SETS (" + keys + ") HAVING size(block) > 1")
|
||||
)
|
||||
|
||||
}
|
||||
|
||||
val generateClustersWithDFAPI: (Dataset[Row] => Dataset[Row]) = df => {
|
||||
|
@ -91,7 +96,7 @@ case class SparkDedupConfig(conf: DedupConfig, numPartitions: Int) extends Seria
|
|||
} else {
|
||||
res
|
||||
}
|
||||
}).checkpoint()
|
||||
})
|
||||
|
||||
var relBlocks: Dataset[Row] = null
|
||||
|
||||
|
@ -108,8 +113,110 @@ case class SparkDedupConfig(conf: DedupConfig, numPartitions: Int) extends Seria
|
|||
}
|
||||
|
||||
val ds: Dataset[Row] = df_with_filters.withColumn("key", functions.explode(clusterValuesUDF(cd).apply(functions.array(columns.asScala: _*))))
|
||||
.groupBy(new Column("key"))
|
||||
.agg(functions.slice(functions.sort_array(functions.collect_set(functions.struct(rowDataType.fieldNames.map(col): _*))), 1, conf.getWf.getQueueMaxSize).as("block"))
|
||||
.select((Seq(rowDataType.fieldNames: _*) ++ Seq("key")).map(col): _*)
|
||||
.groupByKey(r => r.getAs[String]("key"))(Encoders.STRING)
|
||||
.agg(collectSortSliceAggregator.toColumn)
|
||||
.toDF("key", "block")
|
||||
.select(col("block.block").as("block"))
|
||||
|
||||
/*.groupBy("key")
|
||||
.agg(collectSortSliceUDAF(rowDataType.fieldNames.map(col): _*).as("block"))*/
|
||||
.filter(functions.size(new Column("block")).geq(new Literal(2, DataTypes.IntegerType)))
|
||||
|
||||
if (relBlocks == null) relBlocks = ds
|
||||
else relBlocks = relBlocks.union(ds)
|
||||
}
|
||||
|
||||
relBlocks
|
||||
}
|
||||
|
||||
val generateClustersWithDFAPIMerged: (Dataset[Row] => Dataset[Row]) = df => {
|
||||
val df_with_filters = conf.getPace.getModel.asScala.foldLeft(df)((res, fdef) => {
|
||||
if (conf.blacklists.containsKey(fdef.getName)) {
|
||||
res.withColumn(
|
||||
fdef.getName + "_filtered",
|
||||
filterColumnUDF(fdef).apply(new Column(fdef.getName))
|
||||
)
|
||||
} else {
|
||||
res
|
||||
}
|
||||
})
|
||||
|
||||
import scala.collection.JavaConversions._
|
||||
|
||||
val keys = conf.clusterings().foldLeft(null : Column)((res, cd) => {
|
||||
val columns: util.List[Column] = new util.ArrayList[Column](cd.getFields().size)
|
||||
|
||||
for (fName <- cd.getFields()) {
|
||||
if (conf.blacklists.containsKey(fName))
|
||||
columns.add(new Column(fName + "_filtered"))
|
||||
else
|
||||
columns.add(new Column(fName))
|
||||
}
|
||||
|
||||
if (res != null)
|
||||
functions.array_union(res, clusterValuesUDF(cd).apply(functions.array(columns.asScala: _*)))
|
||||
else
|
||||
clusterValuesUDF(cd).apply(functions.array(columns.asScala: _*))
|
||||
})
|
||||
|
||||
val ds: Dataset[Row] = df_with_filters.withColumn("key", functions.explode(keys))
|
||||
.select((Seq(rowDataType.fieldNames: _*) ++ Seq("key")).map(col): _*)
|
||||
.groupByKey(r => r.getAs[String]("key"))(Encoders.STRING)
|
||||
.agg(collectSortSliceAggregator.toColumn)
|
||||
.toDF("key", "block")
|
||||
.select(col("block.block").as("block"))
|
||||
|
||||
/*.groupBy("key")
|
||||
.agg(collectSortSliceUDAF(rowDataType.fieldNames.map(col): _*).as("block"))*/
|
||||
.filter(functions.size(new Column("block")).geq(new Literal(2, DataTypes.IntegerType)))
|
||||
|
||||
ds
|
||||
}
|
||||
|
||||
val generateClustersWithRDDReduction: (Dataset[Row] => Dataset[Row]) = df => {
|
||||
val df_with_filters = conf.getPace.getModel.asScala.foldLeft(df)((res, fdef) => {
|
||||
if (conf.blacklists.containsKey(fdef.getName)) {
|
||||
res.withColumn(
|
||||
fdef.getName + "_filtered",
|
||||
filterColumnUDF(fdef).apply(new Column(fdef.getName))
|
||||
)
|
||||
} else {
|
||||
res
|
||||
}
|
||||
})
|
||||
|
||||
var relBlocks: Dataset[Row] = null
|
||||
|
||||
import scala.collection.JavaConversions._
|
||||
|
||||
for (cd <- conf.clusterings()) {
|
||||
val columns: util.List[Column] = new util.ArrayList[Column](cd.getFields().size)
|
||||
|
||||
for (fName <- cd.getFields()) {
|
||||
if (conf.blacklists.containsKey(fName))
|
||||
columns.add(new Column(fName + "_filtered"))
|
||||
else
|
||||
columns.add(new Column(fName))
|
||||
}
|
||||
|
||||
val ds: Dataset[Row] = df.sparkSession.createDataFrame(df_with_filters.withColumn("key", functions.explode(clusterValuesUDF(cd).apply(functions.array(columns.asScala: _*))))
|
||||
.select(col("key"), functions.array(functions.struct(rowDataType.fieldNames.map(col): _*).as("value")))
|
||||
.rdd.keyBy(_.getString(0))
|
||||
.reduceByKey((a, b) => {
|
||||
val b1 = a.getSeq[Row](1)
|
||||
val b2 = b.getSeq[Row](1)
|
||||
|
||||
if (b1.size + b2.size > conf.getWf.getQueueMaxSize)
|
||||
Row(a.get(0), b1.union(b2).sortBy(_.getString(orderingFieldPosition)).slice(0, conf.getWf.getQueueMaxSize))
|
||||
else
|
||||
Row(a.get(0), b1.union(b2))
|
||||
})
|
||||
.map(_._2)
|
||||
.filter(k => k.getSeq(1).size > 1),
|
||||
new StructType().add(StructField("key", DataTypes.StringType)).add(StructField("block", ArrayType(rowDataType)))
|
||||
)
|
||||
|
||||
if (relBlocks == null) relBlocks = ds
|
||||
else relBlocks = relBlocks.union(ds)
|
||||
}
|
||||
|
@ -167,6 +274,8 @@ case class SparkDedupConfig(conf: DedupConfig, numPartitions: Int) extends Seria
|
|||
|
||||
treeProcessor.compare(a, b)
|
||||
}).apply(functions.struct(rowDataType.fieldNames.map(s => col("l.".concat(s))): _*), functions.struct(rowDataType.fieldNames.map(s => col("r.".concat(s))): _*)))
|
||||
.filter(col("match").equalTo(true))
|
||||
.select(col("l.identifier").as("from"), col("r.identifier").as("to"))
|
||||
|
||||
// dsWithMatch.show(false)
|
||||
|
||||
|
@ -176,9 +285,9 @@ case class SparkDedupConfig(conf: DedupConfig, numPartitions: Int) extends Seria
|
|||
relBlocks = relBlocks.union(dsWithMatch)
|
||||
}
|
||||
|
||||
val res = relBlocks.filter(col("match").equalTo(true))
|
||||
.select(col("l.identifier").as("from"), col("r.identifier").as("to"))
|
||||
//.repartition()
|
||||
val res = relBlocks
|
||||
//.select(col("l.identifier").as("from"), col("r.identifier").as("to"))
|
||||
.repartition()
|
||||
.distinct()
|
||||
|
||||
// res.show(false)
|
||||
|
@ -301,7 +410,7 @@ case class SparkDedupConfig(conf: DedupConfig, numPartitions: Int) extends Seria
|
|||
|
||||
def clusterValuesUDF(cd: ClusteringDef) = {
|
||||
udf[mutable.WrappedArray[String], mutable.WrappedArray[Object]](values => {
|
||||
values.flatMap(f => cd.clusteringFunction().apply(conf, Seq(f.toString).asJava).asScala)
|
||||
values.flatMap(f => cd.clusteringFunction().apply(conf, Seq(f.toString).asJava).asScala).map(cd.getName.concat(_))
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -311,17 +420,78 @@ case class SparkDedupConfig(conf: DedupConfig, numPartitions: Int) extends Seria
|
|||
udf[Array[Tuple2[String, String]], mutable.WrappedArray[Row]](block => {
|
||||
val reporter = new SparkReporter(accumulators)
|
||||
|
||||
//now done by spark
|
||||
// val mapDocuments = block.asJava.stream
|
||||
// .sorted(new RowDataOrderingComparator(rowDataType.fieldIndex(conf.getWf.getOrderField)))
|
||||
// .limit(conf.getWf.getQueueMaxSize)
|
||||
// .collect(Collectors.toList[Row]())
|
||||
val mapDocuments = block.asJava.stream
|
||||
.sorted(new RowDataOrderingComparator(orderingFieldPosition))
|
||||
.limit(conf.getWf.getQueueMaxSize)
|
||||
.collect(Collectors.toList[Row]())
|
||||
|
||||
|
||||
new BlockProcessor(conf, identityFieldPosition, orderingFieldPosition).processSortedRows(block.asJava, reporter)
|
||||
new BlockProcessor(conf, identityFieldPosition, orderingFieldPosition).processSortedRows(mapDocuments, reporter)
|
||||
|
||||
reporter.getRelations.asScala.toArray
|
||||
})
|
||||
}
|
||||
|
||||
val collectSortSliceAggregator : Aggregator[Row,Seq[Row], Row] = new Aggregator[Row, Seq[Row], Row] () {
|
||||
override def zero: Seq[Row] = Seq[Row]()
|
||||
|
||||
|
||||
override def reduce(buffer: Seq[Row], input: Row): Seq[Row] = {
|
||||
merge(buffer, Seq(input))
|
||||
}
|
||||
|
||||
override def merge(buffer: Seq[Row], toMerge: Seq[Row]): Seq[Row] = {
|
||||
val newBlock = buffer ++ toMerge
|
||||
|
||||
if (newBlock.size > conf.getWf.getQueueMaxSize)
|
||||
newBlock.sortBy(_.getString(orderingFieldPosition)).slice(0, conf.getWf.getQueueMaxSize)
|
||||
else
|
||||
newBlock
|
||||
}
|
||||
|
||||
override def finish(reduction: Seq[Row]): Row = {
|
||||
Row(reduction.toArray)
|
||||
}
|
||||
|
||||
override def bufferEncoder: Encoder[Seq[Row]] = Encoders.kryo[Seq[Row]]
|
||||
|
||||
override def outputEncoder: Encoder[Row] = RowEncoder.apply(new StructType().add("block", DataTypes.createArrayType(rowDataType), nullable = true))
|
||||
}
|
||||
|
||||
val collectSortSliceUDAF : UserDefinedAggregateFunction = new UserDefinedAggregateFunction {
|
||||
override def inputSchema: StructType = rowDataType
|
||||
|
||||
override def bufferSchema: StructType = {
|
||||
new StructType().add("block", DataTypes.createArrayType(rowDataType), nullable = true)
|
||||
}
|
||||
|
||||
override def dataType: DataType = DataTypes.createArrayType(rowDataType)
|
||||
|
||||
override def deterministic: Boolean = true
|
||||
|
||||
override def initialize(buffer: MutableAggregationBuffer): Unit = {
|
||||
buffer(0) = Seq[Row]()
|
||||
}
|
||||
|
||||
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
|
||||
val newBlock = buffer.getSeq[Row](0) ++ Seq(input)
|
||||
|
||||
if (newBlock.size > conf.getWf.getQueueMaxSize)
|
||||
buffer(0) = newBlock.sortBy(_.getString(orderingFieldPosition)).slice(0, conf.getWf.getQueueMaxSize)
|
||||
else
|
||||
buffer(0) = newBlock
|
||||
}
|
||||
|
||||
override def merge(buffer: MutableAggregationBuffer, row: Row): Unit = {
|
||||
val newBlock = buffer.getSeq[Row](0) ++ row.getSeq[Row](0)
|
||||
|
||||
if (newBlock.size > conf.getWf.getQueueMaxSize)
|
||||
buffer(0) = newBlock.sortBy(_.getString(orderingFieldPosition)).slice(0, conf.getWf.getQueueMaxSize)
|
||||
else
|
||||
buffer(0) = newBlock
|
||||
}
|
||||
|
||||
override def evaluate(buffer: Row): Any = {
|
||||
buffer.getSeq[Row](0)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -86,15 +86,17 @@ public class SparkCreateSimRels extends AbstractSparkAction {
|
|||
|
||||
SparkDedupConfig sparkConfig = new SparkDedupConfig(dedupConf, numPartitions);
|
||||
|
||||
spark.udf().register("collect_sort_slice", sparkConfig.collectSortSliceUDAF());
|
||||
|
||||
Dataset<?> simRels = spark
|
||||
.read()
|
||||
.textFile(DedupUtility.createEntityPath(graphBasePath, subEntity))
|
||||
.transform(sparkConfig.modelExtractor()) // Extract fields from input json column according to model
|
||||
// definition
|
||||
.transform(sparkConfig.generateAndProcessClustersWithJoins()) // generate <key,block> pairs according to
|
||||
.transform(sparkConfig.generateClustersWithDFAPIMerged()) // generate <key,block> pairs according to
|
||||
// filters, clusters, and model
|
||||
// definition
|
||||
// .transform(sparkConfig.processClusters()) // process blocks and emits <from,to> pairs of found
|
||||
.transform(sparkConfig.processClusters()) // process blocks and emits <from,to> pairs of found
|
||||
// similarities
|
||||
.map(
|
||||
(MapFunction<Row, Relation>) t -> DedupUtility
|
||||
|
|
Loading…
Reference in New Issue