1313# limitations under the License.
1414
1515import jax
16+ # flatbuffers needs importlib.util but fails to import it itself.
17+ import importlib .util # noqa: F401
1618from typing import List
1719
1820import jaxlib .mlir .ir as ir
2123
2224from .mhlo_helpers import custom_call
2325from . import _pocketfft
26+ from . import pocketfft_flatbuffers_py_generated as pd
2427import numpy as np
2528
29+ import flatbuffers
2630from jaxlib import xla_client
2731
2832for _name , _value in _pocketfft .registrations ().items ():
2933 xla_client .register_custom_call_target (_name , _value , platform = "cpu" )
3034
3135FftType = xla_client .FftType
3236
37+ flatbuffers_version_2 = hasattr (flatbuffers , "__version__" )
3338
34- _C2C = 0
35- _C2R = 1
36- _R2C = 2
3739
3840def _pocketfft_descriptor (shape : List [int ], dtype , fft_type : FftType ,
3941 fft_lengths : List [int ]) -> bytes :
4042 n = len (shape )
4143 assert len (fft_lengths ) >= 1
4244 assert len (fft_lengths ) <= n , (fft_lengths , n )
4345
46+ builder = flatbuffers .Builder (128 )
4447
4548 forward = fft_type in (FftType .FFT , FftType .RFFT )
46- is_double = np .finfo (dtype ).dtype == np .float64
4749 if fft_type == FftType .RFFT :
48- pocketfft_type = _R2C
50+ pocketfft_type = pd . PocketFftType . R2C
4951
5052 assert dtype in (np .float32 , np .float64 ), dtype
5153 out_dtype = np .dtype (np .complex64 if dtype == np .float32 else np .complex128 )
54+ pocketfft_dtype = (
55+ pd .PocketFftDtype .COMPLEX64
56+ if dtype == np .float32 else pd .PocketFftDtype .COMPLEX128 )
5257
5358 assert shape [- len (fft_lengths ):] == fft_lengths , (shape , fft_lengths )
5459 out_shape = list (shape )
5560 out_shape [- 1 ] = out_shape [- 1 ] // 2 + 1
5661
5762 elif fft_type == FftType .IRFFT :
58- pocketfft_type = _C2R
63+ pocketfft_type = pd . PocketFftType . C2R
5964 assert np .issubdtype (dtype , np .complexfloating ), dtype
6065
6166 out_dtype = np .dtype (np .float32 if dtype == np .complex64 else np .float64 )
67+ pocketfft_dtype = (
68+ pd .PocketFftDtype .COMPLEX64
69+ if dtype == np .complex64 else pd .PocketFftDtype .COMPLEX128 )
6270
6371 assert shape [- len (fft_lengths ):- 1 ] == fft_lengths [:- 1 ]
6472 out_shape = list (shape )
6573 out_shape [- 1 ] = fft_lengths [- 1 ]
6674 assert (out_shape [- 1 ] // 2 + 1 ) == shape [- 1 ]
6775 else :
68- pocketfft_type = _C2C
76+ pocketfft_type = pd . PocketFftType . C2C
6977
7078 assert np .issubdtype (dtype , np .complexfloating ), dtype
7179 out_dtype = dtype
80+ pocketfft_dtype = (
81+ pd .PocketFftDtype .COMPLEX64
82+ if dtype == np .complex64 else pd .PocketFftDtype .COMPLEX128 )
7283
7384 assert shape [- len (fft_lengths ):] == fft_lengths , (shape , fft_lengths )
7485 out_shape = shape
@@ -79,33 +90,54 @@ def _pocketfft_descriptor(shape: List[int], dtype, fft_type: FftType,
7990
8091 # Builds a PocketFftDescriptor flatbuffer. This descriptor is passed to the
8192 # C++ kernel to describe the FFT to perform.
82- strides_in = []
93+ pd .PocketFftDescriptorStartShapeVector (builder , n )
94+ for d in reversed (shape if fft_type != FftType .IRFFT else out_shape ):
95+ builder .PrependUint64 (d )
96+ if flatbuffers_version_2 :
97+ pocketfft_shape = builder .EndVector ()
98+ else :
99+ pocketfft_shape = builder .EndVector (n )
100+
101+ pd .PocketFftDescriptorStartStridesInVector (builder , n )
83102 stride = dtype .itemsize
84103 for d in reversed (shape ):
85- strides_in . append (stride )
104+ builder . PrependUint64 (stride )
86105 stride *= d
87-
88- strides_out = []
106+ if flatbuffers_version_2 :
107+ strides_in = builder .EndVector ()
108+ else :
109+ strides_in = builder .EndVector (n )
110+ pd .PocketFftDescriptorStartStridesOutVector (builder , n )
89111 stride = out_dtype .itemsize
90112 for d in reversed (out_shape ):
91- strides_out . append (stride )
113+ builder . PrependUint64 (stride )
92114 stride *= d
115+ if flatbuffers_version_2 :
116+ strides_out = builder .EndVector ()
117+ else :
118+ strides_out = builder .EndVector (n )
93119
94- axes = [n - len (fft_lengths ) + d for d in range (len (fft_lengths ))]
120+ pd .PocketFftDescriptorStartAxesVector (builder , len (fft_lengths ))
121+ for d in range (len (fft_lengths )):
122+ builder .PrependUint32 (n - d - 1 )
123+ if flatbuffers_version_2 :
124+ axes = builder .EndVector ()
125+ else :
126+ axes = builder .EndVector (len (fft_lengths ))
95127
96128 scale = 1. if forward else (1. / np .prod (fft_lengths ))
97- descriptor = _pocketfft . pocketfft_descriptor (
98- shape = shape if fft_type != FftType . IRFFT else out_shape ,
99- is_double = is_double ,
100- fft_type = pocketfft_type ,
101- fft_lengths = fft_lengths ,
102- strides_in = list ( reversed ( strides_in )),
103- strides_out = list ( reversed ( strides_out )),
104- axes = axes ,
105- forward = forward ,
106- scale = scale )
107-
108- return descriptor , out_dtype , out_shape
129+ pd . PocketFftDescriptorStart ( builder )
130+ pd . PocketFftDescriptorAddDtype ( builder , pocketfft_dtype )
131+ pd . PocketFftDescriptorAddFftType ( builder , pocketfft_type )
132+ pd . PocketFftDescriptorAddShape ( builder , pocketfft_shape )
133+ pd . PocketFftDescriptorAddStridesIn ( builder , strides_in )
134+ pd . PocketFftDescriptorAddStridesOut ( builder , strides_out )
135+ pd . PocketFftDescriptorAddAxes ( builder , axes )
136+ pd . PocketFftDescriptorAddForward ( builder , forward )
137+ pd . PocketFftDescriptorAddScale ( builder , scale )
138+ descriptor = pd . PocketFftDescriptorEnd ( builder )
139+ builder . Finish ( descriptor )
140+ return builder . Output () , out_dtype , out_shape
109141
110142
111143def pocketfft_mhlo (a , dtype , * , fft_type : FftType , fft_lengths : List [int ]):
0 commit comments