#import spacy #import spacy_transformers import torch import logging class NLU: def __init__(self, LLM_tokenizer, LLM_model): """ self.intent_classifier = intent_classifier self.entity_extractor = entity_extractor self.offensive_classifier = offensive_classifier self.coref_resolver = coref_resolver self.query_rewriter = query_rewriter self.ambig_classifier = ambig_classifier """ self.tokenizer = LLM_tokenizer self.model = LLM_model """ def _resolve_coref(self, history): to_resolve = history + ' ' + self.to_process doc = self.coref_resolver(to_resolve) token_mention_mapper = {} output_string = "" clusters = [ val for key, val in doc.spans.items() if key.startswith("coref_cluster") ] clusters = [] for cluster in cand_clusters: if cluster[0].text == "I": continue clusters.append(cluster) # Iterate through every found cluster for cluster in clusters: first_mention = cluster[0] # Iterate through every other span in the cluster for mention_span in list(cluster)[1:]: # Set first_mention as value for the first token in mention_span in the token_mention_mapper token_mention_mapper[mention_span[0].idx] = first_mention.text + mention_span[0].whitespace_ for token in mention_span[1:]: # Set empty string for all the other tokens in mention_span token_mention_mapper[token.idx] = "" # Iterate through every token in the Doc for token in doc: # Check if token exists in token_mention_mapper if token.idx in token_mention_mapper: output_string += token_mention_mapper[token.idx] # Else add original token text else: output_string += token.text + token.whitespace_ if len(output_string.split(" ", 1)) > 1: cleaned_query = output_string.split(" ", 1)[1] return cleaned_query else: cleaned_query = output_string.split(" ", 1)[0] return cleaned_query def _intentpredictor(self): pred = self.intent_classifier(self.to_process)[0] return pred['label'], pred['score'] def _ambigpredictor(self): pred = self.ambig_classifier(self.to_process)[0] if pred['label'] in ['clear', 'somewhat_clear']: return False else: return True def _entityextractor(self): entities = [] doc = self.entity_extractor(self.to_process) for entity in doc.ents: if entity.text not in ['.', ',', '?', ';']: entities.append({'entity': entity.label_, 'value': entity.text}) return entities def _offensepredictor(self): pred = self.offensive_classifier(self.to_process)[0]['label'] if pred != "neither": return True else: return False def _rewrite_query(self, history): text = history + " ||| " + self.to_process return self.query_rewriter(text)[0]['generated_text'] """ def process_utterance(self, history): """ Query -> coref resolution & intent extraction -> if intents are not confident or if query is ambig -> rewrite query and recheck -> if still ambig, ask a clarifying question """ #if utterance.lower() in ["help", "list resources", "list papers", "list datasets", "list topics"]: # return {"modified_query": utterance.lower(), "intent": "COMMAND", "entities": [], "is_offensive": False, "is_clear": True} #self.to_process = utterance prompt = f"""You are Janet, the virtual assistant of the virtual research enviornment users. What does the user eventually want given this dialogue, which is delimited with triple backticks? Give your answer in one single sentence. Dialogue: '''{history}''' """ chat = [{ "role": "user", "content": prompt },] prompt_chat = self.tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True) inputs = self.tokenizer.encode(prompt_chat, add_special_tokens=False, return_tensors="pt") outputs = self.model.generate(input_ids=inputs, max_new_tokens=150) goal = self.tokenizer.decode(outputs[0]) logging.info("User's goal is:" + goal) #return goal.split("model\n")[-1].split("")[0] return {"modified_query": goal.split("model\n")[-1].split("")[0], "intent": "QA", "entities": [], "is_offensive": False, "is_clear": True} #self.to_process = self._resolve_coref(history_consec) #intent, score = self._intentpredictor() #if score > 0.5: # if intent == 'CHITCHAT': # self.to_process = utterance # entities = self._entityextractor() # offense = self._offensepredictor() # if intent in ['FINDPAPER', 'FINDDATASET', 'SUMMARIZEPAPER'] and len(entities) == 0: # return {"modified_query": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense, "is_clear": False} # return {"modified_query": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense, "is_clear": True} #else: # if self._ambigpredictor(): # self.to_process = self._rewrite_query(history_sep) # intent, score = self._intentpredictor() # entities = self._entityextractor() # offense = self._offensepredictor() # if score > 0.5 or not self._ambigpredictor(): # if intent == 'CHITCHAT': # self.to_process = utterance # if intent in ['FINDPAPER', 'FINDDATASET', 'SUMMARIZEPAPER'] and len(entities) == 0: # return {"modified_query": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense, "is_clear": False} # return {"modified_query": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense, # "is_clear": True} # else: # return {"modified_query": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense, # "is_clear": False} # else: # entities = self._entityextractor() # offense = self._offensepredictor() # if intent == 'CHITCHAT': # self.to_process = utterance # if intent in ['FINDPAPER', 'FINDDATASET', 'SUMMARIZEPAPER'] and len(entities) == 0: # return {"modified_query": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense, "is_clear": False} # return {"modified_query": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense, "is_clear": True}