Skip to content
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,4 @@ dmypy.json

# Pyre type checker
.pyre/
tmp/
4 changes: 2 additions & 2 deletions DeBERTa/apps/sequence_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def forward(self, input_ids, type_ids=None, input_mask=None, labels=None, positi
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)

loss = 0
loss = torch.tensor(0).to(logits)
if labels is not None:
if self.num_labels ==1:
# regression task
Expand All @@ -68,4 +68,4 @@ def forward(self, input_ids, type_ids=None, input_mask=None, labels=None, positi
label_confidence = 1
loss = -((log_softmax(logits)*labels).sum(-1)*label_confidence).mean()

return (logits,loss)
return (loss, logits)
74 changes: 69 additions & 5 deletions DeBERTa/apps/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@
from ..utils import *
from ..utils import xtqdm as tqdm
from .task_registry import tasks
from onnxruntime.capi.ort_trainer import ORTTrainer, IODescription, ModelDescription, LossScaler

from ..training import DistributedTrainer, initialize_distributed, batch_to, set_random_seed,kill_children
from ..data import DistributedBatchSampler, SequentialSampler, BatchSampler, AsyncDataLoader
from ..data import DistributedBatchSampler, SequentialSampler, BatchSampler, RandomSampler, AsyncDataLoader

def create_model(args, num_labels, model_class_fn):
# Prepare model
Expand Down Expand Up @@ -217,9 +218,63 @@ def run_predict(args, model, device, eval_data, prefix=None):
if predict_fn:
predict_fn(predicts, args.output_dir, name, prefix)

def deberta_model_description(args):
vocab_size = 30528
# set concrete input sizes to permit optimization
input_ids_desc = IODescription('input_ids', [args.train_batch_size, args.max_seq_length], torch.int32, num_classes=vocab_size)
type_ids_desc = IODescription('type_ids', [args.train_batch_size, args.max_seq_length], torch.int32) # num_classes=?
position_ids_desc = IODescription('position_ids', [args.train_batch_size, args.max_seq_length], torch.int32) # num_classes=?
input_mask_desc = IODescription('input_mask', [args.train_batch_size, args.max_seq_length], torch.int32) # num_classes=?
labels_desc = IODescription('labels', [args.train_batch_size, args.max_seq_length], torch.float32) # num_classes=?

loss_desc = IODescription('loss', [], torch.float32)
return ModelDescription([input_ids_desc, type_ids_desc, position_ids_desc, input_mask_desc, labels_desc], [loss_desc])

def create_ort_trainer(args, device, model):
# default initial settings: b1=0.9, b2=0.999, e=1e-6
def map_optimizer_attributes(name):
no_decay_keys = ["bias", "gamma", "beta", "LayerNorm"]
no_decay = False
for no_decay_key in no_decay_keys:
if no_decay_key in name:
no_decay = True
break
if no_decay:
return {"alpha": 0.9, "beta": 0.999, "lambda": 0.0, "epsilon": 1e-6}
else:
return {"alpha": 0.9, "beta": 0.999, "lambda": 0.01, "epsilon": 1e-6}

# we request ORTTrainer to create a LambOptimizer with given optimizer_attributes.
# train_step does forward, backward, and optimize step.
model = ORTTrainer(model, None, deberta_model_description(args), "LambOptimizer",
map_optimizer_attributes,
IODescription('Learning_Rate', [1,], torch.float32),
device,
_opset_version = 10)

return model

def run_onnx_training(args, model, device, train_data, prefix=None):
# runs training in ONNX
trainer = create_ort_trainer(args, device, model)
train_sampler = RandomSampler(len(train_data))
batch_sampler = BatchSampler(train_sampler, args.train_batch_size)
batch_sampler = DistributedBatchSampler(batch_sampler, rank=args.rank, world_size=args.world_size)
train_dataloader = DataLoader(train_data, batch_sampler=batch_sampler, num_workers=args.workers, pin_memory=True)
torch.cuda.empty_cache()
for step, batch in enumerate(AsyncDataLoader(train_dataloader, 100)):
#import pdb
#pdb.set_trace()
batch = batch_to(batch, device)
with torch.no_grad():
trainer.train_step(batch['input_ids'], batch['type_ids'], batch['position_ids'], batch['input_mask'], batch['labels'])
# conversion fails now with:
# site-packages/torch/onnx/utils.py:617: UserWarning: ONNX export failed on ATen operator broadcast_tensors
Copy link
Member Author

@ganik ganik Aug 2, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

broadcast_tensor and mse_loss are ops that are not implemented in ONNX currently. To get unblocked need to modify functional.py as per below comment

Copy link
Member Author

@ganik ganik Aug 3, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mse_loss implementation in https://github.com/pytorch/pytorch/blob/master/torch/nn/functional.py#L2682 uses 2 ops that are not implemented: broadcast_tensors() and mse_loss(). Working around this to get unblocked, made a patch:
#expanded_input, expanded_target = torch.broadcast_tensors(input, target)
expanded_input = input + torch.zeros(target.size())
expanded_target = target + torch.zeros(input.size())
#ret = torch._C._nn.mse_loss(expanded_input, expanded_target, _Reduction.get_enum(reduction))
t = expanded_input - expanded_target
t = t * t
ret = torch.mean(t)

# because torch.onnx.symbolic_opset10.broadcast_tensors does not exist

def main(args):
if not args.do_train and not args.do_eval and not args.do_predict:
raise ValueError("At least one of `do_train` or `do_eval` or `do_predict` must be True.")
if not args.do_train and not args.do_eval and not args.do_predict and not args.do_onnx:
raise ValueError("At least one of `do_train` or `do_eval` or `do_predict` or `do_onnx` must be True.")
os.makedirs(args.output_dir, exist_ok=True)
task_name = args.task_name.lower()
random.seed(args.seed)
Expand All @@ -236,11 +291,11 @@ def main(args):
test_data = processor.test_data(max_seq_len=args.max_seq_length)
logger.info(" Prediction batch size = %d", args.predict_batch_size)

if args.do_train:
if args.do_train or args.do_onnx:
train_data = processor.train_data(max_seq_len=args.max_seq_length, mask_gen = None, debug=args.debug)
model_class_fn = processor.get_model_class_fn()
model = create_model(args, len(label_list), model_class_fn)
if args.do_train:
if args.do_train or args.do_onnx:
with open(os.path.join(args.output_dir, 'model_config.json'), 'w', encoding='utf-8') as fs:
fs.write(model.config.to_json_string() + '\n')
logger.info("Model config {}".format(model.config))
Expand All @@ -257,6 +312,10 @@ def main(args):
if args.do_predict:
run_predict(args, model, device, test_data, prefix=args.tag)

# trains in ONNX
if args.do_onnx:
run_onnx_training(args, model, device, train_data, prefix=args.tag)

def build_argument_parser():
parser = argparse.ArgumentParser()

Expand Down Expand Up @@ -437,6 +496,11 @@ def build_argument_parser():
default=None,
type=str,
help="The path of pre-trained RoBERTa model")

parser.add_argument("--do_onnx",
default=False,
action='store_true',
help="Whether to run training in ONNX")
return parser

if __name__ == "__main__":
Expand Down
4 changes: 3 additions & 1 deletion DeBERTa/deberta/disentangled_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,9 @@ def linear(w,b,x):
if self.talking_head:
attention_scores = self.head_logits_proj(attention_scores.permute(0,2,3,1)).permute(0,3,1,2)

attention_probs = XSoftmax.apply(attention_scores, attention_mask, -1)
#attention_probs = XSoftmax.apply(attention_scores, attention_mask, -1)
nodex = torch.nn.Softmax(-1)
attention_probs = nodex(attention_scores + 10000.0*(attention_mask -1))
attention_probs = self.dropout(attention_probs)
if self.talking_head:
attention_probs = self.head_weights_proj(attention_probs.permute(0,2,3,1)).permute(0,3,1,2)
Expand Down
6 changes: 5 additions & 1 deletion DeBERTa/deberta/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,11 @@ def backward(ctx, grad_output):
else:
return grad_output, None

class StableDropout(torch.nn.Module):
class StableDropout(torch.nn.Dropout):
def __init__(self, drop_prob):
super().__init__()

class StableDropout1(torch.nn.Module):
""" Optimized dropout module for stabilizing the training

Args:
Expand Down