new_code
This commit is contained in:
parent
6c4ffc2740
commit
ee9d8c5312
63
DM.py
63
DM.py
|
@ -1,43 +1,56 @@
|
|||
import time
|
||||
|
||||
class DM:
|
||||
def __init__(self):
|
||||
self.utt_history = ""
|
||||
self.history = []
|
||||
self.state = None
|
||||
def __init__(self, max_history_length=3):
|
||||
self.working_history_sep = ""
|
||||
self.working_history_consec = ""
|
||||
self.max_history_length = max_history_length
|
||||
self.chat_history = []
|
||||
self.curr_state = None
|
||||
|
||||
def get_utt_history(self):
|
||||
return self.utt_history
|
||||
def update_history(self):
|
||||
to_consider = [x['modified_query'] for x in self.chat_history[-max_history_length*2:]]
|
||||
self.working_history_consec = " . ".join(to_consider)
|
||||
self.working_history_sep = " ||| ".join(to_consider)
|
||||
|
||||
def get_consec_history(self):
|
||||
return self.working_history_consec
|
||||
|
||||
def get_sep_history(self):
|
||||
return self.working_history_sep
|
||||
|
||||
def get_recent_state(self):
|
||||
return self.state
|
||||
return self.curr_state
|
||||
|
||||
def get_dialogue_state(self):
|
||||
return self.history
|
||||
def get_dialogue_history(self):
|
||||
return self.chat_history
|
||||
|
||||
def update(self, new_state):
|
||||
self.history.append(new_state)
|
||||
self.utt_history = self.utt_history + " ||| " + new_state['modified_prompt']
|
||||
self.state = {'intent': new_state['intent'],
|
||||
'entities': new_state['entities'],
|
||||
'offensive': new_state['is_offensive'],
|
||||
'clear': new_state['is_clear'],
|
||||
'time': time.time()}
|
||||
self.chat_history.append(new_state)
|
||||
self.curr_state = new_state
|
||||
self.update_history()
|
||||
|
||||
def next_action(self):
|
||||
if self.state['clear']:
|
||||
if self.state['offensive']:
|
||||
return "NoCanDo"
|
||||
if self.curr_state['help']:
|
||||
return "Help"
|
||||
elif self.curr_state['inactive']:
|
||||
return "Recommend"
|
||||
elif self.curr_state['is_clear']:
|
||||
if self.curr_state['is_offensive']:
|
||||
return "OffenseReject"
|
||||
else:
|
||||
if self.state['intent'] == 0:
|
||||
if self.curr_state['intent'] == 'QA':
|
||||
return "RetGen"
|
||||
elif self.state['intent'] == 1:
|
||||
elif self.curr_state['intent'] == 'CHITCHAT':
|
||||
return "ConvGen"
|
||||
elif self.state['intent'] == 2:
|
||||
elif self.curr_state['intent'] == 'FINDPAPER':
|
||||
return "findPaper"
|
||||
elif self.state['intent'] == 3:
|
||||
elif self.curr_state['intent'] == 'FINDDATASET':
|
||||
return "findDataset"
|
||||
elif self.state['intent'] == 4:
|
||||
elif self.curr_state['intent'] == 'SUMMARIZEPAPER':
|
||||
return "sumPaper"
|
||||
else:
|
||||
return "Clarify"
|
||||
if self.curr_state['intent'] in ['FINDPAPER', 'FINDDATASET', 'SUMMARIZEPAPER'] and len(self.curr_state['entities']) == 0:
|
||||
return "ClarifyResource"
|
||||
else:
|
||||
return "GenClarify"
|
||||
|
|
215
NLU.py
215
NLU.py
|
@ -1,143 +1,112 @@
|
|||
"""
|
||||
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
|
||||
import torch
|
||||
|
||||
|
||||
class NLU:
|
||||
def_tokenizer = AutoTokenizer.from_pretrained("castorini/t5-base-canard")
|
||||
def_model = AutoModelForSeq2SeqLM.from_pretrained("castorini/t5-base-canard")
|
||||
def_intent_classifier = pipeline("sentiment-analysis", model="/home/ahmed/PycharmProjects/Janet/JanetBackend/intent_classifier")
|
||||
|
||||
def __init__(self, model=def_model, tokenizer=def_tokenizer, intent_classifier=def_intent_classifier,
|
||||
max_history_length=1024, num_gen_seq=2, score_threshold=0.5):
|
||||
self.input = ""
|
||||
self.output = ""
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
self.max_length = max_history_length
|
||||
self.num_return_sequences = num_gen_seq
|
||||
self.score_threshold = score_threshold
|
||||
self.label2id = {'Greet': 0, 'Bye': 1, 'GetKnowledge': 2, 'ChitChat': 3}
|
||||
self.id2label = {0: 'Greet', 1: 'Bye', 2: 'GetKnowledge', 3: 'ChitChat'}
|
||||
self.intent_classifier = intent_classifier
|
||||
|
||||
def process_utterance(self, utterance, history):
|
||||
if len(history) > 0:
|
||||
# crop history
|
||||
while len(history.split(" ")) > self.max_length:
|
||||
index = history.find("|||")
|
||||
history = history[index + 4:]
|
||||
|
||||
self.input = history + " ||| " + utterance
|
||||
inputs = self.tokenizer(self.input, max_length=self.max_length, truncation=True, padding="max_length",
|
||||
return_tensors="pt")
|
||||
|
||||
candidates = self.model.generate(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"],
|
||||
return_dict_in_generate=True, output_scores=True,
|
||||
num_return_sequences=self.num_return_sequences,
|
||||
num_beams=self.num_return_sequences)
|
||||
for i in range(candidates["sequences"].shape[0]):
|
||||
generated_sentence = self.tokenizer.decode(candidates["sequences"][i], skip_special_tokens=True,
|
||||
clean_up_tokenization_spaces=True)
|
||||
log_scores = candidates['sequences_scores']
|
||||
norm_prob = (torch.exp(log_scores[i]) / torch.exp(log_scores).sum()).item()
|
||||
if norm_prob >= self.score_threshold:
|
||||
self.score_threshold = norm_prob
|
||||
self.output = generated_sentence
|
||||
else:
|
||||
self.output = utterance
|
||||
|
||||
intent = self.label2id[self.intent_classifier(self.output)[0]['label']]
|
||||
intent_conf = self.intent_classifier(self.output)[0]['score']
|
||||
|
||||
return {"modified_prompt": self.output, "mod_confidence": self.score_threshold, "prompt_intent": intent,
|
||||
"intent_confidence": intent_conf}
|
||||
"""
|
||||
|
||||
import threading
|
||||
|
||||
import spacy
|
||||
import spacy_transformers
|
||||
import torch
|
||||
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
|
||||
|
||||
|
||||
class NLU:
|
||||
def __init__(self, device, device_flag, reference_resolver, tokenizer,
|
||||
intent_classifier, offense_filter, entity_extractor,
|
||||
max_history_length=1024):
|
||||
#entity_extractor=def_entity_extractor
|
||||
self.reference_resolver = reference_resolver
|
||||
self.device = device
|
||||
self.reference_resolver.to(device)
|
||||
self.tokenizer = tokenizer
|
||||
self.max_length = max_history_length
|
||||
self.label2idintent = {'QA': 0, 'CHITCHAT': 1, 'FINDPAPER': 2, 'FINDDATASET': 3, 'SUMMARIZEPAPER': 4}
|
||||
self.id2labelintent = {0: 'QA', 1: 'CHITCHAT', 2: 'FINDPAPER', 3: 'FINDDATASET', 4: 'SUMMARIZEPAPER'}
|
||||
self.label2idoffense = {'hate': 0, 'offensive': 1, 'neither': 2}
|
||||
self.id2labeloffense = {0: 'hate', 1: 'offensive', 2: 'neither'}
|
||||
self.intent_classifier = pipeline("sentiment-analysis", model=intent_classifier, device=device_flag)
|
||||
self.entity_extractor = entity_extractor
|
||||
self.offense_filter = pipeline("sentiment-analysis", model=offense_filter, device=device_flag)
|
||||
def __init__(self, query_rewriter, coref_resolver, intent_classifier, offensive_classifier, entity_extractor, ambig_classifier):
|
||||
|
||||
self.intents = None
|
||||
self.entities = None
|
||||
self.offensive = None
|
||||
self.clear = True
|
||||
self.intent_classifier = intent_classifier
|
||||
self.entity_extractor = entity_extractor
|
||||
self.offensive_classifier = offensive_classifier
|
||||
self.coref_resolver = coref_resolver
|
||||
self.query_rewriter = query_rewriter
|
||||
self.ambig_classifier = ambig_classifier
|
||||
|
||||
def _resolve_coref(self, history):
|
||||
to_resolve = history + ' <COREF_SEP_TOKEN> ' + self.to_process
|
||||
doc = self.coref_resolver(to_resolve)
|
||||
token_mention_mapper = {}
|
||||
output_string = ""
|
||||
clusters = [
|
||||
val for key, val in doc.spans.items() if key.startswith("coref_cluster")
|
||||
]
|
||||
|
||||
# Iterate through every found cluster
|
||||
for cluster in clusters:
|
||||
first_mention = cluster[0]
|
||||
# Iterate through every other span in the cluster
|
||||
for mention_span in list(cluster)[1:]:
|
||||
# Set first_mention as value for the first token in mention_span in the token_mention_mapper
|
||||
token_mention_mapper[mention_span[0].idx] = first_mention.text + mention_span[0].whitespace_
|
||||
for token in mention_span[1:]:
|
||||
# Set empty string for all the other tokens in mention_span
|
||||
token_mention_mapper[token.idx] = ""
|
||||
|
||||
# Iterate through every token in the Doc
|
||||
for token in doc:
|
||||
# Check if token exists in token_mention_mapper
|
||||
if token.idx in token_mention_mapper:
|
||||
output_string += token_mention_mapper[token.idx]
|
||||
# Else add original token text
|
||||
else:
|
||||
output_string += token.text + token.whitespace_
|
||||
cleaned_query = output_string.split(" <COREF_SEP_TOKEN> ", 1)[1]
|
||||
return cleaned_query
|
||||
|
||||
def _intentpredictor(self):
|
||||
self.intents = self.label2idintent[self.intent_classifier(self.to_process)[0]['label']]
|
||||
pred = self.intent_classifier(self.to_process)[0]
|
||||
return pred['label'], pred['score']
|
||||
|
||||
def _ambigpredictor(self):
|
||||
pred = self.ambig_classifier(self.to_process)[0]
|
||||
if pred['label'] in ['clear', 'somewhat_clear']:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
def _entityextractor(self):
|
||||
self.entities = []
|
||||
entities = []
|
||||
doc = self.entity_extractor(self.to_process)
|
||||
for entity in doc.ents:
|
||||
if entity.text not in ['.', ',', '?', ';']:
|
||||
self.entities.append({'entity': entity.label_, 'value': entity.text})
|
||||
entities.append({'entity': entity.label_, 'value': entity.text})
|
||||
return entities
|
||||
|
||||
def _inappropriatedetector(self):
|
||||
self.offensive = False
|
||||
is_offensive = self.label2idoffense[self.offense_filter(self.to_process)[0]['label']]
|
||||
if is_offensive == 0 or is_offensive == 1:
|
||||
self.offensive = True
|
||||
def _offensepredictor(self):
|
||||
pred = self.offensive_classifier(self.to_process)[0]['label']
|
||||
if pred != "neither":
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def process_utterance(self, utterance, history):
|
||||
def _rewrite_query(self, history):
|
||||
text = history + " ||| " + self.to_process
|
||||
return self.query_rewriter(text)[0]['generated_text']
|
||||
|
||||
|
||||
def process_utterance(self, utterance, history_consec, history_sep):
|
||||
"""
|
||||
Given an utterance and the history of the conversation, refine the query contextually and return a refined
|
||||
utterance
|
||||
Query -> coref resolution & intent extraction -> if intents are not confident or if query is ambig -> rewrite query and recheck -> if still ambig, ask a clarifying question
|
||||
"""
|
||||
self.to_process = utterance
|
||||
if len(history) > 0:
|
||||
# crop history
|
||||
while len(history.split(" ")) > self.max_length:
|
||||
index = history.find("|||")
|
||||
history = history[index + 4:]
|
||||
|
||||
context = history + " ||| " + utterance
|
||||
inputs = self.tokenizer(context, max_length=self.max_length, truncation=True, padding="max_length",
|
||||
return_tensors="pt")
|
||||
self.to_process = self._resolve_coref(history_consec)
|
||||
|
||||
candidates = self.reference_resolver.generate(input_ids=inputs["input_ids"].to(self.device),
|
||||
attention_mask=inputs["attention_mask"].to(self.device),
|
||||
return_dict_in_generate=True, output_scores=True,
|
||||
num_return_sequences=1,
|
||||
num_beams=5)
|
||||
self.to_process = self.tokenizer.decode(candidates["sequences"][0], skip_special_tokens=True,
|
||||
clean_up_tokenization_spaces=True)
|
||||
|
||||
t1 = threading.Thread(target=self._intentpredictor, name='intent')
|
||||
t2 = threading.Thread(target=self._entityextractor, name='entity')
|
||||
t3 = threading.Thread(target=self._inappropriatedetector, name='offensive')
|
||||
|
||||
t3.start()
|
||||
t1.start()
|
||||
t2.start()
|
||||
|
||||
t3.join()
|
||||
t1.join()
|
||||
t2.join()
|
||||
return {"modified_prompt": self.to_process,
|
||||
"intent": self.intents,
|
||||
"entities": self.entities,
|
||||
"is_offensive": self.offensive,
|
||||
"is_clear": self.clear}
|
||||
intent, score = self._intentpredictor()
|
||||
#print(score)
|
||||
if score > 0.5:
|
||||
entities = self._entityextractor()
|
||||
offense = self._offensepredictor()
|
||||
if intent in ['FINDPAPER', 'FINDDATASET', 'SUMMARIZEPAPER'] and len(entities) == 0:
|
||||
return {"modified_prompt": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense, "is_clear": False}
|
||||
return {"modified_prompt": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense, "is_clear": True}
|
||||
else:
|
||||
if self._ambigpredictor():
|
||||
self.to_process = self._rewrite_query(history_sep)
|
||||
intent, score = self._intentpredictor()
|
||||
entities = self._entityextractor()
|
||||
offense = self._offensepredictor()
|
||||
if score > 0.5 or not self._ambigpredictor():
|
||||
if intent in ['FINDPAPER', 'FINDDATASET', 'SUMMARIZEPAPER'] and len(entities) == 0:
|
||||
return {"modified_prompt": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense, "is_clear": False}
|
||||
return {"modified_prompt": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense,
|
||||
"is_clear": True}
|
||||
else:
|
||||
return {"modified_prompt": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense,
|
||||
"is_clear": False}
|
||||
else:
|
||||
entities = self._entityextractor()
|
||||
offense = self._offensepredictor()
|
||||
if intent in ['FINDPAPER', 'FINDDATASET', 'SUMMARIZEPAPER'] and len(entities) == 0:
|
||||
return {"modified_prompt": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense, "is_clear": False}
|
||||
return {"modified_prompt": self.to_process, "intent": intent, "entities": entities, "is_offensive": offense, "is_clear": True}
|
||||
|
|
|
@ -4,13 +4,13 @@ import faiss
|
|||
from sklearn.metrics.pairwise import cosine_similarity
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
class ResponseGenerator:
|
||||
def __init__(self, index, db,
|
||||
generator, retriever, num_retrieved=1):
|
||||
self.generator = generator
|
||||
def __init__(self, index, db,recommender,generators, retriever, num_retrieved=1):
|
||||
self.generators = generators
|
||||
self.retriever = retriever
|
||||
self.recommender = recommender
|
||||
self.db = db
|
||||
self.index = index
|
||||
self.num_retrieved = num_retrieved
|
||||
|
@ -37,6 +37,13 @@ class ResponseGenerator:
|
|||
else:
|
||||
return {}
|
||||
|
||||
def _get_matching_authors(self, rsrc, author):
|
||||
cand = self.db[rsrc].loc[self.db[rsrc]['author'] == author.lower()].reset_index(drop=True)
|
||||
if not cand.empty:
|
||||
return cand.loc[0]
|
||||
else:
|
||||
return {}
|
||||
|
||||
def _get_matching_topics(self, rsrc, topic):
|
||||
matches = []
|
||||
score = 0.7
|
||||
|
@ -60,13 +67,22 @@ class ResponseGenerator:
|
|||
return self.db[db_type].iloc[[I[0]][0]].reset_index(drop=True).loc[0]
|
||||
|
||||
|
||||
def gen_response(self, utterance, state, history, action):
|
||||
if action == "NoCanDo":
|
||||
return str("I am sorry, I cannot answer to this kind of language")
|
||||
def gen_response(self, action, utterance=None, username=None, state=None, consec_history=None):
|
||||
if action == "Help":
|
||||
return "Hey it's Janet! I am here to help you make use of the datasets and papers in the VRE. I can answer questions whose answers may be inside the papers. I can summarize papers for you. I can also chat with you. So, whichever it is, I am ready to chat!"
|
||||
elif action == "Recommend":
|
||||
prompt = self.recommender.make_recommendation(username)
|
||||
if prompt != "":
|
||||
return prompt
|
||||
else:
|
||||
return "I can help you with exploiting the contents of the VRE, just let me know!"
|
||||
|
||||
elif action == "OffenseReject":
|
||||
return "I am sorry, I cannot answer to this kind of language"
|
||||
|
||||
elif action == "ConvGen":
|
||||
gen_kwargs = {"length_penalty": 2.5, "num_beams":4, "max_length": 20}
|
||||
answer = self.generator('question: '+ utterance + ' context: ' + history , **gen_kwargs)[0]['generated_text']
|
||||
gen_kwargs = {"length_penalty": 2.5, "num_beams":2, "max_length": 30}
|
||||
answer = self.generators['chat']('history: '+ consec_history + ' ' + utterance + ' persona: ' + 'I am Janet. My name is Janet. I am an AI developed by CNR to help VRE users.' , **gen_kwargs)[0]['generated_text']
|
||||
return answer
|
||||
|
||||
elif action == "findPaper":
|
||||
|
@ -75,19 +91,26 @@ class ResponseGenerator:
|
|||
self.paper = self._get_matching_titles('paper_db', entity['value'])
|
||||
links = self._get_resources_links(self.paper)
|
||||
if len(self.paper) > 0 and len(links) > 0:
|
||||
return str("This paper could be helpful: " + self.paper['title'] + '. ' + "It can be downloaded at " + links[0])
|
||||
return str("Here is the paper you want: " + self.paper['title'] + '. ' + "It can be downloaded at " + links[0])
|
||||
else:
|
||||
self.paper = self._search_index('paper_titles_index', 'paper_db', entity['value'])
|
||||
links = self._get_resources_links(self.paper)
|
||||
return str("This paper could be helpful: " + self.paper['title'] + '. ' + "It can be downloaded at " + links[0])
|
||||
return str("This paper could be relevant: " + self.paper['title'] + '. ' + "It can be downloaded at " + links[0])
|
||||
if(entity['entity'] == 'TOPIC'):
|
||||
self.paper = self._get_matching_topics('paper_db', entity['value'])
|
||||
links = self._get_resources_links(self.paper)
|
||||
if len(self.paper) > 0 and len(links) > 0:
|
||||
return str("This paper could be helpful: " + self.paper['title'] + '. ' + "It can be downloaded at " + links[0])
|
||||
return str("This paper could be relevant: " + self.paper['title'] + '. ' + "It can be downloaded at " + links[0])
|
||||
|
||||
if(entity['entity'] == 'AUTHOR'):
|
||||
self.paper = self._get_matching_authors('paper_db', entity['value'])
|
||||
links = self._get_resources_links(self.paper)
|
||||
if len(self.paper) > 0 and len(links) > 0:
|
||||
return str("Here is the paper you want: " + self.paper['title'] + '. ' + "It can be downloaded at " + links[0])
|
||||
|
||||
self.paper = self._search_index('paper_desc_index', 'paper_db', utterance)
|
||||
links = self._get_resources_links(self.paper)
|
||||
return str("This paper could be helpful: " + self.paper['title'] + '. ' + "It can be downloaded at " + links[0])
|
||||
return str("This paper could be relevant: " + self.paper['title'] + '. ' + "It can be downloaded at " + links[0])
|
||||
|
||||
elif action == "findDataset":
|
||||
for entity in state['entities']:
|
||||
|
@ -95,19 +118,26 @@ class ResponseGenerator:
|
|||
self.dataset = self._get_matching_titles('dataset_db', entity['value'])
|
||||
links = self._get_resources_links(self.dataset)
|
||||
if len(self.dataset) > 0 and len(links) > 0:
|
||||
return str("This dataset could be helpful: " + self.dataset['title'] + '. ' + "It can be downloaded at " + links[0])
|
||||
return str("Here is the dataset you wanted: " + self.dataset['title'] + '. ' + "It can be downloaded at " + links[0])
|
||||
else:
|
||||
self.dataset = self._search_index('dataset_titles_index', 'dataset_db', entity['value'])
|
||||
links = self._get_resources_links(self.dataset)
|
||||
return str("This dataset could be helpful: " + self.dataset['title'] + '. ' + "It can be downloaded at " + links[0])
|
||||
return str("This dataset could be relevant: " + self.dataset['title'] + '. ' + "It can be downloaded at " + links[0])
|
||||
if(entity['entity'] == 'TOPIC'):
|
||||
self.dataset = self._get_matching_topics('dataset_db', entity['value'])
|
||||
links = self._get_resources_links(self.dataset)
|
||||
if len(self.dataset) > 0 and len(links) > 0:
|
||||
return str("This dataset could be helpful: " + self.dataset['title'] + '. ' + "It can be downloaded at " + links[0])
|
||||
return str("This dataset could be relevant: " + self.dataset['title'] + '. ' + "It can be downloaded at " + links[0])
|
||||
|
||||
if(entity['entity'] == 'AUTHOR'):
|
||||
self.dataset = self._get_matching_authors('dataset_db', entity['value'])
|
||||
links = self._get_resources_links(self.dataset)
|
||||
if len(self.dataset) > 0 and len(links) > 0:
|
||||
return str("Here is the dataset you want: " + self.dataset['title'] + '. ' + "It can be downloaded at " + links[0])
|
||||
|
||||
self.dataset = self._search_index('dataset_desc_index', 'dataset_db', utterance)
|
||||
links = self._get_resources_links(self.dataset)
|
||||
return str("This dataset could be helpful: " + self.dataset['title'] + '. ' + "It can be downloaded at " + links[0])
|
||||
return str("This dataset could be relevant: " + self.dataset['title'] + '. ' + "It can be downloaded at " + links[0])
|
||||
|
||||
|
||||
elif action == "RetGen":
|
||||
|
@ -117,8 +147,8 @@ class ResponseGenerator:
|
|||
gen_seq = 'question: '+utterance+" context: "+content
|
||||
|
||||
#handle return random 2 answers
|
||||
gen_kwargs = {"length_penalty": 0.5, "num_beams":8, "max_length": 100}
|
||||
answer = self.generator(gen_seq, **gen_kwargs)[0]['generated_text']
|
||||
gen_kwargs = {"length_penalty": 0.5, "num_beams":2, "max_length": 60}
|
||||
answer = self.generators['qa'](gen_seq, **gen_kwargs)[0]['generated_text']
|
||||
return str(answer)
|
||||
|
||||
elif action == "sumPaper":
|
||||
|
@ -135,9 +165,16 @@ class ResponseGenerator:
|
|||
answer = ""
|
||||
for i, row in df.iterrows():
|
||||
gen_seq = 'summarize: '+row['content']
|
||||
gen_kwargs = {"length_penalty": 1.5, "num_beams":8, "max_length": 100}
|
||||
answer = self.generator(gen_seq, **gen_kwargs)[0]['generated_text'] + ' '
|
||||
gen_kwargs = {"length_penalty": 1.5, "num_beams":6, "max_length": 120}
|
||||
answer = self.generators['summ'](gen_seq, **gen_kwargs)[0]['generated_text'] + ' '
|
||||
return answer
|
||||
|
||||
elif action == "Clarify":
|
||||
return str("Can you please clarify?")
|
||||
elif action == "ClarifyResource":
|
||||
if state['intent'] in ['FINDPAPER', 'SUMMARIZEPAPER']:
|
||||
return 'Please specify the title, the topic or the paper of interest.'
|
||||
else:
|
||||
return 'Please specify the title, the topic or the dataset of interest.'
|
||||
elif action == "GenClarify":
|
||||
gen_kwargs = {"length_penalty": 2.5, "num_beams":8, "max_length": 120}
|
||||
question = self.generators['amb']('question: '+ utterance + ' context: ' + consec_history , **gen_kwargs)[0]['generated_text']
|
||||
return question
|
||||
|
|
3
User.py
3
User.py
|
@ -20,8 +20,7 @@ class User:
|
|||
if len(index) > 0:
|
||||
self.interests.at[index[0], 'frequency'] += 1
|
||||
else:
|
||||
self.interests = self.interests.append({'interest': topic, 'frequency': max(
|
||||
self.interests['frequency']) if not self.interests.empty else 6}, ignore_index=True)
|
||||
self.interests = self.interests.append({'interest': topic, 'frequency': max(self.interests['frequency']) if not self.interests.empty else 6}, ignore_index=True)
|
||||
|
||||
self.interests = self.interests.sort_values(by='frequency', ascending=False, ignore_index=True)
|
||||
self.interests.to_json(self.interests_file)
|
||||
|
|
112
main.py
112
main.py
|
@ -1,6 +1,5 @@
|
|||
import os
|
||||
import warnings
|
||||
|
||||
import faiss
|
||||
import torch
|
||||
from flask import Flask, render_template, request, jsonify
|
||||
|
@ -10,45 +9,37 @@ import spacy
|
|||
import spacy_transformers
|
||||
import torch
|
||||
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
|
||||
|
||||
from User import User
|
||||
from VRE import VRE
|
||||
from NLU import NLU
|
||||
from DM import DM
|
||||
from Recommender import Recommender
|
||||
from ResponseGenerator import ResponseGenerator
|
||||
|
||||
|
||||
import pandas as pd
|
||||
import time
|
||||
import threading
|
||||
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
|
||||
|
||||
app = Flask(__name__)
|
||||
#allow frontend address
|
||||
url = os.getenv("FRONTEND_URL_WITH_PORT")
|
||||
cors = CORS(app, resources={r"/predict": {"origins": url}, r"/feedback": {"origins": url}})
|
||||
#cors = CORS(app, resources={r"/predict": {"origins": "*"}, r"/feedback": {"origins": "*"}})
|
||||
|
||||
"""
|
||||
conn = psycopg2.connect(
|
||||
host="janet-pg",
|
||||
host="https://janet-app-db.d4science.org",
|
||||
database=os.getenv("POSTGRES_DB"),
|
||||
user=os.getenv("POSTGRES_USER"),
|
||||
password=os.getenv("POSTGRES_PASSWORD"))
|
||||
"""
|
||||
conn = psycopg2.connect(host="https://janet-app-db.d4science.org",
|
||||
database="janet",
|
||||
user="janet_user",
|
||||
password="2fb5e81fec5a2d906a04")
|
||||
|
||||
cur = conn.cursor()
|
||||
|
||||
|
||||
|
||||
#rg = ResponseGenerator(index)
|
||||
|
||||
def get_response(text):
|
||||
# get response from janet itself
|
||||
return text, 'candAnswer'
|
||||
|
||||
def vre_fetch():
|
||||
while True:
|
||||
time.sleep(1000)
|
||||
|
@ -64,20 +55,26 @@ def user_interest_decay():
|
|||
time.sleep(180)
|
||||
user.decay_interests()
|
||||
|
||||
def recommend():
|
||||
while True:
|
||||
if time.time() - dm.get_recent_state()['time'] > 1000:
|
||||
print("Making Recommendation: ")
|
||||
prompt = rec.make_recommendation(user.username)
|
||||
if prompt != "":
|
||||
print(prompt)
|
||||
time.sleep(1000)
|
||||
|
||||
|
||||
@app.route("/predict", methods=['POST'])
|
||||
def predict():
|
||||
text = request.get_json().get("message")
|
||||
state = nlu.process_utterance(text, dm.get_utt_history())
|
||||
message = {}
|
||||
if text == "<HELP_ON_START>":
|
||||
state = {'help': True, 'inactive': False}
|
||||
dm.update(state)
|
||||
action = dm.next_action()
|
||||
response = rg.gen_response(action)
|
||||
message = {"answer": response}
|
||||
elif text == "<RECOMMEND_ON_IDLE>":
|
||||
state = {'help': False, 'inactive': True}
|
||||
dm.update(state)
|
||||
action = dm.next_action()
|
||||
response = rg.gen_response(action, username=user.username)
|
||||
message = {"answer": response}
|
||||
else:
|
||||
state = nlu.process_utterance(text, dm.get_consec_history(), dm.get_sep_history())
|
||||
state['help'] = False
|
||||
state['inactive'] = False
|
||||
user_interests = []
|
||||
for entity in state['entities']:
|
||||
if entity['entity'] == 'TOPIC':
|
||||
|
@ -85,56 +82,52 @@ def predict():
|
|||
user.update_interests(user_interests)
|
||||
dm.update(state)
|
||||
action = dm.next_action()
|
||||
response = rg.gen_response(state['modified_prompt'], dm.get_recent_state(), dm.get_utt_history(), action)
|
||||
message = {"answer": response, "query": text, "cand": "candidate", "history": dm.get_utt_history(), "modQuery": state['modified_prompt']}
|
||||
self, action, utterance=None, username=None, state=None, consec_history=None
|
||||
response = rg.gen_response(action, utterance=state['modified_query'], state=dm.get_recent_state(), consec_history=dm.get_consec_history())
|
||||
message = {"answer": response, "query": text, "cand": "candidate", "history": dm.get_consec_history(), "modQuery": state['modified_query']}
|
||||
new_state = {'modified_query': response}
|
||||
dm.update(new_state)
|
||||
reply = jsonify(message)
|
||||
#reply.headers.add('Access-Control-Allow-Origin', '*')
|
||||
return reply
|
||||
|
||||
@app.route('/feedback', methods = ['POST'])
|
||||
def feedback():
|
||||
data = request.get_json()['feedback']
|
||||
# Make data frame of above data
|
||||
print(data)
|
||||
#df = pd.DataFrame([data])
|
||||
#file_exists = os.path.isfile('feedback.csv')
|
||||
|
||||
#df = pd.DataFrame(data=[data['response'], data['length'], data['fluency'], data['truthfulness'], data['usefulness'], data['speed']]
|
||||
# ,columns=['response', 'length', 'fluency', 'truthfulness', 'usefulness', 'speed'])
|
||||
#df.to_csv('feedback.csv', mode='a', index=False, header=(not file_exists))
|
||||
cur.execute('INSERT INTO feedback (query, history, janet_modified_query,
|
||||
is_modified_query_correct, user_modified_query,
|
||||
response, preferred_response, response_length_feedback,
|
||||
response_fluency_feedback, response_truth_feedback,
|
||||
response_useful_feedback, response_time_feedback, response_intent)'
|
||||
'VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)',
|
||||
cur.execute('INSERT INTO feedback (query, history, janet_modified_query, is_modified_query_correct, user_modified_query, response, preferred_response, response_length_feedback, response_fluency_feedback, response_truth_feedback, response_useful_feedback, response_time_feedback, response_intent) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)',
|
||||
(data['query'], data['history'], data['modQuery'],
|
||||
data['queryModCorrect'], data['correctQuery'],
|
||||
data['janetResponse'], data['preferredResponse'], data['length'],
|
||||
data['fluency'], data['truthfulness'], data['usefulness'],
|
||||
data['speed'], data['intent'])
|
||||
)
|
||||
|
||||
reply = jsonify({"status": "done"})
|
||||
#reply.headers.add('Access-Control-Allow-Origin', '*')
|
||||
return reply
|
||||
|
||||
if __name__ == "__main__":
|
||||
warnings.filterwarnings("ignore")
|
||||
#load NLU
|
||||
def_tokenizer = AutoTokenizer.from_pretrained("castorini/t5-base-canard")
|
||||
def_reference_resolver = AutoModelForSeq2SeqLM.from_pretrained("castorini/t5-base-canard")
|
||||
def_intent_classifier_dir = "./IntentClassifier/"
|
||||
def_entity_extractor = spacy.load("./EntityExtraction/BestModel")
|
||||
def_offense_filter_dir ="./OffensiveClassifier"
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
device_flag = torch.cuda.current_device() if torch.cuda.is_available() else -1
|
||||
nlu = NLU(device, device_flag, def_reference_resolver, def_tokenizer, def_intent_classifier_dir, def_offense_filter_dir, def_entity_extractor)
|
||||
|
||||
query_rewriter = pipeline("text2text-generation", model="castorini/t5-base-canard")
|
||||
intent_classifier = pipeline("sentiment-analysis", model='./intent_classifier', device=device_flag)
|
||||
entity_extractor = spacy.load("./entity_extractor")
|
||||
offensive_classifier = pipeline("sentiment-analysis", model='./offensive_classifier', device=device_flag)
|
||||
ambig_classifier = pipeline("sentiment-analysis", model='./ambig_classifier', device=device_flag)
|
||||
coref_resolver = spacy.load("en_coreference_web_trf")
|
||||
|
||||
nlu = NLU(query_rewriter, coref_resolver, intent_classifier, offensive_classifier, entity_extractor, ambig_classifier)
|
||||
|
||||
#load retriever and generator
|
||||
def_retriever = SentenceTransformer('./BigRetriever/').to(device)
|
||||
def_generator = pipeline("text2text-generation", model="./generator", device=device_flag)
|
||||
|
||||
retriever = SentenceTransformer('./BigRetriever/').to(device)
|
||||
qa_generator = pipeline("text2text-generation", model="./train_qa", device=device_flag)
|
||||
summ_generator = pipeline("text2text-generation", model="./train_summ", device=device_flag)
|
||||
chat_generator = pipeline("text2text-generation", model="./train_chat", device=device_flag)
|
||||
amb_generator = pipeline("text2text-generation", model="./train_amb_gen", device=device_flag)
|
||||
generators = {'qa': qa_generator,
|
||||
'chat': chat_generator,
|
||||
'amb': amb_generator,
|
||||
'summ': summ_generator}
|
||||
|
||||
#load vre
|
||||
token = '2c1e8f88-461c-42c0-8cc1-b7660771c9a3-843339462'
|
||||
|
@ -148,15 +141,14 @@ if __name__ == "__main__":
|
|||
|
||||
threading.Thread(target=user_interest_decay, name='decayinterest').start()
|
||||
|
||||
|
||||
rec = Recommender(def_retriever)
|
||||
rec = Recommender(retriever)
|
||||
|
||||
dm = DM()
|
||||
rg = ResponseGenerator(index,db,def_generator,def_retriever)
|
||||
threading.Thread(target=recommend, name='recommend').start()
|
||||
|
||||
rg = ResponseGenerator(index,db, recommender, generators, retriever)
|
||||
|
||||
|
||||
cur.execute('CREATE TABLE IF NOT EXISTS feedback (id serial PRIMARY KEY,'
|
||||
cur.execute('CREATE TABLE IF NOT EXISTS feedback_trial (id serial PRIMARY KEY,'
|
||||
'query text NOT NULL,'
|
||||
'history text NOT NULL,'
|
||||
'janet_modified_query text NOT NULL,'
|
||||
|
|
|
@ -11,18 +11,26 @@ regex==2022.6.2
|
|||
requests==2.25.1
|
||||
scikit-learn==1.0.2
|
||||
scipy==1.7.3
|
||||
sentence-transformers==2.2.2
|
||||
sentencepiece==0.1.97
|
||||
sklearn-pandas==1.8.0
|
||||
spacy==3.5.0
|
||||
spacy-transformers==1.2.2
|
||||
spacy==3.4.4
|
||||
spacy-alignments==0.9.0
|
||||
spacy-legacy==3.0.12
|
||||
spacy-loggers==1.0.4
|
||||
spacy-transformers==1.1.9
|
||||
spacy-experimental==0.6.2
|
||||
torch @ https://download.pytorch.org/whl/cu116/torch-1.13.1%2Bcu116-cp38-cp38-linux_x86_64.whl
|
||||
torchaudio @ https://download.pytorch.org/whl/cu116/torchaudio-0.13.1%2Bcu116-cp38-cp38-linux_x86_64.whl
|
||||
torchsummary==1.5.1
|
||||
torchtext==0.14.1
|
||||
sentence-transformers
|
||||
torchvision @ https://download.pytorch.org/whl/cu116/torchvision-0.14.1%2Bcu116-cp38-cp38-linux_x86_64.whl
|
||||
tqdm==4.64.1
|
||||
transformers==4.26.1
|
||||
transformers
|
||||
markupsafe==2.0.1
|
||||
psycopg2==2.9.5
|
||||
en-coreference-web-trf @ https://github.com/explosion/spacy-experimental/releases/download/v0.6.1/en_coreference_web_trf-3.4.0a2-py3-none-any.whl
|
||||
Werkzeug==1.0.1
|
||||
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue