package eu.dnetlib.featureextraction import com.johnsnowlabs.nlp.EmbeddingsFinisher import com.johnsnowlabs.nlp.annotator.SentenceDetector import com.johnsnowlabs.nlp.annotators.Tokenizer import com.johnsnowlabs.nlp.base.DocumentAssembler import com.johnsnowlabs.nlp.embeddings.{BertEmbeddings, BertSentenceEmbeddings, WordEmbeddingsModel} import org.apache.spark.ml.Pipeline import org.apache.spark.sql.functions.{array, col, explode} import org.apache.spark.sql.{Dataset, Row} import java.nio.file.Paths object ScalaFeatureTransformer { val DOCUMENT_COL = "document" val SENTENCE_COL = "sentence" val BERT_SENTENCE_EMBEDDINGS_COL = "bert_sentence" val BERT_EMBEDDINGS_COL = "bert" val TOKENIZER_COL = "tokens" val WORD_EMBEDDINGS_COL = "word" //models path private val bertSentenceModelPath = Paths.get(getClass.getResource("/eu/dnetlib/featureextraction/support/sent_small_bert_L6_512_en_2.6.0_2.4_1598350624049").toURI).toFile.getAbsolutePath private val bertModelPath = Paths.get(getClass.getResource("/eu/dnetlib/featureextraction/support/small_bert_L2_128_en_2.6.0_2.4_1598344320681").toURI).toFile.getAbsolutePath private val wordModelPath = Paths.get(getClass.getResource("/eu/dnetlib/featureextraction/support/glove_100d_en_2.4.0_2.4_1579690104032").toURI).toFile.getAbsolutePath /** * Extract the SentenceBERT embeddings for the given field. * * @param inputData: the input data * @param inputField: the input field * @return the dataset with the embeddings * */ def bertSentenceEmbeddings(inputData: Dataset[Row], inputField: String, modelPath: String): Dataset[Row] = { val documentAssembler = new DocumentAssembler() .setInputCol(inputField) .setOutputCol(DOCUMENT_COL) val sentence = new SentenceDetector() .setInputCols(DOCUMENT_COL) .setOutputCol(SENTENCE_COL) val bertSentenceEmbeddings = BertSentenceEmbeddings .load(modelPath) .setInputCols(SENTENCE_COL) .setOutputCol("raw_" + BERT_SENTENCE_EMBEDDINGS_COL) .setCaseSensitive(false) val bertSentenceEmbeddingsFinisher = new EmbeddingsFinisher() .setInputCols("raw_" + BERT_SENTENCE_EMBEDDINGS_COL) .setOutputCols(BERT_SENTENCE_EMBEDDINGS_COL) .setOutputAsVector(true) .setCleanAnnotations(false) val pipeline = new Pipeline() .setStages(Array( documentAssembler, sentence, bertSentenceEmbeddings, bertSentenceEmbeddingsFinisher )) val result = pipeline.fit(inputData).transform(inputData).withColumn(BERT_SENTENCE_EMBEDDINGS_COL, explode(col(BERT_SENTENCE_EMBEDDINGS_COL))) result } /** * Extract the BERT embeddings for the given field. * * @param inputData : the input data * @param inputField : the input field * @return the dataset with the embeddings * */ def bertEmbeddings(inputData: Dataset[Row], inputField: String, modelPath: String): Dataset[Row] = { val documentAssembler = new DocumentAssembler() .setInputCol(inputField) .setOutputCol(DOCUMENT_COL) val tokenizer = new Tokenizer() .setInputCols(DOCUMENT_COL) .setOutputCol(TOKENIZER_COL) val bertEmbeddings = BertEmbeddings .load(modelPath) .setInputCols(TOKENIZER_COL, DOCUMENT_COL) .setOutputCol("raw_" + BERT_EMBEDDINGS_COL) .setCaseSensitive(false) val bertEmbeddingsFinisher = new EmbeddingsFinisher() .setInputCols("raw_" + BERT_EMBEDDINGS_COL) .setOutputCols(BERT_EMBEDDINGS_COL) .setOutputAsVector(true) .setCleanAnnotations(false) val pipeline = new Pipeline() .setStages(Array( documentAssembler, tokenizer, bertEmbeddings, bertEmbeddingsFinisher )) val result = pipeline.fit(inputData).transform(inputData).withColumn(BERT_EMBEDDINGS_COL, explode(col(BERT_EMBEDDINGS_COL))) result } /** * Extract the Word2Vec embeddings for the given field. * * @param inputData : the input data * @param inputField : the input field * @return the dataset with the embeddings * */ def wordEmbeddings(inputData: Dataset[Row], inputField: String, modelPath: String): Dataset[Row] = { val documentAssembler = new DocumentAssembler() .setInputCol(inputField) .setOutputCol(DOCUMENT_COL) val tokenizer = new Tokenizer() .setInputCols(DOCUMENT_COL) .setOutputCol(TOKENIZER_COL) val wordEmbeddings = WordEmbeddingsModel .load(modelPath) .setInputCols(DOCUMENT_COL, TOKENIZER_COL) .setOutputCol("raw_" + WORD_EMBEDDINGS_COL) val wordEmbeddingsFinisher = new EmbeddingsFinisher() .setInputCols("raw_" + WORD_EMBEDDINGS_COL) .setOutputCols(WORD_EMBEDDINGS_COL) .setOutputAsVector(true) .setCleanAnnotations(false) val pipeline = new Pipeline() .setStages(Array( documentAssembler, tokenizer, wordEmbeddings, wordEmbeddingsFinisher )) val result = pipeline.fit(inputData).transform(inputData).withColumn(WORD_EMBEDDINGS_COL, explode(col(WORD_EMBEDDINGS_COL))) result } //bert on the title //bert sentence: on the abstract //word2vec: on the subjects }