181 lines
9.3 KiB
Python
181 lines
9.3 KiB
Python
from sentence_transformers import models, SentenceTransformer
|
|
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
|
|
import faiss
|
|
from sklearn.metrics.pairwise import cosine_similarity
|
|
import numpy as np
|
|
import pandas as pd
|
|
from datetime import datetime
|
|
|
|
class ResponseGenerator:
|
|
def __init__(self, index, db,recommender,generators, retriever, num_retrieved=1):
|
|
self.generators = generators
|
|
self.retriever = retriever
|
|
self.recommender = recommender
|
|
self.db = db
|
|
self.index = index
|
|
self.num_retrieved = num_retrieved
|
|
self.paper = {}
|
|
self.dataset = {}
|
|
|
|
def update_index(self, index):
|
|
self.index = index
|
|
def update_db(self, db):
|
|
self.db = db
|
|
|
|
def _get_resources_links(self, item):
|
|
if len(item) == 0:
|
|
return []
|
|
links = []
|
|
for rsrc in item['resources']:
|
|
links.append(rsrc['url'])
|
|
return links
|
|
|
|
def _get_matching_titles(self, rsrc, title):
|
|
cand = self.db[rsrc].loc[self.db[rsrc]['title'] == title.lower()].reset_index(drop=True)
|
|
if not cand.empty:
|
|
return cand.loc[0]
|
|
else:
|
|
return {}
|
|
|
|
def _get_matching_authors(self, rsrc, author):
|
|
cand = self.db[rsrc].loc[self.db[rsrc]['author'] == author.lower()].reset_index(drop=True)
|
|
if not cand.empty:
|
|
return cand.loc[0]
|
|
else:
|
|
return {}
|
|
|
|
def _get_matching_topics(self, rsrc, topic):
|
|
matches = []
|
|
score = 0.7
|
|
for i, cand in self.db[rsrc].iterrows():
|
|
for tag in cand['tags']:
|
|
sim = cosine_similarity(np.array(self.retriever.encode([tag])), np.array(self.retriever.encode([topic.lower()])))
|
|
if sim > score:
|
|
if(len(matches)>0):
|
|
matches[0] = cand
|
|
else:
|
|
matches.append(cand)
|
|
score = sim
|
|
if len(matches) > 0:
|
|
return matches[0]
|
|
else:
|
|
return []
|
|
|
|
def _search_index(self, index_type, db_type, query):
|
|
xq = self.retriever.encode([query])
|
|
D, I = self.index[index_type].search(xq, self.num_retrieved)
|
|
return self.db[db_type].iloc[[I[0]][0]].reset_index(drop=True).loc[0]
|
|
|
|
|
|
def gen_response(self, action, utterance=None, username=None, state=None, consec_history=None):
|
|
if action == "Help":
|
|
return "Hey it's Janet! I am here to help you make use of the datasets and papers in the VRE. I can answer questions whose answers may be inside the papers. I can summarize papers for you. I can also chat with you. So, whichever it is, I am ready to chat!"
|
|
elif action == "Recommend":
|
|
prompt = self.recommender.make_recommendation(username)
|
|
if prompt != "":
|
|
return prompt
|
|
else:
|
|
return "I can help you with exploiting the contents of the VRE, just let me know!"
|
|
|
|
elif action == "OffenseReject":
|
|
return "I am sorry, I cannot answer to this kind of language"
|
|
|
|
elif action == "ConvGen":
|
|
gen_kwargs = {"length_penalty": 2.5, "num_beams":2, "max_length": 30}
|
|
answer = self.generators['chat']('history: '+ consec_history + ' ' + utterance + ' persona: ' + 'I am Janet. My name is Janet. I am an AI developed by CNR to help VRE users.' , **gen_kwargs)[0]['generated_text']
|
|
return answer
|
|
|
|
elif action == "findPaper":
|
|
for entity in state['entities']:
|
|
if (entity['entity'] == 'TITLE'):
|
|
self.paper = self._get_matching_titles('paper_db', entity['value'])
|
|
links = self._get_resources_links(self.paper)
|
|
if len(self.paper) > 0 and len(links) > 0:
|
|
return str("Here is the paper you want: " + self.paper['title'] + '. ' + "It can be downloaded at " + links[0])
|
|
else:
|
|
self.paper = self._search_index('paper_titles_index', 'paper_db', entity['value'])
|
|
links = self._get_resources_links(self.paper)
|
|
return str("This paper could be relevant: " + self.paper['title'] + '. ' + "It can be downloaded at " + links[0])
|
|
if(entity['entity'] == 'TOPIC'):
|
|
self.paper = self._get_matching_topics('paper_db', entity['value'])
|
|
links = self._get_resources_links(self.paper)
|
|
if len(self.paper) > 0 and len(links) > 0:
|
|
return str("This paper could be relevant: " + self.paper['title'] + '. ' + "It can be downloaded at " + links[0])
|
|
|
|
if(entity['entity'] == 'AUTHOR'):
|
|
self.paper = self._get_matching_authors('paper_db', entity['value'])
|
|
links = self._get_resources_links(self.paper)
|
|
if len(self.paper) > 0 and len(links) > 0:
|
|
return str("Here is the paper you want: " + self.paper['title'] + '. ' + "It can be downloaded at " + links[0])
|
|
|
|
self.paper = self._search_index('paper_desc_index', 'paper_db', utterance)
|
|
links = self._get_resources_links(self.paper)
|
|
return str("This paper could be relevant: " + self.paper['title'] + '. ' + "It can be downloaded at " + links[0])
|
|
|
|
elif action == "findDataset":
|
|
for entity in state['entities']:
|
|
if (entity['entity'] == 'TITLE'):
|
|
self.dataset = self._get_matching_titles('dataset_db', entity['value'])
|
|
links = self._get_resources_links(self.dataset)
|
|
if len(self.dataset) > 0 and len(links) > 0:
|
|
return str("Here is the dataset you wanted: " + self.dataset['title'] + '. ' + "It can be downloaded at " + links[0])
|
|
else:
|
|
self.dataset = self._search_index('dataset_titles_index', 'dataset_db', entity['value'])
|
|
links = self._get_resources_links(self.dataset)
|
|
return str("This dataset could be relevant: " + self.dataset['title'] + '. ' + "It can be downloaded at " + links[0])
|
|
if(entity['entity'] == 'TOPIC'):
|
|
self.dataset = self._get_matching_topics('dataset_db', entity['value'])
|
|
links = self._get_resources_links(self.dataset)
|
|
if len(self.dataset) > 0 and len(links) > 0:
|
|
return str("This dataset could be relevant: " + self.dataset['title'] + '. ' + "It can be downloaded at " + links[0])
|
|
|
|
if(entity['entity'] == 'AUTHOR'):
|
|
self.dataset = self._get_matching_authors('dataset_db', entity['value'])
|
|
links = self._get_resources_links(self.dataset)
|
|
if len(self.dataset) > 0 and len(links) > 0:
|
|
return str("Here is the dataset you want: " + self.dataset['title'] + '. ' + "It can be downloaded at " + links[0])
|
|
|
|
self.dataset = self._search_index('dataset_desc_index', 'dataset_db', utterance)
|
|
links = self._get_resources_links(self.dataset)
|
|
return str("This dataset could be relevant: " + self.dataset['title'] + '. ' + "It can be downloaded at " + links[0])
|
|
|
|
|
|
elif action == "RetGen":
|
|
#retrieve the most relevant paragraph
|
|
content = str(self._search_index('content_index', 'content_db', utterance)['content'])
|
|
#generate the answer
|
|
gen_seq = 'question: '+utterance+" context: "+content
|
|
|
|
#handle return random 2 answers
|
|
gen_kwargs = {"length_penalty": 0.5, "num_beams":2, "max_length": 60}
|
|
answer = self.generators['qa'](gen_seq, **gen_kwargs)[0]['generated_text']
|
|
return str(answer)
|
|
|
|
elif action == "sumPaper":
|
|
if len(self.paper) == 0:
|
|
for entity in state['entities']:
|
|
if (entity['entity'] == 'TITLE'):
|
|
self.paper = self._get_matching_titles('paper_db', entity['value'])
|
|
if (len(self.paper) > 0):
|
|
break
|
|
if len(self.paper) == 0:
|
|
return "I cannot seem to find the requested paper. Try again by specifying the title of the paper."
|
|
#implement that
|
|
df = self.db['content_db'][self.db['content_db']['paperid'] == self.paper['id']]
|
|
answer = ""
|
|
for i, row in df.iterrows():
|
|
gen_seq = 'summarize: '+row['content']
|
|
gen_kwargs = {"length_penalty": 1.5, "num_beams":6, "max_length": 120}
|
|
answer = self.generators['summ'](gen_seq, **gen_kwargs)[0]['generated_text'] + ' '
|
|
return answer
|
|
|
|
elif action == "ClarifyResource":
|
|
if state['intent'] in ['FINDPAPER', 'SUMMARIZEPAPER']:
|
|
return 'Please specify the title, the topic or the paper of interest.'
|
|
else:
|
|
return 'Please specify the title, the topic or the dataset of interest.'
|
|
elif action == "GenClarify":
|
|
gen_kwargs = {"length_penalty": 2.5, "num_beams":8, "max_length": 120}
|
|
question = self.generators['amb']('question: '+ utterance + ' context: ' + consec_history , **gen_kwargs)[0]['generated_text']
|
|
return question
|