144 lines
6.6 KiB
Python
144 lines
6.6 KiB
Python
"""
|
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
|
|
import torch
|
|
|
|
|
|
class NLU:
|
|
def_tokenizer = AutoTokenizer.from_pretrained("castorini/t5-base-canard")
|
|
def_model = AutoModelForSeq2SeqLM.from_pretrained("castorini/t5-base-canard")
|
|
def_intent_classifier = pipeline("sentiment-analysis", model="/home/ahmed/PycharmProjects/Janet/JanetBackend/intent_classifier")
|
|
|
|
def __init__(self, model=def_model, tokenizer=def_tokenizer, intent_classifier=def_intent_classifier,
|
|
max_history_length=1024, num_gen_seq=2, score_threshold=0.5):
|
|
self.input = ""
|
|
self.output = ""
|
|
self.model = model
|
|
self.tokenizer = tokenizer
|
|
self.max_length = max_history_length
|
|
self.num_return_sequences = num_gen_seq
|
|
self.score_threshold = score_threshold
|
|
self.label2id = {'Greet': 0, 'Bye': 1, 'GetKnowledge': 2, 'ChitChat': 3}
|
|
self.id2label = {0: 'Greet', 1: 'Bye', 2: 'GetKnowledge', 3: 'ChitChat'}
|
|
self.intent_classifier = intent_classifier
|
|
|
|
def process_utterance(self, utterance, history):
|
|
if len(history) > 0:
|
|
# crop history
|
|
while len(history.split(" ")) > self.max_length:
|
|
index = history.find("|||")
|
|
history = history[index + 4:]
|
|
|
|
self.input = history + " ||| " + utterance
|
|
inputs = self.tokenizer(self.input, max_length=self.max_length, truncation=True, padding="max_length",
|
|
return_tensors="pt")
|
|
|
|
candidates = self.model.generate(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"],
|
|
return_dict_in_generate=True, output_scores=True,
|
|
num_return_sequences=self.num_return_sequences,
|
|
num_beams=self.num_return_sequences)
|
|
for i in range(candidates["sequences"].shape[0]):
|
|
generated_sentence = self.tokenizer.decode(candidates["sequences"][i], skip_special_tokens=True,
|
|
clean_up_tokenization_spaces=True)
|
|
log_scores = candidates['sequences_scores']
|
|
norm_prob = (torch.exp(log_scores[i]) / torch.exp(log_scores).sum()).item()
|
|
if norm_prob >= self.score_threshold:
|
|
self.score_threshold = norm_prob
|
|
self.output = generated_sentence
|
|
else:
|
|
self.output = utterance
|
|
|
|
intent = self.label2id[self.intent_classifier(self.output)[0]['label']]
|
|
intent_conf = self.intent_classifier(self.output)[0]['score']
|
|
|
|
return {"modified_prompt": self.output, "mod_confidence": self.score_threshold, "prompt_intent": intent,
|
|
"intent_confidence": intent_conf}
|
|
"""
|
|
|
|
import threading
|
|
|
|
import spacy
|
|
import spacy_transformers
|
|
import torch
|
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
|
|
|
|
|
|
class NLU:
|
|
def __init__(self, device, device_flag, reference_resolver, tokenizer,
|
|
intent_classifier, offense_filter, entity_extractor,
|
|
max_history_length=1024):
|
|
#entity_extractor=def_entity_extractor
|
|
self.reference_resolver = reference_resolver
|
|
self.device = device
|
|
self.reference_resolver.to(device)
|
|
self.tokenizer = tokenizer
|
|
self.max_length = max_history_length
|
|
self.label2idintent = {'QA': 0, 'CHITCHAT': 1, 'FINDPAPER': 2, 'FINDDATASET': 3, 'SUMMARIZEPAPER': 4}
|
|
self.id2labelintent = {0: 'QA', 1: 'CHITCHAT', 2: 'FINDPAPER', 3: 'FINDDATASET', 4: 'SUMMARIZEPAPER'}
|
|
self.label2idoffense = {'hate': 0, 'offensive': 1, 'neither': 2}
|
|
self.id2labeloffense = {0: 'hate', 1: 'offensive', 2: 'neither'}
|
|
self.intent_classifier = pipeline("sentiment-analysis", model=intent_classifier, device=device_flag)
|
|
self.entity_extractor = entity_extractor
|
|
self.offense_filter = pipeline("sentiment-analysis", model=offense_filter, device=device_flag)
|
|
|
|
self.intents = None
|
|
self.entities = None
|
|
self.offensive = None
|
|
self.clear = True
|
|
|
|
def _intentpredictor(self):
|
|
self.intents = self.label2idintent[self.intent_classifier(self.to_process)[0]['label']]
|
|
|
|
def _entityextractor(self):
|
|
self.entities = []
|
|
doc = self.entity_extractor(self.to_process)
|
|
for entity in doc.ents:
|
|
if entity.text not in ['.', ',', '?', ';']:
|
|
self.entities.append({'entity': entity.label_, 'value': entity.text})
|
|
|
|
def _inappropriatedetector(self):
|
|
self.offensive = False
|
|
is_offensive = self.label2idoffense[self.offense_filter(self.to_process)[0]['label']]
|
|
if is_offensive == 0 or is_offensive == 1:
|
|
self.offensive = True
|
|
|
|
def process_utterance(self, utterance, history):
|
|
"""
|
|
Given an utterance and the history of the conversation, refine the query contextually and return a refined
|
|
utterance
|
|
"""
|
|
self.to_process = utterance
|
|
if len(history) > 0:
|
|
# crop history
|
|
while len(history.split(" ")) > self.max_length:
|
|
index = history.find("|||")
|
|
history = history[index + 4:]
|
|
|
|
context = history + " ||| " + utterance
|
|
inputs = self.tokenizer(context, max_length=self.max_length, truncation=True, padding="max_length",
|
|
return_tensors="pt")
|
|
|
|
candidates = self.reference_resolver.generate(input_ids=inputs["input_ids"].to(self.device),
|
|
attention_mask=inputs["attention_mask"].to(self.device),
|
|
return_dict_in_generate=True, output_scores=True,
|
|
num_return_sequences=1,
|
|
num_beams=5)
|
|
self.to_process = self.tokenizer.decode(candidates["sequences"][0], skip_special_tokens=True,
|
|
clean_up_tokenization_spaces=True)
|
|
|
|
t1 = threading.Thread(target=self._intentpredictor, name='intent')
|
|
t2 = threading.Thread(target=self._entityextractor, name='entity')
|
|
t3 = threading.Thread(target=self._inappropriatedetector, name='offensive')
|
|
|
|
t3.start()
|
|
t1.start()
|
|
t2.start()
|
|
|
|
t3.join()
|
|
t1.join()
|
|
t2.join()
|
|
return {"modified_prompt": self.to_process,
|
|
"intent": self.intents,
|
|
"entities": self.entities,
|
|
"is_offensive": self.offensive,
|
|
"is_clear": self.clear}
|