142 lines
5.4 KiB
Java
142 lines
5.4 KiB
Java
package eu.dnetlib.featureextraction;
|
|
|
|
import java.io.BufferedReader;
|
|
import java.io.FileReader;
|
|
import java.io.IOException;
|
|
import java.io.Serializable;
|
|
import java.util.*;
|
|
|
|
import eu.dnetlib.featureextraction.util.Utilities;
|
|
import org.apache.spark.ml.Model;
|
|
import org.apache.spark.ml.clustering.LDA;
|
|
import org.apache.spark.ml.clustering.LDAModel;
|
|
import org.apache.spark.ml.evaluation.Evaluator;
|
|
import org.apache.spark.ml.feature.*;
|
|
import org.apache.spark.ml.param.ParamMap;
|
|
import org.apache.spark.ml.tuning.ParamGridBuilder;
|
|
import org.apache.spark.ml.tuning.TrainValidationSplit;
|
|
import org.apache.spark.ml.tuning.TrainValidationSplitModel;
|
|
import org.apache.spark.sql.Dataset;
|
|
import org.apache.spark.sql.Row;
|
|
import scala.Tuple2;
|
|
|
|
public class FeatureTransformer implements Serializable {
|
|
|
|
public static String ID_COL = "id";
|
|
public static String TOKENIZER_INPUT_COL = "sentence";
|
|
public static String TOKENIZER_OUTPUT_COL = "rawTokens";
|
|
public static String STOPWORDREMOVER_OUTPUT_COL = "tokens";
|
|
public static String COUNTVECTORIZER_OUTPUT_COL = "features";
|
|
public static String LDA_OPTIMIZER = "online";
|
|
|
|
/**
|
|
* Returns the tokenization of the data without stopwords
|
|
*
|
|
* @param inputDS: the input dataset of the form (id, sentence)
|
|
* @return the tokenized data (id, tokens)
|
|
*/
|
|
public static Dataset<Row> tokenizeData(Dataset<Row> inputDS) {
|
|
Tokenizer tokenizer = new Tokenizer().setInputCol(TOKENIZER_INPUT_COL).setOutputCol(TOKENIZER_OUTPUT_COL);
|
|
StopWordsRemover remover = new StopWordsRemover().setInputCol(TOKENIZER_OUTPUT_COL).setOutputCol(STOPWORDREMOVER_OUTPUT_COL);
|
|
//TODO consider implementing stemming with SparkNLP library from johnsnowlab
|
|
Dataset<Row> rawTokensDS = tokenizer.transform(inputDS).select(ID_COL, TOKENIZER_OUTPUT_COL);
|
|
return remover.transform(rawTokensDS).select(ID_COL, STOPWORDREMOVER_OUTPUT_COL);
|
|
}
|
|
|
|
/**
|
|
* Create the vocabulary from the given data.
|
|
*
|
|
* @param inputDS: the input dataset of the form (id, tokens)
|
|
* @param minDF: minimum number of different documents a term could appear in to be included in the vocabulary
|
|
* @param minTF: filter to ignore rare words in a document
|
|
* @param vocabSize: maximum size of the vocabulary (number of terms)
|
|
* @return the vocabulary
|
|
*/
|
|
public static CountVectorizerModel createVocabularyFromTokens(Dataset<Row> inputDS, double minDF, double minTF, int vocabSize) {
|
|
return new CountVectorizer()
|
|
.setInputCol(STOPWORDREMOVER_OUTPUT_COL)
|
|
.setOutputCol(COUNTVECTORIZER_OUTPUT_COL)
|
|
.setMinDF(minDF)
|
|
.setMinTF(minTF)
|
|
.setVocabSize(vocabSize)
|
|
.fit(inputDS);
|
|
//TODO setMaxDF not found, try to add it
|
|
}
|
|
|
|
/**
|
|
* Create the vocabulary from file.
|
|
*
|
|
* @param inputFilePath: the input file with the vocabulary elements (one element for line)
|
|
* @return the vocabulary
|
|
*/
|
|
public static CountVectorizerModel createVocabularyFromFile(String inputFilePath) throws IOException {
|
|
Set<String> fileLines = new HashSet<>();
|
|
BufferedReader bf = new BufferedReader(new FileReader(inputFilePath));
|
|
String line = bf.readLine();
|
|
while(line != null) {
|
|
fileLines.add(line);
|
|
line = bf.readLine();
|
|
}
|
|
bf.close();
|
|
|
|
return new CountVectorizerModel(fileLines.toArray(new String[0])).setInputCol(STOPWORDREMOVER_OUTPUT_COL).setOutputCol(COUNTVECTORIZER_OUTPUT_COL);
|
|
|
|
}
|
|
|
|
/**
|
|
* Load an existing vocabulary
|
|
*
|
|
* @param vocabularyPath: location of the vocabulary
|
|
* @return the vocabulary
|
|
*/
|
|
public static CountVectorizerModel loadVocabulary(String vocabularyPath) {
|
|
return CountVectorizerModel.load(vocabularyPath);
|
|
}
|
|
|
|
/**
|
|
* Count vectorize data.
|
|
*
|
|
* @param inputDS: the input dataset of the form (id, tokens)
|
|
* @param vocabulary: the vocabulary to be used for the transformation
|
|
* @return the count vectorized data
|
|
*/
|
|
public static Dataset<Row> countVectorizeData(Dataset<Row> inputDS, CountVectorizerModel vocabulary) {
|
|
return vocabulary.transform(inputDS).select(ID_COL, COUNTVECTORIZER_OUTPUT_COL);
|
|
}
|
|
|
|
/**
|
|
* Train LDA model with the given parameters
|
|
*
|
|
* @param inputDS: the input dataset
|
|
* @param k: number of topics
|
|
* @param maxIter: maximum number of iterations
|
|
* @return the LDA model
|
|
*/
|
|
public static LDAModel trainLDAModel(Dataset<Row> inputDS, int k, int maxIter) {
|
|
|
|
LDA lda = new LDA()
|
|
.setK(k)
|
|
.setMaxIter(maxIter)
|
|
.setFeaturesCol(COUNTVECTORIZER_OUTPUT_COL)
|
|
.setOptimizer(LDA_OPTIMIZER);
|
|
|
|
return lda.fit(inputDS);
|
|
}
|
|
|
|
public static Map<Integer, Tuple2<LDAModel, Double>> ldaTuning(Dataset<Row> dataDS, double trainRatio, int[] numTopics, int maxIter) {
|
|
Dataset<Row>[] setsDS = dataDS.randomSplit(new double[]{trainRatio, 1 - trainRatio});
|
|
Dataset<Row> trainDS = setsDS[0];
|
|
Dataset<Row> testDS = setsDS[1];
|
|
Map<Integer, Tuple2<LDAModel, Double>> ldaModels = new HashMap<>();
|
|
|
|
for(int k: numTopics) {
|
|
LDAModel ldaModel = trainLDAModel(trainDS, k, maxIter);
|
|
double perplexity = ldaModel.logPerplexity(testDS);
|
|
ldaModels.put(k, new Tuple2<>(ldaModel, perplexity));
|
|
}
|
|
|
|
return ldaModels;
|
|
|
|
}
|
|
}
|