diff --git a/deeppavlov/configs/squad/refactor_squad_torch_bert.json b/deeppavlov/configs/squad/refactor_squad_torch_bert.json new file mode 100644 index 0000000000..a1be8baa09 --- /dev/null +++ b/deeppavlov/configs/squad/refactor_squad_torch_bert.json @@ -0,0 +1,175 @@ +{ + "dataset_reader": { + "class_name": "squad_dataset_reader", + "data_path": "{DOWNLOADS_PATH}/squad/" + }, + "dataset_iterator": { + "class_name": "squad_iterator", + "seed": 1337, + "shuffle": true + }, + "chainer": { + "in": [ + "context_raw", + "question_raw" + ], + "in_y": [ + "ans_raw", + "ans_raw_start" + ], + "pipe": [ + { + "class_name": "torch_squad_transformers_preprocessor", + "vocab_file": "{TRANSFORMER}", + "do_lower_case": "{LOWERCASE}", + "max_seq_length": 384, + "return_tokens": true, + "in": [ + "question_raw", + "context_raw" + ], + "out": [ + "bert_features", + "subtokens" + ] + }, + { + "class_name": "squad_bert_mapping", + "do_lower_case": "{LOWERCASE}", + "in": [ + "context_raw", + "bert_features", + "subtokens" + ], + "out": [ + "subtok2chars", + "char2subtoks" + ] + }, + { + "class_name": "squad_bert_ans_preprocessor", + "do_lower_case": "{LOWERCASE}", + "in": [ + "ans_raw", + "ans_raw_start", + "char2subtoks" + ], + "out": [ + "ans", + "ans_start", + "ans_end" + ] + }, + { + "class_name": "torch_transformers_squad", + "pretrained_bert": "{TRANSFORMER}", + "save_path": "{MODEL_PATH}/model", + "load_path": "{MODEL_PATH}/model", + "optimizer": "AdamW", + "optimizer_parameters": { + "lr": 2e-05, + "weight_decay": 0.01, + "betas": [ + 0.9, + 0.999 + ], + "eps": 1e-06 + }, + "learning_rate_drop_patience": 2, + "learning_rate_drop_div": 2.0, + "in": [ + "bert_features" + ], + "in_y": [ + "ans_start", + "ans_end" + ], + "out": [ + "ans_start_predicted", + "ans_end_predicted", + "logits" + ] + }, + { + "class_name": "squad_bert_ans_postprocessor", + "in": [ + "ans_start_predicted", + "ans_end_predicted", + "context_raw", + "bert_features", + "subtok2chars", + "subtokens" + ], + "out": [ + "ans_predicted", + "ans_start_predicted", + "ans_end_predicted" + ] + } + ], + "out": [ + "ans_predicted", + "ans_start_predicted", + "logits" + ] + }, + "train": { + "show_examples": false, + "evaluation_targets": [ + "valid" + ], + "log_every_n_batches": 250, + "val_every_n_batches": 500, + "batch_size": 10, + "pytest_max_batches": 2, + "pytest_batch_size": 5, + "validation_patience": 10, + "metrics": [ + { + "name": "squad_v1_f1", + "inputs": [ + "ans", + "ans_predicted" + ] + }, + { + "name": "squad_v1_em", + "inputs": [ + "ans", + "ans_predicted" + ] + }, + { + "name": "squad_v2_f1", + "inputs": [ + "ans", + "ans_predicted" + ] + }, + { + "name": "squad_v2_em", + "inputs": [ + "ans", + "ans_predicted" + ] + } + ], + "class_name": "torch_trainer" + }, + "metadata": { + "variables": { + "LOWERCASE": true, + "TRANSFORMER": "allenai/longformer-base-4096", + "ROOT_PATH": "~/.deeppavlov", + "DOWNLOADS_PATH": "{ROOT_PATH}/downloads", + "MODELS_PATH": "{ROOT_PATH}/models", + "MODEL_PATH": "{MODELS_PATH}/squad_torch_bert/{TRANSFORMER}" + }, + "download": [ + { + "url": "http://files.deeppavlov.ai/v1/squad/squad_torch_bert.tar.gz", + "subdir": "{ROOT_PATH}/models" + } + ] + } +} diff --git a/deeppavlov/models/preprocessors/squad_preprocessor.py b/deeppavlov/models/preprocessors/squad_preprocessor.py index c342902d4f..628f30d9d9 100644 --- a/deeppavlov/models/preprocessors/squad_preprocessor.py +++ b/deeppavlov/models/preprocessors/squad_preprocessor.py @@ -404,7 +404,10 @@ def __call__(self, contexts, bert_features, *args, **kwargs): subtokens = args[0][batch_counter] else: subtokens = features.tokens - context_start = subtokens.index('[SEP]') + 1 + if '[SEP]' in subtokens: + context_start = subtokens.index('[SEP]') + 1 + else: + context_start = subtokens.index('') + 1 idx = 0 subtok2char: Dict[int, int] = {} char2subtok: Dict[int, int] = {} diff --git a/deeppavlov/models/torch_bert/torch_transformers_squad.py b/deeppavlov/models/torch_bert/torch_transformers_squad.py index 9506ce924e..af4a67ab20 100644 --- a/deeppavlov/models/torch_bert/torch_transformers_squad.py +++ b/deeppavlov/models/torch_bert/torch_transformers_squad.py @@ -123,12 +123,14 @@ def train_on_batch(self, features: List[InputFeatures], y_st: List[List[int]], y b_input_ids = torch.cat(input_ids, dim=0).to(self.device) b_input_masks = torch.cat(input_masks, dim=0).to(self.device) b_input_type_ids = torch.cat(input_type_ids, dim=0).to(self.device) + if any(x in self.pretrained_bert for x in ['roberta', 'distilbert', 'bart', 'longformer']): + b_input_type_ids = b_input_type_ids.unsqueeze(1).expand(-1, b_input_ids.shape[-1]) y_st = [x[0] for x in y_st] y_end = [x[0] for x in y_end] b_y_st = torch.from_numpy(np.array(y_st)).to(self.device) b_y_end = torch.from_numpy(np.array(y_end)).to(self.device) - + input_ = { 'input_ids': b_input_ids, 'attention_mask': b_input_masks, @@ -184,7 +186,9 @@ def __call__(self, features: List[InputFeatures]) -> Tuple[List[int], List[int], b_input_ids = torch.cat(input_ids, dim=0).to(self.device) b_input_masks = torch.cat(input_masks, dim=0).to(self.device) b_input_type_ids = torch.cat(input_type_ids, dim=0).to(self.device) - + if any(x in self.pretrained_bert for x in ['roberta', 'distilbert', 'bart', 'longformer']): + b_input_type_ids = b_input_type_ids.unsqueeze(1).expand(-1, b_input_ids.shape[-1]) + input_ = { 'input_ids': b_input_ids, 'attention_mask': b_input_masks,