This commit is contained in:
ahmed531998 2023-04-16 20:30:32 +02:00
parent 8bfa2e49fe
commit 615e9f191e
1 changed files with 13 additions and 3 deletions

View File

@ -53,6 +53,15 @@ class ResponseGenerator:
return cand.loc[0]
else:
return {}
def _get_most_recent(self, rsrc):
cand = self.db[rsrc]
index = 0
curr = 0
for i, row in cand.iterrows():
if row['time'] > curr:
index = i
curr = row['time']
return cand.loc[index]
def _get_matching_topics(self, rsrc, topic):
matches = []
@ -104,7 +113,7 @@ class ResponseGenerator:
self.post = self._get_matching_authors('post_db', entity['value'], recent=True)
if len(self.post) > 0:
if len(self.post['tags']) > 0:
return str("Here is the most recent post by: " + self.post['author'] + ', which is about ' + ', '.join(self.post['tags']) + self.post['content'])
return str("Here is the most recent post by: " + self.post['author'] + ', which is about ' + ', '.join(self.post['tags']) + '. ' + self.post['content'])
else:
return str("Here is the most recent post by: " + self.post['author'] + ', ' + self.post['content'])
if len(self.post) > 0:
@ -117,7 +126,8 @@ class ResponseGenerator:
return "The post is about: " + answer + " \n There is a special focus on " + ', '.join(self.post['tags'])
else:
return "The post is about: " + answer
return "I could not find the post you are looking for."
self.post = self._get_most_recent('post_db')
return "This is the most recent post. " + self.post['content'] + '\n If you want another post, please rewrite the query specifying either the author or the topic.'
elif action == "ConvGen":
gen_kwargs = {"length_penalty": 2.5, "num_beams":2, "max_length": 30, "repetition_penalty": 2.5, "temperature": 2}