Skip to content

Commit 2492498

Browse files
committed
CI: Add numba linker and drop float32 parametrizations
1 parent f6ade3b commit 2492498

File tree

2 files changed

+56
-88
lines changed

2 files changed

+56
-88
lines changed

.github/workflows/test.yml

Lines changed: 43 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ jobs:
6565
- uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1
6666

6767
test:
68-
name: "${{ matrix.os }} test py${{ matrix.python-version }} : fast-compile ${{ matrix.fast-compile }} : float32 ${{ matrix.float32 }} : ${{ matrix.part }}"
68+
name: "mode ${{ matrix.default-mode }} : py${{ matrix.python-version }} : ${{ matrix.os }} : ${{ matrix.part[0] }}"
6969
needs:
7070
- changes
7171
- style
@@ -74,101 +74,62 @@ jobs:
7474
strategy:
7575
fail-fast: false
7676
matrix:
77-
os: ["ubuntu-latest"]
77+
default-mode: ["C", "NUMBA", "FAST_COMPILE"]
7878
python-version: ["3.11", "3.14"]
79-
fast-compile: [0, 1]
80-
float32: [0, 1]
81-
install-numba: [0]
79+
os: ["ubuntu-latest"]
8280
install-jax: [0]
8381
install-torch: [0]
8482
install-mlx: [0]
8583
install-xarray: [0]
8684
part:
87-
- "tests --ignore=tests/scan --ignore=tests/tensor --ignore=tests/xtensor"
88-
- "tests/scan"
89-
- "tests/tensor --ignore=tests/tensor/test_basic.py --ignore=tests/tensor/test_elemwise.py --ignore=tests/tensor/test_math.py --ignore=tests/tensor/test_math_scipy.py --ignore=tests/tensor/test_blas.py --ignore=tests/tensor/signal --ignore=tests/tensor/conv --ignore=tests/tensor/rewriting --ignore=tests/tensor/linalg --ignore=tests/tensor/test_nlinalg.py --ignore=tests/tensor/test_slinalg.py --ignore=tests/tensor/test_pad.py"
90-
- "tests/tensor/test_basic.py tests/tensor/test_elemwise.py"
91-
- "tests/tensor/test_math.py"
92-
- "tests/tensor/test_math_scipy.py tests/tensor/test_blas.py tests/tensor/signal tests/tensor/conv tests/tensor/test_pad.py"
93-
- "tests/tensor/rewriting"
94-
- "tests/tensor/linalg tests/tensor/test_nlinalg.py tests/tensor/test_slinalg.py"
85+
- [ "*rest", "tests --ignore=tests/scan --ignore=tests/tensor --ignore=tests/xtensor --ignore=tests/link/numba" ]
86+
- [ "scan", "tests/scan" ]
87+
- [ "tensor *rest", "tests/tensor --ignore=tests/tensor/test_basic.py --ignore=tests/tensor/test_elemwise.py --ignore=tests/tensor/test_math.py --ignore=tests/tensor/test_math_scipy.py --ignore=tests/tensor/test_blas.py --ignore=tests/tensor/signal --ignore=tests/tensor/conv --ignore=tests/tensor/rewriting --ignore=tests/tensor/linalg --ignore=tests/tensor/test_nlinalg.py --ignore=tests/tensor/test_slinalg.py --ignore=tests/tensor/test_pad.py" ]
88+
- [ "tensor basic+elemwise", "tests/tensor/test_basic.py tests/tensor/test_elemwise.py" ]
89+
- [ "tensor math", "tests/tensor/test_math.py" ]
90+
- [ "tensor scipy+blas+conv+pad", "tests/tensor/test_math_scipy.py tests/tensor/test_blas.py tests/tensor/signal tests/tensor/conv tests/tensor/test_pad.py" ]
91+
- [ "tensor rewriting", "tests/tensor/rewriting" ]
92+
- [ "tensor linalg", "tests/tensor/linalg tests/tensor/test_nlinalg.py tests/tensor/test_slinalg.py" ]
9593
exclude:
9694
- python-version: "3.11"
97-
fast-compile: 1
98-
- python-version: "3.11"
99-
float32: 1
100-
- fast-compile: 1
101-
float32: 1
95+
default-mode: "FAST_COMPILE"
10296
include:
103-
- os: "ubuntu-latest"
104-
part: "--doctest-modules pytensor --ignore=pytensor/misc/check_duplicate_key.py --ignore=pytensor/link --ignore=pytensor/ipython.py"
97+
- part: ["doctests", "--doctest-modules pytensor --ignore=pytensor/misc/check_duplicate_key.py --ignore=pytensor/link --ignore=pytensor/ipython.py"]
98+
default-mode: "C"
10599
python-version: "3.12"
106-
fast-compile: 0
107-
float32: 0
108-
install-numba: 0
109-
install-jax: 0
110-
install-torch: 0
111-
install-mlx: 0
112-
install-xarray: 0
113-
- install-numba: 1
114100
os: "ubuntu-latest"
115-
python-version: "3.11"
116-
fast-compile: 0
117-
float32: 0
118-
part: "tests/link/numba --ignore=tests/link/numba/test_slinalg.py"
119-
- install-numba: 1
101+
- part: ["numba link", "tests/link/numba --ignore=tests/link/numba/test_slinalg.py"]
102+
default-mode: "C"
103+
python-version: "3.12"
120104
os: "ubuntu-latest"
121-
python-version: "3.14"
122-
fast-compile: 0
123-
float32: 0
124-
part: "tests/link/numba --ignore=tests/link/numba/test_slinalg.py"
125-
- install-numba: 1
105+
- part: ["numba link slinalg", "tests/link/numba/test_slinalg.py"]
106+
default-mode: "C"
107+
python-version: "3.13"
126108
os: "ubuntu-latest"
109+
- part: ["jax link", "tests/link/jax"]
110+
install-jax: 1
111+
default-mode: "C"
127112
python-version: "3.14"
128-
fast-compile: 0
129-
float32: 0
130-
part: "tests/link/numba/test_slinalg.py"
131-
- install-jax: 1
132113
os: "ubuntu-latest"
114+
- part: ["pytorch link", "tests/link/pytorch"]
115+
install-torch: 1
116+
default-mode: "C"
133117
python-version: "3.11"
134-
fast-compile: 0
135-
float32: 0
136-
part: "tests/link/jax"
137-
- install-jax: 1
138118
os: "ubuntu-latest"
119+
- part: ["xtensor", "tests/xtensor"]
120+
install-xarray: 1
121+
default-mode: "C"
139122
python-version: "3.14"
140-
fast-compile: 0
141-
float32: 0
142-
part: "tests/link/jax"
143-
- install-torch: 1
144123
os: "ubuntu-latest"
145-
python-version: "3.11"
146-
fast-compile: 0
147-
float32: 0
148-
part: "tests/link/pytorch"
149-
- install-xarray: 1
150-
os: "ubuntu-latest"
151-
python-version: "3.14"
152-
fast-compile: 0
153-
float32: 0
154-
part: "tests/xtensor"
155-
- os: "macos-15"
156-
python-version: "3.11"
157-
fast-compile: 0
158-
float32: 0
124+
- part: ["mlx link", "tests/link/mlx"]
159125
install-mlx: 1
160-
install-numba: 0
161-
install-jax: 0
162-
install-torch: 0
163-
part: "tests/link/mlx"
164-
- os: "macos-15"
126+
default-mode: "C"
127+
python-version: "3.11"
128+
os: "macos-15"
129+
- part: ["macos smoke test", "tests/tensor/test_elemwise.py tests/tensor/test_math_scipy.py tests/tensor/test_blas.py"]
130+
default-mode: "C"
165131
python-version: "3.14"
166-
fast-compile: 0
167-
float32: 0
168-
install-numba: 0
169-
install-jax: 0
170-
install-torch: 0
171-
part: "tests/tensor/test_elemwise.py tests/tensor/test_math_scipy.py tests/tensor/test_blas.py"
132+
os: "macos-15"
172133

173134
steps:
174135
- uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
@@ -199,11 +160,10 @@ jobs:
199160
run: |
200161
201162
if [[ $OS == "macos-15" ]]; then
202-
micromamba install --yes -q "python~=${PYTHON_VERSION}" "numpy${NUMPY_VERSION}" scipy pip graphviz cython pytest coverage pytest-cov pytest-benchmark pytest-mock pytest-sphinx libblas=*=*accelerate;
163+
micromamba install --yes -q "python~=${PYTHON_VERSION}" numpy scipy "numba>=0.63" pip graphviz cython pytest coverage pytest-cov pytest-benchmark pytest-mock pytest-sphinx libblas=*=*accelerate;
203164
else
204-
micromamba install --yes -q "python~=${PYTHON_VERSION}" mkl "numpy${NUMPY_VERSION}" scipy pip mkl-service graphviz cython pytest coverage pytest-cov pytest-benchmark pytest-mock pytest-sphinx;
165+
micromamba install --yes -q "python~=${PYTHON_VERSION}" numpy scipy "numba>=0.63" pip graphviz cython pytest coverage pytest-cov pytest-benchmark pytest-mock pytest-sphinx mkl mkl-service;
205166
fi
206-
if [[ $INSTALL_NUMBA == "1" ]]; then pip install "numba>=0.63"; fi
207167
if [[ $INSTALL_JAX == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" jax jaxlib numpyro equinox && pip install tfp-nightly; fi
208168
if [[ $INSTALL_TORCH == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" pytorch pytorch-cuda=12.1 "mkl<=2024.0" -c pytorch -c nvidia; fi
209169
if [[ $INSTALL_MLX == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" "mlx<0.29.4"; fi
@@ -219,28 +179,26 @@ jobs:
219179
fi
220180
env:
221181
PYTHON_VERSION: ${{ matrix.python-version }}
222-
INSTALL_NUMBA: ${{ matrix.install-numba }}
223182
INSTALL_JAX: ${{ matrix.install-jax }}
224-
INSTALL_TORCH: ${{ matrix.install-torch}}
183+
INSTALL_TORCH: ${{ matrix.install-torch }}
225184
INSTALL_XARRAY: ${{ matrix.install-xarray }}
226185
INSTALL_MLX: ${{ matrix.install-mlx }}
227186
OS: ${{ matrix.os}}
228187

229188
- name: Run tests
230189
shell: micromamba-shell {0}
231190
run: |
232-
if [[ $FAST_COMPILE == "1" ]]; then export PYTENSOR_FLAGS=$PYTENSOR_FLAGS,mode=FAST_COMPILE; fi
233-
if [[ $FLOAT32 == "1" ]]; then export PYTENSOR_FLAGS=$PYTENSOR_FLAGS,floatX=float32; fi
191+
if [[ $DEFAULT_MODE == "FAST_COMPILE" ]]; then export PYTENSOR_FLAGS=$PYTENSOR_FLAGS,mode=FAST_COMPILE; fi
192+
if [[ $DEFAULT_MODE == "NUMBA" ]]; then export PYTENSOR_FLAGS=$PYTENSOR_FLAGS,linker=numba; fi
234193
export PYTENSOR_FLAGS=$PYTENSOR_FLAGS,warn__ignore_bug_before=all,on_opt_error=raise,on_shape_error=raise,gcc__cxxflags=-pipe
235194
python -m pytest -r A --verbose --runslow --durations=50 --cov=pytensor/ --cov-report=xml:coverage/coverage-${MATRIX_ID}.xml --no-cov-on-fail $PART --benchmark-skip
236195
env:
237196
MATRIX_ID: ${{ steps.matrix-id.outputs.id }}
238197
MKL_THREADING_LAYER: GNU
239198
MKL_NUM_THREADS: 1
240199
OMP_NUM_THREADS: 1
241-
PART: ${{ matrix.part }}
242-
FAST_COMPILE: ${{ matrix.fast-compile }}
243-
FLOAT32: ${{ matrix.float32 }}
200+
PART: ${{ matrix.part[1] }}
201+
DEFAULT_MODE: ${{ matrix.default-mode }}
244202

245203
- name: Upload coverage file
246204
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2

tests/compile/test_mode.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,9 +77,18 @@ def test_modes(self):
7777

7878
# Linkers to use with regular Mode
7979
if config.cxx:
80-
linkers = ["py", "c|py", "c|py_nogc", "vm", "vm_nogc", "cvm", "cvm_nogc"]
80+
linkers = [
81+
"py",
82+
"c|py",
83+
"c|py_nogc",
84+
"vm",
85+
"vm_nogc",
86+
"cvm",
87+
"cvm_nogc",
88+
"numba",
89+
]
8190
else:
82-
linkers = ["py", "c|py", "c|py_nogc", "vm", "vm_nogc"]
91+
linkers = ["py", "c|py", "c|py_nogc", "vm", "vm_nogc", "numba"]
8392
modes = predef_modes + [Mode(linker, "fast_run") for linker in linkers]
8493

8594
for mode in modes:
@@ -93,11 +102,12 @@ def test_modes(self):
93102

94103
# regression check:
95104
# there should be
105+
# - NumbaLinker
96106
# - `VMLinker`
97107
# - OpWiseCLinker (FAST_RUN)
98108
# - PerformLinker (FAST_COMPILE)
99109
# - DebugMode's Linker (DEBUG_MODE)
100-
assert 4 == len(set(linker_classes_involved))
110+
assert 5 == len(set(linker_classes_involved))
101111

102112

103113
class TestOldModesProblem:

0 commit comments

Comments
 (0)