-
Notifications
You must be signed in to change notification settings - Fork 237
[WIP] ONNX conversion #6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 3 commits
498db46
3be8289
dab83af
64f068c
1ce5cc1
efd079b
b301b46
d2fa9fd
e4793b8
155d966
95ec7ad
bf8a3ce
c71818b
7eff1fd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -127,3 +127,4 @@ dmypy.json | |
|
||
# Pyre type checker | ||
.pyre/ | ||
tmp/ |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,9 +24,10 @@ | |
from ..utils import * | ||
from ..utils import xtqdm as tqdm | ||
from .task_registry import tasks | ||
from onnxruntime.capi.ort_trainer import ORTTrainer, IODescription, ModelDescription, LossScaler | ||
|
||
from ..training import DistributedTrainer, initialize_distributed, batch_to, set_random_seed,kill_children | ||
from ..data import DistributedBatchSampler, SequentialSampler, BatchSampler, AsyncDataLoader | ||
from ..data import DistributedBatchSampler, SequentialSampler, BatchSampler, RandomSampler, AsyncDataLoader | ||
|
||
def create_model(args, num_labels, model_class_fn): | ||
# Prepare model | ||
|
@@ -217,9 +218,63 @@ def run_predict(args, model, device, eval_data, prefix=None): | |
if predict_fn: | ||
predict_fn(predicts, args.output_dir, name, prefix) | ||
|
||
def deberta_model_description(args): | ||
vocab_size = 30528 | ||
# set concrete input sizes to permit optimization | ||
input_ids_desc = IODescription('input_ids', [args.train_batch_size, args.max_seq_length], torch.int32, num_classes=vocab_size) | ||
type_ids_desc = IODescription('type_ids', [args.train_batch_size, args.max_seq_length], torch.int32) # num_classes=? | ||
position_ids_desc = IODescription('position_ids', [args.train_batch_size, args.max_seq_length], torch.int32) # num_classes=? | ||
input_mask_desc = IODescription('input_mask', [args.train_batch_size, args.max_seq_length], torch.int32) # num_classes=? | ||
labels_desc = IODescription('labels', [args.train_batch_size, args.max_seq_length], torch.float32) # num_classes=? | ||
|
||
loss_desc = IODescription('loss', [], torch.float32) | ||
return ModelDescription([input_ids_desc, type_ids_desc, position_ids_desc, input_mask_desc, labels_desc], [loss_desc]) | ||
|
||
def create_ort_trainer(args, device, model): | ||
# default initial settings: b1=0.9, b2=0.999, e=1e-6 | ||
def map_optimizer_attributes(name): | ||
no_decay_keys = ["bias", "gamma", "beta", "LayerNorm"] | ||
no_decay = False | ||
for no_decay_key in no_decay_keys: | ||
if no_decay_key in name: | ||
no_decay = True | ||
break | ||
if no_decay: | ||
return {"alpha": 0.9, "beta": 0.999, "lambda": 0.0, "epsilon": 1e-6} | ||
else: | ||
return {"alpha": 0.9, "beta": 0.999, "lambda": 0.01, "epsilon": 1e-6} | ||
|
||
# we request ORTTrainer to create a LambOptimizer with given optimizer_attributes. | ||
# train_step does forward, backward, and optimize step. | ||
model = ORTTrainer(model, None, deberta_model_description(args), "LambOptimizer", | ||
map_optimizer_attributes, | ||
IODescription('Learning_Rate', [1,], torch.float32), | ||
device, | ||
_opset_version = 10) | ||
|
||
return model | ||
|
||
def run_onnx_training(args, model, device, train_data, prefix=None): | ||
# runs training in ONNX | ||
trainer = create_ort_trainer(args, device, model) | ||
train_sampler = RandomSampler(len(train_data)) | ||
batch_sampler = BatchSampler(train_sampler, args.train_batch_size) | ||
batch_sampler = DistributedBatchSampler(batch_sampler, rank=args.rank, world_size=args.world_size) | ||
train_dataloader = DataLoader(train_data, batch_sampler=batch_sampler, num_workers=args.workers, pin_memory=True) | ||
torch.cuda.empty_cache() | ||
for step, batch in enumerate(AsyncDataLoader(train_dataloader, 100)): | ||
#import pdb | ||
#pdb.set_trace() | ||
batch = batch_to(batch, device) | ||
with torch.no_grad(): | ||
trainer.train_step(batch['input_ids'], batch['type_ids'], batch['position_ids'], batch['input_mask'], batch['labels']) | ||
# conversion fails now with: | ||
# site-packages/torch/onnx/utils.py:617: UserWarning: ONNX export failed on ATen operator broadcast_tensors | ||
|
||
# because torch.onnx.symbolic_opset10.broadcast_tensors does not exist | ||
|
||
def main(args): | ||
if not args.do_train and not args.do_eval and not args.do_predict: | ||
raise ValueError("At least one of `do_train` or `do_eval` or `do_predict` must be True.") | ||
if not args.do_train and not args.do_eval and not args.do_predict and not args.do_onnx: | ||
raise ValueError("At least one of `do_train` or `do_eval` or `do_predict` or `do_onnx` must be True.") | ||
os.makedirs(args.output_dir, exist_ok=True) | ||
task_name = args.task_name.lower() | ||
random.seed(args.seed) | ||
|
@@ -236,11 +291,11 @@ def main(args): | |
test_data = processor.test_data(max_seq_len=args.max_seq_length) | ||
logger.info(" Prediction batch size = %d", args.predict_batch_size) | ||
|
||
if args.do_train: | ||
if args.do_train or args.do_onnx: | ||
train_data = processor.train_data(max_seq_len=args.max_seq_length, mask_gen = None, debug=args.debug) | ||
model_class_fn = processor.get_model_class_fn() | ||
model = create_model(args, len(label_list), model_class_fn) | ||
if args.do_train: | ||
if args.do_train or args.do_onnx: | ||
with open(os.path.join(args.output_dir, 'model_config.json'), 'w', encoding='utf-8') as fs: | ||
fs.write(model.config.to_json_string() + '\n') | ||
logger.info("Model config {}".format(model.config)) | ||
|
@@ -257,6 +312,10 @@ def main(args): | |
if args.do_predict: | ||
run_predict(args, model, device, test_data, prefix=args.tag) | ||
|
||
# trains in ONNX | ||
if args.do_onnx: | ||
run_onnx_training(args, model, device, train_data, prefix=args.tag) | ||
|
||
def build_argument_parser(): | ||
parser = argparse.ArgumentParser() | ||
|
||
|
@@ -437,6 +496,11 @@ def build_argument_parser(): | |
default=None, | ||
type=str, | ||
help="The path of pre-trained RoBERTa model") | ||
|
||
parser.add_argument("--do_onnx", | ||
default=False, | ||
action='store_true', | ||
help="Whether to run training in ONNX") | ||
return parser | ||
|
||
if __name__ == "__main__": | ||
|
Uh oh!
There was an error while loading. Please reload this page.