103 lines
3.2 KiB
Java
103 lines
3.2 KiB
Java
package eu.dnetlib.jobs.deeplearning;
|
|
|
|
import eu.dnetlib.jobs.AbstractSparkJob;
|
|
import eu.dnetlib.support.ArgumentApplicationParser;
|
|
import org.apache.commons.io.FileUtils;
|
|
import org.apache.commons.io.IOUtils;
|
|
import org.apache.spark.api.java.JavaSparkContext;
|
|
import org.apache.spark.sql.SparkSession;
|
|
import org.junit.jupiter.api.*;
|
|
|
|
import java.io.File;
|
|
import java.io.IOException;
|
|
import java.net.URISyntaxException;
|
|
import java.nio.file.Paths;
|
|
|
|
@TestMethodOrder(MethodOrderer.OrderAnnotation.class)
|
|
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
|
|
public class GNNTrainingTest {
|
|
|
|
static SparkSession spark;
|
|
static JavaSparkContext context;
|
|
final static String workingPath = "/tmp/working_dir";
|
|
|
|
final static String numPartitions = "20";
|
|
final String inputDataPath = Paths
|
|
.get(getClass().getResource("/eu/dnetlib/jobs/examples/authors.groups.example.json").toURI())
|
|
.toFile()
|
|
.getAbsolutePath();
|
|
final static String groundTruthJPath = "$.orcid";
|
|
final static String idJPath = "$.id";
|
|
final static String featuresJPath = "$.topics";
|
|
|
|
public GNNTrainingTest() throws URISyntaxException {}
|
|
|
|
public static void cleanup() throws IOException {
|
|
//remove directories and clean workspace
|
|
FileUtils.deleteDirectory(new File(workingPath));
|
|
}
|
|
|
|
@BeforeAll
|
|
public void setup() throws IOException {
|
|
cleanup();
|
|
|
|
spark = SparkSession
|
|
.builder()
|
|
.appName("Testing")
|
|
.master("local[*]")
|
|
.getOrCreate();
|
|
|
|
context = JavaSparkContext.fromSparkContext(spark.sparkContext());
|
|
}
|
|
|
|
@AfterAll
|
|
public static void finalCleanUp() throws IOException {
|
|
cleanup();
|
|
}
|
|
|
|
@Test
|
|
@Order(1)
|
|
public void createGroupDataSetTest() throws Exception {
|
|
ArgumentApplicationParser parser = new ArgumentApplicationParser(readResource("/jobs/parameters/createGroupDataset_parameters.json", SparkCreateGroupDataSet.class));
|
|
|
|
parser.parseArgument(
|
|
new String[] {
|
|
"-i", inputDataPath,
|
|
"-gt", groundTruthJPath,
|
|
"-id", idJPath,
|
|
"-f", featuresJPath,
|
|
"-w", workingPath,
|
|
"-np", numPartitions
|
|
}
|
|
);
|
|
|
|
new SparkCreateGroupDataSet(
|
|
parser,
|
|
spark
|
|
).run();
|
|
|
|
}
|
|
|
|
@Test
|
|
@Order(2)
|
|
public void graphClassificationTrainingTest() throws Exception{
|
|
ArgumentApplicationParser parser = new ArgumentApplicationParser(readResource("/jobs/parameters/graphClassificationTraining_parameters.json", SparkGraphClassificationTraining.class));
|
|
|
|
parser.parseArgument(
|
|
new String[] {
|
|
"-w", workingPath,
|
|
"-np", numPartitions
|
|
}
|
|
);
|
|
|
|
new SparkGraphClassificationTraining(
|
|
parser,
|
|
spark
|
|
).run();
|
|
}
|
|
|
|
public static String readResource(String path, Class<? extends AbstractSparkJob> clazz) throws IOException {
|
|
return IOUtils.toString(clazz.getResourceAsStream(path));
|
|
}
|
|
}
|