package eu.dnetlib.dhp.common; import java.util.Objects; import org.apache.spark.api.java.function.FilterFunction; import org.apache.spark.api.java.function.MapFunction; import org.apache.spark.sql.*; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import com.fasterxml.jackson.databind.ObjectMapper; import eu.dnetlib.dhp.schema.common.ModelSupport; import eu.dnetlib.dhp.schema.oaf.Oaf; public class GraphSupport { private static final Logger log = LoggerFactory.getLogger(GraphSupport.class); private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); public static void deleteGraphTable(SparkSession spark, Class clazz, String outputGraph, GraphFormat graphFormat) { switch (graphFormat) { case JSON: String outPath = outputGraph + "/" + clazz.getSimpleName().toLowerCase(); removeOutputDir(spark, outPath); break; case HIVE: String table = ModelSupport.tableIdentifier(outputGraph, clazz); String sql = String.format("DROP TABLE IF EXISTS %s PURGE", table); log.info("running SQL: '{}'", sql); spark.sql(sql); break; } } public static void saveGraphTable(Dataset dataset, Class clazz, String outputGraph, GraphFormat graphFormat) { final DataFrameWriter writer = dataset.write().mode(SaveMode.Overwrite); switch (graphFormat) { case JSON: String type = clazz.getSimpleName().toLowerCase(); String outPath = outputGraph + "/" + type; log.info("saving graph to path {},", outPath); writer.option("compression", "gzip").json(outPath); break; case HIVE: final String db_table = ModelSupport.tableIdentifier(outputGraph, clazz); log.info("saving graph to '{}'", db_table); writer.saveAsTable(db_table); break; } } public static Dataset readGraph( SparkSession spark, String graph, Class clazz, GraphFormat format) { log.info("reading graph {}, format {}, class {}", graph, format, clazz); Encoder encoder = Encoders.bean(clazz); switch (format) { case JSON: String path = graph + "/" + clazz.getSimpleName().toLowerCase(); log.info("reading path {}", path); return spark .read() .textFile(path) .map( (MapFunction) value -> OBJECT_MAPPER.readValue(value, clazz), encoder) .filter((FilterFunction) value -> Objects.nonNull(ModelSupport.idFn().apply(value))); case HIVE: String table = ModelSupport.tableIdentifier(graph, clazz); log.info("reading table {}", table); return spark.table(table).as(encoder); default: throw new IllegalStateException(String.format("format not managed: '%s'", format)); } } public static Dataset readGraphJSON(SparkSession spark, String graph, Class clazz) { return readGraph(spark, graph, clazz, GraphFormat.JSON); } private static void removeOutputDir(SparkSession spark, String path) { HdfsSupport.remove(path, spark.sparkContext().hadoopConfiguration()); } }