janet
This commit is contained in:
parent
c4839cc743
commit
87037105d7
|
@ -84,15 +84,14 @@ class ResponseGenerator:
|
||||||
def _search_index(self, index_type, db_type, query, multi=False):
|
def _search_index(self, index_type, db_type, query, multi=False):
|
||||||
self.index[index_type].add_faiss_index(column="embeddings")
|
self.index[index_type].add_faiss_index(column="embeddings")
|
||||||
scores, samples = self.index[index_type].get_nearest_examples(
|
scores, samples = self.index[index_type].get_nearest_examples(
|
||||||
"embeddings", retriever.encode([query]), k=3
|
"embeddings", self.retriever.encode([query]), k=self.num_retrieved
|
||||||
)
|
)
|
||||||
samples_df = pd.DataFrame.from_dict(samples)
|
samples_df = pd.DataFrame.from_dict(samples)
|
||||||
samples_df["scores"] = scores
|
samples_df["scores"] = scores
|
||||||
samples_df.sort_values("scores", ascending=False, inplace=True)
|
samples_df.sort_values("scores", ascending=False, inplace=True)
|
||||||
|
|
||||||
if multi:
|
if multi:
|
||||||
return samples_df.reset_index(drop=True)
|
return samples_df.reset_index(drop=True)
|
||||||
return samples_df.iloc[0].reset_index(drop=True)
|
return samples_df.reset_index(drop=True).iloc[0]
|
||||||
|
|
||||||
|
|
||||||
def gen_response(self, action, utterance=None, name=None, username=None, vrename=None, state=None, consec_history=None, chitchat_history=None):
|
def gen_response(self, action, utterance=None, name=None, username=None, vrename=None, state=None, consec_history=None, chitchat_history=None):
|
||||||
|
@ -275,8 +274,11 @@ class ResponseGenerator:
|
||||||
self.paper = paper
|
self.paper = paper
|
||||||
break
|
break
|
||||||
if (entity['entity'] == 'TOPIC'):
|
if (entity['entity'] == 'TOPIC'):
|
||||||
self.paper = self._get_matching_topics('paper_db', entity['value'])
|
cand_paper = self._get_matching_topics('paper_db', entity['value'])
|
||||||
|
if len(cand_paper) > 0:
|
||||||
|
self.paper = cand_paper
|
||||||
if len(self.paper) == 0:
|
if len(self.paper) == 0:
|
||||||
|
print(self.paper)
|
||||||
return "I cannot seem to find the requested paper. Try again by specifying the title of the paper."
|
return "I cannot seem to find the requested paper. Try again by specifying the title of the paper."
|
||||||
#implement that
|
#implement that
|
||||||
df = self.db['content_db'][self.db['content_db']['paperid'] == self.paper['id']]
|
df = self.db['content_db'][self.db['content_db']['paperid'] == self.paper['id']]
|
||||||
|
|
13
main.py
13
main.py
|
@ -209,7 +209,18 @@ if __name__ == "__main__":
|
||||||
index = vre.get_index()
|
index = vre.get_index()
|
||||||
db = vre.get_db()
|
db = vre.get_db()
|
||||||
rg = ResponseGenerator(index,db, rec, generators, retriever)
|
rg = ResponseGenerator(index,db, rec, generators, retriever)
|
||||||
|
del retriever
|
||||||
|
del generators
|
||||||
|
del qa_generator
|
||||||
|
del chat_generator
|
||||||
|
del summ_generator
|
||||||
|
del amb_generator
|
||||||
|
del query_rewriter
|
||||||
|
del intent_classifier
|
||||||
|
del entity_extractor
|
||||||
|
del offensive_classifier
|
||||||
|
del ambig_classifier
|
||||||
|
del coref_resolver
|
||||||
threading.Thread(target=vre_fetch, name='updatevre').start()
|
threading.Thread(target=vre_fetch, name='updatevre').start()
|
||||||
threading.Thread(target=clear_inactive, name='clear').start()
|
threading.Thread(target=clear_inactive, name='clear').start()
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue