2023-03-30 15:17:54 +02:00
from sentence_transformers import models , SentenceTransformer
from transformers import pipeline , AutoTokenizer , AutoModelForCausalLM
import faiss
from sklearn . metrics . pairwise import cosine_similarity
import numpy as np
import pandas as pd
2023-04-04 05:34:47 +02:00
from datetime import datetime
2023-03-30 15:17:54 +02:00
class ResponseGenerator :
2023-04-04 05:34:47 +02:00
def __init__ ( self , index , db , recommender , generators , retriever , num_retrieved = 1 ) :
self . generators = generators
self . retriever = retriever
self . recommender = recommender
self . db = db
self . index = index
self . num_retrieved = num_retrieved
self . paper = { }
self . dataset = { }
2023-03-30 15:17:54 +02:00
2023-04-04 05:34:47 +02:00
def update_index ( self , index ) :
self . index = index
def update_db ( self , db ) :
self . db = db
2023-03-30 15:17:54 +02:00
2023-04-04 05:34:47 +02:00
def _get_resources_links ( self , item ) :
if len ( item ) == 0 :
return [ ]
links = [ ]
for rsrc in item [ ' resources ' ] :
links . append ( rsrc [ ' url ' ] )
return links
2023-03-30 15:17:54 +02:00
2023-04-04 05:34:47 +02:00
def _get_matching_titles ( self , rsrc , title ) :
cand = self . db [ rsrc ] . loc [ self . db [ rsrc ] [ ' title ' ] == title . lower ( ) ] . reset_index ( drop = True )
if not cand . empty :
return cand . loc [ 0 ]
else :
return { }
def _get_matching_authors ( self , rsrc , author ) :
cand = self . db [ rsrc ] . loc [ self . db [ rsrc ] [ ' author ' ] == author . lower ( ) ] . reset_index ( drop = True )
if not cand . empty :
return cand . loc [ 0 ]
else :
return { }
2023-03-30 15:17:54 +02:00
2023-04-04 05:34:47 +02:00
def _get_matching_topics ( self , rsrc , topic ) :
matches = [ ]
score = 0.7
for i , cand in self . db [ rsrc ] . iterrows ( ) :
for tag in cand [ ' tags ' ] :
sim = cosine_similarity ( np . array ( self . retriever . encode ( [ tag ] ) ) , np . array ( self . retriever . encode ( [ topic . lower ( ) ] ) ) )
if sim > score :
if ( len ( matches ) > 0 ) :
matches [ 0 ] = cand
else :
matches . append ( cand )
score = sim
if len ( matches ) > 0 :
return matches [ 0 ]
else :
return [ ]
2023-03-30 15:17:54 +02:00
2023-04-04 05:34:47 +02:00
def _search_index ( self , index_type , db_type , query ) :
xq = self . retriever . encode ( [ query ] )
D , I = self . index [ index_type ] . search ( xq , self . num_retrieved )
return self . db [ db_type ] . iloc [ [ I [ 0 ] ] [ 0 ] ] . reset_index ( drop = True ) . loc [ 0 ]
2023-03-30 15:17:54 +02:00
2023-04-04 05:34:47 +02:00
def gen_response ( self , action , utterance = None , username = None , state = None , consec_history = None ) :
if action == " Help " :
return " Hey it ' s Janet! I am here to help you make use of the datasets and papers in the VRE. I can answer questions whose answers may be inside the papers. I can summarize papers for you. I can also chat with you. So, whichever it is, I am ready to chat! "
elif action == " Recommend " :
prompt = self . recommender . make_recommendation ( username )
if prompt != " " :
return prompt
else :
return " I can help you with exploiting the contents of the VRE, just let me know! "
2023-03-30 15:17:54 +02:00
2023-04-04 05:34:47 +02:00
elif action == " OffenseReject " :
return " I am sorry, I cannot answer to this kind of language "
2023-03-30 15:17:54 +02:00
2023-04-04 05:34:47 +02:00
elif action == " ConvGen " :
gen_kwargs = { " length_penalty " : 2.5 , " num_beams " : 2 , " max_length " : 30 }
answer = self . generators [ ' chat ' ] ( ' history: ' + consec_history + ' ' + utterance + ' persona: ' + ' I am Janet. My name is Janet. I am an AI developed by CNR to help VRE users. ' , * * gen_kwargs ) [ 0 ] [ ' generated_text ' ]
return answer
elif action == " findPaper " :
for entity in state [ ' entities ' ] :
if ( entity [ ' entity ' ] == ' TITLE ' ) :
self . paper = self . _get_matching_titles ( ' paper_db ' , entity [ ' value ' ] )
links = self . _get_resources_links ( self . paper )
if len ( self . paper ) > 0 and len ( links ) > 0 :
return str ( " Here is the paper you want: " + self . paper [ ' title ' ] + ' . ' + " It can be downloaded at " + links [ 0 ] )
else :
self . paper = self . _search_index ( ' paper_titles_index ' , ' paper_db ' , entity [ ' value ' ] )
links = self . _get_resources_links ( self . paper )
return str ( " This paper could be relevant: " + self . paper [ ' title ' ] + ' . ' + " It can be downloaded at " + links [ 0 ] )
if ( entity [ ' entity ' ] == ' TOPIC ' ) :
self . paper = self . _get_matching_topics ( ' paper_db ' , entity [ ' value ' ] )
links = self . _get_resources_links ( self . paper )
if len ( self . paper ) > 0 and len ( links ) > 0 :
return str ( " This paper could be relevant: " + self . paper [ ' title ' ] + ' . ' + " It can be downloaded at " + links [ 0 ] )
if ( entity [ ' entity ' ] == ' AUTHOR ' ) :
self . paper = self . _get_matching_authors ( ' paper_db ' , entity [ ' value ' ] )
links = self . _get_resources_links ( self . paper )
if len ( self . paper ) > 0 and len ( links ) > 0 :
return str ( " Here is the paper you want: " + self . paper [ ' title ' ] + ' . ' + " It can be downloaded at " + links [ 0 ] )
self . paper = self . _search_index ( ' paper_desc_index ' , ' paper_db ' , utterance )
2023-03-30 15:17:54 +02:00
links = self . _get_resources_links ( self . paper )
2023-04-04 05:34:47 +02:00
return str ( " This paper could be relevant: " + self . paper [ ' title ' ] + ' . ' + " It can be downloaded at " + links [ 0 ] )
elif action == " findDataset " :
for entity in state [ ' entities ' ] :
if ( entity [ ' entity ' ] == ' TITLE ' ) :
self . dataset = self . _get_matching_titles ( ' dataset_db ' , entity [ ' value ' ] )
links = self . _get_resources_links ( self . dataset )
if len ( self . dataset ) > 0 and len ( links ) > 0 :
return str ( " Here is the dataset you wanted: " + self . dataset [ ' title ' ] + ' . ' + " It can be downloaded at " + links [ 0 ] )
else :
self . dataset = self . _search_index ( ' dataset_titles_index ' , ' dataset_db ' , entity [ ' value ' ] )
links = self . _get_resources_links ( self . dataset )
return str ( " This dataset could be relevant: " + self . dataset [ ' title ' ] + ' . ' + " It can be downloaded at " + links [ 0 ] )
if ( entity [ ' entity ' ] == ' TOPIC ' ) :
self . dataset = self . _get_matching_topics ( ' dataset_db ' , entity [ ' value ' ] )
links = self . _get_resources_links ( self . dataset )
if len ( self . dataset ) > 0 and len ( links ) > 0 :
return str ( " This dataset could be relevant: " + self . dataset [ ' title ' ] + ' . ' + " It can be downloaded at " + links [ 0 ] )
2023-03-30 15:17:54 +02:00
2023-04-04 05:34:47 +02:00
if ( entity [ ' entity ' ] == ' AUTHOR ' ) :
self . dataset = self . _get_matching_authors ( ' dataset_db ' , entity [ ' value ' ] )
links = self . _get_resources_links ( self . dataset )
if len ( self . dataset ) > 0 and len ( links ) > 0 :
return str ( " Here is the dataset you want: " + self . dataset [ ' title ' ] + ' . ' + " It can be downloaded at " + links [ 0 ] )
self . dataset = self . _search_index ( ' dataset_desc_index ' , ' dataset_db ' , utterance )
2023-03-30 15:17:54 +02:00
links = self . _get_resources_links ( self . dataset )
2023-04-04 05:34:47 +02:00
return str ( " This dataset could be relevant: " + self . dataset [ ' title ' ] + ' . ' + " It can be downloaded at " + links [ 0 ] )
2023-03-30 15:17:54 +02:00
2023-04-04 05:34:47 +02:00
elif action == " RetGen " :
#retrieve the most relevant paragraph
content = str ( self . _search_index ( ' content_index ' , ' content_db ' , utterance ) [ ' content ' ] )
#generate the answer
gen_seq = ' question: ' + utterance + " context: " + content
#handle return random 2 answers
gen_kwargs = { " length_penalty " : 0.5 , " num_beams " : 2 , " max_length " : 60 }
answer = self . generators [ ' qa ' ] ( gen_seq , * * gen_kwargs ) [ 0 ] [ ' generated_text ' ]
return str ( answer )
2023-03-30 15:17:54 +02:00
2023-04-04 05:34:47 +02:00
elif action == " sumPaper " :
if len ( self . paper ) == 0 :
for entity in state [ ' entities ' ] :
if ( entity [ ' entity ' ] == ' TITLE ' ) :
self . paper = self . _get_matching_titles ( ' paper_db ' , entity [ ' value ' ] )
if ( len ( self . paper ) > 0 ) :
break
if len ( self . paper ) == 0 :
return " I cannot seem to find the requested paper. Try again by specifying the title of the paper. "
#implement that
df = self . db [ ' content_db ' ] [ self . db [ ' content_db ' ] [ ' paperid ' ] == self . paper [ ' id ' ] ]
answer = " "
for i , row in df . iterrows ( ) :
gen_seq = ' summarize: ' + row [ ' content ' ]
gen_kwargs = { " length_penalty " : 1.5 , " num_beams " : 6 , " max_length " : 120 }
answer = self . generators [ ' summ ' ] ( gen_seq , * * gen_kwargs ) [ 0 ] [ ' generated_text ' ] + ' '
return answer
2023-04-09 22:43:13 +02:00
elif action == " Clarify " :
if state [ ' intent ' ] in [ ' FINDPAPER ' , ' SUMMARIZEPAPER ' ] and len ( state [ ' entities ' ] ) == 0 :
if len ( self . paper ) == 0 :
return ' Please specify the title, the topic of the paper of interest. '
elif state [ ' intent ' ] == ' FINDDATASET ' and len ( state [ ' entities ' ] ) == 0 :
if len ( self . dataset ) == 0 :
return ' Please specify the title, the topic of the dataset of interest. '
2023-04-04 05:34:47 +02:00
else :
2023-04-09 22:43:13 +02:00
gen_kwargs = { " length_penalty " : 2.5 , " num_beams " : 8 , " max_length " : 120 }
question = self . generators [ ' amb ' ] ( ' question: ' + utterance + ' context: ' + consec_history , * * gen_kwargs ) [ 0 ] [ ' generated_text ' ]
return question