This commit is contained in:
ahmed531998 2023-04-21 07:30:33 +02:00
parent 6a0c033b79
commit dd24d723bd
2 changed files with 18 additions and 5 deletions

View File

@ -84,15 +84,14 @@ class ResponseGenerator:
def _search_index(self, index_type, db_type, query, multi=False):
self.index[index_type].add_faiss_index(column="embeddings")
scores, samples = self.index[index_type].get_nearest_examples(
"embeddings", retriever.encode([query]), k=3
"embeddings", self.retriever.encode([query]), k=self.num_retrieved
)
samples_df = pd.DataFrame.from_dict(samples)
samples_df["scores"] = scores
samples_df.sort_values("scores", ascending=False, inplace=True)
if multi:
return samples_df.reset_index(drop=True)
return samples_df.iloc[0].reset_index(drop=True)
return samples_df.reset_index(drop=True).iloc[0]
def gen_response(self, action, utterance=None, name=None, username=None, vrename=None, state=None, consec_history=None, chitchat_history=None):
@ -275,8 +274,11 @@ class ResponseGenerator:
self.paper = paper
break
if (entity['entity'] == 'TOPIC'):
self.paper = self._get_matching_topics('paper_db', entity['value'])
cand_paper = self._get_matching_topics('paper_db', entity['value'])
if len(cand_paper) > 0:
self.paper = cand_paper
if len(self.paper) == 0:
print(self.paper)
return "I cannot seem to find the requested paper. Try again by specifying the title of the paper."
#implement that
df = self.db['content_db'][self.db['content_db']['paperid'] == self.paper['id']]

13
main.py
View File

@ -209,7 +209,18 @@ if __name__ == "__main__":
index = vre.get_index()
db = vre.get_db()
rg = ResponseGenerator(index,db, rec, generators, retriever)
del retriever
del generators
del qa_generator
del chat_generator
del summ_generator
del amb_generator
del query_rewriter
del intent_classifier
del entity_extractor
del offensive_classifier
del ambig_classifier
del coref_resolver
threading.Thread(target=vre_fetch, name='updatevre').start()
threading.Thread(target=clear_inactive, name='clear').start()