From 423853ad659a489fc12e5149c759b02cd0c91bbe Mon Sep 17 00:00:00 2001 From: ahmed531998 Date: Tue, 4 Apr 2023 05:34:47 +0200 Subject: [PATCH] new_code --- DM.py | 63 ++++++---- NLU.py | 221 +++++++++++++++------------------- ResponseGenerator.py | 277 ++++++++++++++++++++++++------------------- User.py | 3 +- main.py | 148 +++++++++++------------ requirements.txt | 16 ++- 6 files changed, 373 insertions(+), 355 deletions(-) diff --git a/DM.py b/DM.py index f784bb0..e036bc3 100644 --- a/DM.py +++ b/DM.py @@ -1,43 +1,56 @@ import time class DM: - def __init__(self): - self.utt_history = "" - self.history = [] - self.state = None + def __init__(self, max_history_length=3): + self.working_history_sep = "" + self.working_history_consec = "" + self.max_history_length = max_history_length + self.chat_history = [] + self.curr_state = None + + def update_history(self): + to_consider = [x['modified_query'] for x in self.chat_history[-max_history_length*2:]] + self.working_history_consec = " . ".join(to_consider) + self.working_history_sep = " ||| ".join(to_consider) + + def get_consec_history(self): + return self.working_history_consec - def get_utt_history(self): - return self.utt_history + def get_sep_history(self): + return self.working_history_sep def get_recent_state(self): - return self.state + return self.curr_state - def get_dialogue_state(self): - return self.history + def get_dialogue_history(self): + return self.chat_history def update(self, new_state): - self.history.append(new_state) - self.utt_history = self.utt_history + " ||| " + new_state['modified_prompt'] - self.state = {'intent': new_state['intent'], - 'entities': new_state['entities'], - 'offensive': new_state['is_offensive'], - 'clear': new_state['is_clear'], - 'time': time.time()} + self.chat_history.append(new_state) + self.curr_state = new_state + self.update_history() def next_action(self): - if self.state['clear']: - if self.state['offensive']: - return "NoCanDo" + if self.curr_state['help']: + return "Help" + elif self.curr_state['inactive']: + return "Recommend" + elif self.curr_state['is_clear']: + if self.curr_state['is_offensive']: + return "OffenseReject" else: - if self.state['intent'] == 0: + if self.curr_state['intent'] == 'QA': return "RetGen" - elif self.state['intent'] == 1: + elif self.curr_state['intent'] == 'CHITCHAT': return "ConvGen" - elif self.state['intent'] == 2: + elif self.curr_state['intent'] == 'FINDPAPER': return "findPaper" - elif self.state['intent'] == 3: + elif self.curr_state['intent'] == 'FINDDATASET': return "findDataset" - elif self.state['intent'] == 4: + elif self.curr_state['intent'] == 'SUMMARIZEPAPER': return "sumPaper" else: - return "Clarify" + if self.curr_state['intent'] in ['FINDPAPER', 'FINDDATASET', 'SUMMARIZEPAPER'] and len(self.curr_state['entities']) == 0: + return "ClarifyResource" + else: + return "GenClarify" diff --git a/NLU.py b/NLU.py index edf7f60..4b0b056 100644 --- a/NLU.py +++ b/NLU.py @@ -1,143 +1,112 @@ -""" -from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline -import torch - - -class NLU: - def_tokenizer = AutoTokenizer.from_pretrained("castorini/t5-base-canard") - def_model = AutoModelForSeq2SeqLM.from_pretrained("castorini/t5-base-canard") - def_intent_classifier = pipeline("sentiment-analysis", model="/home/ahmed/PycharmProjects/Janet/JanetBackend/intent_classifier") - - def __init__(self, model=def_model, tokenizer=def_tokenizer, intent_classifier=def_intent_classifier, - max_history_length=1024, num_gen_seq=2, score_threshold=0.5): - self.input = "" - self.output = "" - self.model = model - self.tokenizer = tokenizer - self.max_length = max_history_length - self.num_return_sequences = num_gen_seq - self.score_threshold = score_threshold - self.label2id = {'Greet': 0, 'Bye': 1, 'GetKnowledge': 2, 'ChitChat': 3} - self.id2label = {0: 'Greet', 1: 'Bye', 2: 'GetKnowledge', 3: 'ChitChat'} - self.intent_classifier = intent_classifier - - def process_utterance(self, utterance, history): - if len(history) > 0: - # crop history - while len(history.split(" ")) > self.max_length: - index = history.find("|||") - history = history[index + 4:] - - self.input = history + " ||| " + utterance - inputs = self.tokenizer(self.input, max_length=self.max_length, truncation=True, padding="max_length", - return_tensors="pt") - - candidates = self.model.generate(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], - return_dict_in_generate=True, output_scores=True, - num_return_sequences=self.num_return_sequences, - num_beams=self.num_return_sequences) - for i in range(candidates["sequences"].shape[0]): - generated_sentence = self.tokenizer.decode(candidates["sequences"][i], skip_special_tokens=True, - clean_up_tokenization_spaces=True) - log_scores = candidates['sequences_scores'] - norm_prob = (torch.exp(log_scores[i]) / torch.exp(log_scores).sum()).item() - if norm_prob >= self.score_threshold: - self.score_threshold = norm_prob - self.output = generated_sentence - else: - self.output = utterance - - intent = self.label2id[self.intent_classifier(self.output)[0]['label']] - intent_conf = self.intent_classifier(self.output)[0]['score'] - - return {"modified_prompt": self.output, "mod_confidence": self.score_threshold, "prompt_intent": intent, - "intent_confidence": intent_conf} -""" - -import threading - import spacy import spacy_transformers import torch -from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline - class NLU: - def __init__(self, device, device_flag, reference_resolver, tokenizer, - intent_classifier, offense_filter, entity_extractor, - max_history_length=1024): - #entity_extractor=def_entity_extractor - self.reference_resolver = reference_resolver - self.device = device - self.reference_resolver.to(device) - self.tokenizer = tokenizer - self.max_length = max_history_length - self.label2idintent = {'QA': 0, 'CHITCHAT': 1, 'FINDPAPER': 2, 'FINDDATASET': 3, 'SUMMARIZEPAPER': 4} - self.id2labelintent = {0: 'QA', 1: 'CHITCHAT', 2: 'FINDPAPER', 3: 'FINDDATASET', 4: 'SUMMARIZEPAPER'} - self.label2idoffense = {'hate': 0, 'offensive': 1, 'neither': 2} - self.id2labeloffense = {0: 'hate', 1: 'offensive', 2: 'neither'} - self.intent_classifier = pipeline("sentiment-analysis", model=intent_classifier, device=device_flag) + def __init__(self, query_rewriter, coref_resolver, intent_classifier, offensive_classifier, entity_extractor, ambig_classifier): + + self.intent_classifier = intent_classifier self.entity_extractor = entity_extractor - self.offense_filter = pipeline("sentiment-analysis", model=offense_filter, device=device_flag) + self.offensive_classifier = offensive_classifier + self.coref_resolver = coref_resolver + self.query_rewriter = query_rewriter + self.ambig_classifier = ambig_classifier + + def _resolve_coref(self, history): + to_resolve = history + ' ' + self.to_process + doc = self.coref_resolver(to_resolve) + token_mention_mapper = {} + output_string = "" + clusters = [ + val for key, val in doc.spans.items() if key.startswith("coref_cluster") + ] + + # Iterate through every found cluster + for cluster in clusters: + first_mention = cluster[0] + # Iterate through every other span in the cluster + for mention_span in list(cluster)[1:]: + # Set first_mention as value for the first token in mention_span in the token_mention_mapper + token_mention_mapper[mention_span[0].idx] = first_mention.text + mention_span[0].whitespace_ + for token in mention_span[1:]: + # Set empty string for all the other tokens in mention_span + token_mention_mapper[token.idx] = "" + + # Iterate through every token in the Doc + for token in doc: + # Check if token exists in token_mention_mapper + if token.idx in token_mention_mapper: + output_string += token_mention_mapper[token.idx] + # Else add original token text + else: + output_string += token.text + token.whitespace_ + cleaned_query = output_string.split(" ", 1)[1] + return cleaned_query - self.intents = None - self.entities = None - self.offensive = None - self.clear = True - def _intentpredictor(self): - self.intents = self.label2idintent[self.intent_classifier(self.to_process)[0]['label']] - + pred = self.intent_classifier(self.to_process)[0] + return pred['label'], pred['score'] + + def _ambigpredictor(self): + pred = self.ambig_classifier(self.to_process)[0] + if pred['label'] in ['clear', 'somewhat_clear']: + return False + else: + return True + def _entityextractor(self): - self.entities = [] + entities = [] doc = self.entity_extractor(self.to_process) for entity in doc.ents: if entity.text not in ['.', ',', '?', ';']: - self.entities.append({'entity': entity.label_, 'value': entity.text}) + entities.append({'entity': entity.label_, 'value': entity.text}) + return entities - def _inappropriatedetector(self): - self.offensive = False - is_offensive = self.label2idoffense[self.offense_filter(self.to_process)[0]['label']] - if is_offensive == 0 or is_offensive == 1: - self.offensive = True + def _offensepredictor(self): + pred = self.offensive_classifier(self.to_process)[0]['label'] + if pred != "neither": + return True + else: + return False - def process_utterance(self, utterance, history): + def _rewrite_query(self, history): + text = history + " ||| " + self.to_process + return self.query_rewriter(text)[0]['generated_text'] + + + def process_utterance(self, utterance, history_consec, history_sep): """ - Given an utterance and the history of the conversation, refine the query contextually and return a refined - utterance + Query -> coref resolution & intent extraction -> if intents are not confident or if query is ambig -> rewrite query and recheck -> if still ambig, ask a clarifying question """ self.to_process = utterance - if len(history) > 0: - # crop history - while len(history.split(" ")) > self.max_length: - index = history.find("|||") - history = history[index + 4:] - - context = history + " ||| " + utterance - inputs = self.tokenizer(context, max_length=self.max_length, truncation=True, padding="max_length", - return_tensors="pt") - - candidates = self.reference_resolver.generate(input_ids=inputs["input_ids"].to(self.device), - attention_mask=inputs["attention_mask"].to(self.device), - return_dict_in_generate=True, output_scores=True, - num_return_sequences=1, - num_beams=5) - self.to_process = self.tokenizer.decode(candidates["sequences"][0], skip_special_tokens=True, - clean_up_tokenization_spaces=True) - - t1 = threading.Thread(target=self._intentpredictor, name='intent') - t2 = threading.Thread(target=self._entityextractor, name='entity') - t3 = threading.Thread(target=self._inappropriatedetector, name='offensive') - - t3.start() - t1.start() - t2.start() - - t3.join() - t1.join() - t2.join() - return {"modified_prompt": self.to_process, - "intent": self.intents, - "entities": self.entities, - "is_offensive": self.offensive, - "is_clear": self.clear} + + self.to_process = self._resolve_coref(history_consec) + + intent, score = self._intentpredictor() + #print(score) + if score > 0.5: + entities = self._entityextractor() + offense = self._offensepredictor() + if intent in ['FINDPAPER', 'FINDDATASET', 'SUMMARIZEPAPER'] and len(entities) == 0: + return {"modified_prompt": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense, "is_clear": False} + return {"modified_prompt": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense, "is_clear": True} + else: + if self._ambigpredictor(): + self.to_process = self._rewrite_query(history_sep) + intent, score = self._intentpredictor() + entities = self._entityextractor() + offense = self._offensepredictor() + if score > 0.5 or not self._ambigpredictor(): + if intent in ['FINDPAPER', 'FINDDATASET', 'SUMMARIZEPAPER'] and len(entities) == 0: + return {"modified_prompt": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense, "is_clear": False} + return {"modified_prompt": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense, + "is_clear": True} + else: + return {"modified_prompt": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense, + "is_clear": False} + else: + entities = self._entityextractor() + offense = self._offensepredictor() + if intent in ['FINDPAPER', 'FINDDATASET', 'SUMMARIZEPAPER'] and len(entities) == 0: + return {"modified_prompt": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense, "is_clear": False} + return {"modified_prompt": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense, "is_clear": True} diff --git a/ResponseGenerator.py b/ResponseGenerator.py index 8d56e44..7ad414c 100644 --- a/ResponseGenerator.py +++ b/ResponseGenerator.py @@ -4,140 +4,177 @@ import faiss from sklearn.metrics.pairwise import cosine_similarity import numpy as np import pandas as pd - +from datetime import datetime class ResponseGenerator: - def __init__(self, index, db, - generator, retriever, num_retrieved=1): - self.generator = generator - self.retriever = retriever - self.db = db - self.index = index - self.num_retrieved = num_retrieved - self.paper = {} - self.dataset = {} + def __init__(self, index, db,recommender,generators, retriever, num_retrieved=1): + self.generators = generators + self.retriever = retriever + self.recommender = recommender + self.db = db + self.index = index + self.num_retrieved = num_retrieved + self.paper = {} + self.dataset = {} - def update_index(self, index): - self.index = index - def update_db(self, db): - self.db = db + def update_index(self, index): + self.index = index + def update_db(self, db): + self.db = db - def _get_resources_links(self, item): - if len(item) == 0: - return [] - links = [] - for rsrc in item['resources']: - links.append(rsrc['url']) - return links + def _get_resources_links(self, item): + if len(item) == 0: + return [] + links = [] + for rsrc in item['resources']: + links.append(rsrc['url']) + return links - def _get_matching_titles(self, rsrc, title): - cand = self.db[rsrc].loc[self.db[rsrc]['title'] == title.lower()].reset_index(drop=True) - if not cand.empty: - return cand.loc[0] - else: - return {} + def _get_matching_titles(self, rsrc, title): + cand = self.db[rsrc].loc[self.db[rsrc]['title'] == title.lower()].reset_index(drop=True) + if not cand.empty: + return cand.loc[0] + else: + return {} - def _get_matching_topics(self, rsrc, topic): - matches = [] - score = 0.7 - for i, cand in self.db[rsrc].iterrows(): - for tag in cand['tags']: - sim = cosine_similarity(np.array(self.retriever.encode([tag])), np.array(self.retriever.encode([topic.lower()]))) - if sim > score: - if(len(matches)>0): - matches[0] = cand - else: - matches.append(cand) - score = sim - if len(matches) > 0: - return matches[0] - else: - return [] + def _get_matching_authors(self, rsrc, author): + cand = self.db[rsrc].loc[self.db[rsrc]['author'] == author.lower()].reset_index(drop=True) + if not cand.empty: + return cand.loc[0] + else: + return {} - def _search_index(self, index_type, db_type, query): - xq = self.retriever.encode([query]) - D, I = self.index[index_type].search(xq, self.num_retrieved) - return self.db[db_type].iloc[[I[0]][0]].reset_index(drop=True).loc[0] + def _get_matching_topics(self, rsrc, topic): + matches = [] + score = 0.7 + for i, cand in self.db[rsrc].iterrows(): + for tag in cand['tags']: + sim = cosine_similarity(np.array(self.retriever.encode([tag])), np.array(self.retriever.encode([topic.lower()]))) + if sim > score: + if(len(matches)>0): + matches[0] = cand + else: + matches.append(cand) + score = sim + if len(matches) > 0: + return matches[0] + else: + return [] + + def _search_index(self, index_type, db_type, query): + xq = self.retriever.encode([query]) + D, I = self.index[index_type].search(xq, self.num_retrieved) + return self.db[db_type].iloc[[I[0]][0]].reset_index(drop=True).loc[0] - def gen_response(self, utterance, state, history, action): - if action == "NoCanDo": - return str("I am sorry, I cannot answer to this kind of language") + def gen_response(self, action, utterance=None, username=None, state=None, consec_history=None): + if action == "Help": + return "Hey it's Janet! I am here to help you make use of the datasets and papers in the VRE. I can answer questions whose answers may be inside the papers. I can summarize papers for you. I can also chat with you. So, whichever it is, I am ready to chat!" + elif action == "Recommend": + prompt = self.recommender.make_recommendation(username) + if prompt != "": + return prompt + else: + return "I can help you with exploiting the contents of the VRE, just let me know!" - elif action == "ConvGen": - gen_kwargs = {"length_penalty": 2.5, "num_beams":4, "max_length": 20} - answer = self.generator('question: '+ utterance + ' context: ' + history , **gen_kwargs)[0]['generated_text'] - return answer + elif action == "OffenseReject": + return "I am sorry, I cannot answer to this kind of language" - elif action == "findPaper": - for entity in state['entities']: - if (entity['entity'] == 'TITLE'): - self.paper = self._get_matching_titles('paper_db', entity['value']) - links = self._get_resources_links(self.paper) - if len(self.paper) > 0 and len(links) > 0: - return str("This paper could be helpful: " + self.paper['title'] + '. ' + "It can be downloaded at " + links[0]) - else: - self.paper = self._search_index('paper_titles_index', 'paper_db', entity['value']) + elif action == "ConvGen": + gen_kwargs = {"length_penalty": 2.5, "num_beams":2, "max_length": 30} + answer = self.generators['chat']('history: '+ consec_history + ' ' + utterance + ' persona: ' + 'I am Janet. My name is Janet. I am an AI developed by CNR to help VRE users.' , **gen_kwargs)[0]['generated_text'] + return answer + + elif action == "findPaper": + for entity in state['entities']: + if (entity['entity'] == 'TITLE'): + self.paper = self._get_matching_titles('paper_db', entity['value']) + links = self._get_resources_links(self.paper) + if len(self.paper) > 0 and len(links) > 0: + return str("Here is the paper you want: " + self.paper['title'] + '. ' + "It can be downloaded at " + links[0]) + else: + self.paper = self._search_index('paper_titles_index', 'paper_db', entity['value']) + links = self._get_resources_links(self.paper) + return str("This paper could be relevant: " + self.paper['title'] + '. ' + "It can be downloaded at " + links[0]) + if(entity['entity'] == 'TOPIC'): + self.paper = self._get_matching_topics('paper_db', entity['value']) + links = self._get_resources_links(self.paper) + if len(self.paper) > 0 and len(links) > 0: + return str("This paper could be relevant: " + self.paper['title'] + '. ' + "It can be downloaded at " + links[0]) + + if(entity['entity'] == 'AUTHOR'): + self.paper = self._get_matching_authors('paper_db', entity['value']) + links = self._get_resources_links(self.paper) + if len(self.paper) > 0 and len(links) > 0: + return str("Here is the paper you want: " + self.paper['title'] + '. ' + "It can be downloaded at " + links[0]) + + self.paper = self._search_index('paper_desc_index', 'paper_db', utterance) links = self._get_resources_links(self.paper) - return str("This paper could be helpful: " + self.paper['title'] + '. ' + "It can be downloaded at " + links[0]) - if(entity['entity'] == 'TOPIC'): - self.paper = self._get_matching_topics('paper_db', entity['value']) - links = self._get_resources_links(self.paper) - if len(self.paper) > 0 and len(links) > 0: - return str("This paper could be helpful: " + self.paper['title'] + '. ' + "It can be downloaded at " + links[0]) - self.paper = self._search_index('paper_desc_index', 'paper_db', utterance) - links = self._get_resources_links(self.paper) - return str("This paper could be helpful: " + self.paper['title'] + '. ' + "It can be downloaded at " + links[0]) + return str("This paper could be relevant: " + self.paper['title'] + '. ' + "It can be downloaded at " + links[0]) - elif action == "findDataset": - for entity in state['entities']: - if (entity['entity'] == 'TITLE'): - self.dataset = self._get_matching_titles('dataset_db', entity['value']) - links = self._get_resources_links(self.dataset) - if len(self.dataset) > 0 and len(links) > 0: - return str("This dataset could be helpful: " + self.dataset['title'] + '. ' + "It can be downloaded at " + links[0]) - else: - self.dataset = self._search_index('dataset_titles_index', 'dataset_db', entity['value']) + elif action == "findDataset": + for entity in state['entities']: + if (entity['entity'] == 'TITLE'): + self.dataset = self._get_matching_titles('dataset_db', entity['value']) + links = self._get_resources_links(self.dataset) + if len(self.dataset) > 0 and len(links) > 0: + return str("Here is the dataset you wanted: " + self.dataset['title'] + '. ' + "It can be downloaded at " + links[0]) + else: + self.dataset = self._search_index('dataset_titles_index', 'dataset_db', entity['value']) + links = self._get_resources_links(self.dataset) + return str("This dataset could be relevant: " + self.dataset['title'] + '. ' + "It can be downloaded at " + links[0]) + if(entity['entity'] == 'TOPIC'): + self.dataset = self._get_matching_topics('dataset_db', entity['value']) + links = self._get_resources_links(self.dataset) + if len(self.dataset) > 0 and len(links) > 0: + return str("This dataset could be relevant: " + self.dataset['title'] + '. ' + "It can be downloaded at " + links[0]) + + if(entity['entity'] == 'AUTHOR'): + self.dataset = self._get_matching_authors('dataset_db', entity['value']) + links = self._get_resources_links(self.dataset) + if len(self.dataset) > 0 and len(links) > 0: + return str("Here is the dataset you want: " + self.dataset['title'] + '. ' + "It can be downloaded at " + links[0]) + + self.dataset = self._search_index('dataset_desc_index', 'dataset_db', utterance) links = self._get_resources_links(self.dataset) - return str("This dataset could be helpful: " + self.dataset['title'] + '. ' + "It can be downloaded at " + links[0]) - if(entity['entity'] == 'TOPIC'): - self.dataset = self._get_matching_topics('dataset_db', entity['value']) - links = self._get_resources_links(self.dataset) - if len(self.dataset) > 0 and len(links) > 0: - return str("This dataset could be helpful: " + self.dataset['title'] + '. ' + "It can be downloaded at " + links[0]) - self.dataset = self._search_index('dataset_desc_index', 'dataset_db', utterance) - links = self._get_resources_links(self.dataset) - return str("This dataset could be helpful: " + self.dataset['title'] + '. ' + "It can be downloaded at " + links[0]) - + return str("This dataset could be relevant: " + self.dataset['title'] + '. ' + "It can be downloaded at " + links[0]) + - elif action == "RetGen": - #retrieve the most relevant paragraph - content = str(self._search_index('content_index', 'content_db', utterance)['content']) - #generate the answer - gen_seq = 'question: '+utterance+" context: "+content + elif action == "RetGen": + #retrieve the most relevant paragraph + content = str(self._search_index('content_index', 'content_db', utterance)['content']) + #generate the answer + gen_seq = 'question: '+utterance+" context: "+content + + #handle return random 2 answers + gen_kwargs = {"length_penalty": 0.5, "num_beams":2, "max_length": 60} + answer = self.generators['qa'](gen_seq, **gen_kwargs)[0]['generated_text'] + return str(answer) - #handle return random 2 answers - gen_kwargs = {"length_penalty": 0.5, "num_beams":8, "max_length": 100} - answer = self.generator(gen_seq, **gen_kwargs)[0]['generated_text'] - return str(answer) - - elif action == "sumPaper": - if len(self.paper) == 0: - for entity in state['entities']: - if (entity['entity'] == 'TITLE'): - self.paper = self._get_matching_titles('paper_db', entity['value']) - if (len(self.paper) > 0): - break - if len(self.paper) == 0: - return "I cannot seem to find the requested paper. Try again by specifying the title of the paper." - #implement that - df = self.db['content_db'][self.db['content_db']['paperid'] == self.paper['id']] - answer = "" - for i, row in df.iterrows(): - gen_seq = 'summarize: '+row['content'] - gen_kwargs = {"length_penalty": 1.5, "num_beams":8, "max_length": 100} - answer = self.generator(gen_seq, **gen_kwargs)[0]['generated_text'] + ' ' - return answer + elif action == "sumPaper": + if len(self.paper) == 0: + for entity in state['entities']: + if (entity['entity'] == 'TITLE'): + self.paper = self._get_matching_titles('paper_db', entity['value']) + if (len(self.paper) > 0): + break + if len(self.paper) == 0: + return "I cannot seem to find the requested paper. Try again by specifying the title of the paper." + #implement that + df = self.db['content_db'][self.db['content_db']['paperid'] == self.paper['id']] + answer = "" + for i, row in df.iterrows(): + gen_seq = 'summarize: '+row['content'] + gen_kwargs = {"length_penalty": 1.5, "num_beams":6, "max_length": 120} + answer = self.generators['summ'](gen_seq, **gen_kwargs)[0]['generated_text'] + ' ' + return answer - elif action == "Clarify": - return str("Can you please clarify?") + elif action == "ClarifyResource": + if state['intent'] in ['FINDPAPER', 'SUMMARIZEPAPER']: + return 'Please specify the title, the topic or the paper of interest.' + else: + return 'Please specify the title, the topic or the dataset of interest.' + elif action == "GenClarify": + gen_kwargs = {"length_penalty": 2.5, "num_beams":8, "max_length": 120} + question = self.generators['amb']('question: '+ utterance + ' context: ' + consec_history , **gen_kwargs)[0]['generated_text'] + return question diff --git a/User.py b/User.py index ccf6708..5159b2a 100644 --- a/User.py +++ b/User.py @@ -20,8 +20,7 @@ class User: if len(index) > 0: self.interests.at[index[0], 'frequency'] += 1 else: - self.interests = self.interests.append({'interest': topic, 'frequency': max( - self.interests['frequency']) if not self.interests.empty else 6}, ignore_index=True) + self.interests = self.interests.append({'interest': topic, 'frequency': max(self.interests['frequency']) if not self.interests.empty else 6}, ignore_index=True) self.interests = self.interests.sort_values(by='frequency', ascending=False, ignore_index=True) self.interests.to_json(self.interests_file) diff --git a/main.py b/main.py index 4135a8b..b3c0636 100644 --- a/main.py +++ b/main.py @@ -1,6 +1,5 @@ import os import warnings - import faiss import torch from flask import Flask, render_template, request, jsonify @@ -10,131 +9,125 @@ import spacy import spacy_transformers import torch from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline - from User import User from VRE import VRE from NLU import NLU from DM import DM from Recommender import Recommender from ResponseGenerator import ResponseGenerator - - import pandas as pd import time import threading - from sentence_transformers import SentenceTransformer app = Flask(__name__) -#allow frontend address url = os.getenv("FRONTEND_URL_WITH_PORT") cors = CORS(app, resources={r"/predict": {"origins": url}, r"/feedback": {"origins": url}}) -#cors = CORS(app, resources={r"/predict": {"origins": "*"}, r"/feedback": {"origins": "*"}}) - +""" conn = psycopg2.connect( - host="janet-pg", + host="https://janet-app-db.d4science.org", database=os.getenv("POSTGRES_DB"), user=os.getenv("POSTGRES_USER"), password=os.getenv("POSTGRES_PASSWORD")) +""" +conn = psycopg2.connect(host="https://janet-app-db.d4science.org", + database="janet", + user="janet_user", + password="2fb5e81fec5a2d906a04") cur = conn.cursor() - -#rg = ResponseGenerator(index) - -def get_response(text): - # get response from janet itself - return text, 'candAnswer' - def vre_fetch(): - while True: - time.sleep(1000) - print('getting new material') - vre.get_vre_update() - vre.index_periodic_update() - rg.update_index(vre.get_index()) - rg.update_db(vre.get_db()) + while True: + time.sleep(1000) + print('getting new material') + vre.get_vre_update() + vre.index_periodic_update() + rg.update_index(vre.get_index()) + rg.update_db(vre.get_db()) def user_interest_decay(): - while True: - print("decaying interests after 3 minutes") - time.sleep(180) - user.decay_interests() - -def recommend(): - while True: - if time.time() - dm.get_recent_state()['time'] > 1000: - print("Making Recommendation: ") - prompt = rec.make_recommendation(user.username) - if prompt != "": - print(prompt) - time.sleep(1000) - + while True: + print("decaying interests after 3 minutes") + time.sleep(180) + user.decay_interests() @app.route("/predict", methods=['POST']) def predict(): text = request.get_json().get("message") - state = nlu.process_utterance(text, dm.get_utt_history()) - user_interests = [] - for entity in state['entities']: - if entity['entity'] == 'TOPIC': - user_interests.append(entity['value']) - user.update_interests(user_interests) - dm.update(state) - action = dm.next_action() - response = rg.gen_response(state['modified_prompt'], dm.get_recent_state(), dm.get_utt_history(), action) - message = {"answer": response, "query": text, "cand": "candidate", "history": dm.get_utt_history(), "modQuery": state['modified_prompt']} + message = {} + if text == "": + state = {'help': True, 'inactive': False} + dm.update(state) + action = dm.next_action() + response = rg.gen_response(action) + message = {"answer": response} + elif text == "": + state = {'help': False, 'inactive': True} + dm.update(state) + action = dm.next_action() + response = rg.gen_response(action, username=user.username) + message = {"answer": response} + else: + state = nlu.process_utterance(text, dm.get_consec_history(), dm.get_sep_history()) + state['help'] = False + state['inactive'] = False + user_interests = [] + for entity in state['entities']: + if entity['entity'] == 'TOPIC': + user_interests.append(entity['value']) + user.update_interests(user_interests) + dm.update(state) + action = dm.next_action() + self, action, utterance=None, username=None, state=None, consec_history=None + response = rg.gen_response(action, utterance=state['modified_query'], state=dm.get_recent_state(), consec_history=dm.get_consec_history()) + message = {"answer": response, "query": text, "cand": "candidate", "history": dm.get_consec_history(), "modQuery": state['modified_query']} + new_state = {'modified_query': response} + dm.update(new_state) reply = jsonify(message) - #reply.headers.add('Access-Control-Allow-Origin', '*') return reply @app.route('/feedback', methods = ['POST']) def feedback(): data = request.get_json()['feedback'] - # Make data frame of above data print(data) - #df = pd.DataFrame([data]) - #file_exists = os.path.isfile('feedback.csv') - - #df = pd.DataFrame(data=[data['response'], data['length'], data['fluency'], data['truthfulness'], data['usefulness'], data['speed']] - # ,columns=['response', 'length', 'fluency', 'truthfulness', 'usefulness', 'speed']) - #df.to_csv('feedback.csv', mode='a', index=False, header=(not file_exists)) - cur.execute('INSERT INTO feedback (query, history, janet_modified_query, - is_modified_query_correct, user_modified_query, - response, preferred_response, response_length_feedback, - response_fluency_feedback, response_truth_feedback, - response_useful_feedback, response_time_feedback, response_intent)' - 'VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)', + cur.execute('INSERT INTO feedback (query, history, janet_modified_query, is_modified_query_correct, user_modified_query, response, preferred_response, response_length_feedback, response_fluency_feedback, response_truth_feedback, response_useful_feedback, response_time_feedback, response_intent) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)', (data['query'], data['history'], data['modQuery'], data['queryModCorrect'], data['correctQuery'], data['janetResponse'], data['preferredResponse'], data['length'], data['fluency'], data['truthfulness'], data['usefulness'], data['speed'], data['intent']) ) - reply = jsonify({"status": "done"}) - #reply.headers.add('Access-Control-Allow-Origin', '*') return reply if __name__ == "__main__": warnings.filterwarnings("ignore") - #load NLU - def_tokenizer = AutoTokenizer.from_pretrained("castorini/t5-base-canard") - def_reference_resolver = AutoModelForSeq2SeqLM.from_pretrained("castorini/t5-base-canard") - def_intent_classifier_dir = "./IntentClassifier/" - def_entity_extractor = spacy.load("./EntityExtraction/BestModel") - def_offense_filter_dir ="./OffensiveClassifier" device = "cuda" if torch.cuda.is_available() else "cpu" device_flag = torch.cuda.current_device() if torch.cuda.is_available() else -1 - nlu = NLU(device, device_flag, def_reference_resolver, def_tokenizer, def_intent_classifier_dir, def_offense_filter_dir, def_entity_extractor) + + query_rewriter = pipeline("text2text-generation", model="castorini/t5-base-canard") + intent_classifier = pipeline("sentiment-analysis", model='./intent_classifier', device=device_flag) + entity_extractor = spacy.load("./entity_extractor") + offensive_classifier = pipeline("sentiment-analysis", model='./offensive_classifier', device=device_flag) + ambig_classifier = pipeline("sentiment-analysis", model='./ambig_classifier', device=device_flag) + coref_resolver = spacy.load("en_coreference_web_trf") + + nlu = NLU(query_rewriter, coref_resolver, intent_classifier, offensive_classifier, entity_extractor, ambig_classifier) #load retriever and generator - def_retriever = SentenceTransformer('./BigRetriever/').to(device) - def_generator = pipeline("text2text-generation", model="./generator", device=device_flag) - + retriever = SentenceTransformer('./BigRetriever/').to(device) + qa_generator = pipeline("text2text-generation", model="./train_qa", device=device_flag) + summ_generator = pipeline("text2text-generation", model="./train_summ", device=device_flag) + chat_generator = pipeline("text2text-generation", model="./train_chat", device=device_flag) + amb_generator = pipeline("text2text-generation", model="./train_amb_gen", device=device_flag) + generators = {'qa': qa_generator, + 'chat': chat_generator, + 'amb': amb_generator, + 'summ': summ_generator} #load vre token = '2c1e8f88-461c-42c0-8cc1-b7660771c9a3-843339462' @@ -148,15 +141,14 @@ if __name__ == "__main__": threading.Thread(target=user_interest_decay, name='decayinterest').start() - - rec = Recommender(def_retriever) + rec = Recommender(retriever) dm = DM() - rg = ResponseGenerator(index,db,def_generator,def_retriever) - threading.Thread(target=recommend, name='recommend').start() + + rg = ResponseGenerator(index,db, recommender, generators, retriever) - cur.execute('CREATE TABLE IF NOT EXISTS feedback (id serial PRIMARY KEY,' + cur.execute('CREATE TABLE IF NOT EXISTS feedback_trial (id serial PRIMARY KEY,' 'query text NOT NULL,' 'history text NOT NULL,' 'janet_modified_query text NOT NULL,' diff --git a/requirements.txt b/requirements.txt index 13f7bfb..1ae4e69 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,18 +11,26 @@ regex==2022.6.2 requests==2.25.1 scikit-learn==1.0.2 scipy==1.7.3 -sentence-transformers==2.2.2 sentencepiece==0.1.97 sklearn-pandas==1.8.0 -spacy==3.5.0 -spacy-transformers==1.2.2 +spacy==3.4.4 +spacy-alignments==0.9.0 +spacy-legacy==3.0.12 +spacy-loggers==1.0.4 +spacy-transformers==1.1.9 +spacy-experimental==0.6.2 torch @ https://download.pytorch.org/whl/cu116/torch-1.13.1%2Bcu116-cp38-cp38-linux_x86_64.whl torchaudio @ https://download.pytorch.org/whl/cu116/torchaudio-0.13.1%2Bcu116-cp38-cp38-linux_x86_64.whl torchsummary==1.5.1 torchtext==0.14.1 +sentence-transformers torchvision @ https://download.pytorch.org/whl/cu116/torchvision-0.14.1%2Bcu116-cp38-cp38-linux_x86_64.whl tqdm==4.64.1 -transformers==4.26.1 +transformers markupsafe==2.0.1 psycopg2==2.9.5 +en-coreference-web-trf @ https://github.com/explosion/spacy-experimental/releases/download/v0.6.1/en_coreference_web_trf-3.4.0a2-py3-none-any.whl Werkzeug==1.0.1 + + +