JanetBackEnd/NLU.py

132 lines
6.1 KiB
Python

import spacy
import spacy_transformers
import torch
class NLU:
def __init__(self, query_rewriter, coref_resolver, intent_classifier, offensive_classifier, entity_extractor, ambig_classifier):
self.intent_classifier = intent_classifier
self.entity_extractor = entity_extractor
self.offensive_classifier = offensive_classifier
self.coref_resolver = coref_resolver
self.query_rewriter = query_rewriter
self.ambig_classifier = ambig_classifier
def _resolve_coref(self, history):
to_resolve = history + ' <COREF_SEP_TOKEN> ' + self.to_process
doc = self.coref_resolver(to_resolve)
token_mention_mapper = {}
output_string = ""
clusters = [
val for key, val in doc.spans.items() if key.startswith("coref_cluster")
]
"""
clusters = []
for cluster in cand_clusters:
if cluster[0].text == "I":
continue
clusters.append(cluster)
"""
# Iterate through every found cluster
for cluster in clusters:
first_mention = cluster[0]
# Iterate through every other span in the cluster
for mention_span in list(cluster)[1:]:
# Set first_mention as value for the first token in mention_span in the token_mention_mapper
token_mention_mapper[mention_span[0].idx] = first_mention.text + mention_span[0].whitespace_
for token in mention_span[1:]:
# Set empty string for all the other tokens in mention_span
token_mention_mapper[token.idx] = ""
# Iterate through every token in the Doc
for token in doc:
# Check if token exists in token_mention_mapper
if token.idx in token_mention_mapper:
output_string += token_mention_mapper[token.idx]
# Else add original token text
else:
output_string += token.text + token.whitespace_
if len(output_string.split(" <COREF_SEP_TOKEN> ", 1)) > 1:
cleaned_query = output_string.split(" <COREF_SEP_TOKEN> ", 1)[1]
return cleaned_query
else:
cleaned_query = output_string.split(" <COREF_SEP_TOKEN> ", 1)[0]
return cleaned_query
def _intentpredictor(self):
pred = self.intent_classifier(self.to_process)[0]
return pred['label'], pred['score']
def _ambigpredictor(self):
pred = self.ambig_classifier(self.to_process)[0]
if pred['label'] in ['clear', 'somewhat_clear']:
return False
else:
return True
def _entityextractor(self):
entities = []
doc = self.entity_extractor(self.to_process)
for entity in doc.ents:
if entity.text not in ['.', ',', '?', ';']:
entities.append({'entity': entity.label_, 'value': entity.text})
return entities
def _offensepredictor(self):
pred = self.offensive_classifier(self.to_process)[0]['label']
if pred != "neither":
return True
else:
return False
def _rewrite_query(self, history):
text = history + " ||| " + self.to_process
return self.query_rewriter(text)[0]['generated_text']
def process_utterance(self, utterance, history_consec, history_sep):
"""
Query -> coref resolution & intent extraction -> if intents are not confident or if query is ambig -> rewrite query and recheck -> if still ambig, ask a clarifying question
"""
if utterance.lower() in ["help", "list resources", "list papers", "list datasets", "list topics"]:
return {"modified_query": utterance.lower(), "intent": "COMMAND", "entities": [], "is_offensive": False, "is_clear": True}
self.to_process = utterance
self.to_process = self._resolve_coref(history_consec)
intent, score = self._intentpredictor()
if score > 0.5:
if intent == 'CHITCHAT':
self.to_process = utterance
entities = self._entityextractor()
offense = self._offensepredictor()
if intent in ['FINDPAPER', 'FINDDATASET', 'SUMMARIZEPAPER'] and len(entities) == 0:
return {"modified_query": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense, "is_clear": False}
return {"modified_query": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense, "is_clear": True}
else:
if self._ambigpredictor():
self.to_process = self._rewrite_query(history_sep)
intent, score = self._intentpredictor()
entities = self._entityextractor()
offense = self._offensepredictor()
if score > 0.5 or not self._ambigpredictor():
if intent == 'CHITCHAT':
self.to_process = utterance
if intent in ['FINDPAPER', 'FINDDATASET', 'SUMMARIZEPAPER'] and len(entities) == 0:
return {"modified_query": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense, "is_clear": False}
return {"modified_query": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense,
"is_clear": True}
else:
return {"modified_query": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense,
"is_clear": False}
else:
entities = self._entityextractor()
offense = self._offensepredictor()
if intent == 'CHITCHAT':
self.to_process = utterance
if intent in ['FINDPAPER', 'FINDDATASET', 'SUMMARIZEPAPER'] and len(entities) == 0:
return {"modified_query": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense, "is_clear": False}
return {"modified_query": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense, "is_clear": True}