85 lines
3.1 KiB
Python
85 lines
3.1 KiB
Python
import time
|
|
|
|
class DM:
|
|
def __init__(self, max_history_length=2):
|
|
self.working_history_sep = ""
|
|
self.working_history_consec = ""
|
|
self.chitchat_history_consec = ""
|
|
self.max_history_length = max_history_length
|
|
self.chat_history = []
|
|
self.curr_state = None
|
|
|
|
def update_history(self):
|
|
to_consider = [x['modified_query'] for x in self.chat_history[-self.max_history_length*2:]]
|
|
self.working_history_consec = " . ".join(to_consider)
|
|
self.working_history_sep = " ||| ".join(to_consider)
|
|
|
|
chat = []
|
|
for utt in self.chat_history:
|
|
if utt['intent'] == 'CHITCHAT':
|
|
if len(chat) == 4:
|
|
chat = chat[1:]
|
|
chat.append(utt['modified_query'])
|
|
self.chitchat_history_consec = '. '.join(chat)
|
|
|
|
|
|
def get_consec_history(self):
|
|
return self.working_history_consec
|
|
|
|
def get_chitchat_history(self):
|
|
return self.chitchat_history_consec
|
|
|
|
def get_sep_history(self):
|
|
return self.working_history_sep
|
|
|
|
def get_recent_state(self):
|
|
return self.curr_state
|
|
|
|
def get_dialogue_history(self):
|
|
return self.chat_history
|
|
|
|
def update(self, new_state):
|
|
self.chat_history.append(new_state)
|
|
self.curr_state = new_state
|
|
self.update_history()
|
|
|
|
def next_action(self):
|
|
if self.curr_state['help']:
|
|
return "Help"
|
|
elif self.curr_state['inactive']:
|
|
return "Recommend"
|
|
elif self.curr_state['is_clear']:
|
|
if self.curr_state['is_offensive']:
|
|
return "OffenseReject"
|
|
else:
|
|
if self.curr_state['intent'] == 'QA':
|
|
return "RetGen"
|
|
if self.curr_state['intent'] == 'EXPLAINPOST':
|
|
return "findPost"
|
|
if self.curr_state['intent'] == 'HELP':
|
|
return "getHelp"
|
|
elif self.curr_state['intent'] == 'CHITCHAT':
|
|
return "ConvGen"
|
|
elif self.curr_state['intent'] == 'FINDPAPER':
|
|
return "findPaper"
|
|
elif self.curr_state['intent'] == 'FINDDATASET':
|
|
return "findDataset"
|
|
elif self.curr_state['intent'] == 'SUMMARIZEPAPER':
|
|
return "sumPaper"
|
|
elif self.curr_state['intent'] == 'LISTPAPERS':
|
|
return "listPapers"
|
|
elif self.curr_state['intent'] == 'LISTDATASETS':
|
|
return "listDatasets"
|
|
elif self.curr_state['intent'] == 'LISTCOMMANDS':
|
|
return "listCommands"
|
|
elif self.curr_state['intent'] == 'LISTTOPICS':
|
|
return "listTopics"
|
|
elif self.curr_state['intent'] == 'LISTRESOURCES':
|
|
return "listResources"
|
|
elif self.curr_state['intent'] == 'COMMAND':
|
|
return "command"
|
|
else:
|
|
return "RetGen"
|
|
else:
|
|
return "Clarify"
|