package eu.dnetlib.jobs; import eu.dnetlib.featureextraction.FeatureTransformer; import eu.dnetlib.support.ArgumentApplicationParser; import org.apache.spark.SparkConf; import org.apache.spark.ml.feature.CountVectorizerModel; import org.apache.spark.sql.*; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.IOException; import java.util.Optional; public class SparkCreateVocabulary extends AbstractSparkJob{ final static int VOCAB_SIZE = 1<<18; final static double MIN_DF = 0.1; final static double MIN_TF = 1; private static final Logger log = LoggerFactory.getLogger(SparkCreateVocabulary.class); public SparkCreateVocabulary(ArgumentApplicationParser parser, SparkSession spark) { super(parser, spark); } public static void main(String[] args) throws Exception { ArgumentApplicationParser parser = new ArgumentApplicationParser( readResource("/jobs/parameters/createVocabulary_parameters.json", SparkCreateVocabulary.class) ); parser.parseArgument(args); SparkConf conf = new SparkConf(); new SparkCreateVocabulary( parser, getSparkSession(conf) ).run(); } @Override public void run() throws IOException { // read oozie parameters final String workingPath = parser.get("workingPath"); final String vocabularyPath = parser.get("vocabularyPath"); final String vocabularyType = parser.get("vocabularyType"); //from file or from tokens final double minDF = Optional .ofNullable(parser.get("minDF")) .map(Double::valueOf) .orElse(MIN_DF); final double minTF = Optional .ofNullable(parser.get("minTF")) .map(Double::valueOf) .orElse(MIN_TF); final int numPartitions = Optional .ofNullable(parser.get("numPartitions")) .map(Integer::valueOf) .orElse(NUM_PARTITIONS); final int vocabSize = Optional .ofNullable(parser.get("vocabSize")) .map(Integer::valueOf) .orElse(VOCAB_SIZE); log.info("workingPath: '{}'", workingPath); log.info("vocabularyPath: '{}'", vocabularyPath); log.info("vocabularyType: '{}'", vocabularyType); log.info("minDF: '{}'", minDF); log.info("minTF: '{}'", minTF); log.info("vocabSize: '{}'", vocabSize); Dataset inputTokensDS = spark.read().load(workingPath + "/tokens").repartition(numPartitions); CountVectorizerModel vocabulary; if (vocabularyType.equals("file")) { vocabulary = FeatureTransformer.createVocabularyFromFile(); } else { vocabulary = FeatureTransformer.createVocabularyFromTokens(inputTokensDS, minDF, minTF, vocabSize); } vocabulary.write().overwrite().save(vocabularyPath); } }