dnet-dedup/dnet-dedup-test/src/main/java/eu/dnetlib/jobs/SparkComputeStatistics.java

208 lines
9.5 KiB
Java

package eu.dnetlib.jobs;
import eu.dnetlib.Deduper;
import eu.dnetlib.pace.config.DedupConfig;
import eu.dnetlib.pace.config.Type;
import eu.dnetlib.pace.model.FieldValueImpl;
import eu.dnetlib.pace.model.MapDocument;
import eu.dnetlib.pace.util.MapDocumentUtil;
import eu.dnetlib.pace.utils.Utility;
import eu.dnetlib.support.ArgumentApplicationParser;
import eu.dnetlib.support.Block;
import eu.dnetlib.support.ConnectedComponent;
import eu.dnetlib.support.Relation;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FSDataOutputStream;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.SparkSession;
import org.codehaus.jackson.map.ObjectMapper;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.Tuple2;
import java.io.IOException;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;
public class SparkComputeStatistics extends AbstractSparkJob {
private static final Logger log = LoggerFactory.getLogger(eu.dnetlib.jobs.SparkComputeStatistics.class);
public SparkComputeStatistics(ArgumentApplicationParser parser, SparkSession spark) {
super(parser, spark);
}
public static void main(String[] args) throws Exception {
ArgumentApplicationParser parser = new ArgumentApplicationParser(
readResource("/jobs/parameters/computeStatistics_parameters.json", eu.dnetlib.jobs.SparkCreateSimRels.class)
);
parser.parseArgument(args);
SparkConf conf = new SparkConf();
new eu.dnetlib.jobs.SparkComputeStatistics(
parser,
getSparkSession(conf)
).run();
}
@Override
public void run() throws IOException {
//https://towardsdatascience.com/7-evaluation-metrics-for-clustering-algorithms-bdc537ff54d2#:~:text=There%20are%20two%20types%20of,to%20all%20unsupervised%20learning%20results)
// read oozie parameters
final String entitiesPath = parser.get("entitiesPath");
final String workingPath = parser.get("workingPath");
final String dedupConfPath = parser.get("dedupConfPath");
final String groundTruthFieldJPath = parser.get("groundTruthFieldJPath");
final int numPartitions = Optional
.ofNullable(parser.get("numPartitions"))
.map(Integer::valueOf)
.orElse(NUM_PARTITIONS);
log.info("entitiesPath: '{}'", entitiesPath);
log.info("workingPath: '{}'", workingPath);
log.info("numPartitions: '{}'", numPartitions);
log.info("dedupConfPath: '{}'", dedupConfPath);
log.info("groundTruthFieldJPath: '{}'", groundTruthFieldJPath);
JavaSparkContext sc = JavaSparkContext.fromSparkContext(spark.sparkContext());
DedupConfig dedupConfig = loadDedupConfig(dedupConfPath);
JavaPairRDD<String, MapDocument> mapDocuments = sc
.textFile(entitiesPath)
.repartition(numPartitions)
.mapToPair(
(PairFunction<String, String, MapDocument>) s -> {
MapDocument d = MapDocumentUtil.asMapDocumentWithJPath(dedupConfig, s);
//put in the map the groundTruthField used to compute statistics
d.getFieldMap().put("groundTruth", new FieldValueImpl(Type.String, "groundTruth", MapDocumentUtil.getJPathString(groundTruthFieldJPath, s)));
return new Tuple2<>(d.getIdentifier(), d);
});
JavaRDD<String> entities = mapDocuments.map(d -> d._2().getFieldMap().get("groundTruth").stringValue());
// create blocks
JavaRDD<List<String>> blocks = Deduper.createSortedBlocks(mapDocuments, dedupConfig)
.map(b -> b._2().getDocuments().stream().map(d -> d.getFieldMap().get("groundTruth").stringValue()).collect(Collectors.toList()));
// <source, target>: source is the dedup_id, target is the id of the mergedIn
JavaRDD<Relation> mergerels = spark
.read()
.load(workingPath + "/mergerels")
.as(Encoders.bean(Relation.class))
.toJavaRDD();
JavaRDD<Relation> simrels = spark
.read()
.load(workingPath + "/simrels")
.as(Encoders.bean(Relation.class))
.toJavaRDD();
JavaRDD<List<String>> groups = sc.textFile(workingPath + "/groupentities")
.map(e -> new ObjectMapper().readValue(e, ConnectedComponent.class))
.map(e -> e.getDocs().stream().map(d -> MapDocumentUtil.getJPathString(groundTruthFieldJPath, d)).collect(Collectors.toList()));
long entities_number = entities.count();
long blocks_number = blocks.count();
double blocks_randIndex = randIndex(blocks);
long simrels_number = simrels.count();
long mergerels_number = mergerels.count();
double groups_randIndex = randIndex(groups);
long groups_number = groups.count();
long groundtruth_number = entities.filter(e -> !e.isEmpty()).count();
long correct_groups = groups.filter(x -> x.stream().distinct().count()==1).count();
long wrong_groups = groups_number - correct_groups;
String print =
"Entities : " + entities_number + "\n" +
"Ground Truth : " + groundtruth_number + "\n" +
"Blocks : " + blocks_number + "\n" +
"Blocks RI : " + blocks_randIndex + "\n" +
"SimRels : " + simrels_number + "\n" +
"MergeRels : " + mergerels_number + "\n" +
"Groups : " + groups_number + " (correct: " + correct_groups + ", wrong: " + wrong_groups + ")\n" +
"Groups RI : " + groups_randIndex;
System.out.println(print);
writeStatsFileToHDFS(groundtruth_number, entities_number, blocks_randIndex, groups_randIndex, blocks_number, simrels_number, mergerels_number, groups_number, workingPath + "/stats_file.txt");
}
public static void writeStatsFileToHDFS(long groundtruth_number, long entities_number, double blocks_randIndex, double groups_randIndex, long blocks_number, long simrels_number, long mergerels_number, long groups_number, String filePath) throws IOException {
Configuration conf = new Configuration();
FileSystem fs = FileSystem.get(conf);
fs.delete(new Path(filePath), true);
try {
fs = FileSystem.get(conf);
Path outFile = new Path(filePath);
// Verification
if (fs.exists(outFile)) {
System.out.println("Output file already exists");
throw new IOException("Output file already exists");
}
String print =
"Entities : " + entities_number + "\n" +
"Ground Truth : " + groundtruth_number + "\n" +
"Blocks : " + blocks_number + "\n" +
"Blocks RI : " + blocks_randIndex + "\n" +
"SimRels : " + simrels_number + "\n" +
"MergeRels : " + mergerels_number + "\n" +
"Groups : " + groups_number + "\n" +
"Groups RI : " + groups_randIndex;
// Create file to write
FSDataOutputStream out = fs.create(outFile);
try{
out.writeBytes(print);
}
finally {
out.close();
}
} catch (IOException e) {
e.printStackTrace();
}
}
//TODO find another maesure that takes into account all the elements outside of the group too
//RandIndex = number of pairwise correct predictions/total number of possible pairs (in the same cluster) -> bounded between 0 and 1
public double randIndex(JavaRDD<List<String>> clusters) {
Tuple2<Integer, Integer> reduce = clusters.map(c -> {
int num = 0;
for (String id : c.stream().distinct().filter(s -> !s.isEmpty()).collect(Collectors.toList())) {
int n = (int) c.stream().filter(i -> i.equals(id)).count();
num += binomialCoefficient(n);
}
int den = binomialCoefficient(c.size());
return new Tuple2<>(num, den);
})
.reduce((a, b) -> new Tuple2<>(a._1() + b._1(), a._2() + b._2()));
return (double)reduce._1()/ reduce._2();
}
private static int binomialCoefficient(int n)
{
return n*(n-1)/2;
}
//V-measure = harmonic mean of homogeneity and completeness, homogeneity = each cluster contains only members of a single class, completeness = all members of a given class are assigned to the same cluster
}