dnet-and/dnet-feature-extraction/src/main/java/eu/dnetlib/featureextraction/ScalaFeatureTransformer.scala

158 lines
5.1 KiB
Scala

package eu.dnetlib.featureextraction
import com.johnsnowlabs.nlp.EmbeddingsFinisher
import com.johnsnowlabs.nlp.annotator.SentenceDetector
import com.johnsnowlabs.nlp.annotators.Tokenizer
import com.johnsnowlabs.nlp.base.DocumentAssembler
import com.johnsnowlabs.nlp.embeddings.{BertEmbeddings, BertSentenceEmbeddings, WordEmbeddingsModel}
import org.apache.spark.ml.Pipeline
import org.apache.spark.sql.functions.{array, col, explode}
import org.apache.spark.sql.{Dataset, Row}
import java.nio.file.Paths
object ScalaFeatureTransformer {
val DOCUMENT_COL = "document"
val SENTENCE_COL = "sentence"
val BERT_SENTENCE_EMBEDDINGS_COL = "bert_sentence"
val BERT_EMBEDDINGS_COL = "bert"
val TOKENIZER_COL = "tokens"
val WORD_EMBEDDINGS_COL = "word"
//models path
private val bertSentenceModelPath = Paths.get(getClass.getResource("/eu/dnetlib/featureextraction/support/sent_small_bert_L6_512_en_2.6.0_2.4_1598350624049").toURI).toFile.getAbsolutePath
private val bertModelPath = Paths.get(getClass.getResource("/eu/dnetlib/featureextraction/support/small_bert_L2_128_en_2.6.0_2.4_1598344320681").toURI).toFile.getAbsolutePath
private val wordModelPath = Paths.get(getClass.getResource("/eu/dnetlib/featureextraction/support/glove_100d_en_2.4.0_2.4_1579690104032").toURI).toFile.getAbsolutePath
/**
* Extract the SentenceBERT embeddings for the given field.
*
* @param inputData: the input data
* @param inputField: the input field
* @return the dataset with the embeddings
* */
def bertSentenceEmbeddings(inputData: Dataset[Row], inputField: String, modelPath: String): Dataset[Row] = {
val documentAssembler = new DocumentAssembler()
.setInputCol(inputField)
.setOutputCol(DOCUMENT_COL)
val sentence = new SentenceDetector()
.setInputCols(DOCUMENT_COL)
.setOutputCol(SENTENCE_COL)
val bertSentenceEmbeddings = BertSentenceEmbeddings
.load(modelPath)
.setInputCols(SENTENCE_COL)
.setOutputCol("raw_" + BERT_SENTENCE_EMBEDDINGS_COL)
.setCaseSensitive(false)
val bertSentenceEmbeddingsFinisher = new EmbeddingsFinisher()
.setInputCols("raw_" + BERT_SENTENCE_EMBEDDINGS_COL)
.setOutputCols(BERT_SENTENCE_EMBEDDINGS_COL)
.setOutputAsVector(true)
.setCleanAnnotations(false)
val pipeline = new Pipeline()
.setStages(Array(
documentAssembler,
sentence,
bertSentenceEmbeddings,
bertSentenceEmbeddingsFinisher
))
val result = pipeline.fit(inputData).transform(inputData).withColumn(BERT_SENTENCE_EMBEDDINGS_COL, explode(col(BERT_SENTENCE_EMBEDDINGS_COL)))
result
}
/**
* Extract the BERT embeddings for the given field.
*
* @param inputData : the input data
* @param inputField : the input field
* @return the dataset with the embeddings
* */
def bertEmbeddings(inputData: Dataset[Row], inputField: String, modelPath: String): Dataset[Row] = {
val documentAssembler = new DocumentAssembler()
.setInputCol(inputField)
.setOutputCol(DOCUMENT_COL)
val tokenizer = new Tokenizer()
.setInputCols(DOCUMENT_COL)
.setOutputCol(TOKENIZER_COL)
val bertEmbeddings = BertEmbeddings
.load(modelPath)
.setInputCols(TOKENIZER_COL, DOCUMENT_COL)
.setOutputCol("raw_" + BERT_EMBEDDINGS_COL)
.setCaseSensitive(false)
val bertEmbeddingsFinisher = new EmbeddingsFinisher()
.setInputCols("raw_" + BERT_EMBEDDINGS_COL)
.setOutputCols(BERT_EMBEDDINGS_COL)
.setOutputAsVector(true)
.setCleanAnnotations(false)
val pipeline = new Pipeline()
.setStages(Array(
documentAssembler,
tokenizer,
bertEmbeddings,
bertEmbeddingsFinisher
))
val result = pipeline.fit(inputData).transform(inputData).withColumn(BERT_EMBEDDINGS_COL, explode(col(BERT_EMBEDDINGS_COL)))
result
}
/**
* Extract the Word2Vec embeddings for the given field.
*
* @param inputData : the input data
* @param inputField : the input field
* @return the dataset with the embeddings
* */
def wordEmbeddings(inputData: Dataset[Row], inputField: String, modelPath: String): Dataset[Row] = {
val documentAssembler = new DocumentAssembler()
.setInputCol(inputField)
.setOutputCol(DOCUMENT_COL)
val tokenizer = new Tokenizer()
.setInputCols(DOCUMENT_COL)
.setOutputCol(TOKENIZER_COL)
val wordEmbeddings = WordEmbeddingsModel
.load(modelPath)
.setInputCols(DOCUMENT_COL, TOKENIZER_COL)
.setOutputCol("raw_" + WORD_EMBEDDINGS_COL)
val wordEmbeddingsFinisher = new EmbeddingsFinisher()
.setInputCols("raw_" + WORD_EMBEDDINGS_COL)
.setOutputCols(WORD_EMBEDDINGS_COL)
.setOutputAsVector(true)
.setCleanAnnotations(false)
val pipeline = new Pipeline()
.setStages(Array(
documentAssembler,
tokenizer,
wordEmbeddings,
wordEmbeddingsFinisher
))
val result = pipeline.fit(inputData).transform(inputData).withColumn(WORD_EMBEDDINGS_COL, explode(col(WORD_EMBEDDINGS_COL)))
result
}
//bert on the title
//bert sentence: on the abstract
//word2vec: on the subjects
}