package eu.dnetlib.jobs; import eu.dnetlib.featureextraction.FeatureTransformer; import eu.dnetlib.support.ArgumentApplicationParser; import org.apache.spark.SparkConf; import org.apache.spark.ml.feature.CountVectorizerModel; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.SaveMode; import org.apache.spark.sql.SparkSession; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.IOException; import java.util.Optional; public class SparkCountVectorizer extends AbstractSparkJob{ private static final Logger log = LoggerFactory.getLogger(SparkCountVectorizer.class); public SparkCountVectorizer(ArgumentApplicationParser parser, SparkSession spark) { super(parser, spark); } public static void main(String[] args) throws Exception { ArgumentApplicationParser parser = new ArgumentApplicationParser( readResource("/jobs/parameters/countVectorizer_parameters.json", SparkCountVectorizer.class) ); parser.parseArgument(args); SparkConf conf = new SparkConf(); new SparkCountVectorizer( parser, getSparkSession(conf) ).run(); } @Override public void run() throws IOException { // read oozie parameters final String workingPath = parser.get("workingPath"); final String vocabularyPath = parser.get("vocabularyPath"); final int numPartitions = Optional .ofNullable(parser.get("numPartitions")) .map(Integer::valueOf) .orElse(NUM_PARTITIONS); log.info("workingPath: '{}'", workingPath); log.info("vocabularyPath: '{}'", vocabularyPath); log.info("numPartitions: '{}'", numPartitions); //read input tokens Dataset inputTokensDS = spark.read().load(workingPath + "/tokens").repartition(numPartitions); //read vocabulary CountVectorizerModel vocabulary = FeatureTransformer.loadVocabulary(vocabularyPath); Dataset countVectorizedData = FeatureTransformer.countVectorizeData(inputTokensDS, vocabulary); countVectorizedData .write() .mode(SaveMode.Overwrite) .save(workingPath + "/countVectorized"); } }