feature transformer implementation: lda model, count vectorizer and tokenizer
This commit is contained in:
parent
be20c4e67e
commit
5aa559cb42
|
@ -7,14 +7,33 @@
|
|||
<groupId>eu.dnetlib</groupId>
|
||||
<artifactId>dnet-and</artifactId>
|
||||
<version>1.0.0-SNAPSHOT</version>
|
||||
<relativePath>../pom.xml</relativePath>
|
||||
</parent>
|
||||
|
||||
<artifactId>dnet-feature-extraction</artifactId>
|
||||
<packaging>jar</packaging>
|
||||
|
||||
<properties>
|
||||
<maven.compiler.source>8</maven.compiler.source>
|
||||
<maven.compiler.target>8</maven.compiler.target>
|
||||
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
|
||||
</properties>
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>org.apache.spark</groupId>
|
||||
<artifactId>spark-core_2.11</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.spark</groupId>
|
||||
<artifactId>spark-graphx_2.11</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.spark</groupId>
|
||||
<artifactId>spark-sql_2.11</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.spark</groupId>
|
||||
<artifactId>spark-mllib_2.11</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.jayway.jsonpath</groupId>
|
||||
<artifactId>json-path</artifactId>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
|
||||
</project>
|
|
@ -1,2 +1,141 @@
|
|||
package eu.dnetlib.featureextraction;public class FeatureTransformer {
|
||||
package eu.dnetlib.featureextraction;
|
||||
|
||||
import java.io.BufferedReader;
|
||||
import java.io.FileReader;
|
||||
import java.io.IOException;
|
||||
import java.io.Serializable;
|
||||
import java.util.*;
|
||||
|
||||
import eu.dnetlib.featureextraction.util.Utilities;
|
||||
import org.apache.spark.ml.Model;
|
||||
import org.apache.spark.ml.clustering.LDA;
|
||||
import org.apache.spark.ml.clustering.LDAModel;
|
||||
import org.apache.spark.ml.evaluation.Evaluator;
|
||||
import org.apache.spark.ml.feature.*;
|
||||
import org.apache.spark.ml.param.ParamMap;
|
||||
import org.apache.spark.ml.tuning.ParamGridBuilder;
|
||||
import org.apache.spark.ml.tuning.TrainValidationSplit;
|
||||
import org.apache.spark.ml.tuning.TrainValidationSplitModel;
|
||||
import org.apache.spark.sql.Dataset;
|
||||
import org.apache.spark.sql.Row;
|
||||
import scala.Tuple2;
|
||||
|
||||
public class FeatureTransformer implements Serializable {
|
||||
|
||||
public static String ID_COL = "id";
|
||||
public static String TOKENIZER_INPUT_COL = "sentence";
|
||||
public static String TOKENIZER_OUTPUT_COL = "rawTokens";
|
||||
public static String STOPWORDREMOVER_OUTPUT_COL = "tokens";
|
||||
public static String COUNTVECTORIZER_OUTPUT_COL = "features";
|
||||
public static String LDA_OPTIMIZER = "online";
|
||||
|
||||
/**
|
||||
* Returns the tokenization of the data without stopwords
|
||||
*
|
||||
* @param inputDS: the input dataset of the form (id, sentence)
|
||||
* @return the tokenized data (id, tokens)
|
||||
*/
|
||||
public static Dataset<Row> tokenizeData(Dataset<Row> inputDS) {
|
||||
Tokenizer tokenizer = new Tokenizer().setInputCol(TOKENIZER_INPUT_COL).setOutputCol(TOKENIZER_OUTPUT_COL);
|
||||
StopWordsRemover remover = new StopWordsRemover().setInputCol(TOKENIZER_OUTPUT_COL).setOutputCol(STOPWORDREMOVER_OUTPUT_COL);
|
||||
//TODO consider implementing stemming with SparkNLP library from johnsnowlab
|
||||
Dataset<Row> rawTokensDS = tokenizer.transform(inputDS).select(ID_COL, TOKENIZER_OUTPUT_COL);
|
||||
return remover.transform(rawTokensDS).select(ID_COL, STOPWORDREMOVER_OUTPUT_COL);
|
||||
}
|
||||
|
||||
/**
|
||||
* Create the vocabulary from the given data.
|
||||
*
|
||||
* @param inputDS: the input dataset of the form (id, tokens)
|
||||
* @param minDF: minimum number of different documents a term could appear in to be included in the vocabulary
|
||||
* @param minTF: filter to ignore rare words in a document
|
||||
* @param vocabSize: maximum size of the vocabulary (number of terms)
|
||||
* @return the vocabulary
|
||||
*/
|
||||
public static CountVectorizerModel createVocabularyFromTokens(Dataset<Row> inputDS, double minDF, double minTF, int vocabSize) {
|
||||
return new CountVectorizer()
|
||||
.setInputCol(STOPWORDREMOVER_OUTPUT_COL)
|
||||
.setOutputCol(COUNTVECTORIZER_OUTPUT_COL)
|
||||
.setMinDF(minDF)
|
||||
.setMinTF(minTF)
|
||||
.setVocabSize(vocabSize)
|
||||
.fit(inputDS);
|
||||
//TODO setMaxDF not found, try to add it
|
||||
}
|
||||
|
||||
/**
|
||||
* Create the vocabulary from file.
|
||||
*
|
||||
* @param inputFilePath: the input file with the vocabulary elements (one element for line)
|
||||
* @return the vocabulary
|
||||
*/
|
||||
public static CountVectorizerModel createVocabularyFromFile(String inputFilePath) throws IOException {
|
||||
Set<String> fileLines = new HashSet<>();
|
||||
BufferedReader bf = new BufferedReader(new FileReader(inputFilePath));
|
||||
String line = bf.readLine();
|
||||
while(line != null) {
|
||||
fileLines.add(line);
|
||||
line = bf.readLine();
|
||||
}
|
||||
bf.close();
|
||||
|
||||
return new CountVectorizerModel(fileLines.toArray(new String[0])).setInputCol(STOPWORDREMOVER_OUTPUT_COL).setOutputCol(COUNTVECTORIZER_OUTPUT_COL);
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* Load an existing vocabulary
|
||||
*
|
||||
* @param vocabularyPath: location of the vocabulary
|
||||
* @return the vocabulary
|
||||
*/
|
||||
public static CountVectorizerModel loadVocabulary(String vocabularyPath) {
|
||||
return CountVectorizerModel.load(vocabularyPath);
|
||||
}
|
||||
|
||||
/**
|
||||
* Count vectorize data.
|
||||
*
|
||||
* @param inputDS: the input dataset of the form (id, tokens)
|
||||
* @param vocabulary: the vocabulary to be used for the transformation
|
||||
* @return the count vectorized data
|
||||
*/
|
||||
public static Dataset<Row> countVectorizeData(Dataset<Row> inputDS, CountVectorizerModel vocabulary) {
|
||||
return vocabulary.transform(inputDS).select(ID_COL, COUNTVECTORIZER_OUTPUT_COL);
|
||||
}
|
||||
|
||||
/**
|
||||
* Train LDA model with the given parameters
|
||||
*
|
||||
* @param inputDS: the input dataset
|
||||
* @param k: number of topics
|
||||
* @param maxIter: maximum number of iterations
|
||||
* @return the LDA model
|
||||
*/
|
||||
public static LDAModel trainLDAModel(Dataset<Row> inputDS, int k, int maxIter) {
|
||||
|
||||
LDA lda = new LDA()
|
||||
.setK(k)
|
||||
.setMaxIter(maxIter)
|
||||
.setFeaturesCol(COUNTVECTORIZER_OUTPUT_COL)
|
||||
.setOptimizer(LDA_OPTIMIZER);
|
||||
|
||||
return lda.fit(inputDS);
|
||||
}
|
||||
|
||||
public static Map<Integer, Tuple2<LDAModel, Double>> ldaTuning(Dataset<Row> dataDS, double trainRatio, int[] numTopics, int maxIter) {
|
||||
Dataset<Row>[] setsDS = dataDS.randomSplit(new double[]{trainRatio, 1 - trainRatio});
|
||||
Dataset<Row> trainDS = setsDS[0];
|
||||
Dataset<Row> testDS = setsDS[1];
|
||||
Map<Integer, Tuple2<LDAModel, Double>> ldaModels = new HashMap<>();
|
||||
|
||||
for(int k: numTopics) {
|
||||
LDAModel ldaModel = trainLDAModel(trainDS, k, maxIter);
|
||||
double perplexity = ldaModel.logPerplexity(testDS);
|
||||
ldaModels.put(k, new Tuple2<>(ldaModel, perplexity));
|
||||
}
|
||||
|
||||
return ldaModels;
|
||||
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,2 +1,9 @@
|
|||
package eu.dnetlib.featureextraction.lda;public class LDAModeler {
|
||||
package eu.dnetlib.featureextraction.lda;
|
||||
|
||||
public class LDAModeler {
|
||||
|
||||
|
||||
public static void main(String[] args) {
|
||||
System.out.println("prova");
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,2 +1,102 @@
|
|||
package eu.dnetlib.featureextraction.util;public class Utilities {
|
||||
package eu.dnetlib.featureextraction.util;
|
||||
|
||||
import com.jayway.jsonpath.JsonPath;
|
||||
import net.minidev.json.JSONArray;
|
||||
import org.apache.hadoop.conf.Configuration;
|
||||
import org.apache.hadoop.fs.FSDataOutputStream;
|
||||
import org.apache.hadoop.fs.FileSystem;
|
||||
import org.apache.hadoop.fs.Path;
|
||||
import org.apache.spark.api.java.JavaRDD;
|
||||
import org.apache.spark.sql.Dataset;
|
||||
import org.apache.spark.sql.Row;
|
||||
import org.apache.spark.sql.RowFactory;
|
||||
import org.apache.spark.sql.SQLContext;
|
||||
import org.apache.spark.sql.types.DataTypes;
|
||||
import org.apache.spark.sql.types.Metadata;
|
||||
import org.apache.spark.sql.types.StructField;
|
||||
import org.apache.spark.sql.types.StructType;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.io.Serializable;
|
||||
import java.text.Normalizer;
|
||||
import java.util.List;
|
||||
|
||||
public class Utilities implements Serializable {
|
||||
|
||||
public static String DATA_ID_FIELD = "$.id";
|
||||
|
||||
static StructType inputSchema = new StructType(new StructField[]{
|
||||
new StructField("id", DataTypes.StringType, false, Metadata.empty()),
|
||||
new StructField("sentence", DataTypes.StringType, false, Metadata.empty())
|
||||
});
|
||||
|
||||
/**
|
||||
* Returns a view of the dataset including the id and the chosen field.
|
||||
*
|
||||
* @param sqlContext: the spark sql context
|
||||
* @param jsonRDD: the input dataset
|
||||
* @param inputFieldJPath: the input field jpath
|
||||
* @return the view of the dataset with normalized data of the inputField (id, inputField)
|
||||
*/
|
||||
public static Dataset<Row> prepareDataset(SQLContext sqlContext, JavaRDD<String> jsonRDD, String inputFieldJPath) {
|
||||
|
||||
JavaRDD<Row> rowRDD = jsonRDD
|
||||
.map(json ->
|
||||
RowFactory.create(getJPathString(DATA_ID_FIELD, json), Utilities.normalize(getJPathString(inputFieldJPath, json))));
|
||||
return sqlContext.createDataFrame(rowRDD, inputSchema);
|
||||
}
|
||||
|
||||
//returns the string value of the jpath in the given input json
|
||||
public static String getJPathString(final String jsonPath, final String inputJson) {
|
||||
try {
|
||||
Object o = JsonPath.read(inputJson, jsonPath);
|
||||
if (o instanceof String)
|
||||
return (String)o;
|
||||
if (o instanceof JSONArray && ((JSONArray)o).size()>0)
|
||||
return (String)((JSONArray)o).get(0);
|
||||
return "";
|
||||
}
|
||||
catch (Exception e) {
|
||||
return "";
|
||||
}
|
||||
}
|
||||
|
||||
public static String normalize(final String s) {
|
||||
return Normalizer.normalize(s, Normalizer.Form.NFD)
|
||||
.replaceAll("[^\\w\\s-]", "") // Remove all non-word, non-space or non-dash characters
|
||||
.replace('-', ' ') // Replace dashes with spaces
|
||||
.trim() // trim leading/trailing whitespace (including what used to be leading/trailing dashes)
|
||||
.toLowerCase(); // Lowercase the final results
|
||||
}
|
||||
|
||||
public static void writeLinesToHDFSFile(List<String> lines, String filePath) throws IOException {
|
||||
Configuration conf = new Configuration();
|
||||
|
||||
FileSystem fs = FileSystem.get(conf);
|
||||
fs.delete(new Path(filePath), true);
|
||||
|
||||
try {
|
||||
fs = FileSystem.get(conf);
|
||||
|
||||
Path outFile = new Path(filePath);
|
||||
// Verification
|
||||
if (fs.exists(outFile)) {
|
||||
System.out.println("Output file already exists");
|
||||
throw new IOException("Output file already exists");
|
||||
}
|
||||
|
||||
// Create file to write
|
||||
FSDataOutputStream out = fs.create(outFile);
|
||||
try{
|
||||
for (String line: lines) {
|
||||
out.writeBytes(line + "\n");
|
||||
}
|
||||
}
|
||||
finally {
|
||||
out.close();
|
||||
}
|
||||
} catch (IOException e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue