diff --git a/dnet-feature-extraction/pom.xml b/dnet-feature-extraction/pom.xml
index fb8c7ca..5b27eef 100644
--- a/dnet-feature-extraction/pom.xml
+++ b/dnet-feature-extraction/pom.xml
@@ -7,14 +7,33 @@
eu.dnetlib
dnet-and
1.0.0-SNAPSHOT
+ ../pom.xml
dnet-feature-extraction
+ jar
-
- 8
- 8
- UTF-8
-
+
+
+ org.apache.spark
+ spark-core_2.11
+
+
+ org.apache.spark
+ spark-graphx_2.11
+
+
+ org.apache.spark
+ spark-sql_2.11
+
+
+ org.apache.spark
+ spark-mllib_2.11
+
+
+ com.jayway.jsonpath
+ json-path
+
+
\ No newline at end of file
diff --git a/dnet-feature-extraction/src/main/java/eu/dnetlib/featureextraction/FeatureTransformer.java b/dnet-feature-extraction/src/main/java/eu/dnetlib/featureextraction/FeatureTransformer.java
index 3e7a521..5c3defc 100644
--- a/dnet-feature-extraction/src/main/java/eu/dnetlib/featureextraction/FeatureTransformer.java
+++ b/dnet-feature-extraction/src/main/java/eu/dnetlib/featureextraction/FeatureTransformer.java
@@ -1,2 +1,141 @@
-package eu.dnetlib.featureextraction;public class FeatureTransformer {
+package eu.dnetlib.featureextraction;
+
+import java.io.BufferedReader;
+import java.io.FileReader;
+import java.io.IOException;
+import java.io.Serializable;
+import java.util.*;
+
+import eu.dnetlib.featureextraction.util.Utilities;
+import org.apache.spark.ml.Model;
+import org.apache.spark.ml.clustering.LDA;
+import org.apache.spark.ml.clustering.LDAModel;
+import org.apache.spark.ml.evaluation.Evaluator;
+import org.apache.spark.ml.feature.*;
+import org.apache.spark.ml.param.ParamMap;
+import org.apache.spark.ml.tuning.ParamGridBuilder;
+import org.apache.spark.ml.tuning.TrainValidationSplit;
+import org.apache.spark.ml.tuning.TrainValidationSplitModel;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
+import scala.Tuple2;
+
+public class FeatureTransformer implements Serializable {
+
+ public static String ID_COL = "id";
+ public static String TOKENIZER_INPUT_COL = "sentence";
+ public static String TOKENIZER_OUTPUT_COL = "rawTokens";
+ public static String STOPWORDREMOVER_OUTPUT_COL = "tokens";
+ public static String COUNTVECTORIZER_OUTPUT_COL = "features";
+ public static String LDA_OPTIMIZER = "online";
+
+ /**
+ * Returns the tokenization of the data without stopwords
+ *
+ * @param inputDS: the input dataset of the form (id, sentence)
+ * @return the tokenized data (id, tokens)
+ */
+ public static Dataset tokenizeData(Dataset inputDS) {
+ Tokenizer tokenizer = new Tokenizer().setInputCol(TOKENIZER_INPUT_COL).setOutputCol(TOKENIZER_OUTPUT_COL);
+ StopWordsRemover remover = new StopWordsRemover().setInputCol(TOKENIZER_OUTPUT_COL).setOutputCol(STOPWORDREMOVER_OUTPUT_COL);
+ //TODO consider implementing stemming with SparkNLP library from johnsnowlab
+ Dataset rawTokensDS = tokenizer.transform(inputDS).select(ID_COL, TOKENIZER_OUTPUT_COL);
+ return remover.transform(rawTokensDS).select(ID_COL, STOPWORDREMOVER_OUTPUT_COL);
+ }
+
+ /**
+ * Create the vocabulary from the given data.
+ *
+ * @param inputDS: the input dataset of the form (id, tokens)
+ * @param minDF: minimum number of different documents a term could appear in to be included in the vocabulary
+ * @param minTF: filter to ignore rare words in a document
+ * @param vocabSize: maximum size of the vocabulary (number of terms)
+ * @return the vocabulary
+ */
+ public static CountVectorizerModel createVocabularyFromTokens(Dataset inputDS, double minDF, double minTF, int vocabSize) {
+ return new CountVectorizer()
+ .setInputCol(STOPWORDREMOVER_OUTPUT_COL)
+ .setOutputCol(COUNTVECTORIZER_OUTPUT_COL)
+ .setMinDF(minDF)
+ .setMinTF(minTF)
+ .setVocabSize(vocabSize)
+ .fit(inputDS);
+ //TODO setMaxDF not found, try to add it
+ }
+
+ /**
+ * Create the vocabulary from file.
+ *
+ * @param inputFilePath: the input file with the vocabulary elements (one element for line)
+ * @return the vocabulary
+ */
+ public static CountVectorizerModel createVocabularyFromFile(String inputFilePath) throws IOException {
+ Set fileLines = new HashSet<>();
+ BufferedReader bf = new BufferedReader(new FileReader(inputFilePath));
+ String line = bf.readLine();
+ while(line != null) {
+ fileLines.add(line);
+ line = bf.readLine();
+ }
+ bf.close();
+
+ return new CountVectorizerModel(fileLines.toArray(new String[0])).setInputCol(STOPWORDREMOVER_OUTPUT_COL).setOutputCol(COUNTVECTORIZER_OUTPUT_COL);
+
+ }
+
+ /**
+ * Load an existing vocabulary
+ *
+ * @param vocabularyPath: location of the vocabulary
+ * @return the vocabulary
+ */
+ public static CountVectorizerModel loadVocabulary(String vocabularyPath) {
+ return CountVectorizerModel.load(vocabularyPath);
+ }
+
+ /**
+ * Count vectorize data.
+ *
+ * @param inputDS: the input dataset of the form (id, tokens)
+ * @param vocabulary: the vocabulary to be used for the transformation
+ * @return the count vectorized data
+ */
+ public static Dataset countVectorizeData(Dataset inputDS, CountVectorizerModel vocabulary) {
+ return vocabulary.transform(inputDS).select(ID_COL, COUNTVECTORIZER_OUTPUT_COL);
+ }
+
+ /**
+ * Train LDA model with the given parameters
+ *
+ * @param inputDS: the input dataset
+ * @param k: number of topics
+ * @param maxIter: maximum number of iterations
+ * @return the LDA model
+ */
+ public static LDAModel trainLDAModel(Dataset inputDS, int k, int maxIter) {
+
+ LDA lda = new LDA()
+ .setK(k)
+ .setMaxIter(maxIter)
+ .setFeaturesCol(COUNTVECTORIZER_OUTPUT_COL)
+ .setOptimizer(LDA_OPTIMIZER);
+
+ return lda.fit(inputDS);
+ }
+
+ public static Map> ldaTuning(Dataset dataDS, double trainRatio, int[] numTopics, int maxIter) {
+ Dataset[] setsDS = dataDS.randomSplit(new double[]{trainRatio, 1 - trainRatio});
+ Dataset trainDS = setsDS[0];
+ Dataset testDS = setsDS[1];
+ Map> ldaModels = new HashMap<>();
+
+ for(int k: numTopics) {
+ LDAModel ldaModel = trainLDAModel(trainDS, k, maxIter);
+ double perplexity = ldaModel.logPerplexity(testDS);
+ ldaModels.put(k, new Tuple2<>(ldaModel, perplexity));
+ }
+
+ return ldaModels;
+
+ }
}
diff --git a/dnet-feature-extraction/src/main/java/eu/dnetlib/featureextraction/lda/LDAModeler.java b/dnet-feature-extraction/src/main/java/eu/dnetlib/featureextraction/lda/LDAModeler.java
index fa4c6f8..a01e748 100644
--- a/dnet-feature-extraction/src/main/java/eu/dnetlib/featureextraction/lda/LDAModeler.java
+++ b/dnet-feature-extraction/src/main/java/eu/dnetlib/featureextraction/lda/LDAModeler.java
@@ -1,2 +1,9 @@
-package eu.dnetlib.featureextraction.lda;public class LDAModeler {
+package eu.dnetlib.featureextraction.lda;
+
+public class LDAModeler {
+
+
+ public static void main(String[] args) {
+ System.out.println("prova");
+ }
}
diff --git a/dnet-feature-extraction/src/main/java/eu/dnetlib/featureextraction/util/Utilities.java b/dnet-feature-extraction/src/main/java/eu/dnetlib/featureextraction/util/Utilities.java
index f7416fb..2e82e8d 100644
--- a/dnet-feature-extraction/src/main/java/eu/dnetlib/featureextraction/util/Utilities.java
+++ b/dnet-feature-extraction/src/main/java/eu/dnetlib/featureextraction/util/Utilities.java
@@ -1,2 +1,102 @@
-package eu.dnetlib.featureextraction.util;public class Utilities {
+package eu.dnetlib.featureextraction.util;
+
+import com.jayway.jsonpath.JsonPath;
+import net.minidev.json.JSONArray;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FSDataOutputStream;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.RowFactory;
+import org.apache.spark.sql.SQLContext;
+import org.apache.spark.sql.types.DataTypes;
+import org.apache.spark.sql.types.Metadata;
+import org.apache.spark.sql.types.StructField;
+import org.apache.spark.sql.types.StructType;
+
+import java.io.IOException;
+import java.io.Serializable;
+import java.text.Normalizer;
+import java.util.List;
+
+public class Utilities implements Serializable {
+
+ public static String DATA_ID_FIELD = "$.id";
+
+ static StructType inputSchema = new StructType(new StructField[]{
+ new StructField("id", DataTypes.StringType, false, Metadata.empty()),
+ new StructField("sentence", DataTypes.StringType, false, Metadata.empty())
+ });
+
+ /**
+ * Returns a view of the dataset including the id and the chosen field.
+ *
+ * @param sqlContext: the spark sql context
+ * @param jsonRDD: the input dataset
+ * @param inputFieldJPath: the input field jpath
+ * @return the view of the dataset with normalized data of the inputField (id, inputField)
+ */
+ public static Dataset prepareDataset(SQLContext sqlContext, JavaRDD jsonRDD, String inputFieldJPath) {
+
+ JavaRDD rowRDD = jsonRDD
+ .map(json ->
+ RowFactory.create(getJPathString(DATA_ID_FIELD, json), Utilities.normalize(getJPathString(inputFieldJPath, json))));
+ return sqlContext.createDataFrame(rowRDD, inputSchema);
+ }
+
+ //returns the string value of the jpath in the given input json
+ public static String getJPathString(final String jsonPath, final String inputJson) {
+ try {
+ Object o = JsonPath.read(inputJson, jsonPath);
+ if (o instanceof String)
+ return (String)o;
+ if (o instanceof JSONArray && ((JSONArray)o).size()>0)
+ return (String)((JSONArray)o).get(0);
+ return "";
+ }
+ catch (Exception e) {
+ return "";
+ }
+ }
+
+ public static String normalize(final String s) {
+ return Normalizer.normalize(s, Normalizer.Form.NFD)
+ .replaceAll("[^\\w\\s-]", "") // Remove all non-word, non-space or non-dash characters
+ .replace('-', ' ') // Replace dashes with spaces
+ .trim() // trim leading/trailing whitespace (including what used to be leading/trailing dashes)
+ .toLowerCase(); // Lowercase the final results
+ }
+
+ public static void writeLinesToHDFSFile(List lines, String filePath) throws IOException {
+ Configuration conf = new Configuration();
+
+ FileSystem fs = FileSystem.get(conf);
+ fs.delete(new Path(filePath), true);
+
+ try {
+ fs = FileSystem.get(conf);
+
+ Path outFile = new Path(filePath);
+ // Verification
+ if (fs.exists(outFile)) {
+ System.out.println("Output file already exists");
+ throw new IOException("Output file already exists");
+ }
+
+ // Create file to write
+ FSDataOutputStream out = fs.create(outFile);
+ try{
+ for (String line: lines) {
+ out.writeBytes(line + "\n");
+ }
+ }
+ finally {
+ out.close();
+ }
+ } catch (IOException e) {
+ e.printStackTrace();
+ }
+ }
}