dnet-and/dnet-feature-extraction/src/main/java/eu/dnetlib/featureextraction/FeatureTransformer.java

176 lines
6.8 KiB
Java

package eu.dnetlib.featureextraction;
import com.google.common.collect.Lists;
import com.johnsnowlabs.nlp.*;
import com.johnsnowlabs.nlp.annotators.sbd.pragmatic.SentenceDetector;
import com.johnsnowlabs.nlp.embeddings.BertSentenceEmbeddings;
import org.apache.spark.ml.Pipeline;
import org.apache.spark.ml.PipelineStage;
import org.apache.spark.ml.clustering.LDA;
import org.apache.spark.ml.clustering.LDAModel;
import org.apache.spark.ml.feature.CountVectorizer;
import org.apache.spark.ml.feature.CountVectorizerModel;
import org.apache.spark.ml.feature.StopWordsRemover;
import org.apache.spark.ml.feature.Tokenizer;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import scala.Tuple2;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.Serializable;
import java.net.URISyntaxException;
import java.nio.file.Paths;
import java.util.*;
public class FeatureTransformer implements Serializable {
public static String ID_COL = "id";
public static String TOKENIZER_INPUT_COL = "sentence";
public static String TOKENIZER_OUTPUT_COL = "rawTokens";
public static String STOPWORDREMOVER_OUTPUT_COL = "tokens";
public static String COUNTVECTORIZER_OUTPUT_COL = "features";
public static String LDA_INFERENCE_OUTPUT_COL = "topicDistribution";
public static String LDA_OPTIMIZER = "online";
/**
* Returns the tokenization of the data without stopwords
*
* @param inputDS: the input dataset of the form (id, sentence)
* @return the tokenized data (id, tokens)
*/
public static Dataset<Row> tokenizeData(Dataset<Row> inputDS) {
Tokenizer tokenizer = new Tokenizer().setInputCol(TOKENIZER_INPUT_COL).setOutputCol(TOKENIZER_OUTPUT_COL);
StopWordsRemover remover = new StopWordsRemover().setInputCol(TOKENIZER_OUTPUT_COL).setOutputCol(STOPWORDREMOVER_OUTPUT_COL);
//TODO consider implementing stemming with SparkNLP library from johnsnowlab
Dataset<Row> rawTokensDS = tokenizer.transform(inputDS).select(ID_COL, TOKENIZER_OUTPUT_COL);
return remover.transform(rawTokensDS).select(ID_COL, STOPWORDREMOVER_OUTPUT_COL);
}
/**
* Create the vocabulary from the given data.
*
* @param inputDS: the input dataset of the form (id, tokens)
* @param minDF: minimum number of different documents a term could appear in to be included in the vocabulary
* @param minTF: filter to ignore rare words in a document
* @param vocabSize: maximum size of the vocabulary (number of terms)
* @return the vocabulary
*/
public static CountVectorizerModel createVocabularyFromTokens(Dataset<Row> inputDS, double minDF, double minTF, int vocabSize) {
return new CountVectorizer()
.setInputCol(STOPWORDREMOVER_OUTPUT_COL)
.setOutputCol(COUNTVECTORIZER_OUTPUT_COL)
.setMinDF(minDF)
.setMinTF(minTF)
.setVocabSize(vocabSize)
.fit(inputDS);
//TODO setMaxDF not found, try to add it
}
/**
* Create the vocabulary from resource file.
*
* @return the vocabulary
*/
public static CountVectorizerModel createVocabularyFromFile() throws IOException {
Set<String> fileLines = new HashSet<>();
BufferedReader bf = new BufferedReader(
new InputStreamReader(
FeatureTransformer.class.getResourceAsStream("/eu/dnetlib/featureextraction/support/dewey_vocabulary.txt")
)
);
String line = bf.readLine();
while(line != null) {
fileLines.add(line);
line = bf.readLine();
}
bf.close();
return new CountVectorizerModel(fileLines.toArray(new String[0])).setInputCol(STOPWORDREMOVER_OUTPUT_COL).setOutputCol(COUNTVECTORIZER_OUTPUT_COL);
}
/**
* Load an existing vocabulary
*
* @param vocabularyPath: location of the vocabulary
* @return the vocabulary
*/
public static CountVectorizerModel loadVocabulary(String vocabularyPath) {
return CountVectorizerModel.load(vocabularyPath);
}
/**
* Count vectorize data.
*
* @param inputDS: the input dataset of the form (id, tokens)
* @param vocabulary: the vocabulary to be used for the transformation
* @return the count vectorized data
*/
public static Dataset<Row> countVectorizeData(Dataset<Row> inputDS, CountVectorizerModel vocabulary) {
return vocabulary.transform(inputDS).select(ID_COL, COUNTVECTORIZER_OUTPUT_COL);
}
/**
* Train LDA model with the given parameters
*
* @param inputDS: the input dataset
* @param k: number of topics
* @param maxIter: maximum number of iterations
* @return the LDA model
*/
public static LDAModel trainLDAModel(Dataset<Row> inputDS, int k, int maxIter) {
LDA lda = new LDA()
.setK(k)
.setMaxIter(maxIter)
.setFeaturesCol(COUNTVECTORIZER_OUTPUT_COL)
.setOptimizer(LDA_OPTIMIZER);
return lda.fit(inputDS);
}
/**
* Tune the LDA model varying the number of topics.
*
* @param dataDS: the input data in the form (id, features)
* @param trainRatio: percentage of the input data to be used as training set
* @param numTopics: topics to which test the LDA model
* @param maxIter: maximum number of iterations of the algorithm
* @return map of trained model with perplexity
* */
public static Map<Integer, Tuple2<LDAModel, Double>> ldaTuning(Dataset<Row> dataDS, double trainRatio, int[] numTopics, int maxIter) {
Dataset<Row>[] setsDS = dataDS.randomSplit(new double[]{trainRatio, 1 - trainRatio});
Dataset<Row> trainDS = setsDS[0];
Dataset<Row> testDS = setsDS[1];
Map<Integer, Tuple2<LDAModel, Double>> ldaModels = new HashMap<>();
for(int k: numTopics) {
LDAModel ldaModel = trainLDAModel(trainDS, k, maxIter);
double perplexity = ldaModel.logPerplexity(testDS);
ldaModels.put(k, new Tuple2<>(ldaModel, perplexity));
}
return ldaModels;
}
/**
* Generate the LDA topic inference of the given data
*
* @param inputDS: input data in the form (id, features)
* @param ldaModel: the LDA model to be used for the inference
* @return the LDA inference of the input data in the form (id, vectors)
*/
public static Dataset<Row> ldaInference(Dataset<Row> inputDS, LDAModel ldaModel) {
return ldaModel.transform(inputDS).select(ID_COL, LDA_INFERENCE_OUTPUT_COL);
}
}