From dd24d723bd9e7088591fb038a89a1eab8f764771 Mon Sep 17 00:00:00 2001 From: ahmed531998 Date: Fri, 21 Apr 2023 07:30:33 +0200 Subject: [PATCH] janet --- ResponseGenerator.py | 10 ++++++---- main.py | 13 ++++++++++++- 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/ResponseGenerator.py b/ResponseGenerator.py index 452c423..3fcd7b4 100644 --- a/ResponseGenerator.py +++ b/ResponseGenerator.py @@ -84,15 +84,14 @@ class ResponseGenerator: def _search_index(self, index_type, db_type, query, multi=False): self.index[index_type].add_faiss_index(column="embeddings") scores, samples = self.index[index_type].get_nearest_examples( - "embeddings", retriever.encode([query]), k=3 + "embeddings", self.retriever.encode([query]), k=self.num_retrieved ) samples_df = pd.DataFrame.from_dict(samples) samples_df["scores"] = scores samples_df.sort_values("scores", ascending=False, inplace=True) - if multi: return samples_df.reset_index(drop=True) - return samples_df.iloc[0].reset_index(drop=True) + return samples_df.reset_index(drop=True).iloc[0] def gen_response(self, action, utterance=None, name=None, username=None, vrename=None, state=None, consec_history=None, chitchat_history=None): @@ -275,8 +274,11 @@ class ResponseGenerator: self.paper = paper break if (entity['entity'] == 'TOPIC'): - self.paper = self._get_matching_topics('paper_db', entity['value']) + cand_paper = self._get_matching_topics('paper_db', entity['value']) + if len(cand_paper) > 0: + self.paper = cand_paper if len(self.paper) == 0: + print(self.paper) return "I cannot seem to find the requested paper. Try again by specifying the title of the paper." #implement that df = self.db['content_db'][self.db['content_db']['paperid'] == self.paper['id']] diff --git a/main.py b/main.py index cfe307c..21e93b7 100644 --- a/main.py +++ b/main.py @@ -209,7 +209,18 @@ if __name__ == "__main__": index = vre.get_index() db = vre.get_db() rg = ResponseGenerator(index,db, rec, generators, retriever) - + del retriever + del generators + del qa_generator + del chat_generator + del summ_generator + del amb_generator + del query_rewriter + del intent_classifier + del entity_extractor + del offensive_classifier + del ambig_classifier + del coref_resolver threading.Thread(target=vre_fetch, name='updatevre').start() threading.Thread(target=clear_inactive, name='clear').start()