Skip to content

Commit 27d766d

Browse files
[Pallas/interpreter] Move SharedMemory out of the main interpreter source file.
This is the first step in a list of changes that aim to make the data structures used by the interpreter re-usable. PiperOrigin-RevId: 817149786
1 parent c15d107 commit 27d766d

14 files changed

+1774
-1441
lines changed

jax/_src/pallas/core.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1443,7 +1443,7 @@ def _core_map_abstract_eval(*args, jaxpr, mesh, **kwargs):
14431443
effs = set()
14441444
if interpret:
14451445
try:
1446-
from jax._src.pallas.mosaic import interpret as mosaic_tpu_interpret # Avoid circular dependency.
1446+
from jax._src.pallas.mosaic.interpret import interpret_pallas_call as mosaic_tpu_interpret # Avoid circular dependency.
14471447
if isinstance(interpret, mosaic_tpu_interpret.InterpretParams):
14481448
effs = mosaic_tpu_interpret.get_interpret_effects()
14491449
except ImportError:
@@ -1619,7 +1619,7 @@ def _core_map_typecheck_rule(_, *in_atoms, jaxpr, mesh, **kwargs):
16191619
effs = set()
16201620
if interpret:
16211621
try:
1622-
from jax._src.pallas.mosaic import interpret as mosaic_tpu_interpret # Avoid circular dependency.
1622+
from jax._src.pallas.mosaic.interpret import interpret_pallas_call as mosaic_tpu_interpret # Avoid circular dependency.
16231623
if isinstance(interpret, mosaic_tpu_interpret.InterpretParams):
16241624
effs = mosaic_tpu_interpret.get_interpret_effects()
16251625
except ImportError:

jax/_src/pallas/mosaic/BUILD

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -230,19 +230,3 @@ py_library(
230230
"//jax/_src/pallas",
231231
],
232232
)
233-
234-
py_library(
235-
name = "interpret",
236-
srcs = ["interpret.py"],
237-
deps = [
238-
":core",
239-
":primitives",
240-
":verification",
241-
"//jax",
242-
"//jax/_src:core",
243-
"//jax/_src:source_info_util",
244-
"//jax/_src:util",
245-
"//jax/_src/lib",
246-
"//jax/_src/pallas",
247-
] + py_deps("numpy"),
248-
)
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# Copyright 2025 The JAX Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# Package for Pallas TPU Interpret Mode
16+
17+
load("@rules_python//python:defs.bzl", "py_library")
18+
load("//jaxlib:jax.bzl", "py_deps", "pytype_strict_library")
19+
20+
package(
21+
default_applicable_licenses = [],
22+
default_visibility = [
23+
"//jax:internal",
24+
],
25+
)
26+
27+
py_library(
28+
name = "interpret_pallas_call",
29+
srcs = ["interpret_pallas_call.py"],
30+
deps = [
31+
":race_detection_state",
32+
":shared_memory",
33+
":vector_clock",
34+
"//jax",
35+
"//jax/_src:api",
36+
"//jax/_src:callback",
37+
"//jax/_src:config",
38+
"//jax/_src:core",
39+
"//jax/_src:frozen_dict",
40+
"//jax/_src:lax",
41+
"//jax/_src:mlir",
42+
"//jax/_src:source_info_util",
43+
"//jax/_src:typing",
44+
"//jax/_src:util",
45+
"//jax/_src/pallas",
46+
"//jax/_src/pallas/mosaic:core",
47+
"//jax/_src/pallas/mosaic:primitives",
48+
"//jax/_src/pallas/mosaic:verification",
49+
] + py_deps("numpy"),
50+
)
51+
52+
pytype_strict_library(
53+
name = "vector_clock",
54+
srcs = ["vector_clock.py"],
55+
deps = py_deps("numpy"),
56+
)
57+
58+
pytype_strict_library(
59+
name = "shared_memory",
60+
srcs = ["shared_memory.py"],
61+
deps = [
62+
":race_detection_state",
63+
":vector_clock",
64+
"//jax",
65+
"//jax/_src:source_info_util",
66+
"//jax/_src:typing",
67+
"//jax/_src/pallas",
68+
"//jax/_src/pallas/mosaic:core",
69+
] + py_deps("numpy"),
70+
)
71+
72+
pytype_strict_library(
73+
name = "race_detection_state",
74+
srcs = ["race_detection_state.py"],
75+
deps = [
76+
":vector_clock",
77+
"//jax/_src:source_info_util",
78+
],
79+
)
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright 2025 The JAX Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.

0 commit comments

Comments
 (0)