Skip to content

Commit f095d4b

Browse files
committed
Testing with linear model
1 parent b3ae647 commit f095d4b

File tree

7 files changed

+129
-22
lines changed

7 files changed

+129
-22
lines changed

algoperf/workloads/lm/input_pipeline.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,6 @@ def batch_iterator():
8686
token_ids = jnp.stack([_tokenize(x) for x in doc['text']])
8787
tokens = jax.nn.one_hot(token_ids, num_classes=vocab_size)
8888
inputs, targets = tokens[:, :-1], tokens[:, 1:]
89-
devices = jax.devices("gpu")
9089
inputs, targets = jax.device_put(inputs), jax.device_put(targets)
9190
yield inputs, targets
9291

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from flax import linen as nn
2+
import jax.numpy as jnp
3+
4+
class LinearModel(nn.Module):
5+
vocab_size: int
6+
7+
@nn.compact
8+
def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray:
9+
x = nn.Dense(
10+
512,
11+
kernel_init=nn.initializers.normal(0.02),
12+
bias_init=nn.initializers.zeros
13+
)(inputs)
14+
return nn.Dense(
15+
self.vocab_size,
16+
kernel_init=nn.initializers.normal(0.02),
17+
bias_init=nn.initializers.zeros
18+
)(x)

algoperf/workloads/lm/lm_jax/workload.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,12 @@
22

33
from typing import Dict, Optional, Tuple
44

5+
import jax.numpy as jnp
6+
from flax import jax_utils
7+
from algoperf import param_utils
58
from algoperf import spec
69
from algoperf.workloads.lm.workload import BaseLmWorkload
10+
from algoperf.workloads.lm.lm_jax.models import LinearModel
711

812

913
class LmWorkload(BaseLmWorkload):
@@ -14,18 +18,32 @@ def init_model_fn(
1418
rng: spec.RandomState,
1519
dropout_rate: Optional[float] = None,
1620
aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState:
17-
"""aux_dropout_rate is used as attention_dropout_rate."""
18-
pass
21+
22+
model = LinearModel(vocab_size=self._vocab_size)
23+
input_shape = (1, self._seq_len, self._vocab_size)
24+
variables = model.init(rng, jnp.ones(input_shape, jnp.float32))
25+
model_state, params = variables.pop('params')
26+
27+
self._param_shapes = param_utils.jax_param_shapes(params)
28+
self._param_types = param_utils.jax_param_types(self._param_shapes)
29+
model_state = jax_utils.replicate(model_state)
30+
params = jax_utils.replicate(params)
31+
32+
return params, model_state
1933

2034
def model_fn(
2135
self,
2236
params: spec.ParameterContainer,
23-
augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor],
37+
batch: Dict[str, spec.Tensor],
2438
model_state: spec.ModelAuxiliaryState,
2539
mode: spec.ForwardPassMode,
2640
rng: spec.RandomState,
2741
update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
28-
pass
42+
43+
del mode, rng, update_batch_norm # Not used for linear model
44+
inputs = batch['inputs']
45+
logits = self._model.apply({'params': params, **model_state}, inputs)
46+
return logits, model_state
2947

3048
def _eval_batch(self,
3149
params: spec.ParameterContainer,
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import torch
2+
import torch.nn as nn
3+
4+
class LinearLayer(nn.Module):
5+
def __init__(self, vocab_size: int):
6+
super().__init__()
7+
self.bottleneck = nn.Linear(vocab_size, 512)
8+
self.output = nn.Linear(512, vocab_size)
9+
self.reset_parameters()
10+
11+
def reset_parameters(self):
12+
nn.init.normal_(self.bottleneck.weight, std=0.02)
13+
nn.init.zeros_(self.bottleneck.bias)
14+
nn.init.normal_(self.output.weight, std=0.02)
15+
nn.init.zeros_(self.output.bias)
16+
17+
def forward(self, x: torch.Tensor) -> torch.Tensor:
18+
return self.output(self.bottleneck(x))

algoperf/workloads/lm/lm_pytorch/workload.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,13 @@
55
import jax
66
import torch
77
import torch.distributed as dist
8+
from torch.nn.parallel import DistributedDataParallel as DDP
89

10+
from algoperf import param_utils
911
from algoperf import pytorch_utils
1012
from algoperf import spec
1113
from algoperf.workloads.lm.workload import BaseLmWorkload
14+
from algoperf.workloads.lm.lm_pytorch.models import LinearLayer
1215

1316
USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_utils.pytorch_setup()
1417

@@ -21,18 +24,39 @@ def init_model_fn(
2124
rng: spec.RandomState,
2225
dropout_rate: Optional[float] = None,
2326
aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState:
24-
"""aux_dropout_rate is used as attention_dropout_rate."""
25-
pass
27+
28+
if hasattr(self, '_model'):
29+
self._model.reset_parameters()
30+
return self._model, None
31+
32+
torch.manual_seed(rng[0])
33+
self._model = LinearLayer(vocab_size=self._vocab_size)
34+
self._param_shapes = param_utils.pytorch_param_shapes(self._model)
35+
self._param_types = param_utils.pytorch_param_types(self._param_shapes)
36+
self._model.to(DEVICE)
37+
38+
if N_GPUS > 1:
39+
if USE_PYTORCH_DDP:
40+
self._model = DDP(self._model, device_ids=[RANK], output_device=RANK)
41+
else:
42+
self._model = torch.nn.DataParallel(self._model)
43+
44+
return self._model, None
2645

2746
def model_fn(
2847
self,
2948
params: spec.ParameterContainer,
30-
augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor],
49+
batch: Dict[str, spec.Tensor],
3150
model_state: spec.ModelAuxiliaryState,
3251
mode: spec.ForwardPassMode,
3352
rng: spec.RandomState,
3453
update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
35-
pass
54+
55+
del model_state, rng, update_batch_norm # Not used for linear model
56+
model = params
57+
inputs = batch['inputs'].float() # Convert one-hot to float
58+
logits = model(inputs)
59+
return logits, None
3660

3761
def _build_input_queue(
3862
self,
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import jax
2+
import jax.numpy as jnp
3+
import torch
4+
5+
TEST_SEQ_LEN = 512
6+
7+
def test_pytorch_linear():
8+
from algoperf.workloads.lm.lm_pytorch.models import LinearLayer
9+
vocab_size = 32000
10+
model = LinearLayer(vocab_size)
11+
12+
batch_size = 8
13+
seq_len = TEST_SEQ_LEN
14+
inputs = torch.randn(batch_size, seq_len, vocab_size)
15+
outputs = model(inputs)
16+
17+
assert outputs.shape == (batch_size, seq_len, vocab_size)
18+
assert not torch.isnan(outputs).any()
19+
20+
def test_jax_linear():
21+
from algoperf.workloads.lm.lm_jax.models import LinearModel
22+
23+
vocab_size = 32000
24+
seq_len = TEST_SEQ_LEN
25+
batch_size = 8
26+
model = LinearModel(vocab_size)
27+
rng = jax.random.PRNGKey(0)
28+
params = model.init(rng, jnp.ones((1, seq_len, vocab_size)))
29+
30+
inputs = jax.random.normal(rng, (batch_size, seq_len, vocab_size))
31+
outputs = model.apply(params, inputs)
32+
33+
assert outputs.shape == (batch_size, seq_len, vocab_size)
34+
assert not jnp.isnan(outputs).any()
35+
36+
if __name__ == '__main__':
37+
test_pytorch_linear()
38+
test_jax_linear()
39+
print("All tests passed!")

algoperf/workloads/lm/workload.py

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020
class BaseLmWorkload(spec.Workload):
2121
"""LM workload."""
2222

23-
_vocab_size: int = 32000
24-
_seq_len: int = 2048
23+
_vocab_size: int = 50257
24+
_seq_len: int = 512
2525

2626
def __init__(self) -> None:
2727
pass
@@ -106,24 +106,15 @@ def activation(self) -> str:
106106
def glu(self) -> bool:
107107
return True
108108

109+
@abc.abstractmethod
109110
def _build_input_queue(self,
110111
data_rng: jax.random.PRNGKey,
111112
split: str,
112113
data_dir: str,
113114
global_batch_size: int,
114115
num_batches: Optional[int] = None,
115116
repeat_final_dataset: bool = False):
116-
ds = input_pipeline.get_lm_dataset(
117-
data_rng,
118-
split,
119-
data_dir,
120-
vocab_size=self._vocab_size,
121-
global_batch_size=global_batch_size,
122-
num_batches=num_batches,
123-
repeat_final_dataset=repeat_final_dataset)
124-
125-
for batch in iter(ds):
126-
yield batch
117+
"""Build an input queue for the given split."""
127118

128119
@abc.abstractmethod
129120
def _eval_batch(self,

0 commit comments

Comments
 (0)