dnet-and/dnet-feature-extraction/src/main/java/eu/dnetlib/deeplearning/GroupClassifier.java

131 lines
5.8 KiB
Java

//package eu.dnetlib.deeplearning;
//
///* *****************************************************************************
// *
// *
// *
// * This program and the accompanying materials are made available under the
// * terms of the Apache License, Version 2.0 which is available at
// * https://www.apache.org/licenses/LICENSE-2.0.
// * See the NOTICE file distributed with this work for additional
// * information regarding copyright ownership.
// *
// * Unless required by applicable law or agreed to in writing, software
// * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// * License for the specific language governing permissions and limitations
// * under the License.
// *
// * SPDX-License-Identifier: Apache-2.0
// ******************************************************************************/
//
//import org.datavec.api.records.reader.RecordReader;
//import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
//import org.datavec.api.split.FileSplit;
//import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
//import org.deeplearning4j.examples.utils.DownloaderUtility;
//import org.deeplearning4j.examples.utils.PlotUtil;
//import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
//import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
//import org.deeplearning4j.nn.conf.layers.DenseLayer;
//import org.deeplearning4j.nn.conf.layers.OutputLayer;
//import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
//import org.deeplearning4j.nn.weights.WeightInit;
//import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
//import org.nd4j.evaluation.classification.Evaluation;
//import org.nd4j.linalg.activations.Activation;
//import org.nd4j.linalg.api.ndarray.INDArray;
//import org.nd4j.linalg.dataset.DataSet;
//import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
//import org.nd4j.linalg.learning.config.Nesterovs;
//import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;
//
//import java.io.File;
//import java.util.concurrent.TimeUnit;
//
//public class GroupClassifier {
//
// public static boolean visualize = true;
// public static String dataLocalPath;
//
// public static void main(String[] args) throws Exception {
// int seed = 123;
// double learningRate = 0.01;
// int batchSize = 50;
// int nEpochs = 30;
//
// int numInputs = 2;
// int numOutputs = 2;
// int numHiddenNodes = 20;
//
// dataLocalPath = DownloaderUtility.CLASSIFICATIONDATA.Download();
// //Load the training data:
// RecordReader rr = new CSVRecordReader();
// rr.initialize(new FileSplit(new File(dataLocalPath, "linear_data_train.csv")));
// DataSetIterator trainIter = new RecordReaderDataSetIterator(rr, batchSize, 0, 2);
//
// //Load the test/evaluation data:
// RecordReader rrTest = new CSVRecordReader();
// rrTest.initialize(new FileSplit(new File(dataLocalPath, "linear_data_eval.csv")));
// DataSetIterator testIter = new RecordReaderDataSetIterator(rrTest, batchSize, 0, 2);
//
// MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
// .seed(seed)
// .weightInit(WeightInit.XAVIER)
// .updater(new Nesterovs(learningRate, 0.9))
// .list()
// .layer(new DenseLayer.Builder().nIn(numInputs).nOut(numHiddenNodes)
// .activation(Activation.RELU)
// .build())
// .layer(new OutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD)
// .activation(Activation.SOFTMAX)
// .nIn(numHiddenNodes).nOut(numOutputs).build())
// .build();
//
//
// MultiLayerNetwork model = new MultiLayerNetwork(conf);
// model.init();
// model.setListeners(new ScoreIterationListener(10)); //Print score every 10 parameter updates
//
// model.fit(trainIter, nEpochs);
//
// System.out.println("Evaluate model....");
// Evaluation eval = new Evaluation(numOutputs);
// while (testIter.hasNext()) {
// DataSet t = testIter.next();
// INDArray features = t.getFeatures();
// INDArray labels = t.getLabels();
// INDArray predicted = model.output(features, false);
// eval.eval(labels, predicted);
// }
// //An alternate way to do the above loop
// //Evaluation evalResults = model.evaluate(testIter);
//
// //Print the evaluation statistics
// System.out.println(eval.stats());
//
// System.out.println("\n****************Example finished********************");
// //Training is complete. Code that follows is for plotting the data & predictions only
// generateVisuals(model, trainIter, testIter);
// }
//
// public static void generateVisuals(MultiLayerNetwork model, DataSetIterator trainIter, DataSetIterator testIter) throws Exception {
// if (visualize) {
// double xMin = 0;
// double xMax = 1.0;
// double yMin = -0.2;
// double yMax = 0.8;
// int nPointsPerAxis = 100;
//
// //Generate x,y points that span the whole range of features
// INDArray allXYPoints = PlotUtil.generatePointsOnGraph(xMin, xMax, yMin, yMax, nPointsPerAxis);
// //Get train data and plot with predictions
// PlotUtil.plotTrainingData(model, trainIter, allXYPoints, nPointsPerAxis);
// TimeUnit.SECONDS.sleep(3);
// //Get test data, run the test data through the network to generate predictions, and plot those predictions:
// PlotUtil.plotTestData(model, testIter, allXYPoints, nPointsPerAxis);
// }
// }
//}
//