Compare commits
No commits in common. "e9a9afbcf8cdc73e6a450ed393d2a2b73e7182ae" and "af0bd24538df865cbea2dea53eef289d9d1164e4" have entirely different histories.
e9a9afbcf8
...
af0bd24538
|
@ -1,2 +0,0 @@
|
|||
.git
|
||||
__pycache__
|
|
@ -1,3 +0,0 @@
|
|||
janet.pdf
|
||||
__pycache__/
|
||||
ahmed.ibrahim39699_interests.json
|
|
@ -2,12 +2,12 @@ FROM python:3.8
|
|||
|
||||
WORKDIR /backend_janet
|
||||
|
||||
COPY requirements_simple.txt .
|
||||
COPY requirements_main.txt .
|
||||
|
||||
RUN pip install -r requirements_simple.txt
|
||||
RUN pip install -r requirements_main.txt
|
||||
|
||||
RUN rm -fr /root/.cache/*
|
||||
|
||||
COPY . .
|
||||
|
||||
ENTRYPOINT ["python", "main_simple.py"]
|
||||
ENTRYPOINT ["python", "main.py"]
|
||||
|
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -1 +0,0 @@
|
|||
{"interest":{"0":"chatbots?","1":"list commands","2":"chatbots"},"frequency":{"0":2,"1":1,"2":1}}
|
112
main.py
112
main.py
|
@ -34,43 +34,19 @@ cors = CORS(app, resources={r"/api/predict": {"origins": url},
|
|||
users = {}
|
||||
alive = "alive"
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
device_flag = torch.cuda.current_device() if torch.cuda.is_available() else -1
|
||||
|
||||
query_rewriter = pipeline("text2text-generation", model="castorini/t5-base-canard")
|
||||
intent_classifier = pipeline("sentiment-analysis", model='/models/intent_classifier', device=device_flag)
|
||||
entity_extractor = spacy.load("/models/entity_extractor")
|
||||
offensive_classifier = pipeline("sentiment-analysis", model='/models/offensive_classifier', device=device_flag)
|
||||
ambig_classifier = pipeline("sentiment-analysis", model='/models/ambig_classifier', device=device_flag)
|
||||
coref_resolver = spacy.load("en_coreference_web_trf")
|
||||
|
||||
nlu = NLU(query_rewriter, coref_resolver, intent_classifier, offensive_classifier, entity_extractor, ambig_classifier)
|
||||
|
||||
#load retriever and generator
|
||||
retriever = SentenceTransformer('/models/retriever/').to(device)
|
||||
qa_generator = pipeline("text2text-generation", model="/models/train_qa", device=device_flag)
|
||||
summ_generator = pipeline("text2text-generation", model="/models/train_summ", device=device_flag)
|
||||
chat_generator = pipeline("text2text-generation", model="/models/train_chat", device=device_flag)
|
||||
amb_generator = pipeline("text2text-generation", model="/models/train_amb_gen", device=device_flag)
|
||||
generators = {'qa': qa_generator,
|
||||
'chat': chat_generator,
|
||||
'amb': amb_generator,
|
||||
'summ': summ_generator}
|
||||
rec = Recommender(retriever)
|
||||
|
||||
def vre_fetch(token):
|
||||
def vre_fetch():
|
||||
while True:
|
||||
try:
|
||||
time.sleep(1000)
|
||||
print('getting new material')
|
||||
users[token]['vre'].get_vre_update()
|
||||
users[token]['vre'].index_periodic_update()
|
||||
users[token]['rg'].update_index(vre.get_index())
|
||||
users[token]['rg'].update_db(vre.get_db())
|
||||
#vre.get_vre_update()
|
||||
#vre.index_periodic_update()
|
||||
#rg.update_index(vre.get_index())
|
||||
#rg.update_db(vre.get_db())
|
||||
#users[token]['args']['vre'].get_vre_update()
|
||||
#users[token]['args']['vre'].index_periodic_update()
|
||||
#users[token]['args']['rg'].update_index(vre.get_index())
|
||||
#users[token]['args']['rg'].update_db(vre.get_db())
|
||||
vre.get_vre_update()
|
||||
vre.index_periodic_update()
|
||||
rg.update_index(vre.get_index())
|
||||
rg.update_db(vre.get_db())
|
||||
except Exception as e:
|
||||
alive = "dead_vre_fetch"
|
||||
|
||||
|
@ -113,7 +89,7 @@ def init_dm():
|
|||
token = request.get_json().get("token")
|
||||
status = request.get_json().get("stat")
|
||||
if status == "start":
|
||||
message = {"stat": "waiting", "err": ""}
|
||||
message = {"stat": "waiting"}
|
||||
elif status == "set":
|
||||
headers = {"gcube-token": token, "Accept": "application/json"}
|
||||
if token not in users:
|
||||
|
@ -123,25 +99,18 @@ def init_dm():
|
|||
username = response.json()['result']['username']
|
||||
name = response.json()['result']['fullname']
|
||||
|
||||
vre = VRE("assistedlab", token, retriever)
|
||||
vre.init()
|
||||
index = vre.get_index()
|
||||
db = vre.get_db()
|
||||
|
||||
rg = ResponseGenerator(index,db, rec, generators, retriever)
|
||||
|
||||
users[token] = {'username': username, 'name': name, 'dm': DM(), 'activity': 0, 'user': User(username, token), 'vre': vre, 'rg': rg}
|
||||
users[token] = {'username': username, 'name': name, 'dm': DM(), 'activity': 0, 'user': User(username, token)}
|
||||
|
||||
threading.Thread(target=user_interest_decay, args=(token,), name='decayinterest_'+users[token]['username']).start()
|
||||
threading.Thread(target=vre_fetch, name='updatevre'+users[token]['username'], args=(token,)).start()
|
||||
message = {"stat": "done", "err": ""}
|
||||
|
||||
message = {"stat": "done"}
|
||||
else:
|
||||
message = {"stat": "rejected", "err": ""}
|
||||
message = {"stat": "rejected"}
|
||||
else:
|
||||
message = {"stat": "done", "err": ""}
|
||||
message = {"stat": "done"}
|
||||
return message
|
||||
except Exception as e:
|
||||
message = {"stat": "init_dm_error", "err": str(e)}
|
||||
message = {"stat": "init_dm_error"}
|
||||
return message
|
||||
|
||||
|
||||
|
@ -151,8 +120,8 @@ def predict():
|
|||
token = request.get_json().get("token")
|
||||
dm = users[token]['dm']
|
||||
user = users[token]['user']
|
||||
rg = users[token]['rg']
|
||||
vre = users[token]['vre']
|
||||
#rg = users[token]['args']['rg']
|
||||
#vre = users[token]['args']['vre']
|
||||
message = {}
|
||||
try:
|
||||
if text == "<HELP_ON_START>":
|
||||
|
@ -198,8 +167,8 @@ def predict():
|
|||
users[token]['dm'] = dm
|
||||
users[token]['user'] = user
|
||||
users[token]['activity'] = 0
|
||||
users[token]['vre'] = vre
|
||||
users[token]['rg'] = rg
|
||||
#users[token]['args']['vre'] = vre
|
||||
#users[token]['args']['rg'] = rg
|
||||
return reply
|
||||
except Exception as e:
|
||||
message = {"answer": str(e), "query": "", "cand": "candidate", "history": "", "modQuery": ""}
|
||||
|
@ -231,6 +200,47 @@ def feedback():
|
|||
|
||||
if __name__ == "__main__":
|
||||
warnings.filterwarnings("ignore")
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
device_flag = torch.cuda.current_device() if torch.cuda.is_available() else -1
|
||||
|
||||
query_rewriter = pipeline("text2text-generation", model="castorini/t5-base-canard")
|
||||
intent_classifier = pipeline("sentiment-analysis", model='/models/intent_classifier', device=device_flag)
|
||||
entity_extractor = spacy.load("/models/entity_extractor")
|
||||
offensive_classifier = pipeline("sentiment-analysis", model='/models/offensive_classifier', device=device_flag)
|
||||
ambig_classifier = pipeline("sentiment-analysis", model='/models/ambig_classifier', device=device_flag)
|
||||
coref_resolver = spacy.load("en_coreference_web_trf")
|
||||
|
||||
nlu = NLU(query_rewriter, coref_resolver, intent_classifier, offensive_classifier, entity_extractor, ambig_classifier)
|
||||
|
||||
#load retriever and generator
|
||||
retriever = SentenceTransformer('/models/retriever/').to(device)
|
||||
qa_generator = pipeline("text2text-generation", model="/models/train_qa", device=device_flag)
|
||||
summ_generator = pipeline("text2text-generation", model="/models/train_summ", device=device_flag)
|
||||
chat_generator = pipeline("text2text-generation", model="/models/train_chat", device=device_flag)
|
||||
amb_generator = pipeline("text2text-generation", model="/models/train_amb_gen", device=device_flag)
|
||||
generators = {'qa': qa_generator,
|
||||
'chat': chat_generator,
|
||||
'amb': amb_generator,
|
||||
'summ': summ_generator}
|
||||
rec = Recommender(retriever)
|
||||
vre = VRE("assistedlab", '2c1e8f88-461c-42c0-8cc1-b7660771c9a3-843339462', retriever)
|
||||
vre.init()
|
||||
index = vre.get_index()
|
||||
db = vre.get_db()
|
||||
rg = ResponseGenerator(index,db, rec, generators, retriever)
|
||||
del retriever
|
||||
del generators
|
||||
del qa_generator
|
||||
del chat_generator
|
||||
del summ_generator
|
||||
del amb_generator
|
||||
del query_rewriter
|
||||
del intent_classifier
|
||||
del entity_extractor
|
||||
del offensive_classifier
|
||||
del ambig_classifier
|
||||
del coref_resolver
|
||||
threading.Thread(target=vre_fetch, name='updatevre').start()
|
||||
threading.Thread(target=clear_inactive, name='clear').start()
|
||||
"""
|
||||
conn = psycopg2.connect(host="janet-pg", database=os.getenv("POSTGRES_DB"), user=os.getenv("POSTGRES_USER"), password=os.getenv("POSTGRES_PASSWORD"))
|
||||
|
|
|
@ -5,9 +5,6 @@ import shutil
|
|||
import re
|
||||
import requests
|
||||
import time
|
||||
from User import User
|
||||
from DM import DM
|
||||
import threading
|
||||
app = Flask(__name__)
|
||||
url = os.getenv("FRONTEND_URL_WITH_PORT")
|
||||
cors = CORS(app, resources={r"/api/predict": {"origins": url},
|
||||
|
@ -16,54 +13,31 @@ cors = CORS(app, resources={r"/api/predict": {"origins": url},
|
|||
r"/health": {"origins": "*"}
|
||||
})
|
||||
users = {}
|
||||
alive = "alive"
|
||||
def user_interest_decay(token):
|
||||
while True:
|
||||
try:
|
||||
if token in users:
|
||||
print("decaying interests after 3 minutes for " + users[token]['username'])
|
||||
time.sleep(180)
|
||||
users[token]['user'].decay_interests()
|
||||
else:
|
||||
break
|
||||
except Exception as e:
|
||||
alive = "dead_interest_decay"
|
||||
|
||||
@app.route("/health", methods=['GET'])
|
||||
def health():
|
||||
if alive=="alive":
|
||||
return "Success", 200
|
||||
else:
|
||||
return alive, 500
|
||||
return "Success", 200
|
||||
|
||||
@app.route("/api/dm", methods=['POST'])
|
||||
def init_dm():
|
||||
try:
|
||||
token = request.get_json().get("token")
|
||||
status = request.get_json().get("stat")
|
||||
if status == "start":
|
||||
message = {"stat": "waiting", "err": ""}
|
||||
elif status == "set":
|
||||
headers = {"gcube-token": token, "Accept": "application/json"}
|
||||
if token not in users:
|
||||
url = 'https://api.d4science.org/rest/2/people/profile'
|
||||
response = requests.get(url, headers=headers)
|
||||
if response.status_code == 200:
|
||||
username = response.json()['result']['username']
|
||||
name = response.json()['result']['fullname']
|
||||
|
||||
users[token] = {'username': username, 'name': name, 'dm': DM(), 'activity': 0, 'user': User(username, token)}
|
||||
|
||||
threading.Thread(target=user_interest_decay, args=(token,), name='decayinterest_'+users[token]['username']).start()
|
||||
|
||||
message = {"stat": "done", "err": ""}
|
||||
else:
|
||||
message = {"stat": "rejected", "err": ""}
|
||||
token = request.get_json().get("token")
|
||||
status = request.get_json().get("stat")
|
||||
if status == "start":
|
||||
message = {"stat": "waiting"}
|
||||
elif status == "set":
|
||||
headers = {"gcube-token": token, "Accept": "application/json"}
|
||||
if token not in users:
|
||||
url = 'https://api.d4science.org/rest/2/people/profile'
|
||||
response = requests.get(url, headers=headers)
|
||||
if response.status_code == 200:
|
||||
username = response.json()['result']['username']
|
||||
name = response.json()['result']['fullname']
|
||||
message = {"stat": "done"}
|
||||
else:
|
||||
message = {"stat": "done", "err": ""}
|
||||
return message
|
||||
except Exception as e:
|
||||
message = {"stat": "init_dm_error", "err": str(e)}
|
||||
return message
|
||||
message = {"stat": "rejected"}
|
||||
else:
|
||||
message = {"stat": "done"}
|
||||
return message
|
||||
@app.route("/api/predict", methods=['POST'])
|
||||
def predict():
|
||||
time.sleep(10)
|
||||
|
@ -80,7 +54,7 @@ def feedback():
|
|||
return reply
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""
|
||||
|
||||
folder = '/app'
|
||||
for filename in os.listdir(folder):
|
||||
file_path = os.path.join(folder, filename)
|
||||
|
@ -91,5 +65,4 @@ if __name__ == "__main__":
|
|||
shutil.rmtree(file_path)
|
||||
except Exception as e:
|
||||
print('Failed to delete %s. Reason: %s' % (file_path, e))
|
||||
"""
|
||||
app.run(host='0.0.0.0')
|
||||
|
|
|
@ -0,0 +1,38 @@
|
|||
faiss-gpu==1.7.2
|
||||
Flask==1.1.4
|
||||
flask-cors==3.0.10
|
||||
protobuf==3.20.0
|
||||
matplotlib==3.5.3
|
||||
nltk==3.7
|
||||
numpy==1.22.4
|
||||
pandas==1.3.5
|
||||
PyPDF2==3.0.1
|
||||
pdfquery
|
||||
html2text
|
||||
regex==2022.6.2
|
||||
requests==2.25.1
|
||||
scikit-learn==1.0.2
|
||||
scipy==1.7.3
|
||||
sentencepiece==0.1.97
|
||||
sklearn-pandas==1.8.0
|
||||
spacy==3.4.4
|
||||
spacy-alignments==0.9.0
|
||||
spacy-legacy==3.0.12
|
||||
spacy-loggers==1.0.4
|
||||
spacy-transformers==1.1.9
|
||||
spacy-experimental==0.6.2
|
||||
torch @ https://download.pytorch.org/whl/cu116/torch-1.13.1%2Bcu116-cp38-cp38-linux_x86_64.whl
|
||||
torchaudio @ https://download.pytorch.org/whl/cu116/torchaudio-0.13.1%2Bcu116-cp38-cp38-linux_x86_64.whl
|
||||
torchsummary==1.5.1
|
||||
torchtext==0.14.1
|
||||
sentence-transformers
|
||||
torchvision @ https://download.pytorch.org/whl/cu116/torchvision-0.14.1%2Bcu116-cp38-cp38-linux_x86_64.whl
|
||||
tqdm==4.64.1
|
||||
transformers
|
||||
markupsafe==2.0.1
|
||||
psycopg2==2.9.5
|
||||
en-coreference-web-trf @ https://github.com/explosion/spacy-experimental/releases/download/v0.6.1/en_coreference_web_trf-3.4.0a2-py3-none-any.whl
|
||||
Werkzeug==1.0.1
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue