-
Notifications
You must be signed in to change notification settings - Fork 2.3k
Description
Reproduction
Only happens when using liger_kernal
Error:
rank0]: Traceback (most recent call last):
[rank0]: File "/home/ec2-user/mid-sft-dpo/alignment-handbook/scripts/dpo.py", line 159, in
[rank0]: main(script_args, training_args, model_args)
[rank0]: File "/home/ec2-user/mid-sft-dpo/alignment-handbook/scripts/dpo.py", line 138, in main
[rank0]: train_result = trainer.train(resume_from_checkpoint=checkpoint)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/ec2-user/mid-sft-dpo/alignment-handbook/handbook/lib/python3.11/site-packages/transformers/trainer.py", line 2316, in train
[rank0]: return inner_training_loop(
[rank0]: ^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/ec2-user/mid-sft-dpo/alignment-handbook/handbook/lib/python3.11/site-packages/transformers/trainer.py", line 2674, in _inner_training_loop
[rank0]: tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/ec2-user/mid-sft-dpo/alignment-handbook/handbook/lib/python3.11/site-packages/transformers/trainer.py", line 4020, in training_step
[rank0]: loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/ec2-user/mid-sft-dpo/alignment-handbook/handbook/lib/python3.11/site-packages/trl/trainer/dpo_trainer.py", line 1826, in compute_loss
[rank0]: loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train")
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/ec2-user/mid-sft-dpo/alignment-handbook/handbook/lib/python3.11/site-packages/trl/trainer/dpo_trainer.py", line 1737, in get_batch_loss_metrics
[rank0]: model_output = self._compute_loss_liger(model, batch)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/ec2-user/mid-sft-dpo/alignment-handbook/handbook/lib/python3.11/site-packages/trl/trainer/dpo_trainer.py", line 1450, in _compute_loss_liger
[rank0]: loss_output = self.dpo_loss_fn(
[rank0]: ^^^^^^^^^^^^^^^^^
[rank0]: File "/home/ec2-user/mid-sft-dpo/alignment-handbook/handbook/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/ec2-user/mid-sft-dpo/alignment-handbook/handbook/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/ec2-user/mid-sft-dpo/alignment-handbook/handbook/lib/python3.11/site-packages/liger_kernel/chunked_loss/dpo_loss.py", line 213, in forward
[rank0]: return LigerFusedLinearDPOFunction.apply(
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/ec2-user/mid-sft-dpo/alignment-handbook/handbook/lib/python3.11/site-packages/torch/autograd/function.py", line 575, in apply
[rank0]: return super().apply(*args, **kwargs) # type: ignore[misc]
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/ec2-user/mid-sft-dpo/alignment-handbook/handbook/lib/python3.11/site-packages/liger_kernel/chunked_loss/dpo_loss.py", line 138, in forward
[rank0]: return super().forward(
[rank0]: ^^^^^^^^^^^^^^^^
[rank0]: File "/home/ec2-user/mid-sft-dpo/alignment-handbook/handbook/lib/python3.11/site-packages/liger_kernel/chunked_loss/fused_linear_preference.py", line 240, in forward
[rank0]: accumulate_chunk(input_chunk, target_chunk, ref_input_chunk, chosen_nll_target_chunk)
[rank0]: File "/home/ec2-user/mid-sft-dpo/alignment-handbook/handbook/lib/python3.11/site-packages/liger_kernel/chunked_loss/fused_linear_preference.py", line 159, in accumulate_chunk
[rank0]: ) = fused_fwd_bwd(input_chunk, target_chunk, ref_input_chunk, chosen_nll_target_chunk)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/ec2-user/mid-sft-dpo/alignment-handbook/handbook/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 574, in _fn
[rank0]: return fn(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/ec2-user/mid-sft-dpo/alignment-handbook/handbook/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 1380, in call
[rank0]: return self._torchdynamo_orig_callable(
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/ec2-user/mid-sft-dpo/alignment-handbook/handbook/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 1164, in call
[rank0]: result = self._inner_convert(
[rank0]: ^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/ec2-user/mid-sft-dpo/alignment-handbook/handbook/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 547, in call
[rank0]: return _compile(
[rank0]: ^^^^^^^^^
[rank0]: File "/home/ec2-user/mid-sft-dpo/alignment-handbook/handbook/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 986, in _compile
[rank0]: guarded_code = compile_inner(code, one_graph, hooks, transform)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/ec2-user/mid-sft-dpo/alignment-handbook/handbook/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 715, in compile_inner
[rank0]: return _compile_inner(code, one_graph, hooks, transform)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/ec2-user/mid-sft-dpo/alignment-handbook/handbook/lib/python3.11/site-packages/torch/_utils_internal.py", line 95, in wrapper_function
[rank0]: return function(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/ec2-user/mid-sft-dpo/alignment-handbook/handbook/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 750, in _compile_inner
[rank0]: out_code = transform_code_object(code, transform)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/ec2-user/mid-sft-dpo/alignment-handbook/handbook/lib/python3.11/site-packages/torch/_dynamo/bytecode_transformation.py", line 1361, in transform_code_object
[rank0]: transformations(instructions, code_options)
[rank0]: File "/home/ec2-user/mid-sft-dpo/alignment-handbook/handbook/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 231, in _fn
[rank0]: return fn(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/ec2-user/mid-sft-dpo/alignment-handbook/handbook/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 662, in transform
[rank0]: tracer.run()
[rank0]: File "/home/ec2-user/mid-sft-dpo/alignment-handbook/handbook/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2868, in run
[rank0]: super().run()
[rank0]: File "/home/ec2-user/mid-sft-dpo/alignment-handbook/handbook/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 1052, in run
[rank0]: while self.step():
[rank0]: ^^^^^^^^^^^
[rank0]: File "/home/ec2-user/mid-sft-dpo/alignment-handbook/handbook/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 962, in step
[rank0]: self.dispatch_table[inst.opcode](self, inst)
[rank0]: File "/home/ec2-user/mid-sft-dpo/alignment-handbook/handbook/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 659, in wrapper
[rank0]: return inner_fn(self, inst)
[rank0]: ^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/ec2-user/mid-sft-dpo/alignment-handbook/handbook/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2341, in CALL
[rank0]: self._call(inst)
[rank0]: File "/home/ec2-user/mid-sft-dpo/alignment-handbook/handbook/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2335, in _call
[rank0]: self.call_function(fn, args, kwargs)
[rank0]: File "/home/ec2-user/mid-sft-dpo/alignment-handbook/handbook/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 897, in call_function
[rank0]: self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type]
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/ec2-user/mid-sft-dpo/alignment-handbook/handbook/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py", line 118, in call_function
[rank0]: return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/ec2-user/mid-sft-dpo/alignment-handbook/handbook/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 903, in inline_user_function_return
[rank0]: return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/ec2-user/mid-sft-dpo/alignment-handbook/handbook/lib/python3.11/site-packages/torch/dynamo/symbolic_convert.py", line 3072, in inline_call
[rank0]: return cls.inline_call(parent, func, args, kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/ec2-user/mid-sft-dpo/alignment-handbook/handbook/lib/python3.11/site-packages/torch/dynamo/symbolic_convert.py", line 3198, in inline_call
[rank0]: tracer.run()
[rank0]: File "/home/ec2-user/mid-sft-dpo/alignment-handbook/handbook/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 1052, in run
[rank0]: while self.step():
[rank0]: ^^^^^^^^^^^
[rank0]: File "/home/ec2-user/mid-sft-dpo/alignment-handbook/handbook/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 962, in step
[rank0]: self.dispatch_table[inst.opcode](self, inst)
[rank0]: File "/home/ec2-user/mid-sft-dpo/alignment-handbook/handbook/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 659, in wrapper
[rank0]: return inner_fn(self, inst)
[rank0]: ^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/ec2-user/mid-sft-dpo/alignment-handbook/handbook/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2341, in CALL
[rank0]: self._call(inst)
[rank0]: File "/home/ec2-user/mid-sft-dpo/alignment-handbook/handbook/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2335, in _call
[rank0]: self.call_function(fn, args, kwargs)
[rank0]: File "/home/ec2-user/mid-sft-dpo/alignment-handbook/handbook/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 897, in call_function
[rank0]: self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type]
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/ec2-user/mid-sft-dpo/alignment-handbook/handbook/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py", line 317, in call_function
[rank0]: return super().call_function(tx, args, kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/ec2-user/mid-sft-dpo/alignment-handbook/handbook/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py", line 118, in call_function
[rank0]: return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/ec2-user/mid-sft-dpo/alignment-handbook/handbook/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 903, in inline_user_function_return
[rank0]: return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/ec2-user/mid-sft-dpo/alignment-handbook/handbook/lib/python3.11/site-packages/torch/dynamo/symbolic_convert.py", line 3072, in inline_call
[rank0]: return cls.inline_call(parent, func, args, kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/ec2-user/mid-sft-dpo/alignment-handbook/handbook/lib/python3.11/site-packages/torch/dynamo/symbolic_convert.py", line 3198, in inline_call
[rank0]: tracer.run()
[rank0]: File "/home/ec2-user/mid-sft-dpo/alignment-handbook/handbook/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 1052, in run
[rank0]: while self.step():
[rank0]: ^^^^^^^^^^^
[rank0]: File "/home/ec2-user/mid-sft-dpo/alignment-handbook/handbook/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 962, in step
[rank0]: self.dispatch_table[inst.opcode](self, inst)
[rank0]: File "/home/ec2-user/mid-sft-dpo/alignment-handbook/handbook/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 659, in wrapper
[rank0]: return inner_fn(self, inst)
[rank0]: ^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/ec2-user/mid-sft-dpo/alignment-handbook/handbook/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 1736, in CALL_FUNCTION_EX
[rank0]: self.call_function(fn, argsvars.items, kwargsvars)
[rank0]: File "/home/ec2-user/mid-sft-dpo/alignment-handbook/handbook/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 897, in call_function
[rank0]: self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type]
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/ec2-user/mid-sft-dpo/alignment-handbook/handbook/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py", line 317, in call_function
[rank0]: return super().call_function(tx, args, kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/ec2-user/mid-sft-dpo/alignment-handbook/handbook/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py", line 118, in call_function
[rank0]: return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/ec2-user/mid-sft-dpo/alignment-handbook/handbook/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 903, in inline_user_function_return
[rank0]: return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/ec2-user/mid-sft-dpo/alignment-handbook/handbook/lib/python3.11/site-packages/torch/dynamo/symbolic_convert.py", line 3072, in inline_call
[rank0]: return cls.inline_call(parent, func, args, kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/ec2-user/mid-sft-dpo/alignment-handbook/handbook/lib/python3.11/site-packages/torch/dynamo/symbolic_convert.py", line 3198, in inline_call
[rank0]: tracer.run()
[rank0]: File "/home/ec2-user/mid-sft-dpo/alignment-handbook/handbook/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 1052, in run
[rank0]: while self.step():
[rank0]: ^^^^^^^^^^^
[rank0]: File "/home/ec2-user/mid-sft-dpo/alignment-handbook/handbook/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 962, in step
[rank0]: self.dispatch_table[inst.opcode](self, inst)
[rank0]: File "/home/ec2-user/mid-sft-dpo/alignment-handbook/handbook/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 659, in wrapper
[rank0]: return inner_fn(self, inst)
[rank0]: ^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/ec2-user/mid-sft-dpo/alignment-handbook/handbook/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 1736, in CALL_FUNCTION_EX
[rank0]: self.call_function(fn, argsvars.items, kwargsvars)
[rank0]: File "/home/ec2-user/mid-sft-dpo/alignment-handbook/handbook/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 897, in call_function
[rank0]: self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type]
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/ec2-user/mid-sft-dpo/alignment-handbook/handbook/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py", line 858, in call_function
[rank0]: return self.func.call_function(tx, merged_args, merged_kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/ec2-user/mid-sft-dpo/alignment-handbook/handbook/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py", line 317, in call_function
[rank0]: return super().call_function(tx, args, kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/ec2-user/mid-sft-dpo/alignment-handbook/handbook/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py", line 118, in call_function
[rank0]: return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/ec2-user/mid-sft-dpo/alignment-handbook/handbook/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 903, in inline_user_function_return
[rank0]: return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/ec2-user/mid-sft-dpo/alignment-handbook/handbook/lib/python3.11/site-packages/torch/dynamo/symbolic_convert.py", line 3072, in inline_call
[rank0]: return cls.inline_call(parent, func, args, kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/ec2-user/mid-sft-dpo/alignment-handbook/handbook/lib/python3.11/site-packages/torch/dynamo/symbolic_convert.py", line 3198, in inline_call
[rank0]: tracer.run()
[rank0]: File "/home/ec2-user/mid-sft-dpo/alignment-handbook/handbook/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 1052, in run
[rank0]: while self.step():
[rank0]: ^^^^^^^^^^^
[rank0]: File "/home/ec2-user/mid-sft-dpo/alignment-handbook/handbook/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 962, in step
[rank0]: self.dispatch_table[inst.opcode](self, inst)
[rank0]: File "/home/ec2-user/mid-sft-dpo/alignment-handbook/handbook/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 659, in wrapper
[rank0]: return inner_fn(self, inst)
[rank0]: ^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/ec2-user/mid-sft-dpo/alignment-handbook/handbook/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2341, in CALL
[rank0]: self._call(inst)
[rank0]: File "/home/ec2-user/mid-sft-dpo/alignment-handbook/handbook/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2335, in _call
[rank0]: self.call_function(fn, args, kwargs)
[rank0]: File "/home/ec2-user/mid-sft-dpo/alignment-handbook/handbook/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 897, in call_function
[rank0]: self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type]
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/ec2-user/mid-sft-dpo/alignment-handbook/handbook/lib/python3.11/site-packages/torch/_dynamo/variables/misc.py", line 1022, in call_function
[rank0]: return self.obj.call_method(tx, self.name, args, kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/ec2-user/mid-sft-dpo/alignment-handbook/handbook/lib/python3.11/site-packages/torch/_dynamo/variables/misc.py", line 778, in call_method
[rank0]: .call_function(tx, args, kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/ec2-user/mid-sft-dpo/alignment-handbook/handbook/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py", line 317, in call_function
[rank0]: return super().call_function(tx, args, kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/ec2-user/mid-sft-dpo/alignment-handbook/handbook/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py", line 118, in call_function
[rank0]: return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/ec2-user/mid-sft-dpo/alignment-handbook/handbook/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 903, in inline_user_function_return
[rank0]: return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/ec2-user/mid-sft-dpo/alignment-handbook/handbook/lib/python3.11/site-packages/torch/dynamo/symbolic_convert.py", line 3072, in inline_call
[rank0]: return cls.inline_call(parent, func, args, kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/ec2-user/mid-sft-dpo/alignment-handbook/handbook/lib/python3.11/site-packages/torch/dynamo/symbolic_convert.py", line 3198, in inline_call
[rank0]: tracer.run()
[rank0]: File "/home/ec2-user/mid-sft-dpo/alignment-handbook/handbook/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 1052, in run
[rank0]: while self.step():
[rank0]: ^^^^^^^^^^^
[rank0]: File "/home/ec2-user/mid-sft-dpo/alignment-handbook/handbook/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 962, in step
[rank0]: self.dispatch_table[inst.opcode](self, inst)
[rank0]: File "/home/ec2-user/mid-sft-dpo/alignment-handbook/handbook/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2282, in BINARY_OP
[rank0]: return _binary_op_lookup[inst.arg](self, inst)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/ec2-user/mid-sft-dpo/alignment-handbook/handbook/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 314, in impl
[rank0]: self.push(fn_var.call_function(self, self.popn(nargs), {}))
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/ec2-user/mid-sft-dpo/alignment-handbook/handbook/lib/python3.11/site-packages/torch/_dynamo/variables/builtin.py", line 1004, in call_function
[rank0]: return handler(tx, args, kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/ec2-user/mid-sft-dpo/alignment-handbook/handbook/lib/python3.11/site-packages/torch/_dynamo/variables/builtin.py", line 980, in _handle_insert_op_in_graph
[rank0]: return wrap_fx_proxy(tx, proxy)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/ec2-user/mid-sft-dpo/alignment-handbook/handbook/lib/python3.11/site-packages/torch/_dynamo/variables/builder.py", line 2153, in wrap_fx_proxy
[rank0]: return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/ec2-user/mid-sft-dpo/alignment-handbook/handbook/lib/python3.11/site-packages/torch/_dynamo/variables/builder.py", line 2219, in wrap_fx_proxy_cls
[rank0]: return _wrap_fx_proxy(
[rank0]: ^^^^^^^^^^^^^^^
[rank0]: File "/home/ec2-user/mid-sft-dpo/alignment-handbook/handbook/lib/python3.11/site-packages/torch/_dynamo/variables/builder.py", line 2315, in _wrap_fx_proxy
[rank0]: example_value = get_fake_value(proxy.node, tx, allow_non_graph_fake=True)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/ec2-user/mid-sft-dpo/alignment-handbook/handbook/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 2536, in get_fake_value
[rank0]: raise TorchRuntimeError(str(e)).with_traceback(e.traceback) from None
[rank0]: File "/home/ec2-user/mid-sft-dpo/alignment-handbook/handbook/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 2471, in get_fake_value
[rank0]: ret_val = wrap_fake_exception(
[rank0]: ^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/ec2-user/mid-sft-dpo/alignment-handbook/handbook/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 2017, in wrap_fake_exception
[rank0]: return fn()
[rank0]: ^^^^
[rank0]: File "/home/ec2-user/mid-sft-dpo/alignment-handbook/handbook/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 2472, in
[rank0]: lambda: run_node(tx.output, node, args, kwargs, nnmodule)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/ec2-user/mid-sft-dpo/alignment-handbook/handbook/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 2604, in run_node
[rank0]: raise RuntimeError(make_error_message(e)).with_traceback(
[rank0]: File "/home/ec2-user/mid-sft-dpo/alignment-handbook/handbook/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 2586, in run_node
[rank0]: return node.target(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/ec2-user/mid-sft-dpo/alignment-handbook/handbook/lib/python3.11/site-packages/torch/_prims_common/wrappers.py", line 289, in _fn
[rank0]: result = fn(*args, is_out=(out is not None), **kwargs) # type: ignore[arg-type]
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/ec2-user/mid-sft-dpo/alignment-handbook/handbook/lib/python3.11/site-packages/torch/_decomp/decompositions.py", line 4444, in matmul
[rank0]: return torch.ops.aten._unsafe_view(t1_folded.mv(t2), output_shape)
[rank0]: ^^^^^^^^^^^^^^^^
[rank0]: File "/home/ec2-user/mid-sft-dpo/alignment-handbook/handbook/lib/python3.11/site-packages/torch/utils/_stats.py", line 21, in wrapper
[rank0]: return fn(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/ec2-user/mid-sft-dpo/alignment-handbook/handbook/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 1276, in torch_dispatch
[rank0]: return self.dispatch(func, types, args, kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/ec2-user/mid-sft-dpo/alignment-handbook/handbook/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 1816, in dispatch
[rank0]: return self._cached_dispatch_impl(func, types, args, kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/ec2-user/mid-sft-dpo/alignment-handbook/handbook/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 1386, in _cached_dispatch_impl
[rank0]: output = self._dispatch_impl(func, types, args, kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/ec2-user/mid-sft-dpo/alignment-handbook/handbook/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 2290, in _dispatch_impl
[rank0]: decomposition_table[func](*args, **kwargs)
[rank0]: File "/home/ec2-user/mid-sft-dpo/alignment-handbook/handbook/lib/python3.11/site-packages/torch/_prims_common/wrappers.py", line 291, in _fn
[rank0]: result = fn(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/ec2-user/mid-sft-dpo/alignment-handbook/handbook/lib/python3.11/site-packages/torch/_decomp/decompositions.py", line 83, in inner
[rank0]: r = f(tree_map(increase_prec, args), **tree_map(increase_prec, kwargs))
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/ec2-user/mid-sft-dpo/alignment-handbook/handbook/lib/python3.11/site-packages/torch/_decomp/decompositions.py", line 4336, in mv
[rank0]: torch._check(
[rank0]: File "/home/ec2-user/mid-sft-dpo/alignment-handbook/handbook/lib/python3.11/site-packages/torch/init.py", line 1656, in _check
[rank0]: _check_with(RuntimeError, cond, message)
[rank0]: File "/home/ec2-user/mid-sft-dpo/alignment-handbook/handbook/lib/python3.11/site-packages/torch/init.py", line 1638, in _check_with
[rank0]: raise error_type(message_evaluated)
[rank0]: torch._dynamo.exc.TorchRuntimeError: Failed running call_function ((GradTrackingTensor(lvl=1, value=
[rank0]: FakeTensor(..., device='cuda:0', size=(1, s2, 2048), dtype=torch.bfloat16)
[rank0]: ), GradTrackingTensor(lvl=1, value=
[rank0]: FakeTensor(..., device='cuda:0', size=(0,), dtype=torch.bfloat16,
[rank0]: requires_grad=True)
[rank0]: )), **{}):
[rank0]: size mismatch, got input (s2x2048), vec (0)
[rank0]: from user code:
[rank0]: File "/home/ec2-user/mid-sft-dpo/alignment-handbook/handbook/lib/python3.11/site-packages/liger_kernel/chunked_loss/fused_linear_preference.py", line 120, in fused_fwd_bwd
[rank0]: return torch.func.grad_and_value(compute_loss, argnums=(0, 1), has_aux=True)(
[rank0]: File "/home/ec2-user/mid-sft-dpo/alignment-handbook/handbook/lib/python3.11/site-packages/torch/_functorch/apis.py", line 442, in wrapper
[rank0]: return eager_transforms.grad_and_value_impl(
[rank0]: File "/home/ec2-user/mid-sft-dpo/alignment-handbook/handbook/lib/python3.11/site-packages/torch/_functorch/vmap.py", line 48, in fn
[rank0]: return f(*args, **kwargs)
[rank0]: File "/home/ec2-user/mid-sft-dpo/alignment-handbook/handbook/lib/python3.11/site-packages/torch/_functorch/eager_transforms.py", line 1364, in grad_and_value_impl
[rank0]: output = func(*args, **kwargs)
[rank0]: File "/home/ec2-user/mid-sft-dpo/alignment-handbook/handbook/lib/python3.11/site-packages/liger_kernel/chunked_loss/fused_linear_preference.py", line 376, in _compute_loss
[rank0]: ) = LigerFusedLinearPreferenceBase.chunk_forward(
[rank0]: File "/home/ec2-user/mid-sft-dpo/alignment-handbook/handbook/lib/python3.11/site-packages/liger_kernel/chunked_loss/fused_linear_preference.py", line 288, in chunk_forward
[rank0]: logits_chunk = input_chunk @ weight.t()
[rank0]: Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
[rank0]: You can suppress this exception and fall back to eager by setting:
[rank0]: import torch._dynamo
[rank0]: torch._dynamo.config.suppress_errors = True
System Info
Versions:
torch = 2.6.0+cu126
trl = 0.25.1
liger_kernel = 0.6.3
Checklist
- I have checked that my issue isn't already filed (see open issues)
- I have included my system information
- Any code provided is minimal, complete, and reproducible (more on MREs)
- Any code provided is properly formatted in code blocks, (no screenshot, more on code blocks)
- Any traceback provided is complete