package eu.dnetlib.pace.model import eu.dnetlib.pace.config.{DedupConfig, Type} import eu.dnetlib.pace.util.{BlockProcessor, SparkReporter} import org.apache.spark.SparkContext import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.expressions._ import org.apache.spark.sql.functions.{col, lit, udf} import org.apache.spark.sql.types._ import org.apache.spark.sql.{Column, Dataset, Row, functions} import java.util.function.Predicate import java.util.stream.Collectors import scala.collection.JavaConversions._ import scala.collection.JavaConverters._ import scala.collection.mutable case class SparkDeduper(conf: DedupConfig) extends Serializable { val model: SparkModel = SparkModel(conf) val dedup: (Dataset[Row] => Dataset[Row]) = df => { df.transform(filterAndCleanup) .transform(generateClustersWithCollect) .transform(processBlocks) } val filterAndCleanup: (Dataset[Row] => Dataset[Row]) = df => { val df_with_filters = conf.getPace.getModel.asScala.foldLeft(df)((res, fdef) => { if (conf.blacklists.containsKey(fdef.getName)) { res.withColumn( fdef.getName + "_filtered", filterColumnUDF(fdef).apply(new Column(fdef.getName)) ) } else { res } }) df_with_filters } def filterColumnUDF(fdef: FieldDef): UserDefinedFunction = { val blacklist: Predicate[String] = conf.blacklists().get(fdef.getName) if (blacklist == null) { throw new IllegalArgumentException("Column: " + fdef.getName + " does not have any filter") } else { fdef.getType match { case Type.List | Type.JSON => udf[Array[String], Array[String]](values => { values.filter((v: String) => !blacklist.test(v)) }) case _ => udf[String, String](v => { if (blacklist.test(v)) "" else v }) } } } val generateClustersWithCollect: (Dataset[Row] => Dataset[Row]) = df_with_filters => { var df_with_clustering_keys: Dataset[Row] = null for ((cd, idx) <- conf.clusterings().zipWithIndex) { val inputColumns = cd.getFields().foldLeft(Seq[Column]())((acc, fName) => { val column = if (conf.blacklists.containsKey(fName)) Seq(col(fName + "_filtered")) else Seq(col(fName)) acc ++ column }) // Add 'key' column with the value generated by the given clustering definition val ds: Dataset[Row] = df_with_filters .withColumn("clustering", lit(cd.getName + "::" + idx)) .withColumn("key", functions.explode(clusterValuesUDF(cd).apply(functions.array(inputColumns: _*)))) // Add position column having the position of the row within the set of rows having the same key value ordered by the sorting value .withColumn("position", functions.row_number().over(Window.partitionBy("key").orderBy(col(model.orderingFieldName), col(model.identifierFieldName)))) if (df_with_clustering_keys == null) df_with_clustering_keys = ds else df_with_clustering_keys = df_with_clustering_keys.union(ds) } //TODO: analytics val df_with_blocks = df_with_clustering_keys // filter out rows with position exceeding the maxqueuesize parameter .filter(col("position").leq(conf.getWf.getQueueMaxSize)) .groupBy("clustering", "key") .agg(functions.collect_set(functions.struct(model.schema.fieldNames.map(col): _*)).as("block")) .filter(functions.size(new Column("block")).gt(1)) df_with_blocks } def clusterValuesUDF(cd: ClusteringDef) = { udf[mutable.WrappedArray[String], mutable.WrappedArray[Any]](values => { values.flatMap(f => cd.clusteringFunction().apply(conf, Seq(f.toString).asJava).asScala) }) } val processBlocks: (Dataset[Row] => Dataset[Row]) = df => { df.filter(functions.size(new Column("block")).geq(new Literal(2, DataTypes.IntegerType))) .withColumn("relations", processBlock(df.sqlContext.sparkContext).apply(new Column("block"))) .select(functions.explode(new Column("relations")).as("relation")) } def processBlock(implicit sc: SparkContext) = { val accumulators = SparkReporter.constructAccumulator(conf, sc) udf[Array[(String, String)], mutable.WrappedArray[Row]](block => { val reporter = new SparkReporter(accumulators) val mapDocuments = block.asJava.stream() .sorted(new RowDataOrderingComparator(model.orderingFieldPosition, model.identityFieldPosition)) .limit(conf.getWf.getQueueMaxSize) .collect(Collectors.toList[Row]()) new BlockProcessor(conf, model.identityFieldPosition, model.orderingFieldPosition).processSortedRows(mapDocuments, reporter) reporter.getRelations.asScala.toArray }).asNondeterministic() } }