package eu.dnetlib.jobs.featureextraction.lda; import com.clearspring.analytics.util.Lists; import eu.dnetlib.featureextraction.Utilities; import eu.dnetlib.jobs.AbstractSparkJob; import eu.dnetlib.support.ArgumentApplicationParser; import eu.dnetlib.support.Author; import eu.dnetlib.support.AuthorsFactory; import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaDoubleRDD; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.sql.SparkSession; import org.codehaus.jackson.map.ObjectMapper; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import scala.Tuple2; import java.io.IOException; import java.util.ArrayList; import java.util.Iterator; import java.util.List; import java.util.Optional; import java.util.stream.Collectors; public class SparkLDAAnalysis extends AbstractSparkJob { private static final Logger log = LoggerFactory.getLogger(SparkLDAAnalysis.class); public SparkLDAAnalysis(ArgumentApplicationParser parser, SparkSession spark) { super(parser, spark); } public static void main(String[] args) throws Exception { ArgumentApplicationParser parser = new ArgumentApplicationParser( readResource("/jobs/parameters/ldaAnalysis_parameters.json", SparkLDATuning.class) ); parser.parseArgument(args); SparkConf conf = new SparkConf(); new SparkLDAAnalysis( parser, getSparkSession(conf) ).run(); } @Override protected void run() throws IOException { // read oozie parameters final String authorsPath = parser.get("authorsPath"); final String workingPath = parser.get("workingPath"); final int numPartitions = Optional .ofNullable(parser.get("numPartitions")) .map(Integer::valueOf) .orElse(NUM_PARTITIONS); log.info("authorsPath: '{}'", authorsPath); log.info("workingPath: '{}'", workingPath); log.info("numPartitions: '{}'", numPartitions); JavaSparkContext context = JavaSparkContext.fromSparkContext(spark.sparkContext()); JavaRDD authors = context .textFile(authorsPath) .map(s -> new ObjectMapper().readValue(s, Author.class)) .filter(a -> !a.getOrcid().isEmpty()); //don't need authors without orcid for the threshold analysis JavaRDD> groundTruthThreshold = authors .mapToPair(a -> new Tuple2<>(AuthorsFactory.getLNFI(a), a)) .flatMapToPair(a -> a._1().stream().map(k -> new Tuple2<>(k, a._2())).collect(Collectors.toList()).iterator()) .groupByKey() .flatMap(a -> thresholdAnalysis(a._2())); JavaDoubleRDD groundTruthTrue = groundTruthThreshold.filter(Tuple2::_1).mapToDouble(Tuple2::_2); long totalPositives = groundTruthTrue.count(); JavaDoubleRDD groundTruthFalse = groundTruthThreshold.filter(x -> !x._1()).mapToDouble(Tuple2::_2); long totalNegatives = groundTruthFalse.count(); double[] thresholds = new double[]{0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0}; List stats = new ArrayList<>(); stats.add("th,fp,fn,tp,tn,total_positives,total_negatives"); for(double threshold: thresholds) { long truePositive = groundTruthTrue.filter(d -> d >= threshold).count(); long falsePositive = groundTruthFalse.filter(d -> d >= threshold).count(); long trueNegative = groundTruthFalse.filter(d -> d < threshold).count(); long falseNegative = groundTruthTrue.filter(d -> d < threshold).count(); stats.add(threshold + "," + falsePositive + "," + falseNegative + "," + truePositive + "," + trueNegative + "," + totalPositives + "," + totalNegatives); } Utilities.writeLinesToHDFSFile(stats, workingPath + "/threshold_analysis.csv"); } public Iterator> thresholdAnalysis(Iterable a) { List authors = Lists.newArrayList(a); List> results = new ArrayList<>(); int i = 0; int j = 1; while(i < authors.size()) { while(j < authors.size()) { boolean bRes; if(authors.get(i).getOrcid().isEmpty() || authors.get(j).getOrcid().isEmpty()) bRes = false; else { bRes = authors.get(i).getOrcid().equals(authors.get(j).getOrcid()); } results.add(new Tuple2<>(bRes, cosineSimilarity(authors.get(i).getEmbeddings().get("lda_topics"), authors.get(j).getEmbeddings().get("lda_topics")))); j++; } i++; j=i+1; } return results.iterator(); } double cosineSimilarity(double[] a, double[] b) { double dotProduct = 0; double normASum = 0; double normBSum = 0; for(int i = 0; i < a.length; i ++) { dotProduct += a[i] * b[i]; normASum += a[i] * a[i]; normBSum += b[i] * b[i]; } double eucledianDist = Math.sqrt(normASum) * Math.sqrt(normBSum); return dotProduct / eucledianDist; } }