JanetBackEnd/DM.py

57 lines
2.0 KiB
Python
Raw Normal View History

2023-03-30 15:17:54 +02:00
import time
class DM:
2023-04-04 05:34:47 +02:00
def __init__(self, max_history_length=3):
self.working_history_sep = ""
self.working_history_consec = ""
self.max_history_length = max_history_length
self.chat_history = []
self.curr_state = None
def update_history(self):
2023-04-04 20:28:22 +02:00
to_consider = [x['modified_query'] for x in self.chat_history[-self.max_history_length*2:]]
2023-04-04 05:34:47 +02:00
self.working_history_consec = " . ".join(to_consider)
self.working_history_sep = " ||| ".join(to_consider)
def get_consec_history(self):
return self.working_history_consec
2023-03-30 15:17:54 +02:00
2023-04-04 05:34:47 +02:00
def get_sep_history(self):
return self.working_history_sep
2023-03-30 15:17:54 +02:00
def get_recent_state(self):
2023-04-04 05:34:47 +02:00
return self.curr_state
2023-03-30 15:17:54 +02:00
2023-04-04 05:34:47 +02:00
def get_dialogue_history(self):
return self.chat_history
2023-03-30 15:17:54 +02:00
def update(self, new_state):
2023-04-04 05:34:47 +02:00
self.chat_history.append(new_state)
self.curr_state = new_state
self.update_history()
2023-03-30 15:17:54 +02:00
def next_action(self):
2023-04-04 05:34:47 +02:00
if self.curr_state['help']:
return "Help"
elif self.curr_state['inactive']:
return "Recommend"
elif self.curr_state['is_clear']:
if self.curr_state['is_offensive']:
return "OffenseReject"
2023-03-30 15:17:54 +02:00
else:
2023-04-04 05:34:47 +02:00
if self.curr_state['intent'] == 'QA':
2023-03-30 15:17:54 +02:00
return "RetGen"
2023-04-04 05:34:47 +02:00
elif self.curr_state['intent'] == 'CHITCHAT':
2023-03-30 15:17:54 +02:00
return "ConvGen"
2023-04-04 05:34:47 +02:00
elif self.curr_state['intent'] == 'FINDPAPER':
2023-03-30 15:17:54 +02:00
return "findPaper"
2023-04-04 05:34:47 +02:00
elif self.curr_state['intent'] == 'FINDDATASET':
2023-03-30 15:17:54 +02:00
return "findDataset"
2023-04-04 05:34:47 +02:00
elif self.curr_state['intent'] == 'SUMMARIZEPAPER':
2023-03-30 15:17:54 +02:00
return "sumPaper"
else:
2023-04-04 05:34:47 +02:00
if self.curr_state['intent'] in ['FINDPAPER', 'FINDDATASET', 'SUMMARIZEPAPER'] and len(self.curr_state['entities']) == 0:
return "ClarifyResource"
else:
return "GenClarify"