158 lines
5.1 KiB
Scala
158 lines
5.1 KiB
Scala
package eu.dnetlib.featureextraction
|
|
|
|
import com.johnsnowlabs.nlp.EmbeddingsFinisher
|
|
import com.johnsnowlabs.nlp.annotator.SentenceDetector
|
|
import com.johnsnowlabs.nlp.annotators.Tokenizer
|
|
import com.johnsnowlabs.nlp.base.DocumentAssembler
|
|
import com.johnsnowlabs.nlp.embeddings.{BertEmbeddings, BertSentenceEmbeddings, WordEmbeddingsModel}
|
|
import org.apache.spark.ml.Pipeline
|
|
import org.apache.spark.sql.functions.{array, col, explode}
|
|
import org.apache.spark.sql.{Dataset, Row}
|
|
|
|
import java.nio.file.Paths
|
|
|
|
object ScalaFeatureTransformer {
|
|
|
|
val DOCUMENT_COL = "document"
|
|
val SENTENCE_COL = "sentence"
|
|
val BERT_SENTENCE_EMBEDDINGS_COL = "bert_sentence"
|
|
val BERT_EMBEDDINGS_COL = "bert"
|
|
val TOKENIZER_COL = "tokens"
|
|
val WORD_EMBEDDINGS_COL = "word"
|
|
|
|
//models path
|
|
private val bertSentenceModelPath = Paths.get(getClass.getResource("/eu/dnetlib/featureextraction/support/sent_small_bert_L6_512_en_2.6.0_2.4_1598350624049").toURI).toFile.getAbsolutePath
|
|
private val bertModelPath = Paths.get(getClass.getResource("/eu/dnetlib/featureextraction/support/small_bert_L2_128_en_2.6.0_2.4_1598344320681").toURI).toFile.getAbsolutePath
|
|
private val wordModelPath = Paths.get(getClass.getResource("/eu/dnetlib/featureextraction/support/glove_100d_en_2.4.0_2.4_1579690104032").toURI).toFile.getAbsolutePath
|
|
|
|
/**
|
|
* Extract the SentenceBERT embeddings for the given field.
|
|
*
|
|
* @param inputData: the input data
|
|
* @param inputField: the input field
|
|
* @return the dataset with the embeddings
|
|
* */
|
|
def bertSentenceEmbeddings(inputData: Dataset[Row], inputField: String, modelPath: String): Dataset[Row] = {
|
|
|
|
val documentAssembler = new DocumentAssembler()
|
|
.setInputCol(inputField)
|
|
.setOutputCol(DOCUMENT_COL)
|
|
|
|
val sentence = new SentenceDetector()
|
|
.setInputCols(DOCUMENT_COL)
|
|
.setOutputCol(SENTENCE_COL)
|
|
|
|
val bertSentenceEmbeddings = BertSentenceEmbeddings
|
|
.load(modelPath)
|
|
.setInputCols(SENTENCE_COL)
|
|
.setOutputCol("raw_" + BERT_SENTENCE_EMBEDDINGS_COL)
|
|
.setCaseSensitive(false)
|
|
|
|
val bertSentenceEmbeddingsFinisher = new EmbeddingsFinisher()
|
|
.setInputCols("raw_" + BERT_SENTENCE_EMBEDDINGS_COL)
|
|
.setOutputCols(BERT_SENTENCE_EMBEDDINGS_COL)
|
|
.setOutputAsVector(true)
|
|
.setCleanAnnotations(false)
|
|
|
|
val pipeline = new Pipeline()
|
|
.setStages(Array(
|
|
documentAssembler,
|
|
sentence,
|
|
bertSentenceEmbeddings,
|
|
bertSentenceEmbeddingsFinisher
|
|
))
|
|
|
|
val result = pipeline.fit(inputData).transform(inputData).withColumn(BERT_SENTENCE_EMBEDDINGS_COL, explode(col(BERT_SENTENCE_EMBEDDINGS_COL)))
|
|
|
|
result
|
|
}
|
|
|
|
/**
|
|
* Extract the BERT embeddings for the given field.
|
|
*
|
|
* @param inputData : the input data
|
|
* @param inputField : the input field
|
|
* @return the dataset with the embeddings
|
|
* */
|
|
def bertEmbeddings(inputData: Dataset[Row], inputField: String, modelPath: String): Dataset[Row] = {
|
|
|
|
val documentAssembler = new DocumentAssembler()
|
|
.setInputCol(inputField)
|
|
.setOutputCol(DOCUMENT_COL)
|
|
|
|
val tokenizer = new Tokenizer()
|
|
.setInputCols(DOCUMENT_COL)
|
|
.setOutputCol(TOKENIZER_COL)
|
|
|
|
val bertEmbeddings = BertEmbeddings
|
|
.load(modelPath)
|
|
.setInputCols(TOKENIZER_COL, DOCUMENT_COL)
|
|
.setOutputCol("raw_" + BERT_EMBEDDINGS_COL)
|
|
.setCaseSensitive(false)
|
|
|
|
val bertEmbeddingsFinisher = new EmbeddingsFinisher()
|
|
.setInputCols("raw_" + BERT_EMBEDDINGS_COL)
|
|
.setOutputCols(BERT_EMBEDDINGS_COL)
|
|
.setOutputAsVector(true)
|
|
.setCleanAnnotations(false)
|
|
|
|
val pipeline = new Pipeline()
|
|
.setStages(Array(
|
|
documentAssembler,
|
|
tokenizer,
|
|
bertEmbeddings,
|
|
bertEmbeddingsFinisher
|
|
))
|
|
|
|
val result = pipeline.fit(inputData).transform(inputData).withColumn(BERT_EMBEDDINGS_COL, explode(col(BERT_EMBEDDINGS_COL)))
|
|
|
|
result
|
|
}
|
|
|
|
/**
|
|
* Extract the Word2Vec embeddings for the given field.
|
|
*
|
|
* @param inputData : the input data
|
|
* @param inputField : the input field
|
|
* @return the dataset with the embeddings
|
|
* */
|
|
def wordEmbeddings(inputData: Dataset[Row], inputField: String, modelPath: String): Dataset[Row] = {
|
|
|
|
val documentAssembler = new DocumentAssembler()
|
|
.setInputCol(inputField)
|
|
.setOutputCol(DOCUMENT_COL)
|
|
|
|
val tokenizer = new Tokenizer()
|
|
.setInputCols(DOCUMENT_COL)
|
|
.setOutputCol(TOKENIZER_COL)
|
|
|
|
val wordEmbeddings = WordEmbeddingsModel
|
|
.load(modelPath)
|
|
.setInputCols(DOCUMENT_COL, TOKENIZER_COL)
|
|
.setOutputCol("raw_" + WORD_EMBEDDINGS_COL)
|
|
|
|
val wordEmbeddingsFinisher = new EmbeddingsFinisher()
|
|
.setInputCols("raw_" + WORD_EMBEDDINGS_COL)
|
|
.setOutputCols(WORD_EMBEDDINGS_COL)
|
|
.setOutputAsVector(true)
|
|
.setCleanAnnotations(false)
|
|
|
|
val pipeline = new Pipeline()
|
|
.setStages(Array(
|
|
documentAssembler,
|
|
tokenizer,
|
|
wordEmbeddings,
|
|
wordEmbeddingsFinisher
|
|
))
|
|
|
|
val result = pipeline.fit(inputData).transform(inputData).withColumn(WORD_EMBEDDINGS_COL, explode(col(WORD_EMBEDDINGS_COL)))
|
|
|
|
result
|
|
}
|
|
|
|
//bert on the title
|
|
//bert sentence: on the abstract
|
|
//word2vec: on the subjects
|
|
|
|
}
|