diff --git a/main.py b/main.py index 761bfda..cacff11 100644 --- a/main.py +++ b/main.py @@ -1,10 +1,63 @@ -from flask import Flask, request, jsonify -from flask_cors import CORS, cross_origin import os +import warnings +import faiss +import torch +from flask import Flask, render_template, request, jsonify +from flask_cors import CORS, cross_origin +import psycopg2 +import spacy +import spacy_transformers +import torch +from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline +from User import User +from VRE import VRE +from NLU import NLU +from DM import DM +from Recommender import Recommender +from ResponseGenerator import ResponseGenerator +import pandas as pd +import time +import threading +from sentence_transformers import SentenceTransformer + + app = Flask(__name__) url = os.getenv("FRONTEND_URL_WITH_PORT") -cors = CORS(app, resources={r"/predict": {"origins": url}, r"/feedback": {"origins": url}}) +cors = CORS(app, resources={r"/api/predict": {"origins": url}, + r"/api/feedback": {"origins": url}, + r"/health": {"origins": "*"} + }) + +conn = psycopg2.connect( + host="janet-pg", + database=os.getenv("POSTGRES_DB"), + user=os.getenv("POSTGRES_USER"), + password=os.getenv("POSTGRES_PASSWORD")) + +""" +conn = psycopg2.connect(host="https://janet-app-db.d4science.org", + database="janet", + user="janet_user", + password="2fb5e81fec5a2d906a04") +""" +cur = conn.cursor() + + +def vre_fetch(): + while True: + time.sleep(1000) + print('getting new material') + vre.get_vre_update() + vre.index_periodic_update() + rg.update_index(vre.get_index()) + rg.update_db(vre.get_db()) + +def user_interest_decay(): + while True: + print("decaying interests after 3 minutes") + time.sleep(180) + user.decay_interests() @app.route("/health", methods=['GET']) def check_health(): @@ -13,15 +66,121 @@ def check_health(): @app.route("/api/predict", methods=['POST']) def predict(): text = request.get_json().get("message") - message = {"answer": "answer", "query": "text", "cand": "candidate", "history": "history", "modQuery": "modQuery"} + message = {} + if text == "": + state = {'help': True, 'inactive': False, 'modified_query':""} + dm.update(state) + action = dm.next_action() + response = rg.gen_response(action) + message = {"answer": response} + elif text == "": + state = {'help': False, 'inactive': True, 'modified_query':"recommed: "} + dm.update(state) + action = dm.next_action() + response = rg.gen_response(action, username=user.username) + message = {"answer": response} + new_state = {'modified_query': response} + dm.update(new_state) + else: + state = nlu.process_utterance(text, dm.get_consec_history(), dm.get_sep_history()) + state['help'] = False + state['inactive'] = False + old_user_interests = user.get_user_interests() + old_vre_material = pd.concat([vre.db['paper_db'], vre.db['dataset_db']]).reset_index(drop=True) + user_interests = [] + for entity in state['entities']: + if entity['entity'] == 'TOPIC': + user_interests.append(entity['value']) + user.update_interests(user_interests) + new_user_interests = user.get_user_interests() + new_vre_material = pd.concat([vre.db['paper_db'], vre.db['dataset_db']]).reset_index(drop=True) + if (new_user_interests() != old_user_interests or old_vre_material != new_vre_material): + rec.generate_recommendations(new_user_interests, new_vre_material) + dm.update(state) + action = dm.next_action() + response = rg.gen_response(action=action, utterance=state['modified_query'], state=dm.get_recent_state(), consec_history=dm.get_consec_history()) + message = {"answer": response, "query": text, "cand": "candidate", "history": dm.get_consec_history(), "modQuery": state['modified_query']} + new_state = {'modified_query': response} + dm.update(new_state) reply = jsonify(message) return reply @app.route('/api/feedback', methods = ['POST']) def feedback(): data = request.get_json().get("feedback") + print(data) + + cur.execute('INSERT INTO feedback_trial (query, history, janet_modified_query, is_modified_query_correct, user_modified_query, response, preferred_response, response_length_feedback, response_fluency_feedback, response_truth_feedback, response_useful_feedback, response_time_feedback, response_intent) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)', + (data['query'], data['history'], data['modQuery'], + data['queryModCorrect'], data['correctQuery'], + data['janetResponse'], data['preferredResponse'], data['length'], + data['fluency'], data['truthfulness'], data['usefulness'], + data['speed'], data['intent']) + ) + conn.commit() + reply = jsonify({"status": "done"}) return reply if __name__ == "__main__": + warnings.filterwarnings("ignore") + device = "cuda" if torch.cuda.is_available() else "cpu" + device_flag = torch.cuda.current_device() if torch.cuda.is_available() else -1 + + query_rewriter = pipeline("text2text-generation", model="castorini/t5-base-canard") + intent_classifier = pipeline("sentiment-analysis", model='/models/intent_classifier', device=device_flag) + entity_extractor = spacy.load("/models/entity_extractor") + offensive_classifier = pipeline("sentiment-analysis", model='/models/offensive_classifier', device=device_flag) + ambig_classifier = pipeline("sentiment-analysis", model='/models/ambig_classifier', device=device_flag) + coref_resolver = spacy.load("en_coreference_web_trf") + + nlu = NLU(query_rewriter, coref_resolver, intent_classifier, offensive_classifier, entity_extractor, ambig_classifier) + + #load retriever and generator + retriever = SentenceTransformer('/models/BigRetriever/').to(device) + qa_generator = pipeline("text2text-generation", model="/models/train_qa", device=device_flag) + summ_generator = pipeline("text2text-generation", model="/models/train_summ", device=device_flag) + chat_generator = pipeline("text2text-generation", model="/models/train_chat", device=device_flag) + amb_generator = pipeline("text2text-generation", model="/models/train_amb_gen", device=device_flag) + generators = {'qa': qa_generator, + 'chat': chat_generator, + 'amb': amb_generator, + 'summ': summ_generator} + + #load vre + token = '2c1e8f88-461c-42c0-8cc1-b7660771c9a3-843339462' + vre = VRE("assistedlab", token, retriever) + vre.init() + index = vre.get_index() + db = vre.get_db() + user = User("ahmed", token) + + threading.Thread(target=vre_fetch, name='updatevre').start() + + threading.Thread(target=user_interest_decay, name='decayinterest').start() + + rec = Recommender(retriever) + + dm = DM() + + rg = ResponseGenerator(index,db, rec, generators, retriever) + + + cur.execute('CREATE TABLE IF NOT EXISTS feedback_trial (id serial PRIMARY KEY,' + 'query text NOT NULL,' + 'history text NOT NULL,' + 'janet_modified_query text NOT NULL,' + 'is_modified_query_correct text NOT NULL,' + 'user_modified_query text NOT NULL,' + 'response text NOT NULL,' + 'preferred_response text,' + 'response_length_feedback text NOT NULL,' + 'response_fluency_feedback text NOT NULL,' + 'response_truth_feedback text NOT NULL,' + 'response_useful_feedback text NOT NULL,' + 'response_time_feedback text NOT NULL,' + 'response_intent text NOT NULL);' + ) + conn.commit() + app.run(host='0.0.0.0', port=4000) diff --git a/main_simple.py b/main_simple.py new file mode 100644 index 0000000..761bfda --- /dev/null +++ b/main_simple.py @@ -0,0 +1,27 @@ +from flask import Flask, request, jsonify +from flask_cors import CORS, cross_origin +import os + +app = Flask(__name__) +url = os.getenv("FRONTEND_URL_WITH_PORT") +cors = CORS(app, resources={r"/predict": {"origins": url}, r"/feedback": {"origins": url}}) + +@app.route("/health", methods=['GET']) +def check_health(): + return "Success", 200 + +@app.route("/api/predict", methods=['POST']) +def predict(): + text = request.get_json().get("message") + message = {"answer": "answer", "query": "text", "cand": "candidate", "history": "history", "modQuery": "modQuery"} + reply = jsonify(message) + return reply + +@app.route('/api/feedback', methods = ['POST']) +def feedback(): + data = request.get_json().get("feedback") + reply = jsonify({"status": "done"}) + return reply + +if __name__ == "__main__": + app.run(host='0.0.0.0', port=4000)