Skip to content

Commit 8da0c79

Browse files
committed
add mixed precision training
1 parent de6bb76 commit 8da0c79

File tree

6 files changed

+249
-122
lines changed

6 files changed

+249
-122
lines changed

algoperf/workloads/finewebedu_lm/finewebedu_lm_jax/models.py

Lines changed: 49 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import jax
1010
import jax.numpy as jnp
11+
import jmp
1112
from flax import linen as nn
1213

1314

@@ -26,18 +27,24 @@ class ModelConfig:
2627
use_residual_scaling: bool = True
2728
tie_embeddings: bool = True # Whether to tie input and output embed
2829
qknorm_epsilon: float = 1e-6
29-
30-
dtype: jnp.dtype = jnp.float32
3130
attention_init: nn.initializers.Initializer = nn.initializers.normal(
3231
stddev=0.02
3332
)
3433
linear_init: nn.initializers.Initializer = nn.initializers.normal(stddev=0.02)
3534
embed_init: nn.initializers.Initializer = nn.initializers.normal(stddev=0.02)
35+
param_dtype: jnp.dtype = jnp.float32
36+
compute_dtype: jnp.dtype = jnp.bfloat16
37+
output_dtype: jnp.dtype = jnp.bfloat16
3638

3739
def __post_init__(self):
3840
self.residual_init = nn.initializers.normal(
3941
stddev=0.02 / jnp.sqrt(2 * self.num_layers)
4042
)
43+
self.mp_policy = jmp.Policy(
44+
compute_dtype=self.compute_dtype,
45+
param_dtype=self.param_dtype,
46+
output_dtype=self.output_dtype,
47+
)
4148

4249

4350
class Mlp(nn.Module):
@@ -49,7 +56,11 @@ class Mlp(nn.Module):
4956
def __call__(self, x_BxLxD: jax.Array):
5057
cfg = self.cfg
5158
linear = partial(
52-
nn.Dense, kernel_init=cfg.linear_init, use_bias=False, dtype=cfg.dtype
59+
nn.Dense,
60+
kernel_init=cfg.linear_init,
61+
use_bias=False,
62+
dtype=cfg.compute_dtype,
63+
param_dtype=cfg.param_dtype,
5364
)
5465
# Adjust hidden dimension to keep the number of parameters invariant to
5566
# the activation function used since the GLU MLP has 3 * hidden_dim * D
@@ -65,7 +76,8 @@ def __call__(self, x_BxLxD: jax.Array):
6576
x_BxLxD = nn.Dense(
6677
cfg.model_dim,
6778
use_bias=False,
68-
dtype=cfg.dtype,
79+
dtype=cfg.compute_dtype,
80+
param_dtype=cfg.param_dtype,
6981
kernel_init=cfg.residual_init
7082
if cfg.use_residual_scaling
7183
else cfg.linear_init,
@@ -96,7 +108,7 @@ def apply_rope(q, k, freqs_cis):
96108

97109
def rotate_tensor(x):
98110
# Split into real and imaginary parts
99-
x_r2 = x.reshape(*x.shape[:-1], -1, 2)
111+
x_r2 = x.reshape(*x.shape[:-1], -1, 2).astype(jnp.float32)
100112
L = x.shape[1]
101113
freqs = freqs_cis[:, :L, :, :, :]
102114

@@ -109,7 +121,7 @@ def rotate_tensor(x):
109121
axis=-1,
110122
)
111123

112-
return rotated_x_r2.reshape(*x.shape)
124+
return rotated_x_r2.reshape(*x.shape).astype(x.dtype)
113125

114126
# Apply rotation to Q and K separately
115127
rotated_q = rotate_tensor(q)
@@ -141,7 +153,8 @@ def setup(self):
141153
features=(cfg.num_heads, self.Dh),
142154
kernel_init=cfg.attention_init,
143155
use_bias=False,
144-
dtype=cfg.dtype,
156+
dtype=cfg.compute_dtype,
157+
param_dtype=cfg.param_dtype,
145158
)
146159
self.multilinear_query = self.multilinear(name='query')
147160
self.multilinear_key = self.multilinear(name='key')
@@ -150,7 +163,9 @@ def setup(self):
150163
seq_len = cfg.seq_len
151164
attn_scale0 = jnp.log2(seq_len**2 - seq_len)
152165
self.attn_scale = self.param(
153-
'attn_scale', nn.initializers.constant(attn_scale0), ()
166+
'attn_scale',
167+
nn.initializers.constant(attn_scale0, dtype=cfg.compute_dtype),
168+
(),
154169
)
155170
self.output_projection = nn.DenseGeneral(
156171
features=cfg.model_dim,
@@ -160,7 +175,8 @@ def setup(self):
160175
if cfg.use_residual_scaling
161176
else cfg.linear_init,
162177
use_bias=False,
163-
dtype=cfg.dtype,
178+
dtype=cfg.compute_dtype,
179+
param_dtype=cfg.param_dtype,
164180
)
165181

166182
def __call__(self, x_BxLxD: jax.Array):
@@ -177,32 +193,17 @@ def __call__(self, x_BxLxD: jax.Array):
177193
# Apply QK normalization
178194
q_BxLxHxDh /= jnp.linalg.norm(q_BxLxHxDh, axis=-1, keepdims=True) + self.eps
179195
k_BxLxHxDh /= jnp.linalg.norm(k_BxLxHxDh, axis=-1, keepdims=True) + self.eps
180-
181-
# Compute attention scores
182-
att_BxHxLxL = jnp.einsum('...qhd,...khd->...hqk', q_BxLxHxDh, k_BxLxHxDh)
183-
184-
# Causal attention mask
185-
L = x_BxLxD.shape[1]
186-
mask_1x1xLxL = jnp.tril(jnp.ones((1, 1, L, L), dtype=jnp.bool_))
187-
188-
# Apply mask and softmax
189-
_NEG_INF = jnp.finfo(cfg.dtype).min
190-
att_BxHxLxL = jnp.where(mask_1x1xLxL, att_BxHxLxL, _NEG_INF)
191-
att_BxHxLxL = (
192-
self.attn_scale * att_BxHxLxL
193-
) # Learned scaling factor for QK norm
194-
att_BxHxLxL = jax.nn.softmax(att_BxHxLxL, axis=-1)
195-
att_BxHxLxL = att_BxHxLxL.astype(cfg.dtype)
196-
197-
# Compute attention output
198-
out_BxLxHxDh = jnp.einsum('...hqk,...khd->...qhd', att_BxHxLxL, v_BxLxHxDh)
199-
200-
# Reshape and project output
196+
q_BxLxHxDh *= self.attn_scale
197+
out_BxLxHxDh = jax.nn.dot_product_attention(
198+
query=q_BxLxHxDh,
199+
key=k_BxLxHxDh,
200+
value=v_BxLxHxDh,
201+
is_causal=True,
202+
scale=1.0,
203+
implementation='cudnn' if cfg.compute_dtype is not jnp.float32 else None,
204+
)
201205
out_BxLxD = out_BxLxHxDh.reshape(*x_BxLxD.shape)
202-
203-
# Output projection
204206
out_BxLxD = self.output_projection(out_BxLxD)
205-
206207
return out_BxLxD
207208

208209

@@ -216,16 +217,16 @@ def __call__(self, in_BxLxD: jax.Array):
216217
cfg = self.docfg
217218

218219
# x = x + attn( attn_norm(x) )
219-
x_BxLxD = nn.RMSNorm(param_dtype=cfg.dtype, epsilon=cfg.rmsnorm_epsilon)(
220-
in_BxLxD
221-
)
220+
x_BxLxD = nn.RMSNorm(
221+
param_dtype=cfg.param_dtype, epsilon=cfg.rmsnorm_epsilon
222+
)(in_BxLxD)
222223
x_BxLxD = CausalAttn(cfg)(x_BxLxD)
223224
x_BxLxD += in_BxLxD
224225

225226
# x = x + mlp( mlp_norm(x) )
226-
z_BxLxD = nn.RMSNorm(param_dtype=cfg.dtype, epsilon=cfg.rmsnorm_epsilon)(
227-
x_BxLxD
228-
)
227+
z_BxLxD = nn.RMSNorm(
228+
param_dtype=cfg.param_dtype, epsilon=cfg.rmsnorm_epsilon
229+
)(x_BxLxD)
229230
z_BxLxD = Mlp(cfg)(z_BxLxD)
230231

231232
return x_BxLxD + z_BxLxD
@@ -242,19 +243,24 @@ def setup(self):
242243
num_embeddings=cfg.vocab_size,
243244
features=cfg.model_dim,
244245
embedding_init=cfg.embed_init,
246+
dtype=cfg.compute_dtype,
247+
param_dtype=cfg.param_dtype,
245248
)
246249

247250
self.blocks = [TBlock(cfg) for _ in range(cfg.num_layers)]
248-
self.out_ln = nn.RMSNorm(param_dtype=cfg.dtype, epsilon=cfg.rmsnorm_epsilon)
251+
self.out_ln = nn.RMSNorm(
252+
param_dtype=cfg.param_dtype, epsilon=cfg.rmsnorm_epsilon
253+
)
249254

250255
# Output projection - tied to input embeddings if configured
251256
if cfg.tie_embeddings:
252-
self.output_proj = lambda x: self.embed.attend(x.astype(jnp.float32))
257+
self.output_proj = lambda x: self.embed.attend(x)
253258
else:
254259
self.output_proj = nn.Dense(
255260
cfg.vocab_size,
256261
kernel_init=cfg.embed_init,
257-
dtype=cfg.dtype,
262+
dtype=cfg.compute_dtype,
263+
param_dtype=cfg.param_dtype,
258264
name='output_proj',
259265
)
260266

@@ -357,6 +363,7 @@ def main():
357363

358364
# Make a prediction (forward pass)
359365
print('\nRunning forward pass...')
366+
params, x_BxL = cfg.mp_policy.cast_to_compute((params, x_BxL))
360367
logits = model.apply(params, x_BxL)
361368

362369
# Print output shape and sample values

algoperf/workloads/finewebedu_lm/finewebedu_lm_jax/workload.py

Lines changed: 46 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
"""LM workload implemented in Jax."""
22

3+
from functools import partial
34
from typing import Any, Dict, Optional, Tuple
45

56
import jax
67
import jax.numpy as jnp
8+
import jmp
79

810
from algoperf import jax_sharding_utils, param_utils, spec
911
from algoperf.workloads.finewebedu_lm.finewebedu_lm_jax.models import (
@@ -13,10 +15,33 @@
1315
from algoperf.workloads.finewebedu_lm.input_pipeline import get_data_iter
1416
from algoperf.workloads.finewebedu_lm.workload import BaseLmWorkload
1517

18+
replicated_sharding = jax_sharding_utils.get_replicate_sharding()
19+
batch_sharding = jax_sharding_utils.get_batch_dim_sharding()
20+
21+
# Dtype mapping from string to JAX dtype
22+
DTYPE_MAP = {
23+
'float32': jnp.float32,
24+
'float16': jnp.float16,
25+
'bfloat16': jnp.bfloat16,
26+
}
27+
1628

1729
class LmWorkload(BaseLmWorkload):
1830
"""LM JAX workload."""
1931

32+
# Convert dtype strings from base class to JAX dtypes
33+
@property
34+
def _compute_dtype(self) -> Any:
35+
return DTYPE_MAP[self._compute_dtype_str]
36+
37+
@property
38+
def _param_dtype(self) -> Any:
39+
return DTYPE_MAP[self._param_dtype_str]
40+
41+
@property
42+
def _output_dtype(self) -> Any:
43+
return DTYPE_MAP[self._output_dtype_str]
44+
2045
def _build_input_queue(
2146
self,
2247
data_rng: jax.random.PRNGKey,
@@ -53,8 +78,14 @@ def init_model_fn(
5378
num_layers=self._n_layers, # num layers
5479
vocab_size=self._vocab_size,
5580
expanded_model_dim=self._mlp_dim, # feedforward dim
56-
dtype=jnp.float32,
81+
rmsnorm_epsilon=self._rmsnorm_epsilon,
82+
qknorm_epsilon=self._qknorm_epsilon,
83+
tie_embeddings=self._tie_embeddings,
84+
param_dtype=self._param_dtype,
85+
compute_dtype=self._compute_dtype,
86+
output_dtype=self._output_dtype,
5787
)
88+
self._mp_policy: jmp.Policy = cfg.mp_policy
5889
self._model = TransformerDo(cfg)
5990
input_shape = (1, self._seq_len) # For token IDs
6091

@@ -66,8 +97,7 @@ def init_model_fn(
6697
self._param_shapes = param_utils.jax_param_shapes(params)
6798
self._param_types = param_utils.jax_param_types(self._param_shapes)
6899
params = jax_sharding_utils.replicate(params)
69-
model_state = None
70-
return params, model_state
100+
return params, None
71101

72102
def model_fn(
73103
self,
@@ -81,10 +111,12 @@ def model_fn(
81111
) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
82112
del mode, rng, update_batch_norm, model_state, dropout_rate
83113
inputs = batch['inputs']
114+
params, inputs = self._mp_policy.cast_to_compute((params, inputs))
84115
# Convert one-hot inputs to token IDs if needed
85116
if inputs.ndim == 3: # one-hot encoded
86117
inputs = jnp.argmax(inputs, axis=-1)
87118
logits = self._model.apply({'params': params}, inputs)
119+
logits = self._mp_policy.cast_to_output(logits)
88120
return logits, None
89121

90122
def loss_fn(
@@ -139,6 +171,17 @@ def loss_fn(
139171
'per_example': per_example_losses,
140172
}
141173

174+
@partial(
175+
jax.jit,
176+
static_argnums=(0,),
177+
in_shardings=(
178+
replicated_sharding,
179+
batch_sharding,
180+
replicated_sharding,
181+
replicated_sharding,
182+
),
183+
out_shardings=(replicated_sharding),
184+
)
142185
def _eval_batch(
143186
self,
144187
params: spec.ParameterContainer,

0 commit comments

Comments
 (0)