test google gemma

This commit is contained in:
Ahmed Salah Tawfik Ibrahim 2024-05-30 19:09:54 +02:00
parent 73117674ad
commit 1efd0ac18d
7 changed files with 199 additions and 123 deletions

2
.gitignore vendored
View File

@ -1,3 +1,5 @@
janet.pdf
__pycache__/
git-filter-repo
.gitignore
ahmed.ibrahim39699_interests.json

60
DM.py
View File

@ -6,43 +6,61 @@ class DM:
self.working_history_consec = ""
self.chitchat_history_consec = ""
self.max_history_length = max_history_length
self.history = ""
self.chat_history = []
self.curr_state = None
def update_history(self):
to_consider = [x['modified_query'] for x in self.chat_history[-self.max_history_length*2:]]
self.working_history_consec = " . ".join(to_consider)
self.working_history_sep = " ||| ".join(to_consider)
#def update_history(self):
#to_consider = [x['modified_query'] for x in self.chat_history[-self.max_history_length*2:]]
#self.working_history_consec = " . ".join(to_consider)
#self.working_history_sep = " ||| ".join(to_consider)
chat = []
for utt in self.chat_history:
if utt['intent'] == 'CHITCHAT':
if len(chat) == 4:
chat = chat[1:]
chat.append(utt['modified_query'])
self.chitchat_history_consec = '. '.join(chat)
#chat = []
#for utt in self.chat_history:
# if utt['intent'] == 'CHITCHAT':
# if len(chat) == 4:
# chat = chat[1:]
# chat.append(utt['modified_query'])
#self.chitchat_history_consec = '. '.join(chat)
def get_consec_history(self):
return self.working_history_consec
#def get_consec_history(self):
# return self.working_history_consec
def get_chitchat_history(self):
return self.chitchat_history_consec
#def get_chitchat_history(self):
# return self.chitchat_history_consec
def get_sep_history(self):
return self.working_history_sep
#def get_sep_history(self):
# return self.working_history_sep
def get_recent_state(self):
return self.curr_state
#def get_recent_state(self):
# return self.curr_state
def get_dialogue_history(self):
return self.chat_history
#def get_dialogue_history(self):
# return self.chat_history
def update(self, new_state):
self.chat_history.append(new_state)
self.curr_state = new_state
self.update_history()
def update_history(self):
to_consider = [x['modified_query'] for x in self.chat_history[-self.max_history_length*2:]]
#self.working_history_consec = " . ".join(to_consider)
#self.working_history_sep = " ||| ".join(to_consider)
for utt in to_consider:
self.history = utt if len(self.history) == 0 else f"""{self.history}
{utt}"""
#user_last_utt = f"""{username}: {text}"""
#self.history = user_last_utt if len(self.history) == 0 else f"""{self.history}
#{user_last_utt}"""
def get_history(self):
return self.history
def next_action(self):
if self.curr_state['help']:
return "Help"

View File

@ -2,15 +2,15 @@ FROM python:3.8
WORKDIR /backend_janet
COPY requirements_simple.txt .
COPY requirements_main.txt .
ARG version_info
ENV FLASK_APP_VERSION_INFO=${version_info}
RUN pip install -r requirements_simple.txt
RUN pip install -r requirements_main.txt
RUN rm -fr /root/.cache/*
COPY . .
ENTRYPOINT ["python", "main_simple.py"]
ENTRYPOINT ["python", "main.py"]

112
NLU.py
View File

@ -1,17 +1,22 @@
import spacy
import spacy_transformers
import torch
import logging
class NLU:
def __init__(self, query_rewriter, coref_resolver, intent_classifier, offensive_classifier, entity_extractor, ambig_classifier):
def __init__(self, LLM_tokenizer, LLM_model):
"""
self.intent_classifier = intent_classifier
self.entity_extractor = entity_extractor
self.offensive_classifier = offensive_classifier
self.coref_resolver = coref_resolver
self.query_rewriter = query_rewriter
self.ambig_classifier = ambig_classifier
"""
self.tokenizer = LLM_tokenizer
self.model = LLM_model
"""
def _resolve_coref(self, history):
to_resolve = history + ' <COREF_SEP_TOKEN> ' + self.to_process
doc = self.coref_resolver(to_resolve)
@ -20,13 +25,13 @@ class NLU:
clusters = [
val for key, val in doc.spans.items() if key.startswith("coref_cluster")
]
"""
clusters = []
for cluster in cand_clusters:
if cluster[0].text == "I":
continue
clusters.append(cluster)
"""
# Iterate through every found cluster
for cluster in clusters:
first_mention = cluster[0]
@ -83,49 +88,68 @@ class NLU:
text = history + " ||| " + self.to_process
return self.query_rewriter(text)[0]['generated_text']
"""
def process_utterance(self, utterance, history_consec, history_sep):
def process_utterance(self, history):
"""
Query -> coref resolution & intent extraction -> if intents are not confident or if query is ambig -> rewrite query and recheck -> if still ambig, ask a clarifying question
"""
if utterance.lower() in ["help", "list resources", "list papers", "list datasets", "list topics"]:
return {"modified_query": utterance.lower(), "intent": "COMMAND", "entities": [], "is_offensive": False, "is_clear": True}
#if utterance.lower() in ["help", "list resources", "list papers", "list datasets", "list topics"]:
# return {"modified_query": utterance.lower(), "intent": "COMMAND", "entities": [], "is_offensive": False, "is_clear": True}
self.to_process = utterance
self.to_process = self._resolve_coref(history_consec)
intent, score = self._intentpredictor()
#self.to_process = utterance
if score > 0.5:
if intent == 'CHITCHAT':
self.to_process = utterance
entities = self._entityextractor()
offense = self._offensepredictor()
if intent in ['FINDPAPER', 'FINDDATASET', 'SUMMARIZEPAPER'] and len(entities) == 0:
return {"modified_query": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense, "is_clear": False}
return {"modified_query": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense, "is_clear": True}
else:
if self._ambigpredictor():
self.to_process = self._rewrite_query(history_sep)
intent, score = self._intentpredictor()
entities = self._entityextractor()
offense = self._offensepredictor()
if score > 0.5 or not self._ambigpredictor():
if intent == 'CHITCHAT':
self.to_process = utterance
if intent in ['FINDPAPER', 'FINDDATASET', 'SUMMARIZEPAPER'] and len(entities) == 0:
return {"modified_query": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense, "is_clear": False}
return {"modified_query": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense,
"is_clear": True}
else:
return {"modified_query": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense,
"is_clear": False}
else:
entities = self._entityextractor()
offense = self._offensepredictor()
if intent == 'CHITCHAT':
self.to_process = utterance
if intent in ['FINDPAPER', 'FINDDATASET', 'SUMMARIZEPAPER'] and len(entities) == 0:
return {"modified_query": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense, "is_clear": False}
return {"modified_query": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense, "is_clear": True}
prompt = f"""You are Janet, the virtual assistant of the virtual research enviornment users.
What does the user eventually want given this dialogue, which is delimited with triple backticks?
Give your answer in one single sentence.
Dialogue: '''{history}'''
"""
chat = [{ "role": "user", "content": prompt },]
prompt_chat = self.tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
inputs = self.tokenizer.encode(prompt_chat, add_special_tokens=False, return_tensors="pt")
outputs = self.model.generate(input_ids=inputs, max_new_tokens=150)
goal = self.tokenizer.decode(outputs[0])
logging.debug("User's goal is:" + goal)
#return goal.split("<start_of_turn>model\n")[-1].split("<eos>")[0]
return {"modified_query": goal.split("<start_of_turn>model\n")[-1].split("<eos>")[0],
"intent": "QA", "entities": [], "is_offensive": False, "is_clear": True}
#self.to_process = self._resolve_coref(history_consec)
#intent, score = self._intentpredictor()
#if score > 0.5:
# if intent == 'CHITCHAT':
# self.to_process = utterance
# entities = self._entityextractor()
# offense = self._offensepredictor()
# if intent in ['FINDPAPER', 'FINDDATASET', 'SUMMARIZEPAPER'] and len(entities) == 0:
# return {"modified_query": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense, "is_clear": False}
# return {"modified_query": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense, "is_clear": True}
#else:
# if self._ambigpredictor():
# self.to_process = self._rewrite_query(history_sep)
# intent, score = self._intentpredictor()
# entities = self._entityextractor()
# offense = self._offensepredictor()
# if score > 0.5 or not self._ambigpredictor():
# if intent == 'CHITCHAT':
# self.to_process = utterance
# if intent in ['FINDPAPER', 'FINDDATASET', 'SUMMARIZEPAPER'] and len(entities) == 0:
# return {"modified_query": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense, "is_clear": False}
# return {"modified_query": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense,
# "is_clear": True}
# else:
# return {"modified_query": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense,
# "is_clear": False}
# else:
# entities = self._entityextractor()
# offense = self._offensepredictor()
# if intent == 'CHITCHAT':
# self.to_process = utterance
# if intent in ['FINDPAPER', 'FINDDATASET', 'SUMMARIZEPAPER'] and len(entities) == 0:
# return {"modified_query": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense, "is_clear": False}
# return {"modified_query": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense, "is_clear": True}

View File

@ -8,7 +8,7 @@ from datetime import datetime
from datasets import Dataset
class ResponseGenerator:
def __init__(self, index, db,recommender,generators, retriever, num_retrieved=3):
def __init__(self, index=None, db=None,recommender=None,generators=None, retriever=None, num_retrieved=3):
self.generators = generators
self.retriever = retriever
self.recommender = recommender

139
main.py
View File

@ -1,4 +1,5 @@
import os
import logging
import re
import warnings
import faiss
@ -10,7 +11,7 @@ import spacy
import requests
import spacy_transformers
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline, AutoModelForCausalLM
from User import User
from VRE import VRE
from NLU import NLU
@ -21,6 +22,9 @@ import pandas as pd
import time
import threading
from sentence_transformers import SentenceTransformer
from huggingface_hub import login
login(token="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")
@ -36,46 +40,56 @@ alive = "alive"
device = "cuda" if torch.cuda.is_available() else "cpu"
device_flag = torch.cuda.current_device() if torch.cuda.is_available() else -1
model_id = "/models/google-gemma"
dtype = torch.bfloat16
query_rewriter = pipeline("text2text-generation", model="castorini/t5-base-canard")
intent_classifier = pipeline("sentiment-analysis", model='/models/intent_classifier', device=device_flag)
entity_extractor = spacy.load("/models/entity_extractor")
offensive_classifier = pipeline("sentiment-analysis", model='/models/offensive_classifier', device=device_flag)
ambig_classifier = pipeline("sentiment-analysis", model='/models/ambig_classifier', device=device_flag)
coref_resolver = spacy.load("en_coreference_web_trf")
nlu = NLU(query_rewriter, coref_resolver, intent_classifier, offensive_classifier, entity_extractor, ambig_classifier)
#query_rewriter = pipeline("text2text-generation", model="castorini/t5-base-canard")
#intent_classifier = pipeline("sentiment-analysis", model='/models/intent_classifier', device=device_flag)
#entity_extractor = spacy.load("/models/entity_extractor")
#offensive_classifier = pipeline("sentiment-analysis", model='/models/offensive_classifier', device=device_flag)
#ambig_classifier = pipeline("sentiment-analysis", model='/models/ambig_classifier', device=device_flag)
#coref_resolver = spacy.load("en_coreference_web_trf")
#LLM = pipeline("text2text-generation", model="/models/google-gemma", device=device_flag)
LLM_tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it")
LLM_model = AutoModelForCausalLM.from_pretrainedAutoModelForCausalLM.from_pretrained(
"google/gemma-2b-it",
torch_dtype=torch.bfloat16
)
nlu = NLU(LLM_tokenizer, LLM_model)
#load retriever and generator
retriever = SentenceTransformer('/models/retriever/').to(device)
qa_generator = pipeline("text2text-generation", model="/models/train_qa", device=device_flag)
summ_generator = pipeline("text2text-generation", model="/models/train_summ", device=device_flag)
chat_generator = pipeline("text2text-generation", model="/models/train_chat", device=device_flag)
amb_generator = pipeline("text2text-generation", model="/models/train_amb_gen", device=device_flag)
generators = {'qa': qa_generator,
'chat': chat_generator,
'amb': amb_generator,
'summ': summ_generator}
#qa_generator = pipeline("text2text-generation", model="/models/train_qa", device=device_flag)
#summ_generator = pipeline("text2text-generation", model="/models/train_summ", device=device_flag)
#chat_generator = pipeline("text2text-generation", model="/models/train_chat", device=device_flag)
#amb_generator = pipeline("text2text-generation", model="/models/train_amb_gen", device=device_flag)
#generators = {'qa': qa_generator,
# 'chat': chat_generator,
# 'amb': amb_generator,
# 'summ': summ_generator}
rec = Recommender(retriever)
def vre_fetch(token):
while True:
try:
time.sleep(1000)
print('getting new material')
users[token]['vre'].get_vre_update()
users[token]['vre'].index_periodic_update()
users[token]['rg'].update_index(vre.get_index())
users[token]['rg'].update_db(vre.get_db())
#def vre_fetch(token):
# while True:
# try:
# time.sleep(1000)
# print('getting new material')
# users[token]['vre'].get_vre_update()
# users[token]['vre'].index_periodic_update()
# users[token]['rg'].update_index(vre.get_index())
# users[token]['rg'].update_db(vre.get_db())
#vre.get_vre_update()
#vre.index_periodic_update()
#rg.update_index(vre.get_index())
#rg.update_db(vre.get_db())
except Exception as e:
alive = "dead_vre_fetch"
# except Exception as e:
# alive = "dead_vre_fetch"
"""
def user_interest_decay(token):
while True:
try:
@ -99,6 +113,7 @@ def clear_inactive():
users[username]['activity'] += 1
except Exception as e:
alive = "dead_clear_inactive"
"""
@app.route("/health", methods=['GET'])
def health():
@ -113,10 +128,13 @@ def init_dm():
token = request.get_json().get("token")
status = request.get_json().get("stat")
if status == "start":
logging.debug("status=start")
message = {"stat": "waiting", "err": ""}
elif status == "set":
logging.debug("status=set")
headers = {"gcube-token": token, "Accept": "application/json"}
if token not in users:
logging.debug("getting user info")
url = 'https://api.d4science.org/rest/2/people/profile'
response = requests.get(url, headers=headers)
if response.status_code == 200:
@ -128,12 +146,13 @@ def init_dm():
index = vre.get_index()
db = vre.get_db()
rg = ResponseGenerator(index,db, rec, generators, retriever)
rg = ResponseGenerator(index,db, rec, retriever=retriever)
users[token] = {'username': username, 'name': name, 'dm': DM(), 'activity': 0, 'user': User(username, token), 'vre': vre, 'rg': rg}
users[token] = {'username': username, 'name': name, 'dm': DM(), 'activity': 0, 'user': User(username, token),
'vre': vre, 'rg': rg}
threading.Thread(target=user_interest_decay, args=(token,), name='decayinterest_'+users[token]['username']).start()
threading.Thread(target=vre_fetch, name='updatevre'+users[token]['username'], args=(token,)).start()
#threading.Thread(target=user_interest_decay, args=(token,), name='decayinterest_'+users[token]['username']).start()
#threading.Thread(target=vre_fetch, name='updatevre'+users[token]['username'], args=(token,)).start()
message = {"stat": "done", "err": ""}
else:
message = {"stat": "rejected", "err": ""}
@ -156,43 +175,55 @@ def predict():
message = {}
try:
if text == "<HELP_ON_START>":
logging.debug("help on start - inactive")
state = {'help': True, 'inactive': False, 'modified_query':"", 'intent':""}
dm.update(state)
action = dm.next_action()
logging.debug("next action:" + action)
#response = "Hey " + users[token]['name'].split()[0] + "! it's Janet! I am here to help you make use of the datasets and papers in the catalogue of the VRE. I can answer questions whose answers may be inside the papers. I can summarize papers for you. I can also chat with you. So, whichever it is, I am ready to chat!"
response = rg.gen_response(action, vrename=vre.name, username=users[token]['username'], name=users[token]['name'].split()[0])
message = {"answer": response}
elif text == "<RECOMMEND_ON_IDLE>":
logging.debug("recommend on idle - inactive")
state = {'help': False, 'inactive': True, 'modified_query':"recommed: ", 'intent':""}
dm.update(state)
action = dm.next_action()
logging.debug("next action:" + action)
#response = "Hey " + users[token]['name'].split()[0] + "! it's Janet! I am here to help you make use of the datasets and papers in the catalogue of the VRE. I can answer questions whose answers may be inside the papers. I can summarize papers for you. I can also chat with you. So, whichever it is, I am ready to chat!"
response = rg.gen_response(action, username=users[token]['username'],name=users[token]['name'].split()[0], vrename=vre.name)
message = {"answer": response}
new_state = {'modified_query': response}
new_state = {'modified_query': "Janet: " + response}
dm.update(new_state)
else:
state = nlu.process_utterance(text, dm.get_consec_history(), dm.get_sep_history())
state = nlu.process_utterance(f"""{dm.get_history()}
user: {text}""")
state['help'] = False
state['inactive'] = False
old_user_interests = user.get_user_interests()
old_vre_material = pd.concat([vre.db['paper_db'], vre.db['dataset_db']]).reset_index(drop=True)
user_interests = []
for entity in state['entities']:
if entity['entity'] == 'TOPIC':
user_interests.append(entity['value'])
user.update_interests(user_interests)
new_user_interests = user.get_user_interests()
new_vre_material = pd.concat([vre.db['paper_db'], vre.db['dataset_db']]).reset_index(drop=True)
if (new_user_interests != old_user_interests or len(old_vre_material) != len(new_vre_material)):
rec.generate_recommendations(users[token]['username'], new_user_interests, new_vre_material)
#old_user_interests = user.get_user_interests()
#old_vre_material = pd.concat([vre.db['paper_db'], vre.db['dataset_db']]).reset_index(drop=True)
#user_interests = []
#for entity in state['entities']:
# if entity['entity'] == 'TOPIC':
# user_interests.append(entity['value'])
#user.update_interests(user_interests)
#new_user_interests = user.get_user_interests()
#new_vre_material = pd.concat([vre.db['paper_db'], vre.db['dataset_db']]).reset_index(drop=True)
#if (new_user_interests != old_user_interests or len(old_vre_material) != len(new_vre_material)):
# rec.generate_recommendations(users[token]['username'], new_user_interests, new_vre_material)
dm.update(state)
action = dm.next_action()
response = rg.gen_response(action=action, utterance=state['modified_query'], state=dm.get_recent_state(), consec_history=dm.get_consec_history(), chitchat_history=dm.get_chitchat_history(), vrename=vre.name, username=users[token]['username'], name=users[token]['name'].split()[0])
message = {"answer": response, "query": text, "cand": "candidate", "history": dm.get_consec_history(), "modQuery": state['modified_query']}
if state['intent'] == "QA":
split_response = response.split("_______ \n ")
if len(split_response) > 1:
response = split_response[1]
new_state = {'modified_query': response, 'intent': state['intent']}
logging.debug("Next action: " + action)
#response = rg.gen_response(action=action, utterance=state['modified_query'], state=dm.get_recent_state(), consec_history=dm.get_consec_history(), chitchat_history=dm.get_chitchat_history(), vrename=vre.name, username=users[token]['username'], name=users[token]['name'].split()[0])
#message = {"answer": response, "query": text, "cand": "candidate", "history": dm.get_consec_history(), "modQuery": state['modified_query']}
message = {"answer": state['modified_query'], "query": text, "cand": "candidate", "history": dm.get_history(), "modQuery": state['modified_query']}
#if state['intent'] == "QA":
# split_response = response.split("_______ \n ")
# if len(split_response) > 1:
# response = split_response[1]
response =state['modified_query']
new_state = {'modified_query': "Janet: " + response, 'intent': state['intent']}
dm.update(new_state)
reply = jsonify(message)
users[token]['dm'] = dm
@ -231,7 +262,7 @@ def feedback():
if __name__ == "__main__":
warnings.filterwarnings("ignore")
threading.Thread(target=clear_inactive, name='clear').start()
#threading.Thread(target=clear_inactive, name='clear').start()
"""
conn = psycopg2.connect(host="janet-pg", database=os.getenv("POSTGRES_DB"), user=os.getenv("POSTGRES_USER"), password=os.getenv("POSTGRES_PASSWORD"))

View File

@ -33,6 +33,7 @@ markupsafe==2.0.1
psycopg2==2.9.5
en-coreference-web-trf @ https://github.com/explosion/spacy-experimental/releases/download/v0.6.1/en_coreference_web_trf-3.4.0a2-py3-none-any.whl
datasets
huggingface_hub
Werkzeug==1.0.1