Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
147 commits
Select commit Hold shift + click to select a range
ae48ccd
Use jax.jit for sharding initial steps
rka97 Nov 21, 2024
eb5cac7
Use jax.jit for adamw
rka97 Nov 21, 2024
82977da
Pass yapf checks
rka97 Dec 9, 2024
99545d4
CIFAR workload sharding
rka97 Dec 9, 2024
018711a
librispeech_conformer now running
rka97 Jan 7, 2025
fbeb5f1
fix formatting
rka97 Feb 5, 2025
6e4e7b0
shard default
rka97 Feb 5, 2025
4a2c02d
start imagenet
rka97 Feb 5, 2025
47beba1
remove bn sync in imagenet (jit handles it automatically)
rka97 Feb 5, 2025
3a18f19
ImageNet-ViT also works
rka97 Feb 6, 2025
bd0f565
Start working on WMT. OOM error
rka97 Feb 20, 2025
3044efb
post-rebase, still on wmt
rka97 Feb 20, 2025
e301c49
cache sharding fix
rka97 Feb 20, 2025
e5ed97a
Merge branch 'dev' into jit_switch
priyakasimbeg Feb 21, 2025
4fcf984
target_setting_algorithms sharding, compilation caching
rka97 Feb 21, 2025
d147e39
Update tests to correct batch size
rka97 Feb 21, 2025
a2b61be
yapf and isort checks..
rka97 Feb 21, 2025
be11c23
Merge branch 'jit_switch' of https://github.com/mlcommons/algorithmic…
priyakasimbeg Mar 6, 2025
e2a3b5f
Merge branch 'dev' into jit_switch
priyakasimbeg Mar 6, 2025
a80f4ec
switch fastmri from pmap to jit
priyakasimbeg Mar 7, 2025
c39ca51
migrate criteo workload
priyakasimbeg Mar 7, 2025
06377d9
update utils function used for sharding conformer
priyakasimbeg Mar 7, 2025
9cbe7d9
update conformer and deepspeech
priyakasimbeg Mar 8, 2025
c6ecd67
debugging
priyakasimbeg Mar 11, 2025
f35690d
debuging
priyakasimbeg Mar 12, 2025
848b50c
reformatting
priyakasimbeg Mar 18, 2025
fb62eae
reformatting
priyakasimbeg Mar 18, 2025
fe3f9f0
reformatting
priyakasimbeg Mar 18, 2025
004afbd
reformatting
priyakasimbeg Mar 18, 2025
f1db3d3
reformatting
priyakasimbeg Mar 18, 2025
c208cc7
sharding deepspeech
priyakasimbeg Mar 19, 2025
2e4cc9e
ogbg jit migration
priyakasimbeg Mar 19, 2025
d3a06fc
deepspeech jit changes
priyakasimbeg Mar 20, 2025
2cfa2a9
set jax to 0.5.1
priyakasimbeg Mar 20, 2025
70705a7
merge
priyakasimbeg Mar 20, 2025
75d6315
upgrade jax to 0.5.3
priyakasimbeg Apr 1, 2025
1df0690
change bsz back
priyakasimbeg Apr 1, 2025
c1d0c66
formatting
priyakasimbeg Apr 3, 2025
1b9466c
remove debugging statements from submission_runner.py
priyakasimbeg Apr 3, 2025
7a71cf0
pyproject.toml
priyakasimbeg Apr 3, 2025
9e1f337
clean up ogbg
priyakasimbeg Apr 3, 2025
a1d0abd
clean up ogbg
priyakasimbeg Apr 3, 2025
adb2b7e
Merge branch 'jit_switch' of github.com:mlcommons/algorithmic-efficie…
priyakasimbeg Apr 3, 2025
99caa03
clean up mnist workload.py
priyakasimbeg Apr 3, 2025
b14174b
refactoring & clean up
priyakasimbeg Apr 3, 2025
a3a9b9f
simplify changes in cifar jax
priyakasimbeg Apr 3, 2025
0a340a2
small fix
priyakasimbeg Apr 3, 2025
60c1cce
rename sharding utils
priyakasimbeg Apr 3, 2025
1edb724
fix sharding rename
priyakasimbeg Apr 3, 2025
49864fb
refactoring
priyakasimbeg Apr 3, 2025
7820ac6
modifications to cifar
priyakasimbeg Apr 4, 2025
0a2043c
fix
priyakasimbeg Apr 5, 2025
95037bf
clean up and small fixes
priyakasimbeg Apr 5, 2025
e79c761
add test for sharding invariance
priyakasimbeg Apr 5, 2025
110e792
fix
priyakasimbeg Apr 8, 2025
9c91c65
Update pyproject.toml
priyakasimbeg Apr 14, 2025
21bb997
Update workload.py
priyakasimbeg Apr 14, 2025
eb56919
Update workload.py
priyakasimbeg Apr 14, 2025
c489749
Merge branch 'jit_switch' of github.com:mlcommons/algorithmic-efficie…
priyakasimbeg Apr 14, 2025
1277cc2
upgrade jax
priyakasimbeg May 19, 2025
def4ac5
update dockerfile
priyakasimbeg May 19, 2025
450cbee
remove extra installs
priyakasimbeg May 19, 2025
89718e7
update jax version
priyakasimbeg May 20, 2025
7dcf5af
update install commands for pytorch cpu only
priyakasimbeg May 20, 2025
4335688
update dockerfile
priyakasimbeg May 20, 2025
8d1fe7e
update dockerfile
priyakasimbeg May 20, 2025
240e2e5
update dockerfile
priyakasimbeg May 20, 2025
cc8d604
update dockerfile
priyakasimbeg May 20, 2025
fe56eaf
update dockerfile
priyakasimbeg May 20, 2025
de4c38b
modify initial model_state
priyakasimbeg May 27, 2025
5b7fb31
docker build script change
priyakasimbeg May 27, 2025
57b8fe6
temporarily use pre-releases for jax install
priyakasimbeg May 27, 2025
e23e99a
fix to pyproject.toml
priyakasimbeg May 28, 2025
505fab2
chnage defaults for job config script
priyakasimbeg May 29, 2025
4acaffe
fix docker image
priyakasimbeg Jun 4, 2025
3481f0e
jax deprecation fix for jax.tree_map
priyakasimbeg Jun 4, 2025
447d621
try to fix jax installation
priyakasimbeg Jun 4, 2025
a3df78c
temporary pip install change for jax gpu nightly
priyakasimbeg Jun 4, 2025
8aa3ffc
add step_time to summary df
priyakasimbeg Jun 7, 2025
1cc068a
capture trace
priyakasimbeg Jun 10, 2025
274a911
add flag to skip evals
priyakasimbeg Jun 10, 2025
2580f5c
add log dir to save traces to
priyakasimbeg Jun 12, 2025
00d3810
remove editable flag from docker install for ml packages
priyakasimbeg Jun 12, 2025
c87d908
add cpu version for pytorch package to pyproject.toml
priyakasimbeg Jun 12, 2025
12f1a87
merge
priyakasimbeg Jun 12, 2025
f387724
decrease logging frequency
priyakasimbeg Jun 12, 2025
f4c6072
fix pyproject.toml
priyakasimbeg Jun 12, 2025
8616a64
fix
priyakasimbeg Jun 12, 2025
89ddb7f
update dockerfile
priyakasimbeg Jun 12, 2025
993fe6f
update installation instructions
priyakasimbeg Jun 12, 2025
20b726a
use jraph.batch_np instead of jraph.batch since jraph.batch with jnp …
priyakasimbeg Jul 2, 2025
9d1f915
modify documentation
priyakasimbeg Jul 2, 2025
3486145
add plot util to visualize training metrics with wandb
priyakasimbeg Jul 3, 2025
5cbb368
expand plot_curves.py to take entire directory
priyakasimbeg Jul 9, 2025
c0edfbe
plot utils fixes
priyakasimbeg Jul 9, 2025
cdf35d8
temporarily use old stephints
priyakasimbeg Jul 9, 2025
8f6648c
add wandb
priyakasimbeg Jul 9, 2025
582951c
remove unused baselines that rely on pmap
priyakasimbeg Jul 10, 2025
b689a98
move prize qualifcation baselines
priyakasimbeg Jul 10, 2025
2c653f8
migrate qualifcation baselines to use jit
priyakasimbeg Jul 11, 2025
eb009ac
formatting
priyakasimbeg Jul 11, 2025
45a7fbe
documentation update
priyakasimbeg Jul 11, 2025
7dff874
removed unused code
priyakasimbeg Jul 11, 2025
c028f94
fix
priyakasimbeg Jul 11, 2025
55204a6
fix
priyakasimbeg Jul 11, 2025
4f1c43e
fix
priyakasimbeg Jul 11, 2025
b952422
fix name
priyakasimbeg Jul 11, 2025
5f46ec7
fix
priyakasimbeg Jul 11, 2025
3593463
add legacy LSTM layer to Deepspeech
priyakasimbeg Jul 15, 2025
7f35327
swap out lstm layer
priyakasimbeg Jul 17, 2025
86e6379
pin jax version
priyakasimbeg Jul 30, 2025
378f76c
pin jax version
priyakasimbeg Jul 30, 2025
db652ee
pin jax to 0.6.2
priyakasimbeg Jul 30, 2025
3b5a623
fix jax version
priyakasimbeg Jul 31, 2025
8e98702
pin cudnn version
priyakasimbeg Jul 31, 2025
d47c70a
add script to export runs to wandb
priyakasimbeg Aug 5, 2025
6634043
merge from dev
priyakasimbeg Aug 9, 2025
579ebc1
fix formatting for ruff
priyakasimbeg Aug 9, 2025
9d3c2d8
small fixes in make_job_config.py
priyakasimbeg Aug 9, 2025
0eba94b
fix format
priyakasimbeg Aug 9, 2025
64255e2
fix
priyakasimbeg Aug 12, 2025
fb3bb33
delete incomplete sharding test
priyakasimbeg Aug 12, 2025
d814fc7
linting
priyakasimbeg Aug 12, 2025
993949a
linting
priyakasimbeg Aug 12, 2025
be88d01
fix
priyakasimbeg Aug 12, 2025
5859935
remove unmaintained baselines
priyakasimbeg Aug 12, 2025
462e1a5
fix sharding for ogbg pytorch
priyakasimbeg Aug 12, 2025
c337cc4
fix
priyakasimbeg Aug 12, 2025
2717519
reformatting
priyakasimbeg Aug 12, 2025
91912cc
fix
priyakasimbeg Aug 12, 2025
ecbf90e
format
priyakasimbeg Aug 12, 2025
cfd4ec9
fix reference algorithm test for ogbg pytorch
priyakasimbeg Aug 12, 2025
df20d97
fix
priyakasimbeg Aug 12, 2025
5f076ac
factor out more array reshaping for pytorch workloads
priyakasimbeg Aug 12, 2025
e0ed0a2
factor out sharding from data_utils
priyakasimbeg Aug 12, 2025
26a77e9
fix
priyakasimbeg Aug 12, 2025
abfa9ee
fix pytorch input pipelines
priyakasimbeg Aug 12, 2025
af44acd
test fixes
priyakasimbeg Aug 13, 2025
17d69e5
revert step hints for speech workloads
priyakasimbeg Aug 13, 2025
bb73fef
fix imagenet
priyakasimbeg Aug 15, 2025
655e031
fix speech workloads
priyakasimbeg Aug 15, 2025
a7403ed
fix wmt
priyakasimbeg Aug 15, 2025
4573499
fix in imagenet_v2 data pipleine
priyakasimbeg Aug 15, 2025
6715342
fix wmt jax
priyakasimbeg Aug 18, 2025
e0e225d
fix ogbg pytorch
priyakasimbeg Aug 18, 2025
21a196a
fix conformer pytorch
priyakasimbeg Aug 18, 2025
43d2191
fix to ogbg CI test
priyakasimbeg Aug 19, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
2 changes: 1 addition & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ jobs:
pip install -e .
python tests/reference_algorithm_tests.py --workload=ogbg --framework=pytorch --global_batch_size=8 --submission_path=reference_algorithms/target_setting_algorithms/pytorch_nesterov.py --tuning_search_space=reference_algorithms/target_setting_algorithms/ogbg/tuning_search_space.json
python tests/reference_algorithm_tests.py --workload=ogbg --framework=jax --global_batch_size=8 --submission_path=reference_algorithms/target_setting_algorithms/jax_nesterov.py --tuning_search_space=reference_algorithms/target_setting_algorithms/ogbg/tuning_search_space.json
pytest:
pytest-params:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
Expand Down
12 changes: 4 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,20 +57,16 @@ You can install this package and dependencies in a [Python virtual environment](
We recommend using a Docker container (or alternatively, a Singularity/Apptainer container) to ensure a similar environment to our scoring and testing environments.
Both options are described in detail in the [**Getting Started**](/docs/GETTING_STARTED.md) document.

_TL;DR to install the Jax version for GPU run:_
*TL;DR to install the Jax version for GPU and all workload dependencies run:*

```bash
pip3 install -e '.[pytorch_cpu]'
pip3 install -e '.[jax_gpu]' -f 'https://storage.googleapis.com/jax-releases/jax_cuda_releases.html'
pip3 install -e '.[full]'
pip3 install -e '.[pytorch_cpu,jax_gpu,full]' --extra-index-url https://download.pytorch.org/whl/cpu
```

_TL;DR to install the PyTorch version for GPU run:_
*TL;DR to install the PyTorch version for GPU and all workload dependencies run:*

```bash
pip3 install -e '.[jax_cpu]'
pip3 install -e '.[pytorch_gpu]' -f 'https://download.pytorch.org/whl/cu121'
pip3 install -e '.[full]'
pip3 install -e '.[jax_cpu,pytorch_gpu,full]'
```

## Getting Started
Expand Down
4 changes: 0 additions & 4 deletions algoperf/checkpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import os
from typing import Sequence, Tuple

import jax
import numpy as np
import torch
from absl import logging
Expand Down Expand Up @@ -210,10 +209,7 @@ def save_checkpoint(
train_state, eval_results, global_step, preemption_count).
"""
if framework == 'jax':
model_params = jax.device_get(jax_utils.unreplicate(model_params))
opt_state, _ = optimizer_state
opt_state = jax.device_get(jax_utils.unreplicate(opt_state))
model_state = jax.device_get(jax_utils.unreplicate(model_state))
else:
if isinstance(
model_params,
Expand Down
13 changes: 9 additions & 4 deletions algoperf/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,19 @@ def _prepare(x):
if remainder_size != 0 or pad_to_global_batch_size:
x = pad(x, pad_size, padding_value=padding_value)

# Reshape (global_batch_size, ...) to
# (local_device_count, per_device_batch_size, ...).
# Assumes that `global_batch_size % local_device_count == 0`.
return x.reshape((local_device_count, -1, *x.shape[1:]))
# return x.reshape((local_device_count, -1, *x.shape[1:]))
return x

return jax.tree.map(_prepare, batch)


def shard(batch):
local_device_count = max(torch.cuda.device_count(), jax.local_device_count())
return jax.tree.map(
lambda x: x.reshape((local_device_count, -1, *x.shape[1:])), batch
)


def pad(
tensor: np.ndarray, pad_size: int, padding_value: int = 0
) -> np.ndarray:
Expand Down
40 changes: 40 additions & 0 deletions algoperf/jax_sharding_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
"""Utilities for dealing with sharding in JAX."""

import jax
from jax.sharding import NamedSharding
from jax.sharding import PartitionSpec as P


def get_replicate_sharding():
"""Returns a sharding spec that replicates data across all devices."""
mesh = jax.sharding.Mesh(jax.devices(), ('batch',))
return NamedSharding(mesh, P())


def get_batch_dim_sharding():
"""Returns a sharding spec that shards data along the first axis."""
mesh = jax.sharding.Mesh(jax.devices(), ('batch',))
return NamedSharding(mesh, P('batch'))


def shard_along_batch_dim(x):
"""Shards a tensor across all devices."""
mesh = jax.sharding.Mesh(jax.devices(), ('batch',))
return jax.tree.map(
lambda x: jax.device_put(x, NamedSharding(mesh, P('batch'))), x
)


def replicate(x):
"""Replicates tensor across all devices."""
mesh = jax.sharding.Mesh(jax.devices(), ('batch',))
return jax.tree.map(lambda x: jax.device_put(x, NamedSharding(mesh, P())), x)


def display_shard_info(x: jax.Array):
"""Displays shard info of a jax array."""
for shard in x.addressable_shards:
print(
f'shard.device: {shard.device}, index: {shard.index}, replica_id:'
f' {shard.replica_id}.\n'
)
59 changes: 36 additions & 23 deletions algoperf/workloads/cifar/cifar_jax/workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,11 @@
import jax.numpy as jnp
import optax
import tensorflow_datasets as tfds
from flax import jax_utils
from flax import linen as nn
from flax.core import pop
from jax import lax

from algoperf import param_utils, spec
from algoperf import jax_sharding_utils, param_utils, spec
from algoperf.workloads.cifar.cifar_jax import models
from algoperf.workloads.cifar.cifar_jax.input_pipeline import create_input_iter
from algoperf.workloads.cifar.workload import BaseCifarWorkload
Expand All @@ -29,6 +28,7 @@ def _build_cifar_dataset(
repeat_final_dataset: Optional[bool] = None,
) -> Iterator[Dict[str, spec.Tensor]]:
ds_builder = tfds.builder('cifar10:3.0.2', data_dir=data_dir)
ds_builder.download_and_prepare()
train = split == 'train'
assert self.num_train_examples + self.num_validation_examples == 50000
if split in ['train', 'eval_train']:
Expand Down Expand Up @@ -89,8 +89,8 @@ def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:
model_state, params = pop(variables, 'params')
self._param_shapes = param_utils.jax_param_shapes(params)
self._param_types = param_utils.jax_param_types(self._param_shapes)
model_state = jax_utils.replicate(model_state)
params = jax_utils.replicate(params)
model_state = jax_sharding_utils.replicate(params)
params = jax_sharding_utils.replicate(params)
return params, model_state

def is_output_params(self, param_key: spec.ParameterKey) -> bool:
Expand Down Expand Up @@ -171,15 +171,8 @@ def _compute_metrics(
'loss': summed_loss,
'accuracy': accuracy,
}
metrics = lax.psum(metrics, axis_name='batch')
return metrics

@functools.partial(
jax.pmap,
axis_name='batch',
in_axes=(None, 0, 0, 0, None),
static_broadcasted_argnums=(0,),
)
def _eval_model(
self,
params: spec.ParameterContainer,
Expand All @@ -188,21 +181,41 @@ def _eval_model(
rng: spec.RandomState,
) -> Dict[spec.Tensor, spec.ModelAuxiliaryState]:
"""Return the mean accuracy and loss as a dict."""
logits, _ = self.model_fn(
params,
batch,
model_state,
spec.ForwardPassMode.EVAL,
rng,
update_batch_norm=False,

@functools.partial(
jax.jit,
in_shardings=(
jax_sharding_utils.get_replicate_sharding(), # params
jax_sharding_utils.get_batch_dim_sharding(), # batch
jax_sharding_utils.get_replicate_sharding(), # model_state
jax_sharding_utils.get_batch_dim_sharding(), # rng
),
)
weights = batch.get('weights')
if weights is None:
weights = jnp.ones(len(logits))
return self._compute_metrics(logits, batch['targets'], weights)
def _eval_model_jitted(
params: spec.ParameterContainer,
batch: Dict[str, spec.Tensor],
model_state: spec.ModelAuxiliaryState,
rng: spec.RandomState,
) -> Dict[spec.Tensor, spec.ModelAuxiliaryState]:
"""Return the mean accuracy and loss as a dict."""
logits, _ = self.model_fn(
params,
batch,
model_state,
spec.ForwardPassMode.EVAL,
rng,
update_batch_norm=False,
)
weights = batch.get('weights')
if weights is None:
weights = jnp.ones(len(logits))
return self._compute_metrics(logits, batch['targets'], weights)

metrics = _eval_model_jitted(params, batch, model_state, rng)
return jax.tree.map(lambda x: x.item(), metrics)

def _normalize_eval_metrics(
self, num_examples: int, total_metrics: Dict[str, Any]
) -> Dict[str, float]:
"""Normalize eval metrics."""
return jax.tree.map(lambda x: float(x[0] / num_examples), total_metrics)
return jax.tree_map(lambda x: x / num_examples, total_metrics)
46 changes: 35 additions & 11 deletions algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@
import jax
import jax.numpy as jnp
import numpy as np
from flax import jax_utils

from algoperf import param_utils, spec
from algoperf import jax_sharding_utils, param_utils, spec
from algoperf.workloads.criteo1tb.criteo1tb_jax import models
from algoperf.workloads.criteo1tb.workload import BaseCriteo1TbDlrmSmallWorkload

Expand Down Expand Up @@ -106,7 +105,7 @@ def init_model_fn(
initial_params = initial_variables['params']
self._param_shapes = param_utils.jax_param_shapes(initial_params)
self._param_types = param_utils.jax_param_types(self._param_shapes)
return jax_utils.replicate(initial_params), None
return jax_sharding_utils.replicate(initial_params), None

def is_output_params(self, param_key: spec.ParameterKey) -> bool:
return param_key == 'Dense_7'
Expand All @@ -132,13 +131,40 @@ def model_fn(
logits_batch = self._model.apply({'params': params}, inputs, **apply_kwargs)
return logits_batch, None

def _build_input_queue(
self,
data_rng: spec.RandomState,
split: str,
data_dir: str,
global_batch_size: int,
cache: Optional[bool] = None,
repeat_final_dataset: Optional[bool] = None,
num_batches: Optional[int] = None,
):
it = super()._build_input_queue(
data_rng,
split,
data_dir,
global_batch_size,
cache,
repeat_final_dataset,
num_batches,
)
f = functools.partial(
jax.device_put, device=jax_sharding_utils.get_batch_dim_sharding()
)
return map(f, it)

@functools.partial(
jax.pmap,
axis_name='batch',
in_axes=(None, 0, 0),
static_broadcasted_argnums=(0,),
jax.jit,
in_shardings=(
jax_sharding_utils.get_replicate_sharding(),
jax_sharding_utils.get_batch_dim_sharding(),
),
static_argnums=(0,),
out_shardings=jax_sharding_utils.get_replicate_sharding(),
)
def _eval_batch_pmapped(
def _eval_batch_jitted(
self, params: spec.ParameterContainer, batch: Dict[str, spec.Tensor]
) -> spec.Tensor:
logits, _ = self.model_fn(
Expand All @@ -162,9 +188,7 @@ def _eval_batch(
) -> spec.Tensor:
# We do NOT psum inside of _eval_batch_pmapped, so the returned tensor of
# shape (local_device_count,) will all be different values.
return np.array(
self._eval_batch_pmapped(params, batch).sum(), dtype=np.float64
)
return np.array(self._eval_batch_jitted(params, batch), dtype=np.float64)


class Criteo1TbDlrmSmallTestWorkload(Criteo1TbDlrmSmallWorkload):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

from algoperf import param_utils, spec
from algoperf import data_utils, param_utils, spec
from algoperf.pytorch_utils import pytorch_setup
from algoperf.workloads.criteo1tb.criteo1tb_pytorch import models
from algoperf.workloads.criteo1tb.workload import BaseCriteo1TbDlrmSmallWorkload
Expand Down Expand Up @@ -152,6 +152,7 @@ def _build_input_queue(
num_batches=num_batches,
repeat_final_dataset=repeat_final_dataset,
)
np_iter = map(data_utils.shard, np_iter)
weights = None
while True:
if RANK == 0:
Expand Down
Loading