134 lines
4.4 KiB
Python
134 lines
4.4 KiB
Python
import os
|
|
import warnings
|
|
|
|
import faiss
|
|
import torch
|
|
from flask import Flask, render_template, request, jsonify
|
|
from flask_cors import CORS, cross_origin
|
|
import spacy
|
|
import spacy_transformers
|
|
import torch
|
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
|
|
|
|
from User import User
|
|
from VRE import VRE
|
|
from NLU import NLU
|
|
from DM import DM
|
|
from Recommender import Recommender
|
|
from ResponseGenerator import ResponseGenerator
|
|
|
|
|
|
import pandas as pd
|
|
import time
|
|
import threading
|
|
|
|
from sentence_transformers import SentenceTransformer
|
|
|
|
app = Flask(__name__)
|
|
#allow frontend address
|
|
url = os.getenv("FRONTEND_URL_WITH_PORT")
|
|
cors = CORS(app, resources={r"/predict": {"origins": url}, r"/feedback": {"origins": url}})
|
|
#cors = CORS(app, resources={r"/predict": {"origins": "*"}, r"/feedback": {"origins": "*"}})
|
|
|
|
|
|
|
|
#rg = ResponseGenerator(index)
|
|
|
|
def get_response(text):
|
|
# get response from janet itself
|
|
return text, 'candAnswer'
|
|
|
|
def vre_fetch():
|
|
while True:
|
|
time.sleep(1000)
|
|
print('getting new material')
|
|
vre.get_vre_update()
|
|
vre.index_periodic_update()
|
|
rg.update_index(vre.get_index())
|
|
rg.update_db(vre.get_db())
|
|
|
|
def user_interest_decay():
|
|
while True:
|
|
print("decaying interests after 3 minutes")
|
|
time.sleep(180)
|
|
user.decay_interests()
|
|
|
|
def recommend():
|
|
while True:
|
|
if time.time() - dm.get_recent_state()['time'] > 1000:
|
|
print("Making Recommendation: ")
|
|
prompt = rec.make_recommendation(user.username)
|
|
if prompt != "":
|
|
print(prompt)
|
|
time.sleep(1000)
|
|
|
|
|
|
@app.route("/predict", methods=['POST'])
|
|
def predict():
|
|
text = request.get_json().get("message")
|
|
state = nlu.process_utterance(text, dm.get_utt_history())
|
|
user_interests = []
|
|
for entity in state['entities']:
|
|
if entity['entity'] == 'TOPIC':
|
|
user_interests.append(entity['value'])
|
|
user.update_interests(user_interests)
|
|
dm.update(state)
|
|
action = dm.next_action()
|
|
response = rg.gen_response(state['modified_prompt'], dm.get_recent_state(), dm.get_utt_history(), action)
|
|
message = {"answer": response, "query": text, "cand": "candidate", "history": dm.get_utt_history(), "modQuery": state['modified_prompt']}
|
|
reply = jsonify(message)
|
|
#reply.headers.add('Access-Control-Allow-Origin', '*')
|
|
return reply
|
|
|
|
@app.route('/feedback', methods = ['POST'])
|
|
def feedback():
|
|
data = request.get_json()['feedback']
|
|
# Make data frame of above data
|
|
print(data)
|
|
df = pd.DataFrame([data])
|
|
file_exists = os.path.isfile('feedback.csv')
|
|
|
|
#df = pd.DataFrame(data=[data['response'], data['length'], data['fluency'], data['truthfulness'], data['usefulness'], data['speed']]
|
|
# ,columns=['response', 'length', 'fluency', 'truthfulness', 'usefulness', 'speed'])
|
|
df.to_csv('feedback.csv', mode='a', index=False, header=(not file_exists))
|
|
reply = jsonify({"status": "done"})
|
|
#reply.headers.add('Access-Control-Allow-Origin', '*')
|
|
return reply
|
|
|
|
if __name__ == "__main__":
|
|
warnings.filterwarnings("ignore")
|
|
#load NLU
|
|
def_tokenizer = AutoTokenizer.from_pretrained("castorini/t5-base-canard")
|
|
def_reference_resolver = AutoModelForSeq2SeqLM.from_pretrained("castorini/t5-base-canard")
|
|
def_intent_classifier_dir = "./IntentClassifier/"
|
|
def_entity_extractor = spacy.load("./EntityExtraction/BestModel")
|
|
def_offense_filter_dir ="./OffensiveClassifier"
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
device_flag = torch.cuda.current_device() if torch.cuda.is_available() else -1
|
|
nlu = NLU(device, device_flag, def_reference_resolver, def_tokenizer, def_intent_classifier_dir, def_offense_filter_dir, def_entity_extractor)
|
|
|
|
#load retriever and generator
|
|
def_retriever = SentenceTransformer('./BigRetriever/').to(device)
|
|
def_generator = pipeline("text2text-generation", model="./generator", device=device_flag)
|
|
|
|
|
|
#load vre
|
|
token = '2c1e8f88-461c-42c0-8cc1-b7660771c9a3-843339462'
|
|
vre = VRE("assistedlab", token, def_retriever)
|
|
vre.init()
|
|
index = vre.get_index()
|
|
db = vre.get_db()
|
|
user = User("ahmed", token)
|
|
|
|
threading.Thread(target=vre_fetch, name='updatevre').start()
|
|
|
|
threading.Thread(target=user_interest_decay, name='decayinterest').start()
|
|
|
|
|
|
rec = Recommender(def_retriever)
|
|
|
|
dm = DM()
|
|
rg = ResponseGenerator(index,db,def_generator,def_retriever)
|
|
threading.Thread(target=recommend, name='recommend').start()
|
|
app.run(host='127.0.0.1', port=4000)
|