88
99import jax
1010import jax .numpy as jnp
11+ import jmp
1112from flax import linen as nn
1213
1314
@@ -26,18 +27,24 @@ class ModelConfig:
2627 use_residual_scaling : bool = True
2728 tie_embeddings : bool = True # Whether to tie input and output embed
2829 qknorm_epsilon : float = 1e-6
29-
30- dtype : jnp .dtype = jnp .float32
3130 attention_init : nn .initializers .Initializer = nn .initializers .normal (
3231 stddev = 0.02
3332 )
3433 linear_init : nn .initializers .Initializer = nn .initializers .normal (stddev = 0.02 )
3534 embed_init : nn .initializers .Initializer = nn .initializers .normal (stddev = 0.02 )
35+ param_dtype : jnp .dtype = jnp .float32
36+ compute_dtype : jnp .dtype = jnp .bfloat16
37+ output_dtype : jnp .dtype = jnp .bfloat16
3638
3739 def __post_init__ (self ):
3840 self .residual_init = nn .initializers .normal (
3941 stddev = 0.02 / jnp .sqrt (2 * self .num_layers )
4042 )
43+ self .mp_policy = jmp .Policy (
44+ compute_dtype = self .compute_dtype ,
45+ param_dtype = self .param_dtype ,
46+ output_dtype = self .output_dtype ,
47+ )
4148
4249
4350class Mlp (nn .Module ):
@@ -49,7 +56,11 @@ class Mlp(nn.Module):
4956 def __call__ (self , x_BxLxD : jax .Array ):
5057 cfg = self .cfg
5158 linear = partial (
52- nn .Dense , kernel_init = cfg .linear_init , use_bias = False , dtype = cfg .dtype
59+ nn .Dense ,
60+ kernel_init = cfg .linear_init ,
61+ use_bias = False ,
62+ dtype = cfg .compute_dtype ,
63+ param_dtype = cfg .param_dtype ,
5364 )
5465 # Adjust hidden dimension to keep the number of parameters invariant to
5566 # the activation function used since the GLU MLP has 3 * hidden_dim * D
@@ -65,7 +76,8 @@ def __call__(self, x_BxLxD: jax.Array):
6576 x_BxLxD = nn .Dense (
6677 cfg .model_dim ,
6778 use_bias = False ,
68- dtype = cfg .dtype ,
79+ dtype = cfg .compute_dtype ,
80+ param_dtype = cfg .param_dtype ,
6981 kernel_init = cfg .residual_init
7082 if cfg .use_residual_scaling
7183 else cfg .linear_init ,
@@ -96,7 +108,7 @@ def apply_rope(q, k, freqs_cis):
96108
97109 def rotate_tensor (x ):
98110 # Split into real and imaginary parts
99- x_r2 = x .reshape (* x .shape [:- 1 ], - 1 , 2 )
111+ x_r2 = x .reshape (* x .shape [:- 1 ], - 1 , 2 ). astype ( jnp . float32 )
100112 L = x .shape [1 ]
101113 freqs = freqs_cis [:, :L , :, :, :]
102114
@@ -109,7 +121,7 @@ def rotate_tensor(x):
109121 axis = - 1 ,
110122 )
111123
112- return rotated_x_r2 .reshape (* x .shape )
124+ return rotated_x_r2 .reshape (* x .shape ). astype ( x . dtype )
113125
114126 # Apply rotation to Q and K separately
115127 rotated_q = rotate_tensor (q )
@@ -141,7 +153,8 @@ def setup(self):
141153 features = (cfg .num_heads , self .Dh ),
142154 kernel_init = cfg .attention_init ,
143155 use_bias = False ,
144- dtype = cfg .dtype ,
156+ dtype = cfg .compute_dtype ,
157+ param_dtype = cfg .param_dtype ,
145158 )
146159 self .multilinear_query = self .multilinear (name = 'query' )
147160 self .multilinear_key = self .multilinear (name = 'key' )
@@ -150,7 +163,9 @@ def setup(self):
150163 seq_len = cfg .seq_len
151164 attn_scale0 = jnp .log2 (seq_len ** 2 - seq_len )
152165 self .attn_scale = self .param (
153- 'attn_scale' , nn .initializers .constant (attn_scale0 ), ()
166+ 'attn_scale' ,
167+ nn .initializers .constant (attn_scale0 , dtype = cfg .compute_dtype ),
168+ (),
154169 )
155170 self .output_projection = nn .DenseGeneral (
156171 features = cfg .model_dim ,
@@ -160,7 +175,8 @@ def setup(self):
160175 if cfg .use_residual_scaling
161176 else cfg .linear_init ,
162177 use_bias = False ,
163- dtype = cfg .dtype ,
178+ dtype = cfg .compute_dtype ,
179+ param_dtype = cfg .param_dtype ,
164180 )
165181
166182 def __call__ (self , x_BxLxD : jax .Array ):
@@ -177,32 +193,17 @@ def __call__(self, x_BxLxD: jax.Array):
177193 # Apply QK normalization
178194 q_BxLxHxDh /= jnp .linalg .norm (q_BxLxHxDh , axis = - 1 , keepdims = True ) + self .eps
179195 k_BxLxHxDh /= jnp .linalg .norm (k_BxLxHxDh , axis = - 1 , keepdims = True ) + self .eps
180-
181- # Compute attention scores
182- att_BxHxLxL = jnp .einsum ('...qhd,...khd->...hqk' , q_BxLxHxDh , k_BxLxHxDh )
183-
184- # Causal attention mask
185- L = x_BxLxD .shape [1 ]
186- mask_1x1xLxL = jnp .tril (jnp .ones ((1 , 1 , L , L ), dtype = jnp .bool_ ))
187-
188- # Apply mask and softmax
189- _NEG_INF = jnp .finfo (cfg .dtype ).min
190- att_BxHxLxL = jnp .where (mask_1x1xLxL , att_BxHxLxL , _NEG_INF )
191- att_BxHxLxL = (
192- self .attn_scale * att_BxHxLxL
193- ) # Learned scaling factor for QK norm
194- att_BxHxLxL = jax .nn .softmax (att_BxHxLxL , axis = - 1 )
195- att_BxHxLxL = att_BxHxLxL .astype (cfg .dtype )
196-
197- # Compute attention output
198- out_BxLxHxDh = jnp .einsum ('...hqk,...khd->...qhd' , att_BxHxLxL , v_BxLxHxDh )
199-
200- # Reshape and project output
196+ q_BxLxHxDh *= self .attn_scale
197+ out_BxLxHxDh = jax .nn .dot_product_attention (
198+ query = q_BxLxHxDh ,
199+ key = k_BxLxHxDh ,
200+ value = v_BxLxHxDh ,
201+ is_causal = True ,
202+ scale = 1.0 ,
203+ implementation = 'cudnn' if cfg .compute_dtype is not jnp .float32 else None ,
204+ )
201205 out_BxLxD = out_BxLxHxDh .reshape (* x_BxLxD .shape )
202-
203- # Output projection
204206 out_BxLxD = self .output_projection (out_BxLxD )
205-
206207 return out_BxLxD
207208
208209
@@ -216,16 +217,16 @@ def __call__(self, in_BxLxD: jax.Array):
216217 cfg = self .docfg
217218
218219 # x = x + attn( attn_norm(x) )
219- x_BxLxD = nn .RMSNorm (param_dtype = cfg . dtype , epsilon = cfg . rmsnorm_epsilon )(
220- in_BxLxD
221- )
220+ x_BxLxD = nn .RMSNorm (
221+ param_dtype = cfg . param_dtype , epsilon = cfg . rmsnorm_epsilon
222+ )( in_BxLxD )
222223 x_BxLxD = CausalAttn (cfg )(x_BxLxD )
223224 x_BxLxD += in_BxLxD
224225
225226 # x = x + mlp( mlp_norm(x) )
226- z_BxLxD = nn .RMSNorm (param_dtype = cfg . dtype , epsilon = cfg . rmsnorm_epsilon )(
227- x_BxLxD
228- )
227+ z_BxLxD = nn .RMSNorm (
228+ param_dtype = cfg . param_dtype , epsilon = cfg . rmsnorm_epsilon
229+ )( x_BxLxD )
229230 z_BxLxD = Mlp (cfg )(z_BxLxD )
230231
231232 return x_BxLxD + z_BxLxD
@@ -242,19 +243,24 @@ def setup(self):
242243 num_embeddings = cfg .vocab_size ,
243244 features = cfg .model_dim ,
244245 embedding_init = cfg .embed_init ,
246+ dtype = cfg .compute_dtype ,
247+ param_dtype = cfg .param_dtype ,
245248 )
246249
247250 self .blocks = [TBlock (cfg ) for _ in range (cfg .num_layers )]
248- self .out_ln = nn .RMSNorm (param_dtype = cfg .dtype , epsilon = cfg .rmsnorm_epsilon )
251+ self .out_ln = nn .RMSNorm (
252+ param_dtype = cfg .param_dtype , epsilon = cfg .rmsnorm_epsilon
253+ )
249254
250255 # Output projection - tied to input embeddings if configured
251256 if cfg .tie_embeddings :
252- self .output_proj = lambda x : self .embed .attend (x . astype ( jnp . float32 ) )
257+ self .output_proj = lambda x : self .embed .attend (x )
253258 else :
254259 self .output_proj = nn .Dense (
255260 cfg .vocab_size ,
256261 kernel_init = cfg .embed_init ,
257- dtype = cfg .dtype ,
262+ dtype = cfg .compute_dtype ,
263+ param_dtype = cfg .param_dtype ,
258264 name = 'output_proj' ,
259265 )
260266
@@ -357,6 +363,7 @@ def main():
357363
358364 # Make a prediction (forward pass)
359365 print ('\n Running forward pass...' )
366+ params , x_BxL = cfg .mp_policy .cast_to_compute ((params , x_BxL ))
360367 logits = model .apply (params , x_BxL )
361368
362369 # Print output shape and sample values
0 commit comments