Skip to content

Commit 403977d

Browse files
committed
Reverts f44bc6e
PiperOrigin-RevId: 819903461
1 parent b1ce4c5 commit 403977d

File tree

5 files changed

+3
-19
lines changed

5 files changed

+3
-19
lines changed

jax/_src/lib/mlir/dialects/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,8 @@
5656
# TODO(joelwee): Remove this once jaxlib 0.8 is the minimum.
5757
try:
5858
from jaxlib.mlir.dialects import mpmd
59-
from jaxlib.mlir.dialects import ub
6059
except ImportError:
6160
mpmd: Any = None # type: ignore[no-redef]
62-
ub: Any = None # type: ignore[no-redef]
6361
from jaxlib.mlir.dialects import sdy
6462

6563
# Alias that is set up to abstract away the transition from MHLO to StableHLO.

jaxlib/BUILD

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,6 @@ pytype_strict_library(
8686
"//jaxlib/mlir:sdy_dialect",
8787
"//jaxlib/mlir:sparse_tensor_dialect",
8888
"//jaxlib/mlir:stablehlo_dialect",
89-
"//jaxlib/mlir:ub_dialect",
9089
"//jaxlib/mlir:vector_dialect",
9190
"//jaxlib/mlir/_mlir_libs:_jax_mlir_ext",
9291
"//jaxlib/mosaic",

jaxlib/mlir/BUILD.bazel

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -87,19 +87,6 @@ symlink_inputs(
8787
],
8888
)
8989

90-
symlink_inputs(
91-
name = "ub_dialect",
92-
rule = py_library,
93-
symlinked_inputs = {"srcs": {
94-
"dialects": ["@llvm-project//mlir/python:UbPyFiles"],
95-
}},
96-
deps = [
97-
":core",
98-
":ir",
99-
":mlir",
100-
],
101-
)
102-
10390
symlink_inputs(
10491
name = "math_dialect",
10592
rule = py_library,

jaxlib/tools/build_wheel.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,6 @@ def prepare_wheel(wheel_sources_path: pathlib.Path, *, cpu, wheel_sources):
311311
f"{source_file_prefix}jaxlib/mlir/dialects/sdy.py",
312312
f"{source_file_prefix}jaxlib/mlir/dialects/sparse_tensor.py",
313313
f"{source_file_prefix}jaxlib/mlir/dialects/stablehlo.py",
314-
f"{source_file_prefix}jaxlib/mlir/dialects/ub.py",
315314
f"{source_file_prefix}jaxlib/mlir/dialects/vector.py",
316315
f"{source_file_prefix}jaxlib/mlir/dialects/nvgpu.py",
317316
f"{source_file_prefix}jaxlib/mlir/dialects/nvvm.py",

tests/mosaic/gpu_dialect_test.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,12 @@
2424
from jax._src.interpreters import mlir as mlir_interpreter
2525
from jax._src.lib.mlir import ir
2626
from jax._src.lib.mlir.dialects import arith
27+
from jax._src.lib.mlir.dialects import builtin
2728
from jax._src.lib.mlir.dialects import gpu
2829
from jax._src.lib.mlir.dialects import llvm
2930
from jax._src.lib.mlir.dialects import memref
3031
from jax._src.lib.mlir.dialects import nvvm
3132
from jax._src.lib.mlir.dialects import scf
32-
from jax._src.lib.mlir.dialects import ub
3333
from jax._src.lib.mlir.dialects import vector
3434
from jax.experimental.mosaic import gpu as mgpu
3535
from jax.experimental.mosaic.gpu import dialect_lowering as lowering
@@ -86,7 +86,8 @@ def workgroup_ptr_ty() -> ir.Type:
8686

8787
def undefs(*tys: ir.Type) -> list[ir.Value]:
8888
"""Returns a list of undefined values of the given types."""
89-
return [ub.poison(ty) for ty in tys]
89+
# TODO(allanrenucci): Use `ub.poison` once Python bindings are available.
90+
return [builtin.unrealized_conversion_cast([ty], []) for ty in tys]
9091

9192

9293
class MosaicGpuTest(parameterized.TestCase):

0 commit comments

Comments
 (0)