2023-03-30 15:17:54 +02:00
from sentence_transformers import models , SentenceTransformer
from transformers import pipeline , AutoTokenizer , AutoModelForCausalLM
import faiss
from sklearn . metrics . pairwise import cosine_similarity
import numpy as np
import pandas as pd
2023-04-04 05:34:47 +02:00
from datetime import datetime
2023-03-30 15:17:54 +02:00
class ResponseGenerator :
2023-04-15 10:52:01 +02:00
def __init__ ( self , index , db , recommender , generators , retriever , num_retrieved = 3 ) :
2023-04-04 05:34:47 +02:00
self . generators = generators
self . retriever = retriever
self . recommender = recommender
self . db = db
self . index = index
self . num_retrieved = num_retrieved
self . paper = { }
self . dataset = { }
2023-04-15 10:52:01 +02:00
self . post = { }
2023-03-30 15:17:54 +02:00
2023-04-04 05:34:47 +02:00
def update_index ( self , index ) :
self . index = index
def update_db ( self , db ) :
self . db = db
2023-03-30 15:17:54 +02:00
2023-04-04 05:34:47 +02:00
def _get_resources_links ( self , item ) :
if len ( item ) == 0 :
return [ ]
links = [ ]
for rsrc in item [ ' resources ' ] :
links . append ( rsrc [ ' url ' ] )
return links
2023-03-30 15:17:54 +02:00
2023-04-04 05:34:47 +02:00
def _get_matching_titles ( self , rsrc , title ) :
cand = self . db [ rsrc ] . loc [ self . db [ rsrc ] [ ' title ' ] == title . lower ( ) ] . reset_index ( drop = True )
if not cand . empty :
return cand . loc [ 0 ]
else :
return { }
2023-04-15 10:52:01 +02:00
def _get_matching_authors ( self , rsrc , author , recent = False ) :
2023-04-04 05:34:47 +02:00
cand = self . db [ rsrc ] . loc [ self . db [ rsrc ] [ ' author ' ] == author . lower ( ) ] . reset_index ( drop = True )
if not cand . empty :
2023-04-15 10:52:01 +02:00
if recent :
index = 0
curr = 0
for i , row in cand . iterrows ( ) :
if row [ ' time ' ] > curr :
index = i
curr = row [ ' time ' ]
return cand . loc [ index ]
else :
2023-04-04 05:34:47 +02:00
return cand . loc [ 0 ]
else :
return { }
2023-04-16 20:30:32 +02:00
def _get_most_recent ( self , rsrc ) :
cand = self . db [ rsrc ]
index = 0
curr = 0
for i , row in cand . iterrows ( ) :
if row [ ' time ' ] > curr :
index = i
curr = row [ ' time ' ]
return cand . loc [ index ]
2023-04-04 05:34:47 +02:00
def _get_matching_topics ( self , rsrc , topic ) :
matches = [ ]
score = 0.7
for i , cand in self . db [ rsrc ] . iterrows ( ) :
for tag in cand [ ' tags ' ] :
sim = cosine_similarity ( np . array ( self . retriever . encode ( [ tag ] ) ) , np . array ( self . retriever . encode ( [ topic . lower ( ) ] ) ) )
if sim > score :
if ( len ( matches ) > 0 ) :
matches [ 0 ] = cand
else :
matches . append ( cand )
score = sim
if len ( matches ) > 0 :
return matches [ 0 ]
else :
return [ ]
2023-03-30 15:17:54 +02:00
2023-04-15 10:52:01 +02:00
def _search_index ( self , index_type , db_type , query , multi = False ) :
2023-04-04 05:34:47 +02:00
xq = self . retriever . encode ( [ query ] )
D , I = self . index [ index_type ] . search ( xq , self . num_retrieved )
2023-04-15 10:52:01 +02:00
if multi :
return self . db [ db_type ] . iloc [ [ I [ 0 ] ] [ 0 ] ] . reset_index ( drop = True )
2023-04-04 05:34:47 +02:00
return self . db [ db_type ] . iloc [ [ I [ 0 ] ] [ 0 ] ] . reset_index ( drop = True ) . loc [ 0 ]
2023-03-30 15:17:54 +02:00
2023-04-18 02:59:58 +02:00
def gen_response ( self , action , utterance = None , name = None , username = None , vrename = None , state = None , consec_history = None , chitchat_history = None ) :
2023-04-04 05:34:47 +02:00
if action == " Help " :
2023-04-18 02:59:58 +02:00
return " Hey " + name + " ! it ' s Janet! I am here to help you make use of the datasets and papers in the catalogue of the " + vrename + " VRE. I can answer questions whose answers may be inside the papers. I can summarize papers for you. I can also chat with you. So, whichever it is, I am ready to chat! "
2023-04-04 05:34:47 +02:00
elif action == " Recommend " :
2023-04-18 02:59:58 +02:00
prompt = self . recommender . make_recommendation ( username , name )
2023-04-04 05:34:47 +02:00
if prompt != " " :
return prompt
else :
return " I can help you with exploiting the contents of the VRE, just let me know! "
2023-03-30 15:17:54 +02:00
2023-04-04 05:34:47 +02:00
elif action == " OffenseReject " :
return " I am sorry, I cannot answer to this kind of language "
2023-04-15 10:52:01 +02:00
elif action == " getHelp " :
return " I can answer questions related to the papers in the VRE ' s catalog. I can also get you the posts, papers and datasets from the catalogue if you specify a topic or an author. I am also capable of small talk and summarizing papers to an extent. Just text me what you want and I will do it :) "
elif action == " findPost " :
for entity in state [ ' entities ' ] :
if ( entity [ ' entity ' ] == ' TOPIC ' ) :
self . post = self . _get_matching_topics ( ' post_db ' , entity [ ' value ' ] )
if len ( self . post ) > 0 :
return str ( " This is a relevant post: " + self . post [ ' content ' ] + ' by ' + self . post [ ' author ' ] )
if ( entity [ ' entity ' ] == ' AUTHOR ' ) :
self . post = self . _get_matching_authors ( ' post_db ' , entity [ ' value ' ] , recent = True )
if len ( self . post ) > 0 :
if len ( self . post [ ' tags ' ] ) > 0 :
2023-04-16 20:30:32 +02:00
return str ( " Here is the most recent post by: " + self . post [ ' author ' ] + ' , which is about ' + ' , ' . join ( self . post [ ' tags ' ] ) + ' . ' + self . post [ ' content ' ] )
2023-04-15 10:52:01 +02:00
else :
return str ( " Here is the most recent post by: " + self . post [ ' author ' ] + ' , ' + self . post [ ' content ' ] )
2023-04-16 02:44:09 +02:00
if len ( self . post ) > 0 :
ev = self . post [ ' content ' ]
#generate the answer
gen_seq = ' question: ' + utterance + " context: " + ev
gen_kwargs = { " length_penalty " : 0.5 , " num_beams " : 2 , " max_length " : 60 , " repetition_penalty " : 2.5 , " temperature " : 2 }
answer = self . generators [ ' qa ' ] ( gen_seq , * * gen_kwargs ) [ 0 ] [ ' generated_text ' ]
if len ( self . post [ ' tags ' ] ) > 0 :
return " The post is about: " + answer + " \n There is a special focus on " + ' , ' . join ( self . post [ ' tags ' ] )
else :
return " The post is about: " + answer
2023-04-16 20:30:32 +02:00
self . post = self . _get_most_recent ( ' post_db ' )
return " This is the most recent post. " + self . post [ ' content ' ] + ' \n If you want another post, please rewrite the query specifying either the author or the topic. '
2023-04-15 10:52:01 +02:00
2023-04-04 05:34:47 +02:00
elif action == " ConvGen " :
2023-04-10 00:47:26 +02:00
gen_kwargs = { " length_penalty " : 2.5 , " num_beams " : 2 , " max_length " : 30 , " repetition_penalty " : 2.5 , " temperature " : 2 }
2023-04-15 10:52:01 +02:00
#answer = self.generators['chat']('history: '+ consec_history + ' ' + utterance + ' persona: ' + 'I am Janet. My name is Janet. I am an AI developed by CNR to help VRE users.' , **gen_kwargs)[0]['generated_text']
2023-04-17 09:45:22 +02:00
answer = self . generators [ ' chat ' ] ( ' question: ' + utterance + ' context: My name is Janet. I am an AI developed by CNR to help VRE users. ' + chitchat_history , * * gen_kwargs ) [ 0 ] [ ' generated_text ' ]
2023-04-04 05:34:47 +02:00
return answer
elif action == " findPaper " :
for entity in state [ ' entities ' ] :
if ( entity [ ' entity ' ] == ' TITLE ' ) :
self . paper = self . _get_matching_titles ( ' paper_db ' , entity [ ' value ' ] )
links = self . _get_resources_links ( self . paper )
if len ( self . paper ) > 0 and len ( links ) > 0 :
return str ( " Here is the paper you want: " + self . paper [ ' title ' ] + ' . ' + " It can be downloaded at " + links [ 0 ] )
else :
self . paper = self . _search_index ( ' paper_titles_index ' , ' paper_db ' , entity [ ' value ' ] )
links = self . _get_resources_links ( self . paper )
return str ( " This paper could be relevant: " + self . paper [ ' title ' ] + ' . ' + " It can be downloaded at " + links [ 0 ] )
if ( entity [ ' entity ' ] == ' TOPIC ' ) :
self . paper = self . _get_matching_topics ( ' paper_db ' , entity [ ' value ' ] )
links = self . _get_resources_links ( self . paper )
if len ( self . paper ) > 0 and len ( links ) > 0 :
return str ( " This paper could be relevant: " + self . paper [ ' title ' ] + ' . ' + " It can be downloaded at " + links [ 0 ] )
if ( entity [ ' entity ' ] == ' AUTHOR ' ) :
self . paper = self . _get_matching_authors ( ' paper_db ' , entity [ ' value ' ] )
links = self . _get_resources_links ( self . paper )
if len ( self . paper ) > 0 and len ( links ) > 0 :
return str ( " Here is the paper you want: " + self . paper [ ' title ' ] + ' . ' + " It can be downloaded at " + links [ 0 ] )
self . paper = self . _search_index ( ' paper_desc_index ' , ' paper_db ' , utterance )
2023-03-30 15:17:54 +02:00
links = self . _get_resources_links ( self . paper )
2023-04-04 05:34:47 +02:00
return str ( " This paper could be relevant: " + self . paper [ ' title ' ] + ' . ' + " It can be downloaded at " + links [ 0 ] )
elif action == " findDataset " :
for entity in state [ ' entities ' ] :
if ( entity [ ' entity ' ] == ' TITLE ' ) :
self . dataset = self . _get_matching_titles ( ' dataset_db ' , entity [ ' value ' ] )
links = self . _get_resources_links ( self . dataset )
if len ( self . dataset ) > 0 and len ( links ) > 0 :
return str ( " Here is the dataset you wanted: " + self . dataset [ ' title ' ] + ' . ' + " It can be downloaded at " + links [ 0 ] )
else :
self . dataset = self . _search_index ( ' dataset_titles_index ' , ' dataset_db ' , entity [ ' value ' ] )
links = self . _get_resources_links ( self . dataset )
return str ( " This dataset could be relevant: " + self . dataset [ ' title ' ] + ' . ' + " It can be downloaded at " + links [ 0 ] )
if ( entity [ ' entity ' ] == ' TOPIC ' ) :
self . dataset = self . _get_matching_topics ( ' dataset_db ' , entity [ ' value ' ] )
links = self . _get_resources_links ( self . dataset )
if len ( self . dataset ) > 0 and len ( links ) > 0 :
return str ( " This dataset could be relevant: " + self . dataset [ ' title ' ] + ' . ' + " It can be downloaded at " + links [ 0 ] )
2023-03-30 15:17:54 +02:00
2023-04-04 05:34:47 +02:00
if ( entity [ ' entity ' ] == ' AUTHOR ' ) :
self . dataset = self . _get_matching_authors ( ' dataset_db ' , entity [ ' value ' ] )
links = self . _get_resources_links ( self . dataset )
if len ( self . dataset ) > 0 and len ( links ) > 0 :
return str ( " Here is the dataset you want: " + self . dataset [ ' title ' ] + ' . ' + " It can be downloaded at " + links [ 0 ] )
self . dataset = self . _search_index ( ' dataset_desc_index ' , ' dataset_db ' , utterance )
2023-03-30 15:17:54 +02:00
links = self . _get_resources_links ( self . dataset )
2023-04-04 05:34:47 +02:00
return str ( " This dataset could be relevant: " + self . dataset [ ' title ' ] + ' . ' + " It can be downloaded at " + links [ 0 ] )
2023-03-30 15:17:54 +02:00
2023-04-04 05:34:47 +02:00
elif action == " RetGen " :
#retrieve the most relevant paragraph
2023-04-15 10:52:01 +02:00
content = self . _search_index ( ' content_index ' , ' content_db ' , utterance , multi = True ) #['content']
evidence = " "
ev = " "
for i , row in content . iterrows ( ) :
evidence = evidence + str ( i + 1 ) + " ) " + row [ ' content ' ] + ' \n '
ev = ev + " " + row [ ' content ' ]
2023-04-04 05:34:47 +02:00
#generate the answer
2023-04-15 10:52:01 +02:00
gen_seq = ' question: ' + utterance + " context: " + ev
2023-04-04 05:34:47 +02:00
#handle return random 2 answers
2023-04-10 00:47:26 +02:00
gen_kwargs = { " length_penalty " : 0.5 , " num_beams " : 2 , " max_length " : 60 , " repetition_penalty " : 2.5 , " temperature " : 2 }
2023-04-04 05:34:47 +02:00
answer = self . generators [ ' qa ' ] ( gen_seq , * * gen_kwargs ) [ 0 ] [ ' generated_text ' ]
2023-04-16 02:44:09 +02:00
return " According to the following evidence: " + evidence + " \n _______ \n " + " The answer is: " + answer
2023-03-30 15:17:54 +02:00
2023-04-04 05:34:47 +02:00
elif action == " sumPaper " :
if len ( self . paper ) == 0 :
for entity in state [ ' entities ' ] :
if ( entity [ ' entity ' ] == ' TITLE ' ) :
self . paper = self . _get_matching_titles ( ' paper_db ' , entity [ ' value ' ] )
if ( len ( self . paper ) > 0 ) :
break
if len ( self . paper ) == 0 :
return " I cannot seem to find the requested paper. Try again by specifying the title of the paper. "
#implement that
df = self . db [ ' content_db ' ] [ self . db [ ' content_db ' ] [ ' paperid ' ] == self . paper [ ' id ' ] ]
answer = " "
for i , row in df . iterrows ( ) :
gen_seq = ' summarize: ' + row [ ' content ' ]
2023-04-15 10:52:01 +02:00
gen_kwargs = { " length_penalty " : 1.5 , " num_beams " : 6 , " max_length " : 30 , " repetition_penalty " : 2.5 , " temperature " : 2 }
answer = answer + self . generators [ ' summ ' ] ( gen_seq , * * gen_kwargs ) [ 0 ] [ ' generated_text ' ] + ' '
2023-04-04 05:34:47 +02:00
return answer
2023-04-09 22:43:13 +02:00
elif action == " Clarify " :
if state [ ' intent ' ] in [ ' FINDPAPER ' , ' SUMMARIZEPAPER ' ] and len ( state [ ' entities ' ] ) == 0 :
if len ( self . paper ) == 0 :
return ' Please specify the title, the topic of the paper of interest. '
elif state [ ' intent ' ] == ' FINDDATASET ' and len ( state [ ' entities ' ] ) == 0 :
if len ( self . dataset ) == 0 :
return ' Please specify the title, the topic of the dataset of interest. '
2023-04-15 10:52:01 +02:00
elif state [ ' intent ' ] == ' EXPLAINPOST ' and len ( state [ ' entities ' ] ) == 0 :
2023-04-16 02:44:09 +02:00
if len ( self . post ) != 0 :
return self . gen_response ( action = " findPost " , utterance = utterance , username = username , state = state , consec_history = consec_history )
2023-04-15 10:52:01 +02:00
return ' Please specify the the topic or the author of the post. '
2023-04-04 05:34:47 +02:00
else :
2023-04-10 00:47:26 +02:00
gen_kwargs = { " length_penalty " : 2.5 , " num_beams " : 8 , " max_length " : 120 , " repetition_penalty " : 2.5 , " temperature " : 2 }
2023-04-09 22:43:13 +02:00
question = self . generators [ ' amb ' ] ( ' question: ' + utterance + ' context: ' + consec_history , * * gen_kwargs ) [ 0 ] [ ' generated_text ' ]
return question
2023-04-16 02:44:09 +02:00
return " I am unable to generate the response. Can you please provide me with a prefered response in the feedback form so I can learn? "