Skip to content

Commit f44bc6e

Browse files
allanrenucciGoogle-ML-Automation
authored andcommitted
[Mosaic GPU][NFC] Create undefined values in test with ub.poison.
`ub.PoisonOp` and `llvm.UndefOp` serve a similar purpose in tests. However, `llvm.UndefOp` is more restrictive as it can only creates an undefined value of the specified **LLVM IR** dialect type. PiperOrigin-RevId: 818583066
1 parent eadee4f commit f44bc6e

File tree

5 files changed

+19
-3
lines changed

5 files changed

+19
-3
lines changed

jax/_src/lib/mlir/dialects/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,10 @@
5656
# TODO(joelwee): Remove this once jaxlib 0.8 is the minimum.
5757
try:
5858
from jaxlib.mlir.dialects import mpmd
59+
from jaxlib.mlir.dialects import ub
5960
except ImportError:
6061
mpmd: Any = None # type: ignore[no-redef]
62+
ub: Any = None # type: ignore[no-redef]
6163
from jaxlib.mlir.dialects import sdy
6264

6365
# Alias that is set up to abstract away the transition from MHLO to StableHLO.

jaxlib/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ pytype_strict_library(
8686
"//jaxlib/mlir:sdy_dialect",
8787
"//jaxlib/mlir:sparse_tensor_dialect",
8888
"//jaxlib/mlir:stablehlo_dialect",
89+
"//jaxlib/mlir:ub_dialect",
8990
"//jaxlib/mlir:vector_dialect",
9091
"//jaxlib/mlir/_mlir_libs:_jax_mlir_ext",
9192
"//jaxlib/mosaic",

jaxlib/mlir/BUILD.bazel

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,19 @@ symlink_inputs(
8787
],
8888
)
8989

90+
symlink_inputs(
91+
name = "ub_dialect",
92+
rule = py_library,
93+
symlinked_inputs = {"srcs": {
94+
"dialects": ["@llvm-project//mlir/python:UbPyFiles"],
95+
}},
96+
deps = [
97+
":core",
98+
":ir",
99+
":mlir",
100+
],
101+
)
102+
90103
symlink_inputs(
91104
name = "math_dialect",
92105
rule = py_library,

jaxlib/tools/build_wheel.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,7 @@ def prepare_wheel(wheel_sources_path: pathlib.Path, *, cpu, wheel_sources):
311311
f"{source_file_prefix}jaxlib/mlir/dialects/sdy.py",
312312
f"{source_file_prefix}jaxlib/mlir/dialects/sparse_tensor.py",
313313
f"{source_file_prefix}jaxlib/mlir/dialects/stablehlo.py",
314+
f"{source_file_prefix}jaxlib/mlir/dialects/ub.py",
314315
f"{source_file_prefix}jaxlib/mlir/dialects/vector.py",
315316
f"{source_file_prefix}jaxlib/mlir/dialects/nvgpu.py",
316317
f"{source_file_prefix}jaxlib/mlir/dialects/nvvm.py",

tests/mosaic/gpu_dialect_test.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,12 @@
2424
from jax._src.interpreters import mlir as mlir_interpreter
2525
from jax._src.lib.mlir import ir
2626
from jax._src.lib.mlir.dialects import arith
27-
from jax._src.lib.mlir.dialects import builtin
2827
from jax._src.lib.mlir.dialects import gpu
2928
from jax._src.lib.mlir.dialects import llvm
3029
from jax._src.lib.mlir.dialects import memref
3130
from jax._src.lib.mlir.dialects import nvvm
3231
from jax._src.lib.mlir.dialects import scf
32+
from jax._src.lib.mlir.dialects import ub
3333
from jax._src.lib.mlir.dialects import vector
3434
from jax.experimental.mosaic import gpu as mgpu
3535
from jax.experimental.mosaic.gpu import dialect_lowering as lowering
@@ -86,8 +86,7 @@ def workgroup_ptr_ty() -> ir.Type:
8686

8787
def undefs(*tys: ir.Type) -> list[ir.Value]:
8888
"""Returns a list of undefined values of the given types."""
89-
# TODO(allanrenucci): Use `ub.poison` once Python bindings are available.
90-
return [builtin.unrealized_conversion_cast([ty], []) for ty in tys]
89+
return [ub.poison(ty) for ty in tys]
9190

9291

9392
class MosaicGpuTest(parameterized.TestCase):

0 commit comments

Comments
 (0)