ecological-engine/src/main/java/org/gcube/dataanalysis/ecoengine/evaluation/DistributionQualityAnalysis...

288 lines
13 KiB
Java

package org.gcube.dataanalysis.ecoengine.evaluation;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import org.gcube.contentmanagement.graphtools.utils.MathFunctions;
import org.gcube.dataanalysis.ecoengine.datatypes.ColumnType;
import org.gcube.dataanalysis.ecoengine.datatypes.DatabaseType;
import org.gcube.dataanalysis.ecoengine.datatypes.InputTable;
import org.gcube.dataanalysis.ecoengine.datatypes.PrimitiveType;
import org.gcube.dataanalysis.ecoengine.datatypes.StatisticalType;
import org.gcube.dataanalysis.ecoengine.datatypes.enumtypes.PrimitiveTypes;
import org.gcube.dataanalysis.ecoengine.datatypes.enumtypes.TableTemplates;
import org.gcube.dataanalysis.ecoengine.interfaces.DataAnalysis;
import org.gcube.dataanalysis.ecoengine.utils.DatabaseFactory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import com.rapidminer.example.Attribute;
import com.rapidminer.example.Attributes;
import com.rapidminer.example.ExampleSet;
import com.rapidminer.example.table.AttributeFactory;
import com.rapidminer.example.table.BinominalMapping;
import com.rapidminer.example.table.DoubleArrayDataRow;
import com.rapidminer.example.table.MemoryExampleTable;
import com.rapidminer.tools.Ontology;
import com.rapidminer.tools.math.ROCData;
import com.rapidminer.tools.math.ROCDataGenerator;
public class DistributionQualityAnalysis extends DataAnalysis {
private static Logger logger = LoggerFactory.getLogger(DistributionQualityAnalysis.class);
static String getProbabilititesQuery = "select count(*) from (select distinct * from %1$s as a join %2$s as b on a.%3$s=b.%4$s and b.%5$s %6$s %7$s) as aa";
static String getNumberOfElementsQuery = "select count(*) from %1$s";
static String getValuesQuery = "select %5$s as distribprob (select distinct * from %1$s as a join %2$s as b on a.%3$s=b.%4$s) as b";
float threshold = 0.1f;
String configPath = "./cfg/";
float acceptanceThreshold = 0.8f;
float rejectionThreshold = 0.3f;
double bestThreshold = 0.5d;
private LinkedHashMap<String, String> output;
public List<StatisticalType> getInputParameters() {
List<StatisticalType> parameters = new ArrayList<StatisticalType>();
List<TableTemplates> templates = new ArrayList<TableTemplates>();
templates.add(TableTemplates.HSPEC);
templates.add(TableTemplates.TRAININGSET);
templates.add(TableTemplates.TESTSET);
List<TableTemplates> templatesOccurrences = new ArrayList<TableTemplates>();
templatesOccurrences.add(TableTemplates.HCAF);
InputTable p1 = new InputTable(templatesOccurrences,"PositiveCasesTable","A Table containing positive cases");
InputTable p2 = new InputTable(templatesOccurrences,"NegativeCasesTable","A Table containing negative cases");
InputTable p5 = new InputTable(templates,"DistributionTable","A probability distribution table");
ColumnType p3 = new ColumnType("PositiveCasesTable", "PositiveCasesTableKeyColumn", "Positive Cases Table Key Column", "csquarecode", false);
ColumnType p4 = new ColumnType("NegativeCasesTable", "NegativeCasesTableKeyColumn", "Negative Cases Table Key Column", "csquarecode", false);
ColumnType p6 = new ColumnType("DistributionTable", "DistributionTableKeyColumn", "Distribution Table Key Column", "csquarecode", false);
ColumnType p7 = new ColumnType("DistributionTable", "DistributionTableProbabilityColumn", "Distribution Table Probability Column", "probability", false);
PrimitiveType p8 = new PrimitiveType(String.class.getName(), null, PrimitiveTypes.STRING, "PositiveThreshold","Positive acceptance threshold","0.8");
PrimitiveType p9 = new PrimitiveType(String.class.getName(), null, PrimitiveTypes.STRING, "NegativeThreshold","Negative acceptance threshold","0.3");
parameters.add(p1);
parameters.add(p2);
parameters.add(p3);
parameters.add(p4);
parameters.add(p5);
parameters.add(p6);
parameters.add(p7);
parameters.add(p8);
parameters.add(p9);
DatabaseType.addDefaultDBPars(parameters);
return parameters;
}
public List<String> getOutputParameters() {
List<String> outputs = new ArrayList<String>();
outputs.add("TRUE_POSITIVES");
outputs.add("TRUE_NEGATIVES");
outputs.add("FALSE_POSITIVES");
outputs.add("FALSE_NEGATIVES");
outputs.add("AUC");
outputs.add("ACCURACY");
outputs.add("SENSITIVITY");
outputs.add("OMISSIONRATE");
outputs.add("SPECIFICITY");
outputs.add("BESTTHRESHOLD");
return outputs;
}
private int calculateNumberOfPoints(String table) {
String numberOfPositiveCasesQuery = String.format(getNumberOfElementsQuery, table);
List<Object> totalPoints = DatabaseFactory.executeSQLQuery(numberOfPositiveCasesQuery, connection);
int points = Integer.parseInt("" + totalPoints.get(0));
return points;
}
private int calculateCaughtPoints(String casesTable, String distributionTable, String casesTableKeyColumn, String distributionTableKeyColumn, String distributionTableProbabilityColumn, String operator, String threshold) {
String query = String.format(getProbabilititesQuery, casesTable, distributionTable, casesTableKeyColumn, distributionTableKeyColumn, distributionTableProbabilityColumn, operator, threshold);
logger.trace("Compare - Query to perform for caught cases:" + query);
List<Object> caughtpoints = DatabaseFactory.executeSQLQuery(query, connection);
int points = Integer.parseInt("" + caughtpoints.get(0));
return points;
}
private double[] getPoints(String casesTable, String distributionTable, String casesTableKeyColumn, String distributionTableKeyColumn, String distributionTableProbabilityColumn, int numberOfExpectedPoints) {
String query = String.format(getValuesQuery, casesTable, distributionTable, casesTableKeyColumn, distributionTableKeyColumn, distributionTableProbabilityColumn);
logger.trace("Compare - Query to perform for caught cases:" + query);
List<Object> caughtpoints = DatabaseFactory.executeSQLQuery(query, connection);
int size = 0;
if (caughtpoints != null)
size = caughtpoints.size();
double[] points = new double[numberOfExpectedPoints];
for (int i = 0; i < size; i++) {
double element = 0;
if (caughtpoints.get(i) != null)
element = Double.parseDouble("" + caughtpoints.get(i));
points[i] = element;
}
return points;
}
public LinkedHashMap<String, String> analyze() throws Exception {
try {
acceptanceThreshold = Float.parseFloat(config.getParam("PositiveThreshold"));
} catch (Exception e) {
logger.debug("ERROR : " + e.getLocalizedMessage());
}
try {
rejectionThreshold = Float.parseFloat(config.getParam("NegativeThreshold"));
} catch (Exception e) {
logger.debug("ERROR : " + e.getLocalizedMessage());
}
String positiveCasesTable = config.getParam("PositiveCasesTable");
String negativeCasesTable = config.getParam("NegativeCasesTable");
String distributionTable = config.getParam("DistributionTable");
String positiveCasesTableKeyColumn = config.getParam("PositiveCasesTableKeyColumn");
String negativeCasesTableKeyColumn = config.getParam("NegativeCasesTableKeyColumn");
String distributionTableKeyColumn = config.getParam("DistributionTableKeyColumn");
String distributionTableProbabilityColumn = config.getParam("DistributionTableProbabilityColumn");
String acceptanceThreshold = config.getParam("PositiveThreshold");
String rejectionThreshold = config.getParam("NegativeThreshold");
int numberOfPositiveCases = calculateNumberOfPoints(positiveCasesTable);
int truePositives = calculateCaughtPoints(positiveCasesTable, distributionTable, positiveCasesTableKeyColumn, distributionTableKeyColumn, distributionTableProbabilityColumn, ">", acceptanceThreshold);
int falseNegatives = numberOfPositiveCases - truePositives;
int numberOfNegativeCases = calculateNumberOfPoints(negativeCasesTable);
super.processedRecords = numberOfPositiveCases + numberOfNegativeCases;
int falsePositives = calculateCaughtPoints(negativeCasesTable, distributionTable, negativeCasesTableKeyColumn, distributionTableKeyColumn, distributionTableProbabilityColumn, ">", rejectionThreshold);
int trueNegatives = numberOfNegativeCases - falsePositives;
double[] positivePoints = getPoints(positiveCasesTable, distributionTable, positiveCasesTableKeyColumn, distributionTableKeyColumn, distributionTableProbabilityColumn, numberOfPositiveCases);
double[] negativePoints = getPoints(negativeCasesTable, distributionTable, negativeCasesTableKeyColumn, distributionTableKeyColumn, distributionTableProbabilityColumn, numberOfNegativeCases);
double auc = calculateAUC(positivePoints, negativePoints, false);
double accuracy = calculateAccuracy(truePositives, trueNegatives, falsePositives, falseNegatives);
double sensitivity = calculateSensitivity(truePositives, falseNegatives);
double omissionrate = calculateOmissionRate(truePositives, falseNegatives);
double specificity = calculateSpecificity(trueNegatives, falsePositives);
output = new LinkedHashMap<String, String>();
output.put("TRUE_POSITIVES", "" + truePositives);
output.put("TRUE_NEGATIVES", "" + trueNegatives);
output.put("FALSE_POSITIVES", "" + falsePositives);
output.put("FALSE_NEGATIVES", "" + falseNegatives);
output.put("AUC", "" + MathFunctions.roundDecimal(auc,2));
output.put("ACCURACY", "" + MathFunctions.roundDecimal(accuracy,2));
output.put("SENSITIVITY", "" + MathFunctions.roundDecimal(sensitivity,2));
output.put("OMISSIONRATE", "" + MathFunctions.roundDecimal(omissionrate,2));
output.put("SPECIFICITY", "" + MathFunctions.roundDecimal(specificity,2));
output.put("BESTTHRESHOLD", "" + MathFunctions.roundDecimal(bestThreshold,2));
return output;
}
public double calculateSensitivity(int TP, int FN) {
return (double) (TP) / (double) (TP + FN);
}
public double calculateOmissionRate(int TP, int FN) {
return (double) (FN) / (double) (TP + FN);
}
public double calculateSpecificity(int TN, int FP) {
return (double) (TN) / (double) (TN + FP);
}
public double calculateAccuracy(int TP, int TN, int FP, int FN) {
return (double) (TP + TN) / (double) (TP + TN + FP + FN);
}
public double calculateAUC(double[] scoresOnPresence, double[] scoresOnAbsence, boolean produceChart) {
List<Attribute> attributes = new LinkedList<Attribute>();
Attribute labelAtt = AttributeFactory.createAttribute("LABEL", Ontology.BINOMINAL);
BinominalMapping bm = (BinominalMapping) labelAtt.getMapping();
bm.setMapping("1", 1);
bm.setMapping("0", 0);
Attribute confidenceAtt1 = AttributeFactory.createAttribute(Attributes.CONFIDENCE_NAME + "_1", Ontology.REAL);
attributes.add(confidenceAtt1);
attributes.add(labelAtt);
MemoryExampleTable table = new MemoryExampleTable(attributes);
int numOfPoints = scoresOnPresence.length + scoresOnAbsence.length;
int numOfPresence = scoresOnPresence.length;
int numOfAttributes = attributes.size();
double pos = labelAtt.getMapping().mapString("1");
double neg = labelAtt.getMapping().mapString("0");
for (int i = 0; i < numOfPresence; i++) {
double[] data = new double[numOfAttributes];
data[0] = scoresOnPresence[i];
data[1] = pos;
table.addDataRow(new DoubleArrayDataRow(data));
}
for (int i = numOfPresence; i < numOfPoints; i++) {
double[] data = new double[numOfAttributes];
data[0] = scoresOnAbsence[i - numOfPresence];
data[1] = neg;
table.addDataRow(new DoubleArrayDataRow(data));
}
ROCDataGenerator roc = new ROCDataGenerator(acceptanceThreshold, rejectionThreshold);
ExampleSet exampleSet = table.createExampleSet(labelAtt);
exampleSet.getAttributes().setSpecialAttribute(confidenceAtt1, Attributes.CONFIDENCE_NAME + "_1");
ROCData dataROC = roc.createROCData(exampleSet, false);
double auc = roc.calculateAUC(dataROC);
// PLOTS THE ROC!!!
if (produceChart)
roc.createROCPlotDialog(dataROC);
bestThreshold = roc.getBestThreshold();
return auc;
}
public static void visualizeResults(HashMap<String, Object> results) {
for (String key : results.keySet()) {
System.out.println(key + ":" + results.get(key));
}
}
@Override
public StatisticalType getOutput() {
PrimitiveType p = new PrimitiveType(Map.class.getName(), PrimitiveType.stringMap2StatisticalMap(output), PrimitiveTypes.MAP, "AnalysisResult","Analysis of the probability distribution quality");
return p;
}
@Override
public String getDescription() {
return "An evaluator algorithm that assesses the effectiveness of a distribution model by computing the Receiver Operating Characteristics (ROC), the Area Under Curve (AUC) and the Accuracy of a model";
}
}