Skip to content

Commit 1b88a2e

Browse files
Merge pull request #827 from init-22/resolve_deprecations
Update Deprecated Functions
2 parents 21a3d03 + 785d82b commit 1b88a2e

File tree

38 files changed

+135
-134
lines changed

38 files changed

+135
-134
lines changed

algorithmic_efficiency/data_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def _prepare(x):
6565
# Assumes that `global_batch_size % local_device_count == 0`.
6666
return x.reshape((local_device_count, -1, *x.shape[1:]))
6767

68-
return jax.tree_map(_prepare, batch)
68+
return jax.tree.map(_prepare, batch)
6969

7070

7171
def pad(tensor: np.ndarray,

algorithmic_efficiency/param_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def pytorch_param_types(
6666

6767
def jax_param_shapes(
6868
params: spec.ParameterContainer) -> spec.ParameterShapeTree:
69-
return jax.tree_map(lambda x: spec.ShapeTuple(x.shape), params)
69+
return jax.tree.map(lambda x: spec.ShapeTuple(x.shape), params)
7070

7171

7272
def jax_param_types(param_shapes: spec.ParameterShapeTree,

algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,4 +207,4 @@ def _normalize_eval_metrics(
207207
self, num_examples: int, total_metrics: Dict[str,
208208
Any]) -> Dict[str, float]:
209209
"""Normalize eval metrics."""
210-
return jax.tree_map(lambda x: float(x[0] / num_examples), total_metrics)
210+
return jax.tree.map(lambda x: float(x[0] / num_examples), total_metrics)

algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/workload.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,7 @@ def _eval_model_on_split(self,
264264
eval_metrics[metric_name] = 0.0
265265
eval_metrics[metric_name] += metric_value
266266

267-
eval_metrics = jax.tree_map(lambda x: float(x[0] / num_examples),
267+
eval_metrics = jax.tree.map(lambda x: float(x[0] / num_examples),
268268
eval_metrics)
269269
return eval_metrics
270270

algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ class Encoder1DBlock(nn.Module):
7070
def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor:
7171
if not self.use_post_layer_norm:
7272
y = nn.LayerNorm(name='LayerNorm_0')(x)
73-
y = nn.SelfAttention(
73+
y = nn.MultiHeadDotProductAttention(
7474
num_heads=self.num_heads,
7575
kernel_init=nn.initializers.xavier_uniform(),
7676
deterministic=train,
@@ -89,7 +89,7 @@ def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor:
8989
x = x + y
9090
else:
9191
y = x
92-
y = nn.SelfAttention(
92+
y = nn.MultiHeadDotProductAttention(
9393
num_heads=self.num_heads,
9494
kernel_init=nn.initializers.xavier_uniform(),
9595
deterministic=train,

algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -396,10 +396,9 @@ def __call__(self, inputs, paddings, train):
396396
mask_paddings > 0, mask_paddings > 0, dtype=jnp.float32)
397397

398398
inputs = LayerNorm(dim=config.encoder_dim)(inputs)
399-
400399
attention_fn = functools.partial(
401400
dot_product_attention, temperature=config.attention_temperature)
402-
result = nn.SelfAttention(
401+
result = nn.MultiHeadDotProductAttention(
403402
num_heads=config.num_attention_heads,
404403
qkv_features=config.encoder_dim,
405404
decode=False,
@@ -410,7 +409,8 @@ def __call__(self, inputs, paddings, train):
410409
broadcast_dropout=False,
411410
attention_fn=attention_fn,
412411
dropout_rate=config.attention_dropout_rate,
413-
deterministic=not train)(inputs, attention_mask)
412+
deterministic=not train)(
413+
inputs_q=inputs, mask=attention_mask)
414414

415415
if config.attention_residual_dropout_rate is None:
416416
attention_residual_dropout_rate = 0.1

algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -227,11 +227,12 @@ def ctc_loss(self,
227227
labels: spec.Tensor,
228228
label_paddings: spec.Tensor,
229229
blank_id: int = 0) -> spec.Tensor:
230-
return optax.ctc_loss(logits,
231-
logit_paddings,
232-
labels,
233-
label_paddings,
234-
blank_id)
230+
return optax.ctc_loss(
231+
logits=logits,
232+
logit_paddings=logit_paddings,
233+
labels=labels,
234+
label_paddings=label_paddings,
235+
blank_id=blank_id)
235236

236237
# Adapted from lingvo's greedy decoding logic here:
237238
# https://github.com/tensorflow/lingvo/blob/2ee26814c57b7dcead3f0382170f2f3da006f810/lingvo/jax/layers/ctc_objectives.py#L138.

algorithmic_efficiency/workloads/mnist/mnist_jax/workload.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,4 +132,4 @@ def _normalize_eval_metrics(
132132
self, num_examples: int, total_metrics: Dict[str,
133133
Any]) -> Dict[str, float]:
134134
"""Normalize eval metrics."""
135-
return jax.tree_map(lambda x: float(x[0] / num_examples), total_metrics)
135+
return jax.tree.map(lambda x: float(x[0] / num_examples), total_metrics)

algorithmic_efficiency/workloads/ogbg/input_pipeline.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def _load_dataset(split, should_shuffle, data_rng, data_dir):
5151

5252
def _to_jraph(example):
5353
"""Converts an example graph to jraph.GraphsTuple."""
54-
example = jax.tree_map(lambda x: x._numpy(), example) # pylint: disable=protected-access
54+
example = jax.tree.map(lambda x: x._numpy(), example) # pylint: disable=protected-access
5555
edge_feat = example['edge_feat']
5656
node_feat = example['node_feat']
5757
edge_index = example['edge_index']
@@ -150,7 +150,7 @@ def _get_batch_iterator(dataset_iter, global_batch_size, num_shards=None):
150150
if count == num_shards:
151151

152152
def f(x):
153-
return jax.tree_map(lambda *vals: np.stack(vals, axis=0), x[0], *x[1:])
153+
return jax.tree.map(lambda *vals: np.stack(vals, axis=0), x[0], *x[1:])
154154

155155
graphs_shards = f(graphs_shards)
156156
labels_shards = f(labels_shards)

algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020

2121
def _pytorch_map(inputs: Any) -> Any:
2222
if USE_PYTORCH_DDP:
23-
return jax.tree_map(lambda a: torch.as_tensor(a, device=DEVICE), inputs)
24-
return jax.tree_map(
23+
return jax.tree.map(lambda a: torch.as_tensor(a, device=DEVICE), inputs)
24+
return jax.tree.map(
2525
lambda a: torch.as_tensor(a, device=DEVICE).view(-1, a.shape[-1])
2626
if len(a.shape) == 3 else torch.as_tensor(a, device=DEVICE).view(-1),
2727
inputs)
@@ -30,7 +30,7 @@ def _pytorch_map(inputs: Any) -> Any:
3030
def _shard(inputs: Any) -> Any:
3131
if not USE_PYTORCH_DDP:
3232
return inputs
33-
return jax.tree_map(lambda tensor: tensor[RANK], inputs)
33+
return jax.tree.map(lambda tensor: tensor[RANK], inputs)
3434

3535

3636
def _graph_map(function: Callable, graph: GraphsTuple) -> GraphsTuple:

0 commit comments

Comments
 (0)