-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Fix CI ImportError: FlashAttention2 and decorator order for all parameterized tests #4176
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
After installing RuntimeError: FlashAttention only supports Ampere GPUs or newer. Traceback: ___ GRPOTrainerSlowTester.test_vlm_training_0_HuggingFaceTB_SmolVLM_Instruct ___
a = (<tests.slow.test_grpo_slow.GRPOTrainerSlowTester testMethod=test_vlm_training_0_HuggingFaceTB_SmolVLM_Instruct>,)
kw = {}
@wraps(func)
def standalone_func(*a, **kw):
> return func(*(a + p.args), **p.kwargs, **kw)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.11/site-packages/parameterized/parameterized.py:620:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
tests/slow/test_grpo_slow.py:317: in test_vlm_training
trainer.train()
.venv/lib/python3.11/site-packages/transformers/trainer.py:2328: in train
return inner_training_loop(
.venv/lib/python3.11/site-packages/transformers/trainer.py:2672: in _inner_training_loop
tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.11/site-packages/transformers/trainer.py:4003: in training_step
inputs = self._prepare_inputs(inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
trl/extras/profiling.py:98: in wrapper
return func(self, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
trl/trainer/grpo_trainer.py:1008: in _prepare_inputs
generation_batch = self._generate_and_score_completions(generation_batch)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
trl/trainer/grpo_trainer.py:1418: in _generate_and_score_completions
) = self._generate(prompts, images)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
trl/trainer/grpo_trainer.py:1333: in _generate
prompt_completion_ids = unwrapped_model.generate(
.venv/lib/python3.11/site-packages/peft/peft_model.py:1973: in generate
outputs = self.base_model.generate(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.11/site-packages/torch/utils/_contextlib.py:120: in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.11/site-packages/transformers/generation/utils.py:2539: in generate
result = self._sample(
.venv/lib/python3.11/site-packages/transformers/generation/utils.py:2867: in _sample
outputs = self(**model_inputs, return_dict=True)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1773: in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1784: in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.11/site-packages/transformers/utils/generic.py:940: in wrapper
output = func(self, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.11/site-packages/transformers/models/idefics3/modeling_idefics3.py:973: in forward
outputs = self.model(
.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1773: in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1784: in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.11/site-packages/transformers/utils/generic.py:940: in wrapper
output = func(self, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.11/site-packages/transformers/models/idefics3/modeling_idefics3.py:795: in forward
image_hidden_states = self.get_image_features(pixel_values, pixel_attention_mask)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.11/site-packages/transformers/models/idefics3/modeling_idefics3.py:722: in get_image_features
image_hidden_states = self.vision_model(pixel_values=pixel_values, patch_attention_mask=patch_attention_mask)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1773: in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1784: in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.11/site-packages/transformers/models/idefics3/modeling_idefics3.py:572: in forward
encoder_outputs = self.encoder(
.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1773: in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1784: in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.11/site-packages/transformers/models/idefics3/modeling_idefics3.py:397: in forward
layer_outputs = encoder_layer(
.venv/lib/python3.11/site-packages/transformers/modeling_layers.py:94: in __call__
return super().__call__(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1773: in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1784: in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.11/site-packages/transformers/models/idefics3/modeling_idefics3.py:317: in forward
hidden_states, attn_weights = self.self_attn(
.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1773: in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1784: in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.11/site-packages/transformers/models/idefics3/modeling_idefics3.py:243: in forward
attn_output, attn_weights = attention_interface(
.venv/lib/python3.11/site-packages/transformers/integrations/flash_attention.py:66: in flash_attention_forward
attn_output = _flash_attention_forward(
.venv/lib/python3.11/site-packages/transformers/modeling_flash_attention_utils.py:664: in _flash_attention_forward
out = flash_fn(query_states, key_states, value_states, **flash_kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.11/site-packages/flash_attn/flash_attn_interface.py:1196: in flash_attn_func
return FlashAttnFunc.apply(
.venv/lib/python3.11/site-packages/torch/autograd/function.py:576: in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.11/site-packages/flash_attn/flash_attn_interface.py:834: in forward
out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_attn_forward(
.venv/lib/python3.11/site-packages/torch/_ops.py:1243: in __call__
return self._op(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.11/site-packages/torch/_library/autograd.py:111: in autograd_impl
result = forward_no_grad(*args, Metadata(keyset, keyword_only_args))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.11/site-packages/torch/_library/autograd.py:40: in forward_no_grad
result = op.redispatch(keyset & _C._after_autograd_keyset, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.11/site-packages/torch/_ops.py:836: in redispatch
return self._handle.redispatch_boxed(keyset, *args, **kwargs) # type: ignore[return-value]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.11/site-packages/torch/_library/custom_ops.py:344: in backend_impl
result = self._backend_fns[device_type](*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.11/site-packages/torch/_compile.py:53: in inner
return disable_fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py:929: in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.11/site-packages/torch/_library/custom_ops.py:377: in wrapped_fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
q = tensor([[[[-8.9062e-01, -3.3594e+00, -1.1094e+00, ..., 2.1094e+00,
-3.2812e+00, 3.8867e-01],
[...12e+00, 1.0781e+00, ..., 2.8125e-01,
-2.3242e-01, 3.8438e+00]]]], device='cuda:0', dtype=torch.bfloat16)
k = tensor([[[[-1.2891e+00, -3.0625e+00, -1.6094e+00, ..., 2.7832e-02,
3.4180e-01, 2.0000e+00],
[...61e-01, -1.5332e-01, ..., 2.2559e-01,
5.8838e-02, -6.1719e-01]]]], device='cuda:0', dtype=torch.bfloat16)
v = tensor([[[[ 1.4648e-02, -6.9824e-02, -4.4678e-02, ..., -4.0234e-01,
1.6113e-02, -4.2480e-02],
[...41e+00, 9.9609e-02, ..., -4.7363e-02,
-2.5146e-02, -3.8818e-02]]]], device='cuda:0', dtype=torch.bfloat16)
dropout_p = 0.0, softmax_scale = 0.11785113019775792, causal = False
window_size_left = -1, window_size_right = -1, softcap = 0.0
alibi_slopes = None, return_softmax = False
@_torch_custom_op_wrapper("flash_attn::_flash_attn_forward", mutates_args=(), device_types="cuda")
def _flash_attn_forward(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
dropout_p: float,
softmax_scale: float,
causal: bool,
window_size_left: int,
window_size_right: int,
softcap: float,
alibi_slopes: Optional[torch.Tensor],
return_softmax: bool
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
> out, softmax_lse, S_dmask, rng_state = flash_attn_gpu.fwd(
q,
k,
v,
None,
alibi_slopes,
dropout_p,
softmax_scale,
causal,
window_size_left,
window_size_right,
softcap,
return_softmax,
None,
)
E RuntimeError: FlashAttention only supports Ampere GPUs or newer.
.venv/lib/python3.11/site-packages/flash_attn/flash_attn_interface.py:91: RuntimeError |
I think the order of the decorators was not right, and the skips were ignored. |
Now the test is properly skipped. |
("HuggingFaceTB/SmolVLM-Instruct",), # Only test the smaller model to avoid OOM | ||
] | ||
) | ||
@require_flash_attn |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok so @require_X
must be after parametrized
. For consistency, can you also modify
test_training_vlm_and_vllm
(it's deactivated but still, it's good to have it right)test_training_with_judge
test_training_with_transformers_paged
test_xpo_trainer_judge_training
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, also for:
test_dpo_trainer_with_liger
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm
Fix:
This PR moves the decorators
@require_flash_attn
,@require_bitsandbytes
, and@require_peft
to be applied after the@parameterized.expand
decorator for thetest_vlm_training
method, so they are not ignored and the test skip is properly applied.Additionally, this PR fixes the decorator order for all parameterized tests.
Fix #4175.