remove_models

This commit is contained in:
ahmed531998 2023-04-05 11:48:16 +02:00
parent 87374ec68b
commit ecd28e6151
6 changed files with 9 additions and 61266 deletions

View File

@ -1,37 +0,0 @@
{
"_name_or_path": "distilbert-base-uncased",
"activation": "gelu",
"architectures": [
"DistilBertForSequenceClassification"
],
"attention_dropout": 0.1,
"dim": 768,
"dropout": 0.1,
"hidden_dim": 3072,
"id2label": {
"0": "clear",
"1": "somewhat_clear",
"2": "somewhat_ambiguous",
"3": "ambiguous"
},
"initializer_range": 0.02,
"label2id": {
"ambiguous": 3,
"clear": 0,
"somewhat_ambiguous": 2,
"somewhat_clear": 1
},
"max_position_embeddings": 512,
"model_type": "distilbert",
"n_heads": 12,
"n_layers": 6,
"pad_token_id": 0,
"problem_type": "single_label_classification",
"qa_dropout": 0.1,
"seq_classif_dropout": 0.2,
"sinusoidal_pos_embds": false,
"tie_weights_": true,
"torch_dtype": "float32",
"transformers_version": "4.26.1",
"vocab_size": 30522
}

View File

@ -1,7 +0,0 @@
{
"cls_token": "[CLS]",
"mask_token": "[MASK]",
"pad_token": "[PAD]",
"sep_token": "[SEP]",
"unk_token": "[UNK]"
}

File diff suppressed because it is too large Load Diff

View File

@ -1,14 +0,0 @@
{
"cls_token": "[CLS]",
"do_lower_case": true,
"mask_token": "[MASK]",
"model_max_length": 512,
"name_or_path": "distilbert-base-uncased",
"pad_token": "[PAD]",
"sep_token": "[SEP]",
"special_tokens_map_file": null,
"strip_accents": null,
"tokenize_chinese_chars": true,
"tokenizer_class": "DistilBertTokenizer",
"unk_token": "[UNK]"
}

File diff suppressed because it is too large Load Diff

18
main.py
View File

@ -114,20 +114,20 @@ if __name__ == "__main__":
device_flag = torch.cuda.current_device() if torch.cuda.is_available() else -1
query_rewriter = pipeline("text2text-generation", model="castorini/t5-base-canard")
intent_classifier = pipeline("sentiment-analysis", model='./intent_classifier', device=device_flag)
entity_extractor = spacy.load("./entity_extractor")
offensive_classifier = pipeline("sentiment-analysis", model='./offensive_classifier', device=device_flag)
ambig_classifier = pipeline("sentiment-analysis", model='./ambig_classifier', device=device_flag)
intent_classifier = pipeline("sentiment-analysis", model='/models/intent_classifier', device=device_flag)
entity_extractor = spacy.load("/models/entity_extractor")
offensive_classifier = pipeline("sentiment-analysis", model='/models/offensive_classifier', device=device_flag)
ambig_classifier = pipeline("sentiment-analysis", model='/models/ambig_classifier', device=device_flag)
coref_resolver = spacy.load("en_coreference_web_trf")
nlu = NLU(query_rewriter, coref_resolver, intent_classifier, offensive_classifier, entity_extractor, ambig_classifier)
#load retriever and generator
retriever = SentenceTransformer('./BigRetriever/').to(device)
qa_generator = pipeline("text2text-generation", model="./train_qa", device=device_flag)
summ_generator = pipeline("text2text-generation", model="./train_summ", device=device_flag)
chat_generator = pipeline("text2text-generation", model="./train_chat", device=device_flag)
amb_generator = pipeline("text2text-generation", model="./train_amb_gen", device=device_flag)
retriever = SentenceTransformer('/models/BigRetriever/').to(device)
qa_generator = pipeline("text2text-generation", model="/models/train_qa", device=device_flag)
summ_generator = pipeline("text2text-generation", model="/models/train_summ", device=device_flag)
chat_generator = pipeline("text2text-generation", model="/models/train_chat", device=device_flag)
amb_generator = pipeline("text2text-generation", model="/models/train_amb_gen", device=device_flag)
generators = {'qa': qa_generator,
'chat': chat_generator,
'amb': amb_generator,