backendrepo
This commit is contained in:
commit
2d4989a81c
|
@ -0,0 +1,43 @@
|
|||
import time
|
||||
|
||||
class DM:
|
||||
def __init__(self):
|
||||
self.utt_history = ""
|
||||
self.history = []
|
||||
self.state = None
|
||||
|
||||
def get_utt_history(self):
|
||||
return self.utt_history
|
||||
|
||||
def get_recent_state(self):
|
||||
return self.state
|
||||
|
||||
def get_dialogue_state(self):
|
||||
return self.history
|
||||
|
||||
def update(self, new_state):
|
||||
self.history.append(new_state)
|
||||
self.utt_history = self.utt_history + " ||| " + new_state['modified_prompt']
|
||||
self.state = {'intent': new_state['intent'],
|
||||
'entities': new_state['entities'],
|
||||
'offensive': new_state['is_offensive'],
|
||||
'clear': new_state['is_clear'],
|
||||
'time': time.time()}
|
||||
|
||||
def next_action(self):
|
||||
if self.state['clear']:
|
||||
if self.state['offensive']:
|
||||
return "NoCanDo"
|
||||
else:
|
||||
if self.state['intent'] == 0:
|
||||
return "RetGen"
|
||||
elif self.state['intent'] == 1:
|
||||
return "ConvGen"
|
||||
elif self.state['intent'] == 2:
|
||||
return "findPaper"
|
||||
elif self.state['intent'] == 3:
|
||||
return "findDataset"
|
||||
elif self.state['intent'] == 4:
|
||||
return "sumPaper"
|
||||
else:
|
||||
return "Clarify"
|
|
@ -0,0 +1,13 @@
|
|||
FROM python:3.8
|
||||
|
||||
WORKDIR .
|
||||
|
||||
COPY requirements.txt .
|
||||
|
||||
RUN pip install -r requirements.txt
|
||||
|
||||
COPY . .
|
||||
|
||||
EXPOSE 4000
|
||||
|
||||
ENTRYPOINT ["python", "main.py"]
|
|
@ -0,0 +1,143 @@
|
|||
"""
|
||||
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
|
||||
import torch
|
||||
|
||||
|
||||
class NLU:
|
||||
def_tokenizer = AutoTokenizer.from_pretrained("castorini/t5-base-canard")
|
||||
def_model = AutoModelForSeq2SeqLM.from_pretrained("castorini/t5-base-canard")
|
||||
def_intent_classifier = pipeline("sentiment-analysis", model="/home/ahmed/PycharmProjects/Janet/JanetBackend/intent_classifier")
|
||||
|
||||
def __init__(self, model=def_model, tokenizer=def_tokenizer, intent_classifier=def_intent_classifier,
|
||||
max_history_length=1024, num_gen_seq=2, score_threshold=0.5):
|
||||
self.input = ""
|
||||
self.output = ""
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
self.max_length = max_history_length
|
||||
self.num_return_sequences = num_gen_seq
|
||||
self.score_threshold = score_threshold
|
||||
self.label2id = {'Greet': 0, 'Bye': 1, 'GetKnowledge': 2, 'ChitChat': 3}
|
||||
self.id2label = {0: 'Greet', 1: 'Bye', 2: 'GetKnowledge', 3: 'ChitChat'}
|
||||
self.intent_classifier = intent_classifier
|
||||
|
||||
def process_utterance(self, utterance, history):
|
||||
if len(history) > 0:
|
||||
# crop history
|
||||
while len(history.split(" ")) > self.max_length:
|
||||
index = history.find("|||")
|
||||
history = history[index + 4:]
|
||||
|
||||
self.input = history + " ||| " + utterance
|
||||
inputs = self.tokenizer(self.input, max_length=self.max_length, truncation=True, padding="max_length",
|
||||
return_tensors="pt")
|
||||
|
||||
candidates = self.model.generate(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"],
|
||||
return_dict_in_generate=True, output_scores=True,
|
||||
num_return_sequences=self.num_return_sequences,
|
||||
num_beams=self.num_return_sequences)
|
||||
for i in range(candidates["sequences"].shape[0]):
|
||||
generated_sentence = self.tokenizer.decode(candidates["sequences"][i], skip_special_tokens=True,
|
||||
clean_up_tokenization_spaces=True)
|
||||
log_scores = candidates['sequences_scores']
|
||||
norm_prob = (torch.exp(log_scores[i]) / torch.exp(log_scores).sum()).item()
|
||||
if norm_prob >= self.score_threshold:
|
||||
self.score_threshold = norm_prob
|
||||
self.output = generated_sentence
|
||||
else:
|
||||
self.output = utterance
|
||||
|
||||
intent = self.label2id[self.intent_classifier(self.output)[0]['label']]
|
||||
intent_conf = self.intent_classifier(self.output)[0]['score']
|
||||
|
||||
return {"modified_prompt": self.output, "mod_confidence": self.score_threshold, "prompt_intent": intent,
|
||||
"intent_confidence": intent_conf}
|
||||
"""
|
||||
|
||||
import threading
|
||||
|
||||
import spacy
|
||||
import spacy_transformers
|
||||
import torch
|
||||
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
|
||||
|
||||
|
||||
class NLU:
|
||||
def __init__(self, device, device_flag, reference_resolver, tokenizer,
|
||||
intent_classifier, offense_filter, entity_extractor,
|
||||
max_history_length=1024):
|
||||
#entity_extractor=def_entity_extractor
|
||||
self.reference_resolver = reference_resolver
|
||||
self.device = device
|
||||
self.reference_resolver.to(device)
|
||||
self.tokenizer = tokenizer
|
||||
self.max_length = max_history_length
|
||||
self.label2idintent = {'QA': 0, 'CHITCHAT': 1, 'FINDPAPER': 2, 'FINDDATASET': 3, 'SUMMARIZEPAPER': 4}
|
||||
self.id2labelintent = {0: 'QA', 1: 'CHITCHAT', 2: 'FINDPAPER', 3: 'FINDDATASET', 4: 'SUMMARIZEPAPER'}
|
||||
self.label2idoffense = {'hate': 0, 'offensive': 1, 'neither': 2}
|
||||
self.id2labeloffense = {0: 'hate', 1: 'offensive', 2: 'neither'}
|
||||
self.intent_classifier = pipeline("sentiment-analysis", model=intent_classifier, device=device_flag)
|
||||
self.entity_extractor = entity_extractor
|
||||
self.offense_filter = pipeline("sentiment-analysis", model=offense_filter, device=device_flag)
|
||||
|
||||
self.intents = None
|
||||
self.entities = None
|
||||
self.offensive = None
|
||||
self.clear = True
|
||||
|
||||
def _intentpredictor(self):
|
||||
self.intents = self.label2idintent[self.intent_classifier(self.to_process)[0]['label']]
|
||||
|
||||
def _entityextractor(self):
|
||||
self.entities = []
|
||||
doc = self.entity_extractor(self.to_process)
|
||||
for entity in doc.ents:
|
||||
if entity.text not in ['.', ',', '?', ';']:
|
||||
self.entities.append({'entity': entity.label_, 'value': entity.text})
|
||||
|
||||
def _inappropriatedetector(self):
|
||||
self.offensive = False
|
||||
is_offensive = self.label2idoffense[self.offense_filter(self.to_process)[0]['label']]
|
||||
if is_offensive == 0 or is_offensive == 1:
|
||||
self.offensive = True
|
||||
|
||||
def process_utterance(self, utterance, history):
|
||||
"""
|
||||
Given an utterance and the history of the conversation, refine the query contextually and return a refined
|
||||
utterance
|
||||
"""
|
||||
self.to_process = utterance
|
||||
if len(history) > 0:
|
||||
# crop history
|
||||
while len(history.split(" ")) > self.max_length:
|
||||
index = history.find("|||")
|
||||
history = history[index + 4:]
|
||||
|
||||
context = history + " ||| " + utterance
|
||||
inputs = self.tokenizer(context, max_length=self.max_length, truncation=True, padding="max_length",
|
||||
return_tensors="pt")
|
||||
|
||||
candidates = self.reference_resolver.generate(input_ids=inputs["input_ids"].to(self.device),
|
||||
attention_mask=inputs["attention_mask"].to(self.device),
|
||||
return_dict_in_generate=True, output_scores=True,
|
||||
num_return_sequences=1,
|
||||
num_beams=5)
|
||||
self.to_process = self.tokenizer.decode(candidates["sequences"][0], skip_special_tokens=True,
|
||||
clean_up_tokenization_spaces=True)
|
||||
|
||||
t1 = threading.Thread(target=self._intentpredictor, name='intent')
|
||||
t2 = threading.Thread(target=self._entityextractor, name='entity')
|
||||
t3 = threading.Thread(target=self._inappropriatedetector, name='offensive')
|
||||
|
||||
t3.start()
|
||||
t1.start()
|
||||
t2.start()
|
||||
|
||||
t3.join()
|
||||
t1.join()
|
||||
t2.join()
|
||||
return {"modified_prompt": self.to_process,
|
||||
"intent": self.intents,
|
||||
"entities": self.entities,
|
||||
"is_offensive": self.offensive,
|
||||
"is_clear": self.clear}
|
|
@ -0,0 +1,42 @@
|
|||
import numpy as np
|
||||
from sklearn.metrics.pairwise import cosine_similarity
|
||||
import random
|
||||
|
||||
class Recommender:
|
||||
def __init__(self, retriever):
|
||||
self.curr_recommendations = []
|
||||
self.recommended = []
|
||||
self.retriever = retriever
|
||||
self.rand_seed = 5
|
||||
|
||||
def _match_tags(self, material, interest):
|
||||
score = 0.7
|
||||
for tag in material['tags']:
|
||||
if cosine_similarity(np.array(self.retriever.encode([tag])),
|
||||
np.array(self.retriever.encode([interest]))) > score:
|
||||
if material not in self.curr_recommendations:
|
||||
self.curr_recommendations.append(material)
|
||||
self.recommended.append(False)
|
||||
|
||||
def generate_recommendations(self, interests, new_material):
|
||||
for interest in interests:
|
||||
for material in new_material:
|
||||
self._match_tags(material, interest)
|
||||
|
||||
def make_recommendation(self, user):
|
||||
if len(self.curr_recommendations) == 0:
|
||||
return ""
|
||||
index = random.choice(list(range(0, len(self.curr_recommendations))))
|
||||
while self.recommended[index] == True:
|
||||
index = random.choice(list(range(0, len(self.curr_recommendations))))
|
||||
recommendation = "Hey " + user + "! This " + self.curr_recommendations[index][
|
||||
'type'].lower() + " about " + ', '.join(
|
||||
self.curr_recommendations[index]['tags']).lower() + " was posted recently by " + \
|
||||
self.curr_recommendations[index][
|
||||
'author'].lower() + " on the catalogue. You may wanna check it out! It is titled " + \
|
||||
self.curr_recommendations[index]['title'].lower() + ". Cheers, Janet"
|
||||
# self.curr_recommendations.remove(self.curr_recommendations[index])
|
||||
self.recommended[index] = True
|
||||
return recommendation
|
||||
|
||||
|
|
@ -0,0 +1,143 @@
|
|||
from sentence_transformers import models, SentenceTransformer
|
||||
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
|
||||
import faiss
|
||||
from sklearn.metrics.pairwise import cosine_similarity
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
|
||||
class ResponseGenerator:
|
||||
def __init__(self, index, db,
|
||||
generator, retriever, num_retrieved=1):
|
||||
self.generator = generator
|
||||
self.retriever = retriever
|
||||
self.db = db
|
||||
self.index = index
|
||||
self.num_retrieved = num_retrieved
|
||||
self.paper = {}
|
||||
self.dataset = {}
|
||||
|
||||
def update_index(self, index):
|
||||
self.index = index
|
||||
def update_db(self, db):
|
||||
self.db = db
|
||||
|
||||
def _get_resources_links(self, item):
|
||||
if len(item) == 0:
|
||||
return []
|
||||
links = []
|
||||
for rsrc in item['resources']:
|
||||
links.append(rsrc['url'])
|
||||
return links
|
||||
|
||||
def _get_matching_titles(self, rsrc, title):
|
||||
cand = self.db[rsrc].loc[self.db[rsrc]['title'] == title.lower()].reset_index(drop=True)
|
||||
if not cand.empty:
|
||||
return cand.loc[0]
|
||||
else:
|
||||
return {}
|
||||
|
||||
def _get_matching_topics(self, rsrc, topic):
|
||||
matches = []
|
||||
score = 0.7
|
||||
for i, cand in self.db[rsrc].iterrows():
|
||||
for tag in cand['tags']:
|
||||
sim = cosine_similarity(np.array(self.retriever.encode([tag])), np.array(self.retriever.encode([topic.lower()])))
|
||||
if sim > score:
|
||||
if(len(matches)>0):
|
||||
matches[0] = cand
|
||||
else:
|
||||
matches.append(cand)
|
||||
score = sim
|
||||
if len(matches) > 0:
|
||||
return matches[0]
|
||||
else:
|
||||
return []
|
||||
|
||||
def _search_index(self, index_type, db_type, query):
|
||||
xq = self.retriever.encode([query])
|
||||
D, I = self.index[index_type].search(xq, self.num_retrieved)
|
||||
return self.db[db_type].iloc[[I[0]][0]].reset_index(drop=True).loc[0]
|
||||
|
||||
|
||||
def gen_response(self, utterance, state, history, action):
|
||||
if action == "NoCanDo":
|
||||
return str("I am sorry, I cannot answer to this kind of language")
|
||||
|
||||
elif action == "ConvGen":
|
||||
gen_kwargs = {"length_penalty": 2.5, "num_beams":4, "max_length": 20}
|
||||
answer = self.generator('question: '+ utterance + ' context: ' + history , **gen_kwargs)[0]['generated_text']
|
||||
return answer
|
||||
|
||||
elif action == "findPaper":
|
||||
for entity in state['entities']:
|
||||
if (entity['entity'] == 'TITLE'):
|
||||
self.paper = self._get_matching_titles('paper_db', entity['value'])
|
||||
links = self._get_resources_links(self.paper)
|
||||
if len(self.paper) > 0 and len(links) > 0:
|
||||
return str("This paper could be helpful: " + self.paper['title'] + '. ' + "It can be downloaded at " + links[0])
|
||||
else:
|
||||
self.paper = self._search_index('paper_titles_index', 'paper_db', entity['value'])
|
||||
links = self._get_resources_links(self.paper)
|
||||
return str("This paper could be helpful: " + self.paper['title'] + '. ' + "It can be downloaded at " + links[0])
|
||||
if(entity['entity'] == 'TOPIC'):
|
||||
self.paper = self._get_matching_topics('paper_db', entity['value'])
|
||||
links = self._get_resources_links(self.paper)
|
||||
if len(self.paper) > 0 and len(links) > 0:
|
||||
return str("This paper could be helpful: " + self.paper['title'] + '. ' + "It can be downloaded at " + links[0])
|
||||
self.paper = self._search_index('paper_desc_index', 'paper_db', utterance)
|
||||
links = self._get_resources_links(self.paper)
|
||||
return str("This paper could be helpful: " + self.paper['title'] + '. ' + "It can be downloaded at " + links[0])
|
||||
|
||||
elif action == "findDataset":
|
||||
for entity in state['entities']:
|
||||
if (entity['entity'] == 'TITLE'):
|
||||
self.dataset = self._get_matching_titles('dataset_db', entity['value'])
|
||||
links = self._get_resources_links(self.dataset)
|
||||
if len(self.dataset) > 0 and len(links) > 0:
|
||||
return str("This dataset could be helpful: " + self.dataset['title'] + '. ' + "It can be downloaded at " + links[0])
|
||||
else:
|
||||
self.dataset = self._search_index('dataset_titles_index', 'dataset_db', entity['value'])
|
||||
links = self._get_resources_links(self.dataset)
|
||||
return str("This dataset could be helpful: " + self.dataset['title'] + '. ' + "It can be downloaded at " + links[0])
|
||||
if(entity['entity'] == 'TOPIC'):
|
||||
self.dataset = self._get_matching_topics('dataset_db', entity['value'])
|
||||
links = self._get_resources_links(self.dataset)
|
||||
if len(self.dataset) > 0 and len(links) > 0:
|
||||
return str("This dataset could be helpful: " + self.dataset['title'] + '. ' + "It can be downloaded at " + links[0])
|
||||
self.dataset = self._search_index('dataset_desc_index', 'dataset_db', utterance)
|
||||
links = self._get_resources_links(self.dataset)
|
||||
return str("This dataset could be helpful: " + self.dataset['title'] + '. ' + "It can be downloaded at " + links[0])
|
||||
|
||||
|
||||
elif action == "RetGen":
|
||||
#retrieve the most relevant paragraph
|
||||
content = str(self._search_index('content_index', 'content_db', utterance)['content'])
|
||||
#generate the answer
|
||||
gen_seq = 'question: '+utterance+" context: "+content
|
||||
|
||||
#handle return random 2 answers
|
||||
gen_kwargs = {"length_penalty": 0.5, "num_beams":8, "max_length": 100}
|
||||
answer = self.generator(gen_seq, **gen_kwargs)[0]['generated_text']
|
||||
return str(answer)
|
||||
|
||||
elif action == "sumPaper":
|
||||
if len(self.paper) == 0:
|
||||
for entity in state['entities']:
|
||||
if (entity['entity'] == 'TITLE'):
|
||||
self.paper = self._get_matching_titles('paper_db', entity['value'])
|
||||
if (len(self.paper) > 0):
|
||||
break
|
||||
if len(self.paper) == 0:
|
||||
return "I cannot seem to find the requested paper. Try again by specifying the title of the paper."
|
||||
#implement that
|
||||
df = self.db['content_db'][self.db['content_db']['paperid'] == self.paper['id']]
|
||||
answer = ""
|
||||
for i, row in df.iterrows():
|
||||
gen_seq = 'summarize: '+row['content']
|
||||
gen_kwargs = {"length_penalty": 1.5, "num_beams":8, "max_length": 100}
|
||||
answer = self.generator(gen_seq, **gen_kwargs)[0]['generated_text'] + ' '
|
||||
return answer
|
||||
|
||||
elif action == "Clarify":
|
||||
return str("Can you please clarify?")
|
|
@ -0,0 +1,39 @@
|
|||
import pandas as pd
|
||||
import os
|
||||
|
||||
|
||||
class User:
|
||||
def __init__(self, username, token, num_interests=3, directory='./', interests_file='interests.json'):
|
||||
self.username = username
|
||||
self.token = token
|
||||
self.num_interests = num_interests
|
||||
self.interests_file = directory + username+'_'+interests_file
|
||||
self.interests = pd.read_json(self.interests_file) if os.path.isfile(self.interests_file) else pd.DataFrame(columns=['interest', 'frequency']) # {'interest': 'frequency':}
|
||||
|
||||
def initialize(self):
|
||||
if self.interests.empty:
|
||||
self.interests = pd.DataFrame(columns=['interest', 'frequency'])
|
||||
|
||||
def update_interests(self, topics):
|
||||
for topic in topics:
|
||||
index = self.interests.index[self.interests['interest'] == topic]
|
||||
if len(index) > 0:
|
||||
self.interests.at[index[0], 'frequency'] += 1
|
||||
else:
|
||||
self.interests = self.interests.append({'interest': topic, 'frequency': max(
|
||||
self.interests['frequency']) if not self.interests.empty else 6}, ignore_index=True)
|
||||
|
||||
self.interests = self.interests.sort_values(by='frequency', ascending=False, ignore_index=True)
|
||||
self.interests.to_json(self.interests_file)
|
||||
|
||||
def decay_interests(self):
|
||||
for i, interest in self.interests.iterrows():
|
||||
if interest['frequency'] > 1:
|
||||
self.interests.at[i, 'frequency'] -= 1
|
||||
|
||||
def get_user_interests(self):
|
||||
current_interests = []
|
||||
for i, row in self.interests.iterrows():
|
||||
if i < self.num_interests:
|
||||
current_interests.append(row['interest'])
|
||||
return current_interests
|
|
@ -0,0 +1,235 @@
|
|||
from datetime import datetime
|
||||
import pandas as pd
|
||||
import requests
|
||||
import os
|
||||
from io import BytesIO
|
||||
import PyPDF2
|
||||
from tqdm.auto import tqdm
|
||||
import numpy as np
|
||||
import math
|
||||
import faiss
|
||||
import time
|
||||
import threading
|
||||
|
||||
class VRE:
|
||||
def __init__(self, name, token, retriever, directory='./'):
|
||||
self.name = name
|
||||
self.token = token
|
||||
self.catalogue_url = 'https://api.d4science.org/catalogue/items/'
|
||||
self.headers = headers = {"gcube-token": self.token, "Accept": "application/json"}
|
||||
self.lastupdatetime = datetime.strptime('2021-01-01T00:00:00.000000', '%Y-%m-%dT%H:%M:%S.%f').timestamp()
|
||||
self.retriever = retriever
|
||||
self.directory = directory
|
||||
self.paper_counter = 0
|
||||
self.dataset_counter = 0
|
||||
self.content_counter = 0
|
||||
self.db = {'paper_db': pd.read_json(self.directory + self.name + '_paper.json') if os.path.isfile(self.directory + self.name + '_paper.json') else pd.DataFrame(columns=['id', 'type', 'resources', 'tags', 'title', 'author', 'notes', 'metadata_created']),
|
||||
'dataset_db': pd.read_json(self.directory + self.name + '_dataset.json') if os.path.isfile(self.directory + self.name + '_dataset.json') else pd.DataFrame(columns=['id', 'type', 'resources', 'tags', 'title', 'author', 'notes', 'metadata_created']),
|
||||
'content_db': pd.read_json(self.directory + self.name + '_content.json') if os.path.isfile(self.directory + self.name + '_content.json') else pd.DataFrame(columns=['id', 'paperid', 'content'])}
|
||||
self.index = {'dataset_titles_index': None if not os.path.isfile(self.directory + 'janet_dataset_titles_index') else faiss.read_index('janet_dataset_titles_index'),
|
||||
'paper_titles_index': None if not os.path.isfile(self.directory + 'janet_paper_titles_index') else faiss.read_index('janet_paper_titles_index'),
|
||||
'dataset_desc_index': None if not os.path.isfile(self.directory + 'janet_dataset_desc_index') else faiss.read_index('janet_dataset_desc_index'),
|
||||
'paper_desc_index': None if not os.path.isfile(self.directory + 'janet_paper_desc_index') else faiss.read_index('janet_paper_desc_index'),
|
||||
'content_index': None if not os.path.isfile(self.directory + 'janet_content_index') else faiss.read_index('janet_content_index')}
|
||||
self.new_income = False
|
||||
|
||||
def init(self):
|
||||
#first run
|
||||
if not os.path.isfile(self.directory + self.name + '_dataset' + '.json') or not os.path.isfile(self.directory + self.name + '_paper' + '.json') or not os.path.isfile(self.directory + self.name + '_content' + '.json'):
|
||||
self.get_content()
|
||||
if self.index['dataset_titles_index'] is None:
|
||||
self.create_index('dataset_db', 'title', 'dataset_titles_index', 'janet_dataset_titles_index')
|
||||
self.populate_index('dataset_db', 'title', 'dataset_titles_index', 'janet_dataset_titles_index')
|
||||
|
||||
if self.index['dataset_desc_index'] is None:
|
||||
self.create_index('dataset_db', 'notes', 'dataset_desc_index', 'janet_dataset_desc_index')
|
||||
self.populate_index('dataset_db', 'notes', 'dataset_desc_index', 'janet_dataset_desc_index')
|
||||
|
||||
if self.index['paper_titles_index'] is None:
|
||||
self.create_index('paper_db', 'title', 'paper_titles_index', 'janet_paper_titles_index')
|
||||
self.populate_index('paper_db', 'title', 'paper_titles_index', 'janet_paper_titles_index')
|
||||
|
||||
if self.index['paper_desc_index'] is None:
|
||||
self.create_index('paper_db', 'notes', 'paper_desc_index', 'janet_paper_desc_index')
|
||||
self.populate_index('paper_db', 'notes', 'paper_desc_index', 'janet_paper_desc_index')
|
||||
|
||||
if self.index['content_index'] is None:
|
||||
self.create_index('content_db', 'content', 'content_index', 'janet_content_index')
|
||||
self.populate_index('content_db', 'content', 'content_index', 'janet_content_index')
|
||||
|
||||
|
||||
def index_periodic_update(self):
|
||||
if self.new_income:
|
||||
if len(self.db['content_db'])%100 != 0:
|
||||
self.create_index('content_db', 'content', 'content_index', 'janet_content_index')
|
||||
self.populate_index('content_db', 'content', 'content_index', 'janet_content_index')
|
||||
if len(self.db['paper_db'])%100 != 0:
|
||||
self.create_index('paper_db', 'title', 'paper_titles_index', 'janet_paper_titles_index')
|
||||
self.populate_index('paper_db', 'title', 'paper_titles_index', 'janet_paper_titles_index')
|
||||
self.create_index('paper_db', 'notes', 'paper_desc_index', 'janet_paper_desc_index')
|
||||
self.populate_index('paper_db', 'notes', 'paper_desc_index', 'janet_paper_desc_index')
|
||||
if len(self.db['dataset_db'])%100 != 0:
|
||||
self.create_index('dataset_db', 'title', 'dataset_titles_index', 'janet_dataset_titles_index')
|
||||
self.populate_index('dataset_db', 'title', 'dataset_titles_index', 'janet_dataset_titles_index')
|
||||
self.create_index('dataset_db', 'notes', 'dataset_desc_index', 'janet_dataset_desc_index')
|
||||
self.populate_index('dataset_db', 'notes', 'dataset_desc_index', 'janet_dataset_desc_index')
|
||||
self.new_income = False
|
||||
|
||||
def create_index(self, db_type, attribute, index_type, filename):
|
||||
to_index = self.db[db_type][attribute]
|
||||
for i, info in enumerate(to_index):
|
||||
if i == 0:
|
||||
emb = self.retriever.encode([info])
|
||||
sentence_embeddings = np.array(emb)
|
||||
else:
|
||||
emb = self.retriever.encode([info])
|
||||
sentence_embeddings = np.append(sentence_embeddings, emb, axis=0)
|
||||
|
||||
# number of partitions of the coarse quantizer = number of posting lists
|
||||
# as rule of thumb, 4*sqrt(N) < nlist < 16*sqrt(N), where N is the size of the database
|
||||
nlist = int(4 * math.sqrt(len(sentence_embeddings))) if int(4 * math.sqrt(len(sentence_embeddings))) < len(sentence_embeddings) else len(sentence_embeddings)-1
|
||||
code_size = 8 # = number of subquantizers = number of sub-vectors
|
||||
n_bits = 4 if len(sentence_embeddings) >= 2**4 else int(math.log2(len(sentence_embeddings))) # n_bits of each code (8 -> 1 byte codes)
|
||||
d = sentence_embeddings.shape[1]
|
||||
coarse_quantizer = faiss.IndexFlatL2(d) # will keep centroids of coarse quantizer (for inverted list)
|
||||
self.index[index_type] = faiss.IndexIVFPQ(coarse_quantizer, d, nlist, code_size, n_bits)
|
||||
self.index[index_type].train(sentence_embeddings) # train on a random subset to speed up k-means (NOTE: ensure they are randomly chosen!)
|
||||
faiss.write_index(self.index[index_type], filename)
|
||||
|
||||
def populate_index(self, db_type, attribute, index_type, filename):
|
||||
to_index = self.db[db_type][attribute]
|
||||
for info in to_index:
|
||||
sentence_embedding = np.array(self.retriever.encode([info]))
|
||||
self.index[index_type].add(sentence_embedding)
|
||||
faiss.write_index(self.index[index_type], filename)
|
||||
|
||||
def get_content(self):
|
||||
response = requests.get(self.catalogue_url, headers=self.headers)
|
||||
items = response.json()
|
||||
items_data = []
|
||||
for item in items:
|
||||
api_url = self.catalogue_url + item + '/'
|
||||
response = requests.get(api_url, headers=self.headers)
|
||||
items_data.append(response.json())
|
||||
|
||||
keys = ['type', 'resources', 'tags', 'title', 'author', 'notes', 'metadata_created']
|
||||
|
||||
paper_df = pd.DataFrame(columns=['id', 'type', 'resources', 'tags', 'title', 'author', 'notes', 'metadata_created'])
|
||||
dataset_df = pd.DataFrame(columns=['id', 'type', 'resources', 'tags', 'title', 'author', 'notes', 'metadata_created'])
|
||||
content_df = pd.DataFrame(columns=['id', 'paperid', 'content'])
|
||||
|
||||
for item in items_data:
|
||||
for el in item['extras']:
|
||||
if el['key'] == 'system:type':
|
||||
rsrc = el['value']
|
||||
resources = []
|
||||
for resource in item['resources']:
|
||||
resources.append(
|
||||
{'name': resource['name'].lower(), 'url': resource['url'], 'description': resource['description'].lower()})
|
||||
tags = []
|
||||
for tag in item['tags']:
|
||||
tags.append(tag['name'].lower())
|
||||
title = item['title'].lower()
|
||||
author = item['author'].lower()
|
||||
notes = item['notes'].lower()
|
||||
date = datetime.strptime(item['metadata_created'], '%Y-%m-%dT%H:%M:%S.%f').timestamp()
|
||||
if date > self.lastupdatetime:
|
||||
self.lastupdatetime = date
|
||||
if rsrc == 'Paper':
|
||||
self.paper_counter += 1
|
||||
paper_df.loc[str(self.paper_counter)] = [self.paper_counter, rsrc, resources, tags, title, author, notes, date]
|
||||
content_df = self.get_pdf_content(item, content_df)
|
||||
content_df = self.get_txt_content(item, content_df)
|
||||
if rsrc == 'Dataset':
|
||||
self.dataset_counter += 1
|
||||
dataset_df.loc[str(self.dataset_counter)] = [self.dataset_counter, rsrc, resources, tags, title, author, notes, date]
|
||||
|
||||
self.db['paper_db'] = paper_df.sort_values(by='metadata_created', ascending=True)
|
||||
self.db['dataset_db'] = dataset_df.sort_values(by='metadata_created', ascending=True)
|
||||
self.db['content_db'] = content_df
|
||||
|
||||
self.db['paper_db'].to_json(self.name + '_paper.json')
|
||||
self.db['dataset_db'].to_json(self.name + '_dataset.json')
|
||||
|
||||
self.db['content_db'].to_json(self.name + '_content.json')
|
||||
|
||||
# modify query
|
||||
def get_vre_update(self):
|
||||
print("Getting new items")
|
||||
response = requests.get(self.catalogue_url, headers=self.headers)
|
||||
items = response.json()
|
||||
items_data = []
|
||||
for item in items:
|
||||
api_url = self.catalogue_url + item + '/'
|
||||
response = requests.get(api_url, headers=self.headers)
|
||||
if datetime.strptime(response.json()['metadata_created'],'%Y-%m-%dT%H:%M:%S.%f').timestamp() > self.lastupdatetime:
|
||||
items_data.append(response.json())
|
||||
|
||||
keys = ['type', 'resources', 'tags', 'title', 'author', 'notes', 'metadata_created']
|
||||
|
||||
paper_df = pd.DataFrame(columns=['id', 'type', 'resources', 'tags', 'title', 'author', 'notes', 'metadata_created'])
|
||||
dataset_df = pd.DataFrame(columns=['id', 'type', 'resources', 'tags', 'title', 'author', 'notes', 'metadata_created'])
|
||||
content_df = pd.DataFrame(columns=['id', 'paperid', 'content'])
|
||||
|
||||
for item in items_data:
|
||||
for el in item['extras']:
|
||||
if el['key'] == 'system:type':
|
||||
rsrc = el['value']
|
||||
resources = []
|
||||
for resource in item['resources']:
|
||||
resources.append(
|
||||
{'name': resource['name'].lower(), 'url': resource['url'], 'description': resource['description'].lower()})
|
||||
tags = []
|
||||
for tag in item['tags']:
|
||||
tags.append(tag['name'].lower())
|
||||
title = item['title'].lower()
|
||||
author = item['author'].lower()
|
||||
notes = item['notes'].lower()
|
||||
date = datetime.strptime(item['metadata_created'], '%Y-%m-%dT%H:%M:%S.%f').timestamp()
|
||||
if date > self.lastupdatetime:
|
||||
self.lastupdatetime = date
|
||||
|
||||
if rsrc == 'Paper':
|
||||
self.paper_counter += 1
|
||||
paper_df.loc[str(self.paper_counter)] = [self.paper_counter, rsrc, resources, tags, title, author, notes, date]
|
||||
content_df = self.get_pdf_content(item, content_df)
|
||||
content_df = self.get_txt_content(item, content_df)
|
||||
if rsrc == 'Dataset':
|
||||
self.dataset_counter += 1
|
||||
dataset_df.loc[str(self.dataset_counter)] = [self.dataset_counter, rsrc, resources, tags, title, author, notes, date]
|
||||
|
||||
self.db['paper_db'] = pd.concat([self.db['paper_db'], paper_df.sort_values(by='metadata_created', ascending=True)])
|
||||
self.db['dataset_db'] = pd.concat([self.db['dataset_db'], dataset_df.sort_values(by='metadata_created', ascending=True)])
|
||||
|
||||
self.db['paper_db'].to_json(self.name + '_paper.json')
|
||||
self.db['dataset_db'].to_json(self.name + '_dataset.json')
|
||||
self.db['content_db'] = pd.concat([self.db['content_db'], content_df])
|
||||
self.db['content_db'].to_json(self.name + '_content.json')
|
||||
if not paper_df.empty or not dataset_df.empty or not content_df.empty:
|
||||
self.new_income = True
|
||||
|
||||
def get_pdf_content(self, item, df):
|
||||
for rsrc in tqdm(item['resources']):
|
||||
response = requests.get(rsrc['url'])
|
||||
if 'application/pdf' in response.headers.get('content-type'):
|
||||
my_raw_data = response.content
|
||||
with BytesIO(my_raw_data) as data:
|
||||
read_pdf = PyPDF2.PdfReader(data)
|
||||
for page in tqdm(range(len(read_pdf.pages))):
|
||||
content = read_pdf.pages[page].extract_text()
|
||||
self.content_counter += 1
|
||||
df.loc[str(self.content_counter)] = [self.content_counter, self.paper_counter, content]
|
||||
return df
|
||||
|
||||
def get_txt_content(self, item, df):
|
||||
for rsrc in tqdm(item['resources']):
|
||||
response = requests.get(rsrc['url'])
|
||||
if 'text/plain' in response.headers.get('content-type'):
|
||||
content = response.text
|
||||
self.content_counter += 1
|
||||
df.loc[str(self.content_counter)] = [self.content_counter, self.paper_counter, content]
|
||||
return df
|
||||
def get_db(self):
|
||||
return self.db
|
||||
def get_index(self):
|
||||
return self.index
|
|
@ -0,0 +1,133 @@
|
|||
import os
|
||||
import warnings
|
||||
|
||||
import faiss
|
||||
import torch
|
||||
from flask import Flask, render_template, request, jsonify
|
||||
from flask_cors import CORS, cross_origin
|
||||
import spacy
|
||||
import spacy_transformers
|
||||
import torch
|
||||
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
|
||||
|
||||
from User import User
|
||||
from VRE import VRE
|
||||
from NLU import NLU
|
||||
from DM import DM
|
||||
from Recommender import Recommender
|
||||
from ResponseGenerator import ResponseGenerator
|
||||
|
||||
|
||||
import pandas as pd
|
||||
import time
|
||||
import threading
|
||||
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
app = Flask(__name__)
|
||||
#allow frontend address
|
||||
url = os.getenv("FRONTEND_URL_WITH_PORT")
|
||||
cors = CORS(app, resources={r"/predict": {"origins": url}, r"/feedback": {"origins": url}})
|
||||
#cors = CORS(app, resources={r"/predict": {"origins": "*"}, r"/feedback": {"origins": "*"}})
|
||||
|
||||
|
||||
|
||||
#rg = ResponseGenerator(index)
|
||||
|
||||
def get_response(text):
|
||||
# get response from janet itself
|
||||
return text, 'candAnswer'
|
||||
|
||||
def vre_fetch():
|
||||
while True:
|
||||
time.sleep(1000)
|
||||
print('getting new material')
|
||||
vre.get_vre_update()
|
||||
vre.index_periodic_update()
|
||||
rg.update_index(vre.get_index())
|
||||
rg.update_db(vre.get_db())
|
||||
|
||||
def user_interest_decay():
|
||||
while True:
|
||||
print("decaying interests after 3 minutes")
|
||||
time.sleep(180)
|
||||
user.decay_interests()
|
||||
|
||||
def recommend():
|
||||
while True:
|
||||
if time.time() - dm.get_recent_state()['time'] > 1000:
|
||||
print("Making Recommendation: ")
|
||||
prompt = rec.make_recommendation(user.username)
|
||||
if prompt != "":
|
||||
print(prompt)
|
||||
time.sleep(1000)
|
||||
|
||||
|
||||
@app.route("/predict", methods=['POST'])
|
||||
def predict():
|
||||
text = request.get_json().get("message")
|
||||
state = nlu.process_utterance(text, dm.get_utt_history())
|
||||
user_interests = []
|
||||
for entity in state['entities']:
|
||||
if entity['entity'] == 'TOPIC':
|
||||
user_interests.append(entity['value'])
|
||||
user.update_interests(user_interests)
|
||||
dm.update(state)
|
||||
action = dm.next_action()
|
||||
response = rg.gen_response(state['modified_prompt'], dm.get_recent_state(), dm.get_utt_history(), action)
|
||||
message = {"answer": response, "query": text, "cand": "candidate", "history": dm.get_utt_history(), "modQuery": state['modified_prompt']}
|
||||
reply = jsonify(message)
|
||||
#reply.headers.add('Access-Control-Allow-Origin', '*')
|
||||
return reply
|
||||
|
||||
@app.route('/feedback', methods = ['POST'])
|
||||
def feedback():
|
||||
data = request.get_json()['feedback']
|
||||
# Make data frame of above data
|
||||
print(data)
|
||||
df = pd.DataFrame([data])
|
||||
file_exists = os.path.isfile('feedback.csv')
|
||||
|
||||
#df = pd.DataFrame(data=[data['response'], data['length'], data['fluency'], data['truthfulness'], data['usefulness'], data['speed']]
|
||||
# ,columns=['response', 'length', 'fluency', 'truthfulness', 'usefulness', 'speed'])
|
||||
df.to_csv('feedback.csv', mode='a', index=False, header=(not file_exists))
|
||||
reply = jsonify({"status": "done"})
|
||||
#reply.headers.add('Access-Control-Allow-Origin', '*')
|
||||
return reply
|
||||
|
||||
if __name__ == "__main__":
|
||||
warnings.filterwarnings("ignore")
|
||||
#load NLU
|
||||
def_tokenizer = AutoTokenizer.from_pretrained("castorini/t5-base-canard")
|
||||
def_reference_resolver = AutoModelForSeq2SeqLM.from_pretrained("castorini/t5-base-canard")
|
||||
def_intent_classifier_dir = "./IntentClassifier/"
|
||||
def_entity_extractor = spacy.load("./EntityExtraction/BestModel")
|
||||
def_offense_filter_dir ="./OffensiveClassifier"
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
device_flag = torch.cuda.current_device() if torch.cuda.is_available() else -1
|
||||
nlu = NLU(device, device_flag, def_reference_resolver, def_tokenizer, def_intent_classifier_dir, def_offense_filter_dir, def_entity_extractor)
|
||||
|
||||
#load retriever and generator
|
||||
def_retriever = SentenceTransformer('./BigRetriever/').to(device)
|
||||
def_generator = pipeline("text2text-generation", model="./generator", device=device_flag)
|
||||
|
||||
|
||||
#load vre
|
||||
token = '2c1e8f88-461c-42c0-8cc1-b7660771c9a3-843339462'
|
||||
vre = VRE("assistedlab", token, def_retriever)
|
||||
vre.init()
|
||||
index = vre.get_index()
|
||||
db = vre.get_db()
|
||||
user = User("ahmed", token)
|
||||
|
||||
threading.Thread(target=vre_fetch, name='updatevre').start()
|
||||
|
||||
threading.Thread(target=user_interest_decay, name='decayinterest').start()
|
||||
|
||||
|
||||
rec = Recommender(def_retriever)
|
||||
|
||||
dm = DM()
|
||||
rg = ResponseGenerator(index,db,def_generator,def_retriever)
|
||||
threading.Thread(target=recommend, name='recommend').start()
|
||||
app.run(host='127.0.0.1', port=4000)
|
|
@ -0,0 +1,27 @@
|
|||
faiss-gpu==1.7.2
|
||||
Flask==1.1.4
|
||||
flask-cors==3.0.10
|
||||
protobuf==3.20.0
|
||||
matplotlib==3.5.3
|
||||
nltk==3.7
|
||||
numpy==1.22.4
|
||||
pandas==1.3.5
|
||||
PyPDF2==3.0.1
|
||||
regex==2022.6.2
|
||||
requests==2.25.1
|
||||
scikit-learn==1.0.2
|
||||
scipy==1.7.3
|
||||
sentence-transformers==2.2.2
|
||||
sentencepiece==0.1.97
|
||||
sklearn-pandas==1.8.0
|
||||
spacy==3.5.0
|
||||
spacy-transformers==1.2.2
|
||||
torch @ https://download.pytorch.org/whl/cu116/torch-1.13.1%2Bcu116-cp38-cp38-linux_x86_64.whl
|
||||
torchaudio @ https://download.pytorch.org/whl/cu116/torchaudio-0.13.1%2Bcu116-cp38-cp38-linux_x86_64.whl
|
||||
torchsummary==1.5.1
|
||||
torchtext==0.14.1
|
||||
torchvision @ https://download.pytorch.org/whl/cu116/torchvision-0.14.1%2Bcu116-cp38-cp38-linux_x86_64.whl
|
||||
tqdm==4.64.1
|
||||
transformers==4.26.1
|
||||
markupsafe==2.0.1
|
||||
Werkzeug==1.0.1
|
Loading…
Reference in New Issue