janet
This commit is contained in:
parent
6a0c033b79
commit
dd24d723bd
|
@ -84,15 +84,14 @@ class ResponseGenerator:
|
|||
def _search_index(self, index_type, db_type, query, multi=False):
|
||||
self.index[index_type].add_faiss_index(column="embeddings")
|
||||
scores, samples = self.index[index_type].get_nearest_examples(
|
||||
"embeddings", retriever.encode([query]), k=3
|
||||
"embeddings", self.retriever.encode([query]), k=self.num_retrieved
|
||||
)
|
||||
samples_df = pd.DataFrame.from_dict(samples)
|
||||
samples_df["scores"] = scores
|
||||
samples_df.sort_values("scores", ascending=False, inplace=True)
|
||||
|
||||
if multi:
|
||||
return samples_df.reset_index(drop=True)
|
||||
return samples_df.iloc[0].reset_index(drop=True)
|
||||
return samples_df.reset_index(drop=True).iloc[0]
|
||||
|
||||
|
||||
def gen_response(self, action, utterance=None, name=None, username=None, vrename=None, state=None, consec_history=None, chitchat_history=None):
|
||||
|
@ -275,8 +274,11 @@ class ResponseGenerator:
|
|||
self.paper = paper
|
||||
break
|
||||
if (entity['entity'] == 'TOPIC'):
|
||||
self.paper = self._get_matching_topics('paper_db', entity['value'])
|
||||
cand_paper = self._get_matching_topics('paper_db', entity['value'])
|
||||
if len(cand_paper) > 0:
|
||||
self.paper = cand_paper
|
||||
if len(self.paper) == 0:
|
||||
print(self.paper)
|
||||
return "I cannot seem to find the requested paper. Try again by specifying the title of the paper."
|
||||
#implement that
|
||||
df = self.db['content_db'][self.db['content_db']['paperid'] == self.paper['id']]
|
||||
|
|
13
main.py
13
main.py
|
@ -209,7 +209,18 @@ if __name__ == "__main__":
|
|||
index = vre.get_index()
|
||||
db = vre.get_db()
|
||||
rg = ResponseGenerator(index,db, rec, generators, retriever)
|
||||
|
||||
del retriever
|
||||
del generators
|
||||
del qa_generator
|
||||
del chat_generator
|
||||
del summ_generator
|
||||
del amb_generator
|
||||
del query_rewriter
|
||||
del intent_classifier
|
||||
del entity_extractor
|
||||
del offensive_classifier
|
||||
del ambig_classifier
|
||||
del coref_resolver
|
||||
threading.Thread(target=vre_fetch, name='updatevre').start()
|
||||
threading.Thread(target=clear_inactive, name='clear').start()
|
||||
|
||||
|
|
Loading…
Reference in New Issue