unit test for SparkGraphImporterJob
This commit is contained in:
parent
abcd3f5bf5
commit
43cbcda7ef
|
@ -19,6 +19,11 @@
|
||||||
<groupId>org.apache.spark</groupId>
|
<groupId>org.apache.spark</groupId>
|
||||||
<artifactId>spark-sql_2.11</artifactId>
|
<artifactId>spark-sql_2.11</artifactId>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.apache.spark</groupId>
|
||||||
|
<artifactId>spark-hive_2.11</artifactId>
|
||||||
|
<scope>test</scope>
|
||||||
|
</dependency>
|
||||||
|
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>eu.dnetlib.dhp</groupId>
|
<groupId>eu.dnetlib.dhp</groupId>
|
||||||
|
|
|
@ -18,29 +18,38 @@ public class SparkGraphImporterJob {
|
||||||
"/eu/dnetlib/dhp/graph/input_graph_parameters.json")));
|
"/eu/dnetlib/dhp/graph/input_graph_parameters.json")));
|
||||||
parser.parseArgument(args);
|
parser.parseArgument(args);
|
||||||
|
|
||||||
|
new SparkGraphImporterJob().run(parser);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void run(ArgumentApplicationParser parser) {
|
||||||
try(SparkSession spark = getSparkSession(parser)) {
|
try(SparkSession spark = getSparkSession(parser)) {
|
||||||
|
|
||||||
final JavaSparkContext sc = new JavaSparkContext(spark.sparkContext());
|
|
||||||
final String inputPath = parser.get("sourcePath");
|
final String inputPath = parser.get("sourcePath");
|
||||||
final String hiveDbName = parser.get("hive_db_name");
|
final String hiveDbName = parser.get("hive_db_name");
|
||||||
|
|
||||||
spark.sql(String.format("DROP DATABASE IF EXISTS %s CASCADE", hiveDbName));
|
runWith(spark, inputPath, hiveDbName);
|
||||||
spark.sql(String.format("CREATE DATABASE IF NOT EXISTS %s", hiveDbName));
|
|
||||||
|
|
||||||
// Read the input file and convert it into RDD of serializable object
|
|
||||||
GraphMappingUtils.types.forEach((name, clazz) -> spark.createDataset(sc.textFile(inputPath + "/" + name)
|
|
||||||
.map(s -> new ObjectMapper().readValue(s, clazz))
|
|
||||||
.rdd(), Encoders.bean(clazz))
|
|
||||||
.write()
|
|
||||||
.mode(SaveMode.Overwrite)
|
|
||||||
.saveAsTable(hiveDbName + "." + name));
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// public for testing
|
||||||
|
public void runWith(SparkSession spark, String inputPath, String hiveDbName) {
|
||||||
|
|
||||||
|
spark.sql(String.format("DROP DATABASE IF EXISTS %s CASCADE", hiveDbName));
|
||||||
|
spark.sql(String.format("CREATE DATABASE IF NOT EXISTS %s", hiveDbName));
|
||||||
|
|
||||||
|
final JavaSparkContext sc = new JavaSparkContext(spark.sparkContext());
|
||||||
|
// Read the input file and convert it into RDD of serializable object
|
||||||
|
GraphMappingUtils.types.forEach((name, clazz) -> spark.createDataset(sc.textFile(inputPath + "/" + name)
|
||||||
|
.map(s -> new ObjectMapper().readValue(s, clazz))
|
||||||
|
.rdd(), Encoders.bean(clazz))
|
||||||
|
.write()
|
||||||
|
.mode(SaveMode.Overwrite)
|
||||||
|
.saveAsTable(hiveDbName + "." + name));
|
||||||
|
}
|
||||||
|
|
||||||
private static SparkSession getSparkSession(ArgumentApplicationParser parser) {
|
private static SparkSession getSparkSession(ArgumentApplicationParser parser) {
|
||||||
SparkConf conf = new SparkConf();
|
SparkConf conf = new SparkConf();
|
||||||
conf.set("hive.metastore.uris", parser.get("hive_metastore_uris"));
|
conf.set("hive.metastore.uris", parser.get("hive_metastore_uris"));
|
||||||
|
|
||||||
return SparkSession
|
return SparkSession
|
||||||
.builder()
|
.builder()
|
||||||
.appName(SparkGraphImporterJob.class.getSimpleName())
|
.appName(SparkGraphImporterJob.class.getSimpleName())
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
[
|
[
|
||||||
{"paramName":"mt", "paramLongName":"master", "paramDescription": "should be local or yarn", "paramRequired": true},
|
{"paramName":"mt", "paramLongName":"master", "paramDescription": "should be local or yarn", "paramRequired": true},
|
||||||
{"paramName":"s", "paramLongName":"sourcePath", "paramDescription": "the path of the sequencial file to read", "paramRequired": true},
|
{"paramName":"s", "paramLongName":"sourcePath", "paramDescription": "the path of the sequencial file to read", "paramRequired": true},
|
||||||
{"paramName":"h", "paramLongName":"hive_metastore_uris","paramDescription": "the hive metastore uris", "paramRequired": true},
|
{"paramName":"h", "paramLongName":"hive_metastore_uris","paramDescription": "the hive metastore uris", "paramRequired": true},
|
||||||
{"paramName":"db", "paramLongName":"hive_db_name", "paramDescription": "the target hive database name", "paramRequired": true}
|
{"paramName":"db", "paramLongName":"hive_db_name", "paramDescription": "the target hive database name", "paramRequired": true}
|
||||||
]
|
]
|
|
@ -59,10 +59,10 @@
|
||||||
--conf spark.sql.queryExecutionListeners="com.cloudera.spark.lineage.NavigatorQueryListener"
|
--conf spark.sql.queryExecutionListeners="com.cloudera.spark.lineage.NavigatorQueryListener"
|
||||||
--conf spark.sql.warehouse.dir="/user/hive/warehouse"
|
--conf spark.sql.warehouse.dir="/user/hive/warehouse"
|
||||||
</spark-opts>
|
</spark-opts>
|
||||||
<arg>-mt</arg> <arg>yarn-cluster</arg>
|
<arg>-mt</arg> <arg>yarn</arg>
|
||||||
<arg>--sourcePath</arg><arg>${sourcePath}</arg>
|
<arg>-s</arg><arg>${sourcePath}</arg>
|
||||||
<arg>--hive_db_name</arg><arg>${hive_db_name}</arg>
|
<arg>-db</arg><arg>${hive_db_name}</arg>
|
||||||
<arg>--hive_metastore_uris</arg><arg>${hive_metastore_uris}</arg>
|
<arg>-h</arg><arg>${hive_metastore_uris}</arg>
|
||||||
</spark>
|
</spark>
|
||||||
<ok to="PostProcessing"/>
|
<ok to="PostProcessing"/>
|
||||||
<error to="Kill"/>
|
<error to="Kill"/>
|
||||||
|
|
|
@ -1,52 +1,54 @@
|
||||||
package eu.dnetlib.dhp.graph;
|
package eu.dnetlib.dhp.graph;
|
||||||
|
|
||||||
import org.apache.spark.api.java.JavaSparkContext;
|
import org.apache.spark.SparkConf;
|
||||||
import org.apache.spark.sql.Encoders;
|
|
||||||
import org.apache.spark.sql.SparkSession;
|
import org.apache.spark.sql.SparkSession;
|
||||||
import org.junit.jupiter.api.Assertions;
|
import org.junit.jupiter.api.Assertions;
|
||||||
import org.junit.jupiter.api.Disabled;
|
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.junit.jupiter.api.io.TempDir;
|
import org.junit.jupiter.api.io.TempDir;
|
||||||
import scala.Tuple2;
|
|
||||||
|
|
||||||
import java.nio.file.Path;
|
import java.nio.file.Path;
|
||||||
import java.util.List;
|
|
||||||
import java.util.stream.Collectors;
|
|
||||||
|
|
||||||
public class SparkGraphImporterJobTest {
|
public class SparkGraphImporterJobTest {
|
||||||
|
|
||||||
private static final long MAX = 1000L;
|
private final static String TEST_DB_NAME = "test";
|
||||||
|
|
||||||
@Disabled("must be parametrized to run locally")
|
@Test
|
||||||
public void testImport(@TempDir Path outPath) throws Exception {
|
public void testImport(@TempDir Path outPath) {
|
||||||
SparkGraphImporterJob.main(new String[] {
|
try(SparkSession spark = testSparkSession(outPath.toString())) {
|
||||||
"-mt", "local[*]",
|
|
||||||
"-s", getClass().getResource("/eu/dnetlib/dhp/graph/sample").getPath(),
|
|
||||||
"-h", "",
|
|
||||||
"-db", "test"
|
|
||||||
});
|
|
||||||
|
|
||||||
countEntities(outPath.toString()).forEach(t -> {
|
new SparkGraphImporterJob().runWith(
|
||||||
System.out.println(t);
|
spark,
|
||||||
Assertions.assertEquals(MAX, t._2().longValue(), String.format("mapped %s must be %s", t._1(), MAX));
|
getClass().getResource("/eu/dnetlib/dhp/graph/sample").getPath(),
|
||||||
});
|
TEST_DB_NAME);
|
||||||
|
|
||||||
|
GraphMappingUtils.types.forEach((name, clazz) -> {
|
||||||
|
final long count = spark.read().table(TEST_DB_NAME + "." + name).count();
|
||||||
|
if (name.equals("relation")) {
|
||||||
|
Assertions.assertEquals(100, count, String.format("%s should be 100", name));
|
||||||
|
} else {
|
||||||
|
Assertions.assertEquals(10, count, String.format("%s should be 10", name));
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public static List<Tuple2<String, Long>> countEntities(final String inputPath) {
|
private SparkSession testSparkSession(final String inputPath) {
|
||||||
|
SparkConf conf = new SparkConf();
|
||||||
|
|
||||||
final SparkSession spark = SparkSession
|
conf.set("spark.driver.host", "localhost");
|
||||||
|
conf.set("hive.metastore.local", "true");
|
||||||
|
conf.set("hive.metastore.warehouse.dir", inputPath + "/warehouse");
|
||||||
|
conf.set("spark.sql.warehouse.dir", inputPath);
|
||||||
|
conf.set("javax.jdo.option.ConnectionURL", String.format("jdbc:derby:;databaseName=%s/junit_metastore_db;create=true", inputPath));
|
||||||
|
conf.set("spark.ui.enabled", "false");
|
||||||
|
|
||||||
|
return SparkSession
|
||||||
.builder()
|
.builder()
|
||||||
.appName(SparkGraphImporterJobTest.class.getSimpleName())
|
.appName(SparkGraphImporterJobTest.class.getSimpleName())
|
||||||
.master("local[*]")
|
.master("local[*]")
|
||||||
|
.config(conf)
|
||||||
|
.enableHiveSupport()
|
||||||
.getOrCreate();
|
.getOrCreate();
|
||||||
//final JavaSparkContext sc = new JavaSparkContext(spark.sparkContext());
|
|
||||||
|
|
||||||
return GraphMappingUtils.types.entrySet()
|
|
||||||
.stream()
|
|
||||||
.map(entry -> {
|
|
||||||
final Long count = spark.read().load(inputPath + "/" + entry.getKey()).as(Encoders.bean(entry.getValue())).count();
|
|
||||||
return new Tuple2<String, Long>(entry.getKey(), count);
|
|
||||||
})
|
|
||||||
.collect(Collectors.toList());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
Binary file not shown.
6
pom.xml
6
pom.xml
|
@ -143,6 +143,12 @@
|
||||||
<version>${dhp.spark.version}</version>
|
<version>${dhp.spark.version}</version>
|
||||||
<scope>provided</scope>
|
<scope>provided</scope>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.apache.spark</groupId>
|
||||||
|
<artifactId>spark-hive_2.11</artifactId>
|
||||||
|
<version>${dhp.spark.version}</version>
|
||||||
|
<scope>test</scope>
|
||||||
|
</dependency>
|
||||||
|
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.slf4j</groupId>
|
<groupId>org.slf4j</groupId>
|
||||||
|
|
Loading…
Reference in New Issue