Skip to content

Commit 76d310e

Browse files
ChromeHeartsThe sparsecore Authors
authored andcommitted
Add unstack_and_unshard for SparseCore
PiperOrigin-RevId: 708384192
1 parent 48edd8d commit 76d310e

File tree

21 files changed

+1367
-23
lines changed

21 files changed

+1367
-23
lines changed

.bazelrc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,11 @@ build:clang --copt=-Wno-gnu-offsetof-extensions
9898
# Disable clang extention that rejects unknown arguments.
9999
build:clang --copt=-Qunused-arguments
100100

101+
##############################################################################
102+
# Test configurations.
103+
##############################################################################
104+
test:cpu --test_env=JAX_PLATFORMS=cpu --test_tag_filters=cpu
105+
101106
#############################################################################
102107
# Some configs to make getting some forms of debug builds. In general, the
103108
# codebase is only regularly built with optimizations. Use 'debug_symbols' to
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
name: Build and test
2+
3+
on:
4+
# Only run workflow on pushes to main (includes PR merge), and on
5+
# opened pull-requests.
6+
push:
7+
branches:
8+
- main
9+
pull_request:
10+
11+
jobs:
12+
build_and_test:
13+
runs-on: ubuntu-24.04
14+
strategy:
15+
matrix:
16+
python-version: ["3.10"]
17+
18+
steps:
19+
- uses: actions/checkout@v4
20+
21+
- name: Set up Python ${{ matrix.python-version }}
22+
uses: actions/setup-python@v5
23+
with:
24+
python-version: ${{ matrix.python-version }}
25+
cache: 'pip'
26+
27+
- name: Display Python version
28+
run: python -c "import sys; print(sys.version)"
29+
30+
- name: Install dependencies
31+
run: |
32+
python -m pip install --upgrade pip setuptools wheel
33+
. build/install_bazelisk.sh
34+
35+
# Load different caches depending on if this is a pull-request or merge.
36+
# If merge (or push commit), use a read-write cache based on the python
37+
# version, branch, and commit-sha.
38+
# If pull-request, use a read-only cache based on the target python
39+
# version, branch, and PR base sha.
40+
- if: github.event_name != 'pull_request'
41+
name: Mount bazel cache (main)
42+
uses: actions/cache@v4
43+
with:
44+
path: "/home/runner/.cache/bazel"
45+
key: bazel-${{ matrix.python-version }}-${{ github.ref_name }}-${{ github.sha }}
46+
restore-keys: |
47+
bazel-${{ matrix.python-version }}-${{ github.ref_name }}
48+
bazel-${{ matrix.python-version }}-
49+
bazel-
50+
51+
- if: github.event_name == 'pull_request'
52+
name: Mount bazel cache (pull-request)
53+
uses: actions/cache/restore@v4
54+
with:
55+
path: "/home/runner/.cache/bazel"
56+
key: bazel-${{ matrix.python-version }}-${{ github.base_ref }}-${{ github.event.pull_request.base.sha }}
57+
restore-keys: |
58+
bazel-${{ matrix.python-version }}-${{ github.base_ref }}
59+
bazel-${{ matrix.python-version }}-
60+
bazel-
61+
62+
- name: Build all targets
63+
run: |
64+
export HERMETIC_PYTHON_VERSION=${{ matrix.python-version }}
65+
bazel build //...
66+
67+
- name: Build pip wheel
68+
run: |
69+
bazel run //build:build_pip_package -- $PWD
70+
71+
- name: Run CPU tests
72+
run: |
73+
bazel test --config=cpu --test_output=errors --keep_going //...

.gitignore

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,9 @@ poetry.lock
2222

2323
# PyCharm
2424
.idea
25+
26+
# Bazel
27+
/bazel-*
28+
29+
# Built wheels.
30+
/*.whl

build/requirements.in

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Library.
22
absl-py
3-
flax
3+
flax @ https://github.com/google/flax/archive/e2134af.zip
44
numpy
55
dm-tree
66
# Pre-release of JAX required for SparseCore TPUs.

build/requirements_lock_3_10.txt

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -200,9 +200,8 @@ etils[epath,epy]==1.10.0 \
200200
# clu
201201
# optax
202202
# orbax-checkpoint
203-
flax==0.10.1 \
204-
--hash=sha256:5218959706bc659a1f282ca537446163093d186d8edb9b1405c0efee4d90d22a \
205-
--hash=sha256:ea98ed843c37954af2e262ea47356312a046794d7a5490d31682dffe908e25d3
203+
flax @ https://github.com/google/flax/archive/e2134af.zip \
204+
--hash=sha256:6384171c69e4a09a1f4fa9c15acd6b48ad9332429c6b61a13412ecced088985d
206205
# via
207206
# -r build/requirements.in
208207
# clu
@@ -463,6 +462,7 @@ numpy==2.1.3 \
463462
# orbax-checkpoint
464463
# scipy
465464
# tensorstore
465+
# treescope
466466
opt-einsum==3.4.0 \
467467
--hash=sha256:69bb92469f86a1565195ece4ac0323943e83477171b91d24c35afe028a90d7cd \
468468
--hash=sha256:96ca72f1b886d148241348783498194c577fa30a8faac108586b14f1ba4473ac
@@ -641,6 +641,10 @@ toolz==1.0.0 \
641641
--hash=sha256:292c8f1c4e7516bf9086f8850935c799a874039c8bcf959d47b600e4c44a6236 \
642642
--hash=sha256:2c86e3d9a04798ac556793bced838816296a2f085017664e4995cb40a1047a02
643643
# via chex
644+
treescope==0.1.7 \
645+
--hash=sha256:14e6527d4bfe6770ac9cbb8058e49b6685444d7cd0d3f85fd10c42491848b102 \
646+
--hash=sha256:2c82ecb633f18d50e5809dd473703cf05aa074a4f3d1add74de7cf7ccdf81ae3
647+
# via flax
644648
typing-extensions==4.12.2 \
645649
--hash=sha256:04e5ca0351e0f3f85c6853954072df659d0d13fac324d0072316b67d7794700d \
646650
--hash=sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8

build/requirements_lock_3_11.txt

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -200,9 +200,8 @@ etils[epath,epy]==1.10.0 \
200200
# clu
201201
# optax
202202
# orbax-checkpoint
203-
flax==0.10.1 \
204-
--hash=sha256:5218959706bc659a1f282ca537446163093d186d8edb9b1405c0efee4d90d22a \
205-
--hash=sha256:ea98ed843c37954af2e262ea47356312a046794d7a5490d31682dffe908e25d3
203+
flax @ https://github.com/google/flax/archive/e2134af.zip \
204+
--hash=sha256:6384171c69e4a09a1f4fa9c15acd6b48ad9332429c6b61a13412ecced088985d
206205
# via
207206
# -r build/requirements.in
208207
# clu
@@ -464,6 +463,7 @@ numpy==2.1.3 \
464463
# orbax-checkpoint
465464
# scipy
466465
# tensorstore
466+
# treescope
467467
opt-einsum==3.4.0 \
468468
--hash=sha256:69bb92469f86a1565195ece4ac0323943e83477171b91d24c35afe028a90d7cd \
469469
--hash=sha256:96ca72f1b886d148241348783498194c577fa30a8faac108586b14f1ba4473ac
@@ -642,6 +642,10 @@ toolz==1.0.0 \
642642
--hash=sha256:292c8f1c4e7516bf9086f8850935c799a874039c8bcf959d47b600e4c44a6236 \
643643
--hash=sha256:2c86e3d9a04798ac556793bced838816296a2f085017664e4995cb40a1047a02
644644
# via chex
645+
treescope==0.1.7 \
646+
--hash=sha256:14e6527d4bfe6770ac9cbb8058e49b6685444d7cd0d3f85fd10c42491848b102 \
647+
--hash=sha256:2c82ecb633f18d50e5809dd473703cf05aa074a4f3d1add74de7cf7ccdf81ae3
648+
# via flax
645649
typing-extensions==4.12.2 \
646650
--hash=sha256:04e5ca0351e0f3f85c6853954072df659d0d13fac324d0072316b67d7794700d \
647651
--hash=sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8

build/requirements_lock_3_12.txt

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -200,9 +200,8 @@ etils[epath,epy]==1.10.0 \
200200
# clu
201201
# optax
202202
# orbax-checkpoint
203-
flax==0.10.1 \
204-
--hash=sha256:5218959706bc659a1f282ca537446163093d186d8edb9b1405c0efee4d90d22a \
205-
--hash=sha256:ea98ed843c37954af2e262ea47356312a046794d7a5490d31682dffe908e25d3
203+
flax @ https://github.com/google/flax/archive/e2134af.zip \
204+
--hash=sha256:6384171c69e4a09a1f4fa9c15acd6b48ad9332429c6b61a13412ecced088985d
206205
# via
207206
# -r build/requirements.in
208207
# clu
@@ -464,6 +463,7 @@ numpy==2.1.3 \
464463
# orbax-checkpoint
465464
# scipy
466465
# tensorstore
466+
# treescope
467467
opt-einsum==3.4.0 \
468468
--hash=sha256:69bb92469f86a1565195ece4ac0323943e83477171b91d24c35afe028a90d7cd \
469469
--hash=sha256:96ca72f1b886d148241348783498194c577fa30a8faac108586b14f1ba4473ac
@@ -642,6 +642,10 @@ toolz==1.0.0 \
642642
--hash=sha256:292c8f1c4e7516bf9086f8850935c799a874039c8bcf959d47b600e4c44a6236 \
643643
--hash=sha256:2c86e3d9a04798ac556793bced838816296a2f085017664e4995cb40a1047a02
644644
# via chex
645+
treescope==0.1.7 \
646+
--hash=sha256:14e6527d4bfe6770ac9cbb8058e49b6685444d7cd0d3f85fd10c42491848b102 \
647+
--hash=sha256:2c82ecb633f18d50e5809dd473703cf05aa074a4f3d1add74de7cf7ccdf81ae3
648+
# via flax
645649
typing-extensions==4.12.2 \
646650
--hash=sha256:04e5ca0351e0f3f85c6853954072df659d0d13fac324d0072316b67d7794700d \
647651
--hash=sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8

build/requirements_lock_3_13.txt

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -200,9 +200,8 @@ etils[epath,epy]==1.10.0 \
200200
# clu
201201
# optax
202202
# orbax-checkpoint
203-
flax==0.10.1 \
204-
--hash=sha256:5218959706bc659a1f282ca537446163093d186d8edb9b1405c0efee4d90d22a \
205-
--hash=sha256:ea98ed843c37954af2e262ea47356312a046794d7a5490d31682dffe908e25d3
203+
flax @ https://github.com/google/flax/archive/e2134af.zip \
204+
--hash=sha256:6384171c69e4a09a1f4fa9c15acd6b48ad9332429c6b61a13412ecced088985d
206205
# via
207206
# -r build/requirements.in
208207
# clu
@@ -464,6 +463,7 @@ numpy==2.1.3 \
464463
# orbax-checkpoint
465464
# scipy
466465
# tensorstore
466+
# treescope
467467
opt-einsum==3.4.0 \
468468
--hash=sha256:69bb92469f86a1565195ece4ac0323943e83477171b91d24c35afe028a90d7cd \
469469
--hash=sha256:96ca72f1b886d148241348783498194c577fa30a8faac108586b14f1ba4473ac
@@ -642,6 +642,10 @@ toolz==1.0.0 \
642642
--hash=sha256:292c8f1c4e7516bf9086f8850935c799a874039c8bcf959d47b600e4c44a6236 \
643643
--hash=sha256:2c86e3d9a04798ac556793bced838816296a2f085017664e4995cb40a1047a02
644644
# via chex
645+
treescope==0.1.7 \
646+
--hash=sha256:14e6527d4bfe6770ac9cbb8058e49b6685444d7cd0d3f85fd10c42491848b102 \
647+
--hash=sha256:2c82ecb633f18d50e5809dd473703cf05aa074a4f3d1add74de7cf7ccdf81ae3
648+
# via flax
645649
typing-extensions==4.12.2 \
646650
--hash=sha256:04e5ca0351e0f3f85c6853954072df659d0d13fac324d0072316b67d7794700d \
647651
--hash=sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# Copyright 2024 The JAX SC Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
load("//third_party/bazel/python:pypi.bzl", "pypi_requirement")
15+
load("//third_party/bazel/python:pytype.bzl", "pytype_strict_library")
16+
17+
package(
18+
default_applicable_licenses = ["//:license"],
19+
default_visibility = [
20+
"//jax_tpu_embedding/sparsecore:__subpackages__",
21+
],
22+
)
23+
24+
pytype_strict_library(
25+
name = "utils",
26+
srcs = ["utils.py"],
27+
deps = [pypi_requirement("jax")],
28+
)
29+
30+
pytype_strict_library(
31+
name = "decompose",
32+
srcs = ["decompose.py"],
33+
deps = [
34+
":preprocess",
35+
":utils",
36+
pypi_requirement("jax"),
37+
],
38+
)
39+
40+
pytype_strict_library(
41+
name = "preprocess",
42+
srcs = ["preprocess.py"],
43+
deps = [
44+
":utils",
45+
pypi_requirement("jax"),
46+
],
47+
)
48+
49+
pytype_strict_library(
50+
name = "auto_pipelining",
51+
srcs = ["auto_pipelining.py"],
52+
deps = [
53+
":decompose",
54+
pypi_requirement("jax"),
55+
],
56+
)

0 commit comments

Comments
 (0)