From 5aa559cb42dce06b60fabbcb374e02d497e4ad39 Mon Sep 17 00:00:00 2001 From: miconis Date: Mon, 3 Apr 2023 09:41:46 +0200 Subject: [PATCH] feature transformer implementation: lda model, count vectorizer and tokenizer --- dnet-feature-extraction/pom.xml | 29 +++- .../featureextraction/FeatureTransformer.java | 141 +++++++++++++++++- .../featureextraction/lda/LDAModeler.java | 9 +- .../featureextraction/util/Utilities.java | 102 ++++++++++++- 4 files changed, 273 insertions(+), 8 deletions(-) diff --git a/dnet-feature-extraction/pom.xml b/dnet-feature-extraction/pom.xml index fb8c7ca..5b27eef 100644 --- a/dnet-feature-extraction/pom.xml +++ b/dnet-feature-extraction/pom.xml @@ -7,14 +7,33 @@ eu.dnetlib dnet-and 1.0.0-SNAPSHOT + ../pom.xml dnet-feature-extraction + jar - - 8 - 8 - UTF-8 - + + + org.apache.spark + spark-core_2.11 + + + org.apache.spark + spark-graphx_2.11 + + + org.apache.spark + spark-sql_2.11 + + + org.apache.spark + spark-mllib_2.11 + + + com.jayway.jsonpath + json-path + + \ No newline at end of file diff --git a/dnet-feature-extraction/src/main/java/eu/dnetlib/featureextraction/FeatureTransformer.java b/dnet-feature-extraction/src/main/java/eu/dnetlib/featureextraction/FeatureTransformer.java index 3e7a521..5c3defc 100644 --- a/dnet-feature-extraction/src/main/java/eu/dnetlib/featureextraction/FeatureTransformer.java +++ b/dnet-feature-extraction/src/main/java/eu/dnetlib/featureextraction/FeatureTransformer.java @@ -1,2 +1,141 @@ -package eu.dnetlib.featureextraction;public class FeatureTransformer { +package eu.dnetlib.featureextraction; + +import java.io.BufferedReader; +import java.io.FileReader; +import java.io.IOException; +import java.io.Serializable; +import java.util.*; + +import eu.dnetlib.featureextraction.util.Utilities; +import org.apache.spark.ml.Model; +import org.apache.spark.ml.clustering.LDA; +import org.apache.spark.ml.clustering.LDAModel; +import org.apache.spark.ml.evaluation.Evaluator; +import org.apache.spark.ml.feature.*; +import org.apache.spark.ml.param.ParamMap; +import org.apache.spark.ml.tuning.ParamGridBuilder; +import org.apache.spark.ml.tuning.TrainValidationSplit; +import org.apache.spark.ml.tuning.TrainValidationSplitModel; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import scala.Tuple2; + +public class FeatureTransformer implements Serializable { + + public static String ID_COL = "id"; + public static String TOKENIZER_INPUT_COL = "sentence"; + public static String TOKENIZER_OUTPUT_COL = "rawTokens"; + public static String STOPWORDREMOVER_OUTPUT_COL = "tokens"; + public static String COUNTVECTORIZER_OUTPUT_COL = "features"; + public static String LDA_OPTIMIZER = "online"; + + /** + * Returns the tokenization of the data without stopwords + * + * @param inputDS: the input dataset of the form (id, sentence) + * @return the tokenized data (id, tokens) + */ + public static Dataset tokenizeData(Dataset inputDS) { + Tokenizer tokenizer = new Tokenizer().setInputCol(TOKENIZER_INPUT_COL).setOutputCol(TOKENIZER_OUTPUT_COL); + StopWordsRemover remover = new StopWordsRemover().setInputCol(TOKENIZER_OUTPUT_COL).setOutputCol(STOPWORDREMOVER_OUTPUT_COL); + //TODO consider implementing stemming with SparkNLP library from johnsnowlab + Dataset rawTokensDS = tokenizer.transform(inputDS).select(ID_COL, TOKENIZER_OUTPUT_COL); + return remover.transform(rawTokensDS).select(ID_COL, STOPWORDREMOVER_OUTPUT_COL); + } + + /** + * Create the vocabulary from the given data. + * + * @param inputDS: the input dataset of the form (id, tokens) + * @param minDF: minimum number of different documents a term could appear in to be included in the vocabulary + * @param minTF: filter to ignore rare words in a document + * @param vocabSize: maximum size of the vocabulary (number of terms) + * @return the vocabulary + */ + public static CountVectorizerModel createVocabularyFromTokens(Dataset inputDS, double minDF, double minTF, int vocabSize) { + return new CountVectorizer() + .setInputCol(STOPWORDREMOVER_OUTPUT_COL) + .setOutputCol(COUNTVECTORIZER_OUTPUT_COL) + .setMinDF(minDF) + .setMinTF(minTF) + .setVocabSize(vocabSize) + .fit(inputDS); + //TODO setMaxDF not found, try to add it + } + + /** + * Create the vocabulary from file. + * + * @param inputFilePath: the input file with the vocabulary elements (one element for line) + * @return the vocabulary + */ + public static CountVectorizerModel createVocabularyFromFile(String inputFilePath) throws IOException { + Set fileLines = new HashSet<>(); + BufferedReader bf = new BufferedReader(new FileReader(inputFilePath)); + String line = bf.readLine(); + while(line != null) { + fileLines.add(line); + line = bf.readLine(); + } + bf.close(); + + return new CountVectorizerModel(fileLines.toArray(new String[0])).setInputCol(STOPWORDREMOVER_OUTPUT_COL).setOutputCol(COUNTVECTORIZER_OUTPUT_COL); + + } + + /** + * Load an existing vocabulary + * + * @param vocabularyPath: location of the vocabulary + * @return the vocabulary + */ + public static CountVectorizerModel loadVocabulary(String vocabularyPath) { + return CountVectorizerModel.load(vocabularyPath); + } + + /** + * Count vectorize data. + * + * @param inputDS: the input dataset of the form (id, tokens) + * @param vocabulary: the vocabulary to be used for the transformation + * @return the count vectorized data + */ + public static Dataset countVectorizeData(Dataset inputDS, CountVectorizerModel vocabulary) { + return vocabulary.transform(inputDS).select(ID_COL, COUNTVECTORIZER_OUTPUT_COL); + } + + /** + * Train LDA model with the given parameters + * + * @param inputDS: the input dataset + * @param k: number of topics + * @param maxIter: maximum number of iterations + * @return the LDA model + */ + public static LDAModel trainLDAModel(Dataset inputDS, int k, int maxIter) { + + LDA lda = new LDA() + .setK(k) + .setMaxIter(maxIter) + .setFeaturesCol(COUNTVECTORIZER_OUTPUT_COL) + .setOptimizer(LDA_OPTIMIZER); + + return lda.fit(inputDS); + } + + public static Map> ldaTuning(Dataset dataDS, double trainRatio, int[] numTopics, int maxIter) { + Dataset[] setsDS = dataDS.randomSplit(new double[]{trainRatio, 1 - trainRatio}); + Dataset trainDS = setsDS[0]; + Dataset testDS = setsDS[1]; + Map> ldaModels = new HashMap<>(); + + for(int k: numTopics) { + LDAModel ldaModel = trainLDAModel(trainDS, k, maxIter); + double perplexity = ldaModel.logPerplexity(testDS); + ldaModels.put(k, new Tuple2<>(ldaModel, perplexity)); + } + + return ldaModels; + + } } diff --git a/dnet-feature-extraction/src/main/java/eu/dnetlib/featureextraction/lda/LDAModeler.java b/dnet-feature-extraction/src/main/java/eu/dnetlib/featureextraction/lda/LDAModeler.java index fa4c6f8..a01e748 100644 --- a/dnet-feature-extraction/src/main/java/eu/dnetlib/featureextraction/lda/LDAModeler.java +++ b/dnet-feature-extraction/src/main/java/eu/dnetlib/featureextraction/lda/LDAModeler.java @@ -1,2 +1,9 @@ -package eu.dnetlib.featureextraction.lda;public class LDAModeler { +package eu.dnetlib.featureextraction.lda; + +public class LDAModeler { + + + public static void main(String[] args) { + System.out.println("prova"); + } } diff --git a/dnet-feature-extraction/src/main/java/eu/dnetlib/featureextraction/util/Utilities.java b/dnet-feature-extraction/src/main/java/eu/dnetlib/featureextraction/util/Utilities.java index f7416fb..2e82e8d 100644 --- a/dnet-feature-extraction/src/main/java/eu/dnetlib/featureextraction/util/Utilities.java +++ b/dnet-feature-extraction/src/main/java/eu/dnetlib/featureextraction/util/Utilities.java @@ -1,2 +1,102 @@ -package eu.dnetlib.featureextraction.util;public class Utilities { +package eu.dnetlib.featureextraction.util; + +import com.jayway.jsonpath.JsonPath; +import net.minidev.json.JSONArray; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FSDataOutputStream; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +import java.io.IOException; +import java.io.Serializable; +import java.text.Normalizer; +import java.util.List; + +public class Utilities implements Serializable { + + public static String DATA_ID_FIELD = "$.id"; + + static StructType inputSchema = new StructType(new StructField[]{ + new StructField("id", DataTypes.StringType, false, Metadata.empty()), + new StructField("sentence", DataTypes.StringType, false, Metadata.empty()) + }); + + /** + * Returns a view of the dataset including the id and the chosen field. + * + * @param sqlContext: the spark sql context + * @param jsonRDD: the input dataset + * @param inputFieldJPath: the input field jpath + * @return the view of the dataset with normalized data of the inputField (id, inputField) + */ + public static Dataset prepareDataset(SQLContext sqlContext, JavaRDD jsonRDD, String inputFieldJPath) { + + JavaRDD rowRDD = jsonRDD + .map(json -> + RowFactory.create(getJPathString(DATA_ID_FIELD, json), Utilities.normalize(getJPathString(inputFieldJPath, json)))); + return sqlContext.createDataFrame(rowRDD, inputSchema); + } + + //returns the string value of the jpath in the given input json + public static String getJPathString(final String jsonPath, final String inputJson) { + try { + Object o = JsonPath.read(inputJson, jsonPath); + if (o instanceof String) + return (String)o; + if (o instanceof JSONArray && ((JSONArray)o).size()>0) + return (String)((JSONArray)o).get(0); + return ""; + } + catch (Exception e) { + return ""; + } + } + + public static String normalize(final String s) { + return Normalizer.normalize(s, Normalizer.Form.NFD) + .replaceAll("[^\\w\\s-]", "") // Remove all non-word, non-space or non-dash characters + .replace('-', ' ') // Replace dashes with spaces + .trim() // trim leading/trailing whitespace (including what used to be leading/trailing dashes) + .toLowerCase(); // Lowercase the final results + } + + public static void writeLinesToHDFSFile(List lines, String filePath) throws IOException { + Configuration conf = new Configuration(); + + FileSystem fs = FileSystem.get(conf); + fs.delete(new Path(filePath), true); + + try { + fs = FileSystem.get(conf); + + Path outFile = new Path(filePath); + // Verification + if (fs.exists(outFile)) { + System.out.println("Output file already exists"); + throw new IOException("Output file already exists"); + } + + // Create file to write + FSDataOutputStream out = fs.create(outFile); + try{ + for (String line: lines) { + out.writeBytes(line + "\n"); + } + } + finally { + out.close(); + } + } catch (IOException e) { + e.printStackTrace(); + } + } }