JanetBackEnd/ResponseGenerator.py

144 lines
6.2 KiB
Python

from sentence_transformers import models, SentenceTransformer
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
import faiss
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
import pandas as pd
class ResponseGenerator:
def __init__(self, index, db,
generator, retriever, num_retrieved=1):
self.generator = generator
self.retriever = retriever
self.db = db
self.index = index
self.num_retrieved = num_retrieved
self.paper = {}
self.dataset = {}
def update_index(self, index):
self.index = index
def update_db(self, db):
self.db = db
def _get_resources_links(self, item):
if len(item) == 0:
return []
links = []
for rsrc in item['resources']:
links.append(rsrc['url'])
return links
def _get_matching_titles(self, rsrc, title):
cand = self.db[rsrc].loc[self.db[rsrc]['title'] == title.lower()].reset_index(drop=True)
if not cand.empty:
return cand.loc[0]
else:
return {}
def _get_matching_topics(self, rsrc, topic):
matches = []
score = 0.7
for i, cand in self.db[rsrc].iterrows():
for tag in cand['tags']:
sim = cosine_similarity(np.array(self.retriever.encode([tag])), np.array(self.retriever.encode([topic.lower()])))
if sim > score:
if(len(matches)>0):
matches[0] = cand
else:
matches.append(cand)
score = sim
if len(matches) > 0:
return matches[0]
else:
return []
def _search_index(self, index_type, db_type, query):
xq = self.retriever.encode([query])
D, I = self.index[index_type].search(xq, self.num_retrieved)
return self.db[db_type].iloc[[I[0]][0]].reset_index(drop=True).loc[0]
def gen_response(self, utterance, state, history, action):
if action == "NoCanDo":
return str("I am sorry, I cannot answer to this kind of language")
elif action == "ConvGen":
gen_kwargs = {"length_penalty": 2.5, "num_beams":4, "max_length": 20}
answer = self.generator('question: '+ utterance + ' context: ' + history , **gen_kwargs)[0]['generated_text']
return answer
elif action == "findPaper":
for entity in state['entities']:
if (entity['entity'] == 'TITLE'):
self.paper = self._get_matching_titles('paper_db', entity['value'])
links = self._get_resources_links(self.paper)
if len(self.paper) > 0 and len(links) > 0:
return str("This paper could be helpful: " + self.paper['title'] + '. ' + "It can be downloaded at " + links[0])
else:
self.paper = self._search_index('paper_titles_index', 'paper_db', entity['value'])
links = self._get_resources_links(self.paper)
return str("This paper could be helpful: " + self.paper['title'] + '. ' + "It can be downloaded at " + links[0])
if(entity['entity'] == 'TOPIC'):
self.paper = self._get_matching_topics('paper_db', entity['value'])
links = self._get_resources_links(self.paper)
if len(self.paper) > 0 and len(links) > 0:
return str("This paper could be helpful: " + self.paper['title'] + '. ' + "It can be downloaded at " + links[0])
self.paper = self._search_index('paper_desc_index', 'paper_db', utterance)
links = self._get_resources_links(self.paper)
return str("This paper could be helpful: " + self.paper['title'] + '. ' + "It can be downloaded at " + links[0])
elif action == "findDataset":
for entity in state['entities']:
if (entity['entity'] == 'TITLE'):
self.dataset = self._get_matching_titles('dataset_db', entity['value'])
links = self._get_resources_links(self.dataset)
if len(self.dataset) > 0 and len(links) > 0:
return str("This dataset could be helpful: " + self.dataset['title'] + '. ' + "It can be downloaded at " + links[0])
else:
self.dataset = self._search_index('dataset_titles_index', 'dataset_db', entity['value'])
links = self._get_resources_links(self.dataset)
return str("This dataset could be helpful: " + self.dataset['title'] + '. ' + "It can be downloaded at " + links[0])
if(entity['entity'] == 'TOPIC'):
self.dataset = self._get_matching_topics('dataset_db', entity['value'])
links = self._get_resources_links(self.dataset)
if len(self.dataset) > 0 and len(links) > 0:
return str("This dataset could be helpful: " + self.dataset['title'] + '. ' + "It can be downloaded at " + links[0])
self.dataset = self._search_index('dataset_desc_index', 'dataset_db', utterance)
links = self._get_resources_links(self.dataset)
return str("This dataset could be helpful: " + self.dataset['title'] + '. ' + "It can be downloaded at " + links[0])
elif action == "RetGen":
#retrieve the most relevant paragraph
content = str(self._search_index('content_index', 'content_db', utterance)['content'])
#generate the answer
gen_seq = 'question: '+utterance+" context: "+content
#handle return random 2 answers
gen_kwargs = {"length_penalty": 0.5, "num_beams":8, "max_length": 100}
answer = self.generator(gen_seq, **gen_kwargs)[0]['generated_text']
return str(answer)
elif action == "sumPaper":
if len(self.paper) == 0:
for entity in state['entities']:
if (entity['entity'] == 'TITLE'):
self.paper = self._get_matching_titles('paper_db', entity['value'])
if (len(self.paper) > 0):
break
if len(self.paper) == 0:
return "I cannot seem to find the requested paper. Try again by specifying the title of the paper."
#implement that
df = self.db['content_db'][self.db['content_db']['paperid'] == self.paper['id']]
answer = ""
for i, row in df.iterrows():
gen_seq = 'summarize: '+row['content']
gen_kwargs = {"length_penalty": 1.5, "num_beams":8, "max_length": 100}
answer = self.generator(gen_seq, **gen_kwargs)[0]['generated_text'] + ' '
return answer
elif action == "Clarify":
return str("Can you please clarify?")