version_1
This commit is contained in:
parent
657583274c
commit
e1dfb0da74
2
DM.py
2
DM.py
|
@ -9,7 +9,7 @@ class DM:
|
|||
self.curr_state = None
|
||||
|
||||
def update_history(self):
|
||||
to_consider = [x['modified_query'] for x in self.chat_history[-max_history_length*2:]]
|
||||
to_consider = [x['modified_query'] for x in self.chat_history[-self.max_history_length*2:]]
|
||||
self.working_history_consec = " . ".join(to_consider)
|
||||
self.working_history_sep = " ||| ".join(to_consider)
|
||||
|
||||
|
|
14
NLU.py
14
NLU.py
|
@ -88,8 +88,8 @@ class NLU:
|
|||
entities = self._entityextractor()
|
||||
offense = self._offensepredictor()
|
||||
if intent in ['FINDPAPER', 'FINDDATASET', 'SUMMARIZEPAPER'] and len(entities) == 0:
|
||||
return {"modified_prompt": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense, "is_clear": False}
|
||||
return {"modified_prompt": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense, "is_clear": True}
|
||||
return {"modified_query": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense, "is_clear": False}
|
||||
return {"modified_query": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense, "is_clear": True}
|
||||
else:
|
||||
if self._ambigpredictor():
|
||||
self.to_process = self._rewrite_query(history_sep)
|
||||
|
@ -98,15 +98,15 @@ class NLU:
|
|||
offense = self._offensepredictor()
|
||||
if score > 0.5 or not self._ambigpredictor():
|
||||
if intent in ['FINDPAPER', 'FINDDATASET', 'SUMMARIZEPAPER'] and len(entities) == 0:
|
||||
return {"modified_prompt": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense, "is_clear": False}
|
||||
return {"modified_prompt": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense,
|
||||
return {"modified_query": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense, "is_clear": False}
|
||||
return {"modified_query": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense,
|
||||
"is_clear": True}
|
||||
else:
|
||||
return {"modified_prompt": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense,
|
||||
return {"modified_query": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense,
|
||||
"is_clear": False}
|
||||
else:
|
||||
entities = self._entityextractor()
|
||||
offense = self._offensepredictor()
|
||||
if intent in ['FINDPAPER', 'FINDDATASET', 'SUMMARIZEPAPER'] and len(entities) == 0:
|
||||
return {"modified_prompt": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense, "is_clear": False}
|
||||
return {"modified_prompt": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense, "is_clear": True}
|
||||
return {"modified_query": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense, "is_clear": False}
|
||||
return {"modified_query": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense, "is_clear": True}
|
||||
|
|
17
main.py
17
main.py
|
@ -25,7 +25,7 @@ from sentence_transformers import SentenceTransformer
|
|||
app = Flask(__name__)
|
||||
url = os.getenv("FRONTEND_URL_WITH_PORT")
|
||||
cors = CORS(app, resources={r"/predict": {"origins": url}, r"/feedback": {"origins": url}})
|
||||
"""
|
||||
|
||||
conn = psycopg2.connect(
|
||||
host="https://janet-app-db.d4science.org",
|
||||
database=os.getenv("POSTGRES_DB"),
|
||||
|
@ -36,7 +36,7 @@ conn = psycopg2.connect(host="https://janet-app-db.d4science.org",
|
|||
database="janet",
|
||||
user="janet_user",
|
||||
password="2fb5e81fec5a2d906a04")
|
||||
|
||||
"""
|
||||
cur = conn.cursor()
|
||||
|
||||
|
||||
|
@ -60,17 +60,19 @@ def predict():
|
|||
text = request.get_json().get("message")
|
||||
message = {}
|
||||
if text == "<HELP_ON_START>":
|
||||
state = {'help': True, 'inactive': False}
|
||||
state = {'help': True, 'inactive': False, 'modified_query':""}
|
||||
dm.update(state)
|
||||
action = dm.next_action()
|
||||
response = rg.gen_response(action)
|
||||
message = {"answer": response}
|
||||
elif text == "<RECOMMEND_ON_IDLE>":
|
||||
state = {'help': False, 'inactive': True}
|
||||
state = {'help': False, 'inactive': True, 'modified_query':"recommed: "}
|
||||
dm.update(state)
|
||||
action = dm.next_action()
|
||||
response = rg.gen_response(action, username=user.username)
|
||||
message = {"answer": response}
|
||||
new_state = {'modified_query': response}
|
||||
dm.update(new_state)
|
||||
else:
|
||||
state = nlu.process_utterance(text, dm.get_consec_history(), dm.get_sep_history())
|
||||
state['help'] = False
|
||||
|
@ -93,6 +95,7 @@ def predict():
|
|||
def feedback():
|
||||
data = request.get_json()['feedback']
|
||||
print(data)
|
||||
|
||||
cur.execute('INSERT INTO feedback (query, history, janet_modified_query, is_modified_query_correct, user_modified_query, response, preferred_response, response_length_feedback, response_fluency_feedback, response_truth_feedback, response_useful_feedback, response_time_feedback, response_intent) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)',
|
||||
(data['query'], data['history'], data['modQuery'],
|
||||
data['queryModCorrect'], data['correctQuery'],
|
||||
|
@ -100,6 +103,7 @@ def feedback():
|
|||
data['fluency'], data['truthfulness'], data['usefulness'],
|
||||
data['speed'], data['intent'])
|
||||
)
|
||||
|
||||
reply = jsonify({"status": "done"})
|
||||
return reply
|
||||
|
||||
|
@ -130,7 +134,7 @@ if __name__ == "__main__":
|
|||
|
||||
#load vre
|
||||
token = '2c1e8f88-461c-42c0-8cc1-b7660771c9a3-843339462'
|
||||
vre = VRE("assistedlab", token, def_retriever)
|
||||
vre = VRE("assistedlab", token, retriever)
|
||||
vre.init()
|
||||
index = vre.get_index()
|
||||
db = vre.get_db()
|
||||
|
@ -144,7 +148,7 @@ if __name__ == "__main__":
|
|||
|
||||
dm = DM()
|
||||
|
||||
rg = ResponseGenerator(index,db, recommender, generators, retriever)
|
||||
rg = ResponseGenerator(index,db, rec, generators, retriever)
|
||||
|
||||
|
||||
cur.execute('CREATE TABLE IF NOT EXISTS feedback_trial (id serial PRIMARY KEY,'
|
||||
|
@ -162,4 +166,5 @@ if __name__ == "__main__":
|
|||
'response_time_feedback text NOT NULL,'
|
||||
'response_intent text NOT NULL);'
|
||||
)
|
||||
|
||||
app.run(host='127.0.0.1', port=4000)
|
||||
|
|
Loading…
Reference in New Issue