File tree Expand file tree Collapse file tree 1 file changed +6
-1
lines changed Expand file tree Collapse file tree 1 file changed +6
-1
lines changed Original file line number Diff line number Diff line change @@ -502,10 +502,15 @@ def jax_memory_cleanup(layer):
502502 # For jax, delete all previous allocated memory to avoid temporarily
503503 # duplicating variable allocations. torch and tensorflow have stateful
504504 # variable types and do not need this fix.
505+ # Skip deletion for sharded arrays to avoid breaking references in distributed setups.
505506 if keras .config .backend () == "jax" :
506507 for weight in layer .weights :
507508 if getattr (weight , "_value" , None ) is not None :
508- weight ._value .delete ()
509+ # Do not delete sharded arrays, as they may be referenced in JAX's
510+ # distributed computation graph and deletion can cause errors.
511+ if not (hasattr (weight ._value , 'sharding' )
512+ and weight ._value .sharding is not None ):
513+ weight ._value .delete ()
509514
510515
511516def set_dtype_in_config (config , dtype = None ):
You can’t perform that action at this time.
0 commit comments