Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
e6c2106
added prepare_for_eval, eval only if is_time_remaining
Niccolo-Ajroldi Sep 12, 2024
8bad99d
added prepare_for_eval to all submissions
Niccolo-Ajroldi Sep 14, 2024
1c7d51c
fix formatting
Niccolo-Ajroldi Sep 15, 2024
21a580b
fix formatting
Niccolo-Ajroldi Sep 15, 2024
420b583
updated documentation
Niccolo-Ajroldi Sep 15, 2024
d9c4ee9
add prepare_for_eval to spec.py
Niccolo-Ajroldi Oct 18, 2024
9caedc5
make prepare_for_eval backward compatible
Niccolo-Ajroldi Oct 21, 2024
4d74d2c
optional prepare_for_eval arg
Niccolo-Ajroldi Oct 21, 2024
364ce41
Merge branch 'dev' into prepare_for_eval
Niccolo-Ajroldi Oct 31, 2024
8cc4f4a
default dropout rates for workloads are added
init-22 Oct 31, 2024
a6fc879
adding the dropout info in fixed workload section
init-22 Oct 31, 2024
1983899
removing bold headings
init-22 Oct 31, 2024
e16ebe0
fix: changing the dtype in random_utils to uint32
init-22 Nov 14, 2024
42da4fd
feat: package updates with python 3.11
init-22 Nov 14, 2024
1005776
fix: reverting the python311 changes
init-22 Nov 14, 2024
76b084b
fix: removed cifar10 and mnist
init-22 Nov 14, 2024
b5ad298
fix: changing PRNGkey in random_utils to key
init-22 Nov 15, 2024
9819868
fix: changing the range of MAX_UINT32
init-22 Nov 16, 2024
4b2e64e
bringing back PRNGKey instead of key, till the python311 branch is me…
init-22 Nov 16, 2024
f72028f
fix: ran yapf for passing the checks
init-22 Nov 19, 2024
2b8b771
fix: ran yapf for passing the checks
init-22 Nov 19, 2024
d8f07b7
fix: triggering the checks again
init-22 Nov 22, 2024
ff176d7
fix: triggering the checks again
init-22 Nov 22, 2024
579a485
fix pytorch_default_init()
EIFY Nov 26, 2024
9d37d3e
Merge pull request #806 from init-22/feat_default_dropout_in_doc
priyakasimbeg Dec 4, 2024
ea66793
Merge pull request #789 from Niccolo-Ajroldi/prepare_for_eval
priyakasimbeg Dec 5, 2024
90959e1
Merge pull request #810 from init-22/fix-unit32
priyakasimbeg Dec 12, 2024
fe90379
Merge pull request #819 from EIFY/torch-init-fix
priyakasimbeg Dec 12, 2024
fc526a4
modify limits of ints
priyakasimbeg Jan 18, 2025
6c8fd56
Merge pull request #834 from mlcommons/random_utils
priyakasimbeg Jan 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 43 additions & 3 deletions DOCUMENTATION.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ In principle, submissions are allowed to use the available hardware systems in a
Submissions provide a [per-workload batch size](#batch-size-getter) to use. Specification of the batch size for each workload is necessary to avoid running out of memory for different workloads. Therefore, submitters can determine this batch size in advance and specify it as part of the submission. Submitters may also provide per-workload batch sizes for all [randomized workloads](#randomized-workloads). If no such batch size is provided for a randomized workload, by default, submissions will then use the batch size of the most similar [fixed workload](#fixed-workloads) (for example, if there is an ImageNet fixed workload and also a randomized workload with a similarly sized model on similarly sized images, the ImageNet batch size will be used for held-out workloads generated from this randomized workload).
Note that submitters are *not* allowed to modify the *evaluation batch size*, which is set by the benchmarking codebase. However, you can file an issue if you believe that the evaluation batch size of a particular workload is set inappropriately. The working group will review this request and consider adjusting the evaluation batch size in the benchmarking codebase, thus affecting all submitters equally.

The **submission functions** are the *batch size getter*, *optimizer state initializer*, *variable update*, and *data selection functions*. The *fixed functions* are the *data augmentation/preprocessing*, *model initialization*, *forward pass*, and *loss function*. The trained model will be evaluated in a separate step that does not call any of the submitted code.
The **submission functions** are the *batch size getter*, *optimizer state initializer*, *variable update*, *prepare for evaluation function*, and *data selection functions*. The *fixed functions* are the *data augmentation/preprocessing*, *model initialization*, *forward pass*, and *loss function*. The trained model will be evaluated in a separate step that does not call any of the submitted code.

##### Fixed functions

Expand Down Expand Up @@ -220,9 +220,35 @@ def update_params(
- Cannot modify the given hyperparameters in a workload-conditional way (please see the [Valid submission](#valid-submissions) section). This rule is intended to prohibit circumventing the tuning rules by looking up a pre-tuned optimal set of hyperparameters for each workload. It is not intended to prohibit line searches and other similar techniques.
- The fixed `init_model_fn` can optionally be called during training, for example, to reinitialize the model after a failed training effort.
- Cannot replace the model parameters with pre-trained ones.
- This API supports Polyak averaging and similar methods that implement moving averages of model parameters.
- Batch norm should work here because the `model_fn` will return updated batch norm moving averages when it is told to with `update_batch_norm`.


###### Prepare for evaluation function

```python
def prepare_for_eval(
workload: Workload,
current_param_container: ParameterContainer,
current_params_types: ParameterTypeTree,
model_state: ModelAuxiliaryState,
hyperparameters: Hyperparameters,
loss_type: LossType,
optimizer_state: OptimizerState,
eval_results: List[Tuple[int, float]],
global_step: int,
rng: RandomState
) -> (updated_optimizer_state, updated_variables, updated_model_state)
```

- Arguments are the same of `update_param`, with the only exception of `batch`.
- This function is called when a submission is deemed eligible for an evaluation (see [Evluation during training](#evaluation-during-training) section).
- The call to `prepare_for_eval` is timed and its runtime accumulates to the overall submission time.
- The returned model parameters are evaluated on the validation and test sets, provided that the accumulated submission time does not exceed the maximum runtime after this function call.
- This API supports Polyak averaging and similar methods that implement moving averages of model parameters.
- Allowed to update model state and model parameters.
- Allowed to update state for the optimizer.
- Cannot replace the model parameters with pre-trained ones.

###### Data selection

```python
Expand Down Expand Up @@ -252,7 +278,8 @@ def data_selection(

In general, with noisy, non-deterministic training, evaluation frequency can affect training time measurements as more "bites of the apple" potentially allows the training code to exploit instability. We also want to discourage submissions from complicated and unrealistic logic that attempts to guess when training is close to complete and increases the evaluation rate, while not producing a well-sampled training curve at the start of training. Simply allowing submissions complete freedom over evaluation frequency encourages competitors to work to minimize the number of evaluations, which distracts from the primary goal of finding better training algorithms.

Submissions are eligible for an untimed eval every `eval_period` seconds, run as soon as the current call of `update_params` completes. Any additional evaluations performed by the submission code count against the runtime for scoring. The harness that runs the submission code will attempt to eval every `eval_period` seconds by checking between each submission step (call of `update_params`) whether it has been at least `eval_period` seconds since that last eval and, if so, pausing the clock and running an eval. This means that if calls to `update_params` typically take a lot more than `eval_period` seconds, such submissions will not receive as many untimed evals as a submission that had an `update_params` function that took less time. However, for appropriate settings of `eval_period`, we expect this to be quite rare. Submissions are always free to restructure their `update_params` code to split work into two subsequent steps to regain the potential benefits of these untimed model evaluations. For each workload, the `eval_period` will be set such that the total evaluation time is roughly between 10% and 20% of the total training time for the target-setting runs.
Submissions are eligible for an untimed eval every `eval_period` seconds. Before proceeding to evaluation, the submission can prepare the model through a call to `prepare_for_eval`, effectively modifying the model parameters and state as well as the the optimizer state. Any additional evaluations performed by the submission code count against the runtime for scoring.
The harness that runs the submission code will attempt to eval every `eval_period` seconds by checking between each submission step (call of `update_params`) whether it has been at least `eval_period` seconds since that last eval, if so, the submission is given the possibility to prepare for evaluation (through a timed call to `prepare_for_eval`). If the accumulated runtime does not exceed the maximum allowed runtime after the preparation step, the clock is paused, and the submission is evaluated. This means that if calls to `update_params` typically take a lot more than `eval_period` seconds, such submissions will not receive as many untimed evals as a submission that had an `update_params` function that took less time. However, for appropriate settings of `eval_period`, we expect this to be quite rare. Submissions are always free to restructure their `update_params` code to split work into two subsequent steps to regain the potential benefits of these untimed model evaluations. For each workload, the `eval_period` will be set such that the total evaluation time is roughly between 10% and 20% of the total training time for the target-setting runs.

#### Valid submissions

Expand Down Expand Up @@ -419,6 +446,19 @@ The currently eight fixed workloads are:
| **7** | Molecular property prediction | OGBG | GNN | CE | mAP | 0.28098 | 0.268729 | 18,477 |
| **8** | Translation | WMT | Transformer | CE | BLEU | 30.8491 | 30.7219 | 48,151 |

Default Dropout Values for Different Workloads:

| Workload | Dropout Values |
|------------------------|------------------------------------------------------------------------------------------------------|
| criteo 1tb | dropout_rate: 0.0 |
| fastmri | dropout_rate: 0.0 |
| imagenet_resnet | dropout not used |
| imagenet_vit | dropout_rate: 0.0 |
| librispeech_conformer | attention_dropout_rate: 0.0 <br> attention_residual_dropout_rate: 0.1 <br> conv_residual_dropout_rate: 0.0 <br> feed_forward_dropout_rate: 0.0 <br> feed_forward_residual_dropout_rate: 0.1 <br> input_dropout_rate: 0.1 |
| librispeech_deepspeech | input_dropout_rate: 0.1 <br> feed_forward_dropout_rate: 0.1 <br> (Only for JAX - dropout_rate in CudnnLSTM class: 0.0) |
| ogbg | dropout_rate: 0.1 |
| wmt | dropout_rate: 0.1 <br> attention_dropout_rate: 0.1 |

#### Randomized workloads

In addition to the [fixed and known workloads](#fixed-workloads), there will also be randomized workloads in our benchmark. These randomized workloads will introduce minor modifications to a fixed workload (e.g. small model changes). The exact instances of these randomized workloads will only be created after the submission deadline and are thus unknown to both the submitters as well as the benchmark organizers. The instructions for creating them, i.e. providing a set or distribution of workloads to sample from, will be defined by this working group and made public with the call for submissions, to allow the members of this working group to submit as well as ensure that they do not possess any additional information compared to other submitters. We will refer to the unspecific workloads as *randomized workloads*, e.g. the set or distribution. The specific instance of such a randomized workload we call a *held-out workload*. That is, a held-out workload is a specific sample of a randomized workload that is used for one iteration of the benchmark. While we may reuse randomized workloads between iterations of the benchmark, new held-out workloads will be sampled for each new benchmark iteration.
Expand Down
2 changes: 1 addition & 1 deletion algorithmic_efficiency/init_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,6 @@ def pytorch_default_init(module: nn.Module) -> None:
# Perform lecun_normal initialization.
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight)
std = math.sqrt(1. / fan_in) / .87962566103423978
nn.init.trunc_normal_(module.weight, std=std)
nn.init.trunc_normal_(module.weight, std=std, a=-2 * std, b=2 * std)
if module.bias is not None:
nn.init.constant_(module.bias, 0.)
12 changes: 6 additions & 6 deletions algorithmic_efficiency/random_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,21 @@

FLAGS = flags.FLAGS

# Annoyingly, RandomState(seed) requires seed to be in [0, 2 ** 32 - 1] (an
# Annoyingly, RandomState(seed) requires seed to be in [0, 2 ** 31 - 1] (an
# unsigned int), while RandomState.randint only accepts and returns signed ints.
MAX_INT32 = 2**31
MIN_INT32 = -MAX_INT32
MAX_INT32 = 2**31 - 1
MIN_INT32 = 0

SeedType = Union[int, list, np.ndarray]


def _signed_to_unsigned(seed: SeedType) -> SeedType:
if isinstance(seed, int):
return seed % 2**32
return seed % MAX_INT32
if isinstance(seed, list):
return [s % 2**32 for s in seed]
return [s % MAX_INT32 for s in seed]
if isinstance(seed, np.ndarray):
return np.array([s % 2**32 for s in seed.tolist()])
return np.array([s % MAX_INT32 for s in seed.tolist()])


def _fold_in(seed: SeedType, data: Any) -> List[Union[SeedType, Any]]:
Expand Down
30 changes: 30 additions & 0 deletions algorithmic_efficiency/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,36 @@ def update_params(workload: Workload,
pass


PrepareForEvalFn = Callable[[
Workload,
ParameterContainer,
ParameterTypeTree,
ModelAuxiliaryState,
Hyperparameters,
LossType,
OptimizerState,
List[Tuple[int, float]],
int,
RandomState
],
UpdateReturn]


# Prepare model and optimizer for evaluation.
def prepare_for_eval(workload: Workload,
current_param_container: ParameterContainer,
current_params_types: ParameterTypeTree,
model_state: ModelAuxiliaryState,
hyperparameters: Hyperparameters,
loss_type: LossType,
optimizer_state: OptimizerState,
eval_results: List[Tuple[int, float]],
global_step: int,
rng: RandomState) -> UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
pass


DataSelectionFn = Callable[[
Workload,
Iterator[Dict[str, Any]],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,27 @@ def update_params(
return (new_optimizer_state, opt_update_fn), new_params, new_model_state


def prepare_for_eval(workload: spec.Workload,
current_param_container: spec.ParameterContainer,
current_params_types: spec.ParameterTypeTree,
model_state: spec.ModelAuxiliaryState,
hyperparameters: spec.Hyperparameters,
loss_type: spec.LossType,
optimizer_state: spec.OptimizerState,
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params)."""
del workload
del hyperparameters
del current_params_types
del loss_type
del eval_results
del global_step
del rng
return (optimizer_state, current_param_container, model_state)


def get_batch_size(workload_name):
# Return the global batch size.
if workload_name == 'criteo1tb':
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,27 @@ def update_params(
return (new_optimizer_state, opt_update_fn), new_params, new_model_state


def prepare_for_eval(workload: spec.Workload,
current_param_container: spec.ParameterContainer,
current_params_types: spec.ParameterTypeTree,
model_state: spec.ModelAuxiliaryState,
hyperparameters: spec.Hyperparameters,
loss_type: spec.LossType,
optimizer_state: spec.OptimizerState,
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params)."""
del workload
del hyperparameters
del current_params_types
del loss_type
del eval_results
del global_step
del rng
return (optimizer_state, current_param_container, model_state)


def get_batch_size(workload_name):
# Return the global batch size.
if workload_name == 'criteo1tb':
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,27 @@ def update_params(
return (optimizer_state, current_param_container, new_model_state)


def prepare_for_eval(workload: spec.Workload,
current_param_container: spec.ParameterContainer,
current_params_types: spec.ParameterTypeTree,
model_state: spec.ModelAuxiliaryState,
hyperparameters: spec.Hyperparameters,
loss_type: spec.LossType,
optimizer_state: spec.OptimizerState,
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params)."""
del workload
del hyperparameters
del current_params_types
del loss_type
del eval_results
del global_step
del rng
return (optimizer_state, current_param_container, model_state)


def get_batch_size(workload_name):
# Return the global batch size.
if workload_name == 'criteo1tb':
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,27 @@ def update_params(
return (optimizer_state, current_param_container, new_model_state)


def prepare_for_eval(workload: spec.Workload,
current_param_container: spec.ParameterContainer,
current_params_types: spec.ParameterTypeTree,
model_state: spec.ModelAuxiliaryState,
hyperparameters: spec.Hyperparameters,
loss_type: spec.LossType,
optimizer_state: spec.OptimizerState,
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params)."""
del workload
del hyperparameters
del current_params_types
del loss_type
del eval_results
del global_step
del rng
return (optimizer_state, current_param_container, model_state)


def get_batch_size(workload_name):
# Return the global batch size.
if workload_name == 'criteo1tb':
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,27 @@ def update_params(
return (new_optimizer_state, opt_update_fn), new_params, new_model_state


def prepare_for_eval(workload: spec.Workload,
current_param_container: spec.ParameterContainer,
current_params_types: spec.ParameterTypeTree,
model_state: spec.ModelAuxiliaryState,
hyperparameters: spec.Hyperparameters,
loss_type: spec.LossType,
optimizer_state: spec.OptimizerState,
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params)."""
del workload
del hyperparameters
del current_params_types
del loss_type
del eval_results
del global_step
del rng
return (optimizer_state, current_param_container, model_state)


def get_batch_size(workload_name):
# Return the global batch size.
if workload_name == 'criteo1tb':
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,27 @@ def update_params(
return (new_optimizer_state, opt_update_fn), new_params, new_model_state


def prepare_for_eval(workload: spec.Workload,
current_param_container: spec.ParameterContainer,
current_params_types: spec.ParameterTypeTree,
model_state: spec.ModelAuxiliaryState,
hyperparameters: spec.Hyperparameters,
loss_type: spec.LossType,
optimizer_state: spec.OptimizerState,
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params)."""
del workload
del hyperparameters
del current_params_types
del loss_type
del eval_results
del global_step
del rng
return (optimizer_state, current_param_container, model_state)


def get_batch_size(workload_name):
# Return the global batch size.
if workload_name == 'criteo1tb':
Expand Down
Loading
Loading