import os import warnings import faiss import torch from flask import Flask, render_template, request, jsonify from flask_cors import CORS, cross_origin import spacy import spacy_transformers import torch from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline from User import User from VRE import VRE from NLU import NLU from DM import DM from Recommender import Recommender from ResponseGenerator import ResponseGenerator import pandas as pd import time import threading from sentence_transformers import SentenceTransformer app = Flask(__name__) #allow frontend address url = os.getenv("FRONTEND_URL_WITH_PORT") cors = CORS(app, resources={r"/predict": {"origins": url}, r"/feedback": {"origins": url}}) #cors = CORS(app, resources={r"/predict": {"origins": "*"}, r"/feedback": {"origins": "*"}}) #rg = ResponseGenerator(index) def get_response(text): # get response from janet itself return text, 'candAnswer' def vre_fetch(): while True: time.sleep(1000) print('getting new material') vre.get_vre_update() vre.index_periodic_update() rg.update_index(vre.get_index()) rg.update_db(vre.get_db()) def user_interest_decay(): while True: print("decaying interests after 3 minutes") time.sleep(180) user.decay_interests() def recommend(): while True: if time.time() - dm.get_recent_state()['time'] > 1000: print("Making Recommendation: ") prompt = rec.make_recommendation(user.username) if prompt != "": print(prompt) time.sleep(1000) @app.route("/predict", methods=['POST']) def predict(): text = request.get_json().get("message") state = nlu.process_utterance(text, dm.get_utt_history()) user_interests = [] for entity in state['entities']: if entity['entity'] == 'TOPIC': user_interests.append(entity['value']) user.update_interests(user_interests) dm.update(state) action = dm.next_action() response = rg.gen_response(state['modified_prompt'], dm.get_recent_state(), dm.get_utt_history(), action) message = {"answer": response, "query": text, "cand": "candidate", "history": dm.get_utt_history(), "modQuery": state['modified_prompt']} reply = jsonify(message) #reply.headers.add('Access-Control-Allow-Origin', '*') return reply @app.route('/feedback', methods = ['POST']) def feedback(): data = request.get_json()['feedback'] # Make data frame of above data print(data) df = pd.DataFrame([data]) file_exists = os.path.isfile('feedback.csv') #df = pd.DataFrame(data=[data['response'], data['length'], data['fluency'], data['truthfulness'], data['usefulness'], data['speed']] # ,columns=['response', 'length', 'fluency', 'truthfulness', 'usefulness', 'speed']) df.to_csv('feedback.csv', mode='a', index=False, header=(not file_exists)) reply = jsonify({"status": "done"}) #reply.headers.add('Access-Control-Allow-Origin', '*') return reply if __name__ == "__main__": warnings.filterwarnings("ignore") #load NLU def_tokenizer = AutoTokenizer.from_pretrained("castorini/t5-base-canard") def_reference_resolver = AutoModelForSeq2SeqLM.from_pretrained("castorini/t5-base-canard") def_intent_classifier_dir = "./IntentClassifier/" def_entity_extractor = spacy.load("./EntityExtraction/BestModel") def_offense_filter_dir ="./OffensiveClassifier" device = "cuda" if torch.cuda.is_available() else "cpu" device_flag = torch.cuda.current_device() if torch.cuda.is_available() else -1 nlu = NLU(device, device_flag, def_reference_resolver, def_tokenizer, def_intent_classifier_dir, def_offense_filter_dir, def_entity_extractor) #load retriever and generator def_retriever = SentenceTransformer('./BigRetriever/').to(device) def_generator = pipeline("text2text-generation", model="./generator", device=device_flag) #load vre token = '2c1e8f88-461c-42c0-8cc1-b7660771c9a3-843339462' vre = VRE("assistedlab", token, def_retriever) vre.init() index = vre.get_index() db = vre.get_db() user = User("ahmed", token) threading.Thread(target=vre_fetch, name='updatevre').start() threading.Thread(target=user_interest_decay, name='decayinterest').start() rec = Recommender(def_retriever) dm = DM() rg = ResponseGenerator(index,db,def_generator,def_retriever) threading.Thread(target=recommend, name='recommend').start() app.run(host='127.0.0.1', port=4000)