package eu.dnetlib.dhp.broker.oa; import org.apache.commons.io.IOUtils; import org.apache.commons.lang.StringUtils; import org.apache.spark.SparkConf; import org.apache.spark.api.java.function.FilterFunction; import org.apache.spark.api.java.function.MapFunction; import org.apache.spark.sql.Encoder; import org.apache.spark.sql.Encoders; import org.apache.spark.sql.SaveMode; import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.expressions.Aggregator; import org.apache.spark.util.LongAccumulator; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import eu.dnetlib.dhp.application.ArgumentApplicationParser; import eu.dnetlib.dhp.broker.model.Event; import eu.dnetlib.dhp.broker.oa.util.ClusterUtils; import scala.Tuple2; public class CheckDuplictedIdsJob { private static final Logger log = LoggerFactory.getLogger(CheckDuplictedIdsJob.class); public static void main(final String[] args) throws Exception { final ArgumentApplicationParser parser = new ArgumentApplicationParser( IOUtils .toString( CheckDuplictedIdsJob.class .getResourceAsStream("/eu/dnetlib/dhp/broker/oa/check_duplicates.json"))); parser.parseArgument(args); final SparkConf conf = new SparkConf(); final String eventsPath = parser.get("outputDir") + "/events"; log.info("eventsPath: {}", eventsPath); final String countPath = parser.get("outputDir") + "/counts"; log.info("countPath: {}", countPath); final SparkSession spark = SparkSession.builder().config(conf).getOrCreate(); final LongAccumulator total = spark.sparkContext().longAccumulator("invaild_event_id"); final Encoder> encoder = Encoders.tuple(Encoders.STRING(), Encoders.LONG()); ClusterUtils .readPath(spark, eventsPath, Event.class) .map((MapFunction>) e -> new Tuple2<>(e.getEventId(), 1l), encoder) .groupByKey((MapFunction, String>) t -> t._1, Encoders.STRING()) .agg(new CountAggregator().toColumn()) .map((MapFunction>, Tuple2>) t -> t._2, encoder) .filter((FilterFunction>) t -> t._2 > 1) .map( (MapFunction, Tuple2>) o -> ClusterUtils .incrementAccumulator(o, total), encoder) .write() .mode(SaveMode.Overwrite) .option("compression", "gzip") .json(countPath); } } class CountAggregator extends Aggregator, Tuple2, Tuple2> { /** * */ private static final long serialVersionUID = 1395935985734672538L; @Override public Encoder> bufferEncoder() { return Encoders.tuple(Encoders.STRING(), Encoders.LONG()); } @Override public Tuple2 finish(final Tuple2 arg0) { return arg0; } @Override public Tuple2 merge(final Tuple2 arg0, final Tuple2 arg1) { return doMerge(arg0, arg1); } @Override public Encoder> outputEncoder() { return Encoders.tuple(Encoders.STRING(), Encoders.LONG()); } @Override public Tuple2 reduce(final Tuple2 arg0, final Tuple2 arg1) { return doMerge(arg0, arg1); } private Tuple2 doMerge(final Tuple2 arg0, final Tuple2 arg1) { final String s = StringUtils.defaultIfBlank(arg0._1, arg1._1); return new Tuple2<>(s, arg0._2 + arg1._2); } @Override public Tuple2 zero() { return new Tuple2<>(null, 0l); } }