function for lda pipeline
This commit is contained in:
parent
67d8a57e61
commit
f4c7fc1c15
|
@ -30,10 +30,24 @@
|
|||
<groupId>org.apache.spark</groupId>
|
||||
<artifactId>spark-mllib_2.11</artifactId>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>com.jayway.jsonpath</groupId>
|
||||
<artifactId>json-path</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.deeplearning4j</groupId>
|
||||
<artifactId>dl4j-spark-parameterserver_2.11</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.deeplearning4j</groupId>
|
||||
<artifactId>dl4j-spark_2.11</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.nd4j</groupId>
|
||||
<artifactId>${nd4j.backend}</artifactId>
|
||||
</dependency>
|
||||
|
||||
</dependencies>
|
||||
|
||||
</project>
|
|
@ -1,25 +1,24 @@
|
|||
package eu.dnetlib.featureextraction;
|
||||
|
||||
import java.io.BufferedReader;
|
||||
import java.io.FileReader;
|
||||
import java.io.IOException;
|
||||
import java.io.Serializable;
|
||||
import java.util.*;
|
||||
|
||||
import eu.dnetlib.featureextraction.util.Utilities;
|
||||
import org.apache.spark.ml.Model;
|
||||
import org.apache.spark.ml.clustering.LDA;
|
||||
import org.apache.spark.ml.clustering.LDAModel;
|
||||
import org.apache.spark.ml.evaluation.Evaluator;
|
||||
import org.apache.spark.ml.feature.*;
|
||||
import org.apache.spark.ml.param.ParamMap;
|
||||
import org.apache.spark.ml.tuning.ParamGridBuilder;
|
||||
import org.apache.spark.ml.tuning.TrainValidationSplit;
|
||||
import org.apache.spark.ml.tuning.TrainValidationSplitModel;
|
||||
import org.apache.spark.ml.feature.CountVectorizer;
|
||||
import org.apache.spark.ml.feature.CountVectorizerModel;
|
||||
import org.apache.spark.ml.feature.StopWordsRemover;
|
||||
import org.apache.spark.ml.feature.Tokenizer;
|
||||
import org.apache.spark.sql.Dataset;
|
||||
import org.apache.spark.sql.Row;
|
||||
import scala.Tuple2;
|
||||
|
||||
import java.io.BufferedReader;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStreamReader;
|
||||
import java.io.Serializable;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
|
||||
public class FeatureTransformer implements Serializable {
|
||||
|
||||
public static String ID_COL = "id";
|
||||
|
@ -27,6 +26,7 @@ public class FeatureTransformer implements Serializable {
|
|||
public static String TOKENIZER_OUTPUT_COL = "rawTokens";
|
||||
public static String STOPWORDREMOVER_OUTPUT_COL = "tokens";
|
||||
public static String COUNTVECTORIZER_OUTPUT_COL = "features";
|
||||
public static String LDA_INFERENCE_OUTPUT_COL = "topicDistribution";
|
||||
public static String LDA_OPTIMIZER = "online";
|
||||
|
||||
/**
|
||||
|
@ -64,14 +64,18 @@ public class FeatureTransformer implements Serializable {
|
|||
}
|
||||
|
||||
/**
|
||||
* Create the vocabulary from file.
|
||||
* Create the vocabulary from resource file.
|
||||
*
|
||||
* @param inputFilePath: the input file with the vocabulary elements (one element for line)
|
||||
* @return the vocabulary
|
||||
*/
|
||||
public static CountVectorizerModel createVocabularyFromFile(String inputFilePath) throws IOException {
|
||||
public static CountVectorizerModel createVocabularyFromFile() throws IOException {
|
||||
|
||||
Set<String> fileLines = new HashSet<>();
|
||||
BufferedReader bf = new BufferedReader(new FileReader(inputFilePath));
|
||||
BufferedReader bf = new BufferedReader(
|
||||
new InputStreamReader(
|
||||
FeatureTransformer.class.getResourceAsStream("/eu/dnetlib/featureextraction/support/dewey_vocabulary.txt")
|
||||
)
|
||||
);
|
||||
String line = bf.readLine();
|
||||
while(line != null) {
|
||||
fileLines.add(line);
|
||||
|
@ -123,6 +127,15 @@ public class FeatureTransformer implements Serializable {
|
|||
return lda.fit(inputDS);
|
||||
}
|
||||
|
||||
/**
|
||||
* Tune the LDA model varying the number of topics.
|
||||
*
|
||||
* @param dataDS: the input data in the form (id, features)
|
||||
* @param trainRatio: percentage of the input data to be used as training set
|
||||
* @param numTopics: topics to which test the LDA model
|
||||
* @param maxIter: maximum number of iterations of the algorithm
|
||||
* @return map of trained model with perplexity
|
||||
* */
|
||||
public static Map<Integer, Tuple2<LDAModel, Double>> ldaTuning(Dataset<Row> dataDS, double trainRatio, int[] numTopics, int maxIter) {
|
||||
Dataset<Row>[] setsDS = dataDS.randomSplit(new double[]{trainRatio, 1 - trainRatio});
|
||||
Dataset<Row> trainDS = setsDS[0];
|
||||
|
@ -136,6 +149,16 @@ public class FeatureTransformer implements Serializable {
|
|||
}
|
||||
|
||||
return ldaModels;
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate the LDA topic inference of the given data
|
||||
*
|
||||
* @param inputDS: input data in the form (id, features)
|
||||
* @param ldaModel: the LDA model to be used for the inference
|
||||
* @return the LDA inference of the input data in the form (id, vectors)
|
||||
*/
|
||||
public static Dataset<Row> ldaInference(Dataset<Row> inputDS, LDAModel ldaModel) {
|
||||
return ldaModel.transform(inputDS).select(ID_COL, LDA_INFERENCE_OUTPUT_COL);
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue