JanetBackEnd/NLU.py

113 lines
5.3 KiB
Python

import spacy
import spacy_transformers
import torch
class NLU:
def __init__(self, query_rewriter, coref_resolver, intent_classifier, offensive_classifier, entity_extractor, ambig_classifier):
self.intent_classifier = intent_classifier
self.entity_extractor = entity_extractor
self.offensive_classifier = offensive_classifier
self.coref_resolver = coref_resolver
self.query_rewriter = query_rewriter
self.ambig_classifier = ambig_classifier
def _resolve_coref(self, history):
to_resolve = history + ' <COREF_SEP_TOKEN> ' + self.to_process
doc = self.coref_resolver(to_resolve)
token_mention_mapper = {}
output_string = ""
clusters = [
val for key, val in doc.spans.items() if key.startswith("coref_cluster")
]
# Iterate through every found cluster
for cluster in clusters:
first_mention = cluster[0]
# Iterate through every other span in the cluster
for mention_span in list(cluster)[1:]:
# Set first_mention as value for the first token in mention_span in the token_mention_mapper
token_mention_mapper[mention_span[0].idx] = first_mention.text + mention_span[0].whitespace_
for token in mention_span[1:]:
# Set empty string for all the other tokens in mention_span
token_mention_mapper[token.idx] = ""
# Iterate through every token in the Doc
for token in doc:
# Check if token exists in token_mention_mapper
if token.idx in token_mention_mapper:
output_string += token_mention_mapper[token.idx]
# Else add original token text
else:
output_string += token.text + token.whitespace_
cleaned_query = output_string.split(" <COREF_SEP_TOKEN> ", 1)[1]
return cleaned_query
def _intentpredictor(self):
pred = self.intent_classifier(self.to_process)[0]
return pred['label'], pred['score']
def _ambigpredictor(self):
pred = self.ambig_classifier(self.to_process)[0]
if pred['label'] in ['clear', 'somewhat_clear']:
return False
else:
return True
def _entityextractor(self):
entities = []
doc = self.entity_extractor(self.to_process)
for entity in doc.ents:
if entity.text not in ['.', ',', '?', ';']:
entities.append({'entity': entity.label_, 'value': entity.text})
return entities
def _offensepredictor(self):
pred = self.offensive_classifier(self.to_process)[0]['label']
if pred != "neither":
return True
else:
return False
def _rewrite_query(self, history):
text = history + " ||| " + self.to_process
return self.query_rewriter(text)[0]['generated_text']
def process_utterance(self, utterance, history_consec, history_sep):
"""
Query -> coref resolution & intent extraction -> if intents are not confident or if query is ambig -> rewrite query and recheck -> if still ambig, ask a clarifying question
"""
self.to_process = utterance
self.to_process = self._resolve_coref(history_consec)
intent, score = self._intentpredictor()
#print(score)
if score > 0.5:
entities = self._entityextractor()
offense = self._offensepredictor()
if intent in ['FINDPAPER', 'FINDDATASET', 'SUMMARIZEPAPER', 'EXPLAINPOST'] and len(entities) == 0:
return {"modified_query": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense, "is_clear": False}
return {"modified_query": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense, "is_clear": True}
else:
if self._ambigpredictor():
self.to_process = self._rewrite_query(history_sep)
intent, score = self._intentpredictor()
entities = self._entityextractor()
offense = self._offensepredictor()
if score > 0.5 or not self._ambigpredictor():
if intent in ['FINDPAPER', 'FINDDATASET', 'SUMMARIZEPAPER', 'EXPLAINPOST'] and len(entities) == 0:
return {"modified_query": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense, "is_clear": False}
return {"modified_query": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense,
"is_clear": True}
else:
return {"modified_query": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense,
"is_clear": False}
else:
entities = self._entityextractor()
offense = self._offensepredictor()
if intent in ['FINDPAPER', 'FINDDATASET', 'SUMMARIZEPAPER', 'EXPLAINPOST'] and len(entities) == 0:
return {"modified_query": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense, "is_clear": False}
return {"modified_query": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense, "is_clear": True}