JanetBackEnd/main.py

222 lines
9.8 KiB
Python
Raw Normal View History

2023-04-08 03:33:28 +02:00
import os
2023-04-19 04:08:20 +02:00
import re
2023-04-08 04:04:24 +02:00
import warnings
import faiss
import torch
from flask import Flask, render_template, request, jsonify
from flask_cors import CORS, cross_origin
import psycopg2
import spacy
2023-04-18 02:59:58 +02:00
import requests
2023-04-08 04:04:24 +02:00
import spacy_transformers
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
from User import User
from VRE import VRE
from NLU import NLU
from DM import DM
from Recommender import Recommender
from ResponseGenerator import ResponseGenerator
import pandas as pd
import time
import threading
from sentence_transformers import SentenceTransformer
2023-03-30 20:14:13 +02:00
2023-03-30 15:17:54 +02:00
app = Flask(__name__)
url = os.getenv("FRONTEND_URL_WITH_PORT")
2023-04-08 04:04:24 +02:00
cors = CORS(app, resources={r"/api/predict": {"origins": url},
r"/api/feedback": {"origins": url},
2023-04-17 04:14:44 +02:00
r"/api/dm": {"origins": url},
2023-04-08 04:04:24 +02:00
r"/health": {"origins": "*"}
})
2023-04-19 04:57:54 +02:00
2023-04-08 04:04:24 +02:00
conn = psycopg2.connect(
host="janet-pg",
database=os.getenv("POSTGRES_DB"),
user=os.getenv("POSTGRES_USER"),
password=os.getenv("POSTGRES_PASSWORD"))
cur = conn.cursor()
2023-04-19 04:57:54 +02:00
2023-04-17 05:40:41 +02:00
users = {}
2023-04-08 04:04:24 +02:00
2023-04-18 02:59:58 +02:00
def vre_fetch(token):
2023-04-08 04:04:24 +02:00
while True:
time.sleep(1000)
print('getting new material')
2023-04-18 02:59:58 +02:00
users[token]['args']['vre'].get_vre_update()
users[token]['args']['vre'].index_periodic_update()
users[token]['args']['rg'].update_index(vre.get_index())
users[token]['args']['rg'].update_db(vre.get_db())
2023-04-08 04:04:24 +02:00
2023-04-18 02:59:58 +02:00
def user_interest_decay(token):
2023-04-08 04:04:24 +02:00
while True:
2023-04-18 02:59:58 +02:00
print("decaying interests after 3 minutes for " + users[token]['username'])
2023-04-08 04:04:24 +02:00
time.sleep(180)
2023-04-18 02:59:58 +02:00
users[token]['user'].decay_interests()
2023-03-30 15:17:54 +02:00
2023-04-17 06:59:02 +02:00
def clear_inactive():
while True:
time.sleep(1)
2023-04-17 08:55:23 +02:00
for username in users:
if users[username]['activity'] > 3600:
del users[username]
users[username]['activity'] += 1
2023-04-17 06:59:02 +02:00
2023-04-07 17:35:51 +02:00
@app.route("/health", methods=['GET'])
2023-04-08 22:51:44 +02:00
def health():
2023-04-07 17:35:51 +02:00
return "Success", 200
2023-04-17 05:40:41 +02:00
@app.route("/api/dm", methods=['POST'])
2023-04-17 04:14:44 +02:00
def init_dm():
2023-04-18 02:59:58 +02:00
token = request.get_json().get("token")
2023-04-18 21:25:50 +02:00
status = request.get_json().get("stat")
if status == "start":
message = {"stat": "waiting"}
elif status == "set":
headers = {"gcube-token": token, "Accept": "application/json"}
if token not in users:
url = 'https://api.d4science.org/rest/2/people/profile'
response = requests.get(url, headers=headers)
if response.status_code == 200:
username = response.json()['result']['username']
name = response.json()['result']['fullname']
vre = VRE("assistedlab", token, retriever)
vre.init()
index = vre.get_index()
db = vre.get_db()
rg = ResponseGenerator(index,db, rec, generators, retriever)
args = {'vre': vre, 'rg': rg}
2023-04-18 02:59:58 +02:00
2023-04-18 21:25:50 +02:00
users[token] = {'username': username, 'name': name, 'dm': DM(), 'activity': 0, 'user': User(username, token), 'args': args}
2023-04-18 02:59:58 +02:00
2023-04-18 21:25:50 +02:00
threading.Thread(target=vre_fetch, args=(token,), name='updatevre_'+users[token]['username']).start()
threading.Thread(target=user_interest_decay, args=(token,), name='decayinterest_'+users[token]['username']).start()
message = {"stat": "done"}
else:
message = {"stat": "rejected"}
else:
message = {"stat": "done"}
2023-04-17 05:40:41 +02:00
return message
2023-04-17 04:14:44 +02:00
2023-04-07 18:27:07 +02:00
@app.route("/api/predict", methods=['POST'])
2023-03-30 15:17:54 +02:00
def predict():
text = request.get_json().get("message")
2023-04-18 02:59:58 +02:00
token = request.get_json().get("token")
dm = users[token]['dm']
user = users[token]['user']
rg = users[token]['args']['rg']
vre = users[token]['args']['vre']
2023-04-08 04:04:24 +02:00
message = {}
if text == "<HELP_ON_START>":
2023-04-16 02:44:09 +02:00
state = {'help': True, 'inactive': False, 'modified_query':"", 'intent':""}
2023-04-08 04:04:24 +02:00
dm.update(state)
action = dm.next_action()
2023-04-18 02:59:58 +02:00
response = rg.gen_response(action, vrename=vre.name, username=users[token]['username'], name=users[token]['name'].split()[0])
2023-04-08 04:04:24 +02:00
message = {"answer": response}
elif text == "<RECOMMEND_ON_IDLE>":
2023-04-16 02:44:09 +02:00
state = {'help': False, 'inactive': True, 'modified_query':"recommed: ", 'intent':""}
2023-04-08 04:04:24 +02:00
dm.update(state)
action = dm.next_action()
2023-04-18 02:59:58 +02:00
response = rg.gen_response(action, username=users[token]['username'],name=users[token]['name'].split()[0], vrename=vre.name)
2023-04-08 04:04:24 +02:00
message = {"answer": response}
new_state = {'modified_query': response}
dm.update(new_state)
else:
state = nlu.process_utterance(text, dm.get_consec_history(), dm.get_sep_history())
state['help'] = False
state['inactive'] = False
old_user_interests = user.get_user_interests()
old_vre_material = pd.concat([vre.db['paper_db'], vre.db['dataset_db']]).reset_index(drop=True)
user_interests = []
for entity in state['entities']:
if entity['entity'] == 'TOPIC':
user_interests.append(entity['value'])
user.update_interests(user_interests)
new_user_interests = user.get_user_interests()
new_vre_material = pd.concat([vre.db['paper_db'], vre.db['dataset_db']]).reset_index(drop=True)
2023-04-09 20:14:17 +02:00
if (new_user_interests != old_user_interests or len(old_vre_material) != len(new_vre_material)):
2023-04-18 02:59:58 +02:00
rec.generate_recommendations(users[token]['username'], new_user_interests, new_vre_material)
2023-04-08 04:04:24 +02:00
dm.update(state)
action = dm.next_action()
2023-04-17 21:10:06 +02:00
response = rg.gen_response(action=action, utterance=state['modified_query'], state=dm.get_recent_state(), consec_history=dm.get_consec_history(), chitchat_history=dm.get_chitchat_history(), vrename=vre.name)
2023-04-08 04:04:24 +02:00
message = {"answer": response, "query": text, "cand": "candidate", "history": dm.get_consec_history(), "modQuery": state['modified_query']}
2023-04-16 02:44:09 +02:00
if state['intent'] == "QA":
response = response.split("_______ \n The answer is: ")[1]
new_state = {'modified_query': response, 'intent': state['intent']}
2023-04-08 04:04:24 +02:00
dm.update(new_state)
2023-03-30 15:17:54 +02:00
reply = jsonify(message)
2023-04-18 02:59:58 +02:00
users[token]['dm'] = dm
users[token]['user'] = user
users[token]['activity'] = 0
users[token]['args']['vre'] = vre
users[token]['args']['rg'] = rg
2023-03-30 15:17:54 +02:00
return reply
2023-04-07 18:27:07 +02:00
@app.route('/api/feedback', methods = ['POST'])
2023-03-30 15:17:54 +02:00
def feedback():
2023-04-07 03:23:01 +02:00
data = request.get_json().get("feedback")
2023-04-08 04:04:24 +02:00
print(data)
2023-04-19 04:57:54 +02:00
2023-04-16 19:52:40 +02:00
cur.execute('INSERT INTO feedback_experimental (query, history, janet_modified_query, is_modified_query_correct, user_modified_query, evidence_useful, response, preferred_response, response_length_feedback, response_fluency_feedback, response_truth_feedback, response_useful_feedback, response_time_feedback, response_intent) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)',
2023-04-08 04:04:24 +02:00
(data['query'], data['history'], data['modQuery'],
2023-04-16 19:52:40 +02:00
data['queryModCorrect'], data['correctQuery'], data['evidence'], data['janetResponse'], data['preferredResponse'], data['length'],
2023-04-08 04:04:24 +02:00
data['fluency'], data['truthfulness'], data['usefulness'],
data['speed'], data['intent'])
)
conn.commit()
2023-04-19 04:57:54 +02:00
2023-03-30 15:17:54 +02:00
reply = jsonify({"status": "done"})
return reply
if __name__ == "__main__":
2023-04-08 04:04:24 +02:00
warnings.filterwarnings("ignore")
device = "cuda" if torch.cuda.is_available() else "cpu"
device_flag = torch.cuda.current_device() if torch.cuda.is_available() else -1
query_rewriter = pipeline("text2text-generation", model="castorini/t5-base-canard")
intent_classifier = pipeline("sentiment-analysis", model='/models/intent_classifier', device=device_flag)
entity_extractor = spacy.load("/models/entity_extractor")
offensive_classifier = pipeline("sentiment-analysis", model='/models/offensive_classifier', device=device_flag)
ambig_classifier = pipeline("sentiment-analysis", model='/models/ambig_classifier', device=device_flag)
coref_resolver = spacy.load("en_coreference_web_trf")
nlu = NLU(query_rewriter, coref_resolver, intent_classifier, offensive_classifier, entity_extractor, ambig_classifier)
#load retriever and generator
2023-04-14 07:11:23 +02:00
retriever = SentenceTransformer('/models/retriever/').to(device)
2023-04-08 04:04:24 +02:00
qa_generator = pipeline("text2text-generation", model="/models/train_qa", device=device_flag)
summ_generator = pipeline("text2text-generation", model="/models/train_summ", device=device_flag)
chat_generator = pipeline("text2text-generation", model="/models/train_chat", device=device_flag)
amb_generator = pipeline("text2text-generation", model="/models/train_amb_gen", device=device_flag)
generators = {'qa': qa_generator,
'chat': chat_generator,
'amb': amb_generator,
'summ': summ_generator}
2023-04-18 02:59:58 +02:00
2023-04-17 06:59:02 +02:00
threading.Thread(target=clear_inactive, name='clear').start()
2023-04-08 04:04:24 +02:00
rec = Recommender(retriever)
2023-04-19 04:57:54 +02:00
2023-04-15 11:21:45 +02:00
cur.execute('CREATE TABLE IF NOT EXISTS feedback_experimental (id serial PRIMARY KEY,'
2023-04-08 04:04:24 +02:00
'query text NOT NULL,'
'history text NOT NULL,'
'janet_modified_query text NOT NULL,'
'is_modified_query_correct text NOT NULL,'
2023-04-16 19:52:40 +02:00
'user_modified_query text NOT NULL, evidence_useful text NOT NULL,'
2023-04-15 11:21:45 +02:00
'response text NOT NULL,'
2023-04-08 04:04:24 +02:00
'preferred_response text,'
'response_length_feedback text NOT NULL,'
'response_fluency_feedback text NOT NULL,'
'response_truth_feedback text NOT NULL,'
'response_useful_feedback text NOT NULL,'
'response_time_feedback text NOT NULL,'
'response_intent text NOT NULL);'
)
conn.commit()
2023-04-19 04:57:54 +02:00
2023-04-09 19:25:15 +02:00
app.run(host='0.0.0.0')