This commit is contained in:
ahmed531998 2023-04-04 05:34:47 +02:00
parent 5c328ce7df
commit 423853ad65
6 changed files with 373 additions and 355 deletions

63
DM.py
View File

@ -1,43 +1,56 @@
import time import time
class DM: class DM:
def __init__(self): def __init__(self, max_history_length=3):
self.utt_history = "" self.working_history_sep = ""
self.history = [] self.working_history_consec = ""
self.state = None self.max_history_length = max_history_length
self.chat_history = []
self.curr_state = None
def update_history(self):
to_consider = [x['modified_query'] for x in self.chat_history[-max_history_length*2:]]
self.working_history_consec = " . ".join(to_consider)
self.working_history_sep = " ||| ".join(to_consider)
def get_consec_history(self):
return self.working_history_consec
def get_utt_history(self): def get_sep_history(self):
return self.utt_history return self.working_history_sep
def get_recent_state(self): def get_recent_state(self):
return self.state return self.curr_state
def get_dialogue_state(self): def get_dialogue_history(self):
return self.history return self.chat_history
def update(self, new_state): def update(self, new_state):
self.history.append(new_state) self.chat_history.append(new_state)
self.utt_history = self.utt_history + " ||| " + new_state['modified_prompt'] self.curr_state = new_state
self.state = {'intent': new_state['intent'], self.update_history()
'entities': new_state['entities'],
'offensive': new_state['is_offensive'],
'clear': new_state['is_clear'],
'time': time.time()}
def next_action(self): def next_action(self):
if self.state['clear']: if self.curr_state['help']:
if self.state['offensive']: return "Help"
return "NoCanDo" elif self.curr_state['inactive']:
return "Recommend"
elif self.curr_state['is_clear']:
if self.curr_state['is_offensive']:
return "OffenseReject"
else: else:
if self.state['intent'] == 0: if self.curr_state['intent'] == 'QA':
return "RetGen" return "RetGen"
elif self.state['intent'] == 1: elif self.curr_state['intent'] == 'CHITCHAT':
return "ConvGen" return "ConvGen"
elif self.state['intent'] == 2: elif self.curr_state['intent'] == 'FINDPAPER':
return "findPaper" return "findPaper"
elif self.state['intent'] == 3: elif self.curr_state['intent'] == 'FINDDATASET':
return "findDataset" return "findDataset"
elif self.state['intent'] == 4: elif self.curr_state['intent'] == 'SUMMARIZEPAPER':
return "sumPaper" return "sumPaper"
else: else:
return "Clarify" if self.curr_state['intent'] in ['FINDPAPER', 'FINDDATASET', 'SUMMARIZEPAPER'] and len(self.curr_state['entities']) == 0:
return "ClarifyResource"
else:
return "GenClarify"

221
NLU.py
View File

@ -1,143 +1,112 @@
"""
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
import torch
class NLU:
def_tokenizer = AutoTokenizer.from_pretrained("castorini/t5-base-canard")
def_model = AutoModelForSeq2SeqLM.from_pretrained("castorini/t5-base-canard")
def_intent_classifier = pipeline("sentiment-analysis", model="/home/ahmed/PycharmProjects/Janet/JanetBackend/intent_classifier")
def __init__(self, model=def_model, tokenizer=def_tokenizer, intent_classifier=def_intent_classifier,
max_history_length=1024, num_gen_seq=2, score_threshold=0.5):
self.input = ""
self.output = ""
self.model = model
self.tokenizer = tokenizer
self.max_length = max_history_length
self.num_return_sequences = num_gen_seq
self.score_threshold = score_threshold
self.label2id = {'Greet': 0, 'Bye': 1, 'GetKnowledge': 2, 'ChitChat': 3}
self.id2label = {0: 'Greet', 1: 'Bye', 2: 'GetKnowledge', 3: 'ChitChat'}
self.intent_classifier = intent_classifier
def process_utterance(self, utterance, history):
if len(history) > 0:
# crop history
while len(history.split(" ")) > self.max_length:
index = history.find("|||")
history = history[index + 4:]
self.input = history + " ||| " + utterance
inputs = self.tokenizer(self.input, max_length=self.max_length, truncation=True, padding="max_length",
return_tensors="pt")
candidates = self.model.generate(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"],
return_dict_in_generate=True, output_scores=True,
num_return_sequences=self.num_return_sequences,
num_beams=self.num_return_sequences)
for i in range(candidates["sequences"].shape[0]):
generated_sentence = self.tokenizer.decode(candidates["sequences"][i], skip_special_tokens=True,
clean_up_tokenization_spaces=True)
log_scores = candidates['sequences_scores']
norm_prob = (torch.exp(log_scores[i]) / torch.exp(log_scores).sum()).item()
if norm_prob >= self.score_threshold:
self.score_threshold = norm_prob
self.output = generated_sentence
else:
self.output = utterance
intent = self.label2id[self.intent_classifier(self.output)[0]['label']]
intent_conf = self.intent_classifier(self.output)[0]['score']
return {"modified_prompt": self.output, "mod_confidence": self.score_threshold, "prompt_intent": intent,
"intent_confidence": intent_conf}
"""
import threading
import spacy import spacy
import spacy_transformers import spacy_transformers
import torch import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
class NLU: class NLU:
def __init__(self, device, device_flag, reference_resolver, tokenizer, def __init__(self, query_rewriter, coref_resolver, intent_classifier, offensive_classifier, entity_extractor, ambig_classifier):
intent_classifier, offense_filter, entity_extractor,
max_history_length=1024): self.intent_classifier = intent_classifier
#entity_extractor=def_entity_extractor
self.reference_resolver = reference_resolver
self.device = device
self.reference_resolver.to(device)
self.tokenizer = tokenizer
self.max_length = max_history_length
self.label2idintent = {'QA': 0, 'CHITCHAT': 1, 'FINDPAPER': 2, 'FINDDATASET': 3, 'SUMMARIZEPAPER': 4}
self.id2labelintent = {0: 'QA', 1: 'CHITCHAT', 2: 'FINDPAPER', 3: 'FINDDATASET', 4: 'SUMMARIZEPAPER'}
self.label2idoffense = {'hate': 0, 'offensive': 1, 'neither': 2}
self.id2labeloffense = {0: 'hate', 1: 'offensive', 2: 'neither'}
self.intent_classifier = pipeline("sentiment-analysis", model=intent_classifier, device=device_flag)
self.entity_extractor = entity_extractor self.entity_extractor = entity_extractor
self.offense_filter = pipeline("sentiment-analysis", model=offense_filter, device=device_flag) self.offensive_classifier = offensive_classifier
self.coref_resolver = coref_resolver
self.query_rewriter = query_rewriter
self.ambig_classifier = ambig_classifier
def _resolve_coref(self, history):
to_resolve = history + ' <COREF_SEP_TOKEN> ' + self.to_process
doc = self.coref_resolver(to_resolve)
token_mention_mapper = {}
output_string = ""
clusters = [
val for key, val in doc.spans.items() if key.startswith("coref_cluster")
]
# Iterate through every found cluster
for cluster in clusters:
first_mention = cluster[0]
# Iterate through every other span in the cluster
for mention_span in list(cluster)[1:]:
# Set first_mention as value for the first token in mention_span in the token_mention_mapper
token_mention_mapper[mention_span[0].idx] = first_mention.text + mention_span[0].whitespace_
for token in mention_span[1:]:
# Set empty string for all the other tokens in mention_span
token_mention_mapper[token.idx] = ""
# Iterate through every token in the Doc
for token in doc:
# Check if token exists in token_mention_mapper
if token.idx in token_mention_mapper:
output_string += token_mention_mapper[token.idx]
# Else add original token text
else:
output_string += token.text + token.whitespace_
cleaned_query = output_string.split(" <COREF_SEP_TOKEN> ", 1)[1]
return cleaned_query
self.intents = None
self.entities = None
self.offensive = None
self.clear = True
def _intentpredictor(self): def _intentpredictor(self):
self.intents = self.label2idintent[self.intent_classifier(self.to_process)[0]['label']] pred = self.intent_classifier(self.to_process)[0]
return pred['label'], pred['score']
def _ambigpredictor(self):
pred = self.ambig_classifier(self.to_process)[0]
if pred['label'] in ['clear', 'somewhat_clear']:
return False
else:
return True
def _entityextractor(self): def _entityextractor(self):
self.entities = [] entities = []
doc = self.entity_extractor(self.to_process) doc = self.entity_extractor(self.to_process)
for entity in doc.ents: for entity in doc.ents:
if entity.text not in ['.', ',', '?', ';']: if entity.text not in ['.', ',', '?', ';']:
self.entities.append({'entity': entity.label_, 'value': entity.text}) entities.append({'entity': entity.label_, 'value': entity.text})
return entities
def _inappropriatedetector(self): def _offensepredictor(self):
self.offensive = False pred = self.offensive_classifier(self.to_process)[0]['label']
is_offensive = self.label2idoffense[self.offense_filter(self.to_process)[0]['label']] if pred != "neither":
if is_offensive == 0 or is_offensive == 1: return True
self.offensive = True else:
return False
def process_utterance(self, utterance, history): def _rewrite_query(self, history):
text = history + " ||| " + self.to_process
return self.query_rewriter(text)[0]['generated_text']
def process_utterance(self, utterance, history_consec, history_sep):
""" """
Given an utterance and the history of the conversation, refine the query contextually and return a refined Query -> coref resolution & intent extraction -> if intents are not confident or if query is ambig -> rewrite query and recheck -> if still ambig, ask a clarifying question
utterance
""" """
self.to_process = utterance self.to_process = utterance
if len(history) > 0:
# crop history self.to_process = self._resolve_coref(history_consec)
while len(history.split(" ")) > self.max_length:
index = history.find("|||") intent, score = self._intentpredictor()
history = history[index + 4:] #print(score)
if score > 0.5:
context = history + " ||| " + utterance entities = self._entityextractor()
inputs = self.tokenizer(context, max_length=self.max_length, truncation=True, padding="max_length", offense = self._offensepredictor()
return_tensors="pt") if intent in ['FINDPAPER', 'FINDDATASET', 'SUMMARIZEPAPER'] and len(entities) == 0:
return {"modified_prompt": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense, "is_clear": False}
candidates = self.reference_resolver.generate(input_ids=inputs["input_ids"].to(self.device), return {"modified_prompt": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense, "is_clear": True}
attention_mask=inputs["attention_mask"].to(self.device), else:
return_dict_in_generate=True, output_scores=True, if self._ambigpredictor():
num_return_sequences=1, self.to_process = self._rewrite_query(history_sep)
num_beams=5) intent, score = self._intentpredictor()
self.to_process = self.tokenizer.decode(candidates["sequences"][0], skip_special_tokens=True, entities = self._entityextractor()
clean_up_tokenization_spaces=True) offense = self._offensepredictor()
if score > 0.5 or not self._ambigpredictor():
t1 = threading.Thread(target=self._intentpredictor, name='intent') if intent in ['FINDPAPER', 'FINDDATASET', 'SUMMARIZEPAPER'] and len(entities) == 0:
t2 = threading.Thread(target=self._entityextractor, name='entity') return {"modified_prompt": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense, "is_clear": False}
t3 = threading.Thread(target=self._inappropriatedetector, name='offensive') return {"modified_prompt": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense,
"is_clear": True}
t3.start() else:
t1.start() return {"modified_prompt": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense,
t2.start() "is_clear": False}
else:
t3.join() entities = self._entityextractor()
t1.join() offense = self._offensepredictor()
t2.join() if intent in ['FINDPAPER', 'FINDDATASET', 'SUMMARIZEPAPER'] and len(entities) == 0:
return {"modified_prompt": self.to_process, return {"modified_prompt": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense, "is_clear": False}
"intent": self.intents, return {"modified_prompt": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense, "is_clear": True}
"entities": self.entities,
"is_offensive": self.offensive,
"is_clear": self.clear}

View File

@ -4,140 +4,177 @@ import faiss
from sklearn.metrics.pairwise import cosine_similarity from sklearn.metrics.pairwise import cosine_similarity
import numpy as np import numpy as np
import pandas as pd import pandas as pd
from datetime import datetime
class ResponseGenerator: class ResponseGenerator:
def __init__(self, index, db, def __init__(self, index, db,recommender,generators, retriever, num_retrieved=1):
generator, retriever, num_retrieved=1): self.generators = generators
self.generator = generator self.retriever = retriever
self.retriever = retriever self.recommender = recommender
self.db = db self.db = db
self.index = index self.index = index
self.num_retrieved = num_retrieved self.num_retrieved = num_retrieved
self.paper = {} self.paper = {}
self.dataset = {} self.dataset = {}
def update_index(self, index): def update_index(self, index):
self.index = index self.index = index
def update_db(self, db): def update_db(self, db):
self.db = db self.db = db
def _get_resources_links(self, item): def _get_resources_links(self, item):
if len(item) == 0: if len(item) == 0:
return [] return []
links = [] links = []
for rsrc in item['resources']: for rsrc in item['resources']:
links.append(rsrc['url']) links.append(rsrc['url'])
return links return links
def _get_matching_titles(self, rsrc, title): def _get_matching_titles(self, rsrc, title):
cand = self.db[rsrc].loc[self.db[rsrc]['title'] == title.lower()].reset_index(drop=True) cand = self.db[rsrc].loc[self.db[rsrc]['title'] == title.lower()].reset_index(drop=True)
if not cand.empty: if not cand.empty:
return cand.loc[0] return cand.loc[0]
else: else:
return {} return {}
def _get_matching_topics(self, rsrc, topic): def _get_matching_authors(self, rsrc, author):
matches = [] cand = self.db[rsrc].loc[self.db[rsrc]['author'] == author.lower()].reset_index(drop=True)
score = 0.7 if not cand.empty:
for i, cand in self.db[rsrc].iterrows(): return cand.loc[0]
for tag in cand['tags']: else:
sim = cosine_similarity(np.array(self.retriever.encode([tag])), np.array(self.retriever.encode([topic.lower()]))) return {}
if sim > score:
if(len(matches)>0):
matches[0] = cand
else:
matches.append(cand)
score = sim
if len(matches) > 0:
return matches[0]
else:
return []
def _search_index(self, index_type, db_type, query): def _get_matching_topics(self, rsrc, topic):
xq = self.retriever.encode([query]) matches = []
D, I = self.index[index_type].search(xq, self.num_retrieved) score = 0.7
return self.db[db_type].iloc[[I[0]][0]].reset_index(drop=True).loc[0] for i, cand in self.db[rsrc].iterrows():
for tag in cand['tags']:
sim = cosine_similarity(np.array(self.retriever.encode([tag])), np.array(self.retriever.encode([topic.lower()])))
if sim > score:
if(len(matches)>0):
matches[0] = cand
else:
matches.append(cand)
score = sim
if len(matches) > 0:
return matches[0]
else:
return []
def _search_index(self, index_type, db_type, query):
xq = self.retriever.encode([query])
D, I = self.index[index_type].search(xq, self.num_retrieved)
return self.db[db_type].iloc[[I[0]][0]].reset_index(drop=True).loc[0]
def gen_response(self, utterance, state, history, action): def gen_response(self, action, utterance=None, username=None, state=None, consec_history=None):
if action == "NoCanDo": if action == "Help":
return str("I am sorry, I cannot answer to this kind of language") return "Hey it's Janet! I am here to help you make use of the datasets and papers in the VRE. I can answer questions whose answers may be inside the papers. I can summarize papers for you. I can also chat with you. So, whichever it is, I am ready to chat!"
elif action == "Recommend":
prompt = self.recommender.make_recommendation(username)
if prompt != "":
return prompt
else:
return "I can help you with exploiting the contents of the VRE, just let me know!"
elif action == "ConvGen": elif action == "OffenseReject":
gen_kwargs = {"length_penalty": 2.5, "num_beams":4, "max_length": 20} return "I am sorry, I cannot answer to this kind of language"
answer = self.generator('question: '+ utterance + ' context: ' + history , **gen_kwargs)[0]['generated_text']
return answer
elif action == "findPaper": elif action == "ConvGen":
for entity in state['entities']: gen_kwargs = {"length_penalty": 2.5, "num_beams":2, "max_length": 30}
if (entity['entity'] == 'TITLE'): answer = self.generators['chat']('history: '+ consec_history + ' ' + utterance + ' persona: ' + 'I am Janet. My name is Janet. I am an AI developed by CNR to help VRE users.' , **gen_kwargs)[0]['generated_text']
self.paper = self._get_matching_titles('paper_db', entity['value']) return answer
links = self._get_resources_links(self.paper)
if len(self.paper) > 0 and len(links) > 0: elif action == "findPaper":
return str("This paper could be helpful: " + self.paper['title'] + '. ' + "It can be downloaded at " + links[0]) for entity in state['entities']:
else: if (entity['entity'] == 'TITLE'):
self.paper = self._search_index('paper_titles_index', 'paper_db', entity['value']) self.paper = self._get_matching_titles('paper_db', entity['value'])
links = self._get_resources_links(self.paper)
if len(self.paper) > 0 and len(links) > 0:
return str("Here is the paper you want: " + self.paper['title'] + '. ' + "It can be downloaded at " + links[0])
else:
self.paper = self._search_index('paper_titles_index', 'paper_db', entity['value'])
links = self._get_resources_links(self.paper)
return str("This paper could be relevant: " + self.paper['title'] + '. ' + "It can be downloaded at " + links[0])
if(entity['entity'] == 'TOPIC'):
self.paper = self._get_matching_topics('paper_db', entity['value'])
links = self._get_resources_links(self.paper)
if len(self.paper) > 0 and len(links) > 0:
return str("This paper could be relevant: " + self.paper['title'] + '. ' + "It can be downloaded at " + links[0])
if(entity['entity'] == 'AUTHOR'):
self.paper = self._get_matching_authors('paper_db', entity['value'])
links = self._get_resources_links(self.paper)
if len(self.paper) > 0 and len(links) > 0:
return str("Here is the paper you want: " + self.paper['title'] + '. ' + "It can be downloaded at " + links[0])
self.paper = self._search_index('paper_desc_index', 'paper_db', utterance)
links = self._get_resources_links(self.paper) links = self._get_resources_links(self.paper)
return str("This paper could be helpful: " + self.paper['title'] + '. ' + "It can be downloaded at " + links[0]) return str("This paper could be relevant: " + self.paper['title'] + '. ' + "It can be downloaded at " + links[0])
if(entity['entity'] == 'TOPIC'):
self.paper = self._get_matching_topics('paper_db', entity['value'])
links = self._get_resources_links(self.paper)
if len(self.paper) > 0 and len(links) > 0:
return str("This paper could be helpful: " + self.paper['title'] + '. ' + "It can be downloaded at " + links[0])
self.paper = self._search_index('paper_desc_index', 'paper_db', utterance)
links = self._get_resources_links(self.paper)
return str("This paper could be helpful: " + self.paper['title'] + '. ' + "It can be downloaded at " + links[0])
elif action == "findDataset": elif action == "findDataset":
for entity in state['entities']: for entity in state['entities']:
if (entity['entity'] == 'TITLE'): if (entity['entity'] == 'TITLE'):
self.dataset = self._get_matching_titles('dataset_db', entity['value']) self.dataset = self._get_matching_titles('dataset_db', entity['value'])
links = self._get_resources_links(self.dataset) links = self._get_resources_links(self.dataset)
if len(self.dataset) > 0 and len(links) > 0: if len(self.dataset) > 0 and len(links) > 0:
return str("This dataset could be helpful: " + self.dataset['title'] + '. ' + "It can be downloaded at " + links[0]) return str("Here is the dataset you wanted: " + self.dataset['title'] + '. ' + "It can be downloaded at " + links[0])
else: else:
self.dataset = self._search_index('dataset_titles_index', 'dataset_db', entity['value']) self.dataset = self._search_index('dataset_titles_index', 'dataset_db', entity['value'])
links = self._get_resources_links(self.dataset)
return str("This dataset could be relevant: " + self.dataset['title'] + '. ' + "It can be downloaded at " + links[0])
if(entity['entity'] == 'TOPIC'):
self.dataset = self._get_matching_topics('dataset_db', entity['value'])
links = self._get_resources_links(self.dataset)
if len(self.dataset) > 0 and len(links) > 0:
return str("This dataset could be relevant: " + self.dataset['title'] + '. ' + "It can be downloaded at " + links[0])
if(entity['entity'] == 'AUTHOR'):
self.dataset = self._get_matching_authors('dataset_db', entity['value'])
links = self._get_resources_links(self.dataset)
if len(self.dataset) > 0 and len(links) > 0:
return str("Here is the dataset you want: " + self.dataset['title'] + '. ' + "It can be downloaded at " + links[0])
self.dataset = self._search_index('dataset_desc_index', 'dataset_db', utterance)
links = self._get_resources_links(self.dataset) links = self._get_resources_links(self.dataset)
return str("This dataset could be helpful: " + self.dataset['title'] + '. ' + "It can be downloaded at " + links[0]) return str("This dataset could be relevant: " + self.dataset['title'] + '. ' + "It can be downloaded at " + links[0])
if(entity['entity'] == 'TOPIC'):
self.dataset = self._get_matching_topics('dataset_db', entity['value'])
links = self._get_resources_links(self.dataset)
if len(self.dataset) > 0 and len(links) > 0:
return str("This dataset could be helpful: " + self.dataset['title'] + '. ' + "It can be downloaded at " + links[0])
self.dataset = self._search_index('dataset_desc_index', 'dataset_db', utterance)
links = self._get_resources_links(self.dataset)
return str("This dataset could be helpful: " + self.dataset['title'] + '. ' + "It can be downloaded at " + links[0])
elif action == "RetGen": elif action == "RetGen":
#retrieve the most relevant paragraph #retrieve the most relevant paragraph
content = str(self._search_index('content_index', 'content_db', utterance)['content']) content = str(self._search_index('content_index', 'content_db', utterance)['content'])
#generate the answer #generate the answer
gen_seq = 'question: '+utterance+" context: "+content gen_seq = 'question: '+utterance+" context: "+content
#handle return random 2 answers
gen_kwargs = {"length_penalty": 0.5, "num_beams":2, "max_length": 60}
answer = self.generators['qa'](gen_seq, **gen_kwargs)[0]['generated_text']
return str(answer)
#handle return random 2 answers elif action == "sumPaper":
gen_kwargs = {"length_penalty": 0.5, "num_beams":8, "max_length": 100} if len(self.paper) == 0:
answer = self.generator(gen_seq, **gen_kwargs)[0]['generated_text'] for entity in state['entities']:
return str(answer) if (entity['entity'] == 'TITLE'):
self.paper = self._get_matching_titles('paper_db', entity['value'])
elif action == "sumPaper": if (len(self.paper) > 0):
if len(self.paper) == 0: break
for entity in state['entities']: if len(self.paper) == 0:
if (entity['entity'] == 'TITLE'): return "I cannot seem to find the requested paper. Try again by specifying the title of the paper."
self.paper = self._get_matching_titles('paper_db', entity['value']) #implement that
if (len(self.paper) > 0): df = self.db['content_db'][self.db['content_db']['paperid'] == self.paper['id']]
break answer = ""
if len(self.paper) == 0: for i, row in df.iterrows():
return "I cannot seem to find the requested paper. Try again by specifying the title of the paper." gen_seq = 'summarize: '+row['content']
#implement that gen_kwargs = {"length_penalty": 1.5, "num_beams":6, "max_length": 120}
df = self.db['content_db'][self.db['content_db']['paperid'] == self.paper['id']] answer = self.generators['summ'](gen_seq, **gen_kwargs)[0]['generated_text'] + ' '
answer = "" return answer
for i, row in df.iterrows():
gen_seq = 'summarize: '+row['content']
gen_kwargs = {"length_penalty": 1.5, "num_beams":8, "max_length": 100}
answer = self.generator(gen_seq, **gen_kwargs)[0]['generated_text'] + ' '
return answer
elif action == "Clarify": elif action == "ClarifyResource":
return str("Can you please clarify?") if state['intent'] in ['FINDPAPER', 'SUMMARIZEPAPER']:
return 'Please specify the title, the topic or the paper of interest.'
else:
return 'Please specify the title, the topic or the dataset of interest.'
elif action == "GenClarify":
gen_kwargs = {"length_penalty": 2.5, "num_beams":8, "max_length": 120}
question = self.generators['amb']('question: '+ utterance + ' context: ' + consec_history , **gen_kwargs)[0]['generated_text']
return question

View File

@ -20,8 +20,7 @@ class User:
if len(index) > 0: if len(index) > 0:
self.interests.at[index[0], 'frequency'] += 1 self.interests.at[index[0], 'frequency'] += 1
else: else:
self.interests = self.interests.append({'interest': topic, 'frequency': max( self.interests = self.interests.append({'interest': topic, 'frequency': max(self.interests['frequency']) if not self.interests.empty else 6}, ignore_index=True)
self.interests['frequency']) if not self.interests.empty else 6}, ignore_index=True)
self.interests = self.interests.sort_values(by='frequency', ascending=False, ignore_index=True) self.interests = self.interests.sort_values(by='frequency', ascending=False, ignore_index=True)
self.interests.to_json(self.interests_file) self.interests.to_json(self.interests_file)

148
main.py
View File

@ -1,6 +1,5 @@
import os import os
import warnings import warnings
import faiss import faiss
import torch import torch
from flask import Flask, render_template, request, jsonify from flask import Flask, render_template, request, jsonify
@ -10,131 +9,125 @@ import spacy
import spacy_transformers import spacy_transformers
import torch import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
from User import User from User import User
from VRE import VRE from VRE import VRE
from NLU import NLU from NLU import NLU
from DM import DM from DM import DM
from Recommender import Recommender from Recommender import Recommender
from ResponseGenerator import ResponseGenerator from ResponseGenerator import ResponseGenerator
import pandas as pd import pandas as pd
import time import time
import threading import threading
from sentence_transformers import SentenceTransformer from sentence_transformers import SentenceTransformer
app = Flask(__name__) app = Flask(__name__)
#allow frontend address
url = os.getenv("FRONTEND_URL_WITH_PORT") url = os.getenv("FRONTEND_URL_WITH_PORT")
cors = CORS(app, resources={r"/predict": {"origins": url}, r"/feedback": {"origins": url}}) cors = CORS(app, resources={r"/predict": {"origins": url}, r"/feedback": {"origins": url}})
#cors = CORS(app, resources={r"/predict": {"origins": "*"}, r"/feedback": {"origins": "*"}}) """
conn = psycopg2.connect( conn = psycopg2.connect(
host="janet-pg", host="https://janet-app-db.d4science.org",
database=os.getenv("POSTGRES_DB"), database=os.getenv("POSTGRES_DB"),
user=os.getenv("POSTGRES_USER"), user=os.getenv("POSTGRES_USER"),
password=os.getenv("POSTGRES_PASSWORD")) password=os.getenv("POSTGRES_PASSWORD"))
"""
conn = psycopg2.connect(host="https://janet-app-db.d4science.org",
database="janet",
user="janet_user",
password="2fb5e81fec5a2d906a04")
cur = conn.cursor() cur = conn.cursor()
#rg = ResponseGenerator(index)
def get_response(text):
# get response from janet itself
return text, 'candAnswer'
def vre_fetch(): def vre_fetch():
while True: while True:
time.sleep(1000) time.sleep(1000)
print('getting new material') print('getting new material')
vre.get_vre_update() vre.get_vre_update()
vre.index_periodic_update() vre.index_periodic_update()
rg.update_index(vre.get_index()) rg.update_index(vre.get_index())
rg.update_db(vre.get_db()) rg.update_db(vre.get_db())
def user_interest_decay(): def user_interest_decay():
while True: while True:
print("decaying interests after 3 minutes") print("decaying interests after 3 minutes")
time.sleep(180) time.sleep(180)
user.decay_interests() user.decay_interests()
def recommend():
while True:
if time.time() - dm.get_recent_state()['time'] > 1000:
print("Making Recommendation: ")
prompt = rec.make_recommendation(user.username)
if prompt != "":
print(prompt)
time.sleep(1000)
@app.route("/predict", methods=['POST']) @app.route("/predict", methods=['POST'])
def predict(): def predict():
text = request.get_json().get("message") text = request.get_json().get("message")
state = nlu.process_utterance(text, dm.get_utt_history()) message = {}
user_interests = [] if text == "<HELP_ON_START>":
for entity in state['entities']: state = {'help': True, 'inactive': False}
if entity['entity'] == 'TOPIC': dm.update(state)
user_interests.append(entity['value']) action = dm.next_action()
user.update_interests(user_interests) response = rg.gen_response(action)
dm.update(state) message = {"answer": response}
action = dm.next_action() elif text == "<RECOMMEND_ON_IDLE>":
response = rg.gen_response(state['modified_prompt'], dm.get_recent_state(), dm.get_utt_history(), action) state = {'help': False, 'inactive': True}
message = {"answer": response, "query": text, "cand": "candidate", "history": dm.get_utt_history(), "modQuery": state['modified_prompt']} dm.update(state)
action = dm.next_action()
response = rg.gen_response(action, username=user.username)
message = {"answer": response}
else:
state = nlu.process_utterance(text, dm.get_consec_history(), dm.get_sep_history())
state['help'] = False
state['inactive'] = False
user_interests = []
for entity in state['entities']:
if entity['entity'] == 'TOPIC':
user_interests.append(entity['value'])
user.update_interests(user_interests)
dm.update(state)
action = dm.next_action()
self, action, utterance=None, username=None, state=None, consec_history=None
response = rg.gen_response(action, utterance=state['modified_query'], state=dm.get_recent_state(), consec_history=dm.get_consec_history())
message = {"answer": response, "query": text, "cand": "candidate", "history": dm.get_consec_history(), "modQuery": state['modified_query']}
new_state = {'modified_query': response}
dm.update(new_state)
reply = jsonify(message) reply = jsonify(message)
#reply.headers.add('Access-Control-Allow-Origin', '*')
return reply return reply
@app.route('/feedback', methods = ['POST']) @app.route('/feedback', methods = ['POST'])
def feedback(): def feedback():
data = request.get_json()['feedback'] data = request.get_json()['feedback']
# Make data frame of above data
print(data) print(data)
#df = pd.DataFrame([data]) cur.execute('INSERT INTO feedback (query, history, janet_modified_query, is_modified_query_correct, user_modified_query, response, preferred_response, response_length_feedback, response_fluency_feedback, response_truth_feedback, response_useful_feedback, response_time_feedback, response_intent) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)',
#file_exists = os.path.isfile('feedback.csv')
#df = pd.DataFrame(data=[data['response'], data['length'], data['fluency'], data['truthfulness'], data['usefulness'], data['speed']]
# ,columns=['response', 'length', 'fluency', 'truthfulness', 'usefulness', 'speed'])
#df.to_csv('feedback.csv', mode='a', index=False, header=(not file_exists))
cur.execute('INSERT INTO feedback (query, history, janet_modified_query,
is_modified_query_correct, user_modified_query,
response, preferred_response, response_length_feedback,
response_fluency_feedback, response_truth_feedback,
response_useful_feedback, response_time_feedback, response_intent)'
'VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)',
(data['query'], data['history'], data['modQuery'], (data['query'], data['history'], data['modQuery'],
data['queryModCorrect'], data['correctQuery'], data['queryModCorrect'], data['correctQuery'],
data['janetResponse'], data['preferredResponse'], data['length'], data['janetResponse'], data['preferredResponse'], data['length'],
data['fluency'], data['truthfulness'], data['usefulness'], data['fluency'], data['truthfulness'], data['usefulness'],
data['speed'], data['intent']) data['speed'], data['intent'])
) )
reply = jsonify({"status": "done"}) reply = jsonify({"status": "done"})
#reply.headers.add('Access-Control-Allow-Origin', '*')
return reply return reply
if __name__ == "__main__": if __name__ == "__main__":
warnings.filterwarnings("ignore") warnings.filterwarnings("ignore")
#load NLU
def_tokenizer = AutoTokenizer.from_pretrained("castorini/t5-base-canard")
def_reference_resolver = AutoModelForSeq2SeqLM.from_pretrained("castorini/t5-base-canard")
def_intent_classifier_dir = "./IntentClassifier/"
def_entity_extractor = spacy.load("./EntityExtraction/BestModel")
def_offense_filter_dir ="./OffensiveClassifier"
device = "cuda" if torch.cuda.is_available() else "cpu" device = "cuda" if torch.cuda.is_available() else "cpu"
device_flag = torch.cuda.current_device() if torch.cuda.is_available() else -1 device_flag = torch.cuda.current_device() if torch.cuda.is_available() else -1
nlu = NLU(device, device_flag, def_reference_resolver, def_tokenizer, def_intent_classifier_dir, def_offense_filter_dir, def_entity_extractor)
query_rewriter = pipeline("text2text-generation", model="castorini/t5-base-canard")
intent_classifier = pipeline("sentiment-analysis", model='./intent_classifier', device=device_flag)
entity_extractor = spacy.load("./entity_extractor")
offensive_classifier = pipeline("sentiment-analysis", model='./offensive_classifier', device=device_flag)
ambig_classifier = pipeline("sentiment-analysis", model='./ambig_classifier', device=device_flag)
coref_resolver = spacy.load("en_coreference_web_trf")
nlu = NLU(query_rewriter, coref_resolver, intent_classifier, offensive_classifier, entity_extractor, ambig_classifier)
#load retriever and generator #load retriever and generator
def_retriever = SentenceTransformer('./BigRetriever/').to(device) retriever = SentenceTransformer('./BigRetriever/').to(device)
def_generator = pipeline("text2text-generation", model="./generator", device=device_flag) qa_generator = pipeline("text2text-generation", model="./train_qa", device=device_flag)
summ_generator = pipeline("text2text-generation", model="./train_summ", device=device_flag)
chat_generator = pipeline("text2text-generation", model="./train_chat", device=device_flag)
amb_generator = pipeline("text2text-generation", model="./train_amb_gen", device=device_flag)
generators = {'qa': qa_generator,
'chat': chat_generator,
'amb': amb_generator,
'summ': summ_generator}
#load vre #load vre
token = '2c1e8f88-461c-42c0-8cc1-b7660771c9a3-843339462' token = '2c1e8f88-461c-42c0-8cc1-b7660771c9a3-843339462'
@ -148,15 +141,14 @@ if __name__ == "__main__":
threading.Thread(target=user_interest_decay, name='decayinterest').start() threading.Thread(target=user_interest_decay, name='decayinterest').start()
rec = Recommender(retriever)
rec = Recommender(def_retriever)
dm = DM() dm = DM()
rg = ResponseGenerator(index,db,def_generator,def_retriever)
threading.Thread(target=recommend, name='recommend').start() rg = ResponseGenerator(index,db, recommender, generators, retriever)
cur.execute('CREATE TABLE IF NOT EXISTS feedback (id serial PRIMARY KEY,' cur.execute('CREATE TABLE IF NOT EXISTS feedback_trial (id serial PRIMARY KEY,'
'query text NOT NULL,' 'query text NOT NULL,'
'history text NOT NULL,' 'history text NOT NULL,'
'janet_modified_query text NOT NULL,' 'janet_modified_query text NOT NULL,'

View File

@ -11,18 +11,26 @@ regex==2022.6.2
requests==2.25.1 requests==2.25.1
scikit-learn==1.0.2 scikit-learn==1.0.2
scipy==1.7.3 scipy==1.7.3
sentence-transformers==2.2.2
sentencepiece==0.1.97 sentencepiece==0.1.97
sklearn-pandas==1.8.0 sklearn-pandas==1.8.0
spacy==3.5.0 spacy==3.4.4
spacy-transformers==1.2.2 spacy-alignments==0.9.0
spacy-legacy==3.0.12
spacy-loggers==1.0.4
spacy-transformers==1.1.9
spacy-experimental==0.6.2
torch @ https://download.pytorch.org/whl/cu116/torch-1.13.1%2Bcu116-cp38-cp38-linux_x86_64.whl torch @ https://download.pytorch.org/whl/cu116/torch-1.13.1%2Bcu116-cp38-cp38-linux_x86_64.whl
torchaudio @ https://download.pytorch.org/whl/cu116/torchaudio-0.13.1%2Bcu116-cp38-cp38-linux_x86_64.whl torchaudio @ https://download.pytorch.org/whl/cu116/torchaudio-0.13.1%2Bcu116-cp38-cp38-linux_x86_64.whl
torchsummary==1.5.1 torchsummary==1.5.1
torchtext==0.14.1 torchtext==0.14.1
sentence-transformers
torchvision @ https://download.pytorch.org/whl/cu116/torchvision-0.14.1%2Bcu116-cp38-cp38-linux_x86_64.whl torchvision @ https://download.pytorch.org/whl/cu116/torchvision-0.14.1%2Bcu116-cp38-cp38-linux_x86_64.whl
tqdm==4.64.1 tqdm==4.64.1
transformers==4.26.1 transformers
markupsafe==2.0.1 markupsafe==2.0.1
psycopg2==2.9.5 psycopg2==2.9.5
en-coreference-web-trf @ https://github.com/explosion/spacy-experimental/releases/download/v0.6.1/en_coreference_web_trf-3.4.0a2-py3-none-any.whl
Werkzeug==1.0.1 Werkzeug==1.0.1