Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 52 additions & 105 deletions tests/pipelines/test_pipelines_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
UNet2DConditionModel,
apply_faster_cache,
)
from diffusers.hooks import apply_group_offloading
from diffusers.hooks.faster_cache import FasterCacheBlockHook, FasterCacheDenoiserHook
from diffusers.hooks.first_block_cache import FirstBlockCacheConfig
from diffusers.hooks.pyramid_attention_broadcast import PyramidAttentionBroadcastHook
Expand Down Expand Up @@ -2244,80 +2243,6 @@ def test_layerwise_casting_inference(self):
inputs = self.get_dummy_inputs(torch_device)
_ = pipe(**inputs)[0]

@require_torch_accelerator
def test_group_offloading_inference(self):
if not self.test_group_offloading:
return

def create_pipe():
torch.manual_seed(0)
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.set_progress_bar_config(disable=None)
return pipe

def enable_group_offload_on_component(pipe, group_offloading_kwargs):
# We intentionally don't test VAE's here. This is because some tests enable tiling on the VAE. If
# tiling is enabled and a forward pass is run, when accelerator streams are used, the execution order of
# the layers is not traced correctly. This causes errors. For apply group offloading to VAE, a
# warmup forward pass (even with dummy small inputs) is recommended.
for component_name in [
"text_encoder",
"text_encoder_2",
"text_encoder_3",
"transformer",
"unet",
"controlnet",
]:
if not hasattr(pipe, component_name):
continue
component = getattr(pipe, component_name)
if not getattr(component, "_supports_group_offloading", True):
continue
if hasattr(component, "enable_group_offload"):
# For diffusers ModelMixin implementations
component.enable_group_offload(torch.device(torch_device), **group_offloading_kwargs)
else:
# For other models not part of diffusers
apply_group_offloading(
component, onload_device=torch.device(torch_device), **group_offloading_kwargs
)
self.assertTrue(
all(
module._diffusers_hook.get_hook("group_offloading") is not None
for module in component.modules()
if hasattr(module, "_diffusers_hook")
)
)
for component_name in ["vae", "vqvae", "image_encoder"]:
component = getattr(pipe, component_name, None)
if isinstance(component, torch.nn.Module):
component.to(torch_device)

def run_forward(pipe):
torch.manual_seed(0)
inputs = self.get_dummy_inputs(torch_device)
return pipe(**inputs)[0]

pipe = create_pipe().to(torch_device)
output_without_group_offloading = run_forward(pipe)

pipe = create_pipe()
enable_group_offload_on_component(pipe, {"offload_type": "block_level", "num_blocks_per_group": 1})
output_with_group_offloading1 = run_forward(pipe)

pipe = create_pipe()
enable_group_offload_on_component(pipe, {"offload_type": "leaf_level"})
output_with_group_offloading2 = run_forward(pipe)

if torch.is_tensor(output_without_group_offloading):
output_without_group_offloading = output_without_group_offloading.detach().cpu().numpy()
output_with_group_offloading1 = output_with_group_offloading1.detach().cpu().numpy()
output_with_group_offloading2 = output_with_group_offloading2.detach().cpu().numpy()

self.assertTrue(np.allclose(output_without_group_offloading, output_with_group_offloading1, atol=1e-4))
self.assertTrue(np.allclose(output_without_group_offloading, output_with_group_offloading2, atol=1e-4))

def test_torch_dtype_dict(self):
components = self.get_dummy_components()
if not components:
Expand Down Expand Up @@ -2364,7 +2289,7 @@ def test_pipeline_with_accelerator_device_map(self, expected_max_difference=1e-4
self.assertLess(max_diff, expected_max_difference)

@require_torch_accelerator
def test_pipeline_level_group_offloading_sanity_checks(self):
def test_group_offloading_sanity_checks(self):
components = self.get_dummy_components()
pipe: DiffusionPipeline = self.pipeline_class(**components)

Expand Down Expand Up @@ -2395,40 +2320,62 @@ def test_pipeline_level_group_offloading_sanity_checks(self):
self.assertTrue(torch.device(component_device).type == torch.device(offload_device).type)

@require_torch_accelerator
def test_pipeline_level_group_offloading_inference(self, expected_max_difference=1e-4):
components = self.get_dummy_components()
pipe: DiffusionPipeline = self.pipeline_class(**components)
def test_group_offloading_inference(self):
if not self.test_group_offloading:
pytest.skip("`test_group_offloading` is disabled hence skipping.")

for name, component in pipe.components.items():
if hasattr(component, "_supports_group_offloading"):
if not component._supports_group_offloading:
pytest.skip(f"{self.pipeline_class.__name__} is not suitable for this test.")
def create_pipe():
torch.manual_seed(0)
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.set_progress_bar_config(disable=None)
return pipe

# Regular inference.
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
torch.manual_seed(0)
inputs = self.get_dummy_inputs(torch_device)
inputs["generator"] = torch.manual_seed(0)
out = pipe(**inputs)[0]
def enable_group_offload_on_component(pipe, group_offloading_kwargs):
# We intentionally don't test VAE's here. This is because some tests enable tiling on the VAE. If
# tiling is enabled and a forward pass is run, when accelerator streams are used, the execution order of
# the layers is not traced correctly. This causes errors. For apply group offloading to VAE, a
# warmup forward pass (even with dummy small inputs) is recommended.
exclude_modules = []
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't we just set directly? Why nested loops? Or even set them via group_offloading_kwargs?

exclude_modules = ["vae", "vqvae", "image_encoder"]

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just readability convenience.

Passing exclude_modules = ["vae", "vqvae", "image_encoder"] is a bit misleading for the test, IMO. We would still log an info if non-member components are passed, though:

However, I have simplified this in a commit:

exclude_modules = {"vae", "vqvae", "image_encoder"}
exclude_modules = list(exclude_modules & set(pipe.components.keys()))

for name, component in pipe.components.items():
for name in ["vae", "vqvae", "image_encoder"]:
exclude_modules.append(name)
pipe.enable_group_offload(
exclude_modules=exclude_modules, onload_device=torch_device, **group_offloading_kwargs
)
for component_name, component in pipe.components.items():
if component_name not in exclude_modules and isinstance(component, torch.nn.Module):
self.assertTrue(
all(
module._diffusers_hook.get_hook("group_offloading") is not None
for module in component.modules()
if hasattr(module, "_diffusers_hook")
)
)

pipe.to("cpu")
del pipe
def run_forward(pipe):
torch.manual_seed(0)
inputs = self.get_dummy_inputs(torch_device)
return pipe(**inputs)[0]

# Inference with offloading
pipe: DiffusionPipeline = self.pipeline_class(**components)
offload_device = "cpu"
pipe.enable_group_offload(
onload_device=torch_device,
offload_device=offload_device,
offload_type="leaf_level",
)
pipe.set_progress_bar_config(disable=None)
inputs["generator"] = torch.manual_seed(0)
out_offload = pipe(**inputs)[0]
pipe = create_pipe().to(torch_device)
output_without_group_offloading = run_forward(pipe)

max_diff = np.abs(to_np(out) - to_np(out_offload)).max()
self.assertLess(max_diff, expected_max_difference)
pipe = create_pipe()
enable_group_offload_on_component(pipe, {"offload_type": "block_level", "num_blocks_per_group": 1})
output_with_group_offloading1 = run_forward(pipe)

pipe = create_pipe()
enable_group_offload_on_component(pipe, {"offload_type": "leaf_level"})
output_with_group_offloading2 = run_forward(pipe)

if torch.is_tensor(output_without_group_offloading):
output_without_group_offloading = output_without_group_offloading.detach().cpu().numpy()
output_with_group_offloading1 = output_with_group_offloading1.detach().cpu().numpy()
output_with_group_offloading2 = output_with_group_offloading2.detach().cpu().numpy()

self.assertTrue(np.allclose(output_without_group_offloading, output_with_group_offloading1, atol=1e-4))
self.assertTrue(np.allclose(output_without_group_offloading, output_with_group_offloading2, atol=1e-4))


@is_staging_test
Expand Down
Loading