Use UDAF and Aggregation class for testing
This commit is contained in:
parent
df19548c56
commit
dcc08cc512
|
@ -6,10 +6,12 @@ import eu.dnetlib.pace.tree.support.TreeProcessor
|
||||||
import eu.dnetlib.pace.util.MapDocumentUtil.truncateValue
|
import eu.dnetlib.pace.util.MapDocumentUtil.truncateValue
|
||||||
import eu.dnetlib.pace.util.{BlockProcessor, MapDocumentUtil, SparkReporter}
|
import eu.dnetlib.pace.util.{BlockProcessor, MapDocumentUtil, SparkReporter}
|
||||||
import org.apache.spark.SparkContext
|
import org.apache.spark.SparkContext
|
||||||
import org.apache.spark.sql.{Column, Dataset, Row, functions}
|
import org.apache.spark.rdd.RDD.rddToPairRDDFunctions
|
||||||
|
import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder}
|
||||||
|
import org.apache.spark.sql.{Column, Dataset, Encoder, Encoders, Row, functions}
|
||||||
import org.apache.spark.sql.catalyst.expressions.{GenericRowWithSchema, Literal}
|
import org.apache.spark.sql.catalyst.expressions.{GenericRowWithSchema, Literal}
|
||||||
import org.apache.spark.sql.expressions.{UserDefinedFunction, Window}
|
import org.apache.spark.sql.expressions.{Aggregator, MutableAggregationBuffer, UserDefinedAggregateFunction, UserDefinedFunction, Window}
|
||||||
import org.apache.spark.sql.types.{DataTypes, Metadata, StructField, StructType}
|
import org.apache.spark.sql.types.{ArrayType, DataType, DataTypes, Metadata, StructField, StructType}
|
||||||
|
|
||||||
import java.util
|
import java.util
|
||||||
import java.util.function.Predicate
|
import java.util.function.Predicate
|
||||||
|
@ -18,6 +20,8 @@ import scala.collection.JavaConverters._
|
||||||
import scala.collection.mutable
|
import scala.collection.mutable
|
||||||
import org.apache.spark.sql.functions.{col, lit, udf}
|
import org.apache.spark.sql.functions.{col, lit, udf}
|
||||||
|
|
||||||
|
import java.util.stream.Collectors
|
||||||
|
|
||||||
case class SparkDedupConfig(conf: DedupConfig, numPartitions: Int) extends Serializable {
|
case class SparkDedupConfig(conf: DedupConfig, numPartitions: Int) extends Serializable {
|
||||||
|
|
||||||
private val URL_REGEX: Pattern = Pattern.compile("^\\s*(http|https|ftp)\\://.*")
|
private val URL_REGEX: Pattern = Pattern.compile("^\\s*(http|https|ftp)\\://.*")
|
||||||
|
@ -74,11 +78,12 @@ case class SparkDedupConfig(conf: DedupConfig, numPartitions: Int) extends Seria
|
||||||
|
|
||||||
val keys = conf.clusterings().asScala.map(_.getName + "_clustered").mkString(",")
|
val keys = conf.clusterings().asScala.map(_.getName + "_clustered").mkString(",")
|
||||||
val fields = rowDataType.fieldNames.mkString(",")
|
val fields = rowDataType.fieldNames.mkString(",")
|
||||||
// // Using SQL because GROUPING SETS are not available through Scala/Java DSL// // Using SQL because GROUPING SETS are not available through Scala/Java DSL
|
|
||||||
|
|
||||||
|
// Using SQL because GROUPING SETS are not available through Scala/Java DSL
|
||||||
df_with_keys.sqlContext.sql(
|
df_with_keys.sqlContext.sql(
|
||||||
("SELECT coalesce(" + keys + ") as key, slice(sort_array(collect_set(struct(" + fields + "))), 1, " + conf.getWf.getQueueMaxSize + ") as block FROM " + tempTable + " WHERE coalesce(" + keys + ") IS NOT NULL GROUP BY GROUPING SETS (" + keys + ") ")
|
("SELECT coalesce(" + keys + ") as key, collect_sort_slice(" + fields + ") as block FROM " + tempTable + " WHERE coalesce(" + keys + ") IS NOT NULL GROUP BY GROUPING SETS (" + keys + ") HAVING size(block) > 1")
|
||||||
)
|
)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
val generateClustersWithDFAPI: (Dataset[Row] => Dataset[Row]) = df => {
|
val generateClustersWithDFAPI: (Dataset[Row] => Dataset[Row]) = df => {
|
||||||
|
@ -91,7 +96,7 @@ case class SparkDedupConfig(conf: DedupConfig, numPartitions: Int) extends Seria
|
||||||
} else {
|
} else {
|
||||||
res
|
res
|
||||||
}
|
}
|
||||||
}).checkpoint()
|
})
|
||||||
|
|
||||||
var relBlocks: Dataset[Row] = null
|
var relBlocks: Dataset[Row] = null
|
||||||
|
|
||||||
|
@ -108,8 +113,110 @@ case class SparkDedupConfig(conf: DedupConfig, numPartitions: Int) extends Seria
|
||||||
}
|
}
|
||||||
|
|
||||||
val ds: Dataset[Row] = df_with_filters.withColumn("key", functions.explode(clusterValuesUDF(cd).apply(functions.array(columns.asScala: _*))))
|
val ds: Dataset[Row] = df_with_filters.withColumn("key", functions.explode(clusterValuesUDF(cd).apply(functions.array(columns.asScala: _*))))
|
||||||
.groupBy(new Column("key"))
|
.select((Seq(rowDataType.fieldNames: _*) ++ Seq("key")).map(col): _*)
|
||||||
.agg(functions.slice(functions.sort_array(functions.collect_set(functions.struct(rowDataType.fieldNames.map(col): _*))), 1, conf.getWf.getQueueMaxSize).as("block"))
|
.groupByKey(r => r.getAs[String]("key"))(Encoders.STRING)
|
||||||
|
.agg(collectSortSliceAggregator.toColumn)
|
||||||
|
.toDF("key", "block")
|
||||||
|
.select(col("block.block").as("block"))
|
||||||
|
|
||||||
|
/*.groupBy("key")
|
||||||
|
.agg(collectSortSliceUDAF(rowDataType.fieldNames.map(col): _*).as("block"))*/
|
||||||
|
.filter(functions.size(new Column("block")).geq(new Literal(2, DataTypes.IntegerType)))
|
||||||
|
|
||||||
|
if (relBlocks == null) relBlocks = ds
|
||||||
|
else relBlocks = relBlocks.union(ds)
|
||||||
|
}
|
||||||
|
|
||||||
|
relBlocks
|
||||||
|
}
|
||||||
|
|
||||||
|
val generateClustersWithDFAPIMerged: (Dataset[Row] => Dataset[Row]) = df => {
|
||||||
|
val df_with_filters = conf.getPace.getModel.asScala.foldLeft(df)((res, fdef) => {
|
||||||
|
if (conf.blacklists.containsKey(fdef.getName)) {
|
||||||
|
res.withColumn(
|
||||||
|
fdef.getName + "_filtered",
|
||||||
|
filterColumnUDF(fdef).apply(new Column(fdef.getName))
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
res
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
import scala.collection.JavaConversions._
|
||||||
|
|
||||||
|
val keys = conf.clusterings().foldLeft(null : Column)((res, cd) => {
|
||||||
|
val columns: util.List[Column] = new util.ArrayList[Column](cd.getFields().size)
|
||||||
|
|
||||||
|
for (fName <- cd.getFields()) {
|
||||||
|
if (conf.blacklists.containsKey(fName))
|
||||||
|
columns.add(new Column(fName + "_filtered"))
|
||||||
|
else
|
||||||
|
columns.add(new Column(fName))
|
||||||
|
}
|
||||||
|
|
||||||
|
if (res != null)
|
||||||
|
functions.array_union(res, clusterValuesUDF(cd).apply(functions.array(columns.asScala: _*)))
|
||||||
|
else
|
||||||
|
clusterValuesUDF(cd).apply(functions.array(columns.asScala: _*))
|
||||||
|
})
|
||||||
|
|
||||||
|
val ds: Dataset[Row] = df_with_filters.withColumn("key", functions.explode(keys))
|
||||||
|
.select((Seq(rowDataType.fieldNames: _*) ++ Seq("key")).map(col): _*)
|
||||||
|
.groupByKey(r => r.getAs[String]("key"))(Encoders.STRING)
|
||||||
|
.agg(collectSortSliceAggregator.toColumn)
|
||||||
|
.toDF("key", "block")
|
||||||
|
.select(col("block.block").as("block"))
|
||||||
|
|
||||||
|
/*.groupBy("key")
|
||||||
|
.agg(collectSortSliceUDAF(rowDataType.fieldNames.map(col): _*).as("block"))*/
|
||||||
|
.filter(functions.size(new Column("block")).geq(new Literal(2, DataTypes.IntegerType)))
|
||||||
|
|
||||||
|
ds
|
||||||
|
}
|
||||||
|
|
||||||
|
val generateClustersWithRDDReduction: (Dataset[Row] => Dataset[Row]) = df => {
|
||||||
|
val df_with_filters = conf.getPace.getModel.asScala.foldLeft(df)((res, fdef) => {
|
||||||
|
if (conf.blacklists.containsKey(fdef.getName)) {
|
||||||
|
res.withColumn(
|
||||||
|
fdef.getName + "_filtered",
|
||||||
|
filterColumnUDF(fdef).apply(new Column(fdef.getName))
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
res
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
var relBlocks: Dataset[Row] = null
|
||||||
|
|
||||||
|
import scala.collection.JavaConversions._
|
||||||
|
|
||||||
|
for (cd <- conf.clusterings()) {
|
||||||
|
val columns: util.List[Column] = new util.ArrayList[Column](cd.getFields().size)
|
||||||
|
|
||||||
|
for (fName <- cd.getFields()) {
|
||||||
|
if (conf.blacklists.containsKey(fName))
|
||||||
|
columns.add(new Column(fName + "_filtered"))
|
||||||
|
else
|
||||||
|
columns.add(new Column(fName))
|
||||||
|
}
|
||||||
|
|
||||||
|
val ds: Dataset[Row] = df.sparkSession.createDataFrame(df_with_filters.withColumn("key", functions.explode(clusterValuesUDF(cd).apply(functions.array(columns.asScala: _*))))
|
||||||
|
.select(col("key"), functions.array(functions.struct(rowDataType.fieldNames.map(col): _*).as("value")))
|
||||||
|
.rdd.keyBy(_.getString(0))
|
||||||
|
.reduceByKey((a, b) => {
|
||||||
|
val b1 = a.getSeq[Row](1)
|
||||||
|
val b2 = b.getSeq[Row](1)
|
||||||
|
|
||||||
|
if (b1.size + b2.size > conf.getWf.getQueueMaxSize)
|
||||||
|
Row(a.get(0), b1.union(b2).sortBy(_.getString(orderingFieldPosition)).slice(0, conf.getWf.getQueueMaxSize))
|
||||||
|
else
|
||||||
|
Row(a.get(0), b1.union(b2))
|
||||||
|
})
|
||||||
|
.map(_._2)
|
||||||
|
.filter(k => k.getSeq(1).size > 1),
|
||||||
|
new StructType().add(StructField("key", DataTypes.StringType)).add(StructField("block", ArrayType(rowDataType)))
|
||||||
|
)
|
||||||
|
|
||||||
if (relBlocks == null) relBlocks = ds
|
if (relBlocks == null) relBlocks = ds
|
||||||
else relBlocks = relBlocks.union(ds)
|
else relBlocks = relBlocks.union(ds)
|
||||||
}
|
}
|
||||||
|
@ -167,6 +274,8 @@ case class SparkDedupConfig(conf: DedupConfig, numPartitions: Int) extends Seria
|
||||||
|
|
||||||
treeProcessor.compare(a, b)
|
treeProcessor.compare(a, b)
|
||||||
}).apply(functions.struct(rowDataType.fieldNames.map(s => col("l.".concat(s))): _*), functions.struct(rowDataType.fieldNames.map(s => col("r.".concat(s))): _*)))
|
}).apply(functions.struct(rowDataType.fieldNames.map(s => col("l.".concat(s))): _*), functions.struct(rowDataType.fieldNames.map(s => col("r.".concat(s))): _*)))
|
||||||
|
.filter(col("match").equalTo(true))
|
||||||
|
.select(col("l.identifier").as("from"), col("r.identifier").as("to"))
|
||||||
|
|
||||||
// dsWithMatch.show(false)
|
// dsWithMatch.show(false)
|
||||||
|
|
||||||
|
@ -176,9 +285,9 @@ case class SparkDedupConfig(conf: DedupConfig, numPartitions: Int) extends Seria
|
||||||
relBlocks = relBlocks.union(dsWithMatch)
|
relBlocks = relBlocks.union(dsWithMatch)
|
||||||
}
|
}
|
||||||
|
|
||||||
val res = relBlocks.filter(col("match").equalTo(true))
|
val res = relBlocks
|
||||||
.select(col("l.identifier").as("from"), col("r.identifier").as("to"))
|
//.select(col("l.identifier").as("from"), col("r.identifier").as("to"))
|
||||||
//.repartition()
|
.repartition()
|
||||||
.distinct()
|
.distinct()
|
||||||
|
|
||||||
// res.show(false)
|
// res.show(false)
|
||||||
|
@ -301,7 +410,7 @@ case class SparkDedupConfig(conf: DedupConfig, numPartitions: Int) extends Seria
|
||||||
|
|
||||||
def clusterValuesUDF(cd: ClusteringDef) = {
|
def clusterValuesUDF(cd: ClusteringDef) = {
|
||||||
udf[mutable.WrappedArray[String], mutable.WrappedArray[Object]](values => {
|
udf[mutable.WrappedArray[String], mutable.WrappedArray[Object]](values => {
|
||||||
values.flatMap(f => cd.clusteringFunction().apply(conf, Seq(f.toString).asJava).asScala)
|
values.flatMap(f => cd.clusteringFunction().apply(conf, Seq(f.toString).asJava).asScala).map(cd.getName.concat(_))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -311,17 +420,78 @@ case class SparkDedupConfig(conf: DedupConfig, numPartitions: Int) extends Seria
|
||||||
udf[Array[Tuple2[String, String]], mutable.WrappedArray[Row]](block => {
|
udf[Array[Tuple2[String, String]], mutable.WrappedArray[Row]](block => {
|
||||||
val reporter = new SparkReporter(accumulators)
|
val reporter = new SparkReporter(accumulators)
|
||||||
|
|
||||||
//now done by spark
|
val mapDocuments = block.asJava.stream
|
||||||
// val mapDocuments = block.asJava.stream
|
.sorted(new RowDataOrderingComparator(orderingFieldPosition))
|
||||||
// .sorted(new RowDataOrderingComparator(rowDataType.fieldIndex(conf.getWf.getOrderField)))
|
.limit(conf.getWf.getQueueMaxSize)
|
||||||
// .limit(conf.getWf.getQueueMaxSize)
|
.collect(Collectors.toList[Row]())
|
||||||
// .collect(Collectors.toList[Row]())
|
|
||||||
|
|
||||||
|
new BlockProcessor(conf, identityFieldPosition, orderingFieldPosition).processSortedRows(mapDocuments, reporter)
|
||||||
new BlockProcessor(conf, identityFieldPosition, orderingFieldPosition).processSortedRows(block.asJava, reporter)
|
|
||||||
|
|
||||||
reporter.getRelations.asScala.toArray
|
reporter.getRelations.asScala.toArray
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
val collectSortSliceAggregator : Aggregator[Row,Seq[Row], Row] = new Aggregator[Row, Seq[Row], Row] () {
|
||||||
|
override def zero: Seq[Row] = Seq[Row]()
|
||||||
|
|
||||||
|
|
||||||
|
override def reduce(buffer: Seq[Row], input: Row): Seq[Row] = {
|
||||||
|
merge(buffer, Seq(input))
|
||||||
|
}
|
||||||
|
|
||||||
|
override def merge(buffer: Seq[Row], toMerge: Seq[Row]): Seq[Row] = {
|
||||||
|
val newBlock = buffer ++ toMerge
|
||||||
|
|
||||||
|
if (newBlock.size > conf.getWf.getQueueMaxSize)
|
||||||
|
newBlock.sortBy(_.getString(orderingFieldPosition)).slice(0, conf.getWf.getQueueMaxSize)
|
||||||
|
else
|
||||||
|
newBlock
|
||||||
|
}
|
||||||
|
|
||||||
|
override def finish(reduction: Seq[Row]): Row = {
|
||||||
|
Row(reduction.toArray)
|
||||||
|
}
|
||||||
|
|
||||||
|
override def bufferEncoder: Encoder[Seq[Row]] = Encoders.kryo[Seq[Row]]
|
||||||
|
|
||||||
|
override def outputEncoder: Encoder[Row] = RowEncoder.apply(new StructType().add("block", DataTypes.createArrayType(rowDataType), nullable = true))
|
||||||
|
}
|
||||||
|
|
||||||
|
val collectSortSliceUDAF : UserDefinedAggregateFunction = new UserDefinedAggregateFunction {
|
||||||
|
override def inputSchema: StructType = rowDataType
|
||||||
|
|
||||||
|
override def bufferSchema: StructType = {
|
||||||
|
new StructType().add("block", DataTypes.createArrayType(rowDataType), nullable = true)
|
||||||
|
}
|
||||||
|
|
||||||
|
override def dataType: DataType = DataTypes.createArrayType(rowDataType)
|
||||||
|
|
||||||
|
override def deterministic: Boolean = true
|
||||||
|
|
||||||
|
override def initialize(buffer: MutableAggregationBuffer): Unit = {
|
||||||
|
buffer(0) = Seq[Row]()
|
||||||
|
}
|
||||||
|
|
||||||
|
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
|
||||||
|
val newBlock = buffer.getSeq[Row](0) ++ Seq(input)
|
||||||
|
|
||||||
|
if (newBlock.size > conf.getWf.getQueueMaxSize)
|
||||||
|
buffer(0) = newBlock.sortBy(_.getString(orderingFieldPosition)).slice(0, conf.getWf.getQueueMaxSize)
|
||||||
|
else
|
||||||
|
buffer(0) = newBlock
|
||||||
|
}
|
||||||
|
|
||||||
|
override def merge(buffer: MutableAggregationBuffer, row: Row): Unit = {
|
||||||
|
val newBlock = buffer.getSeq[Row](0) ++ row.getSeq[Row](0)
|
||||||
|
|
||||||
|
if (newBlock.size > conf.getWf.getQueueMaxSize)
|
||||||
|
buffer(0) = newBlock.sortBy(_.getString(orderingFieldPosition)).slice(0, conf.getWf.getQueueMaxSize)
|
||||||
|
else
|
||||||
|
buffer(0) = newBlock
|
||||||
|
}
|
||||||
|
|
||||||
|
override def evaluate(buffer: Row): Any = {
|
||||||
|
buffer.getSeq[Row](0)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -86,15 +86,17 @@ public class SparkCreateSimRels extends AbstractSparkAction {
|
||||||
|
|
||||||
SparkDedupConfig sparkConfig = new SparkDedupConfig(dedupConf, numPartitions);
|
SparkDedupConfig sparkConfig = new SparkDedupConfig(dedupConf, numPartitions);
|
||||||
|
|
||||||
|
spark.udf().register("collect_sort_slice", sparkConfig.collectSortSliceUDAF());
|
||||||
|
|
||||||
Dataset<?> simRels = spark
|
Dataset<?> simRels = spark
|
||||||
.read()
|
.read()
|
||||||
.textFile(DedupUtility.createEntityPath(graphBasePath, subEntity))
|
.textFile(DedupUtility.createEntityPath(graphBasePath, subEntity))
|
||||||
.transform(sparkConfig.modelExtractor()) // Extract fields from input json column according to model
|
.transform(sparkConfig.modelExtractor()) // Extract fields from input json column according to model
|
||||||
// definition
|
// definition
|
||||||
.transform(sparkConfig.generateAndProcessClustersWithJoins()) // generate <key,block> pairs according to
|
.transform(sparkConfig.generateClustersWithDFAPIMerged()) // generate <key,block> pairs according to
|
||||||
// filters, clusters, and model
|
// filters, clusters, and model
|
||||||
// definition
|
// definition
|
||||||
// .transform(sparkConfig.processClusters()) // process blocks and emits <from,to> pairs of found
|
.transform(sparkConfig.processClusters()) // process blocks and emits <from,to> pairs of found
|
||||||
// similarities
|
// similarities
|
||||||
.map(
|
.map(
|
||||||
(MapFunction<Row, Relation>) t -> DedupUtility
|
(MapFunction<Row, Relation>) t -> DedupUtility
|
||||||
|
|
Loading…
Reference in New Issue