test google gemma
This commit is contained in:
parent
73117674ad
commit
1efd0ac18d
|
@ -1,3 +1,5 @@
|
||||||
janet.pdf
|
janet.pdf
|
||||||
__pycache__/
|
__pycache__/
|
||||||
|
git-filter-repo
|
||||||
|
.gitignore
|
||||||
ahmed.ibrahim39699_interests.json
|
ahmed.ibrahim39699_interests.json
|
||||||
|
|
60
DM.py
60
DM.py
|
@ -6,43 +6,61 @@ class DM:
|
||||||
self.working_history_consec = ""
|
self.working_history_consec = ""
|
||||||
self.chitchat_history_consec = ""
|
self.chitchat_history_consec = ""
|
||||||
self.max_history_length = max_history_length
|
self.max_history_length = max_history_length
|
||||||
|
self.history = ""
|
||||||
self.chat_history = []
|
self.chat_history = []
|
||||||
self.curr_state = None
|
self.curr_state = None
|
||||||
|
|
||||||
def update_history(self):
|
#def update_history(self):
|
||||||
to_consider = [x['modified_query'] for x in self.chat_history[-self.max_history_length*2:]]
|
#to_consider = [x['modified_query'] for x in self.chat_history[-self.max_history_length*2:]]
|
||||||
self.working_history_consec = " . ".join(to_consider)
|
#self.working_history_consec = " . ".join(to_consider)
|
||||||
self.working_history_sep = " ||| ".join(to_consider)
|
#self.working_history_sep = " ||| ".join(to_consider)
|
||||||
|
|
||||||
chat = []
|
#chat = []
|
||||||
for utt in self.chat_history:
|
#for utt in self.chat_history:
|
||||||
if utt['intent'] == 'CHITCHAT':
|
# if utt['intent'] == 'CHITCHAT':
|
||||||
if len(chat) == 4:
|
# if len(chat) == 4:
|
||||||
chat = chat[1:]
|
# chat = chat[1:]
|
||||||
chat.append(utt['modified_query'])
|
# chat.append(utt['modified_query'])
|
||||||
self.chitchat_history_consec = '. '.join(chat)
|
#self.chitchat_history_consec = '. '.join(chat)
|
||||||
|
|
||||||
|
|
||||||
def get_consec_history(self):
|
#def get_consec_history(self):
|
||||||
return self.working_history_consec
|
# return self.working_history_consec
|
||||||
|
|
||||||
def get_chitchat_history(self):
|
#def get_chitchat_history(self):
|
||||||
return self.chitchat_history_consec
|
# return self.chitchat_history_consec
|
||||||
|
|
||||||
def get_sep_history(self):
|
#def get_sep_history(self):
|
||||||
return self.working_history_sep
|
# return self.working_history_sep
|
||||||
|
|
||||||
def get_recent_state(self):
|
#def get_recent_state(self):
|
||||||
return self.curr_state
|
# return self.curr_state
|
||||||
|
|
||||||
def get_dialogue_history(self):
|
#def get_dialogue_history(self):
|
||||||
return self.chat_history
|
# return self.chat_history
|
||||||
|
|
||||||
def update(self, new_state):
|
def update(self, new_state):
|
||||||
self.chat_history.append(new_state)
|
self.chat_history.append(new_state)
|
||||||
self.curr_state = new_state
|
self.curr_state = new_state
|
||||||
self.update_history()
|
self.update_history()
|
||||||
|
|
||||||
|
def update_history(self):
|
||||||
|
to_consider = [x['modified_query'] for x in self.chat_history[-self.max_history_length*2:]]
|
||||||
|
#self.working_history_consec = " . ".join(to_consider)
|
||||||
|
#self.working_history_sep = " ||| ".join(to_consider)
|
||||||
|
for utt in to_consider:
|
||||||
|
self.history = utt if len(self.history) == 0 else f"""{self.history}
|
||||||
|
{utt}"""
|
||||||
|
|
||||||
|
#user_last_utt = f"""{username}: {text}"""
|
||||||
|
#self.history = user_last_utt if len(self.history) == 0 else f"""{self.history}
|
||||||
|
#{user_last_utt}"""
|
||||||
|
|
||||||
|
def get_history(self):
|
||||||
|
return self.history
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def next_action(self):
|
def next_action(self):
|
||||||
if self.curr_state['help']:
|
if self.curr_state['help']:
|
||||||
return "Help"
|
return "Help"
|
||||||
|
|
|
@ -2,15 +2,15 @@ FROM python:3.8
|
||||||
|
|
||||||
WORKDIR /backend_janet
|
WORKDIR /backend_janet
|
||||||
|
|
||||||
COPY requirements_simple.txt .
|
COPY requirements_main.txt .
|
||||||
|
|
||||||
ARG version_info
|
ARG version_info
|
||||||
ENV FLASK_APP_VERSION_INFO=${version_info}
|
ENV FLASK_APP_VERSION_INFO=${version_info}
|
||||||
|
|
||||||
RUN pip install -r requirements_simple.txt
|
RUN pip install -r requirements_main.txt
|
||||||
|
|
||||||
RUN rm -fr /root/.cache/*
|
RUN rm -fr /root/.cache/*
|
||||||
|
|
||||||
COPY . .
|
COPY . .
|
||||||
|
|
||||||
ENTRYPOINT ["python", "main_simple.py"]
|
ENTRYPOINT ["python", "main.py"]
|
||||||
|
|
108
NLU.py
108
NLU.py
|
@ -1,17 +1,22 @@
|
||||||
import spacy
|
import spacy
|
||||||
import spacy_transformers
|
import spacy_transformers
|
||||||
import torch
|
import torch
|
||||||
|
import logging
|
||||||
|
|
||||||
class NLU:
|
class NLU:
|
||||||
def __init__(self, query_rewriter, coref_resolver, intent_classifier, offensive_classifier, entity_extractor, ambig_classifier):
|
def __init__(self, LLM_tokenizer, LLM_model):
|
||||||
|
"""
|
||||||
self.intent_classifier = intent_classifier
|
self.intent_classifier = intent_classifier
|
||||||
self.entity_extractor = entity_extractor
|
self.entity_extractor = entity_extractor
|
||||||
self.offensive_classifier = offensive_classifier
|
self.offensive_classifier = offensive_classifier
|
||||||
self.coref_resolver = coref_resolver
|
self.coref_resolver = coref_resolver
|
||||||
self.query_rewriter = query_rewriter
|
self.query_rewriter = query_rewriter
|
||||||
self.ambig_classifier = ambig_classifier
|
self.ambig_classifier = ambig_classifier
|
||||||
|
"""
|
||||||
|
self.tokenizer = LLM_tokenizer
|
||||||
|
self.model = LLM_model
|
||||||
|
|
||||||
|
"""
|
||||||
def _resolve_coref(self, history):
|
def _resolve_coref(self, history):
|
||||||
to_resolve = history + ' <COREF_SEP_TOKEN> ' + self.to_process
|
to_resolve = history + ' <COREF_SEP_TOKEN> ' + self.to_process
|
||||||
doc = self.coref_resolver(to_resolve)
|
doc = self.coref_resolver(to_resolve)
|
||||||
|
@ -20,13 +25,13 @@ class NLU:
|
||||||
clusters = [
|
clusters = [
|
||||||
val for key, val in doc.spans.items() if key.startswith("coref_cluster")
|
val for key, val in doc.spans.items() if key.startswith("coref_cluster")
|
||||||
]
|
]
|
||||||
"""
|
|
||||||
clusters = []
|
clusters = []
|
||||||
for cluster in cand_clusters:
|
for cluster in cand_clusters:
|
||||||
if cluster[0].text == "I":
|
if cluster[0].text == "I":
|
||||||
continue
|
continue
|
||||||
clusters.append(cluster)
|
clusters.append(cluster)
|
||||||
"""
|
|
||||||
# Iterate through every found cluster
|
# Iterate through every found cluster
|
||||||
for cluster in clusters:
|
for cluster in clusters:
|
||||||
first_mention = cluster[0]
|
first_mention = cluster[0]
|
||||||
|
@ -83,49 +88,68 @@ class NLU:
|
||||||
text = history + " ||| " + self.to_process
|
text = history + " ||| " + self.to_process
|
||||||
return self.query_rewriter(text)[0]['generated_text']
|
return self.query_rewriter(text)[0]['generated_text']
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
def process_utterance(self, utterance, history_consec, history_sep):
|
def process_utterance(self, history):
|
||||||
"""
|
"""
|
||||||
Query -> coref resolution & intent extraction -> if intents are not confident or if query is ambig -> rewrite query and recheck -> if still ambig, ask a clarifying question
|
Query -> coref resolution & intent extraction -> if intents are not confident or if query is ambig -> rewrite query and recheck -> if still ambig, ask a clarifying question
|
||||||
"""
|
"""
|
||||||
if utterance.lower() in ["help", "list resources", "list papers", "list datasets", "list topics"]:
|
#if utterance.lower() in ["help", "list resources", "list papers", "list datasets", "list topics"]:
|
||||||
return {"modified_query": utterance.lower(), "intent": "COMMAND", "entities": [], "is_offensive": False, "is_clear": True}
|
# return {"modified_query": utterance.lower(), "intent": "COMMAND", "entities": [], "is_offensive": False, "is_clear": True}
|
||||||
|
|
||||||
self.to_process = utterance
|
#self.to_process = utterance
|
||||||
|
|
||||||
self.to_process = self._resolve_coref(history_consec)
|
prompt = f"""You are Janet, the virtual assistant of the virtual research enviornment users.
|
||||||
|
What does the user eventually want given this dialogue, which is delimited with triple backticks?
|
||||||
|
Give your answer in one single sentence.
|
||||||
|
Dialogue: '''{history}'''
|
||||||
|
"""
|
||||||
|
|
||||||
intent, score = self._intentpredictor()
|
chat = [{ "role": "user", "content": prompt },]
|
||||||
|
prompt_chat = self.tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
|
||||||
|
inputs = self.tokenizer.encode(prompt_chat, add_special_tokens=False, return_tensors="pt")
|
||||||
|
outputs = self.model.generate(input_ids=inputs, max_new_tokens=150)
|
||||||
|
|
||||||
if score > 0.5:
|
goal = self.tokenizer.decode(outputs[0])
|
||||||
if intent == 'CHITCHAT':
|
logging.debug("User's goal is:" + goal)
|
||||||
self.to_process = utterance
|
|
||||||
entities = self._entityextractor()
|
#return goal.split("<start_of_turn>model\n")[-1].split("<eos>")[0]
|
||||||
offense = self._offensepredictor()
|
return {"modified_query": goal.split("<start_of_turn>model\n")[-1].split("<eos>")[0],
|
||||||
if intent in ['FINDPAPER', 'FINDDATASET', 'SUMMARIZEPAPER'] and len(entities) == 0:
|
"intent": "QA", "entities": [], "is_offensive": False, "is_clear": True}
|
||||||
return {"modified_query": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense, "is_clear": False}
|
|
||||||
return {"modified_query": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense, "is_clear": True}
|
#self.to_process = self._resolve_coref(history_consec)
|
||||||
else:
|
|
||||||
if self._ambigpredictor():
|
#intent, score = self._intentpredictor()
|
||||||
self.to_process = self._rewrite_query(history_sep)
|
|
||||||
intent, score = self._intentpredictor()
|
#if score > 0.5:
|
||||||
entities = self._entityextractor()
|
# if intent == 'CHITCHAT':
|
||||||
offense = self._offensepredictor()
|
# self.to_process = utterance
|
||||||
if score > 0.5 or not self._ambigpredictor():
|
# entities = self._entityextractor()
|
||||||
if intent == 'CHITCHAT':
|
# offense = self._offensepredictor()
|
||||||
self.to_process = utterance
|
# if intent in ['FINDPAPER', 'FINDDATASET', 'SUMMARIZEPAPER'] and len(entities) == 0:
|
||||||
if intent in ['FINDPAPER', 'FINDDATASET', 'SUMMARIZEPAPER'] and len(entities) == 0:
|
# return {"modified_query": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense, "is_clear": False}
|
||||||
return {"modified_query": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense, "is_clear": False}
|
# return {"modified_query": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense, "is_clear": True}
|
||||||
return {"modified_query": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense,
|
#else:
|
||||||
"is_clear": True}
|
# if self._ambigpredictor():
|
||||||
else:
|
# self.to_process = self._rewrite_query(history_sep)
|
||||||
return {"modified_query": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense,
|
# intent, score = self._intentpredictor()
|
||||||
"is_clear": False}
|
# entities = self._entityextractor()
|
||||||
else:
|
# offense = self._offensepredictor()
|
||||||
entities = self._entityextractor()
|
# if score > 0.5 or not self._ambigpredictor():
|
||||||
offense = self._offensepredictor()
|
# if intent == 'CHITCHAT':
|
||||||
if intent == 'CHITCHAT':
|
# self.to_process = utterance
|
||||||
self.to_process = utterance
|
# if intent in ['FINDPAPER', 'FINDDATASET', 'SUMMARIZEPAPER'] and len(entities) == 0:
|
||||||
if intent in ['FINDPAPER', 'FINDDATASET', 'SUMMARIZEPAPER'] and len(entities) == 0:
|
# return {"modified_query": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense, "is_clear": False}
|
||||||
return {"modified_query": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense, "is_clear": False}
|
# return {"modified_query": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense,
|
||||||
return {"modified_query": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense, "is_clear": True}
|
# "is_clear": True}
|
||||||
|
# else:
|
||||||
|
# return {"modified_query": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense,
|
||||||
|
# "is_clear": False}
|
||||||
|
# else:
|
||||||
|
# entities = self._entityextractor()
|
||||||
|
# offense = self._offensepredictor()
|
||||||
|
# if intent == 'CHITCHAT':
|
||||||
|
# self.to_process = utterance
|
||||||
|
# if intent in ['FINDPAPER', 'FINDDATASET', 'SUMMARIZEPAPER'] and len(entities) == 0:
|
||||||
|
# return {"modified_query": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense, "is_clear": False}
|
||||||
|
# return {"modified_query": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense, "is_clear": True}
|
||||||
|
|
|
@ -8,7 +8,7 @@ from datetime import datetime
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
|
|
||||||
class ResponseGenerator:
|
class ResponseGenerator:
|
||||||
def __init__(self, index, db,recommender,generators, retriever, num_retrieved=3):
|
def __init__(self, index=None, db=None,recommender=None,generators=None, retriever=None, num_retrieved=3):
|
||||||
self.generators = generators
|
self.generators = generators
|
||||||
self.retriever = retriever
|
self.retriever = retriever
|
||||||
self.recommender = recommender
|
self.recommender = recommender
|
||||||
|
|
137
main.py
137
main.py
|
@ -1,4 +1,5 @@
|
||||||
import os
|
import os
|
||||||
|
import logging
|
||||||
import re
|
import re
|
||||||
import warnings
|
import warnings
|
||||||
import faiss
|
import faiss
|
||||||
|
@ -10,7 +11,7 @@ import spacy
|
||||||
import requests
|
import requests
|
||||||
import spacy_transformers
|
import spacy_transformers
|
||||||
import torch
|
import torch
|
||||||
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline, AutoModelForCausalLM
|
||||||
from User import User
|
from User import User
|
||||||
from VRE import VRE
|
from VRE import VRE
|
||||||
from NLU import NLU
|
from NLU import NLU
|
||||||
|
@ -21,6 +22,9 @@ import pandas as pd
|
||||||
import time
|
import time
|
||||||
import threading
|
import threading
|
||||||
from sentence_transformers import SentenceTransformer
|
from sentence_transformers import SentenceTransformer
|
||||||
|
from huggingface_hub import login
|
||||||
|
|
||||||
|
login(token="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")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -36,46 +40,56 @@ alive = "alive"
|
||||||
|
|
||||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
device_flag = torch.cuda.current_device() if torch.cuda.is_available() else -1
|
device_flag = torch.cuda.current_device() if torch.cuda.is_available() else -1
|
||||||
|
model_id = "/models/google-gemma"
|
||||||
|
dtype = torch.bfloat16
|
||||||
|
|
||||||
query_rewriter = pipeline("text2text-generation", model="castorini/t5-base-canard")
|
#query_rewriter = pipeline("text2text-generation", model="castorini/t5-base-canard")
|
||||||
intent_classifier = pipeline("sentiment-analysis", model='/models/intent_classifier', device=device_flag)
|
#intent_classifier = pipeline("sentiment-analysis", model='/models/intent_classifier', device=device_flag)
|
||||||
entity_extractor = spacy.load("/models/entity_extractor")
|
#entity_extractor = spacy.load("/models/entity_extractor")
|
||||||
offensive_classifier = pipeline("sentiment-analysis", model='/models/offensive_classifier', device=device_flag)
|
#offensive_classifier = pipeline("sentiment-analysis", model='/models/offensive_classifier', device=device_flag)
|
||||||
ambig_classifier = pipeline("sentiment-analysis", model='/models/ambig_classifier', device=device_flag)
|
#ambig_classifier = pipeline("sentiment-analysis", model='/models/ambig_classifier', device=device_flag)
|
||||||
coref_resolver = spacy.load("en_coreference_web_trf")
|
#coref_resolver = spacy.load("en_coreference_web_trf")
|
||||||
|
|
||||||
nlu = NLU(query_rewriter, coref_resolver, intent_classifier, offensive_classifier, entity_extractor, ambig_classifier)
|
#LLM = pipeline("text2text-generation", model="/models/google-gemma", device=device_flag)
|
||||||
|
|
||||||
|
LLM_tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it")
|
||||||
|
LLM_model = AutoModelForCausalLM.from_pretrainedAutoModelForCausalLM.from_pretrained(
|
||||||
|
"google/gemma-2b-it",
|
||||||
|
torch_dtype=torch.bfloat16
|
||||||
|
)
|
||||||
|
|
||||||
|
nlu = NLU(LLM_tokenizer, LLM_model)
|
||||||
|
|
||||||
#load retriever and generator
|
#load retriever and generator
|
||||||
retriever = SentenceTransformer('/models/retriever/').to(device)
|
retriever = SentenceTransformer('/models/retriever/').to(device)
|
||||||
qa_generator = pipeline("text2text-generation", model="/models/train_qa", device=device_flag)
|
#qa_generator = pipeline("text2text-generation", model="/models/train_qa", device=device_flag)
|
||||||
summ_generator = pipeline("text2text-generation", model="/models/train_summ", device=device_flag)
|
#summ_generator = pipeline("text2text-generation", model="/models/train_summ", device=device_flag)
|
||||||
chat_generator = pipeline("text2text-generation", model="/models/train_chat", device=device_flag)
|
#chat_generator = pipeline("text2text-generation", model="/models/train_chat", device=device_flag)
|
||||||
amb_generator = pipeline("text2text-generation", model="/models/train_amb_gen", device=device_flag)
|
#amb_generator = pipeline("text2text-generation", model="/models/train_amb_gen", device=device_flag)
|
||||||
generators = {'qa': qa_generator,
|
#generators = {'qa': qa_generator,
|
||||||
'chat': chat_generator,
|
# 'chat': chat_generator,
|
||||||
'amb': amb_generator,
|
# 'amb': amb_generator,
|
||||||
'summ': summ_generator}
|
# 'summ': summ_generator}
|
||||||
rec = Recommender(retriever)
|
rec = Recommender(retriever)
|
||||||
|
|
||||||
def vre_fetch(token):
|
#def vre_fetch(token):
|
||||||
while True:
|
# while True:
|
||||||
try:
|
# try:
|
||||||
time.sleep(1000)
|
# time.sleep(1000)
|
||||||
print('getting new material')
|
# print('getting new material')
|
||||||
users[token]['vre'].get_vre_update()
|
# users[token]['vre'].get_vre_update()
|
||||||
users[token]['vre'].index_periodic_update()
|
# users[token]['vre'].index_periodic_update()
|
||||||
users[token]['rg'].update_index(vre.get_index())
|
# users[token]['rg'].update_index(vre.get_index())
|
||||||
users[token]['rg'].update_db(vre.get_db())
|
# users[token]['rg'].update_db(vre.get_db())
|
||||||
#vre.get_vre_update()
|
#vre.get_vre_update()
|
||||||
#vre.index_periodic_update()
|
#vre.index_periodic_update()
|
||||||
#rg.update_index(vre.get_index())
|
#rg.update_index(vre.get_index())
|
||||||
#rg.update_db(vre.get_db())
|
#rg.update_db(vre.get_db())
|
||||||
except Exception as e:
|
# except Exception as e:
|
||||||
alive = "dead_vre_fetch"
|
# alive = "dead_vre_fetch"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
def user_interest_decay(token):
|
def user_interest_decay(token):
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
|
@ -99,6 +113,7 @@ def clear_inactive():
|
||||||
users[username]['activity'] += 1
|
users[username]['activity'] += 1
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
alive = "dead_clear_inactive"
|
alive = "dead_clear_inactive"
|
||||||
|
"""
|
||||||
|
|
||||||
@app.route("/health", methods=['GET'])
|
@app.route("/health", methods=['GET'])
|
||||||
def health():
|
def health():
|
||||||
|
@ -113,10 +128,13 @@ def init_dm():
|
||||||
token = request.get_json().get("token")
|
token = request.get_json().get("token")
|
||||||
status = request.get_json().get("stat")
|
status = request.get_json().get("stat")
|
||||||
if status == "start":
|
if status == "start":
|
||||||
|
logging.debug("status=start")
|
||||||
message = {"stat": "waiting", "err": ""}
|
message = {"stat": "waiting", "err": ""}
|
||||||
elif status == "set":
|
elif status == "set":
|
||||||
|
logging.debug("status=set")
|
||||||
headers = {"gcube-token": token, "Accept": "application/json"}
|
headers = {"gcube-token": token, "Accept": "application/json"}
|
||||||
if token not in users:
|
if token not in users:
|
||||||
|
logging.debug("getting user info")
|
||||||
url = 'https://api.d4science.org/rest/2/people/profile'
|
url = 'https://api.d4science.org/rest/2/people/profile'
|
||||||
response = requests.get(url, headers=headers)
|
response = requests.get(url, headers=headers)
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
|
@ -128,12 +146,13 @@ def init_dm():
|
||||||
index = vre.get_index()
|
index = vre.get_index()
|
||||||
db = vre.get_db()
|
db = vre.get_db()
|
||||||
|
|
||||||
rg = ResponseGenerator(index,db, rec, generators, retriever)
|
rg = ResponseGenerator(index,db, rec, retriever=retriever)
|
||||||
|
|
||||||
users[token] = {'username': username, 'name': name, 'dm': DM(), 'activity': 0, 'user': User(username, token), 'vre': vre, 'rg': rg}
|
users[token] = {'username': username, 'name': name, 'dm': DM(), 'activity': 0, 'user': User(username, token),
|
||||||
|
'vre': vre, 'rg': rg}
|
||||||
|
|
||||||
threading.Thread(target=user_interest_decay, args=(token,), name='decayinterest_'+users[token]['username']).start()
|
#threading.Thread(target=user_interest_decay, args=(token,), name='decayinterest_'+users[token]['username']).start()
|
||||||
threading.Thread(target=vre_fetch, name='updatevre'+users[token]['username'], args=(token,)).start()
|
#threading.Thread(target=vre_fetch, name='updatevre'+users[token]['username'], args=(token,)).start()
|
||||||
message = {"stat": "done", "err": ""}
|
message = {"stat": "done", "err": ""}
|
||||||
else:
|
else:
|
||||||
message = {"stat": "rejected", "err": ""}
|
message = {"stat": "rejected", "err": ""}
|
||||||
|
@ -156,43 +175,55 @@ def predict():
|
||||||
message = {}
|
message = {}
|
||||||
try:
|
try:
|
||||||
if text == "<HELP_ON_START>":
|
if text == "<HELP_ON_START>":
|
||||||
|
logging.debug("help on start - inactive")
|
||||||
state = {'help': True, 'inactive': False, 'modified_query':"", 'intent':""}
|
state = {'help': True, 'inactive': False, 'modified_query':"", 'intent':""}
|
||||||
dm.update(state)
|
dm.update(state)
|
||||||
action = dm.next_action()
|
action = dm.next_action()
|
||||||
|
logging.debug("next action:" + action)
|
||||||
|
#response = "Hey " + users[token]['name'].split()[0] + "! it's Janet! I am here to help you make use of the datasets and papers in the catalogue of the VRE. I can answer questions whose answers may be inside the papers. I can summarize papers for you. I can also chat with you. So, whichever it is, I am ready to chat!"
|
||||||
response = rg.gen_response(action, vrename=vre.name, username=users[token]['username'], name=users[token]['name'].split()[0])
|
response = rg.gen_response(action, vrename=vre.name, username=users[token]['username'], name=users[token]['name'].split()[0])
|
||||||
message = {"answer": response}
|
message = {"answer": response}
|
||||||
elif text == "<RECOMMEND_ON_IDLE>":
|
elif text == "<RECOMMEND_ON_IDLE>":
|
||||||
|
logging.debug("recommend on idle - inactive")
|
||||||
state = {'help': False, 'inactive': True, 'modified_query':"recommed: ", 'intent':""}
|
state = {'help': False, 'inactive': True, 'modified_query':"recommed: ", 'intent':""}
|
||||||
dm.update(state)
|
dm.update(state)
|
||||||
action = dm.next_action()
|
action = dm.next_action()
|
||||||
|
logging.debug("next action:" + action)
|
||||||
|
#response = "Hey " + users[token]['name'].split()[0] + "! it's Janet! I am here to help you make use of the datasets and papers in the catalogue of the VRE. I can answer questions whose answers may be inside the papers. I can summarize papers for you. I can also chat with you. So, whichever it is, I am ready to chat!"
|
||||||
|
|
||||||
response = rg.gen_response(action, username=users[token]['username'],name=users[token]['name'].split()[0], vrename=vre.name)
|
response = rg.gen_response(action, username=users[token]['username'],name=users[token]['name'].split()[0], vrename=vre.name)
|
||||||
message = {"answer": response}
|
message = {"answer": response}
|
||||||
new_state = {'modified_query': response}
|
new_state = {'modified_query': "Janet: " + response}
|
||||||
dm.update(new_state)
|
dm.update(new_state)
|
||||||
else:
|
else:
|
||||||
state = nlu.process_utterance(text, dm.get_consec_history(), dm.get_sep_history())
|
state = nlu.process_utterance(f"""{dm.get_history()}
|
||||||
|
user: {text}""")
|
||||||
state['help'] = False
|
state['help'] = False
|
||||||
state['inactive'] = False
|
state['inactive'] = False
|
||||||
old_user_interests = user.get_user_interests()
|
#old_user_interests = user.get_user_interests()
|
||||||
old_vre_material = pd.concat([vre.db['paper_db'], vre.db['dataset_db']]).reset_index(drop=True)
|
#old_vre_material = pd.concat([vre.db['paper_db'], vre.db['dataset_db']]).reset_index(drop=True)
|
||||||
user_interests = []
|
#user_interests = []
|
||||||
for entity in state['entities']:
|
#for entity in state['entities']:
|
||||||
if entity['entity'] == 'TOPIC':
|
# if entity['entity'] == 'TOPIC':
|
||||||
user_interests.append(entity['value'])
|
# user_interests.append(entity['value'])
|
||||||
user.update_interests(user_interests)
|
#user.update_interests(user_interests)
|
||||||
new_user_interests = user.get_user_interests()
|
#new_user_interests = user.get_user_interests()
|
||||||
new_vre_material = pd.concat([vre.db['paper_db'], vre.db['dataset_db']]).reset_index(drop=True)
|
#new_vre_material = pd.concat([vre.db['paper_db'], vre.db['dataset_db']]).reset_index(drop=True)
|
||||||
if (new_user_interests != old_user_interests or len(old_vre_material) != len(new_vre_material)):
|
#if (new_user_interests != old_user_interests or len(old_vre_material) != len(new_vre_material)):
|
||||||
rec.generate_recommendations(users[token]['username'], new_user_interests, new_vre_material)
|
# rec.generate_recommendations(users[token]['username'], new_user_interests, new_vre_material)
|
||||||
dm.update(state)
|
dm.update(state)
|
||||||
action = dm.next_action()
|
action = dm.next_action()
|
||||||
response = rg.gen_response(action=action, utterance=state['modified_query'], state=dm.get_recent_state(), consec_history=dm.get_consec_history(), chitchat_history=dm.get_chitchat_history(), vrename=vre.name, username=users[token]['username'], name=users[token]['name'].split()[0])
|
logging.debug("Next action: " + action)
|
||||||
message = {"answer": response, "query": text, "cand": "candidate", "history": dm.get_consec_history(), "modQuery": state['modified_query']}
|
#response = rg.gen_response(action=action, utterance=state['modified_query'], state=dm.get_recent_state(), consec_history=dm.get_consec_history(), chitchat_history=dm.get_chitchat_history(), vrename=vre.name, username=users[token]['username'], name=users[token]['name'].split()[0])
|
||||||
if state['intent'] == "QA":
|
#message = {"answer": response, "query": text, "cand": "candidate", "history": dm.get_consec_history(), "modQuery": state['modified_query']}
|
||||||
split_response = response.split("_______ \n ")
|
message = {"answer": state['modified_query'], "query": text, "cand": "candidate", "history": dm.get_history(), "modQuery": state['modified_query']}
|
||||||
if len(split_response) > 1:
|
|
||||||
response = split_response[1]
|
#if state['intent'] == "QA":
|
||||||
new_state = {'modified_query': response, 'intent': state['intent']}
|
# split_response = response.split("_______ \n ")
|
||||||
|
# if len(split_response) > 1:
|
||||||
|
# response = split_response[1]
|
||||||
|
response =state['modified_query']
|
||||||
|
new_state = {'modified_query': "Janet: " + response, 'intent': state['intent']}
|
||||||
dm.update(new_state)
|
dm.update(new_state)
|
||||||
reply = jsonify(message)
|
reply = jsonify(message)
|
||||||
users[token]['dm'] = dm
|
users[token]['dm'] = dm
|
||||||
|
@ -231,7 +262,7 @@ def feedback():
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
warnings.filterwarnings("ignore")
|
warnings.filterwarnings("ignore")
|
||||||
threading.Thread(target=clear_inactive, name='clear').start()
|
#threading.Thread(target=clear_inactive, name='clear').start()
|
||||||
"""
|
"""
|
||||||
conn = psycopg2.connect(host="janet-pg", database=os.getenv("POSTGRES_DB"), user=os.getenv("POSTGRES_USER"), password=os.getenv("POSTGRES_PASSWORD"))
|
conn = psycopg2.connect(host="janet-pg", database=os.getenv("POSTGRES_DB"), user=os.getenv("POSTGRES_USER"), password=os.getenv("POSTGRES_PASSWORD"))
|
||||||
|
|
||||||
|
|
|
@ -33,6 +33,7 @@ markupsafe==2.0.1
|
||||||
psycopg2==2.9.5
|
psycopg2==2.9.5
|
||||||
en-coreference-web-trf @ https://github.com/explosion/spacy-experimental/releases/download/v0.6.1/en_coreference_web_trf-3.4.0a2-py3-none-any.whl
|
en-coreference-web-trf @ https://github.com/explosion/spacy-experimental/releases/download/v0.6.1/en_coreference_web_trf-3.4.0a2-py3-none-any.whl
|
||||||
datasets
|
datasets
|
||||||
|
huggingface_hub
|
||||||
Werkzeug==1.0.1
|
Werkzeug==1.0.1
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue