Skip to content

Commit 9f6f33b

Browse files
For sharded weights let's not delete explicitly
1 parent 6f72208 commit 9f6f33b

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

keras_hub/src/utils/preset_utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -502,10 +502,15 @@ def jax_memory_cleanup(layer):
502502
# For jax, delete all previous allocated memory to avoid temporarily
503503
# duplicating variable allocations. torch and tensorflow have stateful
504504
# variable types and do not need this fix.
505+
# Skip deletion for sharded arrays to avoid breaking references in distributed setups.
505506
if keras.config.backend() == "jax":
506507
for weight in layer.weights:
507508
if getattr(weight, "_value", None) is not None:
508-
weight._value.delete()
509+
# Do not delete sharded arrays, as they may be referenced in JAX's
510+
# distributed computation graph and deletion can cause errors.
511+
if not (hasattr(weight._value, 'sharding')
512+
and weight._value.sharding is not None):
513+
weight._value.delete()
509514

510515

511516
def set_dtype_in_config(config, dtype=None):

0 commit comments

Comments
 (0)