82 lines
2.9 KiB
Java
82 lines
2.9 KiB
Java
package org.gcube.dataanalysis.ecoengine.utils;
|
|
|
|
import java.util.List;
|
|
|
|
import org.gcube.dataanalysis.ecoengine.configuration.AlgorithmConfiguration;
|
|
import org.hibernate.SessionFactory;
|
|
import org.slf4j.Logger;
|
|
import org.slf4j.LoggerFactory;
|
|
|
|
public class TrainingSetsGenerator {
|
|
|
|
private static Logger logger = LoggerFactory.getLogger(TrainingSetsGenerator.class);
|
|
|
|
static String getProbabilititesQuery = "select count(*) as distribprob from %1$s as a join %2$s as b on a.%3$s=b.%4$s and b.%5$s %6$s %7$s";
|
|
static String getNumberOfElementsQuery = "select count(*) from %1$s";
|
|
|
|
static String getRandomElements = "select * into %2$s from %1$s order by RANDOM() limit %3$s";
|
|
static String getOtherElements = "select * into %4$s from %1$s where %3$s not in (select %3$s from %2$s)";
|
|
static String dropTable = "drop table %1$s";
|
|
|
|
float threshold = 0.1f;
|
|
String configPath = "./cfg/";
|
|
|
|
|
|
|
|
private int calculateNumberOfPoints(String table) {
|
|
|
|
String numberOfPositiveCasesQuery = String.format(getNumberOfElementsQuery, table);
|
|
List<Object> totalPoints = DatabaseFactory.executeSQLQuery(numberOfPositiveCasesQuery, connection);
|
|
int points = Integer.parseInt("" + totalPoints.get(0));
|
|
return points;
|
|
}
|
|
|
|
private void deleteTable(String testTable, String trainingTable) {
|
|
try {
|
|
DatabaseFactory.executeSQLUpdate(String.format(dropTable, testTable), connection);
|
|
} catch (Exception e) {
|
|
}
|
|
try {
|
|
DatabaseFactory.executeSQLUpdate(String.format(dropTable, trainingTable), connection);
|
|
} catch (Exception e) {
|
|
}
|
|
}
|
|
|
|
private void generatePointsTable(String table, String key, int numberOfElements) {
|
|
String testTable = table + "_test" + numberOfElements;
|
|
String trainingTable = table + "_training" + numberOfElements;
|
|
deleteTable(testTable, trainingTable);
|
|
|
|
String testSetQuery = String.format(getRandomElements, table, testTable, numberOfElements);
|
|
logger.trace("TestSet Creation Query: "+testSetQuery);
|
|
try {
|
|
DatabaseFactory.executeSQLUpdate(testSetQuery, connection);
|
|
} catch (Exception e) {
|
|
e.printStackTrace();
|
|
}
|
|
|
|
String trainingSetQuery = String.format(getOtherElements, table, testTable, key, trainingTable);
|
|
logger.trace("TrainingSet Creation Query: "+trainingSetQuery);
|
|
try {
|
|
DatabaseFactory.executeSQLUpdate(trainingSetQuery, connection);
|
|
} catch (Exception e) {
|
|
e.printStackTrace();
|
|
}
|
|
logger.trace("DONE!");
|
|
}
|
|
|
|
public void generate(AlgorithmConfiguration config) throws Exception {
|
|
init(config);
|
|
int numberOfElements = calculateNumberOfPoints(config.getParam("casesTable"));
|
|
int elementsToTake = (int) (0.4 * (float) numberOfElements);
|
|
generatePointsTable(config.getParam("casesTable"), config.getParam("columnKeyName"), elementsToTake);
|
|
|
|
}
|
|
SessionFactory connection;
|
|
|
|
public void init(AlgorithmConfiguration config) throws Exception {
|
|
// init db connection
|
|
connection = AlgorithmConfiguration.getConnectionFromConfig(config);
|
|
}
|
|
}
|