package eu.dnetlib.jobs.deeplearning; import eu.dnetlib.deeplearning.support.DataSetProcessor; import eu.dnetlib.deeplearning.support.NetworkConfigurations; import eu.dnetlib.jobs.AbstractSparkJob; import eu.dnetlib.jobs.SparkLDATuning; import eu.dnetlib.support.ArgumentApplicationParser; import eu.dnetlib.support.ConnectedComponent; import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.sql.SparkSession; import org.codehaus.jackson.map.ObjectMapper; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.optimize.listeners.PerformanceListener; import org.deeplearning4j.optimize.solvers.accumulation.encoding.threshold.AdaptiveThresholdAlgorithm; import org.deeplearning4j.spark.api.RDDTrainingApproach; import org.deeplearning4j.spark.api.TrainingMaster; import org.deeplearning4j.spark.impl.graph.SparkComputationGraph; import org.deeplearning4j.spark.parameterserver.training.SharedTrainingMaster; import org.nd4j.linalg.dataset.api.MultiDataSet; import org.nd4j.parameterserver.distributed.conf.VoidConfiguration; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.IOException; import java.util.Optional; public class SparkGraphClassificationTraining extends AbstractSparkJob { private static final Logger log = LoggerFactory.getLogger(SparkGraphClassificationTraining.class); public SparkGraphClassificationTraining(ArgumentApplicationParser parser, SparkSession spark) { super(parser, spark); } public static void main(String[] args) throws Exception { ArgumentApplicationParser parser = new ArgumentApplicationParser( readResource("/jobs/parameters/graphClassificationTraining_parameters.json", SparkLDATuning.class) ); parser.parseArgument(args); SparkConf conf = new SparkConf(); new SparkGraphClassificationTraining( parser, getSparkSession(conf) ).run(); } @Override public void run() throws IOException { // read oozie parameters final String workingPath = parser.get("workingPath"); final int numPartitions = Optional .ofNullable(parser.get("numPartitions")) .map(Integer::valueOf) .orElse(NUM_PARTITIONS); log.info("workingPath: '{}'", workingPath); log.info("numPartitions: '{}'", numPartitions); JavaSparkContext context = JavaSparkContext.fromSparkContext(spark.sparkContext()); VoidConfiguration conf = VoidConfiguration.builder() .unicastPort(40123) // .networkMask("255.255.148.0/22") .controllerAddress("127.0.0.1") .build(); TrainingMaster trainingMaster = new SharedTrainingMaster.Builder(conf,1) .rngSeed(12345) .collectTrainingStats(false) .thresholdAlgorithm(new AdaptiveThresholdAlgorithm(1e-3)) .batchSizePerWorker(32) .workersPerNode(4) .rddTrainingApproach(RDDTrainingApproach.Direct) .build(); JavaRDD trainData = context.objectFile(workingPath + "/groupDataset"); SparkComputationGraph sparkComputationGraph = new SparkComputationGraph( context, NetworkConfigurations.getSimpleGCN(3, 2, 5, 2), trainingMaster); sparkComputationGraph.setListeners(new PerformanceListener(10, true)); //execute training for (int i = 0; i < 20; i ++) { sparkComputationGraph.fitMultiDataSet(trainData); } ComputationGraph network = sparkComputationGraph.getNetwork(); System.out.println("network = " + network.getConfiguration().toJson()); } }