288 lines
13 KiB
Java
288 lines
13 KiB
Java
package org.gcube.dataanalysis.ecoengine.evaluation;
|
|
|
|
import java.util.ArrayList;
|
|
import java.util.HashMap;
|
|
import java.util.LinkedHashMap;
|
|
import java.util.LinkedList;
|
|
import java.util.List;
|
|
import java.util.Map;
|
|
|
|
import org.gcube.contentmanagement.graphtools.utils.MathFunctions;
|
|
import org.gcube.dataanalysis.ecoengine.datatypes.ColumnType;
|
|
import org.gcube.dataanalysis.ecoengine.datatypes.DatabaseType;
|
|
import org.gcube.dataanalysis.ecoengine.datatypes.InputTable;
|
|
import org.gcube.dataanalysis.ecoengine.datatypes.PrimitiveType;
|
|
import org.gcube.dataanalysis.ecoengine.datatypes.StatisticalType;
|
|
import org.gcube.dataanalysis.ecoengine.datatypes.enumtypes.PrimitiveTypes;
|
|
import org.gcube.dataanalysis.ecoengine.datatypes.enumtypes.TableTemplates;
|
|
import org.gcube.dataanalysis.ecoengine.interfaces.DataAnalysis;
|
|
import org.gcube.dataanalysis.ecoengine.utils.DatabaseFactory;
|
|
import org.slf4j.Logger;
|
|
import org.slf4j.LoggerFactory;
|
|
|
|
import com.rapidminer.example.Attribute;
|
|
import com.rapidminer.example.Attributes;
|
|
import com.rapidminer.example.ExampleSet;
|
|
import com.rapidminer.example.table.AttributeFactory;
|
|
import com.rapidminer.example.table.BinominalMapping;
|
|
import com.rapidminer.example.table.DoubleArrayDataRow;
|
|
import com.rapidminer.example.table.MemoryExampleTable;
|
|
import com.rapidminer.tools.Ontology;
|
|
import com.rapidminer.tools.math.ROCData;
|
|
import com.rapidminer.tools.math.ROCDataGenerator;
|
|
|
|
public class DistributionQualityAnalysis extends DataAnalysis {
|
|
|
|
private static Logger logger = LoggerFactory.getLogger(DistributionQualityAnalysis.class);
|
|
|
|
static String getProbabilititesQuery = "select count(*) from (select distinct * from %1$s as a join %2$s as b on a.%3$s=b.%4$s and b.%5$s %6$s %7$s) as aa";
|
|
static String getNumberOfElementsQuery = "select count(*) from %1$s";
|
|
static String getValuesQuery = "select %5$s as distribprob (select distinct * from %1$s as a join %2$s as b on a.%3$s=b.%4$s) as b";
|
|
|
|
float threshold = 0.1f;
|
|
String configPath = "./cfg/";
|
|
float acceptanceThreshold = 0.8f;
|
|
float rejectionThreshold = 0.3f;
|
|
double bestThreshold = 0.5d;
|
|
private LinkedHashMap<String, String> output;
|
|
|
|
public List<StatisticalType> getInputParameters() {
|
|
List<StatisticalType> parameters = new ArrayList<StatisticalType>();
|
|
List<TableTemplates> templates = new ArrayList<TableTemplates>();
|
|
templates.add(TableTemplates.HSPEC);
|
|
templates.add(TableTemplates.TRAININGSET);
|
|
templates.add(TableTemplates.TESTSET);
|
|
|
|
List<TableTemplates> templatesOccurrences = new ArrayList<TableTemplates>();
|
|
templatesOccurrences.add(TableTemplates.HCAF);
|
|
|
|
InputTable p1 = new InputTable(templatesOccurrences,"PositiveCasesTable","A Table containing positive cases");
|
|
InputTable p2 = new InputTable(templatesOccurrences,"NegativeCasesTable","A Table containing negative cases");
|
|
InputTable p5 = new InputTable(templates,"DistributionTable","A probability distribution table");
|
|
|
|
ColumnType p3 = new ColumnType("PositiveCasesTable", "PositiveCasesTableKeyColumn", "Positive Cases Table Key Column", "csquarecode", false);
|
|
ColumnType p4 = new ColumnType("NegativeCasesTable", "NegativeCasesTableKeyColumn", "Negative Cases Table Key Column", "csquarecode", false);
|
|
ColumnType p6 = new ColumnType("DistributionTable", "DistributionTableKeyColumn", "Distribution Table Key Column", "csquarecode", false);
|
|
ColumnType p7 = new ColumnType("DistributionTable", "DistributionTableProbabilityColumn", "Distribution Table Probability Column", "probability", false);
|
|
|
|
PrimitiveType p8 = new PrimitiveType(String.class.getName(), null, PrimitiveTypes.STRING, "PositiveThreshold","Positive acceptance threshold","0.8");
|
|
PrimitiveType p9 = new PrimitiveType(String.class.getName(), null, PrimitiveTypes.STRING, "NegativeThreshold","Negative acceptance threshold","0.3");
|
|
|
|
parameters.add(p1);
|
|
parameters.add(p2);
|
|
parameters.add(p3);
|
|
parameters.add(p4);
|
|
parameters.add(p5);
|
|
parameters.add(p6);
|
|
parameters.add(p7);
|
|
parameters.add(p8);
|
|
parameters.add(p9);
|
|
|
|
DatabaseType.addDefaultDBPars(parameters);
|
|
|
|
return parameters;
|
|
}
|
|
|
|
public List<String> getOutputParameters() {
|
|
|
|
List<String> outputs = new ArrayList<String>();
|
|
|
|
outputs.add("TRUE_POSITIVES");
|
|
outputs.add("TRUE_NEGATIVES");
|
|
outputs.add("FALSE_POSITIVES");
|
|
outputs.add("FALSE_NEGATIVES");
|
|
outputs.add("AUC");
|
|
outputs.add("ACCURACY");
|
|
outputs.add("SENSITIVITY");
|
|
outputs.add("OMISSIONRATE");
|
|
outputs.add("SPECIFICITY");
|
|
outputs.add("BESTTHRESHOLD");
|
|
|
|
return outputs;
|
|
}
|
|
|
|
private int calculateNumberOfPoints(String table) {
|
|
|
|
String numberOfPositiveCasesQuery = String.format(getNumberOfElementsQuery, table);
|
|
List<Object> totalPoints = DatabaseFactory.executeSQLQuery(numberOfPositiveCasesQuery, connection);
|
|
int points = Integer.parseInt("" + totalPoints.get(0));
|
|
return points;
|
|
}
|
|
|
|
private int calculateCaughtPoints(String casesTable, String distributionTable, String casesTableKeyColumn, String distributionTableKeyColumn, String distributionTableProbabilityColumn, String operator, String threshold) {
|
|
|
|
String query = String.format(getProbabilititesQuery, casesTable, distributionTable, casesTableKeyColumn, distributionTableKeyColumn, distributionTableProbabilityColumn, operator, threshold);
|
|
logger.trace("Compare - Query to perform for caught cases:" + query);
|
|
List<Object> caughtpoints = DatabaseFactory.executeSQLQuery(query, connection);
|
|
int points = Integer.parseInt("" + caughtpoints.get(0));
|
|
return points;
|
|
}
|
|
|
|
private double[] getPoints(String casesTable, String distributionTable, String casesTableKeyColumn, String distributionTableKeyColumn, String distributionTableProbabilityColumn, int numberOfExpectedPoints) {
|
|
|
|
String query = String.format(getValuesQuery, casesTable, distributionTable, casesTableKeyColumn, distributionTableKeyColumn, distributionTableProbabilityColumn);
|
|
|
|
logger.trace("Compare - Query to perform for caught cases:" + query);
|
|
List<Object> caughtpoints = DatabaseFactory.executeSQLQuery(query, connection);
|
|
int size = 0;
|
|
if (caughtpoints != null)
|
|
size = caughtpoints.size();
|
|
double[] points = new double[numberOfExpectedPoints];
|
|
|
|
for (int i = 0; i < size; i++) {
|
|
double element = 0;
|
|
if (caughtpoints.get(i) != null)
|
|
element = Double.parseDouble("" + caughtpoints.get(i));
|
|
|
|
points[i] = element;
|
|
}
|
|
|
|
return points;
|
|
}
|
|
|
|
public LinkedHashMap<String, String> analyze() throws Exception {
|
|
|
|
try {
|
|
acceptanceThreshold = Float.parseFloat(config.getParam("PositiveThreshold"));
|
|
} catch (Exception e) {
|
|
logger.debug("ERROR : " + e.getLocalizedMessage());
|
|
}
|
|
try {
|
|
rejectionThreshold = Float.parseFloat(config.getParam("NegativeThreshold"));
|
|
} catch (Exception e) {
|
|
logger.debug("ERROR : " + e.getLocalizedMessage());
|
|
}
|
|
|
|
String positiveCasesTable = config.getParam("PositiveCasesTable");
|
|
String negativeCasesTable = config.getParam("NegativeCasesTable");
|
|
String distributionTable = config.getParam("DistributionTable");
|
|
String positiveCasesTableKeyColumn = config.getParam("PositiveCasesTableKeyColumn");
|
|
String negativeCasesTableKeyColumn = config.getParam("NegativeCasesTableKeyColumn");
|
|
String distributionTableKeyColumn = config.getParam("DistributionTableKeyColumn");
|
|
String distributionTableProbabilityColumn = config.getParam("DistributionTableProbabilityColumn");
|
|
String acceptanceThreshold = config.getParam("PositiveThreshold");
|
|
String rejectionThreshold = config.getParam("NegativeThreshold");
|
|
|
|
int numberOfPositiveCases = calculateNumberOfPoints(positiveCasesTable);
|
|
|
|
int truePositives = calculateCaughtPoints(positiveCasesTable, distributionTable, positiveCasesTableKeyColumn, distributionTableKeyColumn, distributionTableProbabilityColumn, ">", acceptanceThreshold);
|
|
|
|
int falseNegatives = numberOfPositiveCases - truePositives;
|
|
|
|
int numberOfNegativeCases = calculateNumberOfPoints(negativeCasesTable);
|
|
|
|
super.processedRecords = numberOfPositiveCases + numberOfNegativeCases;
|
|
|
|
int falsePositives = calculateCaughtPoints(negativeCasesTable, distributionTable, negativeCasesTableKeyColumn, distributionTableKeyColumn, distributionTableProbabilityColumn, ">", rejectionThreshold);
|
|
|
|
int trueNegatives = numberOfNegativeCases - falsePositives;
|
|
|
|
double[] positivePoints = getPoints(positiveCasesTable, distributionTable, positiveCasesTableKeyColumn, distributionTableKeyColumn, distributionTableProbabilityColumn, numberOfPositiveCases);
|
|
|
|
double[] negativePoints = getPoints(negativeCasesTable, distributionTable, negativeCasesTableKeyColumn, distributionTableKeyColumn, distributionTableProbabilityColumn, numberOfNegativeCases);
|
|
|
|
double auc = calculateAUC(positivePoints, negativePoints, false);
|
|
double accuracy = calculateAccuracy(truePositives, trueNegatives, falsePositives, falseNegatives);
|
|
double sensitivity = calculateSensitivity(truePositives, falseNegatives);
|
|
double omissionrate = calculateOmissionRate(truePositives, falseNegatives);
|
|
double specificity = calculateSpecificity(trueNegatives, falsePositives);
|
|
|
|
output = new LinkedHashMap<String, String>();
|
|
output.put("TRUE_POSITIVES", "" + truePositives);
|
|
output.put("TRUE_NEGATIVES", "" + trueNegatives);
|
|
output.put("FALSE_POSITIVES", "" + falsePositives);
|
|
output.put("FALSE_NEGATIVES", "" + falseNegatives);
|
|
output.put("AUC", "" + MathFunctions.roundDecimal(auc,2));
|
|
output.put("ACCURACY", "" + MathFunctions.roundDecimal(accuracy,2));
|
|
output.put("SENSITIVITY", "" + MathFunctions.roundDecimal(sensitivity,2));
|
|
output.put("OMISSIONRATE", "" + MathFunctions.roundDecimal(omissionrate,2));
|
|
output.put("SPECIFICITY", "" + MathFunctions.roundDecimal(specificity,2));
|
|
output.put("BESTTHRESHOLD", "" + MathFunctions.roundDecimal(bestThreshold,2));
|
|
|
|
return output;
|
|
}
|
|
|
|
public double calculateSensitivity(int TP, int FN) {
|
|
return (double) (TP) / (double) (TP + FN);
|
|
}
|
|
|
|
public double calculateOmissionRate(int TP, int FN) {
|
|
return (double) (FN) / (double) (TP + FN);
|
|
}
|
|
|
|
public double calculateSpecificity(int TN, int FP) {
|
|
return (double) (TN) / (double) (TN + FP);
|
|
}
|
|
|
|
public double calculateAccuracy(int TP, int TN, int FP, int FN) {
|
|
return (double) (TP + TN) / (double) (TP + TN + FP + FN);
|
|
}
|
|
|
|
public double calculateAUC(double[] scoresOnPresence, double[] scoresOnAbsence, boolean produceChart) {
|
|
|
|
List<Attribute> attributes = new LinkedList<Attribute>();
|
|
Attribute labelAtt = AttributeFactory.createAttribute("LABEL", Ontology.BINOMINAL);
|
|
BinominalMapping bm = (BinominalMapping) labelAtt.getMapping();
|
|
bm.setMapping("1", 1);
|
|
bm.setMapping("0", 0);
|
|
|
|
Attribute confidenceAtt1 = AttributeFactory.createAttribute(Attributes.CONFIDENCE_NAME + "_1", Ontology.REAL);
|
|
attributes.add(confidenceAtt1);
|
|
attributes.add(labelAtt);
|
|
|
|
MemoryExampleTable table = new MemoryExampleTable(attributes);
|
|
int numOfPoints = scoresOnPresence.length + scoresOnAbsence.length;
|
|
int numOfPresence = scoresOnPresence.length;
|
|
int numOfAttributes = attributes.size();
|
|
double pos = labelAtt.getMapping().mapString("1");
|
|
double neg = labelAtt.getMapping().mapString("0");
|
|
|
|
for (int i = 0; i < numOfPresence; i++) {
|
|
double[] data = new double[numOfAttributes];
|
|
data[0] = scoresOnPresence[i];
|
|
data[1] = pos;
|
|
table.addDataRow(new DoubleArrayDataRow(data));
|
|
}
|
|
|
|
for (int i = numOfPresence; i < numOfPoints; i++) {
|
|
double[] data = new double[numOfAttributes];
|
|
data[0] = scoresOnAbsence[i - numOfPresence];
|
|
data[1] = neg;
|
|
table.addDataRow(new DoubleArrayDataRow(data));
|
|
}
|
|
|
|
ROCDataGenerator roc = new ROCDataGenerator(acceptanceThreshold, rejectionThreshold);
|
|
ExampleSet exampleSet = table.createExampleSet(labelAtt);
|
|
exampleSet.getAttributes().setSpecialAttribute(confidenceAtt1, Attributes.CONFIDENCE_NAME + "_1");
|
|
|
|
ROCData dataROC = roc.createROCData(exampleSet, false);
|
|
double auc = roc.calculateAUC(dataROC);
|
|
|
|
// PLOTS THE ROC!!!
|
|
if (produceChart)
|
|
roc.createROCPlotDialog(dataROC);
|
|
|
|
bestThreshold = roc.getBestThreshold();
|
|
return auc;
|
|
}
|
|
|
|
public static void visualizeResults(HashMap<String, Object> results) {
|
|
|
|
for (String key : results.keySet()) {
|
|
System.out.println(key + ":" + results.get(key));
|
|
}
|
|
}
|
|
|
|
@Override
|
|
public StatisticalType getOutput() {
|
|
PrimitiveType p = new PrimitiveType(Map.class.getName(), PrimitiveType.stringMap2StatisticalMap(output), PrimitiveTypes.MAP, "AnalysisResult","Analysis of the probability distribution quality");
|
|
return p;
|
|
}
|
|
|
|
@Override
|
|
public String getDescription() {
|
|
return "An evaluator algorithm that assesses the effectiveness of a distribution model by computing the Receiver Operating Characteristics (ROC), the Area Under Curve (AUC) and the Accuracy of a model";
|
|
}
|
|
|
|
}
|