Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion keras_hub/src/utils/preset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,10 +502,15 @@ def jax_memory_cleanup(layer):
# For jax, delete all previous allocated memory to avoid temporarily
# duplicating variable allocations. torch and tensorflow have stateful
# variable types and do not need this fix.
# Skip deletion for sharded arrays to avoid breaking references in distributed setups.
if keras.config.backend() == "jax":
for weight in layer.weights:
if getattr(weight, "_value", None) is not None:
weight._value.delete()
# Do not delete sharded arrays, as they may be referenced in JAX's
# distributed computation graph and deletion can cause errors.
if not (hasattr(weight._value, 'sharding')
and weight._value.sharding is not None):
weight._value.delete()


def set_dtype_in_config(config, dtype=None):
Expand Down
Loading