test google gemma
This commit is contained in:
parent
73117674ad
commit
1efd0ac18d
|
@ -1,3 +1,5 @@
|
|||
janet.pdf
|
||||
__pycache__/
|
||||
git-filter-repo
|
||||
.gitignore
|
||||
ahmed.ibrahim39699_interests.json
|
||||
|
|
60
DM.py
60
DM.py
|
@ -6,43 +6,61 @@ class DM:
|
|||
self.working_history_consec = ""
|
||||
self.chitchat_history_consec = ""
|
||||
self.max_history_length = max_history_length
|
||||
self.history = ""
|
||||
self.chat_history = []
|
||||
self.curr_state = None
|
||||
|
||||
def update_history(self):
|
||||
to_consider = [x['modified_query'] for x in self.chat_history[-self.max_history_length*2:]]
|
||||
self.working_history_consec = " . ".join(to_consider)
|
||||
self.working_history_sep = " ||| ".join(to_consider)
|
||||
#def update_history(self):
|
||||
#to_consider = [x['modified_query'] for x in self.chat_history[-self.max_history_length*2:]]
|
||||
#self.working_history_consec = " . ".join(to_consider)
|
||||
#self.working_history_sep = " ||| ".join(to_consider)
|
||||
|
||||
chat = []
|
||||
for utt in self.chat_history:
|
||||
if utt['intent'] == 'CHITCHAT':
|
||||
if len(chat) == 4:
|
||||
chat = chat[1:]
|
||||
chat.append(utt['modified_query'])
|
||||
self.chitchat_history_consec = '. '.join(chat)
|
||||
#chat = []
|
||||
#for utt in self.chat_history:
|
||||
# if utt['intent'] == 'CHITCHAT':
|
||||
# if len(chat) == 4:
|
||||
# chat = chat[1:]
|
||||
# chat.append(utt['modified_query'])
|
||||
#self.chitchat_history_consec = '. '.join(chat)
|
||||
|
||||
|
||||
def get_consec_history(self):
|
||||
return self.working_history_consec
|
||||
#def get_consec_history(self):
|
||||
# return self.working_history_consec
|
||||
|
||||
def get_chitchat_history(self):
|
||||
return self.chitchat_history_consec
|
||||
#def get_chitchat_history(self):
|
||||
# return self.chitchat_history_consec
|
||||
|
||||
def get_sep_history(self):
|
||||
return self.working_history_sep
|
||||
#def get_sep_history(self):
|
||||
# return self.working_history_sep
|
||||
|
||||
def get_recent_state(self):
|
||||
return self.curr_state
|
||||
#def get_recent_state(self):
|
||||
# return self.curr_state
|
||||
|
||||
def get_dialogue_history(self):
|
||||
return self.chat_history
|
||||
#def get_dialogue_history(self):
|
||||
# return self.chat_history
|
||||
|
||||
def update(self, new_state):
|
||||
self.chat_history.append(new_state)
|
||||
self.curr_state = new_state
|
||||
self.update_history()
|
||||
|
||||
def update_history(self):
|
||||
to_consider = [x['modified_query'] for x in self.chat_history[-self.max_history_length*2:]]
|
||||
#self.working_history_consec = " . ".join(to_consider)
|
||||
#self.working_history_sep = " ||| ".join(to_consider)
|
||||
for utt in to_consider:
|
||||
self.history = utt if len(self.history) == 0 else f"""{self.history}
|
||||
{utt}"""
|
||||
|
||||
#user_last_utt = f"""{username}: {text}"""
|
||||
#self.history = user_last_utt if len(self.history) == 0 else f"""{self.history}
|
||||
#{user_last_utt}"""
|
||||
|
||||
def get_history(self):
|
||||
return self.history
|
||||
|
||||
|
||||
|
||||
def next_action(self):
|
||||
if self.curr_state['help']:
|
||||
return "Help"
|
||||
|
|
|
@ -2,15 +2,15 @@ FROM python:3.8
|
|||
|
||||
WORKDIR /backend_janet
|
||||
|
||||
COPY requirements_simple.txt .
|
||||
COPY requirements_main.txt .
|
||||
|
||||
ARG version_info
|
||||
ENV FLASK_APP_VERSION_INFO=${version_info}
|
||||
|
||||
RUN pip install -r requirements_simple.txt
|
||||
RUN pip install -r requirements_main.txt
|
||||
|
||||
RUN rm -fr /root/.cache/*
|
||||
|
||||
COPY . .
|
||||
|
||||
ENTRYPOINT ["python", "main_simple.py"]
|
||||
ENTRYPOINT ["python", "main.py"]
|
||||
|
|
108
NLU.py
108
NLU.py
|
@ -1,17 +1,22 @@
|
|||
import spacy
|
||||
import spacy_transformers
|
||||
import torch
|
||||
import logging
|
||||
|
||||
class NLU:
|
||||
def __init__(self, query_rewriter, coref_resolver, intent_classifier, offensive_classifier, entity_extractor, ambig_classifier):
|
||||
|
||||
def __init__(self, LLM_tokenizer, LLM_model):
|
||||
"""
|
||||
self.intent_classifier = intent_classifier
|
||||
self.entity_extractor = entity_extractor
|
||||
self.offensive_classifier = offensive_classifier
|
||||
self.coref_resolver = coref_resolver
|
||||
self.query_rewriter = query_rewriter
|
||||
self.ambig_classifier = ambig_classifier
|
||||
"""
|
||||
self.tokenizer = LLM_tokenizer
|
||||
self.model = LLM_model
|
||||
|
||||
"""
|
||||
def _resolve_coref(self, history):
|
||||
to_resolve = history + ' <COREF_SEP_TOKEN> ' + self.to_process
|
||||
doc = self.coref_resolver(to_resolve)
|
||||
|
@ -20,13 +25,13 @@ class NLU:
|
|||
clusters = [
|
||||
val for key, val in doc.spans.items() if key.startswith("coref_cluster")
|
||||
]
|
||||
"""
|
||||
|
||||
clusters = []
|
||||
for cluster in cand_clusters:
|
||||
if cluster[0].text == "I":
|
||||
continue
|
||||
clusters.append(cluster)
|
||||
"""
|
||||
|
||||
# Iterate through every found cluster
|
||||
for cluster in clusters:
|
||||
first_mention = cluster[0]
|
||||
|
@ -83,49 +88,68 @@ class NLU:
|
|||
text = history + " ||| " + self.to_process
|
||||
return self.query_rewriter(text)[0]['generated_text']
|
||||
|
||||
"""
|
||||
|
||||
def process_utterance(self, utterance, history_consec, history_sep):
|
||||
def process_utterance(self, history):
|
||||
"""
|
||||
Query -> coref resolution & intent extraction -> if intents are not confident or if query is ambig -> rewrite query and recheck -> if still ambig, ask a clarifying question
|
||||
"""
|
||||
if utterance.lower() in ["help", "list resources", "list papers", "list datasets", "list topics"]:
|
||||
return {"modified_query": utterance.lower(), "intent": "COMMAND", "entities": [], "is_offensive": False, "is_clear": True}
|
||||
#if utterance.lower() in ["help", "list resources", "list papers", "list datasets", "list topics"]:
|
||||
# return {"modified_query": utterance.lower(), "intent": "COMMAND", "entities": [], "is_offensive": False, "is_clear": True}
|
||||
|
||||
self.to_process = utterance
|
||||
#self.to_process = utterance
|
||||
|
||||
self.to_process = self._resolve_coref(history_consec)
|
||||
prompt = f"""You are Janet, the virtual assistant of the virtual research enviornment users.
|
||||
What does the user eventually want given this dialogue, which is delimited with triple backticks?
|
||||
Give your answer in one single sentence.
|
||||
Dialogue: '''{history}'''
|
||||
"""
|
||||
|
||||
intent, score = self._intentpredictor()
|
||||
chat = [{ "role": "user", "content": prompt },]
|
||||
prompt_chat = self.tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
|
||||
inputs = self.tokenizer.encode(prompt_chat, add_special_tokens=False, return_tensors="pt")
|
||||
outputs = self.model.generate(input_ids=inputs, max_new_tokens=150)
|
||||
|
||||
if score > 0.5:
|
||||
if intent == 'CHITCHAT':
|
||||
self.to_process = utterance
|
||||
entities = self._entityextractor()
|
||||
offense = self._offensepredictor()
|
||||
if intent in ['FINDPAPER', 'FINDDATASET', 'SUMMARIZEPAPER'] and len(entities) == 0:
|
||||
return {"modified_query": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense, "is_clear": False}
|
||||
return {"modified_query": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense, "is_clear": True}
|
||||
else:
|
||||
if self._ambigpredictor():
|
||||
self.to_process = self._rewrite_query(history_sep)
|
||||
intent, score = self._intentpredictor()
|
||||
entities = self._entityextractor()
|
||||
offense = self._offensepredictor()
|
||||
if score > 0.5 or not self._ambigpredictor():
|
||||
if intent == 'CHITCHAT':
|
||||
self.to_process = utterance
|
||||
if intent in ['FINDPAPER', 'FINDDATASET', 'SUMMARIZEPAPER'] and len(entities) == 0:
|
||||
return {"modified_query": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense, "is_clear": False}
|
||||
return {"modified_query": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense,
|
||||
"is_clear": True}
|
||||
else:
|
||||
return {"modified_query": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense,
|
||||
"is_clear": False}
|
||||
else:
|
||||
entities = self._entityextractor()
|
||||
offense = self._offensepredictor()
|
||||
if intent == 'CHITCHAT':
|
||||
self.to_process = utterance
|
||||
if intent in ['FINDPAPER', 'FINDDATASET', 'SUMMARIZEPAPER'] and len(entities) == 0:
|
||||
return {"modified_query": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense, "is_clear": False}
|
||||
return {"modified_query": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense, "is_clear": True}
|
||||
goal = self.tokenizer.decode(outputs[0])
|
||||
logging.debug("User's goal is:" + goal)
|
||||
|
||||
#return goal.split("<start_of_turn>model\n")[-1].split("<eos>")[0]
|
||||
return {"modified_query": goal.split("<start_of_turn>model\n")[-1].split("<eos>")[0],
|
||||
"intent": "QA", "entities": [], "is_offensive": False, "is_clear": True}
|
||||
|
||||
#self.to_process = self._resolve_coref(history_consec)
|
||||
|
||||
#intent, score = self._intentpredictor()
|
||||
|
||||
#if score > 0.5:
|
||||
# if intent == 'CHITCHAT':
|
||||
# self.to_process = utterance
|
||||
# entities = self._entityextractor()
|
||||
# offense = self._offensepredictor()
|
||||
# if intent in ['FINDPAPER', 'FINDDATASET', 'SUMMARIZEPAPER'] and len(entities) == 0:
|
||||
# return {"modified_query": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense, "is_clear": False}
|
||||
# return {"modified_query": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense, "is_clear": True}
|
||||
#else:
|
||||
# if self._ambigpredictor():
|
||||
# self.to_process = self._rewrite_query(history_sep)
|
||||
# intent, score = self._intentpredictor()
|
||||
# entities = self._entityextractor()
|
||||
# offense = self._offensepredictor()
|
||||
# if score > 0.5 or not self._ambigpredictor():
|
||||
# if intent == 'CHITCHAT':
|
||||
# self.to_process = utterance
|
||||
# if intent in ['FINDPAPER', 'FINDDATASET', 'SUMMARIZEPAPER'] and len(entities) == 0:
|
||||
# return {"modified_query": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense, "is_clear": False}
|
||||
# return {"modified_query": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense,
|
||||
# "is_clear": True}
|
||||
# else:
|
||||
# return {"modified_query": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense,
|
||||
# "is_clear": False}
|
||||
# else:
|
||||
# entities = self._entityextractor()
|
||||
# offense = self._offensepredictor()
|
||||
# if intent == 'CHITCHAT':
|
||||
# self.to_process = utterance
|
||||
# if intent in ['FINDPAPER', 'FINDDATASET', 'SUMMARIZEPAPER'] and len(entities) == 0:
|
||||
# return {"modified_query": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense, "is_clear": False}
|
||||
# return {"modified_query": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense, "is_clear": True}
|
||||
|
|
|
@ -8,7 +8,7 @@ from datetime import datetime
|
|||
from datasets import Dataset
|
||||
|
||||
class ResponseGenerator:
|
||||
def __init__(self, index, db,recommender,generators, retriever, num_retrieved=3):
|
||||
def __init__(self, index=None, db=None,recommender=None,generators=None, retriever=None, num_retrieved=3):
|
||||
self.generators = generators
|
||||
self.retriever = retriever
|
||||
self.recommender = recommender
|
||||
|
|
137
main.py
137
main.py
|
@ -1,4 +1,5 @@
|
|||
import os
|
||||
import logging
|
||||
import re
|
||||
import warnings
|
||||
import faiss
|
||||
|
@ -10,7 +11,7 @@ import spacy
|
|||
import requests
|
||||
import spacy_transformers
|
||||
import torch
|
||||
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
|
||||
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline, AutoModelForCausalLM
|
||||
from User import User
|
||||
from VRE import VRE
|
||||
from NLU import NLU
|
||||
|
@ -21,6 +22,9 @@ import pandas as pd
|
|||
import time
|
||||
import threading
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from huggingface_hub import login
|
||||
|
||||
login(token="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")
|
||||
|
||||
|
||||
|
||||
|
@ -36,46 +40,56 @@ alive = "alive"
|
|||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
device_flag = torch.cuda.current_device() if torch.cuda.is_available() else -1
|
||||
model_id = "/models/google-gemma"
|
||||
dtype = torch.bfloat16
|
||||
|
||||
query_rewriter = pipeline("text2text-generation", model="castorini/t5-base-canard")
|
||||
intent_classifier = pipeline("sentiment-analysis", model='/models/intent_classifier', device=device_flag)
|
||||
entity_extractor = spacy.load("/models/entity_extractor")
|
||||
offensive_classifier = pipeline("sentiment-analysis", model='/models/offensive_classifier', device=device_flag)
|
||||
ambig_classifier = pipeline("sentiment-analysis", model='/models/ambig_classifier', device=device_flag)
|
||||
coref_resolver = spacy.load("en_coreference_web_trf")
|
||||
#query_rewriter = pipeline("text2text-generation", model="castorini/t5-base-canard")
|
||||
#intent_classifier = pipeline("sentiment-analysis", model='/models/intent_classifier', device=device_flag)
|
||||
#entity_extractor = spacy.load("/models/entity_extractor")
|
||||
#offensive_classifier = pipeline("sentiment-analysis", model='/models/offensive_classifier', device=device_flag)
|
||||
#ambig_classifier = pipeline("sentiment-analysis", model='/models/ambig_classifier', device=device_flag)
|
||||
#coref_resolver = spacy.load("en_coreference_web_trf")
|
||||
|
||||
nlu = NLU(query_rewriter, coref_resolver, intent_classifier, offensive_classifier, entity_extractor, ambig_classifier)
|
||||
#LLM = pipeline("text2text-generation", model="/models/google-gemma", device=device_flag)
|
||||
|
||||
LLM_tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it")
|
||||
LLM_model = AutoModelForCausalLM.from_pretrainedAutoModelForCausalLM.from_pretrained(
|
||||
"google/gemma-2b-it",
|
||||
torch_dtype=torch.bfloat16
|
||||
)
|
||||
|
||||
nlu = NLU(LLM_tokenizer, LLM_model)
|
||||
|
||||
#load retriever and generator
|
||||
retriever = SentenceTransformer('/models/retriever/').to(device)
|
||||
qa_generator = pipeline("text2text-generation", model="/models/train_qa", device=device_flag)
|
||||
summ_generator = pipeline("text2text-generation", model="/models/train_summ", device=device_flag)
|
||||
chat_generator = pipeline("text2text-generation", model="/models/train_chat", device=device_flag)
|
||||
amb_generator = pipeline("text2text-generation", model="/models/train_amb_gen", device=device_flag)
|
||||
generators = {'qa': qa_generator,
|
||||
'chat': chat_generator,
|
||||
'amb': amb_generator,
|
||||
'summ': summ_generator}
|
||||
#qa_generator = pipeline("text2text-generation", model="/models/train_qa", device=device_flag)
|
||||
#summ_generator = pipeline("text2text-generation", model="/models/train_summ", device=device_flag)
|
||||
#chat_generator = pipeline("text2text-generation", model="/models/train_chat", device=device_flag)
|
||||
#amb_generator = pipeline("text2text-generation", model="/models/train_amb_gen", device=device_flag)
|
||||
#generators = {'qa': qa_generator,
|
||||
# 'chat': chat_generator,
|
||||
# 'amb': amb_generator,
|
||||
# 'summ': summ_generator}
|
||||
rec = Recommender(retriever)
|
||||
|
||||
def vre_fetch(token):
|
||||
while True:
|
||||
try:
|
||||
time.sleep(1000)
|
||||
print('getting new material')
|
||||
users[token]['vre'].get_vre_update()
|
||||
users[token]['vre'].index_periodic_update()
|
||||
users[token]['rg'].update_index(vre.get_index())
|
||||
users[token]['rg'].update_db(vre.get_db())
|
||||
#def vre_fetch(token):
|
||||
# while True:
|
||||
# try:
|
||||
# time.sleep(1000)
|
||||
# print('getting new material')
|
||||
# users[token]['vre'].get_vre_update()
|
||||
# users[token]['vre'].index_periodic_update()
|
||||
# users[token]['rg'].update_index(vre.get_index())
|
||||
# users[token]['rg'].update_db(vre.get_db())
|
||||
#vre.get_vre_update()
|
||||
#vre.index_periodic_update()
|
||||
#rg.update_index(vre.get_index())
|
||||
#rg.update_db(vre.get_db())
|
||||
except Exception as e:
|
||||
alive = "dead_vre_fetch"
|
||||
|
||||
# except Exception as e:
|
||||
# alive = "dead_vre_fetch"
|
||||
|
||||
|
||||
"""
|
||||
def user_interest_decay(token):
|
||||
while True:
|
||||
try:
|
||||
|
@ -99,6 +113,7 @@ def clear_inactive():
|
|||
users[username]['activity'] += 1
|
||||
except Exception as e:
|
||||
alive = "dead_clear_inactive"
|
||||
"""
|
||||
|
||||
@app.route("/health", methods=['GET'])
|
||||
def health():
|
||||
|
@ -113,10 +128,13 @@ def init_dm():
|
|||
token = request.get_json().get("token")
|
||||
status = request.get_json().get("stat")
|
||||
if status == "start":
|
||||
logging.debug("status=start")
|
||||
message = {"stat": "waiting", "err": ""}
|
||||
elif status == "set":
|
||||
logging.debug("status=set")
|
||||
headers = {"gcube-token": token, "Accept": "application/json"}
|
||||
if token not in users:
|
||||
logging.debug("getting user info")
|
||||
url = 'https://api.d4science.org/rest/2/people/profile'
|
||||
response = requests.get(url, headers=headers)
|
||||
if response.status_code == 200:
|
||||
|
@ -128,12 +146,13 @@ def init_dm():
|
|||
index = vre.get_index()
|
||||
db = vre.get_db()
|
||||
|
||||
rg = ResponseGenerator(index,db, rec, generators, retriever)
|
||||
rg = ResponseGenerator(index,db, rec, retriever=retriever)
|
||||
|
||||
users[token] = {'username': username, 'name': name, 'dm': DM(), 'activity': 0, 'user': User(username, token), 'vre': vre, 'rg': rg}
|
||||
users[token] = {'username': username, 'name': name, 'dm': DM(), 'activity': 0, 'user': User(username, token),
|
||||
'vre': vre, 'rg': rg}
|
||||
|
||||
threading.Thread(target=user_interest_decay, args=(token,), name='decayinterest_'+users[token]['username']).start()
|
||||
threading.Thread(target=vre_fetch, name='updatevre'+users[token]['username'], args=(token,)).start()
|
||||
#threading.Thread(target=user_interest_decay, args=(token,), name='decayinterest_'+users[token]['username']).start()
|
||||
#threading.Thread(target=vre_fetch, name='updatevre'+users[token]['username'], args=(token,)).start()
|
||||
message = {"stat": "done", "err": ""}
|
||||
else:
|
||||
message = {"stat": "rejected", "err": ""}
|
||||
|
@ -156,43 +175,55 @@ def predict():
|
|||
message = {}
|
||||
try:
|
||||
if text == "<HELP_ON_START>":
|
||||
logging.debug("help on start - inactive")
|
||||
state = {'help': True, 'inactive': False, 'modified_query':"", 'intent':""}
|
||||
dm.update(state)
|
||||
action = dm.next_action()
|
||||
logging.debug("next action:" + action)
|
||||
#response = "Hey " + users[token]['name'].split()[0] + "! it's Janet! I am here to help you make use of the datasets and papers in the catalogue of the VRE. I can answer questions whose answers may be inside the papers. I can summarize papers for you. I can also chat with you. So, whichever it is, I am ready to chat!"
|
||||
response = rg.gen_response(action, vrename=vre.name, username=users[token]['username'], name=users[token]['name'].split()[0])
|
||||
message = {"answer": response}
|
||||
elif text == "<RECOMMEND_ON_IDLE>":
|
||||
logging.debug("recommend on idle - inactive")
|
||||
state = {'help': False, 'inactive': True, 'modified_query':"recommed: ", 'intent':""}
|
||||
dm.update(state)
|
||||
action = dm.next_action()
|
||||
logging.debug("next action:" + action)
|
||||
#response = "Hey " + users[token]['name'].split()[0] + "! it's Janet! I am here to help you make use of the datasets and papers in the catalogue of the VRE. I can answer questions whose answers may be inside the papers. I can summarize papers for you. I can also chat with you. So, whichever it is, I am ready to chat!"
|
||||
|
||||
response = rg.gen_response(action, username=users[token]['username'],name=users[token]['name'].split()[0], vrename=vre.name)
|
||||
message = {"answer": response}
|
||||
new_state = {'modified_query': response}
|
||||
new_state = {'modified_query': "Janet: " + response}
|
||||
dm.update(new_state)
|
||||
else:
|
||||
state = nlu.process_utterance(text, dm.get_consec_history(), dm.get_sep_history())
|
||||
state = nlu.process_utterance(f"""{dm.get_history()}
|
||||
user: {text}""")
|
||||
state['help'] = False
|
||||
state['inactive'] = False
|
||||
old_user_interests = user.get_user_interests()
|
||||
old_vre_material = pd.concat([vre.db['paper_db'], vre.db['dataset_db']]).reset_index(drop=True)
|
||||
user_interests = []
|
||||
for entity in state['entities']:
|
||||
if entity['entity'] == 'TOPIC':
|
||||
user_interests.append(entity['value'])
|
||||
user.update_interests(user_interests)
|
||||
new_user_interests = user.get_user_interests()
|
||||
new_vre_material = pd.concat([vre.db['paper_db'], vre.db['dataset_db']]).reset_index(drop=True)
|
||||
if (new_user_interests != old_user_interests or len(old_vre_material) != len(new_vre_material)):
|
||||
rec.generate_recommendations(users[token]['username'], new_user_interests, new_vre_material)
|
||||
#old_user_interests = user.get_user_interests()
|
||||
#old_vre_material = pd.concat([vre.db['paper_db'], vre.db['dataset_db']]).reset_index(drop=True)
|
||||
#user_interests = []
|
||||
#for entity in state['entities']:
|
||||
# if entity['entity'] == 'TOPIC':
|
||||
# user_interests.append(entity['value'])
|
||||
#user.update_interests(user_interests)
|
||||
#new_user_interests = user.get_user_interests()
|
||||
#new_vre_material = pd.concat([vre.db['paper_db'], vre.db['dataset_db']]).reset_index(drop=True)
|
||||
#if (new_user_interests != old_user_interests or len(old_vre_material) != len(new_vre_material)):
|
||||
# rec.generate_recommendations(users[token]['username'], new_user_interests, new_vre_material)
|
||||
dm.update(state)
|
||||
action = dm.next_action()
|
||||
response = rg.gen_response(action=action, utterance=state['modified_query'], state=dm.get_recent_state(), consec_history=dm.get_consec_history(), chitchat_history=dm.get_chitchat_history(), vrename=vre.name, username=users[token]['username'], name=users[token]['name'].split()[0])
|
||||
message = {"answer": response, "query": text, "cand": "candidate", "history": dm.get_consec_history(), "modQuery": state['modified_query']}
|
||||
if state['intent'] == "QA":
|
||||
split_response = response.split("_______ \n ")
|
||||
if len(split_response) > 1:
|
||||
response = split_response[1]
|
||||
new_state = {'modified_query': response, 'intent': state['intent']}
|
||||
logging.debug("Next action: " + action)
|
||||
#response = rg.gen_response(action=action, utterance=state['modified_query'], state=dm.get_recent_state(), consec_history=dm.get_consec_history(), chitchat_history=dm.get_chitchat_history(), vrename=vre.name, username=users[token]['username'], name=users[token]['name'].split()[0])
|
||||
#message = {"answer": response, "query": text, "cand": "candidate", "history": dm.get_consec_history(), "modQuery": state['modified_query']}
|
||||
message = {"answer": state['modified_query'], "query": text, "cand": "candidate", "history": dm.get_history(), "modQuery": state['modified_query']}
|
||||
|
||||
#if state['intent'] == "QA":
|
||||
# split_response = response.split("_______ \n ")
|
||||
# if len(split_response) > 1:
|
||||
# response = split_response[1]
|
||||
response =state['modified_query']
|
||||
new_state = {'modified_query': "Janet: " + response, 'intent': state['intent']}
|
||||
dm.update(new_state)
|
||||
reply = jsonify(message)
|
||||
users[token]['dm'] = dm
|
||||
|
@ -231,7 +262,7 @@ def feedback():
|
|||
|
||||
if __name__ == "__main__":
|
||||
warnings.filterwarnings("ignore")
|
||||
threading.Thread(target=clear_inactive, name='clear').start()
|
||||
#threading.Thread(target=clear_inactive, name='clear').start()
|
||||
"""
|
||||
conn = psycopg2.connect(host="janet-pg", database=os.getenv("POSTGRES_DB"), user=os.getenv("POSTGRES_USER"), password=os.getenv("POSTGRES_PASSWORD"))
|
||||
|
||||
|
|
|
@ -33,6 +33,7 @@ markupsafe==2.0.1
|
|||
psycopg2==2.9.5
|
||||
en-coreference-web-trf @ https://github.com/explosion/spacy-experimental/releases/download/v0.6.1/en_coreference_web_trf-3.4.0a2-py3-none-any.whl
|
||||
datasets
|
||||
huggingface_hub
|
||||
Werkzeug==1.0.1
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue