dnet-and/dnet-and-test/src/main/java/eu/dnetlib/jobs/featureextraction/lda/SparkLDAAnalysis.java

137 lines
5.3 KiB
Java

package eu.dnetlib.jobs.featureextraction.lda;
import com.clearspring.analytics.util.Lists;
import eu.dnetlib.featureextraction.Utilities;
import eu.dnetlib.jobs.AbstractSparkJob;
import eu.dnetlib.support.ArgumentApplicationParser;
import eu.dnetlib.support.Author;
import eu.dnetlib.support.AuthorsFactory;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaDoubleRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.SparkSession;
import org.codehaus.jackson.map.ObjectMapper;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.Tuple2;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;
public class SparkLDAAnalysis extends AbstractSparkJob {
private static final Logger log = LoggerFactory.getLogger(SparkLDAAnalysis.class);
public SparkLDAAnalysis(ArgumentApplicationParser parser, SparkSession spark) {
super(parser, spark);
}
public static void main(String[] args) throws Exception {
ArgumentApplicationParser parser = new ArgumentApplicationParser(
readResource("/jobs/parameters/ldaAnalysis_parameters.json", SparkLDATuning.class)
);
parser.parseArgument(args);
SparkConf conf = new SparkConf();
new SparkLDAAnalysis(
parser,
getSparkSession(conf)
).run();
}
@Override
protected void run() throws IOException {
// read oozie parameters
final String authorsPath = parser.get("authorsPath");
final String workingPath = parser.get("workingPath");
final int numPartitions = Optional
.ofNullable(parser.get("numPartitions"))
.map(Integer::valueOf)
.orElse(NUM_PARTITIONS);
log.info("authorsPath: '{}'", authorsPath);
log.info("workingPath: '{}'", workingPath);
log.info("numPartitions: '{}'", numPartitions);
JavaSparkContext context = JavaSparkContext.fromSparkContext(spark.sparkContext());
JavaRDD<Author> authors = context
.textFile(authorsPath)
.map(s -> new ObjectMapper().readValue(s, Author.class))
.filter(a -> !a.getOrcid().isEmpty()); //don't need authors without orcid for the threshold analysis
JavaRDD<Tuple2<Boolean, Double>> groundTruthThreshold = authors
.mapToPair(a -> new Tuple2<>(AuthorsFactory.getLNFI(a), a))
.flatMapToPair(a -> a._1().stream().map(k -> new Tuple2<>(k, a._2())).collect(Collectors.toList()).iterator())
.groupByKey()
.flatMap(a -> thresholdAnalysis(a._2()));
JavaDoubleRDD groundTruthTrue = groundTruthThreshold.filter(Tuple2::_1).mapToDouble(Tuple2::_2);
long totalPositives = groundTruthTrue.count();
JavaDoubleRDD groundTruthFalse = groundTruthThreshold.filter(x -> !x._1()).mapToDouble(Tuple2::_2);
long totalNegatives = groundTruthFalse.count();
double[] thresholds = new double[]{0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0};
List<String> stats = new ArrayList<>();
stats.add("th,fp,fn,tp,tn,total_positives,total_negatives");
for(double threshold: thresholds) {
long truePositive = groundTruthTrue.filter(d -> d >= threshold).count();
long falsePositive = groundTruthFalse.filter(d -> d >= threshold).count();
long trueNegative = groundTruthFalse.filter(d -> d < threshold).count();
long falseNegative = groundTruthTrue.filter(d -> d < threshold).count();
stats.add(threshold + "," + falsePositive + "," + falseNegative + "," + truePositive + "," + trueNegative + "," + totalPositives + "," + totalNegatives);
}
Utilities.writeLinesToHDFSFile(stats, workingPath + "/threshold_analysis.csv");
}
public Iterator<Tuple2<Boolean, Double>> thresholdAnalysis(Iterable<Author> a) {
List<Author> authors = Lists.newArrayList(a);
List<Tuple2<Boolean, Double>> results = new ArrayList<>();
int i = 0;
int j = 1;
while(i < authors.size()) {
while(j < authors.size()) {
boolean bRes;
if(authors.get(i).getOrcid().isEmpty() || authors.get(j).getOrcid().isEmpty())
bRes = false;
else {
bRes = authors.get(i).getOrcid().equals(authors.get(j).getOrcid());
}
results.add(new Tuple2<>(bRes, cosineSimilarity(authors.get(i).getEmbeddings().get("lda_topics"), authors.get(j).getEmbeddings().get("lda_topics"))));
j++;
}
i++;
j=i+1;
}
return results.iterator();
}
double cosineSimilarity(double[] a, double[] b) {
double dotProduct = 0;
double normASum = 0;
double normBSum = 0;
for(int i = 0; i < a.length; i ++) {
dotProduct += a[i] * b[i];
normASum += a[i] * a[i];
normBSum += b[i] * b[i];
}
double eucledianDist = Math.sqrt(normASum) * Math.sqrt(normBSum);
return dotProduct / eucledianDist;
}
}