import spacy import spacy_transformers import torch class NLU: def __init__(self, query_rewriter, coref_resolver, intent_classifier, offensive_classifier, entity_extractor, ambig_classifier): self.intent_classifier = intent_classifier self.entity_extractor = entity_extractor self.offensive_classifier = offensive_classifier self.coref_resolver = coref_resolver self.query_rewriter = query_rewriter self.ambig_classifier = ambig_classifier def _resolve_coref(self, history): to_resolve = history + ' ' + self.to_process doc = self.coref_resolver(to_resolve) token_mention_mapper = {} output_string = "" clusters = [ val for key, val in doc.spans.items() if key.startswith("coref_cluster") ] # Iterate through every found cluster for cluster in clusters: first_mention = cluster[0] # Iterate through every other span in the cluster for mention_span in list(cluster)[1:]: # Set first_mention as value for the first token in mention_span in the token_mention_mapper token_mention_mapper[mention_span[0].idx] = first_mention.text + mention_span[0].whitespace_ for token in mention_span[1:]: # Set empty string for all the other tokens in mention_span token_mention_mapper[token.idx] = "" # Iterate through every token in the Doc for token in doc: # Check if token exists in token_mention_mapper if token.idx in token_mention_mapper: output_string += token_mention_mapper[token.idx] # Else add original token text else: output_string += token.text + token.whitespace_ cleaned_query = output_string.split(" ", 1)[1] return cleaned_query def _intentpredictor(self): pred = self.intent_classifier(self.to_process)[0] return pred['label'], pred['score'] def _ambigpredictor(self): pred = self.ambig_classifier(self.to_process)[0] if pred['label'] in ['clear', 'somewhat_clear']: return False else: return True def _entityextractor(self): entities = [] doc = self.entity_extractor(self.to_process) for entity in doc.ents: if entity.text not in ['.', ',', '?', ';']: entities.append({'entity': entity.label_, 'value': entity.text}) return entities def _offensepredictor(self): pred = self.offensive_classifier(self.to_process)[0]['label'] if pred != "neither": return True else: return False def _rewrite_query(self, history): text = history + " ||| " + self.to_process return self.query_rewriter(text)[0]['generated_text'] def process_utterance(self, utterance, history_consec, history_sep): """ Query -> coref resolution & intent extraction -> if intents are not confident or if query is ambig -> rewrite query and recheck -> if still ambig, ask a clarifying question """ self.to_process = utterance self.to_process = self._resolve_coref(history_consec) intent, score = self._intentpredictor() #print(score) if score > 0.5: entities = self._entityextractor() offense = self._offensepredictor() if intent in ['FINDPAPER', 'FINDDATASET', 'SUMMARIZEPAPER'] and len(entities) == 0: return {"modified_prompt": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense, "is_clear": False} return {"modified_prompt": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense, "is_clear": True} else: if self._ambigpredictor(): self.to_process = self._rewrite_query(history_sep) intent, score = self._intentpredictor() entities = self._entityextractor() offense = self._offensepredictor() if score > 0.5 or not self._ambigpredictor(): if intent in ['FINDPAPER', 'FINDDATASET', 'SUMMARIZEPAPER'] and len(entities) == 0: return {"modified_prompt": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense, "is_clear": False} return {"modified_prompt": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense, "is_clear": True} else: return {"modified_prompt": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense, "is_clear": False} else: entities = self._entityextractor() offense = self._offensepredictor() if intent in ['FINDPAPER', 'FINDDATASET', 'SUMMARIZEPAPER'] and len(entities) == 0: return {"modified_prompt": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense, "is_clear": False} return {"modified_prompt": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense, "is_clear": True}