diff --git a/main.py b/main.py index cacff11..761bfda 100644 --- a/main.py +++ b/main.py @@ -1,63 +1,10 @@ -import os -import warnings -import faiss -import torch -from flask import Flask, render_template, request, jsonify +from flask import Flask, request, jsonify from flask_cors import CORS, cross_origin -import psycopg2 -import spacy -import spacy_transformers -import torch -from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline -from User import User -from VRE import VRE -from NLU import NLU -from DM import DM -from Recommender import Recommender -from ResponseGenerator import ResponseGenerator -import pandas as pd -import time -import threading -from sentence_transformers import SentenceTransformer - - +import os app = Flask(__name__) url = os.getenv("FRONTEND_URL_WITH_PORT") -cors = CORS(app, resources={r"/api/predict": {"origins": url}, - r"/api/feedback": {"origins": url}, - r"/health": {"origins": "*"} - }) - -conn = psycopg2.connect( - host="janet-pg", - database=os.getenv("POSTGRES_DB"), - user=os.getenv("POSTGRES_USER"), - password=os.getenv("POSTGRES_PASSWORD")) - -""" -conn = psycopg2.connect(host="https://janet-app-db.d4science.org", - database="janet", - user="janet_user", - password="2fb5e81fec5a2d906a04") -""" -cur = conn.cursor() - - -def vre_fetch(): - while True: - time.sleep(1000) - print('getting new material') - vre.get_vre_update() - vre.index_periodic_update() - rg.update_index(vre.get_index()) - rg.update_db(vre.get_db()) - -def user_interest_decay(): - while True: - print("decaying interests after 3 minutes") - time.sleep(180) - user.decay_interests() +cors = CORS(app, resources={r"/predict": {"origins": url}, r"/feedback": {"origins": url}}) @app.route("/health", methods=['GET']) def check_health(): @@ -66,121 +13,15 @@ def check_health(): @app.route("/api/predict", methods=['POST']) def predict(): text = request.get_json().get("message") - message = {} - if text == "": - state = {'help': True, 'inactive': False, 'modified_query':""} - dm.update(state) - action = dm.next_action() - response = rg.gen_response(action) - message = {"answer": response} - elif text == "": - state = {'help': False, 'inactive': True, 'modified_query':"recommed: "} - dm.update(state) - action = dm.next_action() - response = rg.gen_response(action, username=user.username) - message = {"answer": response} - new_state = {'modified_query': response} - dm.update(new_state) - else: - state = nlu.process_utterance(text, dm.get_consec_history(), dm.get_sep_history()) - state['help'] = False - state['inactive'] = False - old_user_interests = user.get_user_interests() - old_vre_material = pd.concat([vre.db['paper_db'], vre.db['dataset_db']]).reset_index(drop=True) - user_interests = [] - for entity in state['entities']: - if entity['entity'] == 'TOPIC': - user_interests.append(entity['value']) - user.update_interests(user_interests) - new_user_interests = user.get_user_interests() - new_vre_material = pd.concat([vre.db['paper_db'], vre.db['dataset_db']]).reset_index(drop=True) - if (new_user_interests() != old_user_interests or old_vre_material != new_vre_material): - rec.generate_recommendations(new_user_interests, new_vre_material) - dm.update(state) - action = dm.next_action() - response = rg.gen_response(action=action, utterance=state['modified_query'], state=dm.get_recent_state(), consec_history=dm.get_consec_history()) - message = {"answer": response, "query": text, "cand": "candidate", "history": dm.get_consec_history(), "modQuery": state['modified_query']} - new_state = {'modified_query': response} - dm.update(new_state) + message = {"answer": "answer", "query": "text", "cand": "candidate", "history": "history", "modQuery": "modQuery"} reply = jsonify(message) return reply @app.route('/api/feedback', methods = ['POST']) def feedback(): data = request.get_json().get("feedback") - print(data) - - cur.execute('INSERT INTO feedback_trial (query, history, janet_modified_query, is_modified_query_correct, user_modified_query, response, preferred_response, response_length_feedback, response_fluency_feedback, response_truth_feedback, response_useful_feedback, response_time_feedback, response_intent) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)', - (data['query'], data['history'], data['modQuery'], - data['queryModCorrect'], data['correctQuery'], - data['janetResponse'], data['preferredResponse'], data['length'], - data['fluency'], data['truthfulness'], data['usefulness'], - data['speed'], data['intent']) - ) - conn.commit() - reply = jsonify({"status": "done"}) return reply if __name__ == "__main__": - warnings.filterwarnings("ignore") - device = "cuda" if torch.cuda.is_available() else "cpu" - device_flag = torch.cuda.current_device() if torch.cuda.is_available() else -1 - - query_rewriter = pipeline("text2text-generation", model="castorini/t5-base-canard") - intent_classifier = pipeline("sentiment-analysis", model='/models/intent_classifier', device=device_flag) - entity_extractor = spacy.load("/models/entity_extractor") - offensive_classifier = pipeline("sentiment-analysis", model='/models/offensive_classifier', device=device_flag) - ambig_classifier = pipeline("sentiment-analysis", model='/models/ambig_classifier', device=device_flag) - coref_resolver = spacy.load("en_coreference_web_trf") - - nlu = NLU(query_rewriter, coref_resolver, intent_classifier, offensive_classifier, entity_extractor, ambig_classifier) - - #load retriever and generator - retriever = SentenceTransformer('/models/BigRetriever/').to(device) - qa_generator = pipeline("text2text-generation", model="/models/train_qa", device=device_flag) - summ_generator = pipeline("text2text-generation", model="/models/train_summ", device=device_flag) - chat_generator = pipeline("text2text-generation", model="/models/train_chat", device=device_flag) - amb_generator = pipeline("text2text-generation", model="/models/train_amb_gen", device=device_flag) - generators = {'qa': qa_generator, - 'chat': chat_generator, - 'amb': amb_generator, - 'summ': summ_generator} - - #load vre - token = '2c1e8f88-461c-42c0-8cc1-b7660771c9a3-843339462' - vre = VRE("assistedlab", token, retriever) - vre.init() - index = vre.get_index() - db = vre.get_db() - user = User("ahmed", token) - - threading.Thread(target=vre_fetch, name='updatevre').start() - - threading.Thread(target=user_interest_decay, name='decayinterest').start() - - rec = Recommender(retriever) - - dm = DM() - - rg = ResponseGenerator(index,db, rec, generators, retriever) - - - cur.execute('CREATE TABLE IF NOT EXISTS feedback_trial (id serial PRIMARY KEY,' - 'query text NOT NULL,' - 'history text NOT NULL,' - 'janet_modified_query text NOT NULL,' - 'is_modified_query_correct text NOT NULL,' - 'user_modified_query text NOT NULL,' - 'response text NOT NULL,' - 'preferred_response text,' - 'response_length_feedback text NOT NULL,' - 'response_fluency_feedback text NOT NULL,' - 'response_truth_feedback text NOT NULL,' - 'response_useful_feedback text NOT NULL,' - 'response_time_feedback text NOT NULL,' - 'response_intent text NOT NULL);' - ) - conn.commit() - app.run(host='0.0.0.0', port=4000)