JanetBackEnd/ResponseGenerator.py

308 lines
17 KiB
Python

from sentence_transformers import models, SentenceTransformer
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
import faiss
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
import pandas as pd
from datetime import datetime
from datasets import Dataset
class ResponseGenerator:
def __init__(self, index, db,recommender,generators, retriever, num_retrieved=3):
self.generators = generators
self.retriever = retriever
self.recommender = recommender
self.db = db
self.index = index
self.num_retrieved = num_retrieved
self.paper = {}
self.dataset = {}
self.post = {}
def update_index(self, index):
self.index = index
def update_db(self, db):
self.db = db
def _get_resources_links(self, item):
if len(item) == 0:
return []
links = []
for rsrc in item['resources']:
links.append(rsrc['url'])
return links
def _get_matching_titles(self, rsrc, title):
cand = self.db[rsrc].loc[self.db[rsrc]['title'] == title.lower()].reset_index(drop=True)
if not cand.empty:
return cand.loc[0]
else:
return {}
def _get_matching_authors(self, rsrc, author, recent=False):
cand = self.db[rsrc].loc[self.db[rsrc]['author'] == author.lower()].reset_index(drop=True)
if not cand.empty:
if recent:
index = 0
curr = 0
for i, row in cand.iterrows():
if row['time'] > curr:
index = i
curr = row['time']
return cand.loc[index]
else:
return cand.loc[0]
else:
return {}
def _get_most_recent(self, rsrc):
cand = self.db[rsrc]
index = 0
curr = 0
for i, row in cand.iterrows():
if row['time'] > curr:
index = i
curr = row['time']
return cand.loc[index]
def _get_matching_topics(self, rsrc, topic):
matches = []
score = 0.7
for i, cand in self.db[rsrc].iterrows():
for tag in cand['tags']:
sim = cosine_similarity(np.array(self.retriever.encode([tag])), np.array(self.retriever.encode([topic.lower()])))
if sim > score:
if(len(matches)>0):
matches[0] = cand
else:
matches.append(cand)
score = sim
if len(matches) > 0:
return matches[0]
else:
return []
def _search_index(self, index_type, db_type, query, multi=False):
self.index[index_type].add_faiss_index(column="embeddings")
scores, samples = self.index[index_type].get_nearest_examples(
"embeddings", self.retriever.encode([query]), k=self.num_retrieved
)
samples_df = pd.DataFrame.from_dict(samples)
samples_df["scores"] = scores
samples_df.sort_values("scores", ascending=False, inplace=True)
if multi:
return samples_df.reset_index(drop=True)
return samples_df.reset_index(drop=True).iloc[0]
def gen_response(self, action, utterance=None, name=None, username=None, vrename=None, state=None, consec_history=None, chitchat_history=None):
if action == "Help":
commands = " You can choose between using one of the supported commands to explore the environment or you can use natural language to find resourcesand get answers and summaries. \n "
listofcommands = self.gen_response(action="listCommands")
return "Hey " + name + "! it's Janet! I am here to help you make use of the datasets and papers in the catalogue of the " + vrename +" VRE. I can answer questions whose answers may be inside the papers. I can summarize papers for you. I can also chat with you. So, whichever it is, I am ready to chat!" + commands + listofcommands
elif action == "Recommend":
prompt = self.recommender.make_recommendation(username, name)
if prompt != "":
return prompt
else:
return "I can help you with exploiting the contents of the VRE, just let me know!"
elif action == "OffenseReject":
return "I am sorry, I cannot answer to this kind of language"
elif action == "getHelp":
commands = self.gen_response(action="listCommands")
return "I can answer questions related to the papers in the VRE's catalogue. I can also get you the posts, papers and datasets from the catalogue if you specify a topic or an author. I am also capable of small talk and summarizing papers to an extent. Just write to me what you want in natural language and I will try to do it. Alternatively, you may use one of the commands Janet supports. " + commands
elif action == "findPost":
for entity in state['entities']:
if(entity['entity'] == 'TOPIC'):
self.post = self._get_matching_topics('post_db', entity['value'])
if len(self.post) > 0:
return str("This is a relevant post: " + self.post['content'] + ' by ' + self.post['author'])
if(entity['entity'] == 'AUTHOR'):
self.post = self._get_matching_authors('post_db', entity['value'], recent=True)
if len(self.post) > 0:
if len(self.post['tags']) > 0:
return str("Here is the most recent post by: " + self.post['author'] + ', which is about ' + ', '.join(self.post['tags']) + '. ' + self.post['content'])
else:
return str("Here is the most recent post by: " + self.post['author'] + ', ' + self.post['content'])
if len(self.post) > 0:
ev = self.post['content']
#generate the answer
gen_seq = 'question: '+utterance+" context: "+ev
gen_kwargs = {"length_penalty": 0.5, "num_beams":2, "max_length": 60, "repetition_penalty": 2.5, "temperature": 2}
answer = self.generators['qa'](gen_seq, **gen_kwargs)[0]['generated_text']
if len(self.post['tags']) > 0:
return "The post is about: " + answer + " \n There is a special focus on " + ', '.join(self.post['tags'])
else:
return "The post is about: " + answer
self.post = self._get_most_recent('post_db')
return "This is the most recent post. " + self.post['content'] + '\n If you want another post, please rewrite the query specifying either the author or the topic.'
elif action == "ConvGen":
gen_kwargs = {"length_penalty": 2.5, "num_beams":2, "max_length": 30, "repetition_penalty": 2.5, "temperature": 2}
#answer = self.generators['chat']('history: '+ consec_history + ' ' + utterance + ' persona: ' + 'I am Janet. My name is Janet. I am an AI developed by CNR to help VRE users.' , **gen_kwargs)[0]['generated_text']
answer = self.generators['chat']('question: ' + utterance + ' context: My name is Janet. I am an AI developed by CNR to help VRE users. ' + chitchat_history , **gen_kwargs)[0]['generated_text']
return answer
elif action == "findPaper":
for entity in state['entities']:
if (entity['entity'] == 'TITLE'):
self.paper = self._get_matching_titles('paper_db', entity['value'])
#links = self._get_resources_links(self.paper)
if len(self.paper) > 0:# and len(links) > 0:
return str("Here is the paper you want: " + self.paper['title'] + '. ' + "It can be viewed at " + self.paper['url']) #links[0]
else:
self.paper = self._search_index('paper_titles_index', 'paper_db', entity['value'])
#links = self._get_resources_links(self.paper)
return str("This paper could be relevant: " + self.paper['title'] + '. ' + "It can be viewed at " + self.paper['url'])
if(entity['entity'] == 'TOPIC'):
self.paper = self._get_matching_topics('paper_db', entity['value'])
#links = self._get_resources_links(self.paper)
if len(self.paper) > 0: # and len(links) > 0:
return str("This paper could be relevant: " + self.paper['title'] + '. ' + "It can be viewed at " + self.paper['url'])
if(entity['entity'] == 'AUTHOR'):
self.paper = self._get_matching_authors('paper_db', entity['value'])
#links = self._get_resources_links(self.paper)
if len(self.paper) > 0: # and len(links) > 0:
return str("Here is the paper you want: " + self.paper['title'] + '. ' + "It can be viewed at " + self.paper['url'])
self.paper = self._search_index('paper_desc_index', 'paper_db', utterance)
#links = self._get_resources_links(self.paper)
return str("This paper could be relevant: " + self.paper['title'] + '. ' + "It can be viewed at " + self.paper['url'])
elif action == "findDataset":
for entity in state['entities']:
if (entity['entity'] == 'TITLE'):
self.dataset = self._get_matching_titles('dataset_db', entity['value'])
#links = self._get_resources_links(self.dataset)
if len(self.dataset) > 0: # and len(links) > 0:
return str("Here is the dataset you wanted: " + self.dataset['title'] + '. ' + "It can be viewed at " + self.dataset['url'])
else:
self.dataset = self._search_index('dataset_titles_index', 'dataset_db', entity['value'])
#links = self._get_resources_links(self.dataset)
return str("This dataset could be relevant: " + self.dataset['title'] + '. ' + "It can be viewed at " + self.dataset['url'])
if(entity['entity'] == 'TOPIC'):
self.dataset = self._get_matching_topics('dataset_db', entity['value'])
#links = self._get_resources_links(self.dataset)
if len(self.dataset) > 0: # and len(links) > 0:
return str("This dataset could be relevant: " + self.dataset['title'] + '. ' + "It can be viewed at " + self.dataset['url'])
if(entity['entity'] == 'AUTHOR'):
self.dataset = self._get_matching_authors('dataset_db', entity['value'])
#links = self._get_resources_links(self.dataset)
if len(self.dataset) > 0: #and len(links) > 0:
return str("Here is the dataset you want: " + self.dataset['title'] + '. ' + "It can be viewed at " + self.dataset['url'])
self.dataset = self._search_index('dataset_desc_index', 'dataset_db', utterance)
#links = self._get_resources_links(self.dataset)
return str("This dataset could be relevant: " + self.dataset['title'] + '. ' + "It can be viewed at " + self.dataset['url'])
elif action == "RetGen":
#retrieve the most relevant paragraph
content = self._search_index('content_index', 'content_db', utterance, multi=True)#['content']
evidence = ""
ev = ""
for i, row in content.iterrows():
evidence = evidence + str(i+1) + ") " + row['content'] + ' \n '
ev = ev + " " + row['content']
#generate the answer
gen_seq = 'question: '+utterance+" context: "+ev
#handle return random 2 answers
gen_kwargs = {"length_penalty": 0.5, "num_beams":2, "max_length": 60, "repetition_penalty": 2.5, "temperature": 2}
answer = self.generators['qa'](gen_seq, **gen_kwargs)[0]['generated_text']
return "According to the following evidence: " + evidence + " \n _______ \n " + "The answer is: " + answer
elif action == "listPapers":
answer = vrename + " has the following papers: \n"
j = 1
for i, pap in self.db['paper_db'].iterrows():
answer = answer + ' ' + str(j) + ') ' + pap['title'] + ': ' + pap['notes'] + ' \n '
j+=1
return answer
elif action == "listDatasets":
j = 1
answer = vrename + " has the following datasets: \n"
for i, datase in self.db['dataset_db'].iterrows():
answer = answer + ' ' + str(j) + ') ' + datase['title'] + ': ' +datase['notes'] + ' \n '
j+=1
return answer
elif action == "listCommands":
return "Janet supports the following commands: \n 1) help : explains how to use Janet. \n 2) list resources : lists all the papers and datasets in the VRE. \n 3) list papers : lists all the papers in the VRE. \n 4) list datasets : lists all the datasets in the VRE. \n 5) list topics : lists the topics discussed in the VRE. \n 6) list commands : displays this list of commands. \n"
elif action == "listTopics":
topics = {}
for i, pos in self.db['post_db'].iterrows():
for tag in pos['tags']:
topics[tag] = topics[tag]+1 if tag in topics else 1
topics = sorted(topics, reverse=True)
topic_string = topics[0]
for i in range(1, len(topics)):
topic_string = topic_string + ', ' + topics[i]
return "The main topics of " + vrename + " ordered by popularity are: " + topic_string + '. \n '
elif action == "listResources":
papers = self.gen_response(action="listPapers", vrename=vrename)
datasets = self.gen_response(action="listDatasets", vrename=vrename)
return papers + " Also, " + datasets
elif action == "command":
if utterance == "help":
return self.gen_response(action="getHelp", vrename=vrename)
elif utterance == "list resources":
return self.gen_response(action="listResources", vrename=vrename)
elif utterance == "list papers":
return self.gen_response(action="listPapers", vrename=vrename)
elif utterance == "list datasets":
return self.gen_response(action="listDatasets", vrename=vrename)
elif utterance == "list topics":
return self.gen_response(action="listTopics", vrename=vrename)
elif utterance == "list commands":
return self.gen_response(action="listCommands")
elif action == "sumPaper":
if len(self.paper) == 0 or (len(self.paper) > 0 and len(state['entities'])>0):
for entity in state['entities']:
if (entity['entity'] == 'TITLE'):
paper = self._get_matching_titles('paper_db', entity['value'])
if (len(paper) > 0):
self.paper = paper
break
if (entity['entity'] == 'TOPIC'):
cand_paper = self._get_matching_topics('paper_db', entity['value'])
if len(cand_paper) > 0:
self.paper = cand_paper
if len(self.paper) == 0:
print(self.paper)
return "I cannot seem to find the requested paper. Try again by specifying the title of the paper."
#implement that
df = self.db['content_db'][self.db['content_db']['paperid'] == self.paper['id']]
answer = ""
for i, row in df.iterrows():
gen_seq = 'summarize: '+row['content']
gen_kwargs = {"length_penalty": 1.5, "num_beams":6, "max_length": 30, "repetition_penalty": 2.5, "temperature": 2}
answer = answer + self.generators['summ'](gen_seq, **gen_kwargs)[0]['generated_text'] + ' '
return answer
elif action == "Clarify":
if state['intent'] in ['FINDPAPER', 'SUMMARIZEPAPER'] and len(state['entities']) == 0:
if len(self.paper) == 0:
return 'Please specify the title, the topic of the paper of interest.'
elif state['intent'] == 'FINDDATASET' and len(state['entities']) == 0:
if len(self.dataset) == 0:
return 'Please specify the title, the topic of the dataset of interest.'
elif state['intent'] == 'EXPLAINPOST' and len(state['entities']) == 0:
if len(self.post) != 0:
return self.gen_response(action="findPost", utterance=utterance, username=username, state=state, consec_history=consec_history)
return 'Please specify the the topic or the author of the post.'
else:
gen_kwargs = {"length_penalty": 2.5, "num_beams":8, "max_length": 120, "repetition_penalty": 2.5, "temperature": 2}
question = self.generators['amb']('question: '+ utterance + ' context: ' + consec_history , **gen_kwargs)[0]['generated_text']
return question
return "I am unable to generate the response. Can you please provide me with a prefered response in the feedback form so I can learn?"