package eu.dnetlib.jobs.deeplearning; import eu.dnetlib.jobs.AbstractSparkJob; import eu.dnetlib.support.ArgumentApplicationParser; import org.apache.commons.io.FileUtils; import org.apache.commons.io.IOUtils; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.sql.SparkSession; import org.junit.jupiter.api.*; import java.io.File; import java.io.IOException; import java.net.URISyntaxException; import java.nio.file.Paths; @TestMethodOrder(MethodOrderer.OrderAnnotation.class) @TestInstance(TestInstance.Lifecycle.PER_CLASS) public class GNNTrainingTest { static SparkSession spark; static JavaSparkContext context; final static String workingPath = "/tmp/working_dir"; final static String numPartitions = "20"; final String inputDataPath = Paths .get(getClass().getResource("/eu/dnetlib/jobs/examples/authors.groups.example.json").toURI()) .toFile() .getAbsolutePath(); final static String groundTruthJPath = "$.orcid"; final static String idJPath = "$.id"; final static String featuresJPath = "$.topics"; public GNNTrainingTest() throws URISyntaxException {} public static void cleanup() throws IOException { //remove directories and clean workspace FileUtils.deleteDirectory(new File(workingPath)); } @BeforeAll public void setup() throws IOException { cleanup(); spark = SparkSession .builder() .appName("Testing") .master("local[*]") .getOrCreate(); context = JavaSparkContext.fromSparkContext(spark.sparkContext()); } @AfterAll public static void finalCleanUp() throws IOException { cleanup(); } @Test @Order(1) public void createGroupDataSetTest() throws Exception { ArgumentApplicationParser parser = new ArgumentApplicationParser(readResource("/jobs/parameters/createGroupDataset_parameters.json", SparkCreateGroupDataSet.class)); parser.parseArgument( new String[] { "-i", inputDataPath, "-gt", groundTruthJPath, "-id", idJPath, "-f", featuresJPath, "-w", workingPath, "-np", numPartitions } ); new SparkCreateGroupDataSet( parser, spark ).run(); } @Test @Order(2) public void graphClassificationTrainingTest() throws Exception{ ArgumentApplicationParser parser = new ArgumentApplicationParser(readResource("/jobs/parameters/graphClassificationTraining_parameters.json", SparkGraphClassificationTraining.class)); parser.parseArgument( new String[] { "-w", workingPath, "-np", numPartitions } ); new SparkGraphClassificationTraining( parser, spark ).run(); } public static String readResource(String path, Class clazz) throws IOException { return IOUtils.toString(clazz.getResourceAsStream(path)); } }