Skip to content

Commit 5b576cb

Browse files
hawkinspjax authors
authored andcommitted
Revert: Drop flatbuffers as a Python dependency of JAX.
This change appears to be causing crashes on Mac. Original description: Change the pocketfft custom kernel in jaxlib to generate its flatbuffer descriptor in C++ instead. Surprisingly this code is actually much more readable in C++ because the flatbuffers Python API does not have a readable but less efficient API. Breaking changes to the flatbuffers Python APIs have caused breakage in JAX in the past, and we can avoid the dependency completely without much work. PiperOrigin-RevId: 457559793
1 parent 5b865ed commit 5b576cb

File tree

7 files changed

+71
-65
lines changed

7 files changed

+71
-65
lines changed

build/build_wheel.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,7 @@ def prepare_wheel(sources_path):
180180
copy_to_jaxlib(f"__main__/jaxlib/_lapack.{pyext}")
181181
copy_to_jaxlib("__main__/jaxlib/mhlo_helpers.py")
182182
copy_to_jaxlib(f"__main__/jaxlib/_pocketfft.{pyext}")
183+
copy_to_jaxlib("__main__/jaxlib/pocketfft_flatbuffers_py_generated.py")
183184
copy_to_jaxlib("__main__/jaxlib/pocketfft.py")
184185
copy_to_jaxlib("__main__/jaxlib/gpu_prng.py")
185186
copy_to_jaxlib("__main__/jaxlib/gpu_linalg.py")

build/test-requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
cloudpickle
22
colorama>=0.4.4
3+
flatbuffers==2.0
34
# TODO(jakevdp): fix use of deprecated NEAREST resampling for more recent pillow.
45
pillow>=8.3.1,<9.1.0
56
pytest-benchmark

jaxlib/BUILD

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
load(
1818
"//jaxlib:jax.bzl",
1919
"flatbuffer_cc_library",
20+
"flatbuffer_py_library",
2021
"pybind_extension",
2122
)
2223

@@ -85,6 +86,7 @@ py_library(
8586
":_lapack",
8687
":_pocketfft",
8788
":cpu_feature_guard",
89+
":pocketfft_flatbuffers_py",
8890
],
8991
)
9092

@@ -146,6 +148,11 @@ flatbuffer_cc_library(
146148
srcs = ["pocketfft.fbs"],
147149
)
148150

151+
flatbuffer_py_library(
152+
name = "pocketfft_flatbuffers_py",
153+
srcs = ["pocketfft.fbs"],
154+
)
155+
149156
cc_library(
150157
name = "pocketfft_kernels",
151158
srcs = ["pocketfft_kernels.cc"],
@@ -171,9 +178,7 @@ pybind_extension(
171178
module_name = "_pocketfft",
172179
deps = [
173180
":kernel_pybind11_helpers",
174-
":pocketfft_flatbuffers_cc",
175181
":pocketfft_kernels",
176-
"@flatbuffers//:runtime_cc",
177182
"@pybind11",
178183
],
179184
)

jaxlib/jax.bzl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ load("@org_tensorflow//tensorflow/core/platform/default:build_config.bzl", _pyx_
1818
load("@org_tensorflow//tensorflow:tensorflow.bzl", _pybind_extension = "pybind_extension")
1919
load("@local_config_cuda//cuda:build_defs.bzl", _cuda_library = "cuda_library", _if_cuda_is_configured = "if_cuda_is_configured")
2020
load("@local_config_rocm//rocm:build_defs.bzl", _if_rocm_is_configured = "if_rocm_is_configured", _rocm_library = "rocm_library")
21-
load("@flatbuffers//:build_defs.bzl", _flatbuffer_cc_library = "flatbuffer_cc_library")
21+
load("@flatbuffers//:build_defs.bzl", _flatbuffer_cc_library = "flatbuffer_cc_library", _flatbuffer_py_library = "flatbuffer_py_library")
2222

2323
# Explicitly re-exports names to avoid "unused variable" warnings from .bzl
2424
# lint tools.
@@ -30,6 +30,7 @@ pybind_extension = _pybind_extension
3030
if_cuda_is_configured = _if_cuda_is_configured
3131
if_rocm_is_configured = _if_rocm_is_configured
3232
flatbuffer_cc_library = _flatbuffer_cc_library
33+
flatbuffer_py_library = _flatbuffer_py_library
3334

3435
def py_extension(name, srcs, copts, deps):
3536
pybind_extension(name, srcs = srcs, copts = copts, deps = deps, module_name = name)

jaxlib/pocketfft.cc

Lines changed: 2 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -14,55 +14,21 @@ limitations under the License.
1414
==============================================================================*/
1515

1616
#include <complex>
17-
#include <vector>
1817

1918
#include "jaxlib/kernel_pybind11_helpers.h"
20-
#include "jaxlib/pocketfft_generated.h"
2119
#include "jaxlib/pocketfft_kernels.h"
2220
#include "include/pybind11/pybind11.h"
23-
#include "include/pybind11/stl.h"
24-
25-
namespace py = pybind11;
2621

2722
namespace jax {
2823
namespace {
2924

30-
py::bytes BuildPocketFftDescriptor(const std::vector<uint64_t>& shape,
31-
bool is_double, int fft_type,
32-
const std::vector<uint64_t>& fft_lengths,
33-
const std::vector<uint64_t>& strides_in,
34-
const std::vector<uint64_t>& strides_out,
35-
const std::vector<uint32_t>& axes,
36-
bool forward, double scale) {
37-
PocketFftDescriptorT descriptor;
38-
descriptor.shape = shape;
39-
descriptor.fft_type = static_cast<PocketFftType>(fft_type);
40-
descriptor.dtype =
41-
is_double ? PocketFftDtype_COMPLEX128 : PocketFftDtype_COMPLEX64;
42-
descriptor.strides_in = strides_in;
43-
descriptor.strides_out = strides_out;
44-
descriptor.axes = axes;
45-
descriptor.forward = forward;
46-
descriptor.scale = scale;
47-
flatbuffers::FlatBufferBuilder fbb;
48-
fbb.Finish(PocketFftDescriptor::Pack(fbb, &descriptor));
49-
return py::bytes(reinterpret_cast<char*>(fbb.GetBufferPointer()),
50-
fbb.GetSize());
51-
}
52-
53-
py::dict Registrations() {
25+
pybind11::dict Registrations() {
5426
pybind11::dict dict;
5527
dict["pocketfft"] = EncapsulateFunction(PocketFft);
5628
return dict;
5729
}
5830

59-
PYBIND11_MODULE(_pocketfft, m) {
60-
m.def("registrations", &Registrations);
61-
m.def("pocketfft_descriptor", &BuildPocketFftDescriptor, py::arg("shape"),
62-
py::arg("is_double"), py::arg("fft_type"), py::arg("fft_lengths"),
63-
py::arg("strides_in"), py::arg("strides_out"), py::arg("axes"),
64-
py::arg("forward"), py::arg("scale"));
65-
}
31+
PYBIND11_MODULE(_pocketfft, m) { m.def("registrations", &Registrations); }
6632

6733
} // namespace
6834
} // namespace jax

jaxlib/pocketfft.py

Lines changed: 57 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
# limitations under the License.
1414

1515
import jax
16+
# flatbuffers needs importlib.util but fails to import it itself.
17+
import importlib.util # noqa: F401
1618
from typing import List
1719

1820
import jaxlib.mlir.ir as ir
@@ -21,54 +23,63 @@
2123

2224
from .mhlo_helpers import custom_call
2325
from . import _pocketfft
26+
from . import pocketfft_flatbuffers_py_generated as pd
2427
import numpy as np
2528

29+
import flatbuffers
2630
from jaxlib import xla_client
2731

2832
for _name, _value in _pocketfft.registrations().items():
2933
xla_client.register_custom_call_target(_name, _value, platform="cpu")
3034

3135
FftType = xla_client.FftType
3236

37+
flatbuffers_version_2 = hasattr(flatbuffers, "__version__")
3338

34-
_C2C = 0
35-
_C2R = 1
36-
_R2C = 2
3739

3840
def _pocketfft_descriptor(shape: List[int], dtype, fft_type: FftType,
3941
fft_lengths: List[int]) -> bytes:
4042
n = len(shape)
4143
assert len(fft_lengths) >= 1
4244
assert len(fft_lengths) <= n, (fft_lengths, n)
4345

46+
builder = flatbuffers.Builder(128)
4447

4548
forward = fft_type in (FftType.FFT, FftType.RFFT)
46-
is_double = np.finfo(dtype).dtype == np.float64
4749
if fft_type == FftType.RFFT:
48-
pocketfft_type = _R2C
50+
pocketfft_type = pd.PocketFftType.R2C
4951

5052
assert dtype in (np.float32, np.float64), dtype
5153
out_dtype = np.dtype(np.complex64 if dtype == np.float32 else np.complex128)
54+
pocketfft_dtype = (
55+
pd.PocketFftDtype.COMPLEX64
56+
if dtype == np.float32 else pd.PocketFftDtype.COMPLEX128)
5257

5358
assert shape[-len(fft_lengths):] == fft_lengths, (shape, fft_lengths)
5459
out_shape = list(shape)
5560
out_shape[-1] = out_shape[-1] // 2 + 1
5661

5762
elif fft_type == FftType.IRFFT:
58-
pocketfft_type = _C2R
63+
pocketfft_type = pd.PocketFftType.C2R
5964
assert np.issubdtype(dtype, np.complexfloating), dtype
6065

6166
out_dtype = np.dtype(np.float32 if dtype == np.complex64 else np.float64)
67+
pocketfft_dtype = (
68+
pd.PocketFftDtype.COMPLEX64
69+
if dtype == np.complex64 else pd.PocketFftDtype.COMPLEX128)
6270

6371
assert shape[-len(fft_lengths):-1] == fft_lengths[:-1]
6472
out_shape = list(shape)
6573
out_shape[-1] = fft_lengths[-1]
6674
assert (out_shape[-1] // 2 + 1) == shape[-1]
6775
else:
68-
pocketfft_type = _C2C
76+
pocketfft_type = pd.PocketFftType.C2C
6977

7078
assert np.issubdtype(dtype, np.complexfloating), dtype
7179
out_dtype = dtype
80+
pocketfft_dtype = (
81+
pd.PocketFftDtype.COMPLEX64
82+
if dtype == np.complex64 else pd.PocketFftDtype.COMPLEX128)
7283

7384
assert shape[-len(fft_lengths):] == fft_lengths, (shape, fft_lengths)
7485
out_shape = shape
@@ -79,33 +90,54 @@ def _pocketfft_descriptor(shape: List[int], dtype, fft_type: FftType,
7990

8091
# Builds a PocketFftDescriptor flatbuffer. This descriptor is passed to the
8192
# C++ kernel to describe the FFT to perform.
82-
strides_in = []
93+
pd.PocketFftDescriptorStartShapeVector(builder, n)
94+
for d in reversed(shape if fft_type != FftType.IRFFT else out_shape):
95+
builder.PrependUint64(d)
96+
if flatbuffers_version_2:
97+
pocketfft_shape = builder.EndVector()
98+
else:
99+
pocketfft_shape = builder.EndVector(n)
100+
101+
pd.PocketFftDescriptorStartStridesInVector(builder, n)
83102
stride = dtype.itemsize
84103
for d in reversed(shape):
85-
strides_in.append(stride)
104+
builder.PrependUint64(stride)
86105
stride *= d
87-
88-
strides_out = []
106+
if flatbuffers_version_2:
107+
strides_in = builder.EndVector()
108+
else:
109+
strides_in = builder.EndVector(n)
110+
pd.PocketFftDescriptorStartStridesOutVector(builder, n)
89111
stride = out_dtype.itemsize
90112
for d in reversed(out_shape):
91-
strides_out.append(stride)
113+
builder.PrependUint64(stride)
92114
stride *= d
115+
if flatbuffers_version_2:
116+
strides_out = builder.EndVector()
117+
else:
118+
strides_out = builder.EndVector(n)
93119

94-
axes = [n - len(fft_lengths) + d for d in range(len(fft_lengths))]
120+
pd.PocketFftDescriptorStartAxesVector(builder, len(fft_lengths))
121+
for d in range(len(fft_lengths)):
122+
builder.PrependUint32(n - d - 1)
123+
if flatbuffers_version_2:
124+
axes = builder.EndVector()
125+
else:
126+
axes = builder.EndVector(len(fft_lengths))
95127

96128
scale = 1. if forward else (1. / np.prod(fft_lengths))
97-
descriptor = _pocketfft.pocketfft_descriptor(
98-
shape=shape if fft_type != FftType.IRFFT else out_shape,
99-
is_double=is_double,
100-
fft_type=pocketfft_type,
101-
fft_lengths=fft_lengths,
102-
strides_in=list(reversed(strides_in)),
103-
strides_out=list(reversed(strides_out)),
104-
axes=axes,
105-
forward=forward,
106-
scale=scale)
107-
108-
return descriptor, out_dtype, out_shape
129+
pd.PocketFftDescriptorStart(builder)
130+
pd.PocketFftDescriptorAddDtype(builder, pocketfft_dtype)
131+
pd.PocketFftDescriptorAddFftType(builder, pocketfft_type)
132+
pd.PocketFftDescriptorAddShape(builder, pocketfft_shape)
133+
pd.PocketFftDescriptorAddStridesIn(builder, strides_in)
134+
pd.PocketFftDescriptorAddStridesOut(builder, strides_out)
135+
pd.PocketFftDescriptorAddAxes(builder, axes)
136+
pd.PocketFftDescriptorAddForward(builder, forward)
137+
pd.PocketFftDescriptorAddScale(builder, scale)
138+
descriptor = pd.PocketFftDescriptorEnd(builder)
139+
builder.Finish(descriptor)
140+
return builder.Output(), out_dtype, out_shape
109141

110142

111143
def pocketfft_mhlo(a, dtype, *, fft_type: FftType, fft_lengths: List[int]):

jaxlib/setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
author_email='[email protected]',
3939
packages=['jaxlib', 'jaxlib.xla_extension'],
4040
python_requires='>=3.7',
41-
install_requires=['scipy>=1.5', 'numpy>=1.19', 'absl-py'],
41+
install_requires=['scipy>=1.5', 'numpy>=1.19', 'absl-py', 'flatbuffers >= 1.12, < 3.0'],
4242
url='https://github.com/google/jax',
4343
license='Apache-2.0',
4444
classifiers=[

0 commit comments

Comments
 (0)