dnet-and/dnet-and-test/src/main/java/eu/dnetlib/jobs/deeplearning/SparkGraphClassificationTra...

100 lines
3.9 KiB
Java

package eu.dnetlib.jobs.deeplearning;
import eu.dnetlib.deeplearning.support.DataSetProcessor;
import eu.dnetlib.deeplearning.support.NetworkConfigurations;
import eu.dnetlib.jobs.AbstractSparkJob;
import eu.dnetlib.jobs.SparkLDATuning;
import eu.dnetlib.support.ArgumentApplicationParser;
import eu.dnetlib.support.ConnectedComponent;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.SparkSession;
import org.codehaus.jackson.map.ObjectMapper;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.optimize.listeners.PerformanceListener;
import org.deeplearning4j.optimize.solvers.accumulation.encoding.threshold.AdaptiveThresholdAlgorithm;
import org.deeplearning4j.spark.api.RDDTrainingApproach;
import org.deeplearning4j.spark.api.TrainingMaster;
import org.deeplearning4j.spark.impl.graph.SparkComputationGraph;
import org.deeplearning4j.spark.parameterserver.training.SharedTrainingMaster;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.parameterserver.distributed.conf.VoidConfiguration;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.util.Optional;
public class SparkGraphClassificationTraining extends AbstractSparkJob {
private static final Logger log = LoggerFactory.getLogger(SparkGraphClassificationTraining.class);
public SparkGraphClassificationTraining(ArgumentApplicationParser parser, SparkSession spark) {
super(parser, spark);
}
public static void main(String[] args) throws Exception {
ArgumentApplicationParser parser = new ArgumentApplicationParser(
readResource("/jobs/parameters/graphClassificationTraining_parameters.json", SparkLDATuning.class)
);
parser.parseArgument(args);
SparkConf conf = new SparkConf();
new SparkGraphClassificationTraining(
parser,
getSparkSession(conf)
).run();
}
@Override
public void run() throws IOException {
// read oozie parameters
final String workingPath = parser.get("workingPath");
final int numPartitions = Optional
.ofNullable(parser.get("numPartitions"))
.map(Integer::valueOf)
.orElse(NUM_PARTITIONS);
log.info("workingPath: '{}'", workingPath);
log.info("numPartitions: '{}'", numPartitions);
JavaSparkContext context = JavaSparkContext.fromSparkContext(spark.sparkContext());
VoidConfiguration conf = VoidConfiguration.builder()
.unicastPort(40123)
// .networkMask("255.255.148.0/22")
.controllerAddress("127.0.0.1")
.build();
TrainingMaster trainingMaster = new SharedTrainingMaster.Builder(conf,1)
.rngSeed(12345)
.collectTrainingStats(false)
.thresholdAlgorithm(new AdaptiveThresholdAlgorithm(1e-3))
.batchSizePerWorker(32)
.workersPerNode(4)
.rddTrainingApproach(RDDTrainingApproach.Direct)
.build();
JavaRDD<MultiDataSet> trainData = context.objectFile(workingPath + "/groupDataset");
SparkComputationGraph sparkComputationGraph = new SparkComputationGraph(
context,
NetworkConfigurations.getSimpleGCN(3, 2, 5, 2),
trainingMaster);
sparkComputationGraph.setListeners(new PerformanceListener(10, true));
//execute training
for (int i = 0; i < 20; i ++) {
sparkComputationGraph.fitMultiDataSet(trainData);
}
ComputationGraph network = sparkComputationGraph.getNetwork();
System.out.println("network = " + network.getConfiguration().toJson());
}
}