We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 9e7ae56 commit feb4827Copy full SHA for feb4827
src/diffusers/pipelines/pipeline_utils.py
@@ -504,6 +504,11 @@ def module_is_offloaded(module):
504
os.environ["PT_HPU_MAX_COMPOUND_OP_SIZE"] = "1"
505
logger.debug("Environment variable set: PT_HPU_MAX_COMPOUND_OP_SIZE=1")
506
507
+ if dtype == torch.bfloat16 and kwargs.pop("sdp_on_bf16", True):
508
+ if hasattr(torch._C, "_set_math_sdp_allow_fp16_bf16_reduction"):
509
+ torch._C._set_math_sdp_allow_fp16_bf16_reduction(True)
510
+ logger.debug("Enabled SDP with BF16 precision on HPU. To disable, please use `.to('hpu', sdp_on_bf16=False)`")
511
+
512
module_names, _ = self._get_signature_keys(self)
513
modules = [getattr(self, n, None) for n in module_names]
514
modules = [m for m in modules if isinstance(m, torch.nn.Module)]
0 commit comments