2023-04-08 03:33:28 +02:00
|
|
|
import os
|
2023-04-08 04:04:24 +02:00
|
|
|
import warnings
|
|
|
|
import faiss
|
|
|
|
import torch
|
|
|
|
from flask import Flask, render_template, request, jsonify
|
|
|
|
from flask_cors import CORS, cross_origin
|
|
|
|
import psycopg2
|
|
|
|
import spacy
|
|
|
|
import spacy_transformers
|
|
|
|
import torch
|
|
|
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
|
|
|
|
from User import User
|
|
|
|
from VRE import VRE
|
|
|
|
from NLU import NLU
|
|
|
|
from DM import DM
|
|
|
|
from Recommender import Recommender
|
|
|
|
from ResponseGenerator import ResponseGenerator
|
|
|
|
import pandas as pd
|
|
|
|
import time
|
|
|
|
import threading
|
|
|
|
from sentence_transformers import SentenceTransformer
|
|
|
|
|
|
|
|
|
2023-03-30 20:14:13 +02:00
|
|
|
|
2023-03-30 15:17:54 +02:00
|
|
|
app = Flask(__name__)
|
|
|
|
url = os.getenv("FRONTEND_URL_WITH_PORT")
|
2023-04-08 04:04:24 +02:00
|
|
|
cors = CORS(app, resources={r"/api/predict": {"origins": url},
|
|
|
|
r"/api/feedback": {"origins": url},
|
|
|
|
r"/health": {"origins": "*"}
|
|
|
|
})
|
2023-04-08 22:51:44 +02:00
|
|
|
|
2023-04-08 04:04:24 +02:00
|
|
|
conn = psycopg2.connect(
|
|
|
|
host="janet-pg",
|
|
|
|
database=os.getenv("POSTGRES_DB"),
|
|
|
|
user=os.getenv("POSTGRES_USER"),
|
|
|
|
password=os.getenv("POSTGRES_PASSWORD"))
|
|
|
|
|
2023-04-08 22:51:44 +02:00
|
|
|
|
2023-04-08 04:04:24 +02:00
|
|
|
"""
|
|
|
|
conn = psycopg2.connect(host="https://janet-app-db.d4science.org",
|
|
|
|
database="janet",
|
|
|
|
user="janet_user",
|
|
|
|
password="2fb5e81fec5a2d906a04")
|
|
|
|
"""
|
2023-04-08 22:51:44 +02:00
|
|
|
|
2023-04-08 04:04:24 +02:00
|
|
|
cur = conn.cursor()
|
2023-04-08 22:51:44 +02:00
|
|
|
|
2023-04-08 04:04:24 +02:00
|
|
|
|
|
|
|
def vre_fetch():
|
|
|
|
while True:
|
|
|
|
time.sleep(1000)
|
|
|
|
print('getting new material')
|
|
|
|
vre.get_vre_update()
|
|
|
|
vre.index_periodic_update()
|
|
|
|
rg.update_index(vre.get_index())
|
|
|
|
rg.update_db(vre.get_db())
|
|
|
|
|
|
|
|
def user_interest_decay():
|
|
|
|
while True:
|
|
|
|
print("decaying interests after 3 minutes")
|
|
|
|
time.sleep(180)
|
|
|
|
user.decay_interests()
|
2023-03-30 15:17:54 +02:00
|
|
|
|
2023-04-07 17:35:51 +02:00
|
|
|
@app.route("/health", methods=['GET'])
|
2023-04-08 22:51:44 +02:00
|
|
|
def health():
|
2023-04-07 17:35:51 +02:00
|
|
|
return "Success", 200
|
|
|
|
|
2023-04-07 18:27:07 +02:00
|
|
|
@app.route("/api/predict", methods=['POST'])
|
2023-03-30 15:17:54 +02:00
|
|
|
def predict():
|
|
|
|
text = request.get_json().get("message")
|
2023-04-08 04:04:24 +02:00
|
|
|
message = {}
|
|
|
|
if text == "<HELP_ON_START>":
|
|
|
|
state = {'help': True, 'inactive': False, 'modified_query':""}
|
|
|
|
dm.update(state)
|
|
|
|
action = dm.next_action()
|
|
|
|
response = rg.gen_response(action)
|
|
|
|
message = {"answer": response}
|
|
|
|
elif text == "<RECOMMEND_ON_IDLE>":
|
|
|
|
state = {'help': False, 'inactive': True, 'modified_query':"recommed: "}
|
|
|
|
dm.update(state)
|
|
|
|
action = dm.next_action()
|
|
|
|
response = rg.gen_response(action, username=user.username)
|
|
|
|
message = {"answer": response}
|
|
|
|
new_state = {'modified_query': response}
|
|
|
|
dm.update(new_state)
|
|
|
|
else:
|
|
|
|
state = nlu.process_utterance(text, dm.get_consec_history(), dm.get_sep_history())
|
|
|
|
state['help'] = False
|
|
|
|
state['inactive'] = False
|
|
|
|
old_user_interests = user.get_user_interests()
|
|
|
|
old_vre_material = pd.concat([vre.db['paper_db'], vre.db['dataset_db']]).reset_index(drop=True)
|
|
|
|
user_interests = []
|
|
|
|
for entity in state['entities']:
|
|
|
|
if entity['entity'] == 'TOPIC':
|
|
|
|
user_interests.append(entity['value'])
|
|
|
|
user.update_interests(user_interests)
|
|
|
|
new_user_interests = user.get_user_interests()
|
|
|
|
new_vre_material = pd.concat([vre.db['paper_db'], vre.db['dataset_db']]).reset_index(drop=True)
|
|
|
|
if (new_user_interests() != old_user_interests or old_vre_material != new_vre_material):
|
|
|
|
rec.generate_recommendations(new_user_interests, new_vre_material)
|
|
|
|
dm.update(state)
|
|
|
|
action = dm.next_action()
|
|
|
|
response = rg.gen_response(action=action, utterance=state['modified_query'], state=dm.get_recent_state(), consec_history=dm.get_consec_history())
|
|
|
|
message = {"answer": response, "query": text, "cand": "candidate", "history": dm.get_consec_history(), "modQuery": state['modified_query']}
|
|
|
|
new_state = {'modified_query': response}
|
|
|
|
dm.update(new_state)
|
2023-03-30 15:17:54 +02:00
|
|
|
reply = jsonify(message)
|
|
|
|
return reply
|
|
|
|
|
2023-04-07 18:27:07 +02:00
|
|
|
@app.route('/api/feedback', methods = ['POST'])
|
2023-03-30 15:17:54 +02:00
|
|
|
def feedback():
|
2023-04-07 03:23:01 +02:00
|
|
|
data = request.get_json().get("feedback")
|
2023-04-08 04:04:24 +02:00
|
|
|
print(data)
|
2023-04-08 22:51:44 +02:00
|
|
|
|
2023-04-08 04:04:24 +02:00
|
|
|
cur.execute('INSERT INTO feedback_trial (query, history, janet_modified_query, is_modified_query_correct, user_modified_query, response, preferred_response, response_length_feedback, response_fluency_feedback, response_truth_feedback, response_useful_feedback, response_time_feedback, response_intent) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)',
|
|
|
|
(data['query'], data['history'], data['modQuery'],
|
|
|
|
data['queryModCorrect'], data['correctQuery'],
|
|
|
|
data['janetResponse'], data['preferredResponse'], data['length'],
|
|
|
|
data['fluency'], data['truthfulness'], data['usefulness'],
|
|
|
|
data['speed'], data['intent'])
|
|
|
|
)
|
|
|
|
conn.commit()
|
2023-04-08 22:51:44 +02:00
|
|
|
|
2023-03-30 15:17:54 +02:00
|
|
|
reply = jsonify({"status": "done"})
|
|
|
|
return reply
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
2023-04-08 04:04:24 +02:00
|
|
|
warnings.filterwarnings("ignore")
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
device_flag = torch.cuda.current_device() if torch.cuda.is_available() else -1
|
|
|
|
|
|
|
|
query_rewriter = pipeline("text2text-generation", model="castorini/t5-base-canard")
|
|
|
|
intent_classifier = pipeline("sentiment-analysis", model='/models/intent_classifier', device=device_flag)
|
|
|
|
entity_extractor = spacy.load("/models/entity_extractor")
|
|
|
|
offensive_classifier = pipeline("sentiment-analysis", model='/models/offensive_classifier', device=device_flag)
|
|
|
|
ambig_classifier = pipeline("sentiment-analysis", model='/models/ambig_classifier', device=device_flag)
|
|
|
|
coref_resolver = spacy.load("en_coreference_web_trf")
|
|
|
|
|
|
|
|
nlu = NLU(query_rewriter, coref_resolver, intent_classifier, offensive_classifier, entity_extractor, ambig_classifier)
|
|
|
|
|
|
|
|
#load retriever and generator
|
|
|
|
retriever = SentenceTransformer('/models/BigRetriever/').to(device)
|
|
|
|
qa_generator = pipeline("text2text-generation", model="/models/train_qa", device=device_flag)
|
|
|
|
summ_generator = pipeline("text2text-generation", model="/models/train_summ", device=device_flag)
|
|
|
|
chat_generator = pipeline("text2text-generation", model="/models/train_chat", device=device_flag)
|
|
|
|
amb_generator = pipeline("text2text-generation", model="/models/train_amb_gen", device=device_flag)
|
|
|
|
generators = {'qa': qa_generator,
|
|
|
|
'chat': chat_generator,
|
|
|
|
'amb': amb_generator,
|
|
|
|
'summ': summ_generator}
|
|
|
|
|
|
|
|
#load vre
|
|
|
|
token = '2c1e8f88-461c-42c0-8cc1-b7660771c9a3-843339462'
|
|
|
|
vre = VRE("assistedlab", token, retriever)
|
|
|
|
vre.init()
|
|
|
|
index = vre.get_index()
|
|
|
|
db = vre.get_db()
|
|
|
|
user = User("ahmed", token)
|
|
|
|
|
|
|
|
threading.Thread(target=vre_fetch, name='updatevre').start()
|
|
|
|
|
|
|
|
threading.Thread(target=user_interest_decay, name='decayinterest').start()
|
|
|
|
|
|
|
|
rec = Recommender(retriever)
|
|
|
|
|
|
|
|
dm = DM()
|
|
|
|
|
|
|
|
rg = ResponseGenerator(index,db, rec, generators, retriever)
|
|
|
|
|
2023-04-08 22:51:44 +02:00
|
|
|
|
2023-04-08 04:04:24 +02:00
|
|
|
cur.execute('CREATE TABLE IF NOT EXISTS feedback_trial (id serial PRIMARY KEY,'
|
|
|
|
'query text NOT NULL,'
|
|
|
|
'history text NOT NULL,'
|
|
|
|
'janet_modified_query text NOT NULL,'
|
|
|
|
'is_modified_query_correct text NOT NULL,'
|
|
|
|
'user_modified_query text NOT NULL,'
|
|
|
|
'response text NOT NULL,'
|
|
|
|
'preferred_response text,'
|
|
|
|
'response_length_feedback text NOT NULL,'
|
|
|
|
'response_fluency_feedback text NOT NULL,'
|
|
|
|
'response_truth_feedback text NOT NULL,'
|
|
|
|
'response_useful_feedback text NOT NULL,'
|
|
|
|
'response_time_feedback text NOT NULL,'
|
|
|
|
'response_intent text NOT NULL);'
|
|
|
|
)
|
|
|
|
conn.commit()
|
2023-04-08 22:51:44 +02:00
|
|
|
|
2023-04-07 21:03:43 +02:00
|
|
|
app.run(host='0.0.0.0', port=4000)
|