import os import re import warnings import faiss import torch from flask import Flask, render_template, request, jsonify from flask_cors import CORS, cross_origin import psycopg2 import spacy import requests import spacy_transformers import torch from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline from User import User from VRE import VRE from NLU import NLU from DM import DM from Recommender import Recommender from ResponseGenerator import ResponseGenerator import pandas as pd import time import threading from sentence_transformers import SentenceTransformer app = Flask(__name__) url = os.getenv("FRONTEND_URL_WITH_PORT") cors = CORS(app, resources={r"/api/predict": {"origins": url}, r"/api/feedback": {"origins": url}, r"/api/dm": {"origins": url}, r"/health": {"origins": "*"} }) users = {} alive = "alive" device = "cuda" if torch.cuda.is_available() else "cpu" device_flag = torch.cuda.current_device() if torch.cuda.is_available() else -1 query_rewriter = pipeline("text2text-generation", model="castorini/t5-base-canard") intent_classifier = pipeline("sentiment-analysis", model='/models/intent_classifier', device=device_flag) entity_extractor = spacy.load("/models/entity_extractor") offensive_classifier = pipeline("sentiment-analysis", model='/models/offensive_classifier', device=device_flag) ambig_classifier = pipeline("sentiment-analysis", model='/models/ambig_classifier', device=device_flag) coref_resolver = spacy.load("en_coreference_web_trf") nlu = NLU(query_rewriter, coref_resolver, intent_classifier, offensive_classifier, entity_extractor, ambig_classifier) #load retriever and generator retriever = SentenceTransformer('/models/retriever/').to(device) qa_generator = pipeline("text2text-generation", model="/models/train_qa", device=device_flag) summ_generator = pipeline("text2text-generation", model="/models/train_summ", device=device_flag) chat_generator = pipeline("text2text-generation", model="/models/train_chat", device=device_flag) amb_generator = pipeline("text2text-generation", model="/models/train_amb_gen", device=device_flag) generators = {'qa': qa_generator, 'chat': chat_generator, 'amb': amb_generator, 'summ': summ_generator} rec = Recommender(retriever) def vre_fetch(token): while True: try: time.sleep(1000) print('getting new material') users[token]['vre'].get_vre_update() users[token]['vre'].index_periodic_update() users[token]['rg'].update_index(vre.get_index()) users[token]['rg'].update_db(vre.get_db()) #vre.get_vre_update() #vre.index_periodic_update() #rg.update_index(vre.get_index()) #rg.update_db(vre.get_db()) except Exception as e: alive = "dead_vre_fetch" def user_interest_decay(token): while True: try: if token in users: print("decaying interests after 3 minutes for " + users[token]['username']) time.sleep(180) users[token]['user'].decay_interests() else: break except Exception as e: alive = "dead_interest_decay" def clear_inactive(): while True: try: time.sleep(1) for username in users: if users[username]['activity'] > 3600: del users[username] users[username]['activity'] += 1 except Exception as e: alive = "dead_clear_inactive" @app.route("/health", methods=['GET']) def health(): if alive=="alive": return "Success", 200 else: return alive, 500 @app.route("/api/dm", methods=['POST']) def init_dm(): try: token = request.get_json().get("token") status = request.get_json().get("stat") if status == "start": message = {"stat": "waiting", "err": ""} elif status == "set": headers = {"gcube-token": token, "Accept": "application/json"} if token not in users: url = 'https://api.d4science.org/rest/2/people/profile' response = requests.get(url, headers=headers) if response.status_code == 200: username = response.json()['result']['username'] name = response.json()['result']['fullname'] vre = VRE("assistedlab", token, retriever) vre.init() index = vre.get_index() db = vre.get_db() rg = ResponseGenerator(index,db, rec, generators, retriever) users[token] = {'username': username, 'name': name, 'dm': DM(), 'activity': 0, 'user': User(username, token), 'vre': vre, 'rg': rg} threading.Thread(target=user_interest_decay, args=(token,), name='decayinterest_'+users[token]['username']).start() threading.Thread(target=vre_fetch, name='updatevre'+users[token]['username'], args=(token,)).start() message = {"stat": "done", "err": ""} else: message = {"stat": "rejected", "err": ""} else: message = {"stat": "done", "err": ""} return message except Exception as e: message = {"stat": "init_dm_error", "err": str(e)} return message @app.route("/api/predict", methods=['POST']) def predict(): text = request.get_json().get("message") token = request.get_json().get("token") dm = users[token]['dm'] user = users[token]['user'] rg = users[token]['rg'] vre = users[token]['vre'] message = {} try: if text == "": state = {'help': True, 'inactive': False, 'modified_query':"", 'intent':""} dm.update(state) action = dm.next_action() response = rg.gen_response(action, vrename=vre.name, username=users[token]['username'], name=users[token]['name'].split()[0]) message = {"answer": response} elif text == "": state = {'help': False, 'inactive': True, 'modified_query':"recommed: ", 'intent':""} dm.update(state) action = dm.next_action() response = rg.gen_response(action, username=users[token]['username'],name=users[token]['name'].split()[0], vrename=vre.name) message = {"answer": response} new_state = {'modified_query': response} dm.update(new_state) else: state = nlu.process_utterance(text, dm.get_consec_history(), dm.get_sep_history()) state['help'] = False state['inactive'] = False old_user_interests = user.get_user_interests() old_vre_material = pd.concat([vre.db['paper_db'], vre.db['dataset_db']]).reset_index(drop=True) user_interests = [] for entity in state['entities']: if entity['entity'] == 'TOPIC': user_interests.append(entity['value']) user.update_interests(user_interests) new_user_interests = user.get_user_interests() new_vre_material = pd.concat([vre.db['paper_db'], vre.db['dataset_db']]).reset_index(drop=True) if (new_user_interests != old_user_interests or len(old_vre_material) != len(new_vre_material)): rec.generate_recommendations(users[token]['username'], new_user_interests, new_vre_material) dm.update(state) action = dm.next_action() response = rg.gen_response(action=action, utterance=state['modified_query'], state=dm.get_recent_state(), consec_history=dm.get_consec_history(), chitchat_history=dm.get_chitchat_history(), vrename=vre.name, username=users[token]['username'], name=users[token]['name'].split()[0]) message = {"answer": response, "query": text, "cand": "candidate", "history": dm.get_consec_history(), "modQuery": state['modified_query']} if state['intent'] == "QA": split_response = response.split("_______ \n ") if len(split_response) > 1: response = split_response[1] new_state = {'modified_query': response, 'intent': state['intent']} dm.update(new_state) reply = jsonify(message) users[token]['dm'] = dm users[token]['user'] = user users[token]['activity'] = 0 users[token]['vre'] = vre users[token]['rg'] = rg return reply except Exception as e: message = {"answer": str(e), "query": "", "cand": "candidate", "history": "", "modQuery": ""} return jsonify(message) @app.route('/api/feedback', methods = ['POST']) def feedback(): data = request.get_json().get("feedback") print(data) try: """ conn = psycopg2.connect(host="janet-pg", database=os.getenv("POSTGRES_DB"), user=os.getenv("POSTGRES_USER"), password=os.getenv("POSTGRES_PASSWORD")) cur = conn.cursor() cur.execute('INSERT INTO feedback_experimental (query, history, janet_modified_query, is_modified_query_correct, user_modified_query, evidence_useful, response, preferred_response, response_length_feedback, response_fluency_feedback, response_truth_feedback, response_useful_feedback, response_time_feedback, response_intent) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)', (data['query'], data['history'], data['modQuery'], data['queryModCorrect'], data['correctQuery'], data['evidence'], data['janetResponse'], data['preferredResponse'], data['length'], data['fluency'], data['truthfulness'], data['usefulness'], data['speed'], data['intent'])) conn.commit() cur.close() """ reply = jsonify({"status": "done"}) return reply except Exception as e: return jsonify({"status": str(e)}) if __name__ == "__main__": warnings.filterwarnings("ignore") threading.Thread(target=clear_inactive, name='clear').start() """ conn = psycopg2.connect(host="janet-pg", database=os.getenv("POSTGRES_DB"), user=os.getenv("POSTGRES_USER"), password=os.getenv("POSTGRES_PASSWORD")) cur = conn.cursor() cur.execute('CREATE TABLE IF NOT EXISTS feedback_experimental (id serial PRIMARY KEY,' 'query text NOT NULL,' 'history text NOT NULL,' 'janet_modified_query text NOT NULL,' 'is_modified_query_correct text NOT NULL,' 'user_modified_query text, evidence_useful text NOT NULL,' 'response text NOT NULL,' 'preferred_response text,' 'response_length_feedback text NOT NULL,' 'response_fluency_feedback text NOT NULL,' 'response_truth_feedback text NOT NULL,' 'response_useful_feedback text NOT NULL,' 'response_time_feedback text NOT NULL,' 'response_intent text NOT NULL);' ) conn.commit() cur.close() """ app.run(host='0.0.0.0')