This commit is contained in:
ahmed531998 2023-04-21 07:30:33 +02:00
parent 6a0c033b79
commit dd24d723bd
2 changed files with 18 additions and 5 deletions

View File

@ -84,15 +84,14 @@ class ResponseGenerator:
def _search_index(self, index_type, db_type, query, multi=False): def _search_index(self, index_type, db_type, query, multi=False):
self.index[index_type].add_faiss_index(column="embeddings") self.index[index_type].add_faiss_index(column="embeddings")
scores, samples = self.index[index_type].get_nearest_examples( scores, samples = self.index[index_type].get_nearest_examples(
"embeddings", retriever.encode([query]), k=3 "embeddings", self.retriever.encode([query]), k=self.num_retrieved
) )
samples_df = pd.DataFrame.from_dict(samples) samples_df = pd.DataFrame.from_dict(samples)
samples_df["scores"] = scores samples_df["scores"] = scores
samples_df.sort_values("scores", ascending=False, inplace=True) samples_df.sort_values("scores", ascending=False, inplace=True)
if multi: if multi:
return samples_df.reset_index(drop=True) return samples_df.reset_index(drop=True)
return samples_df.iloc[0].reset_index(drop=True) return samples_df.reset_index(drop=True).iloc[0]
def gen_response(self, action, utterance=None, name=None, username=None, vrename=None, state=None, consec_history=None, chitchat_history=None): def gen_response(self, action, utterance=None, name=None, username=None, vrename=None, state=None, consec_history=None, chitchat_history=None):
@ -275,8 +274,11 @@ class ResponseGenerator:
self.paper = paper self.paper = paper
break break
if (entity['entity'] == 'TOPIC'): if (entity['entity'] == 'TOPIC'):
self.paper = self._get_matching_topics('paper_db', entity['value']) cand_paper = self._get_matching_topics('paper_db', entity['value'])
if len(cand_paper) > 0:
self.paper = cand_paper
if len(self.paper) == 0: if len(self.paper) == 0:
print(self.paper)
return "I cannot seem to find the requested paper. Try again by specifying the title of the paper." return "I cannot seem to find the requested paper. Try again by specifying the title of the paper."
#implement that #implement that
df = self.db['content_db'][self.db['content_db']['paperid'] == self.paper['id']] df = self.db['content_db'][self.db['content_db']['paperid'] == self.paper['id']]

13
main.py
View File

@ -209,7 +209,18 @@ if __name__ == "__main__":
index = vre.get_index() index = vre.get_index()
db = vre.get_db() db = vre.get_db()
rg = ResponseGenerator(index,db, rec, generators, retriever) rg = ResponseGenerator(index,db, rec, generators, retriever)
del retriever
del generators
del qa_generator
del chat_generator
del summ_generator
del amb_generator
del query_rewriter
del intent_classifier
del entity_extractor
del offensive_classifier
del ambig_classifier
del coref_resolver
threading.Thread(target=vre_fetch, name='updatevre').start() threading.Thread(target=vre_fetch, name='updatevre').start()
threading.Thread(target=clear_inactive, name='clear').start() threading.Thread(target=clear_inactive, name='clear').start()