113 lines
5.3 KiB
Python
113 lines
5.3 KiB
Python
import spacy
|
|
import spacy_transformers
|
|
import torch
|
|
|
|
class NLU:
|
|
def __init__(self, query_rewriter, coref_resolver, intent_classifier, offensive_classifier, entity_extractor, ambig_classifier):
|
|
|
|
self.intent_classifier = intent_classifier
|
|
self.entity_extractor = entity_extractor
|
|
self.offensive_classifier = offensive_classifier
|
|
self.coref_resolver = coref_resolver
|
|
self.query_rewriter = query_rewriter
|
|
self.ambig_classifier = ambig_classifier
|
|
|
|
def _resolve_coref(self, history):
|
|
to_resolve = history + ' <COREF_SEP_TOKEN> ' + self.to_process
|
|
doc = self.coref_resolver(to_resolve)
|
|
token_mention_mapper = {}
|
|
output_string = ""
|
|
clusters = [
|
|
val for key, val in doc.spans.items() if key.startswith("coref_cluster")
|
|
]
|
|
|
|
# Iterate through every found cluster
|
|
for cluster in clusters:
|
|
first_mention = cluster[0]
|
|
# Iterate through every other span in the cluster
|
|
for mention_span in list(cluster)[1:]:
|
|
# Set first_mention as value for the first token in mention_span in the token_mention_mapper
|
|
token_mention_mapper[mention_span[0].idx] = first_mention.text + mention_span[0].whitespace_
|
|
for token in mention_span[1:]:
|
|
# Set empty string for all the other tokens in mention_span
|
|
token_mention_mapper[token.idx] = ""
|
|
|
|
# Iterate through every token in the Doc
|
|
for token in doc:
|
|
# Check if token exists in token_mention_mapper
|
|
if token.idx in token_mention_mapper:
|
|
output_string += token_mention_mapper[token.idx]
|
|
# Else add original token text
|
|
else:
|
|
output_string += token.text + token.whitespace_
|
|
cleaned_query = output_string.split(" <COREF_SEP_TOKEN> ", 1)[1]
|
|
return cleaned_query
|
|
|
|
def _intentpredictor(self):
|
|
pred = self.intent_classifier(self.to_process)[0]
|
|
return pred['label'], pred['score']
|
|
|
|
def _ambigpredictor(self):
|
|
pred = self.ambig_classifier(self.to_process)[0]
|
|
if pred['label'] in ['clear', 'somewhat_clear']:
|
|
return False
|
|
else:
|
|
return True
|
|
|
|
def _entityextractor(self):
|
|
entities = []
|
|
doc = self.entity_extractor(self.to_process)
|
|
for entity in doc.ents:
|
|
if entity.text not in ['.', ',', '?', ';']:
|
|
entities.append({'entity': entity.label_, 'value': entity.text})
|
|
return entities
|
|
|
|
def _offensepredictor(self):
|
|
pred = self.offensive_classifier(self.to_process)[0]['label']
|
|
if pred != "neither":
|
|
return True
|
|
else:
|
|
return False
|
|
|
|
def _rewrite_query(self, history):
|
|
text = history + " ||| " + self.to_process
|
|
return self.query_rewriter(text)[0]['generated_text']
|
|
|
|
|
|
def process_utterance(self, utterance, history_consec, history_sep):
|
|
"""
|
|
Query -> coref resolution & intent extraction -> if intents are not confident or if query is ambig -> rewrite query and recheck -> if still ambig, ask a clarifying question
|
|
"""
|
|
self.to_process = utterance
|
|
|
|
self.to_process = self._resolve_coref(history_consec)
|
|
|
|
intent, score = self._intentpredictor()
|
|
#print(score)
|
|
if score > 0.5:
|
|
entities = self._entityextractor()
|
|
offense = self._offensepredictor()
|
|
if intent in ['FINDPAPER', 'FINDDATASET', 'SUMMARIZEPAPER', 'EXPLAINPOST'] and len(entities) == 0:
|
|
return {"modified_query": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense, "is_clear": False}
|
|
return {"modified_query": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense, "is_clear": True}
|
|
else:
|
|
if self._ambigpredictor():
|
|
self.to_process = self._rewrite_query(history_sep)
|
|
intent, score = self._intentpredictor()
|
|
entities = self._entityextractor()
|
|
offense = self._offensepredictor()
|
|
if score > 0.5 or not self._ambigpredictor():
|
|
if intent in ['FINDPAPER', 'FINDDATASET', 'SUMMARIZEPAPER', 'EXPLAINPOST'] and len(entities) == 0:
|
|
return {"modified_query": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense, "is_clear": False}
|
|
return {"modified_query": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense,
|
|
"is_clear": True}
|
|
else:
|
|
return {"modified_query": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense,
|
|
"is_clear": False}
|
|
else:
|
|
entities = self._entityextractor()
|
|
offense = self._offensepredictor()
|
|
if intent in ['FINDPAPER', 'FINDDATASET', 'SUMMARIZEPAPER', 'EXPLAINPOST'] and len(entities) == 0:
|
|
return {"modified_query": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense, "is_clear": False}
|
|
return {"modified_query": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense, "is_clear": True}
|