dnet-and/dnet-and-test/src/test/java/eu/dnetlib/jobs/deeplearning/GNNTrainingTest.java

103 lines
3.2 KiB
Java

package eu.dnetlib.jobs.deeplearning;
import eu.dnetlib.jobs.AbstractSparkJob;
import eu.dnetlib.support.ArgumentApplicationParser;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.IOUtils;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.SparkSession;
import org.junit.jupiter.api.*;
import java.io.File;
import java.io.IOException;
import java.net.URISyntaxException;
import java.nio.file.Paths;
@TestMethodOrder(MethodOrderer.OrderAnnotation.class)
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
public class GNNTrainingTest {
static SparkSession spark;
static JavaSparkContext context;
final static String workingPath = "/tmp/working_dir";
final static String numPartitions = "20";
final String inputDataPath = Paths
.get(getClass().getResource("/eu/dnetlib/jobs/examples/authors.groups.example.json").toURI())
.toFile()
.getAbsolutePath();
final static String groundTruthJPath = "$.orcid";
final static String idJPath = "$.id";
final static String featuresJPath = "$.topics";
public GNNTrainingTest() throws URISyntaxException {}
public static void cleanup() throws IOException {
//remove directories and clean workspace
FileUtils.deleteDirectory(new File(workingPath));
}
@BeforeAll
public void setup() throws IOException {
cleanup();
spark = SparkSession
.builder()
.appName("Testing")
.master("local[*]")
.getOrCreate();
context = JavaSparkContext.fromSparkContext(spark.sparkContext());
}
@AfterAll
public static void finalCleanUp() throws IOException {
cleanup();
}
@Test
@Order(1)
public void createGroupDataSetTest() throws Exception {
ArgumentApplicationParser parser = new ArgumentApplicationParser(readResource("/jobs/parameters/createGroupDataset_parameters.json", SparkCreateGroupDataSet.class));
parser.parseArgument(
new String[] {
"-i", inputDataPath,
"-gt", groundTruthJPath,
"-id", idJPath,
"-f", featuresJPath,
"-w", workingPath,
"-np", numPartitions
}
);
new SparkCreateGroupDataSet(
parser,
spark
).run();
}
@Test
@Order(2)
public void graphClassificationTrainingTest() throws Exception{
ArgumentApplicationParser parser = new ArgumentApplicationParser(readResource("/jobs/parameters/graphClassificationTraining_parameters.json", SparkGraphClassificationTraining.class));
parser.parseArgument(
new String[] {
"-w", workingPath,
"-np", numPartitions
}
);
new SparkGraphClassificationTraining(
parser,
spark
).run();
}
public static String readResource(String path, Class<? extends AbstractSparkJob> clazz) throws IOException {
return IOUtils.toString(clazz.getResourceAsStream(path));
}
}