158 lines
6.4 KiB
Scala
158 lines
6.4 KiB
Scala
package eu.dnetlib.pace.model
|
|
|
|
import eu.dnetlib.pace.config.{DedupConfig, Type}
|
|
import eu.dnetlib.pace.util.{BlockProcessor, SparkReporter}
|
|
import org.apache.spark.SparkContext
|
|
import org.apache.spark.sql.catalyst.expressions.Literal
|
|
import org.apache.spark.sql.expressions._
|
|
import org.apache.spark.sql.functions.{col, desc, expr, lit, udf}
|
|
import org.apache.spark.sql.types._
|
|
import org.apache.spark.sql.{Column, Dataset, Row, SaveMode, functions}
|
|
|
|
import java.util.function.Predicate
|
|
import java.util.stream.Collectors
|
|
import scala.collection.JavaConversions._
|
|
import scala.collection.JavaConverters._
|
|
import scala.collection.mutable
|
|
case class SparkDeduper(conf: DedupConfig) extends Serializable {
|
|
|
|
val model: SparkModel = SparkModel(conf)
|
|
|
|
val dedup: (Dataset[Row] => Dataset[Row]) = df => {
|
|
df.transform(filterAndCleanup)
|
|
.transform(generateClustersWithCollect)
|
|
.transform(processBlocks)
|
|
}
|
|
|
|
|
|
val filterAndCleanup: (Dataset[Row] => Dataset[Row]) = df => {
|
|
val df_with_filters = conf.getPace.getModel.asScala.foldLeft(df)((res, fdef) => {
|
|
if (conf.blacklists.containsKey(fdef.getName)) {
|
|
res.withColumn(
|
|
fdef.getName + "_filtered",
|
|
filterColumnUDF(fdef).apply(new Column(fdef.getName))
|
|
)
|
|
} else {
|
|
res
|
|
}
|
|
})
|
|
|
|
df_with_filters
|
|
}
|
|
|
|
def filterColumnUDF(fdef: FieldDef): UserDefinedFunction = {
|
|
val blacklist: Predicate[String] = conf.blacklists().get(fdef.getName)
|
|
|
|
if (blacklist == null) {
|
|
throw new IllegalArgumentException("Column: " + fdef.getName + " does not have any filter")
|
|
} else {
|
|
fdef.getType match {
|
|
case Type.List | Type.JSON =>
|
|
udf[Array[String], Array[String]](values => {
|
|
values.filter((v: String) => !blacklist.test(v))
|
|
})
|
|
|
|
case _ =>
|
|
udf[String, String](v => {
|
|
if (blacklist.test(v)) ""
|
|
else v
|
|
})
|
|
}
|
|
}
|
|
}
|
|
|
|
val generateClustersWithCollect: (Dataset[Row] => Dataset[Row]) = df_with_filters => {
|
|
var df_with_clustering_keys: Dataset[Row] = null
|
|
|
|
for ((cd, idx) <- conf.clusterings().zipWithIndex) {
|
|
val inputColumns = cd.getFields().foldLeft(Seq[Column]())((acc, fName) => {
|
|
val column = if (conf.blacklists.containsKey(fName))
|
|
Seq(col(fName + "_filtered"))
|
|
else
|
|
Seq(col(fName))
|
|
|
|
acc ++ column
|
|
})
|
|
|
|
// Add 'key' column with the value generated by the given clustering definition
|
|
val ds: Dataset[Row] = df_with_filters
|
|
.withColumn("clustering", lit(cd.getName + "::" + idx))
|
|
.withColumn("key", functions.explode(clusterValuesUDF(cd).apply(functions.array(inputColumns: _*))))
|
|
// Add position column having the position of the row within the set of rows having the same key value ordered by the sorting value
|
|
.withColumn("position", functions.row_number().over(Window.partitionBy("key").orderBy(col(model.orderingFieldName), col(model.identifierFieldName))))
|
|
// .withColumn("count", functions.max("position").over(Window.partitionBy("key").orderBy(col(model.orderingFieldName), col(model.identifierFieldName)).rowsBetween(Window.unboundedPreceding,Window.unboundedFollowing) ))
|
|
// .filter("count > 1")
|
|
|
|
if (df_with_clustering_keys == null)
|
|
df_with_clustering_keys = ds
|
|
else
|
|
df_with_clustering_keys = df_with_clustering_keys.union(ds)
|
|
}
|
|
|
|
//TODO: analytics
|
|
/*df_with_clustering_keys.groupBy(col("clustering"), col("key"))
|
|
.agg(expr("max(count) AS size"))
|
|
.orderBy(desc("size"))
|
|
.show*/
|
|
|
|
val df_with_blocks = df_with_clustering_keys
|
|
// split the clustering block into smaller blocks of queuemaxsize
|
|
.groupBy(col("clustering"), col("key"), functions.floor(col("position").divide(lit(conf.getWf.getQueueMaxSize))))
|
|
.agg(functions.collect_set(functions.struct(model.schema.fieldNames.map(col): _*)).as("block"))
|
|
.filter(functions.size(new Column("block")).gt(1))
|
|
.union(
|
|
//adjacency blocks
|
|
df_with_clustering_keys
|
|
// filter out leading and trailing elements
|
|
.filter(col("position").gt(conf.getWf.getSlidingWindowSize/2))
|
|
//.filter(col("position").lt(col("count").minus(conf.getWf.getSlidingWindowSize/2)))
|
|
// create small blocks of records on "the border" of maxqueuesize: getSlidingWindowSize/2 elements before and after
|
|
.filter(
|
|
col("position").mod(conf.getWf.getQueueMaxSize).lt(conf.getWf.getSlidingWindowSize/2) // slice of the start of block
|
|
|| col("position").mod(conf.getWf.getQueueMaxSize).gt(conf.getWf.getQueueMaxSize - (conf.getWf.getSlidingWindowSize/2)) //slice of the end of the block
|
|
)
|
|
.groupBy(col("clustering"), col("key"), functions.floor((col("position") + lit(conf.getWf.getSlidingWindowSize/2)).divide(lit(conf.getWf.getQueueMaxSize))))
|
|
.agg(functions.collect_set(functions.struct(model.schema.fieldNames.map(col): _*)).as("block"))
|
|
.filter(functions.size(new Column("block")).gt(1))
|
|
)
|
|
|
|
df_with_blocks
|
|
}
|
|
|
|
def clusterValuesUDF(cd: ClusteringDef) = {
|
|
udf[mutable.WrappedArray[String], mutable.WrappedArray[Any]](values => {
|
|
val valueList = values.flatMap {
|
|
case a: mutable.WrappedArray[Any] => a.map(_.toString)
|
|
case s: Any => Seq(s.toString)
|
|
}.asJava;
|
|
|
|
mutable.WrappedArray.make(cd.clusteringFunction().apply(conf, valueList).toArray())
|
|
|
|
})
|
|
}
|
|
|
|
val processBlocks: (Dataset[Row] => Dataset[Row]) = df => {
|
|
df.filter(functions.size(new Column("block")).geq(new Literal(2, DataTypes.IntegerType)))
|
|
.withColumn("relations", processBlock(df.sqlContext.sparkContext).apply(new Column("block")))
|
|
.select(functions.explode(new Column("relations")).as("relation"))
|
|
}
|
|
|
|
def processBlock(implicit sc: SparkContext) = {
|
|
val accumulators = SparkReporter.constructAccumulator(conf, sc)
|
|
|
|
udf[Array[(String, String)], mutable.WrappedArray[Row]](block => {
|
|
val reporter = new SparkReporter(accumulators)
|
|
|
|
val mapDocuments = block.asJava.stream()
|
|
.sorted(new RowDataOrderingComparator(model.orderingFieldPosition, model.identityFieldPosition))
|
|
.limit(conf.getWf.getQueueMaxSize)
|
|
.collect(Collectors.toList[Row]())
|
|
|
|
new BlockProcessor(conf, model.identityFieldPosition, model.orderingFieldPosition).processSortedRows(mapDocuments, reporter)
|
|
|
|
reporter.getRelations.asScala.toArray
|
|
}).asNondeterministic()
|
|
}
|
|
|
|
}
|