function for lda pipeline

This commit is contained in:
Michele De Bonis 2023-04-11 09:03:22 +02:00
parent 67d8a57e61
commit f4c7fc1c15
2 changed files with 55 additions and 18 deletions

View File

@ -30,10 +30,24 @@
<groupId>org.apache.spark</groupId>
<artifactId>spark-mllib_2.11</artifactId>
</dependency>
<dependency>
<groupId>com.jayway.jsonpath</groupId>
<artifactId>json-path</artifactId>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>dl4j-spark-parameterserver_2.11</artifactId>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>dl4j-spark_2.11</artifactId>
</dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>${nd4j.backend}</artifactId>
</dependency>
</dependencies>
</project>

View File

@ -1,25 +1,24 @@
package eu.dnetlib.featureextraction;
import java.io.BufferedReader;
import java.io.FileReader;
import java.io.IOException;
import java.io.Serializable;
import java.util.*;
import eu.dnetlib.featureextraction.util.Utilities;
import org.apache.spark.ml.Model;
import org.apache.spark.ml.clustering.LDA;
import org.apache.spark.ml.clustering.LDAModel;
import org.apache.spark.ml.evaluation.Evaluator;
import org.apache.spark.ml.feature.*;
import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.ml.tuning.ParamGridBuilder;
import org.apache.spark.ml.tuning.TrainValidationSplit;
import org.apache.spark.ml.tuning.TrainValidationSplitModel;
import org.apache.spark.ml.feature.CountVectorizer;
import org.apache.spark.ml.feature.CountVectorizerModel;
import org.apache.spark.ml.feature.StopWordsRemover;
import org.apache.spark.ml.feature.Tokenizer;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import scala.Tuple2;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.Serializable;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
public class FeatureTransformer implements Serializable {
public static String ID_COL = "id";
@ -27,6 +26,7 @@ public class FeatureTransformer implements Serializable {
public static String TOKENIZER_OUTPUT_COL = "rawTokens";
public static String STOPWORDREMOVER_OUTPUT_COL = "tokens";
public static String COUNTVECTORIZER_OUTPUT_COL = "features";
public static String LDA_INFERENCE_OUTPUT_COL = "topicDistribution";
public static String LDA_OPTIMIZER = "online";
/**
@ -64,14 +64,18 @@ public class FeatureTransformer implements Serializable {
}
/**
* Create the vocabulary from file.
* Create the vocabulary from resource file.
*
* @param inputFilePath: the input file with the vocabulary elements (one element for line)
* @return the vocabulary
*/
public static CountVectorizerModel createVocabularyFromFile(String inputFilePath) throws IOException {
public static CountVectorizerModel createVocabularyFromFile() throws IOException {
Set<String> fileLines = new HashSet<>();
BufferedReader bf = new BufferedReader(new FileReader(inputFilePath));
BufferedReader bf = new BufferedReader(
new InputStreamReader(
FeatureTransformer.class.getResourceAsStream("/eu/dnetlib/featureextraction/support/dewey_vocabulary.txt")
)
);
String line = bf.readLine();
while(line != null) {
fileLines.add(line);
@ -123,6 +127,15 @@ public class FeatureTransformer implements Serializable {
return lda.fit(inputDS);
}
/**
* Tune the LDA model varying the number of topics.
*
* @param dataDS: the input data in the form (id, features)
* @param trainRatio: percentage of the input data to be used as training set
* @param numTopics: topics to which test the LDA model
* @param maxIter: maximum number of iterations of the algorithm
* @return map of trained model with perplexity
* */
public static Map<Integer, Tuple2<LDAModel, Double>> ldaTuning(Dataset<Row> dataDS, double trainRatio, int[] numTopics, int maxIter) {
Dataset<Row>[] setsDS = dataDS.randomSplit(new double[]{trainRatio, 1 - trainRatio});
Dataset<Row> trainDS = setsDS[0];
@ -136,6 +149,16 @@ public class FeatureTransformer implements Serializable {
}
return ldaModels;
}
/**
* Generate the LDA topic inference of the given data
*
* @param inputDS: input data in the form (id, features)
* @param ldaModel: the LDA model to be used for the inference
* @return the LDA inference of the input data in the form (id, vectors)
*/
public static Dataset<Row> ldaInference(Dataset<Row> inputDS, LDAModel ldaModel) {
return ldaModel.transform(inputDS).select(ID_COL, LDA_INFERENCE_OUTPUT_COL);
}
}