Janet/JanetBackEnd/main.py

134 lines
4.4 KiB
Python

import os
import warnings
import faiss
import torch
from flask import Flask, render_template, request, jsonify
from flask_cors import CORS, cross_origin
import spacy
import spacy_transformers
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
from User import User
from VRE import VRE
from NLU import NLU
from DM import DM
from Recommender import Recommender
from ResponseGenerator import ResponseGenerator
import pandas as pd
import time
import threading
from sentence_transformers import SentenceTransformer
app = Flask(__name__)
#allow frontend address
url = os.getenv("FRONTEND_URL_WITH_PORT")
cors = CORS(app, resources={r"/predict": {"origins": url}, r"/feedback": {"origins": url}})
#cors = CORS(app, resources={r"/predict": {"origins": "*"}, r"/feedback": {"origins": "*"}})
#rg = ResponseGenerator(index)
def get_response(text):
# get response from janet itself
return text, 'candAnswer'
def vre_fetch():
while True:
time.sleep(1000)
print('getting new material')
vre.get_vre_update()
vre.index_periodic_update()
rg.update_index(vre.get_index())
rg.update_db(vre.get_db())
def user_interest_decay():
while True:
print("decaying interests after 3 minutes")
time.sleep(180)
user.decay_interests()
def recommend():
while True:
if time.time() - dm.get_recent_state()['time'] > 1000:
print("Making Recommendation: ")
prompt = rec.make_recommendation(user.username)
if prompt != "":
print(prompt)
time.sleep(1000)
@app.route("/predict", methods=['POST'])
def predict():
text = request.get_json().get("message")
state = nlu.process_utterance(text, dm.get_utt_history())
user_interests = []
for entity in state['entities']:
if entity['entity'] == 'TOPIC':
user_interests.append(entity['value'])
user.update_interests(user_interests)
dm.update(state)
action = dm.next_action()
response = rg.gen_response(state['modified_prompt'], dm.get_recent_state(), dm.get_utt_history(), action)
message = {"answer": response, "query": text, "cand": "candidate", "history": dm.get_utt_history(), "modQuery": state['modified_prompt']}
reply = jsonify(message)
#reply.headers.add('Access-Control-Allow-Origin', '*')
return reply
@app.route('/feedback', methods = ['POST'])
def feedback():
data = request.get_json()['feedback']
# Make data frame of above data
print(data)
df = pd.DataFrame([data])
file_exists = os.path.isfile('feedback.csv')
#df = pd.DataFrame(data=[data['response'], data['length'], data['fluency'], data['truthfulness'], data['usefulness'], data['speed']]
# ,columns=['response', 'length', 'fluency', 'truthfulness', 'usefulness', 'speed'])
df.to_csv('feedback.csv', mode='a', index=False, header=(not file_exists))
reply = jsonify({"status": "done"})
#reply.headers.add('Access-Control-Allow-Origin', '*')
return reply
if __name__ == "__main__":
warnings.filterwarnings("ignore")
#load NLU
def_tokenizer = AutoTokenizer.from_pretrained("castorini/t5-base-canard")
def_reference_resolver = AutoModelForSeq2SeqLM.from_pretrained("castorini/t5-base-canard")
def_intent_classifier_dir = "./IntentClassifier/"
def_entity_extractor = spacy.load("./EntityExtraction/BestModel")
def_offense_filter_dir ="./OffensiveClassifier"
device = "cuda" if torch.cuda.is_available() else "cpu"
device_flag = torch.cuda.current_device() if torch.cuda.is_available() else -1
nlu = NLU(device, device_flag, def_reference_resolver, def_tokenizer, def_intent_classifier_dir, def_offense_filter_dir, def_entity_extractor)
#load retriever and generator
def_retriever = SentenceTransformer('./BigRetriever/').to(device)
def_generator = pipeline("text2text-generation", model="./generator", device=device_flag)
#load vre
token = '2c1e8f88-461c-42c0-8cc1-b7660771c9a3-843339462'
vre = VRE("assistedlab", token, def_retriever)
vre.init()
index = vre.get_index()
db = vre.get_db()
user = User("ahmed", token)
threading.Thread(target=vre_fetch, name='updatevre').start()
threading.Thread(target=user_interest_decay, name='decayinterest').start()
rec = Recommender(def_retriever)
dm = DM()
rg = ResponseGenerator(index,db,def_generator,def_retriever)
threading.Thread(target=recommend, name='recommend').start()
app.run(host='127.0.0.1', port=4000)