From 18560aa636988be81d4f368463bca98c39c8cb62 Mon Sep 17 00:00:00 2001 From: ahmed531998 Date: Thu, 6 Apr 2023 00:37:20 +0200 Subject: [PATCH] cors_fix --- main.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/main.py b/main.py index 34d76b6..b034135 100644 --- a/main.py +++ b/main.py @@ -24,7 +24,7 @@ from sentence_transformers import SentenceTransformer app = Flask(__name__) url = os.getenv("FRONTEND_URL_WITH_PORT") -cors = CORS(app, resources={r"/predict": {"origins": url}, r"/feedback": {"origins": url}}) +cors = CORS(app, resources={r"/predict": {"origins": *}, r"/feedback": {"origins": *}}) conn = psycopg2.connect( host="janet-pg", @@ -113,20 +113,20 @@ if __name__ == "__main__": device_flag = torch.cuda.current_device() if torch.cuda.is_available() else -1 query_rewriter = pipeline("text2text-generation", model="castorini/t5-base-canard") - intent_classifier = pipeline("sentiment-analysis", model='/app/models/intent_classifier', device=device_flag) - entity_extractor = spacy.load("/app/models/entity_extractor") - offensive_classifier = pipeline("sentiment-analysis", model='/app/models/offensive_classifier', device=device_flag) - ambig_classifier = pipeline("sentiment-analysis", model='/app/models/ambig_classifier', device=device_flag) + intent_classifier = pipeline("sentiment-analysis", model='/models/intent_classifier', device=device_flag) + entity_extractor = spacy.load("/models/entity_extractor") + offensive_classifier = pipeline("sentiment-analysis", model='/models/offensive_classifier', device=device_flag) + ambig_classifier = pipeline("sentiment-analysis", model='/models/ambig_classifier', device=device_flag) coref_resolver = spacy.load("en_coreference_web_trf") nlu = NLU(query_rewriter, coref_resolver, intent_classifier, offensive_classifier, entity_extractor, ambig_classifier) #load retriever and generator - retriever = SentenceTransformer('/app/models/BigRetriever/').to(device) - qa_generator = pipeline("text2text-generation", model="/app/models/train_qa", device=device_flag) - summ_generator = pipeline("text2text-generation", model="/app/models/train_summ", device=device_flag) - chat_generator = pipeline("text2text-generation", model="/app/models/train_chat", device=device_flag) - amb_generator = pipeline("text2text-generation", model="/app/models/train_amb_gen", device=device_flag) + retriever = SentenceTransformer('/models/BigRetriever/').to(device) + qa_generator = pipeline("text2text-generation", model="/models/train_qa", device=device_flag) + summ_generator = pipeline("text2text-generation", model="/models/train_summ", device=device_flag) + chat_generator = pipeline("text2text-generation", model="/models/train_chat", device=device_flag) + amb_generator = pipeline("text2text-generation", model="/models/train_amb_gen", device=device_flag) generators = {'qa': qa_generator, 'chat': chat_generator, 'amb': amb_generator,