From 1efd0ac18d7e010ff943cb8bcb362b9c2e32cdb1 Mon Sep 17 00:00:00 2001 From: Ahmed Ibrahim Date: Thu, 30 May 2024 19:09:54 +0200 Subject: [PATCH] test google gemma --- .gitignore | 2 + DM.py | 60 +++++++++++------- Dockerfile | 6 +- NLU.py | 112 +++++++++++++++++++++------------- ResponseGenerator.py | 2 +- main.py | 139 ++++++++++++++++++++++++++---------------- requirements_main.txt | 1 + 7 files changed, 199 insertions(+), 123 deletions(-) diff --git a/.gitignore b/.gitignore index 7fb0f16..32d6c9d 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,5 @@ janet.pdf __pycache__/ +git-filter-repo +.gitignore ahmed.ibrahim39699_interests.json diff --git a/DM.py b/DM.py index 9603db8..d19f821 100644 --- a/DM.py +++ b/DM.py @@ -6,43 +6,61 @@ class DM: self.working_history_consec = "" self.chitchat_history_consec = "" self.max_history_length = max_history_length + self.history = "" self.chat_history = [] self.curr_state = None - def update_history(self): - to_consider = [x['modified_query'] for x in self.chat_history[-self.max_history_length*2:]] - self.working_history_consec = " . ".join(to_consider) - self.working_history_sep = " ||| ".join(to_consider) + #def update_history(self): + #to_consider = [x['modified_query'] for x in self.chat_history[-self.max_history_length*2:]] + #self.working_history_consec = " . ".join(to_consider) + #self.working_history_sep = " ||| ".join(to_consider) - chat = [] - for utt in self.chat_history: - if utt['intent'] == 'CHITCHAT': - if len(chat) == 4: - chat = chat[1:] - chat.append(utt['modified_query']) - self.chitchat_history_consec = '. '.join(chat) + #chat = [] + #for utt in self.chat_history: + # if utt['intent'] == 'CHITCHAT': + # if len(chat) == 4: + # chat = chat[1:] + # chat.append(utt['modified_query']) + #self.chitchat_history_consec = '. '.join(chat) - def get_consec_history(self): - return self.working_history_consec + #def get_consec_history(self): + # return self.working_history_consec - def get_chitchat_history(self): - return self.chitchat_history_consec + #def get_chitchat_history(self): + # return self.chitchat_history_consec - def get_sep_history(self): - return self.working_history_sep + #def get_sep_history(self): + # return self.working_history_sep - def get_recent_state(self): - return self.curr_state + #def get_recent_state(self): + # return self.curr_state - def get_dialogue_history(self): - return self.chat_history + #def get_dialogue_history(self): + # return self.chat_history def update(self, new_state): self.chat_history.append(new_state) self.curr_state = new_state self.update_history() + def update_history(self): + to_consider = [x['modified_query'] for x in self.chat_history[-self.max_history_length*2:]] + #self.working_history_consec = " . ".join(to_consider) + #self.working_history_sep = " ||| ".join(to_consider) + for utt in to_consider: + self.history = utt if len(self.history) == 0 else f"""{self.history} +{utt}""" + + #user_last_utt = f"""{username}: {text}""" + #self.history = user_last_utt if len(self.history) == 0 else f"""{self.history} + #{user_last_utt}""" + + def get_history(self): + return self.history + + + def next_action(self): if self.curr_state['help']: return "Help" diff --git a/Dockerfile b/Dockerfile index 03c5016..8435f65 100644 --- a/Dockerfile +++ b/Dockerfile @@ -2,15 +2,15 @@ FROM python:3.8 WORKDIR /backend_janet -COPY requirements_simple.txt . +COPY requirements_main.txt . ARG version_info ENV FLASK_APP_VERSION_INFO=${version_info} -RUN pip install -r requirements_simple.txt +RUN pip install -r requirements_main.txt RUN rm -fr /root/.cache/* COPY . . -ENTRYPOINT ["python", "main_simple.py"] +ENTRYPOINT ["python", "main.py"] diff --git a/NLU.py b/NLU.py index d39bc51..8d05166 100644 --- a/NLU.py +++ b/NLU.py @@ -1,17 +1,22 @@ import spacy import spacy_transformers import torch +import logging class NLU: - def __init__(self, query_rewriter, coref_resolver, intent_classifier, offensive_classifier, entity_extractor, ambig_classifier): - + def __init__(self, LLM_tokenizer, LLM_model): + """ self.intent_classifier = intent_classifier self.entity_extractor = entity_extractor self.offensive_classifier = offensive_classifier self.coref_resolver = coref_resolver self.query_rewriter = query_rewriter self.ambig_classifier = ambig_classifier + """ + self.tokenizer = LLM_tokenizer + self.model = LLM_model + """ def _resolve_coref(self, history): to_resolve = history + ' ' + self.to_process doc = self.coref_resolver(to_resolve) @@ -20,13 +25,13 @@ class NLU: clusters = [ val for key, val in doc.spans.items() if key.startswith("coref_cluster") ] - """ + clusters = [] for cluster in cand_clusters: if cluster[0].text == "I": continue clusters.append(cluster) - """ + # Iterate through every found cluster for cluster in clusters: first_mention = cluster[0] @@ -83,49 +88,68 @@ class NLU: text = history + " ||| " + self.to_process return self.query_rewriter(text)[0]['generated_text'] + """ - def process_utterance(self, utterance, history_consec, history_sep): + def process_utterance(self, history): """ Query -> coref resolution & intent extraction -> if intents are not confident or if query is ambig -> rewrite query and recheck -> if still ambig, ask a clarifying question """ - if utterance.lower() in ["help", "list resources", "list papers", "list datasets", "list topics"]: - return {"modified_query": utterance.lower(), "intent": "COMMAND", "entities": [], "is_offensive": False, "is_clear": True} + #if utterance.lower() in ["help", "list resources", "list papers", "list datasets", "list topics"]: + # return {"modified_query": utterance.lower(), "intent": "COMMAND", "entities": [], "is_offensive": False, "is_clear": True} - self.to_process = utterance - - self.to_process = self._resolve_coref(history_consec) - - intent, score = self._intentpredictor() + #self.to_process = utterance - if score > 0.5: - if intent == 'CHITCHAT': - self.to_process = utterance - entities = self._entityextractor() - offense = self._offensepredictor() - if intent in ['FINDPAPER', 'FINDDATASET', 'SUMMARIZEPAPER'] and len(entities) == 0: - return {"modified_query": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense, "is_clear": False} - return {"modified_query": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense, "is_clear": True} - else: - if self._ambigpredictor(): - self.to_process = self._rewrite_query(history_sep) - intent, score = self._intentpredictor() - entities = self._entityextractor() - offense = self._offensepredictor() - if score > 0.5 or not self._ambigpredictor(): - if intent == 'CHITCHAT': - self.to_process = utterance - if intent in ['FINDPAPER', 'FINDDATASET', 'SUMMARIZEPAPER'] and len(entities) == 0: - return {"modified_query": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense, "is_clear": False} - return {"modified_query": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense, - "is_clear": True} - else: - return {"modified_query": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense, - "is_clear": False} - else: - entities = self._entityextractor() - offense = self._offensepredictor() - if intent == 'CHITCHAT': - self.to_process = utterance - if intent in ['FINDPAPER', 'FINDDATASET', 'SUMMARIZEPAPER'] and len(entities) == 0: - return {"modified_query": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense, "is_clear": False} - return {"modified_query": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense, "is_clear": True} + prompt = f"""You are Janet, the virtual assistant of the virtual research enviornment users. + What does the user eventually want given this dialogue, which is delimited with triple backticks? + Give your answer in one single sentence. + Dialogue: '''{history}''' + """ + + chat = [{ "role": "user", "content": prompt },] + prompt_chat = self.tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True) + inputs = self.tokenizer.encode(prompt_chat, add_special_tokens=False, return_tensors="pt") + outputs = self.model.generate(input_ids=inputs, max_new_tokens=150) + + goal = self.tokenizer.decode(outputs[0]) + logging.debug("User's goal is:" + goal) + + #return goal.split("model\n")[-1].split("")[0] + return {"modified_query": goal.split("model\n")[-1].split("")[0], + "intent": "QA", "entities": [], "is_offensive": False, "is_clear": True} + + #self.to_process = self._resolve_coref(history_consec) + + #intent, score = self._intentpredictor() + + #if score > 0.5: + # if intent == 'CHITCHAT': + # self.to_process = utterance + # entities = self._entityextractor() + # offense = self._offensepredictor() + # if intent in ['FINDPAPER', 'FINDDATASET', 'SUMMARIZEPAPER'] and len(entities) == 0: + # return {"modified_query": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense, "is_clear": False} + # return {"modified_query": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense, "is_clear": True} + #else: + # if self._ambigpredictor(): + # self.to_process = self._rewrite_query(history_sep) + # intent, score = self._intentpredictor() + # entities = self._entityextractor() + # offense = self._offensepredictor() + # if score > 0.5 or not self._ambigpredictor(): + # if intent == 'CHITCHAT': + # self.to_process = utterance + # if intent in ['FINDPAPER', 'FINDDATASET', 'SUMMARIZEPAPER'] and len(entities) == 0: + # return {"modified_query": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense, "is_clear": False} + # return {"modified_query": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense, + # "is_clear": True} + # else: + # return {"modified_query": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense, + # "is_clear": False} + # else: + # entities = self._entityextractor() + # offense = self._offensepredictor() + # if intent == 'CHITCHAT': + # self.to_process = utterance + # if intent in ['FINDPAPER', 'FINDDATASET', 'SUMMARIZEPAPER'] and len(entities) == 0: + # return {"modified_query": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense, "is_clear": False} + # return {"modified_query": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense, "is_clear": True} diff --git a/ResponseGenerator.py b/ResponseGenerator.py index 3fcd7b4..98f4d80 100644 --- a/ResponseGenerator.py +++ b/ResponseGenerator.py @@ -8,7 +8,7 @@ from datetime import datetime from datasets import Dataset class ResponseGenerator: - def __init__(self, index, db,recommender,generators, retriever, num_retrieved=3): + def __init__(self, index=None, db=None,recommender=None,generators=None, retriever=None, num_retrieved=3): self.generators = generators self.retriever = retriever self.recommender = recommender diff --git a/main.py b/main.py index e8c7871..b03332c 100644 --- a/main.py +++ b/main.py @@ -1,4 +1,5 @@ import os +import logging import re import warnings import faiss @@ -10,7 +11,7 @@ import spacy import requests import spacy_transformers import torch -from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline +from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline, AutoModelForCausalLM from User import User from VRE import VRE from NLU import NLU @@ -21,6 +22,9 @@ import pandas as pd import time import threading from sentence_transformers import SentenceTransformer +from huggingface_hub import login + +login(token="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") @@ -36,46 +40,56 @@ alive = "alive" device = "cuda" if torch.cuda.is_available() else "cpu" device_flag = torch.cuda.current_device() if torch.cuda.is_available() else -1 +model_id = "/models/google-gemma" +dtype = torch.bfloat16 -query_rewriter = pipeline("text2text-generation", model="castorini/t5-base-canard") -intent_classifier = pipeline("sentiment-analysis", model='/models/intent_classifier', device=device_flag) -entity_extractor = spacy.load("/models/entity_extractor") -offensive_classifier = pipeline("sentiment-analysis", model='/models/offensive_classifier', device=device_flag) -ambig_classifier = pipeline("sentiment-analysis", model='/models/ambig_classifier', device=device_flag) -coref_resolver = spacy.load("en_coreference_web_trf") - -nlu = NLU(query_rewriter, coref_resolver, intent_classifier, offensive_classifier, entity_extractor, ambig_classifier) +#query_rewriter = pipeline("text2text-generation", model="castorini/t5-base-canard") +#intent_classifier = pipeline("sentiment-analysis", model='/models/intent_classifier', device=device_flag) +#entity_extractor = spacy.load("/models/entity_extractor") +#offensive_classifier = pipeline("sentiment-analysis", model='/models/offensive_classifier', device=device_flag) +#ambig_classifier = pipeline("sentiment-analysis", model='/models/ambig_classifier', device=device_flag) +#coref_resolver = spacy.load("en_coreference_web_trf") + +#LLM = pipeline("text2text-generation", model="/models/google-gemma", device=device_flag) + +LLM_tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it") +LLM_model = AutoModelForCausalLM.from_pretrainedAutoModelForCausalLM.from_pretrained( + "google/gemma-2b-it", + torch_dtype=torch.bfloat16 +) + +nlu = NLU(LLM_tokenizer, LLM_model) #load retriever and generator retriever = SentenceTransformer('/models/retriever/').to(device) -qa_generator = pipeline("text2text-generation", model="/models/train_qa", device=device_flag) -summ_generator = pipeline("text2text-generation", model="/models/train_summ", device=device_flag) -chat_generator = pipeline("text2text-generation", model="/models/train_chat", device=device_flag) -amb_generator = pipeline("text2text-generation", model="/models/train_amb_gen", device=device_flag) -generators = {'qa': qa_generator, - 'chat': chat_generator, - 'amb': amb_generator, - 'summ': summ_generator} +#qa_generator = pipeline("text2text-generation", model="/models/train_qa", device=device_flag) +#summ_generator = pipeline("text2text-generation", model="/models/train_summ", device=device_flag) +#chat_generator = pipeline("text2text-generation", model="/models/train_chat", device=device_flag) +#amb_generator = pipeline("text2text-generation", model="/models/train_amb_gen", device=device_flag) +#generators = {'qa': qa_generator, +# 'chat': chat_generator, +# 'amb': amb_generator, +# 'summ': summ_generator} rec = Recommender(retriever) -def vre_fetch(token): - while True: - try: - time.sleep(1000) - print('getting new material') - users[token]['vre'].get_vre_update() - users[token]['vre'].index_periodic_update() - users[token]['rg'].update_index(vre.get_index()) - users[token]['rg'].update_db(vre.get_db()) +#def vre_fetch(token): +# while True: +# try: +# time.sleep(1000) +# print('getting new material') +# users[token]['vre'].get_vre_update() +# users[token]['vre'].index_periodic_update() +# users[token]['rg'].update_index(vre.get_index()) +# users[token]['rg'].update_db(vre.get_db()) #vre.get_vre_update() #vre.index_periodic_update() #rg.update_index(vre.get_index()) #rg.update_db(vre.get_db()) - except Exception as e: - alive = "dead_vre_fetch" +# except Exception as e: +# alive = "dead_vre_fetch" - +""" def user_interest_decay(token): while True: try: @@ -99,6 +113,7 @@ def clear_inactive(): users[username]['activity'] += 1 except Exception as e: alive = "dead_clear_inactive" +""" @app.route("/health", methods=['GET']) def health(): @@ -113,10 +128,13 @@ def init_dm(): token = request.get_json().get("token") status = request.get_json().get("stat") if status == "start": + logging.debug("status=start") message = {"stat": "waiting", "err": ""} elif status == "set": + logging.debug("status=set") headers = {"gcube-token": token, "Accept": "application/json"} if token not in users: + logging.debug("getting user info") url = 'https://api.d4science.org/rest/2/people/profile' response = requests.get(url, headers=headers) if response.status_code == 200: @@ -128,12 +146,13 @@ def init_dm(): index = vre.get_index() db = vre.get_db() - rg = ResponseGenerator(index,db, rec, generators, retriever) + rg = ResponseGenerator(index,db, rec, retriever=retriever) - users[token] = {'username': username, 'name': name, 'dm': DM(), 'activity': 0, 'user': User(username, token), 'vre': vre, 'rg': rg} + users[token] = {'username': username, 'name': name, 'dm': DM(), 'activity': 0, 'user': User(username, token), + 'vre': vre, 'rg': rg} - threading.Thread(target=user_interest_decay, args=(token,), name='decayinterest_'+users[token]['username']).start() - threading.Thread(target=vre_fetch, name='updatevre'+users[token]['username'], args=(token,)).start() + #threading.Thread(target=user_interest_decay, args=(token,), name='decayinterest_'+users[token]['username']).start() + #threading.Thread(target=vre_fetch, name='updatevre'+users[token]['username'], args=(token,)).start() message = {"stat": "done", "err": ""} else: message = {"stat": "rejected", "err": ""} @@ -156,43 +175,55 @@ def predict(): message = {} try: if text == "": + logging.debug("help on start - inactive") state = {'help': True, 'inactive': False, 'modified_query':"", 'intent':""} dm.update(state) action = dm.next_action() + logging.debug("next action:" + action) + #response = "Hey " + users[token]['name'].split()[0] + "! it's Janet! I am here to help you make use of the datasets and papers in the catalogue of the VRE. I can answer questions whose answers may be inside the papers. I can summarize papers for you. I can also chat with you. So, whichever it is, I am ready to chat!" response = rg.gen_response(action, vrename=vre.name, username=users[token]['username'], name=users[token]['name'].split()[0]) message = {"answer": response} elif text == "": + logging.debug("recommend on idle - inactive") state = {'help': False, 'inactive': True, 'modified_query':"recommed: ", 'intent':""} dm.update(state) action = dm.next_action() + logging.debug("next action:" + action) + #response = "Hey " + users[token]['name'].split()[0] + "! it's Janet! I am here to help you make use of the datasets and papers in the catalogue of the VRE. I can answer questions whose answers may be inside the papers. I can summarize papers for you. I can also chat with you. So, whichever it is, I am ready to chat!" + response = rg.gen_response(action, username=users[token]['username'],name=users[token]['name'].split()[0], vrename=vre.name) message = {"answer": response} - new_state = {'modified_query': response} + new_state = {'modified_query': "Janet: " + response} dm.update(new_state) else: - state = nlu.process_utterance(text, dm.get_consec_history(), dm.get_sep_history()) + state = nlu.process_utterance(f"""{dm.get_history()} +user: {text}""") state['help'] = False state['inactive'] = False - old_user_interests = user.get_user_interests() - old_vre_material = pd.concat([vre.db['paper_db'], vre.db['dataset_db']]).reset_index(drop=True) - user_interests = [] - for entity in state['entities']: - if entity['entity'] == 'TOPIC': - user_interests.append(entity['value']) - user.update_interests(user_interests) - new_user_interests = user.get_user_interests() - new_vre_material = pd.concat([vre.db['paper_db'], vre.db['dataset_db']]).reset_index(drop=True) - if (new_user_interests != old_user_interests or len(old_vre_material) != len(new_vre_material)): - rec.generate_recommendations(users[token]['username'], new_user_interests, new_vre_material) + #old_user_interests = user.get_user_interests() + #old_vre_material = pd.concat([vre.db['paper_db'], vre.db['dataset_db']]).reset_index(drop=True) + #user_interests = [] + #for entity in state['entities']: + # if entity['entity'] == 'TOPIC': + # user_interests.append(entity['value']) + #user.update_interests(user_interests) + #new_user_interests = user.get_user_interests() + #new_vre_material = pd.concat([vre.db['paper_db'], vre.db['dataset_db']]).reset_index(drop=True) + #if (new_user_interests != old_user_interests or len(old_vre_material) != len(new_vre_material)): + # rec.generate_recommendations(users[token]['username'], new_user_interests, new_vre_material) dm.update(state) action = dm.next_action() - response = rg.gen_response(action=action, utterance=state['modified_query'], state=dm.get_recent_state(), consec_history=dm.get_consec_history(), chitchat_history=dm.get_chitchat_history(), vrename=vre.name, username=users[token]['username'], name=users[token]['name'].split()[0]) - message = {"answer": response, "query": text, "cand": "candidate", "history": dm.get_consec_history(), "modQuery": state['modified_query']} - if state['intent'] == "QA": - split_response = response.split("_______ \n ") - if len(split_response) > 1: - response = split_response[1] - new_state = {'modified_query': response, 'intent': state['intent']} + logging.debug("Next action: " + action) + #response = rg.gen_response(action=action, utterance=state['modified_query'], state=dm.get_recent_state(), consec_history=dm.get_consec_history(), chitchat_history=dm.get_chitchat_history(), vrename=vre.name, username=users[token]['username'], name=users[token]['name'].split()[0]) + #message = {"answer": response, "query": text, "cand": "candidate", "history": dm.get_consec_history(), "modQuery": state['modified_query']} + message = {"answer": state['modified_query'], "query": text, "cand": "candidate", "history": dm.get_history(), "modQuery": state['modified_query']} + + #if state['intent'] == "QA": + # split_response = response.split("_______ \n ") + # if len(split_response) > 1: + # response = split_response[1] + response =state['modified_query'] + new_state = {'modified_query': "Janet: " + response, 'intent': state['intent']} dm.update(new_state) reply = jsonify(message) users[token]['dm'] = dm @@ -231,7 +262,7 @@ def feedback(): if __name__ == "__main__": warnings.filterwarnings("ignore") - threading.Thread(target=clear_inactive, name='clear').start() + #threading.Thread(target=clear_inactive, name='clear').start() """ conn = psycopg2.connect(host="janet-pg", database=os.getenv("POSTGRES_DB"), user=os.getenv("POSTGRES_USER"), password=os.getenv("POSTGRES_PASSWORD")) diff --git a/requirements_main.txt b/requirements_main.txt index e34ed57..b95b8b8 100644 --- a/requirements_main.txt +++ b/requirements_main.txt @@ -33,6 +33,7 @@ markupsafe==2.0.1 psycopg2==2.9.5 en-coreference-web-trf @ https://github.com/explosion/spacy-experimental/releases/download/v0.6.1/en_coreference_web_trf-3.4.0a2-py3-none-any.whl datasets +huggingface_hub Werkzeug==1.0.1