diff --git a/dhp-workflows/dhp-propagation/src/main/java/eu/dnetlib/dhp/countrypropagation/SparkCountryPropagationJob2.java b/dhp-workflows/dhp-propagation/src/main/java/eu/dnetlib/dhp/countrypropagation/SparkCountryPropagationJob2.java index 967c940b5..b4a415bd7 100644 --- a/dhp-workflows/dhp-propagation/src/main/java/eu/dnetlib/dhp/countrypropagation/SparkCountryPropagationJob2.java +++ b/dhp-workflows/dhp-propagation/src/main/java/eu/dnetlib/dhp/countrypropagation/SparkCountryPropagationJob2.java @@ -4,6 +4,9 @@ import com.fasterxml.jackson.databind.ObjectMapper; import eu.dnetlib.dhp.application.ArgumentApplicationParser; import eu.dnetlib.dhp.schema.oaf.*; import org.apache.commons.io.IOUtils; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.io.compress.CompressionCodec; +import org.apache.hadoop.io.compress.GzipCodec; import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; @@ -18,6 +21,9 @@ import java.util.*; import static eu.dnetlib.dhp.PropagationConstant.*; import org.slf4j.Logger; import org.slf4j.LoggerFactory; + +import static eu.dnetlib.dhp.common.SparkSessionSupport.runWithSparkHiveSession; + public class SparkCountryPropagationJob2 { private static final Logger log = LoggerFactory.getLogger(SparkCountryPropagationJob2.class); @@ -35,30 +41,54 @@ public class SparkCountryPropagationJob2 { parser.parseArgument(args); + Boolean isSparkSessionManaged = isSparkSessionManaged(parser); + log.info("isSparkSessionManaged: {}", isSparkSessionManaged); + String inputPath = parser.get("sourcePath"); log.info("inputPath: {}", inputPath); - final String outputPath = "/tmp/provision/propagation/countrytoresultfrominstitutionalrepositories"; - final String datasourcecountrypath = outputPath + "/prepared_datasource_country"; - final String resultClassName = parser.get("resultTableName"); + final String outputPath = parser.get("outputPath"); + log.info("outputPath: {}", outputPath); - final String resultType = resultClassName.substring(resultClassName.lastIndexOf(".")+1); + final String datasourcecountrypath = parser.get("preparedInfoPath"); + log.info("preparedInfoPath: {}", datasourcecountrypath); + + final String resultClassName = parser.get("resultTableName"); + log.info("resultTableName: {}", resultClassName); + + final String resultType = resultClassName.substring(resultClassName.lastIndexOf(".") + 1).toLowerCase(); + log.info("resultType: {}", resultType); + + final Boolean writeUpdates = Optional + .ofNullable(parser.get("writeUpdate")) + .map(Boolean::valueOf) + .orElse(Boolean.TRUE); + log.info("writeUpdate: {}", writeUpdates); + + final Boolean saveGraph = Optional + .ofNullable(parser.get("saveGraph")) + .map(Boolean::valueOf) + .orElse(Boolean.TRUE); + log.info("saveGraph: {}", saveGraph); Class resultClazz = (Class) Class.forName(resultClassName); SparkConf conf = new SparkConf(); conf.set("hive.metastore.uris", parser.get("hive_metastore_uris")); - final SparkSession spark = SparkSession - .builder() - .appName(SparkCountryPropagationJob2.class.getSimpleName()) - .master(parser.get("master")) - .config(conf) - .enableHiveSupport() - .getOrCreate(); - final boolean writeUpdates = TRUE.equals(parser.get("writeUpdate")); - final boolean saveGraph = TRUE.equals(parser.get("saveGraph")); + runWithSparkHiveSession(conf, isSparkSessionManaged, + spark -> { + //createOutputDirs(outputPath, FileSystem.get(spark.sparkContext().hadoopConfiguration())); + removeOutputDir(spark, outputPath); + execPropagation(spark, datasourcecountrypath, inputPath, outputPath, resultClazz, resultType, + writeUpdates, saveGraph); + }); + } + + private static void execPropagation(SparkSession spark, String datasourcecountrypath, + String inputPath, String outputPath, Class resultClazz, String resultType, + boolean writeUpdates, boolean saveGraph){ final JavaSparkContext sc = new JavaSparkContext(spark.sparkContext()); //Load parque file with preprocessed association datasource - country @@ -74,10 +104,9 @@ public class SparkCountryPropagationJob2 { } if(saveGraph){ - updateResultTable(spark, potentialUpdates, inputPath, resultClazz, outputPath + "/" + resultType); + updateResultTable(spark, potentialUpdates, inputPath, resultClazz, outputPath); } - } private static void updateResultTable(SparkSession spark, Dataset potentialUpdates, @@ -96,33 +125,32 @@ public class SparkCountryPropagationJob2 { Encoders.tuple(Encoders.STRING(), Encoders.bean(ResultCountrySet.class))); Dataset new_table = result_pair - .joinWith(potential_update_pair, result_pair.col("_1").equalTo(potential_update_pair.col("_1")), "left") - .map((MapFunction, Tuple2>, R>) value -> { + .joinWith(potentialUpdates, result_pair.col("_1").equalTo(potentialUpdates.col("resultId")), "left_outer") + .map((MapFunction, ResultCountrySet>, R>) value -> { R r = value._1()._2(); - Optional potentialNewCountries = Optional.ofNullable(value._2()).map(Tuple2::_2); + Optional potentialNewCountries = Optional.ofNullable(value._2()); if (potentialNewCountries.isPresent()) { HashSet countries = new HashSet<>(); for (Qualifier country : r.getCountry()) { countries.add(country.getClassid()); } - for (Country country : potentialNewCountries.get().getCountrySet()) { + for (CountrySbs country : potentialNewCountries.get().getCountrySet()) { if (!countries.contains(country.getClassid())) { - r.getCountry().add(getCountry(country.getClassid(),country.getClassname())); + r.getCountry().add(getCountry(country.getClassid(), country.getClassname())); } } } return r; - }, Encoders.bean(resultClazz)); - log.info("Saving graph table to path: {}", outputPath); + //log.info("number of saved recordsa: {}", new_table.count()); new_table - .toJSON() - .write() - .option("compression", "gzip") - .text(outputPath); + .toJavaRDD() + .map(r -> OBJECT_MAPPER.writeValueAsString(r)) + .saveAsTextFile(outputPath , GzipCodec.class); + } @@ -133,6 +161,7 @@ public class SparkCountryPropagationJob2 { Dataset result = readPathEntity(spark, inputPath, resultClazz); result.createOrReplaceTempView("result"); + //log.info("number of results: {}", result.count()); createCfHbforresult(spark); return countryPropagationAssoc(spark, broadcast_datasourcecountryassoc); } @@ -147,6 +176,7 @@ public class SparkCountryPropagationJob2 { "LATERAL VIEW EXPLODE(instance) i AS inst"; Dataset cfhb = spark.sql(query); cfhb.createOrReplaceTempView("cfhb"); + log.info("cfhb_number : {}", cfhb.count()); } @@ -155,19 +185,22 @@ public class SparkCountryPropagationJob2 { Dataset datasource_country = broadcast_datasourcecountryassoc.value(); datasource_country.createOrReplaceTempView("datasource_country"); + log.info("datasource_country number : {}",datasource_country.count()); - String query = "SELECT id, collect_set(country) country "+ + String query = "SELECT id resultId, collect_set(country) countrySet "+ "FROM ( SELECT id, country " + - "FROM rels " + + "FROM datasource_country " + "JOIN cfhb " + - " ON cf = ds " + + " ON cf = dataSourceId " + "UNION ALL " + "SELECT id , country " + - "FROM rels " + + "FROM datasource_country " + "JOIN cfhb " + - " ON hb = ds ) tmp " + + " ON hb = dataSourceId ) tmp " + "GROUP BY id"; - return spark.sql(query); + Dataset potentialUpdates = spark.sql(query); + log.info("potential update number : {}", potentialUpdates.count()); + return potentialUpdates; } private static Dataset readPathEntity(SparkSession spark, String inputEntityPath, Class resultClazz) { @@ -180,14 +213,15 @@ public class SparkCountryPropagationJob2 { } private static Dataset readAssocDatasourceCountry(SparkSession spark, String relationPath) { - return spark.read() - .load(relationPath) - .as(Encoders.bean(DatasourceCountry.class)); + return spark + .read() + .textFile(relationPath) + .map(value -> OBJECT_MAPPER.readValue(value, DatasourceCountry.class), Encoders.bean(DatasourceCountry.class)); } private static void writeUpdates(JavaRDD potentialUpdates, String outputPath){ potentialUpdates.map(u -> OBJECT_MAPPER.writeValueAsString(u)) - .saveAsTextFile(outputPath); + .saveAsTextFile(outputPath, GzipCodec.class); }