Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tests/quantization/bnb/test_4bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -886,6 +886,7 @@ def quantization_config(self):
components_to_quantize=["transformer", "text_encoder_2"],
)

@require_bitsandbytes_version_greater("0.46.1")
def test_torch_compile(self):
torch._dynamo.config.capture_dynamic_output_shape_ops = True
super().test_torch_compile()
Expand Down
4 changes: 4 additions & 0 deletions tests/quantization/bnb/test_mixed_int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -847,6 +847,10 @@ def quantization_config(self):
components_to_quantize=["transformer", "text_encoder_2"],
)

@pytest.mark.xfail(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@matthewdouglas I get:

- 0/0: expected type of 'module._modules['norm_out']._modules['linear']._parameters['weight'].CB' to be a tensor type, ' but found <class 'NoneType'>

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the time being I'm not sure that we can do a whole lot to avoid this for bnb int8. At the very least it is not a high priority for us. Not 100% sure but it's possible you could get around this by making a forward pass through the model prior to compiling it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will note this down then in the xfail reason.

reason="Test fails because of an offloading problem from Accelerate with confusion in hooks."
" Test passes without recompilation context manager."
)
def test_torch_compile(self):
torch._dynamo.config.capture_dynamic_output_shape_ops = True
super()._test_torch_compile(torch_dtype=torch.float16)
Expand Down
10 changes: 8 additions & 2 deletions tests/quantization/test_torch_compile_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,18 @@ def _test_torch_compile(self, torch_dtype=torch.bfloat16):
pipe.transformer.compile(fullgraph=True)

# small resolutions to ensure speedy execution.
pipe("a dog", num_inference_steps=2, max_sequence_length=16, height=256, width=256)
with torch._dynamo.config.patch(error_on_recompile=True):
pipe("a dog", num_inference_steps=2, max_sequence_length=16, height=256, width=256)

def _test_torch_compile_with_cpu_offload(self, torch_dtype=torch.bfloat16):
pipe = self._init_pipeline(self.quantization_config, torch_dtype)
pipe.enable_model_cpu_offload()
pipe.transformer.compile()
# regional compilation is better for offloading.
# see: https://pytorch.org/blog/torch-compile-and-diffusers-a-hands-on-guide-to-peak-performance/
if getattr(pipe.transformer, "_repeated_blocks"):
pipe.transformer.compile_repeated_blocks(fullgraph=True)
else:
pipe.transformer.compile()

# small resolutions to ensure speedy execution.
pipe("a dog", num_inference_steps=2, max_sequence_length=16, height=256, width=256)
Expand Down
Loading