Fix and little enhancements on the Neural Networks model

git-svn-id: https://svn.d4science.research-infrastructures.eu/gcube/trunk/data-analysis/EcologicalEngine@101825 82a268e6-3cf1-43bd-a215-b396298e98cf
This commit is contained in:
Gianpaolo Coro 2014-12-02 11:15:23 +00:00
parent f803af301b
commit a693fa7530
5 changed files with 52 additions and 13 deletions

View File

@ -164,6 +164,7 @@ public class DiscrepancyAnalysis extends DataAnalysis {
if (maxdiscrepancyPoint==null) if (maxdiscrepancyPoint==null)
maxdiscrepancyPoint="-"; maxdiscrepancyPoint="-";
AnalysisLogger.getLogger().debug("Discrepancy Calculation - Kappa values: " + "agreementA1B1 "+agreementA1B1 +" agreementA1B0 " + agreementA1B0 +" agreementA0B1 "+agreementA0B1+" agreementA0B0 "+agreementA0B0);
double kappa = MathFunctions.cohensKappaForDichotomy(agreementA1B1, agreementA1B0, agreementA0B1, agreementA0B0); double kappa = MathFunctions.cohensKappaForDichotomy(agreementA1B1, agreementA1B0, agreementA0B1, agreementA0B0);
AnalysisLogger.getLogger().debug("Discrepancy Calculation - Calculated Cohen's Kappa:" + kappa); AnalysisLogger.getLogger().debug("Discrepancy Calculation - Calculated Cohen's Kappa:" + kappa);
@ -228,7 +229,7 @@ public class DiscrepancyAnalysis extends DataAnalysis {
if (elements[3] != null) if (elements[3] != null)
probabilityPoint2 = Float.parseFloat(""+elements[3]); probabilityPoint2 = Float.parseFloat(""+elements[3]);
float discrepancy = Math.abs(probabilityPoint2 - probabilityPoint1); float discrepancy = Math.abs(probabilityPoint2 - probabilityPoint1);
if (discrepancy > threshold) { if (discrepancy > threshold) {
errors.add(Math.abs(probabilityPoint2 - probabilityPoint1)); errors.add(Math.abs(probabilityPoint2 - probabilityPoint1));
numberoferrors++; numberoferrors++;

View File

@ -192,7 +192,8 @@ public class FeedForwardNN extends ModelAquamapsNN{
AnalysisLogger.getLogger().debug("Features were correctly preprocessed - Training"); AnalysisLogger.getLogger().debug("Features were correctly preprocessed - Training");
// train the NN // train the NN
nn.train(in, out); nn.train(in, out);
learningscore=nn.en;
AnalysisLogger.getLogger().error("Final learning error: "+nn.en);
AnalysisLogger.getLogger().debug("Saving Network"); AnalysisLogger.getLogger().debug("Saving Network");
save(fileName, nn); save(fileName, nn);
AnalysisLogger.getLogger().debug("Done"); AnalysisLogger.getLogger().debug("Done");

View File

@ -6,6 +6,7 @@ import java.io.IOException;
import java.io.ObjectOutputStream; import java.io.ObjectOutputStream;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.HashMap; import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
@ -200,6 +201,8 @@ public class ModelAquamapsNN implements Model {
// train the NN // train the NN
nn.train(in, out); nn.train(in, out);
learningscore=nn.en;
AnalysisLogger.getLogger().error("Final learning error: "+nn.en);
save(fileName, nn); save(fileName, nn);
} catch (Exception e) { } catch (Exception e) {
@ -209,12 +212,16 @@ public class ModelAquamapsNN implements Model {
status = 100f; status = 100f;
} }
double learningscore =0;
@Override @Override
public StatisticalType getOutput() { public StatisticalType getOutput() {
LinkedHashMap<String, StatisticalType> map = new LinkedHashMap<String, StatisticalType>();
PrimitiveType p = new PrimitiveType(File.class.getName(), new File(fileName), PrimitiveTypes.FILE, "NeuralNetwork","Trained Neural Network"); PrimitiveType p = new PrimitiveType(File.class.getName(), new File(fileName), PrimitiveTypes.FILE, "NeuralNetwork","Trained Neural Network");
return p; map.put("Learning", new PrimitiveType(String.class.getName(), "" + learningscore, PrimitiveTypes.STRING, "Learning Score", ""));
map.put("NeuralNetwork", p);
PrimitiveType outputm = new PrimitiveType(LinkedHashMap.class.getName(), map, PrimitiveTypes.MAP, "ResultsMap", "Results Map");
return outputm;
} }
@Override @Override

View File

@ -363,7 +363,7 @@ public class Neural_Network implements Serializable {
enprec=en; enprec=en;
} }
System.out.println("Scarto Finale: " + en); System.out.println("Final Error: " + en);
if (counter >= maxcycle) if (counter >= maxcycle)
AnalysisLogger.getLogger().debug("training incomplete: didn't manage to reduce the error under the thr!"); AnalysisLogger.getLogger().debug("training incomplete: didn't manage to reduce the error under the thr!");
else else
@ -384,9 +384,9 @@ public class Neural_Network implements Serializable {
public void writeout(double numero, double soglia) { public void writeout(double numero, double soglia) {
if (numero < soglia) if (numero < soglia)
System.out.println("Uscita : " + 0); System.out.println("Output : " + 0);
else else
System.out.println("Uscita : " + 1); System.out.println("Output : " + 1);
} }
//classify //classify
@ -451,7 +451,7 @@ public class Neural_Network implements Serializable {
nn.train(in, out); nn.train(in, out);
double[] dummy = { 0, 0 }; double[] dummy = { 0, 0 };
System.out.println("responso sul dummy: " + nn.propagate(dummy)[0]); System.out.println("dummy test " + nn.propagate(dummy)[0]);
nn.writeout(nn.propagate(dummy)[0], 0.5); nn.writeout(nn.propagate(dummy)[0], 0.5);

View File

@ -2,9 +2,9 @@ package org.gcube.dataanalysis.ecoengine.test.regression;
import java.util.List; import java.util.List;
import org.gcube.contentmanagement.lexicalmatcher.utils.AnalysisLogger;
import org.gcube.dataanalysis.ecoengine.configuration.AlgorithmConfiguration; import org.gcube.dataanalysis.ecoengine.configuration.AlgorithmConfiguration;
import org.gcube.dataanalysis.ecoengine.interfaces.ComputationalAgent; import org.gcube.dataanalysis.ecoengine.interfaces.ComputationalAgent;
import org.gcube.dataanalysis.ecoengine.interfaces.Evaluator;
import org.gcube.dataanalysis.ecoengine.processing.factories.EvaluatorsFactory; import org.gcube.dataanalysis.ecoengine.processing.factories.EvaluatorsFactory;
public class RegressionTestEvaluators { public class RegressionTestEvaluators {
@ -14,16 +14,18 @@ public class RegressionTestEvaluators {
*/ */
public static void main(String[] args) throws Exception { public static void main(String[] args) throws Exception {
AnalysisLogger.setLogger("./cfg/ALog.properties");
List<ComputationalAgent> evaluators = null;
/*
List<ComputationalAgent> evaluators = EvaluatorsFactory.getEvaluators(testConfig1()); List<ComputationalAgent> evaluators = EvaluatorsFactory.getEvaluators(testConfig1());
evaluators.get(0).init(); evaluators.get(0).init();
Regressor.process(evaluators.get(0)); Regressor.process(evaluators.get(0));
evaluators = null; evaluators = null;
*/
System.out.println("\n**********-------************\n"); System.out.println("\n**********-------************\n");
//test Discrepancy //test Discrepancy
evaluators = EvaluatorsFactory.getEvaluators(testConfig2()); evaluators = EvaluatorsFactory.getEvaluators(testMapsComparison());
evaluators.get(0).init(); evaluators.get(0).init();
Regressor.process(evaluators.get(0)); Regressor.process(evaluators.get(0));
evaluators = null; evaluators = null;
@ -67,4 +69,32 @@ public class RegressionTestEvaluators {
} }
private static AlgorithmConfiguration testMapsComparison() {
AlgorithmConfiguration config = Regressor.getConfig();
config.setNumberOfResources(1);
config.setAgent("DISCREPANCY_ANALYSIS");
config.setParam("DatabaseUserName","utente");
config.setParam("DatabasePassword","d4science");
config.setParam("DatabaseURL","jdbc:postgresql://statistical-manager.d.d4science.org/testdb");
config.setParam("DatabaseDriver","org.postgresql.Driver");
config.setNumberOfResources(1);
config.setParam("FirstTable", "rstr280e0453e8c7408c96edd49a8dcb5986");
config.setParam("SecondTable", "rstr11b3b436ddaf4ae5ae5227ea8e0658ba");
config.setParam("FirstTableCsquareColumn", "csquarecode");
config.setParam("SecondTableCsquareColumn", "csquarecode");
config.setParam("FirstTableProbabilityColumn", "f_probability");
config.setParam("SecondTableProbabilityColumn", "f_probability");
config.setParam("ComparisonThreshold", "0.5");
config.setParam("KThreshold", "0.5");
config.setParam("MaxSamples", "45000");
return config;
}
} }