This commit is contained in:
ahmed531998 2023-04-04 05:34:47 +02:00
parent 5c328ce7df
commit 423853ad65
6 changed files with 373 additions and 355 deletions

63
DM.py
View File

@ -1,43 +1,56 @@
import time
class DM:
def __init__(self):
self.utt_history = ""
self.history = []
self.state = None
def __init__(self, max_history_length=3):
self.working_history_sep = ""
self.working_history_consec = ""
self.max_history_length = max_history_length
self.chat_history = []
self.curr_state = None
def update_history(self):
to_consider = [x['modified_query'] for x in self.chat_history[-max_history_length*2:]]
self.working_history_consec = " . ".join(to_consider)
self.working_history_sep = " ||| ".join(to_consider)
def get_consec_history(self):
return self.working_history_consec
def get_utt_history(self):
return self.utt_history
def get_sep_history(self):
return self.working_history_sep
def get_recent_state(self):
return self.state
return self.curr_state
def get_dialogue_state(self):
return self.history
def get_dialogue_history(self):
return self.chat_history
def update(self, new_state):
self.history.append(new_state)
self.utt_history = self.utt_history + " ||| " + new_state['modified_prompt']
self.state = {'intent': new_state['intent'],
'entities': new_state['entities'],
'offensive': new_state['is_offensive'],
'clear': new_state['is_clear'],
'time': time.time()}
self.chat_history.append(new_state)
self.curr_state = new_state
self.update_history()
def next_action(self):
if self.state['clear']:
if self.state['offensive']:
return "NoCanDo"
if self.curr_state['help']:
return "Help"
elif self.curr_state['inactive']:
return "Recommend"
elif self.curr_state['is_clear']:
if self.curr_state['is_offensive']:
return "OffenseReject"
else:
if self.state['intent'] == 0:
if self.curr_state['intent'] == 'QA':
return "RetGen"
elif self.state['intent'] == 1:
elif self.curr_state['intent'] == 'CHITCHAT':
return "ConvGen"
elif self.state['intent'] == 2:
elif self.curr_state['intent'] == 'FINDPAPER':
return "findPaper"
elif self.state['intent'] == 3:
elif self.curr_state['intent'] == 'FINDDATASET':
return "findDataset"
elif self.state['intent'] == 4:
elif self.curr_state['intent'] == 'SUMMARIZEPAPER':
return "sumPaper"
else:
return "Clarify"
if self.curr_state['intent'] in ['FINDPAPER', 'FINDDATASET', 'SUMMARIZEPAPER'] and len(self.curr_state['entities']) == 0:
return "ClarifyResource"
else:
return "GenClarify"

221
NLU.py
View File

@ -1,143 +1,112 @@
"""
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
import torch
class NLU:
def_tokenizer = AutoTokenizer.from_pretrained("castorini/t5-base-canard")
def_model = AutoModelForSeq2SeqLM.from_pretrained("castorini/t5-base-canard")
def_intent_classifier = pipeline("sentiment-analysis", model="/home/ahmed/PycharmProjects/Janet/JanetBackend/intent_classifier")
def __init__(self, model=def_model, tokenizer=def_tokenizer, intent_classifier=def_intent_classifier,
max_history_length=1024, num_gen_seq=2, score_threshold=0.5):
self.input = ""
self.output = ""
self.model = model
self.tokenizer = tokenizer
self.max_length = max_history_length
self.num_return_sequences = num_gen_seq
self.score_threshold = score_threshold
self.label2id = {'Greet': 0, 'Bye': 1, 'GetKnowledge': 2, 'ChitChat': 3}
self.id2label = {0: 'Greet', 1: 'Bye', 2: 'GetKnowledge', 3: 'ChitChat'}
self.intent_classifier = intent_classifier
def process_utterance(self, utterance, history):
if len(history) > 0:
# crop history
while len(history.split(" ")) > self.max_length:
index = history.find("|||")
history = history[index + 4:]
self.input = history + " ||| " + utterance
inputs = self.tokenizer(self.input, max_length=self.max_length, truncation=True, padding="max_length",
return_tensors="pt")
candidates = self.model.generate(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"],
return_dict_in_generate=True, output_scores=True,
num_return_sequences=self.num_return_sequences,
num_beams=self.num_return_sequences)
for i in range(candidates["sequences"].shape[0]):
generated_sentence = self.tokenizer.decode(candidates["sequences"][i], skip_special_tokens=True,
clean_up_tokenization_spaces=True)
log_scores = candidates['sequences_scores']
norm_prob = (torch.exp(log_scores[i]) / torch.exp(log_scores).sum()).item()
if norm_prob >= self.score_threshold:
self.score_threshold = norm_prob
self.output = generated_sentence
else:
self.output = utterance
intent = self.label2id[self.intent_classifier(self.output)[0]['label']]
intent_conf = self.intent_classifier(self.output)[0]['score']
return {"modified_prompt": self.output, "mod_confidence": self.score_threshold, "prompt_intent": intent,
"intent_confidence": intent_conf}
"""
import threading
import spacy
import spacy_transformers
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
class NLU:
def __init__(self, device, device_flag, reference_resolver, tokenizer,
intent_classifier, offense_filter, entity_extractor,
max_history_length=1024):
#entity_extractor=def_entity_extractor
self.reference_resolver = reference_resolver
self.device = device
self.reference_resolver.to(device)
self.tokenizer = tokenizer
self.max_length = max_history_length
self.label2idintent = {'QA': 0, 'CHITCHAT': 1, 'FINDPAPER': 2, 'FINDDATASET': 3, 'SUMMARIZEPAPER': 4}
self.id2labelintent = {0: 'QA', 1: 'CHITCHAT', 2: 'FINDPAPER', 3: 'FINDDATASET', 4: 'SUMMARIZEPAPER'}
self.label2idoffense = {'hate': 0, 'offensive': 1, 'neither': 2}
self.id2labeloffense = {0: 'hate', 1: 'offensive', 2: 'neither'}
self.intent_classifier = pipeline("sentiment-analysis", model=intent_classifier, device=device_flag)
def __init__(self, query_rewriter, coref_resolver, intent_classifier, offensive_classifier, entity_extractor, ambig_classifier):
self.intent_classifier = intent_classifier
self.entity_extractor = entity_extractor
self.offense_filter = pipeline("sentiment-analysis", model=offense_filter, device=device_flag)
self.offensive_classifier = offensive_classifier
self.coref_resolver = coref_resolver
self.query_rewriter = query_rewriter
self.ambig_classifier = ambig_classifier
def _resolve_coref(self, history):
to_resolve = history + ' <COREF_SEP_TOKEN> ' + self.to_process
doc = self.coref_resolver(to_resolve)
token_mention_mapper = {}
output_string = ""
clusters = [
val for key, val in doc.spans.items() if key.startswith("coref_cluster")
]
# Iterate through every found cluster
for cluster in clusters:
first_mention = cluster[0]
# Iterate through every other span in the cluster
for mention_span in list(cluster)[1:]:
# Set first_mention as value for the first token in mention_span in the token_mention_mapper
token_mention_mapper[mention_span[0].idx] = first_mention.text + mention_span[0].whitespace_
for token in mention_span[1:]:
# Set empty string for all the other tokens in mention_span
token_mention_mapper[token.idx] = ""
# Iterate through every token in the Doc
for token in doc:
# Check if token exists in token_mention_mapper
if token.idx in token_mention_mapper:
output_string += token_mention_mapper[token.idx]
# Else add original token text
else:
output_string += token.text + token.whitespace_
cleaned_query = output_string.split(" <COREF_SEP_TOKEN> ", 1)[1]
return cleaned_query
self.intents = None
self.entities = None
self.offensive = None
self.clear = True
def _intentpredictor(self):
self.intents = self.label2idintent[self.intent_classifier(self.to_process)[0]['label']]
pred = self.intent_classifier(self.to_process)[0]
return pred['label'], pred['score']
def _ambigpredictor(self):
pred = self.ambig_classifier(self.to_process)[0]
if pred['label'] in ['clear', 'somewhat_clear']:
return False
else:
return True
def _entityextractor(self):
self.entities = []
entities = []
doc = self.entity_extractor(self.to_process)
for entity in doc.ents:
if entity.text not in ['.', ',', '?', ';']:
self.entities.append({'entity': entity.label_, 'value': entity.text})
entities.append({'entity': entity.label_, 'value': entity.text})
return entities
def _inappropriatedetector(self):
self.offensive = False
is_offensive = self.label2idoffense[self.offense_filter(self.to_process)[0]['label']]
if is_offensive == 0 or is_offensive == 1:
self.offensive = True
def _offensepredictor(self):
pred = self.offensive_classifier(self.to_process)[0]['label']
if pred != "neither":
return True
else:
return False
def process_utterance(self, utterance, history):
def _rewrite_query(self, history):
text = history + " ||| " + self.to_process
return self.query_rewriter(text)[0]['generated_text']
def process_utterance(self, utterance, history_consec, history_sep):
"""
Given an utterance and the history of the conversation, refine the query contextually and return a refined
utterance
Query -> coref resolution & intent extraction -> if intents are not confident or if query is ambig -> rewrite query and recheck -> if still ambig, ask a clarifying question
"""
self.to_process = utterance
if len(history) > 0:
# crop history
while len(history.split(" ")) > self.max_length:
index = history.find("|||")
history = history[index + 4:]
context = history + " ||| " + utterance
inputs = self.tokenizer(context, max_length=self.max_length, truncation=True, padding="max_length",
return_tensors="pt")
candidates = self.reference_resolver.generate(input_ids=inputs["input_ids"].to(self.device),
attention_mask=inputs["attention_mask"].to(self.device),
return_dict_in_generate=True, output_scores=True,
num_return_sequences=1,
num_beams=5)
self.to_process = self.tokenizer.decode(candidates["sequences"][0], skip_special_tokens=True,
clean_up_tokenization_spaces=True)
t1 = threading.Thread(target=self._intentpredictor, name='intent')
t2 = threading.Thread(target=self._entityextractor, name='entity')
t3 = threading.Thread(target=self._inappropriatedetector, name='offensive')
t3.start()
t1.start()
t2.start()
t3.join()
t1.join()
t2.join()
return {"modified_prompt": self.to_process,
"intent": self.intents,
"entities": self.entities,
"is_offensive": self.offensive,
"is_clear": self.clear}
self.to_process = self._resolve_coref(history_consec)
intent, score = self._intentpredictor()
#print(score)
if score > 0.5:
entities = self._entityextractor()
offense = self._offensepredictor()
if intent in ['FINDPAPER', 'FINDDATASET', 'SUMMARIZEPAPER'] and len(entities) == 0:
return {"modified_prompt": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense, "is_clear": False}
return {"modified_prompt": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense, "is_clear": True}
else:
if self._ambigpredictor():
self.to_process = self._rewrite_query(history_sep)
intent, score = self._intentpredictor()
entities = self._entityextractor()
offense = self._offensepredictor()
if score > 0.5 or not self._ambigpredictor():
if intent in ['FINDPAPER', 'FINDDATASET', 'SUMMARIZEPAPER'] and len(entities) == 0:
return {"modified_prompt": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense, "is_clear": False}
return {"modified_prompt": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense,
"is_clear": True}
else:
return {"modified_prompt": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense,
"is_clear": False}
else:
entities = self._entityextractor()
offense = self._offensepredictor()
if intent in ['FINDPAPER', 'FINDDATASET', 'SUMMARIZEPAPER'] and len(entities) == 0:
return {"modified_prompt": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense, "is_clear": False}
return {"modified_prompt": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense, "is_clear": True}

View File

@ -4,140 +4,177 @@ import faiss
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
import pandas as pd
from datetime import datetime
class ResponseGenerator:
def __init__(self, index, db,
generator, retriever, num_retrieved=1):
self.generator = generator
self.retriever = retriever
self.db = db
self.index = index
self.num_retrieved = num_retrieved
self.paper = {}
self.dataset = {}
def __init__(self, index, db,recommender,generators, retriever, num_retrieved=1):
self.generators = generators
self.retriever = retriever
self.recommender = recommender
self.db = db
self.index = index
self.num_retrieved = num_retrieved
self.paper = {}
self.dataset = {}
def update_index(self, index):
self.index = index
def update_db(self, db):
self.db = db
def update_index(self, index):
self.index = index
def update_db(self, db):
self.db = db
def _get_resources_links(self, item):
if len(item) == 0:
return []
links = []
for rsrc in item['resources']:
links.append(rsrc['url'])
return links
def _get_resources_links(self, item):
if len(item) == 0:
return []
links = []
for rsrc in item['resources']:
links.append(rsrc['url'])
return links
def _get_matching_titles(self, rsrc, title):
cand = self.db[rsrc].loc[self.db[rsrc]['title'] == title.lower()].reset_index(drop=True)
if not cand.empty:
return cand.loc[0]
else:
return {}
def _get_matching_titles(self, rsrc, title):
cand = self.db[rsrc].loc[self.db[rsrc]['title'] == title.lower()].reset_index(drop=True)
if not cand.empty:
return cand.loc[0]
else:
return {}
def _get_matching_topics(self, rsrc, topic):
matches = []
score = 0.7
for i, cand in self.db[rsrc].iterrows():
for tag in cand['tags']:
sim = cosine_similarity(np.array(self.retriever.encode([tag])), np.array(self.retriever.encode([topic.lower()])))
if sim > score:
if(len(matches)>0):
matches[0] = cand
else:
matches.append(cand)
score = sim
if len(matches) > 0:
return matches[0]
else:
return []
def _get_matching_authors(self, rsrc, author):
cand = self.db[rsrc].loc[self.db[rsrc]['author'] == author.lower()].reset_index(drop=True)
if not cand.empty:
return cand.loc[0]
else:
return {}
def _search_index(self, index_type, db_type, query):
xq = self.retriever.encode([query])
D, I = self.index[index_type].search(xq, self.num_retrieved)
return self.db[db_type].iloc[[I[0]][0]].reset_index(drop=True).loc[0]
def _get_matching_topics(self, rsrc, topic):
matches = []
score = 0.7
for i, cand in self.db[rsrc].iterrows():
for tag in cand['tags']:
sim = cosine_similarity(np.array(self.retriever.encode([tag])), np.array(self.retriever.encode([topic.lower()])))
if sim > score:
if(len(matches)>0):
matches[0] = cand
else:
matches.append(cand)
score = sim
if len(matches) > 0:
return matches[0]
else:
return []
def _search_index(self, index_type, db_type, query):
xq = self.retriever.encode([query])
D, I = self.index[index_type].search(xq, self.num_retrieved)
return self.db[db_type].iloc[[I[0]][0]].reset_index(drop=True).loc[0]
def gen_response(self, utterance, state, history, action):
if action == "NoCanDo":
return str("I am sorry, I cannot answer to this kind of language")
def gen_response(self, action, utterance=None, username=None, state=None, consec_history=None):
if action == "Help":
return "Hey it's Janet! I am here to help you make use of the datasets and papers in the VRE. I can answer questions whose answers may be inside the papers. I can summarize papers for you. I can also chat with you. So, whichever it is, I am ready to chat!"
elif action == "Recommend":
prompt = self.recommender.make_recommendation(username)
if prompt != "":
return prompt
else:
return "I can help you with exploiting the contents of the VRE, just let me know!"
elif action == "ConvGen":
gen_kwargs = {"length_penalty": 2.5, "num_beams":4, "max_length": 20}
answer = self.generator('question: '+ utterance + ' context: ' + history , **gen_kwargs)[0]['generated_text']
return answer
elif action == "OffenseReject":
return "I am sorry, I cannot answer to this kind of language"
elif action == "findPaper":
for entity in state['entities']:
if (entity['entity'] == 'TITLE'):
self.paper = self._get_matching_titles('paper_db', entity['value'])
links = self._get_resources_links(self.paper)
if len(self.paper) > 0 and len(links) > 0:
return str("This paper could be helpful: " + self.paper['title'] + '. ' + "It can be downloaded at " + links[0])
else:
self.paper = self._search_index('paper_titles_index', 'paper_db', entity['value'])
elif action == "ConvGen":
gen_kwargs = {"length_penalty": 2.5, "num_beams":2, "max_length": 30}
answer = self.generators['chat']('history: '+ consec_history + ' ' + utterance + ' persona: ' + 'I am Janet. My name is Janet. I am an AI developed by CNR to help VRE users.' , **gen_kwargs)[0]['generated_text']
return answer
elif action == "findPaper":
for entity in state['entities']:
if (entity['entity'] == 'TITLE'):
self.paper = self._get_matching_titles('paper_db', entity['value'])
links = self._get_resources_links(self.paper)
if len(self.paper) > 0 and len(links) > 0:
return str("Here is the paper you want: " + self.paper['title'] + '. ' + "It can be downloaded at " + links[0])
else:
self.paper = self._search_index('paper_titles_index', 'paper_db', entity['value'])
links = self._get_resources_links(self.paper)
return str("This paper could be relevant: " + self.paper['title'] + '. ' + "It can be downloaded at " + links[0])
if(entity['entity'] == 'TOPIC'):
self.paper = self._get_matching_topics('paper_db', entity['value'])
links = self._get_resources_links(self.paper)
if len(self.paper) > 0 and len(links) > 0:
return str("This paper could be relevant: " + self.paper['title'] + '. ' + "It can be downloaded at " + links[0])
if(entity['entity'] == 'AUTHOR'):
self.paper = self._get_matching_authors('paper_db', entity['value'])
links = self._get_resources_links(self.paper)
if len(self.paper) > 0 and len(links) > 0:
return str("Here is the paper you want: " + self.paper['title'] + '. ' + "It can be downloaded at " + links[0])
self.paper = self._search_index('paper_desc_index', 'paper_db', utterance)
links = self._get_resources_links(self.paper)
return str("This paper could be helpful: " + self.paper['title'] + '. ' + "It can be downloaded at " + links[0])
if(entity['entity'] == 'TOPIC'):
self.paper = self._get_matching_topics('paper_db', entity['value'])
links = self._get_resources_links(self.paper)
if len(self.paper) > 0 and len(links) > 0:
return str("This paper could be helpful: " + self.paper['title'] + '. ' + "It can be downloaded at " + links[0])
self.paper = self._search_index('paper_desc_index', 'paper_db', utterance)
links = self._get_resources_links(self.paper)
return str("This paper could be helpful: " + self.paper['title'] + '. ' + "It can be downloaded at " + links[0])
return str("This paper could be relevant: " + self.paper['title'] + '. ' + "It can be downloaded at " + links[0])
elif action == "findDataset":
for entity in state['entities']:
if (entity['entity'] == 'TITLE'):
self.dataset = self._get_matching_titles('dataset_db', entity['value'])
links = self._get_resources_links(self.dataset)
if len(self.dataset) > 0 and len(links) > 0:
return str("This dataset could be helpful: " + self.dataset['title'] + '. ' + "It can be downloaded at " + links[0])
else:
self.dataset = self._search_index('dataset_titles_index', 'dataset_db', entity['value'])
elif action == "findDataset":
for entity in state['entities']:
if (entity['entity'] == 'TITLE'):
self.dataset = self._get_matching_titles('dataset_db', entity['value'])
links = self._get_resources_links(self.dataset)
if len(self.dataset) > 0 and len(links) > 0:
return str("Here is the dataset you wanted: " + self.dataset['title'] + '. ' + "It can be downloaded at " + links[0])
else:
self.dataset = self._search_index('dataset_titles_index', 'dataset_db', entity['value'])
links = self._get_resources_links(self.dataset)
return str("This dataset could be relevant: " + self.dataset['title'] + '. ' + "It can be downloaded at " + links[0])
if(entity['entity'] == 'TOPIC'):
self.dataset = self._get_matching_topics('dataset_db', entity['value'])
links = self._get_resources_links(self.dataset)
if len(self.dataset) > 0 and len(links) > 0:
return str("This dataset could be relevant: " + self.dataset['title'] + '. ' + "It can be downloaded at " + links[0])
if(entity['entity'] == 'AUTHOR'):
self.dataset = self._get_matching_authors('dataset_db', entity['value'])
links = self._get_resources_links(self.dataset)
if len(self.dataset) > 0 and len(links) > 0:
return str("Here is the dataset you want: " + self.dataset['title'] + '. ' + "It can be downloaded at " + links[0])
self.dataset = self._search_index('dataset_desc_index', 'dataset_db', utterance)
links = self._get_resources_links(self.dataset)
return str("This dataset could be helpful: " + self.dataset['title'] + '. ' + "It can be downloaded at " + links[0])
if(entity['entity'] == 'TOPIC'):
self.dataset = self._get_matching_topics('dataset_db', entity['value'])
links = self._get_resources_links(self.dataset)
if len(self.dataset) > 0 and len(links) > 0:
return str("This dataset could be helpful: " + self.dataset['title'] + '. ' + "It can be downloaded at " + links[0])
self.dataset = self._search_index('dataset_desc_index', 'dataset_db', utterance)
links = self._get_resources_links(self.dataset)
return str("This dataset could be helpful: " + self.dataset['title'] + '. ' + "It can be downloaded at " + links[0])
return str("This dataset could be relevant: " + self.dataset['title'] + '. ' + "It can be downloaded at " + links[0])
elif action == "RetGen":
#retrieve the most relevant paragraph
content = str(self._search_index('content_index', 'content_db', utterance)['content'])
#generate the answer
gen_seq = 'question: '+utterance+" context: "+content
elif action == "RetGen":
#retrieve the most relevant paragraph
content = str(self._search_index('content_index', 'content_db', utterance)['content'])
#generate the answer
gen_seq = 'question: '+utterance+" context: "+content
#handle return random 2 answers
gen_kwargs = {"length_penalty": 0.5, "num_beams":2, "max_length": 60}
answer = self.generators['qa'](gen_seq, **gen_kwargs)[0]['generated_text']
return str(answer)
#handle return random 2 answers
gen_kwargs = {"length_penalty": 0.5, "num_beams":8, "max_length": 100}
answer = self.generator(gen_seq, **gen_kwargs)[0]['generated_text']
return str(answer)
elif action == "sumPaper":
if len(self.paper) == 0:
for entity in state['entities']:
if (entity['entity'] == 'TITLE'):
self.paper = self._get_matching_titles('paper_db', entity['value'])
if (len(self.paper) > 0):
break
if len(self.paper) == 0:
return "I cannot seem to find the requested paper. Try again by specifying the title of the paper."
#implement that
df = self.db['content_db'][self.db['content_db']['paperid'] == self.paper['id']]
answer = ""
for i, row in df.iterrows():
gen_seq = 'summarize: '+row['content']
gen_kwargs = {"length_penalty": 1.5, "num_beams":8, "max_length": 100}
answer = self.generator(gen_seq, **gen_kwargs)[0]['generated_text'] + ' '
return answer
elif action == "sumPaper":
if len(self.paper) == 0:
for entity in state['entities']:
if (entity['entity'] == 'TITLE'):
self.paper = self._get_matching_titles('paper_db', entity['value'])
if (len(self.paper) > 0):
break
if len(self.paper) == 0:
return "I cannot seem to find the requested paper. Try again by specifying the title of the paper."
#implement that
df = self.db['content_db'][self.db['content_db']['paperid'] == self.paper['id']]
answer = ""
for i, row in df.iterrows():
gen_seq = 'summarize: '+row['content']
gen_kwargs = {"length_penalty": 1.5, "num_beams":6, "max_length": 120}
answer = self.generators['summ'](gen_seq, **gen_kwargs)[0]['generated_text'] + ' '
return answer
elif action == "Clarify":
return str("Can you please clarify?")
elif action == "ClarifyResource":
if state['intent'] in ['FINDPAPER', 'SUMMARIZEPAPER']:
return 'Please specify the title, the topic or the paper of interest.'
else:
return 'Please specify the title, the topic or the dataset of interest.'
elif action == "GenClarify":
gen_kwargs = {"length_penalty": 2.5, "num_beams":8, "max_length": 120}
question = self.generators['amb']('question: '+ utterance + ' context: ' + consec_history , **gen_kwargs)[0]['generated_text']
return question

View File

@ -20,8 +20,7 @@ class User:
if len(index) > 0:
self.interests.at[index[0], 'frequency'] += 1
else:
self.interests = self.interests.append({'interest': topic, 'frequency': max(
self.interests['frequency']) if not self.interests.empty else 6}, ignore_index=True)
self.interests = self.interests.append({'interest': topic, 'frequency': max(self.interests['frequency']) if not self.interests.empty else 6}, ignore_index=True)
self.interests = self.interests.sort_values(by='frequency', ascending=False, ignore_index=True)
self.interests.to_json(self.interests_file)

148
main.py
View File

@ -1,6 +1,5 @@
import os
import warnings
import faiss
import torch
from flask import Flask, render_template, request, jsonify
@ -10,131 +9,125 @@ import spacy
import spacy_transformers
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
from User import User
from VRE import VRE
from NLU import NLU
from DM import DM
from Recommender import Recommender
from ResponseGenerator import ResponseGenerator
import pandas as pd
import time
import threading
from sentence_transformers import SentenceTransformer
app = Flask(__name__)
#allow frontend address
url = os.getenv("FRONTEND_URL_WITH_PORT")
cors = CORS(app, resources={r"/predict": {"origins": url}, r"/feedback": {"origins": url}})
#cors = CORS(app, resources={r"/predict": {"origins": "*"}, r"/feedback": {"origins": "*"}})
"""
conn = psycopg2.connect(
host="janet-pg",
host="https://janet-app-db.d4science.org",
database=os.getenv("POSTGRES_DB"),
user=os.getenv("POSTGRES_USER"),
password=os.getenv("POSTGRES_PASSWORD"))
"""
conn = psycopg2.connect(host="https://janet-app-db.d4science.org",
database="janet",
user="janet_user",
password="2fb5e81fec5a2d906a04")
cur = conn.cursor()
#rg = ResponseGenerator(index)
def get_response(text):
# get response from janet itself
return text, 'candAnswer'
def vre_fetch():
while True:
time.sleep(1000)
print('getting new material')
vre.get_vre_update()
vre.index_periodic_update()
rg.update_index(vre.get_index())
rg.update_db(vre.get_db())
while True:
time.sleep(1000)
print('getting new material')
vre.get_vre_update()
vre.index_periodic_update()
rg.update_index(vre.get_index())
rg.update_db(vre.get_db())
def user_interest_decay():
while True:
print("decaying interests after 3 minutes")
time.sleep(180)
user.decay_interests()
def recommend():
while True:
if time.time() - dm.get_recent_state()['time'] > 1000:
print("Making Recommendation: ")
prompt = rec.make_recommendation(user.username)
if prompt != "":
print(prompt)
time.sleep(1000)
while True:
print("decaying interests after 3 minutes")
time.sleep(180)
user.decay_interests()
@app.route("/predict", methods=['POST'])
def predict():
text = request.get_json().get("message")
state = nlu.process_utterance(text, dm.get_utt_history())
user_interests = []
for entity in state['entities']:
if entity['entity'] == 'TOPIC':
user_interests.append(entity['value'])
user.update_interests(user_interests)
dm.update(state)
action = dm.next_action()
response = rg.gen_response(state['modified_prompt'], dm.get_recent_state(), dm.get_utt_history(), action)
message = {"answer": response, "query": text, "cand": "candidate", "history": dm.get_utt_history(), "modQuery": state['modified_prompt']}
message = {}
if text == "<HELP_ON_START>":
state = {'help': True, 'inactive': False}
dm.update(state)
action = dm.next_action()
response = rg.gen_response(action)
message = {"answer": response}
elif text == "<RECOMMEND_ON_IDLE>":
state = {'help': False, 'inactive': True}
dm.update(state)
action = dm.next_action()
response = rg.gen_response(action, username=user.username)
message = {"answer": response}
else:
state = nlu.process_utterance(text, dm.get_consec_history(), dm.get_sep_history())
state['help'] = False
state['inactive'] = False
user_interests = []
for entity in state['entities']:
if entity['entity'] == 'TOPIC':
user_interests.append(entity['value'])
user.update_interests(user_interests)
dm.update(state)
action = dm.next_action()
self, action, utterance=None, username=None, state=None, consec_history=None
response = rg.gen_response(action, utterance=state['modified_query'], state=dm.get_recent_state(), consec_history=dm.get_consec_history())
message = {"answer": response, "query": text, "cand": "candidate", "history": dm.get_consec_history(), "modQuery": state['modified_query']}
new_state = {'modified_query': response}
dm.update(new_state)
reply = jsonify(message)
#reply.headers.add('Access-Control-Allow-Origin', '*')
return reply
@app.route('/feedback', methods = ['POST'])
def feedback():
data = request.get_json()['feedback']
# Make data frame of above data
print(data)
#df = pd.DataFrame([data])
#file_exists = os.path.isfile('feedback.csv')
#df = pd.DataFrame(data=[data['response'], data['length'], data['fluency'], data['truthfulness'], data['usefulness'], data['speed']]
# ,columns=['response', 'length', 'fluency', 'truthfulness', 'usefulness', 'speed'])
#df.to_csv('feedback.csv', mode='a', index=False, header=(not file_exists))
cur.execute('INSERT INTO feedback (query, history, janet_modified_query,
is_modified_query_correct, user_modified_query,
response, preferred_response, response_length_feedback,
response_fluency_feedback, response_truth_feedback,
response_useful_feedback, response_time_feedback, response_intent)'
'VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)',
cur.execute('INSERT INTO feedback (query, history, janet_modified_query, is_modified_query_correct, user_modified_query, response, preferred_response, response_length_feedback, response_fluency_feedback, response_truth_feedback, response_useful_feedback, response_time_feedback, response_intent) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)',
(data['query'], data['history'], data['modQuery'],
data['queryModCorrect'], data['correctQuery'],
data['janetResponse'], data['preferredResponse'], data['length'],
data['fluency'], data['truthfulness'], data['usefulness'],
data['speed'], data['intent'])
)
reply = jsonify({"status": "done"})
#reply.headers.add('Access-Control-Allow-Origin', '*')
return reply
if __name__ == "__main__":
warnings.filterwarnings("ignore")
#load NLU
def_tokenizer = AutoTokenizer.from_pretrained("castorini/t5-base-canard")
def_reference_resolver = AutoModelForSeq2SeqLM.from_pretrained("castorini/t5-base-canard")
def_intent_classifier_dir = "./IntentClassifier/"
def_entity_extractor = spacy.load("./EntityExtraction/BestModel")
def_offense_filter_dir ="./OffensiveClassifier"
device = "cuda" if torch.cuda.is_available() else "cpu"
device_flag = torch.cuda.current_device() if torch.cuda.is_available() else -1
nlu = NLU(device, device_flag, def_reference_resolver, def_tokenizer, def_intent_classifier_dir, def_offense_filter_dir, def_entity_extractor)
query_rewriter = pipeline("text2text-generation", model="castorini/t5-base-canard")
intent_classifier = pipeline("sentiment-analysis", model='./intent_classifier', device=device_flag)
entity_extractor = spacy.load("./entity_extractor")
offensive_classifier = pipeline("sentiment-analysis", model='./offensive_classifier', device=device_flag)
ambig_classifier = pipeline("sentiment-analysis", model='./ambig_classifier', device=device_flag)
coref_resolver = spacy.load("en_coreference_web_trf")
nlu = NLU(query_rewriter, coref_resolver, intent_classifier, offensive_classifier, entity_extractor, ambig_classifier)
#load retriever and generator
def_retriever = SentenceTransformer('./BigRetriever/').to(device)
def_generator = pipeline("text2text-generation", model="./generator", device=device_flag)
retriever = SentenceTransformer('./BigRetriever/').to(device)
qa_generator = pipeline("text2text-generation", model="./train_qa", device=device_flag)
summ_generator = pipeline("text2text-generation", model="./train_summ", device=device_flag)
chat_generator = pipeline("text2text-generation", model="./train_chat", device=device_flag)
amb_generator = pipeline("text2text-generation", model="./train_amb_gen", device=device_flag)
generators = {'qa': qa_generator,
'chat': chat_generator,
'amb': amb_generator,
'summ': summ_generator}
#load vre
token = '2c1e8f88-461c-42c0-8cc1-b7660771c9a3-843339462'
@ -148,15 +141,14 @@ if __name__ == "__main__":
threading.Thread(target=user_interest_decay, name='decayinterest').start()
rec = Recommender(def_retriever)
rec = Recommender(retriever)
dm = DM()
rg = ResponseGenerator(index,db,def_generator,def_retriever)
threading.Thread(target=recommend, name='recommend').start()
rg = ResponseGenerator(index,db, recommender, generators, retriever)
cur.execute('CREATE TABLE IF NOT EXISTS feedback (id serial PRIMARY KEY,'
cur.execute('CREATE TABLE IF NOT EXISTS feedback_trial (id serial PRIMARY KEY,'
'query text NOT NULL,'
'history text NOT NULL,'
'janet_modified_query text NOT NULL,'

View File

@ -11,18 +11,26 @@ regex==2022.6.2
requests==2.25.1
scikit-learn==1.0.2
scipy==1.7.3
sentence-transformers==2.2.2
sentencepiece==0.1.97
sklearn-pandas==1.8.0
spacy==3.5.0
spacy-transformers==1.2.2
spacy==3.4.4
spacy-alignments==0.9.0
spacy-legacy==3.0.12
spacy-loggers==1.0.4
spacy-transformers==1.1.9
spacy-experimental==0.6.2
torch @ https://download.pytorch.org/whl/cu116/torch-1.13.1%2Bcu116-cp38-cp38-linux_x86_64.whl
torchaudio @ https://download.pytorch.org/whl/cu116/torchaudio-0.13.1%2Bcu116-cp38-cp38-linux_x86_64.whl
torchsummary==1.5.1
torchtext==0.14.1
sentence-transformers
torchvision @ https://download.pytorch.org/whl/cu116/torchvision-0.14.1%2Bcu116-cp38-cp38-linux_x86_64.whl
tqdm==4.64.1
transformers==4.26.1
transformers
markupsafe==2.0.1
psycopg2==2.9.5
en-coreference-web-trf @ https://github.com/explosion/spacy-experimental/releases/download/v0.6.1/en_coreference_web_trf-3.4.0a2-py3-none-any.whl
Werkzeug==1.0.1