package eu.dnetlib.featureextraction; import com.google.common.collect.Lists; import com.johnsnowlabs.nlp.*; import com.johnsnowlabs.nlp.annotators.sbd.pragmatic.SentenceDetector; import com.johnsnowlabs.nlp.embeddings.BertSentenceEmbeddings; import org.apache.spark.ml.Pipeline; import org.apache.spark.ml.PipelineStage; import org.apache.spark.ml.clustering.LDA; import org.apache.spark.ml.clustering.LDAModel; import org.apache.spark.ml.feature.CountVectorizer; import org.apache.spark.ml.feature.CountVectorizerModel; import org.apache.spark.ml.feature.StopWordsRemover; import org.apache.spark.ml.feature.Tokenizer; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; import scala.Tuple2; import java.io.BufferedReader; import java.io.IOException; import java.io.InputStreamReader; import java.io.Serializable; import java.net.URISyntaxException; import java.nio.file.Paths; import java.util.*; public class FeatureTransformer implements Serializable { public static String ID_COL = "id"; public static String TOKENIZER_INPUT_COL = "sentence"; public static String TOKENIZER_OUTPUT_COL = "rawTokens"; public static String STOPWORDREMOVER_OUTPUT_COL = "tokens"; public static String COUNTVECTORIZER_OUTPUT_COL = "features"; public static String LDA_INFERENCE_OUTPUT_COL = "topicDistribution"; public static String LDA_OPTIMIZER = "online"; /** * Returns the tokenization of the data without stopwords * * @param inputDS: the input dataset of the form (id, sentence) * @return the tokenized data (id, tokens) */ public static Dataset tokenizeData(Dataset inputDS) { Tokenizer tokenizer = new Tokenizer().setInputCol(TOKENIZER_INPUT_COL).setOutputCol(TOKENIZER_OUTPUT_COL); StopWordsRemover remover = new StopWordsRemover().setInputCol(TOKENIZER_OUTPUT_COL).setOutputCol(STOPWORDREMOVER_OUTPUT_COL); //TODO consider implementing stemming with SparkNLP library from johnsnowlab Dataset rawTokensDS = tokenizer.transform(inputDS).select(ID_COL, TOKENIZER_OUTPUT_COL); return remover.transform(rawTokensDS).select(ID_COL, STOPWORDREMOVER_OUTPUT_COL); } /** * Create the vocabulary from the given data. * * @param inputDS: the input dataset of the form (id, tokens) * @param minDF: minimum number of different documents a term could appear in to be included in the vocabulary * @param minTF: filter to ignore rare words in a document * @param vocabSize: maximum size of the vocabulary (number of terms) * @return the vocabulary */ public static CountVectorizerModel createVocabularyFromTokens(Dataset inputDS, double minDF, double minTF, int vocabSize) { return new CountVectorizer() .setInputCol(STOPWORDREMOVER_OUTPUT_COL) .setOutputCol(COUNTVECTORIZER_OUTPUT_COL) .setMinDF(minDF) .setMinTF(minTF) .setVocabSize(vocabSize) .fit(inputDS); //TODO setMaxDF not found, try to add it } /** * Create the vocabulary from resource file. * * @return the vocabulary */ public static CountVectorizerModel createVocabularyFromFile() throws IOException { Set fileLines = new HashSet<>(); BufferedReader bf = new BufferedReader( new InputStreamReader( FeatureTransformer.class.getResourceAsStream("/eu/dnetlib/featureextraction/support/dewey_vocabulary.txt") ) ); String line = bf.readLine(); while(line != null) { fileLines.add(line); line = bf.readLine(); } bf.close(); return new CountVectorizerModel(fileLines.toArray(new String[0])).setInputCol(STOPWORDREMOVER_OUTPUT_COL).setOutputCol(COUNTVECTORIZER_OUTPUT_COL); } /** * Load an existing vocabulary * * @param vocabularyPath: location of the vocabulary * @return the vocabulary */ public static CountVectorizerModel loadVocabulary(String vocabularyPath) { return CountVectorizerModel.load(vocabularyPath); } /** * Count vectorize data. * * @param inputDS: the input dataset of the form (id, tokens) * @param vocabulary: the vocabulary to be used for the transformation * @return the count vectorized data */ public static Dataset countVectorizeData(Dataset inputDS, CountVectorizerModel vocabulary) { return vocabulary.transform(inputDS).select(ID_COL, COUNTVECTORIZER_OUTPUT_COL); } /** * Train LDA model with the given parameters * * @param inputDS: the input dataset * @param k: number of topics * @param maxIter: maximum number of iterations * @return the LDA model */ public static LDAModel trainLDAModel(Dataset inputDS, int k, int maxIter) { LDA lda = new LDA() .setK(k) .setMaxIter(maxIter) .setFeaturesCol(COUNTVECTORIZER_OUTPUT_COL) .setOptimizer(LDA_OPTIMIZER); return lda.fit(inputDS); } /** * Tune the LDA model varying the number of topics. * * @param dataDS: the input data in the form (id, features) * @param trainRatio: percentage of the input data to be used as training set * @param numTopics: topics to which test the LDA model * @param maxIter: maximum number of iterations of the algorithm * @return map of trained model with perplexity * */ public static Map> ldaTuning(Dataset dataDS, double trainRatio, int[] numTopics, int maxIter) { Dataset[] setsDS = dataDS.randomSplit(new double[]{trainRatio, 1 - trainRatio}); Dataset trainDS = setsDS[0]; Dataset testDS = setsDS[1]; Map> ldaModels = new HashMap<>(); for(int k: numTopics) { LDAModel ldaModel = trainLDAModel(trainDS, k, maxIter); double perplexity = ldaModel.logPerplexity(testDS); ldaModels.put(k, new Tuple2<>(ldaModel, perplexity)); } return ldaModels; } /** * Generate the LDA topic inference of the given data * * @param inputDS: input data in the form (id, features) * @param ldaModel: the LDA model to be used for the inference * @return the LDA inference of the input data in the form (id, vectors) */ public static Dataset ldaInference(Dataset inputDS, LDAModel ldaModel) { return ldaModel.transform(inputDS).select(ID_COL, LDA_INFERENCE_OUTPUT_COL); } }