package eu.dnetlib.pace.model import com.jayway.jsonpath.{Configuration, JsonPath, Option} import eu.dnetlib.pace.config.{DedupConfig, Type} import eu.dnetlib.pace.tree.support.TreeProcessor import eu.dnetlib.pace.util.MapDocumentUtil.truncateValue import eu.dnetlib.pace.util.{BlockProcessor, MapDocumentUtil, SparkReporter} import org.apache.spark.SparkContext import org.apache.spark.rdd.RDD.rddToPairRDDFunctions import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} import org.apache.spark.sql.{Column, Dataset, Encoder, Encoders, Row, functions} import org.apache.spark.sql.catalyst.expressions.{GenericRowWithSchema, Literal} import org.apache.spark.sql.expressions.{Aggregator, MutableAggregationBuffer, UserDefinedAggregateFunction, UserDefinedFunction, Window} import org.apache.spark.sql.types.{ArrayType, DataType, DataTypes, Metadata, StructField, StructType} import java.util import java.util.function.Predicate import java.util.regex.Pattern import scala.collection.JavaConverters._ import scala.collection.mutable import org.apache.spark.sql.functions.{col, lit, udf} import java.util.Collections import java.util.stream.Collectors case class SparkDedupConfig(conf: DedupConfig, numPartitions: Int) extends Serializable { private val URL_REGEX: Pattern = Pattern.compile("^\\s*(http|https|ftp)\\://.*") private val CONCAT_REGEX: Pattern = Pattern.compile("\\|\\|\\|") private val urlFilter = (s: String) => URL_REGEX.matcher(s).matches val modelExtractor: (Dataset[String] => Dataset[Row]) = df => { df.withColumn("mapDocument", rowFromJsonUDF.apply(df.col(df.columns(0)))) .withColumn("identifier", new Column("mapDocument.identifier")) //.repartition(new Column("identifier")) .dropDuplicates("identifier") .select("mapDocument.*") df.map(r => rowFromJson(r))(RowEncoder(rowDataType)) .dropDuplicates("identifier") } val generateClusters: (Dataset[Row] => Dataset[Row]) = df => { val df_with_filters = conf.getPace.getModel.asScala.foldLeft(df)((res, fdef) => { if (conf.blacklists.containsKey(fdef.getName)) { res.withColumn( fdef.getName + "_filtered", filterColumnUDF(fdef).apply(new Column(fdef.getName)) ) } else { res } }) val df_with_keys = conf .clusterings() .asScala .foldLeft(df_with_filters)((res, cd) => { res.withColumn( cd.getName + "_clustered", functions.explode_outer( clusterValuesUDF(cd).apply( functions.array( cd.getFields.asScala .map(f => res.col(if (conf.blacklists.containsKey(f)) f.concat("_filtered") else f)): _* ) ) ) ) }) // filter blacklisted values// filter blacklisted values // create one column per cluster prefix// create one column per cluster prefix // GROUPING sets approach// GROUPING sets approach val tempTable = this.getClass.getSimpleName + "__generateClusters"; df_with_keys.createOrReplaceTempView(this.getClass.getSimpleName + "__generateClusters") val keys = conf.clusterings().asScala.map(_.getName + "_clustered").mkString(",") val fields = rowDataType.fieldNames.mkString(",") // Using SQL because GROUPING SETS are not available through Scala/Java DSL df_with_keys.sqlContext.sql( ("SELECT coalesce(" + keys + ") as key, sort_array(collect_sort_slice(" + fields + ")) as block FROM " + tempTable + " WHERE coalesce(" + keys + ") IS NOT NULL GROUP BY GROUPING SETS (" + keys + ") HAVING size(block) > 1") ) } val generateClustersWithDFAPI: (Dataset[Row] => Dataset[Row]) = df => { System.out.println(conf.getWf.getEntityType + "::" +conf.getWf.getSubEntityType) val df_with_filters = conf.getPace.getModel.asScala.foldLeft(df)((res, fdef) => { if (conf.blacklists.containsKey(fdef.getName)) { res.withColumn( fdef.getName + "_filtered", filterColumnUDF(fdef).apply(new Column(fdef.getName)) ) } else { res } }) var relBlocks: Dataset[Row] = null import scala.collection.JavaConversions._ for (cd <- conf.clusterings()) { val columns: util.List[Column] = new util.ArrayList[Column](cd.getFields().size) for (fName <- cd.getFields()) { if (conf.blacklists.containsKey(fName)) columns.add(new Column(fName + "_filtered")) else columns.add(new Column(fName)) } val tmp: Dataset[Row] = df_with_filters.withColumn("key", functions.explode(clusterValuesUDF(cd).apply(functions.array(columns.asScala: _*)))) /*.select((Seq(rowDataType.fieldNames: _*) ++ Seq("key")).map(col): _*) .groupByKey(r => r.getAs[String]("key"))(Encoders.STRING) .agg(collectSortSliceAggregator.toColumn) .toDF("key", "block") .select(col("block.block").as("block"))*/ System.out.println(cd.getName) val ds = tmp.groupBy("key") // .agg(functions.sort_array(collectSortSliceUDAF(rowDataType.fieldNames.map(col): _*)).as("block")) .agg(functions.collect_set(functions.struct(rowDataType.fieldNames.map(col): _*)).as("block")) //.filter(functions.size(new Column("block")).geq(new Literal(2, DataTypes.IntegerType))) //df_with_filters.printSchema() //ds.printSchema() if (relBlocks == null) relBlocks = ds else relBlocks = relBlocks.union(ds) } // System.out.println() relBlocks } val generateClustersWithWindows: (Dataset[Row] => Dataset[Row]) = df => { val df_with_filters = conf.getPace.getModel.asScala.foldLeft(df)((res, fdef) => { if (conf.blacklists.containsKey(fdef.getName)) { res.withColumn( fdef.getName + "_filtered", filterColumnUDF(fdef).apply(new Column(fdef.getName)) ) } else { res } }) var relBlocks: Dataset[Row] = null import scala.collection.JavaConversions._ for (cd <- conf.clusterings()) { System.out.println(conf.getWf.getEntityType + "::" + conf.getWf.getSubEntityType+ ": " + cd.getName + " " + cd.toString) val columns: util.List[Column] = new util.ArrayList[Column](cd.getFields().size) for (fName <- cd.getFields()) { if (conf.blacklists.containsKey(fName)) columns.add(new Column(fName + "_filtered")) else columns.add(new Column(fName)) } // Add 'key' column with the value generated by the given clustering definition val ds: Dataset[Row] = df_with_filters.withColumn("key", functions.explode(clusterValuesUDF(cd).apply(functions.array(columns.asScala: _*)))) // Add position column having the position of the row within the set of rows having the same key value ordered by the sorting value .withColumn("position", functions.row_number().over(Window.partitionBy("key").orderBy(col(conf.getWf.getOrderField)))) // filter out rows with position exceeding the maxqueuesize parameter .filter(col("position").leq(conf.getWf.getQueueMaxSize)) .groupBy("key") .agg(functions.collect_set(functions.struct(rowDataType.fieldNames.map(col): _*)).as("block")) .filter(functions.size(new Column("block")).geq(new Literal(2, DataTypes.IntegerType))) if (relBlocks == null) relBlocks = ds else relBlocks = relBlocks.union(ds) } relBlocks } val generateClustersWithDFAPIMerged: (Dataset[Row] => Dataset[Row]) = df => { val df_with_filters = conf.getPace.getModel.asScala.foldLeft(df)((res, fdef) => { if (conf.blacklists.containsKey(fdef.getName)) { res.withColumn( fdef.getName + "_filtered", filterColumnUDF(fdef).apply(new Column(fdef.getName)) ) } else { res } }) import scala.collection.JavaConversions._ val keys = conf.clusterings().foldLeft(null : Column)((res, cd) => { val columns: util.List[Column] = new util.ArrayList[Column](cd.getFields().size) for (fName <- cd.getFields()) { if (conf.blacklists.containsKey(fName)) columns.add(new Column(fName + "_filtered")) else columns.add(new Column(fName)) } if (res != null) functions.array_union(res, clusterValuesUDF(cd).apply(functions.array(columns.asScala: _*))) else clusterValuesUDF(cd).apply(functions.array(columns.asScala: _*)) }) val ds: Dataset[Row] = df_with_filters.withColumn("key", functions.explode(keys)) .select((Seq(rowDataType.fieldNames: _*) ++ Seq("key")).map(col): _*) .groupByKey(r => r.getAs[String]("key"))(Encoders.STRING) .agg(collectSortSliceAggregator.toColumn) .toDF("key", "block") .select(col("block.block").as("block")) /*.groupBy("key") .agg(collectSortSliceUDAF(rowDataType.fieldNames.map(col): _*).as("block"))*/ .filter(functions.size(new Column("block")).geq(new Literal(2, DataTypes.IntegerType))) ds } val generateClustersWithRDDReduction: (Dataset[Row] => Dataset[Row]) = df => { val df_with_filters = conf.getPace.getModel.asScala.foldLeft(df)((res, fdef) => { if (conf.blacklists.containsKey(fdef.getName)) { res.withColumn( fdef.getName + "_filtered", filterColumnUDF(fdef).apply(new Column(fdef.getName)) ) } else { res } }) var relBlocks: Dataset[Row] = null import scala.collection.JavaConversions._ for (cd <- conf.clusterings()) { val columns: util.List[Column] = new util.ArrayList[Column](cd.getFields().size) for (fName <- cd.getFields()) { if (conf.blacklists.containsKey(fName)) columns.add(new Column(fName + "_filtered")) else columns.add(new Column(fName)) } val ds: Dataset[Row] = df.sparkSession.createDataFrame(df_with_filters.withColumn("key", functions.explode(clusterValuesUDF(cd).apply(functions.array(columns.asScala: _*)))) .select(col("key"), functions.array(functions.struct(rowDataType.fieldNames.map(col): _*).as("value"))) .rdd.keyBy(_.getString(0)) .reduceByKey((a, b) => { val b1 = a.getSeq[Row](1) val b2 = b.getSeq[Row](1) if (b1.size + b2.size > conf.getWf.getQueueMaxSize) Row(a.get(0), b1.union(b2).sortBy(_.getString(orderingFieldPosition)).slice(0, conf.getWf.getQueueMaxSize)) else Row(a.get(0), b1.union(b2)) }) .map(_._2) .filter(k => k.getSeq(1).size > 1), new StructType().add(StructField("key", DataTypes.StringType)).add(StructField("block", ArrayType(rowDataType))) ) if (relBlocks == null) relBlocks = ds else relBlocks = relBlocks.union(ds) } relBlocks } val printAnalytics: (Dataset[Row] => Dataset[Row]) = df => { val df_with_filters = conf.getPace.getModel.asScala.foldLeft(df)((res, fdef) => { if (conf.blacklists.containsKey(fdef.getName)) { res.withColumn( fdef.getName + "_filtered", filterColumnUDF(fdef).apply(new Column(fdef.getName)) ) } else { res } }) var relBlocks: Dataset[Row] = null import scala.collection.JavaConversions._ for (cd <- conf.clusterings()) { val columns: util.List[Column] = new util.ArrayList[Column](cd.getFields().size) for (fName <- cd.getFields()) { if (conf.blacklists.containsKey(fName)) columns.add(new Column(fName + "_filtered")) else columns.add(new Column(fName)) } // Add 'key' column with the value generated by the given clustering definition val ds: Dataset[Row] = df_with_filters.withColumn("key", functions.explode(clusterValuesUDF(cd).apply(functions.array(columns.asScala: _*)))) // Add position column having the position of the row within the set of rows having the same key value ordered by the sorting value .withColumn("position", functions.row_number().over(Window.partitionBy("key").orderBy(conf.getWf.getOrderField))) // filter out rows with position exceeding the maxqueuesize parameter .filter(col("position").lt(conf.getWf.getQueueMaxSize)) // inner join to compute all combination of rows to compare // note the condition on position to obtain 'windowing': given a row this is compared at most with the next // SlidingWindowSize rows following the sort order val dsWithMatch = ds.as("l").join(ds.as("r"), col("l.key").equalTo(col("r.key")), "inner" ) .filter((col("l.position").lt(col("r.position"))) && (col("r.position").lt(col("l.position").plus(lit(conf.getWf.getSlidingWindowSize))))) // Add match column with the result of comparison // dsWithMatch.show(false) if (relBlocks == null) relBlocks = dsWithMatch else relBlocks = relBlocks.union(dsWithMatch) } System.out.println(conf.getWf.getEntityType + "::" + conf.getWf.getSubEntityType) System.out.println("Total number of comparations: " + relBlocks.count()) df } val generateAndProcessClustersWithJoins: (Dataset[Row] => Dataset[Row]) = df => { val df_with_filters = conf.getPace.getModel.asScala.foldLeft(df)((res, fdef) => { if (conf.blacklists.containsKey(fdef.getName)) { res.withColumn( fdef.getName + "_filtered", filterColumnUDF(fdef).apply(new Column(fdef.getName)) ) } else { res } }) var relBlocks: Dataset[Row] = null import scala.collection.JavaConversions._ for (cd <- conf.clusterings()) { val columns: util.List[Column] = new util.ArrayList[Column](cd.getFields().size) for (fName <- cd.getFields()) { if (conf.blacklists.containsKey(fName)) columns.add(new Column(fName + "_filtered")) else columns.add(new Column(fName)) } // Add 'key' column with the value generated by the given clustering definition val ds: Dataset[Row] = df_with_filters.withColumn("key", functions.explode(clusterValuesUDF(cd).apply(functions.array(columns.asScala: _*)))) // Add position column having the position of the row within the set of rows having the same key value ordered by the sorting value .withColumn("position", functions.row_number().over(Window.partitionBy("key").orderBy(conf.getWf.getOrderField))) // filter out rows with position exceeding the maxqueuesize parameter .filter(col("position").lt(conf.getWf.getQueueMaxSize)) // inner join to compute all combination of rows to compare // note the condition on position to obtain 'windowing': given a row this is compared at most with the next // SlidingWindowSize rows following the sort order val dsWithMatch = ds.as("l").join(ds.as("r"), col("l.key").equalTo(col("r.key")), "inner" ) .filter((col("l.position").lt(col("r.position"))) && (col("r.position").lt(col("l.position").plus(lit(conf.getWf.getSlidingWindowSize))))) // Add match column with the result of comparison .withColumn("match", udf[Boolean, Row, Row]((a, b) => { val treeProcessor = new TreeProcessor(conf) treeProcessor.compare(a, b) }).apply(functions.struct(rowDataType.fieldNames.map(s => col("l.".concat(s))): _*), functions.struct(rowDataType.fieldNames.map(s => col("r.".concat(s))): _*))) .filter(col("match").equalTo(true)) .select(col("l.identifier").as("from"), col("r.identifier").as("to")) // dsWithMatch.show(false) if (relBlocks == null) relBlocks = dsWithMatch else relBlocks = relBlocks.union(dsWithMatch) } val res = relBlocks //.select(col("l.identifier").as("from"), col("r.identifier").as("to")) //.repartition() .distinct() // res.show(false) res.select(functions.struct("from", "to")) } val processClusters: (Dataset[Row] => Dataset[Row]) = df => { val entity = conf.getWf.getEntityType df.filter(functions.size(new Column("block")).geq(new Literal(2, DataTypes.IntegerType))) .withColumn("relations", processBlock(df.sqlContext.sparkContext).apply(new Column("block"))) .select(functions.explode(new Column("relations")).as("relation")) //.repartition(new Column("relation")) .dropDuplicates("relation") } val rowDataType: StructType = { // val unordered = conf.getPace.getModel.asScala.foldLeft( // new StructType() // )((resType, fdef) => { // resType.add(fdef.getType match { // case Type.List | Type.JSON => // StructField(fdef.getName, DataTypes.createArrayType(DataTypes.StringType), true, Metadata.empty) // case Type.DoubleArray => // StructField(fdef.getName, DataTypes.createArrayType(DataTypes.DoubleType), true, Metadata.empty) // case _ => // StructField(fdef.getName, DataTypes.StringType, true, Metadata.empty) // }) // }) // // conf.getPace.getModel.asScala.filterNot(_.getName.equals(conf.getWf.getOrderField)).foldLeft( // new StructType() // .add(unordered(conf.getWf.getOrderField)) // .add(StructField("identifier", DataTypes.StringType, false, Metadata.empty)) // )((resType, fdef) => resType.add(unordered(fdef.getName))) val identifier = new FieldDef() identifier.setName("identifier") identifier.setType(Type.String) (conf.getPace.getModel.asScala ++ Seq(identifier)).sortBy(_.getName) .foldLeft( new StructType() )((resType, fdef) => { resType.add(fdef.getType match { case Type.List | Type.JSON => StructField(fdef.getName, DataTypes.createArrayType(DataTypes.StringType), true, Metadata.empty) case Type.DoubleArray => StructField(fdef.getName, DataTypes.createArrayType(DataTypes.DoubleType), true, Metadata.empty) case _ => StructField(fdef.getName, DataTypes.StringType, true, Metadata.empty) }) }) } val identityFieldPosition: Int = rowDataType.fieldIndex("identifier") val orderingFieldPosition: Int = rowDataType.fieldIndex(conf.getWf.getOrderField) def rowFromJson(json: String) : Row = { val documentContext = JsonPath.using(Configuration.defaultConfiguration.addOptions(Option.SUPPRESS_EXCEPTIONS)).parse(json) val values = new Array[Any](rowDataType.size) values(identityFieldPosition) = MapDocumentUtil.getJPathString(conf.getWf.getIdPath, documentContext) rowDataType.fieldNames.zipWithIndex.foldLeft(values) { case ((res, (fname, index))) => { val fdef = conf.getPace.getModelMap.get(fname) if (fdef != null) { res(index) = fdef.getType match { case Type.String | Type.Int => MapDocumentUtil.truncateValue( MapDocumentUtil.getJPathString(fdef.getPath, documentContext), fdef.getLength ) case Type.URL => var uv = MapDocumentUtil.getJPathString(fdef.getPath, documentContext) if (!urlFilter(uv)) uv = "" uv case Type.List | Type.JSON => MapDocumentUtil.truncateList( MapDocumentUtil.getJPathList(fdef.getPath, documentContext, fdef.getType), fdef.getSize ).toArray case Type.StringConcat => val jpaths = CONCAT_REGEX.split(fdef.getPath) truncateValue( jpaths .map(jpath => MapDocumentUtil.getJPathString(jpath, documentContext)) .mkString(" "), fdef.getLength ) case Type.DoubleArray => MapDocumentUtil.getJPathArray(fdef.getPath, json) } } res } } new GenericRowWithSchema(values, rowDataType) } val rowFromJsonUDF = udf(rowFromJson(_), rowDataType) def filterColumnUDF(fdef: FieldDef): UserDefinedFunction = { val blacklist: Predicate[String] = conf.blacklists().get(fdef.getName) if (blacklist == null) { throw new IllegalArgumentException("Column: " + fdef.getName + " does not have any filter") } else { fdef.getType match { case Type.List | Type.JSON => udf[Array[String], Array[String]](values => { values.filter((v: String) => !blacklist.test(v)) }) case _ => udf[String, String](v => { if (blacklist.test(v)) "" else v }) } } } def clusterValuesUDF(cd: ClusteringDef) = { udf[mutable.WrappedArray[String], mutable.WrappedArray[Object]](values => { values.flatMap(f => cd.clusteringFunction().apply(conf, Seq(f.toString).asJava).asScala).map(cd.getName.concat(_)) }) } def processBlock(implicit sc: SparkContext) = { val accumulators = SparkReporter.constructAccumulator(conf, sc) udf[Array[Tuple2[String, String]], mutable.WrappedArray[Row]](block => { val reporter = new SparkReporter(accumulators) val mapDocuments = block.asJava.stream .sorted(new RowDataOrderingComparator(orderingFieldPosition)) .limit(conf.getWf.getQueueMaxSize) .collect(Collectors.toList[Row]()) new BlockProcessor(conf, identityFieldPosition, orderingFieldPosition).processSortedRows(mapDocuments, reporter) reporter.getRelations.asScala.toArray }).asNondeterministic() } val collectSortSliceAggregator : Aggregator[Row,Seq[Row], Row] = new Aggregator[Row, Seq[Row], Row] () { override def zero: Seq[Row] = Seq[Row]() override def reduce(buffer: Seq[Row], input: Row): Seq[Row] = { merge(buffer, Seq(input)) } override def merge(buffer: Seq[Row], toMerge: Seq[Row]): Seq[Row] = { val newBlock = buffer ++ toMerge if (newBlock.size > conf.getWf.getQueueMaxSize) newBlock.sortBy(_.getString(orderingFieldPosition)).slice(0, conf.getWf.getQueueMaxSize) else newBlock } override def finish(reduction: Seq[Row]): Row = { Row(reduction.toArray) } override def bufferEncoder: Encoder[Seq[Row]] = Encoders.kryo[Seq[Row]] override def outputEncoder: Encoder[Row] = RowEncoder.apply(new StructType().add("block", DataTypes.createArrayType(rowDataType), nullable = true)) } val collectSortSliceUDAF : UserDefinedAggregateFunction = new UserDefinedAggregateFunction { override def inputSchema: StructType = rowDataType override def bufferSchema: StructType = { new StructType().add("block", DataTypes.createArrayType(rowDataType), nullable = true) } override def dataType: DataType = DataTypes.createArrayType(rowDataType) override def deterministic: Boolean = true override def initialize(buffer: MutableAggregationBuffer): Unit = { buffer(0) = Seq[Row]() } override def update(buffer: MutableAggregationBuffer, input: Row): Unit = { val newBlock = buffer.getSeq[Row](0) ++ Seq(input) if (newBlock.size > conf.getWf.getQueueMaxSize) buffer(0) = newBlock.sortBy(_.getString(orderingFieldPosition)).slice(0, conf.getWf.getQueueMaxSize) else buffer(0) = newBlock } override def merge(buffer: MutableAggregationBuffer, row: Row): Unit = { val newBlock = buffer.getSeq[Row](0) ++ row.getSeq[Row](0) if (newBlock.size > conf.getWf.getQueueMaxSize) buffer(0) = newBlock.sortBy(_.getString(orderingFieldPosition)).slice(0, conf.getWf.getQueueMaxSize) else buffer(0) = newBlock } override def evaluate(buffer: Row): Any = { buffer.getSeq[Row](0) } } }