JanetBackEnd/NLU.py

144 lines
6.6 KiB
Python

"""
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
import torch
class NLU:
def_tokenizer = AutoTokenizer.from_pretrained("castorini/t5-base-canard")
def_model = AutoModelForSeq2SeqLM.from_pretrained("castorini/t5-base-canard")
def_intent_classifier = pipeline("sentiment-analysis", model="/home/ahmed/PycharmProjects/Janet/JanetBackend/intent_classifier")
def __init__(self, model=def_model, tokenizer=def_tokenizer, intent_classifier=def_intent_classifier,
max_history_length=1024, num_gen_seq=2, score_threshold=0.5):
self.input = ""
self.output = ""
self.model = model
self.tokenizer = tokenizer
self.max_length = max_history_length
self.num_return_sequences = num_gen_seq
self.score_threshold = score_threshold
self.label2id = {'Greet': 0, 'Bye': 1, 'GetKnowledge': 2, 'ChitChat': 3}
self.id2label = {0: 'Greet', 1: 'Bye', 2: 'GetKnowledge', 3: 'ChitChat'}
self.intent_classifier = intent_classifier
def process_utterance(self, utterance, history):
if len(history) > 0:
# crop history
while len(history.split(" ")) > self.max_length:
index = history.find("|||")
history = history[index + 4:]
self.input = history + " ||| " + utterance
inputs = self.tokenizer(self.input, max_length=self.max_length, truncation=True, padding="max_length",
return_tensors="pt")
candidates = self.model.generate(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"],
return_dict_in_generate=True, output_scores=True,
num_return_sequences=self.num_return_sequences,
num_beams=self.num_return_sequences)
for i in range(candidates["sequences"].shape[0]):
generated_sentence = self.tokenizer.decode(candidates["sequences"][i], skip_special_tokens=True,
clean_up_tokenization_spaces=True)
log_scores = candidates['sequences_scores']
norm_prob = (torch.exp(log_scores[i]) / torch.exp(log_scores).sum()).item()
if norm_prob >= self.score_threshold:
self.score_threshold = norm_prob
self.output = generated_sentence
else:
self.output = utterance
intent = self.label2id[self.intent_classifier(self.output)[0]['label']]
intent_conf = self.intent_classifier(self.output)[0]['score']
return {"modified_prompt": self.output, "mod_confidence": self.score_threshold, "prompt_intent": intent,
"intent_confidence": intent_conf}
"""
import threading
import spacy
import spacy_transformers
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
class NLU:
def __init__(self, device, device_flag, reference_resolver, tokenizer,
intent_classifier, offense_filter, entity_extractor,
max_history_length=1024):
#entity_extractor=def_entity_extractor
self.reference_resolver = reference_resolver
self.device = device
self.reference_resolver.to(device)
self.tokenizer = tokenizer
self.max_length = max_history_length
self.label2idintent = {'QA': 0, 'CHITCHAT': 1, 'FINDPAPER': 2, 'FINDDATASET': 3, 'SUMMARIZEPAPER': 4}
self.id2labelintent = {0: 'QA', 1: 'CHITCHAT', 2: 'FINDPAPER', 3: 'FINDDATASET', 4: 'SUMMARIZEPAPER'}
self.label2idoffense = {'hate': 0, 'offensive': 1, 'neither': 2}
self.id2labeloffense = {0: 'hate', 1: 'offensive', 2: 'neither'}
self.intent_classifier = pipeline("sentiment-analysis", model=intent_classifier, device=device_flag)
self.entity_extractor = entity_extractor
self.offense_filter = pipeline("sentiment-analysis", model=offense_filter, device=device_flag)
self.intents = None
self.entities = None
self.offensive = None
self.clear = True
def _intentpredictor(self):
self.intents = self.label2idintent[self.intent_classifier(self.to_process)[0]['label']]
def _entityextractor(self):
self.entities = []
doc = self.entity_extractor(self.to_process)
for entity in doc.ents:
if entity.text not in ['.', ',', '?', ';']:
self.entities.append({'entity': entity.label_, 'value': entity.text})
def _inappropriatedetector(self):
self.offensive = False
is_offensive = self.label2idoffense[self.offense_filter(self.to_process)[0]['label']]
if is_offensive == 0 or is_offensive == 1:
self.offensive = True
def process_utterance(self, utterance, history):
"""
Given an utterance and the history of the conversation, refine the query contextually and return a refined
utterance
"""
self.to_process = utterance
if len(history) > 0:
# crop history
while len(history.split(" ")) > self.max_length:
index = history.find("|||")
history = history[index + 4:]
context = history + " ||| " + utterance
inputs = self.tokenizer(context, max_length=self.max_length, truncation=True, padding="max_length",
return_tensors="pt")
candidates = self.reference_resolver.generate(input_ids=inputs["input_ids"].to(self.device),
attention_mask=inputs["attention_mask"].to(self.device),
return_dict_in_generate=True, output_scores=True,
num_return_sequences=1,
num_beams=5)
self.to_process = self.tokenizer.decode(candidates["sequences"][0], skip_special_tokens=True,
clean_up_tokenization_spaces=True)
t1 = threading.Thread(target=self._intentpredictor, name='intent')
t2 = threading.Thread(target=self._entityextractor, name='entity')
t3 = threading.Thread(target=self._inappropriatedetector, name='offensive')
t3.start()
t1.start()
t2.start()
t3.join()
t1.join()
t2.join()
return {"modified_prompt": self.to_process,
"intent": self.intents,
"entities": self.entities,
"is_offensive": self.offensive,
"is_clear": self.clear}