132 lines
4.7 KiB
Scala
132 lines
4.7 KiB
Scala
package eu.dnetlib.pace.model
|
|
|
|
import eu.dnetlib.pace.config.{DedupConfig, Type}
|
|
import eu.dnetlib.pace.util.{BlockProcessor, SparkReporter}
|
|
import org.apache.spark.SparkContext
|
|
import org.apache.spark.sql.catalyst.expressions.Literal
|
|
import org.apache.spark.sql.expressions._
|
|
import org.apache.spark.sql.functions.{col, lit, udf}
|
|
import org.apache.spark.sql.types._
|
|
import org.apache.spark.sql.{Column, Dataset, Row, functions}
|
|
|
|
import java.util.function.Predicate
|
|
import java.util.stream.Collectors
|
|
import scala.collection.JavaConversions._
|
|
import scala.collection.JavaConverters._
|
|
import scala.collection.mutable
|
|
case class SparkDeduper(conf: DedupConfig) extends Serializable {
|
|
|
|
val model: SparkModel = SparkModel(conf)
|
|
|
|
val dedup: (Dataset[Row] => Dataset[Row]) = df => {
|
|
df.transform(filterAndCleanup)
|
|
.transform(generateClustersWithCollect)
|
|
.transform(processBlocks)
|
|
}
|
|
|
|
|
|
val filterAndCleanup: (Dataset[Row] => Dataset[Row]) = df => {
|
|
val df_with_filters = conf.getPace.getModel.asScala.foldLeft(df)((res, fdef) => {
|
|
if (conf.blacklists.containsKey(fdef.getName)) {
|
|
res.withColumn(
|
|
fdef.getName + "_filtered",
|
|
filterColumnUDF(fdef).apply(new Column(fdef.getName))
|
|
)
|
|
} else {
|
|
res
|
|
}
|
|
})
|
|
|
|
df_with_filters
|
|
}
|
|
|
|
def filterColumnUDF(fdef: FieldDef): UserDefinedFunction = {
|
|
val blacklist: Predicate[String] = conf.blacklists().get(fdef.getName)
|
|
|
|
if (blacklist == null) {
|
|
throw new IllegalArgumentException("Column: " + fdef.getName + " does not have any filter")
|
|
} else {
|
|
fdef.getType match {
|
|
case Type.List | Type.JSON =>
|
|
udf[Array[String], Array[String]](values => {
|
|
values.filter((v: String) => !blacklist.test(v))
|
|
})
|
|
|
|
case _ =>
|
|
udf[String, String](v => {
|
|
if (blacklist.test(v)) ""
|
|
else v
|
|
})
|
|
}
|
|
}
|
|
}
|
|
|
|
val generateClustersWithCollect: (Dataset[Row] => Dataset[Row]) = df_with_filters => {
|
|
var df_with_clustering_keys: Dataset[Row] = null
|
|
|
|
for ((cd, idx) <- conf.clusterings().zipWithIndex) {
|
|
val inputColumns = cd.getFields().foldLeft(Seq[Column]())((acc, fName) => {
|
|
val column = if (conf.blacklists.containsKey(fName))
|
|
Seq(col(fName + "_filtered"))
|
|
else
|
|
Seq(col(fName))
|
|
|
|
acc ++ column
|
|
})
|
|
|
|
// Add 'key' column with the value generated by the given clustering definition
|
|
val ds: Dataset[Row] = df_with_filters
|
|
.withColumn("clustering", lit(cd.getName + "::" + idx))
|
|
.withColumn("key", functions.explode(clusterValuesUDF(cd).apply(functions.array(inputColumns: _*))))
|
|
// Add position column having the position of the row within the set of rows having the same key value ordered by the sorting value
|
|
.withColumn("position", functions.row_number().over(Window.partitionBy("key").orderBy(col(model.orderingFieldName), col(model.identifierFieldName))))
|
|
|
|
if (df_with_clustering_keys == null)
|
|
df_with_clustering_keys = ds
|
|
else
|
|
df_with_clustering_keys = df_with_clustering_keys.union(ds)
|
|
}
|
|
|
|
//TODO: analytics
|
|
|
|
val df_with_blocks = df_with_clustering_keys
|
|
// filter out rows with position exceeding the maxqueuesize parameter
|
|
.filter(col("position").leq(conf.getWf.getQueueMaxSize))
|
|
.groupBy("clustering", "key")
|
|
.agg(functions.collect_set(functions.struct(model.schema.fieldNames.map(col): _*)).as("block"))
|
|
.filter(functions.size(new Column("block")).gt(1))
|
|
|
|
df_with_blocks
|
|
}
|
|
|
|
def clusterValuesUDF(cd: ClusteringDef) = {
|
|
udf[mutable.WrappedArray[String], mutable.WrappedArray[Any]](values => {
|
|
values.flatMap(f => cd.clusteringFunction().apply(conf, Seq(f.toString).asJava).asScala)
|
|
})
|
|
}
|
|
|
|
val processBlocks: (Dataset[Row] => Dataset[Row]) = df => {
|
|
df.filter(functions.size(new Column("block")).geq(new Literal(2, DataTypes.IntegerType)))
|
|
.withColumn("relations", processBlock(df.sqlContext.sparkContext).apply(new Column("block")))
|
|
.select(functions.explode(new Column("relations")).as("relation"))
|
|
}
|
|
|
|
def processBlock(implicit sc: SparkContext) = {
|
|
val accumulators = SparkReporter.constructAccumulator(conf, sc)
|
|
|
|
udf[Array[(String, String)], mutable.WrappedArray[Row]](block => {
|
|
val reporter = new SparkReporter(accumulators)
|
|
|
|
val mapDocuments = block.asJava.stream()
|
|
.sorted(new RowDataOrderingComparator(model.orderingFieldPosition, model.identityFieldPosition))
|
|
.limit(conf.getWf.getQueueMaxSize)
|
|
.collect(Collectors.toList[Row]())
|
|
|
|
new BlockProcessor(conf, model.identityFieldPosition, model.orderingFieldPosition).processSortedRows(mapDocuments, reporter)
|
|
|
|
reporter.getRelations.asScala.toArray
|
|
}).asNondeterministic()
|
|
}
|
|
|
|
}
|