remove_models
This commit is contained in:
parent
87374ec68b
commit
ecd28e6151
|
@ -1,37 +0,0 @@
|
|||
{
|
||||
"_name_or_path": "distilbert-base-uncased",
|
||||
"activation": "gelu",
|
||||
"architectures": [
|
||||
"DistilBertForSequenceClassification"
|
||||
],
|
||||
"attention_dropout": 0.1,
|
||||
"dim": 768,
|
||||
"dropout": 0.1,
|
||||
"hidden_dim": 3072,
|
||||
"id2label": {
|
||||
"0": "clear",
|
||||
"1": "somewhat_clear",
|
||||
"2": "somewhat_ambiguous",
|
||||
"3": "ambiguous"
|
||||
},
|
||||
"initializer_range": 0.02,
|
||||
"label2id": {
|
||||
"ambiguous": 3,
|
||||
"clear": 0,
|
||||
"somewhat_ambiguous": 2,
|
||||
"somewhat_clear": 1
|
||||
},
|
||||
"max_position_embeddings": 512,
|
||||
"model_type": "distilbert",
|
||||
"n_heads": 12,
|
||||
"n_layers": 6,
|
||||
"pad_token_id": 0,
|
||||
"problem_type": "single_label_classification",
|
||||
"qa_dropout": 0.1,
|
||||
"seq_classif_dropout": 0.2,
|
||||
"sinusoidal_pos_embds": false,
|
||||
"tie_weights_": true,
|
||||
"torch_dtype": "float32",
|
||||
"transformers_version": "4.26.1",
|
||||
"vocab_size": 30522
|
||||
}
|
|
@ -1,7 +0,0 @@
|
|||
{
|
||||
"cls_token": "[CLS]",
|
||||
"mask_token": "[MASK]",
|
||||
"pad_token": "[PAD]",
|
||||
"sep_token": "[SEP]",
|
||||
"unk_token": "[UNK]"
|
||||
}
|
File diff suppressed because it is too large
Load Diff
|
@ -1,14 +0,0 @@
|
|||
{
|
||||
"cls_token": "[CLS]",
|
||||
"do_lower_case": true,
|
||||
"mask_token": "[MASK]",
|
||||
"model_max_length": 512,
|
||||
"name_or_path": "distilbert-base-uncased",
|
||||
"pad_token": "[PAD]",
|
||||
"sep_token": "[SEP]",
|
||||
"special_tokens_map_file": null,
|
||||
"strip_accents": null,
|
||||
"tokenize_chinese_chars": true,
|
||||
"tokenizer_class": "DistilBertTokenizer",
|
||||
"unk_token": "[UNK]"
|
||||
}
|
File diff suppressed because it is too large
Load Diff
18
main.py
18
main.py
|
@ -114,20 +114,20 @@ if __name__ == "__main__":
|
|||
device_flag = torch.cuda.current_device() if torch.cuda.is_available() else -1
|
||||
|
||||
query_rewriter = pipeline("text2text-generation", model="castorini/t5-base-canard")
|
||||
intent_classifier = pipeline("sentiment-analysis", model='./intent_classifier', device=device_flag)
|
||||
entity_extractor = spacy.load("./entity_extractor")
|
||||
offensive_classifier = pipeline("sentiment-analysis", model='./offensive_classifier', device=device_flag)
|
||||
ambig_classifier = pipeline("sentiment-analysis", model='./ambig_classifier', device=device_flag)
|
||||
intent_classifier = pipeline("sentiment-analysis", model='/models/intent_classifier', device=device_flag)
|
||||
entity_extractor = spacy.load("/models/entity_extractor")
|
||||
offensive_classifier = pipeline("sentiment-analysis", model='/models/offensive_classifier', device=device_flag)
|
||||
ambig_classifier = pipeline("sentiment-analysis", model='/models/ambig_classifier', device=device_flag)
|
||||
coref_resolver = spacy.load("en_coreference_web_trf")
|
||||
|
||||
nlu = NLU(query_rewriter, coref_resolver, intent_classifier, offensive_classifier, entity_extractor, ambig_classifier)
|
||||
|
||||
#load retriever and generator
|
||||
retriever = SentenceTransformer('./BigRetriever/').to(device)
|
||||
qa_generator = pipeline("text2text-generation", model="./train_qa", device=device_flag)
|
||||
summ_generator = pipeline("text2text-generation", model="./train_summ", device=device_flag)
|
||||
chat_generator = pipeline("text2text-generation", model="./train_chat", device=device_flag)
|
||||
amb_generator = pipeline("text2text-generation", model="./train_amb_gen", device=device_flag)
|
||||
retriever = SentenceTransformer('/models/BigRetriever/').to(device)
|
||||
qa_generator = pipeline("text2text-generation", model="/models/train_qa", device=device_flag)
|
||||
summ_generator = pipeline("text2text-generation", model="/models/train_summ", device=device_flag)
|
||||
chat_generator = pipeline("text2text-generation", model="/models/train_chat", device=device_flag)
|
||||
amb_generator = pipeline("text2text-generation", model="/models/train_amb_gen", device=device_flag)
|
||||
generators = {'qa': qa_generator,
|
||||
'chat': chat_generator,
|
||||
'amb': amb_generator,
|
||||
|
|
Loading…
Reference in New Issue