Skip to content

Commit 06cd1fe

Browse files
hawkinspjax authors
authored andcommitted
Move dtype canonicalization out of core.AbstractValue subclasses.
This is a strictly mechanical change that moves abstract value canonicalization out of the core.AbstractValue subclasses and into their callers. This makes it safe to manipulate non-canonical abstract values even inside an -x32 context. The callers to which canonicalization was added were: a) all callers of `ConcreteArray` inside the JAX Tree. b) all callers of `ShapedArray` and `UnshapedArray` that were found to be passing non-canonical dtypes during a global presubmit. These were identified by adding an assertion that the dtype is in fact canonical and fixing all the resulting test failures. PiperOrigin-RevId: 414704700
1 parent 56f029f commit 06cd1fe

File tree

14 files changed

+56
-38
lines changed

14 files changed

+56
-38
lines changed

jax/_src/abstract_arrays.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,18 +49,23 @@ def zeros_like_array(x):
4949
np.complex64, np.complex128,
5050
np.longlong, np.intc}
5151

52+
def canonical_concrete_aval(val, weak_type=None):
53+
return ConcreteArray(dtypes.canonicalize_dtype(np.result_type(val)), val,
54+
weak_type=weak_type)
55+
5256
for t in array_types:
53-
core.pytype_aval_mappings[t] = ConcreteArray
57+
core.pytype_aval_mappings[t] = canonical_concrete_aval
5458
ad_util.jaxval_zeros_likers[t] = zeros_like_array
5559

5660
core.literalable_types.update(array_types)
5761

5862
def _zeros_like_python_scalar(t, x):
59-
aval = core.ShapedArray((), dtypes.python_scalar_dtypes[t], weak_type=True)
63+
dtype = dtypes.canonicalize_dtype(dtypes.python_scalar_dtypes[t])
64+
aval = core.ShapedArray((), dtype, weak_type=True)
6065
return ad_util.zeros_like_aval(aval)
6166

6267
def _make_concrete_python_scalar(t, x):
63-
return ConcreteArray(
68+
return canonical_concrete_aval(
6469
np.array(x, dtype=dtypes._scalar_type_to_dtype(t, x)),
6570
weak_type=True)
6671

jax/_src/api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@
6969
local_devices, process_index,
7070
process_count, host_id, host_ids,
7171
host_count, default_backend)
72-
from jax.core import ConcreteArray, ShapedArray, raise_to_shaped
72+
from jax.core import ShapedArray, raise_to_shaped
7373
from jax.interpreters import partial_eval as pe
7474
from jax.interpreters import xla
7575
from jax.interpreters import pxla

jax/_src/device_array.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
from jax import core
2525
from jax._src.config import config
26+
from jax._src import abstract_arrays
2627
from jax._src import dtypes
2728
from jax._src import profiler
2829
from jax._src.lib import xla_client as xc
@@ -306,4 +307,4 @@ class DeletedBuffer(object): pass
306307
device_array_types: List[type] = [xc.Buffer, _DeviceArray]
307308
for _device_array in device_array_types:
308309
core.literalable_types.add(_device_array)
309-
core.pytype_aval_mappings[device_array] = core.ConcreteArray
310+
core.pytype_aval_mappings[device_array] = abstract_arrays.canonical_concrete_aval

jax/_src/lax/control_flow.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1480,7 +1480,7 @@ def scan(f, init, xs, length=None):
14801480
return carry, stacked_y
14811481

14821482
x_shapes = [masking.padded_shape_as_value(x.shape[1:]) for x in xs_flat]
1483-
x_dtypes = [x.dtype for x in xs_flat]
1483+
x_dtypes = [dtypes.canonicalize_dtype(x.dtype) for x in xs_flat]
14841484
x_avals = tuple(_map(ShapedArray, x_shapes, x_dtypes))
14851485

14861486
def _create_jaxpr(init):
@@ -2038,7 +2038,7 @@ def masked(*args):
20382038
for new_c, c in zip(new_carry, carry)]
20392039
return [i + 1] + new_carry + ys
20402040

2041-
aval = ShapedArray((), dtypes.int_)
2041+
aval = ShapedArray((), dtypes.canonicalize_dtype(dtypes.int_))
20422042
const_avals, carry_avals, x_avals = split_list(jaxpr.in_avals, [num_consts, num_carry])
20432043
return _make_closed_jaxpr(masked, [aval] + const_avals + [aval] + carry_avals + x_avals)
20442044

jax/_src/lax/slicing.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1240,7 +1240,8 @@ def _gather_batching_rule(batched_args, batch_dims, *, dimension_numbers,
12401240
if core.symbolic_equal_dim(operand.shape[0], 0):
12411241
output_shape = _gather_shape_rule(
12421242
core.ShapedArray(operand.shape[1:], operand.dtype),
1243-
core.ShapedArray(indices.shape[1:], indices.dtype),
1243+
core.ShapedArray(indices.shape[1:],
1244+
dtypes.canonicalize_dtype(indices.dtype)),
12441245
dimension_numbers=dimension_numbers, slice_sizes=slice_sizes,
12451246
unique_indices=unique_indices, indices_are_sorted=indices_are_sorted,
12461247
mode=mode, fill_value=fill_value)
@@ -1456,8 +1457,8 @@ def _scatter_translation_rule(ctx, avals_in, avals_out, operand, indices,
14561457
if mode == GatherScatterMode.CLIP:
14571458
clip_fn = xla.lower_fun(_clamp_scatter_indices, multiple_results=False,
14581459
new_style=True)
1459-
indices, = clip_fn(ctx, avals_in, [indices_aval.update(dtype=np.int64)],
1460-
operand, indices, updates, dnums=dimension_numbers)
1460+
indices, = clip_fn(ctx, avals_in, None, operand, indices, updates,
1461+
dnums=dimension_numbers)
14611462

14621463
c = ctx.builder
14631464

@@ -1477,8 +1478,8 @@ def _scatter_add_translation_rule(
14771478
if mode == GatherScatterMode.CLIP:
14781479
clip_fn = xla.lower_fun(_clamp_scatter_indices, multiple_results=False,
14791480
new_style=True)
1480-
indices, = clip_fn(ctx, avals_in, [indices_aval.update(dtype=np.int64)],
1481-
operand, indices, updates, dnums=dimension_numbers)
1481+
indices, = clip_fn(ctx, avals_in, None, operand, indices, updates,
1482+
dnums=dimension_numbers)
14821483

14831484
dtype = operand_aval.dtype
14841485
scatter_dims = _scatter_dimensions_proto(

jax/_src/lax/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,8 @@ def standard_abstract_eval(prim, shape_rule, dtype_rule, weak_type_rule,
6060
least_specialized = _max(map(type, avals),
6161
key=operator.attrgetter('array_abstraction_level'))
6262
if least_specialized is core.ConcreteArray:
63-
return core.ConcreteArray(prim.impl(*[x.val for x in avals], **kwargs),
64-
weak_type=weak_type)
63+
out = prim.impl(*[x.val for x in avals], **kwargs)
64+
return core.ConcreteArray(out.dtype, out, weak_type=weak_type)
6565
elif least_specialized is core.ShapedArray:
6666
return core.ShapedArray(shape_rule(*avals, **kwargs),
6767
dtype_rule(*avals, **kwargs), weak_type=weak_type,
@@ -81,7 +81,7 @@ def standard_multi_result_abstract_eval(
8181
weak_types = weak_type_rule(*avals, **kwargs)
8282
if least_specialized is core.ConcreteArray:
8383
out_vals = prim.impl(*[x.val for x in avals], **kwargs)
84-
return [core.ConcreteArray(val, weak_type=weak_type)
84+
return [core.ConcreteArray(val.dtype, val, weak_type=weak_type)
8585
for val, weak_type in safe_zip(out_vals, weak_types)]
8686
elif least_specialized is core.ShapedArray:
8787
out_shapes = shape_rule(*avals, **kwargs)

jax/core.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1041,7 +1041,7 @@ class UnshapedArray(AbstractValue):
10411041
array_abstraction_level = 2
10421042

10431043
def __init__(self, dtype, weak_type=False):
1044-
self.dtype = np.dtype(dtypes.canonicalize_dtype(dtype))
1044+
self.dtype = np.dtype(dtype)
10451045
self.weak_type = weak_type
10461046

10471047
def update(self, dtype=None, weak_type=None):
@@ -1183,19 +1183,20 @@ class ConcreteArray(ShapedArray):
11831183
__slots__ = ['val']
11841184
array_abstraction_level = 0
11851185

1186-
def __init__(self, val, weak_type=None):
1187-
super().__init__(np.shape(val), np.result_type(val),
1188-
weak_type=dtypes.is_weakly_typed(val) if weak_type is None else weak_type)
1186+
def __init__(self, dtype, val, weak_type=None):
1187+
super().__init__(
1188+
np.shape(val), dtype,
1189+
weak_type=dtypes.is_weakly_typed(val) if weak_type is None else weak_type)
11891190
# Note: canonicalized self.dtype doesn't necessarily match self.val
1191+
assert self.dtype == dtypes.canonicalize_dtype(np.result_type(val)), (val, dtype)
11901192
self.val = val
11911193
assert self.dtype != np.dtype('O'), val
11921194

1193-
def update(self, val=None, weak_type=None):
1194-
if val is None:
1195-
val = self.val
1196-
if weak_type is None:
1197-
weak_type = self.weak_type
1198-
return ConcreteArray(val, weak_type)
1195+
def update(self, dtype=None, val=None, weak_type=None):
1196+
dtype = self.dtype if dtype is None else dtype
1197+
val = self.val if val is None else val
1198+
weak_type = self.weak_type if weak_type is None else weak_type
1199+
return ConcreteArray(dtype, val, weak_type)
11991200

12001201
def __eq__(self, other):
12011202
if (type(self) is type(other) and self.dtype == other.dtype
@@ -1271,7 +1272,8 @@ def raise_to_shaped(aval: AbstractValue, weak_type=None):
12711272
Bot: lambda aval, _: aval,
12721273
UnshapedArray: lambda aval, _: aval,
12731274
ShapedArray: lambda aval, weak_type: ShapedArray(
1274-
aval.shape, aval.dtype, weak_type, aval.named_shape)
1275+
aval.shape, dtypes.canonicalize_dtype(aval.dtype), weak_type,
1276+
aval.named_shape)
12751277
}
12761278

12771279
### Operations on shapes and dimension sizes.

jax/experimental/jax2tf/call_tf.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -351,14 +351,14 @@ def is_fully_known_shape(s):
351351
xla_comp_parameter_shapes = xla_comp.program_shape().parameter_shapes()
352352
found_parameter_avals = [
353353
core.ShapedArray(found_xla_shape.dimensions(),
354-
found_xla_shape.numpy_dtype())
354+
dtypes.canonicalize_dtype(found_xla_shape.numpy_dtype()))
355355
for found_xla_shape in xla_comp_parameter_shapes
356356
]
357357
# Add the captured_inputs to args_flat_sig_tf
358358
expected_args_flat_sig_tf = list(args_flat_sig_tf) + list(captured_inputs)
359359
expected_parameter_avals = [
360360
core.ShapedArray(tuple(arg_sig.shape.as_list()),
361-
arg_sig.dtype.as_numpy_dtype)
361+
dtypes.canonicalize_dtype(arg_sig.dtype.as_numpy_dtype))
362362
for arg_sig in expected_args_flat_sig_tf]
363363
if found_parameter_avals != expected_parameter_avals:
364364
msg = ("Compiled TensorFlow function has unexpected parameter types " +

jax/interpreters/masking.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -459,7 +459,8 @@ def __init__(self, trace, val, polymorphic_shape):
459459

460460
@property
461461
def aval(self):
462-
return ShapedArray(self.polymorphic_shape, self.dtype)
462+
return ShapedArray(self.polymorphic_shape,
463+
dtypes.canonicalize_dtype(self.dtype))
463464

464465
@property
465466
def dtype(self):

jax/interpreters/pxla.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
from jax._src.config import config
4646
from jax import core
4747
from jax import linear_util as lu
48+
from jax._src import abstract_arrays
4849
from jax._src.abstract_arrays import array_types
4950
from jax.core import ConcreteArray, ShapedArray
5051
from jax._src import device_array
@@ -740,7 +741,7 @@ def _register_handlers_for_sharded_device_array(sda):
740741
shard_arg_handlers[sda] = _shard_sharded_device_array_slow_path
741742
xla.register_constant_handler(sda, _sharded_device_array_constant_handler)
742743

743-
core.pytype_aval_mappings[sda] = ConcreteArray
744+
core.pytype_aval_mappings[sda] = abstract_arrays.canonical_concrete_aval
744745
dispatch.device_put_handlers[sda] = dispatch._device_put_array
745746
xla.pytype_aval_mappings[sda] = op.attrgetter("aval")
746747
xla.canonicalize_dtype_handlers[sda] = identity

0 commit comments

Comments
 (0)