From e6c2106c2460d0149235dd4eccfd4017b0952734 Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Thu, 12 Sep 2024 15:30:15 +0200 Subject: [PATCH 001/105] added prepare_for_eval, eval only if is_time_remaining --- algorithmic_efficiency/spec.py | 14 +- submission_runner.py | 203 ++++++++++++++++------------- submissions/template/submission.py | 20 +++ 3 files changed, 149 insertions(+), 88 deletions(-) diff --git a/algorithmic_efficiency/spec.py b/algorithmic_efficiency/spec.py index 285983957..792093a2e 100644 --- a/algorithmic_efficiency/spec.py +++ b/algorithmic_efficiency/spec.py @@ -406,7 +406,19 @@ def init_optimizer_state(workload: Workload, RandomState ], UpdateReturn] - +PrepareForEvalFn = Callable[[ + Workload, + ParameterContainer, + ParameterTypeTree, + ModelAuxiliaryState, + Hyperparameters, + LossType, + OptimizerState, + List[Tuple[int, float]], + int, + RandomState +], + UpdateReturn] # Each call to this function is considered a "step". # Can raise a TrainingCompleteError if it believes it has achieved the goal and diff --git a/submission_runner.py b/submission_runner.py index 551173bf5..5df1f05ff 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -200,6 +200,7 @@ def train_once( init_optimizer_state: spec.InitOptimizerFn, update_params: spec.UpdateParamsFn, data_selection: spec.DataSelectionFn, + prepare_for_eval: spec.PrepareForEvalFn, hyperparameters: Optional[spec.Hyperparameters], rng_seed: int, rng: spec.RandomState, @@ -335,7 +336,7 @@ def train_once( not train_state['training_complete']: step_rng = prng.fold_in(rng, global_step) - data_select_rng, update_rng, eval_rng = prng.split(step_rng, 3) + data_select_rng, update_rng, prep_eval_rng, eval_rng = prng.split(step_rng, 4) with profiler.profile('Data selection'): batch = data_selection(workload, @@ -370,101 +371,128 @@ def train_once( train_state['accumulated_submission_time'] += ( train_step_end_time - train_state['last_step_end_time']) - # Use 3x the runtime budget for the self-tuning ruleset. - max_allowed_runtime_sec = ( - workload.max_allowed_runtime_sec if FLAGS.tuning_ruleset == 'external' - else 3 * workload.max_allowed_runtime_sec) - train_state['is_time_remaining'] = ( - train_state['accumulated_submission_time'] < max_allowed_runtime_sec) + # Check if submission is eligible for an untimed eval. if ((train_step_end_time - train_state['last_eval_time']) >= workload.eval_period_time_sec or train_state['training_complete']): - with profiler.profile('Evaluation'): + + # Prepare for evaluation (timed). + with profiler.profile('Prepare for eval'): del batch - _reset_cuda_mem() - - try: - eval_start_time = get_time() - latest_eval_result = workload.eval_model(global_eval_batch_size, - model_params, - model_state, - eval_rng, - data_dir, - imagenet_v2_data_dir, - global_step) - # Check if targets reached. - # Note that this is one of the stopping conditions for the length of - # a training run. To score the run we only consider the time - # to validation target retrospectively. - train_state['validation_goal_reached'] = ( - workload.has_reached_validation_target(latest_eval_result) or - train_state['validation_goal_reached']) - train_state['test_goal_reached'] = ( - workload.has_reached_test_target(latest_eval_result) or - train_state['test_goal_reached']) - goals_reached = ( - train_state['validation_goal_reached'] and - train_state['test_goal_reached']) - # Save last eval time. - eval_end_time = get_time() - train_state['last_eval_time'] = eval_end_time - - # Accumulate eval time. - train_state[ - 'accumulated_eval_time'] += eval_end_time - eval_start_time - - # Add times to eval results for logging. - latest_eval_result['score'] = ( - train_state['accumulated_submission_time']) - latest_eval_result[ - 'total_duration'] = eval_end_time - global_start_time - latest_eval_result['accumulated_submission_time'] = train_state[ - 'accumulated_submission_time'] - latest_eval_result['accumulated_eval_time'] = train_state[ - 'accumulated_eval_time'] - latest_eval_result['accumulated_logging_time'] = train_state[ - 'accumulated_logging_time'] - time_since_start = latest_eval_result['total_duration'] - logging.info(f'Time since start: {time_since_start:.2f}s, ' - f'\tStep: {global_step}, \t{latest_eval_result}') - eval_results.append((global_step, latest_eval_result)) - - logging_start_time = get_time() - - if log_dir is not None and RANK == 0: - metrics_logger.append_scalar_metrics( - latest_eval_result, - global_step=global_step, - preemption_count=preemption_count, - is_eval=True, - ) - if save_checkpoints: - checkpoint_utils.save_checkpoint( - framework=FLAGS.framework, - optimizer_state=optimizer_state, - model_params=model_params, - model_state=model_state, - train_state=train_state, - eval_results=eval_results, - global_step=global_step, - preemption_count=preemption_count, - checkpoint_dir=log_dir, - save_intermediate_checkpoints=FLAGS - .save_intermediate_checkpoints) + prepare_for_eval_start_time = get_time() + optimizer_state, model_params, model_state = prepare_for_eval( + workload=workload, + current_param_container=model_params, + current_params_types=workload.model_params_types, + model_state=model_state, + hyperparameters=hyperparameters, + loss_type=workload.loss_type, + optimizer_state=optimizer_state, + eval_results=eval_results, + global_step=global_step, + rng=prep_eval_rng) + prepare_for_eval_end_time = get_time() + + # Update sumbission time. + train_state['accumulated_submission_time'] += ( + prepare_for_eval_end_time - prepare_for_eval_start_time) + + # Check if time is remaining, + # use 3x the runtime budget for the self-tuning ruleset. + max_allowed_runtime_sec = ( + workload.max_allowed_runtime_sec if FLAGS.tuning_ruleset == 'external' + else 3 * workload.max_allowed_runtime_sec) + train_state['is_time_remaining'] = ( + train_state['accumulated_submission_time'] < max_allowed_runtime_sec) - logging_end_time = get_time() - train_state['accumulated_logging_time'] += ( - logging_end_time - logging_start_time) + # Eval if time is remaining (untimed). + if train_state['is_time_remaining']: + with profiler.profile('Evaluation'): _reset_cuda_mem() - except RuntimeError as e: - logging.exception(f'Eval step {global_step} error.\n') - if 'out of memory' in str(e): - logging.warning('Error: GPU out of memory during eval during step ' - f'{global_step}, error : {str(e)}.') + try: + eval_start_time = get_time() + latest_eval_result = workload.eval_model(global_eval_batch_size, + model_params, + model_state, + eval_rng, + data_dir, + imagenet_v2_data_dir, + global_step) + # Check if targets reached. + # Note that this is one of the stopping conditions for the length of + # a training run. To score the run we only consider the time + # to validation target retrospectively. + train_state['validation_goal_reached'] = ( + workload.has_reached_validation_target(latest_eval_result) or + train_state['validation_goal_reached']) + train_state['test_goal_reached'] = ( + workload.has_reached_test_target(latest_eval_result) or + train_state['test_goal_reached']) + goals_reached = ( + train_state['validation_goal_reached'] and + train_state['test_goal_reached']) + # Save last eval time. + eval_end_time = get_time() + train_state['last_eval_time'] = eval_end_time + + # Accumulate eval time. + train_state[ + 'accumulated_eval_time'] += eval_end_time - eval_start_time + + # Add times to eval results for logging. + latest_eval_result['score'] = ( + train_state['accumulated_submission_time']) + latest_eval_result[ + 'total_duration'] = eval_end_time - global_start_time + latest_eval_result['accumulated_submission_time'] = train_state[ + 'accumulated_submission_time'] + latest_eval_result['accumulated_eval_time'] = train_state[ + 'accumulated_eval_time'] + latest_eval_result['accumulated_logging_time'] = train_state[ + 'accumulated_logging_time'] + time_since_start = latest_eval_result['total_duration'] + logging.info(f'Time since start: {time_since_start:.2f}s, ' + f'\tStep: {global_step}, \t{latest_eval_result}') + eval_results.append((global_step, latest_eval_result)) + + logging_start_time = get_time() + + if log_dir is not None and RANK == 0: + metrics_logger.append_scalar_metrics( + latest_eval_result, + global_step=global_step, + preemption_count=preemption_count, + is_eval=True, + ) + if save_checkpoints: + checkpoint_utils.save_checkpoint( + framework=FLAGS.framework, + optimizer_state=optimizer_state, + model_params=model_params, + model_state=model_state, + train_state=train_state, + eval_results=eval_results, + global_step=global_step, + preemption_count=preemption_count, + checkpoint_dir=log_dir, + save_intermediate_checkpoints=FLAGS + .save_intermediate_checkpoints) + + logging_end_time = get_time() + train_state['accumulated_logging_time'] += ( + logging_end_time - logging_start_time) + _reset_cuda_mem() + except RuntimeError as e: + logging.exception(f'Eval step {global_step} error.\n') + if 'out of memory' in str(e): + logging.warning('Error: GPU out of memory during eval during step ' + f'{global_step}, error : {str(e)}.') + _reset_cuda_mem() + train_state['last_step_end_time'] = get_time() metrics = {'eval_results': eval_results, 'global_step': global_step} @@ -518,6 +546,7 @@ def score_submission_on_workload(workload: spec.Workload, init_optimizer_state = submission_module.init_optimizer_state update_params = submission_module.update_params data_selection = submission_module.data_selection + prepare_for_eval = submission_module.prepare_for_eval try: global_batch_size = submission_module.get_batch_size(workload_name) except ValueError: @@ -589,7 +618,7 @@ def score_submission_on_workload(workload: spec.Workload, global_eval_batch_size, data_dir, imagenet_v2_data_dir, init_optimizer_state, - update_params, data_selection, + update_params, data_selection, prepare_for_eval, hyperparameters, rng_seed, rng, diff --git a/submissions/template/submission.py b/submissions/template/submission.py index 5ef195db5..848d8af44 100644 --- a/submissions/template/submission.py +++ b/submissions/template/submission.py @@ -42,6 +42,26 @@ def update_params(workload: spec.Workload, pass +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + # batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """ + Returns: + new_optimizer_state + new_params + new_model_state + """ + pass + + def get_batch_size(workload_name): """ Gets batch size for workload. From 8bad99d663f34ce4b6b6c4a2a40b828e19fc3a5b Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Sat, 14 Sep 2024 18:11:43 +0200 Subject: [PATCH 002/105] added prepare_for_eval to all submissions --- .../external_tuning/jax_nadamw_full_budget.py | 21 ++++++++++++++++ .../jax_nadamw_target_setting.py | 21 ++++++++++++++++ .../pytorch_nadamw_full_budget.py | 21 ++++++++++++++++ .../pytorch_nadamw_target_setting.py | 21 ++++++++++++++++ .../self_tuning/jax_nadamw_full_budget.py | 21 ++++++++++++++++ .../self_tuning/jax_nadamw_target_setting.py | 21 ++++++++++++++++ .../self_tuning/pytorch_nadamw_full_budget.py | 21 ++++++++++++++++ .../pytorch_nadamw_target_setting.py | 21 ++++++++++++++++ .../cifar/cifar_jax/submission.py | 25 +++++++++++++++++-- .../cifar/cifar_pytorch/submission.py | 21 ++++++++++++++++ .../mnist/mnist_jax/submission.py | 21 ++++++++++++++++ .../mnist/mnist_pytorch/submission.py | 21 ++++++++++++++++ .../adafactor/jax/submission.py | 21 ++++++++++++++++ .../adafactor/pytorch/submission.py | 21 ++++++++++++++++ .../paper_baselines/adamw/jax/submission.py | 21 ++++++++++++++++ .../adamw/pytorch/submission.py | 21 ++++++++++++++++ .../paper_baselines/lamb/jax/submission.py | 21 ++++++++++++++++ .../lamb/pytorch/submission.py | 21 ++++++++++++++++ .../momentum/jax/submission.py | 21 ++++++++++++++++ .../momentum/pytorch/submission.py | 21 ++++++++++++++++ .../paper_baselines/nadamw/jax/submission.py | 21 ++++++++++++++++ .../nadamw/pytorch/submission.py | 21 ++++++++++++++++ .../nesterov/jax/submission.py | 21 ++++++++++++++++ .../nesterov/pytorch/submission.py | 21 ++++++++++++++++ .../paper_baselines/sam/jax/submission.py | 21 ++++++++++++++++ .../paper_baselines/sam/pytorch/submission.py | 21 ++++++++++++++++ .../paper_baselines/shampoo/jax/submission.py | 21 ++++++++++++++++ .../jax_submission_base.py | 21 ++++++++++++++++ .../pytorch_submission_base.py | 21 ++++++++++++++++ 29 files changed, 611 insertions(+), 2 deletions(-) diff --git a/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py b/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py index 98193f01f..5f203c5c6 100644 --- a/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py +++ b/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py @@ -299,6 +299,27 @@ def update_params(workload: spec.Workload, return (new_optimizer_state, opt_update_fn), new_params, new_model_state +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': diff --git a/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py b/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py index 66fdc4ebb..32f4e830e 100644 --- a/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py +++ b/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py @@ -299,6 +299,27 @@ def update_params(workload: spec.Workload, return (new_optimizer_state, opt_update_fn), new_params, new_model_state +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': diff --git a/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py b/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py index ebc49d428..ba56cd99f 100644 --- a/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py +++ b/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py @@ -301,6 +301,27 @@ def update_params(workload: spec.Workload, return (optimizer_state, current_param_container, new_model_state) +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': diff --git a/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py b/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py index 524bc20af..e2c44d9c1 100644 --- a/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py +++ b/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py @@ -301,6 +301,27 @@ def update_params(workload: spec.Workload, return (optimizer_state, current_param_container, new_model_state) +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': diff --git a/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py b/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py index 4f53afb56..502b7e5b4 100644 --- a/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py +++ b/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py @@ -314,6 +314,27 @@ def update_params(workload: spec.Workload, return (new_optimizer_state, opt_update_fn), new_params, new_model_state +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': diff --git a/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py b/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py index 60a1f784d..8bc2eed95 100644 --- a/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py +++ b/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py @@ -314,6 +314,27 @@ def update_params(workload: spec.Workload, return (new_optimizer_state, opt_update_fn), new_params, new_model_state +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': diff --git a/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py b/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py index f8e87ec2a..bbf548ccb 100644 --- a/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py +++ b/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py @@ -316,6 +316,27 @@ def update_params(workload: spec.Workload, return (optimizer_state, current_param_container, new_model_state) +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': diff --git a/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py b/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py index 1de26417f..992f769f3 100644 --- a/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py +++ b/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py @@ -316,6 +316,27 @@ def update_params(workload: spec.Workload, return (optimizer_state, current_param_container, new_model_state) +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': diff --git a/reference_algorithms/development_algorithms/cifar/cifar_jax/submission.py b/reference_algorithms/development_algorithms/cifar/cifar_jax/submission.py index 2971efe9a..b2256fc5a 100644 --- a/reference_algorithms/development_algorithms/cifar/cifar_jax/submission.py +++ b/reference_algorithms/development_algorithms/cifar/cifar_jax/submission.py @@ -108,8 +108,6 @@ def _loss_fn(params): return new_optimizer_state, updated_params, new_model_state -# Not allowed to update the model parameters, hyperparameters, global step, or -# optimzier state. def update_params(workload: spec.Workload, current_param_container: spec.ParameterContainer, current_params_types: spec.ParameterTypeTree, @@ -134,6 +132,29 @@ def update_params(workload: spec.Workload, return (new_optimizer_state, opt_update_fn), new_params, new_model_state +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + +# Not allowed to update the model parameters, hyperparameters, global step, or +# optimzier state. def data_selection(workload: spec.Workload, input_queue: Iterator[Dict[str, spec.Tensor]], optimizer_state: spec.OptimizerState, diff --git a/reference_algorithms/development_algorithms/cifar/cifar_pytorch/submission.py b/reference_algorithms/development_algorithms/cifar/cifar_pytorch/submission.py index 358c6bffc..b55c31afc 100644 --- a/reference_algorithms/development_algorithms/cifar/cifar_pytorch/submission.py +++ b/reference_algorithms/development_algorithms/cifar/cifar_pytorch/submission.py @@ -96,6 +96,27 @@ def update_params(workload: spec.Workload, return (optimizer_state, current_param_container, new_model_state) +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + # Not allowed to update the model parameters, hyperparameters, global step, or # optimzier state. def data_selection(workload: spec.Workload, diff --git a/reference_algorithms/development_algorithms/mnist/mnist_jax/submission.py b/reference_algorithms/development_algorithms/mnist/mnist_jax/submission.py index 896609d51..f09886215 100644 --- a/reference_algorithms/development_algorithms/mnist/mnist_jax/submission.py +++ b/reference_algorithms/development_algorithms/mnist/mnist_jax/submission.py @@ -106,6 +106,27 @@ def update_params(workload: spec.Workload, return (new_optimizer_state, opt_update_fn), updated_params, new_model_state +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + # Not allowed to update the model parameters, hyperparameters, global step, or # optimzier state. def data_selection(workload: spec.Workload, diff --git a/reference_algorithms/development_algorithms/mnist/mnist_pytorch/submission.py b/reference_algorithms/development_algorithms/mnist/mnist_pytorch/submission.py index f1601e606..8b5151c77 100644 --- a/reference_algorithms/development_algorithms/mnist/mnist_pytorch/submission.py +++ b/reference_algorithms/development_algorithms/mnist/mnist_pytorch/submission.py @@ -72,6 +72,27 @@ def update_params(workload: spec.Workload, return (optimizer_state, current_param_container, new_model_state) +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + # Not allowed to update the model parameters, hyperparameters, global step, or # optimzier state. def data_selection(workload: spec.Workload, diff --git a/reference_algorithms/paper_baselines/adafactor/jax/submission.py b/reference_algorithms/paper_baselines/adafactor/jax/submission.py index 2dd85c29b..ed2ee371f 100644 --- a/reference_algorithms/paper_baselines/adafactor/jax/submission.py +++ b/reference_algorithms/paper_baselines/adafactor/jax/submission.py @@ -157,6 +157,27 @@ def update_params(workload: spec.Workload, return (new_optimizer_state, opt_update_fn), new_params, new_model_state +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': diff --git a/reference_algorithms/paper_baselines/adafactor/pytorch/submission.py b/reference_algorithms/paper_baselines/adafactor/pytorch/submission.py index e6fef17dc..5f6540020 100644 --- a/reference_algorithms/paper_baselines/adafactor/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/adafactor/pytorch/submission.py @@ -265,6 +265,27 @@ def update_params(workload: spec.Workload, return (optimizer_state, current_param_container, new_model_state) +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': diff --git a/reference_algorithms/paper_baselines/adamw/jax/submission.py b/reference_algorithms/paper_baselines/adamw/jax/submission.py index 80a963600..5d2107ba6 100644 --- a/reference_algorithms/paper_baselines/adamw/jax/submission.py +++ b/reference_algorithms/paper_baselines/adamw/jax/submission.py @@ -157,6 +157,27 @@ def update_params(workload: spec.Workload, return (new_optimizer_state, opt_update_fn), new_params, new_model_state +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': diff --git a/reference_algorithms/paper_baselines/adamw/pytorch/submission.py b/reference_algorithms/paper_baselines/adamw/pytorch/submission.py index 32353e5b4..2b42bb5a4 100644 --- a/reference_algorithms/paper_baselines/adamw/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/adamw/pytorch/submission.py @@ -125,6 +125,27 @@ def update_params(workload: spec.Workload, return (optimizer_state, current_param_container, new_model_state) +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': diff --git a/reference_algorithms/paper_baselines/lamb/jax/submission.py b/reference_algorithms/paper_baselines/lamb/jax/submission.py index 27d635ee9..e08d5b433 100644 --- a/reference_algorithms/paper_baselines/lamb/jax/submission.py +++ b/reference_algorithms/paper_baselines/lamb/jax/submission.py @@ -165,6 +165,27 @@ def update_params(workload: spec.Workload, return (new_optimizer_state, opt_update_fn), new_params, new_model_state +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': diff --git a/reference_algorithms/paper_baselines/lamb/pytorch/submission.py b/reference_algorithms/paper_baselines/lamb/pytorch/submission.py index 7d0d8763e..da5865087 100644 --- a/reference_algorithms/paper_baselines/lamb/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/lamb/pytorch/submission.py @@ -258,6 +258,27 @@ def update_params(workload: spec.Workload, return (optimizer_state, current_param_container, new_model_state) +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': diff --git a/reference_algorithms/paper_baselines/momentum/jax/submission.py b/reference_algorithms/paper_baselines/momentum/jax/submission.py index cccb3c1b5..1ab362dd6 100644 --- a/reference_algorithms/paper_baselines/momentum/jax/submission.py +++ b/reference_algorithms/paper_baselines/momentum/jax/submission.py @@ -191,6 +191,27 @@ def update_params(workload: spec.Workload, return (new_optimizer_state, opt_update_fn), new_params, new_model_state +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': diff --git a/reference_algorithms/paper_baselines/momentum/pytorch/submission.py b/reference_algorithms/paper_baselines/momentum/pytorch/submission.py index ec5c0b31c..999321bd5 100644 --- a/reference_algorithms/paper_baselines/momentum/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/momentum/pytorch/submission.py @@ -144,6 +144,27 @@ def update_params(workload: spec.Workload, return (optimizer_state, current_param_container, new_model_state) +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': diff --git a/reference_algorithms/paper_baselines/nadamw/jax/submission.py b/reference_algorithms/paper_baselines/nadamw/jax/submission.py index 98193f01f..5f203c5c6 100644 --- a/reference_algorithms/paper_baselines/nadamw/jax/submission.py +++ b/reference_algorithms/paper_baselines/nadamw/jax/submission.py @@ -299,6 +299,27 @@ def update_params(workload: spec.Workload, return (new_optimizer_state, opt_update_fn), new_params, new_model_state +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': diff --git a/reference_algorithms/paper_baselines/nadamw/pytorch/submission.py b/reference_algorithms/paper_baselines/nadamw/pytorch/submission.py index ebc49d428..ba56cd99f 100644 --- a/reference_algorithms/paper_baselines/nadamw/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/nadamw/pytorch/submission.py @@ -301,6 +301,27 @@ def update_params(workload: spec.Workload, return (optimizer_state, current_param_container, new_model_state) +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': diff --git a/reference_algorithms/paper_baselines/nesterov/jax/submission.py b/reference_algorithms/paper_baselines/nesterov/jax/submission.py index f3b0aeed4..20109a9e3 100644 --- a/reference_algorithms/paper_baselines/nesterov/jax/submission.py +++ b/reference_algorithms/paper_baselines/nesterov/jax/submission.py @@ -191,6 +191,27 @@ def update_params(workload: spec.Workload, return (new_optimizer_state, opt_update_fn), new_params, new_model_state +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': diff --git a/reference_algorithms/paper_baselines/nesterov/pytorch/submission.py b/reference_algorithms/paper_baselines/nesterov/pytorch/submission.py index fe9154934..b4b8b77af 100644 --- a/reference_algorithms/paper_baselines/nesterov/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/nesterov/pytorch/submission.py @@ -144,6 +144,27 @@ def update_params(workload: spec.Workload, return (optimizer_state, current_param_container, new_model_state) +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': diff --git a/reference_algorithms/paper_baselines/sam/jax/submission.py b/reference_algorithms/paper_baselines/sam/jax/submission.py index 85b3d7441..9f12c4f3f 100644 --- a/reference_algorithms/paper_baselines/sam/jax/submission.py +++ b/reference_algorithms/paper_baselines/sam/jax/submission.py @@ -244,6 +244,27 @@ def update_params(workload: spec.Workload, return (new_optimizer_state, opt_update_fn), new_params, new_model_state +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': diff --git a/reference_algorithms/paper_baselines/sam/pytorch/submission.py b/reference_algorithms/paper_baselines/sam/pytorch/submission.py index 2cab75972..cf5e49f4f 100644 --- a/reference_algorithms/paper_baselines/sam/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/sam/pytorch/submission.py @@ -216,6 +216,27 @@ def _loss_fn(params, update_batch_norm=True): return (optimizer_state, current_param_container, new_model_state) +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': diff --git a/reference_algorithms/paper_baselines/shampoo/jax/submission.py b/reference_algorithms/paper_baselines/shampoo/jax/submission.py index 9c6b66b7f..b596f0bdc 100644 --- a/reference_algorithms/paper_baselines/shampoo/jax/submission.py +++ b/reference_algorithms/paper_baselines/shampoo/jax/submission.py @@ -160,6 +160,27 @@ def update_params(workload: spec.Workload, return (new_optimizer_state, opt_update_fn), new_params, new_model_state +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + def get_batch_size(workload_name): # Return the global batch size. if workload_name == 'criteo1tb': diff --git a/reference_algorithms/target_setting_algorithms/jax_submission_base.py b/reference_algorithms/target_setting_algorithms/jax_submission_base.py index 2a641b520..31e8a8850 100644 --- a/reference_algorithms/target_setting_algorithms/jax_submission_base.py +++ b/reference_algorithms/target_setting_algorithms/jax_submission_base.py @@ -109,3 +109,24 @@ def update_params(workload: spec.Workload, 'grad_norm': grad_norm[0], }, global_step) return (new_optimizer_state, opt_update_fn), new_params, new_model_state + + +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) diff --git a/reference_algorithms/target_setting_algorithms/pytorch_submission_base.py b/reference_algorithms/target_setting_algorithms/pytorch_submission_base.py index f9e40212b..549d2dc58 100644 --- a/reference_algorithms/target_setting_algorithms/pytorch_submission_base.py +++ b/reference_algorithms/target_setting_algorithms/pytorch_submission_base.py @@ -89,3 +89,24 @@ def update_params(workload: spec.Workload, grad_norm.item()) return (optimizer_state, current_param_container, new_model_state) + + +def prepare_for_eval(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) From 1c7d51c0eb2cf64295f030b6ef0566bcd24b01cf Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Sun, 15 Sep 2024 11:08:48 +0200 Subject: [PATCH 003/105] fix formatting --- submission_runner.py | 25 ++++++++++++++----------- submissions/template/submission.py | 1 - 2 files changed, 14 insertions(+), 12 deletions(-) diff --git a/submission_runner.py b/submission_runner.py index 5df1f05ff..a711be9ac 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -336,7 +336,8 @@ def train_once( not train_state['training_complete']: step_rng = prng.fold_in(rng, global_step) - data_select_rng, update_rng, prep_eval_rng, eval_rng = prng.split(step_rng, 4) + data_select_rng, update_rng, prep_eval_rng, eval_rng = \ + prng.split(step_rng, 4) with profiler.profile('Data selection'): batch = data_selection(workload, @@ -414,12 +415,12 @@ def train_once( try: eval_start_time = get_time() latest_eval_result = workload.eval_model(global_eval_batch_size, - model_params, - model_state, - eval_rng, - data_dir, - imagenet_v2_data_dir, - global_step) + model_params, + model_state, + eval_rng, + data_dir, + imagenet_v2_data_dir, + global_step) # Check if targets reached. # Note that this is one of the stopping conditions for the length of # a training run. To score the run we only consider the time @@ -454,7 +455,7 @@ def train_once( 'accumulated_logging_time'] time_since_start = latest_eval_result['total_duration'] logging.info(f'Time since start: {time_since_start:.2f}s, ' - f'\tStep: {global_step}, \t{latest_eval_result}') + f'\tStep: {global_step}, \t{latest_eval_result}') eval_results.append((global_step, latest_eval_result)) logging_start_time = get_time() @@ -489,8 +490,9 @@ def train_once( except RuntimeError as e: logging.exception(f'Eval step {global_step} error.\n') if 'out of memory' in str(e): - logging.warning('Error: GPU out of memory during eval during step ' - f'{global_step}, error : {str(e)}.') + logging.warning( + 'Error: GPU out of memory during eval during step ' + f'{global_step}, error : {str(e)}.') _reset_cuda_mem() train_state['last_step_end_time'] = get_time() @@ -618,7 +620,8 @@ def score_submission_on_workload(workload: spec.Workload, global_eval_batch_size, data_dir, imagenet_v2_data_dir, init_optimizer_state, - update_params, data_selection, prepare_for_eval, + update_params, data_selection, + prepare_for_eval, hyperparameters, rng_seed, rng, diff --git a/submissions/template/submission.py b/submissions/template/submission.py index 848d8af44..445e1f7cd 100644 --- a/submissions/template/submission.py +++ b/submissions/template/submission.py @@ -47,7 +47,6 @@ def prepare_for_eval(workload: spec.Workload, current_params_types: spec.ParameterTypeTree, model_state: spec.ModelAuxiliaryState, hyperparameters: spec.Hyperparameters, - # batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, eval_results: List[Tuple[int, float]], From 21a580b56c1f19cff11b13b62d4fceb1dc003f29 Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Sun, 15 Sep 2024 11:18:03 +0200 Subject: [PATCH 004/105] fix formatting --- algorithmic_efficiency/spec.py | 3 ++- submission_runner.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/algorithmic_efficiency/spec.py b/algorithmic_efficiency/spec.py index 792093a2e..25bd7b6d0 100644 --- a/algorithmic_efficiency/spec.py +++ b/algorithmic_efficiency/spec.py @@ -418,7 +418,8 @@ def init_optimizer_state(workload: Workload, int, RandomState ], - UpdateReturn] + UpdateReturn] + # Each call to this function is considered a "step". # Can raise a TrainingCompleteError if it believes it has achieved the goal and diff --git a/submission_runner.py b/submission_runner.py index a711be9ac..632cb450b 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -620,7 +620,7 @@ def score_submission_on_workload(workload: spec.Workload, global_eval_batch_size, data_dir, imagenet_v2_data_dir, init_optimizer_state, - update_params, data_selection, + update_params, data_selection, prepare_for_eval, hyperparameters, rng_seed, From 420b583f8bd60ca13b6b7cf9a7d0b8211d5c904b Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Sun, 15 Sep 2024 12:48:13 +0200 Subject: [PATCH 005/105] updated documentation --- DOCUMENTATION.md | 33 ++++++++++++++++++++++++++++++--- 1 file changed, 30 insertions(+), 3 deletions(-) diff --git a/DOCUMENTATION.md b/DOCUMENTATION.md index 607f47ead..586e03d8c 100644 --- a/DOCUMENTATION.md +++ b/DOCUMENTATION.md @@ -80,7 +80,7 @@ In principle, submissions are allowed to use the available hardware systems in a Submissions provide a [per-workload batch size](#batch-size-getter) to use. Specification of the batch size for each workload is necessary to avoid running out of memory for different workloads. Therefore, submitters can determine this batch size in advance and specify it as part of the submission. Submitters may also provide per-workload batch sizes for all [randomized workloads](#randomized-workloads). If no such batch size is provided for a randomized workload, by default, submissions will then use the batch size of the most similar [fixed workload](#fixed-workloads) (for example, if there is an ImageNet fixed workload and also a randomized workload with a similarly sized model on similarly sized images, the ImageNet batch size will be used for held-out workloads generated from this randomized workload). Note that submitters are *not* allowed to modify the *evaluation batch size*, which is set by the benchmarking codebase. However, you can file an issue if you believe that the evaluation batch size of a particular workload is set inappropriately. The working group will review this request and consider adjusting the evaluation batch size in the benchmarking codebase, thus affecting all submitters equally. -The **submission functions** are the *batch size getter*, *optimizer state initializer*, *variable update*, and *data selection functions*. The *fixed functions* are the *data augmentation/preprocessing*, *model initialization*, *forward pass*, and *loss function*. The trained model will be evaluated in a separate step that does not call any of the submitted code. +The **submission functions** are the *batch size getter*, *optimizer state initializer*, *variable update*, *prepare for evaluation function*, and *data selection functions*. The *fixed functions* are the *data augmentation/preprocessing*, *model initialization*, *forward pass*, and *loss function*. The trained model will be evaluated in a separate step that does not call any of the submitted code. ##### Fixed functions @@ -218,9 +218,35 @@ def update_params( - Cannot modify the given hyperparameters in a workload-conditional way (please see the [Valid submission](#valid-submissions) section). This rule is intended to prohibit circumventing the tuning rules by looking up a pre-tuned optimal set of hyperparameters for each workload. It is not intended to prohibit line searches and other similar techniques. - The fixed `init_model_fn` can optionally be called during training, for example, to reinitialize the model after a failed training effort. - Cannot replace the model parameters with pre-trained ones. -- This API supports Polyak averaging and similar methods that implement moving averages of model parameters. - Batch norm should work here because the `model_fn` will return updated batch norm moving averages when it is told to with `update_batch_norm`. + +###### Prepare for evaluation function + +```python +def prepare_for_eval( + workload: Workload, + current_param_container: ParameterContainer, + current_params_types: ParameterTypeTree, + model_state: ModelAuxiliaryState, + hyperparameters: Hyperparameters, + loss_type: LossType, + optimizer_state: OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: RandomState +) -> (updated_optimizer_state, updated_variables, updated_model_state) +``` + +- Arguments are the same of `update_param`, with the only exception of `batch`. +- This function is called when a submission is deemed eligible for an evaluation (see [Evluation during training](#evaluation-during-training) section). + - The call to `prepare_for_eval` is timed and its runtime accumulates to the overall submission time. + - The returned model parameters are evaluated on the validation and test sets, provided that the accumulated submission time does not exceed the maximum runtime after this function call. +- This API supports Polyak averaging and similar methods that implement moving averages of model parameters. +- Allowed to update model state and model parameters. +- Allowed to update state for the optimizer. +- Cannot replace the model parameters with pre-trained ones. + ###### Data selection ```python @@ -250,7 +276,8 @@ def data_selection( In general, with noisy, non-deterministic training, evaluation frequency can affect training time measurements as more "bites of the apple" potentially allows the training code to exploit instability. We also want to discourage submissions from complicated and unrealistic logic that attempts to guess when training is close to complete and increases the evaluation rate, while not producing a well-sampled training curve at the start of training. Simply allowing submissions complete freedom over evaluation frequency encourages competitors to work to minimize the number of evaluations, which distracts from the primary goal of finding better training algorithms. -Submissions are eligible for an untimed eval every `eval_period` seconds, run as soon as the current call of `update_params` completes. Any additional evaluations performed by the submission code count against the runtime for scoring. The harness that runs the submission code will attempt to eval every `eval_period` seconds by checking between each submission step (call of `update_params`) whether it has been at least `eval_period` seconds since that last eval and, if so, pausing the clock and running an eval. This means that if calls to `update_params` typically take a lot more than `eval_period` seconds, such submissions will not receive as many untimed evals as a submission that had an `update_params` function that took less time. However, for appropriate settings of `eval_period`, we expect this to be quite rare. Submissions are always free to restructure their `update_params` code to split work into two subsequent steps to regain the potential benefits of these untimed model evaluations. For each workload, the `eval_period` will be set such that the total evaluation time is roughly between 10% and 20% of the total training time for the target-setting runs. +Submissions are eligible for an untimed eval every `eval_period` seconds. Before proceeding to evaluation, the submission can prepare the model through a call to `prepare_for_eval`, effectively modifying the model parameters and state as well as the the optimizer state. Any additional evaluations performed by the submission code count against the runtime for scoring. +The harness that runs the submission code will attempt to eval every `eval_period` seconds by checking between each submission step (call of `update_params`) whether it has been at least `eval_period` seconds since that last eval, if so, the submission is given the possibility to prepare for evaluation (through a timed call to `prepare_for_eval`). If the accumulated runtime does not exceed the maximum allowed runtime after the preparation step, the clock is paused, and the submission is evaluated. This means that if calls to `update_params` typically take a lot more than `eval_period` seconds, such submissions will not receive as many untimed evals as a submission that had an `update_params` function that took less time. However, for appropriate settings of `eval_period`, we expect this to be quite rare. Submissions are always free to restructure their `update_params` code to split work into two subsequent steps to regain the potential benefits of these untimed model evaluations. For each workload, the `eval_period` will be set such that the total evaluation time is roughly between 10% and 20% of the total training time for the target-setting runs. #### Valid submissions From d9c4ee9d3a85f55e069db21b39feaf216ee9d42d Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Fri, 18 Oct 2024 17:19:41 +0200 Subject: [PATCH 006/105] add prepare_for_eval to spec.py --- algorithmic_efficiency/spec.py | 43 ++++++++++++++++++++++++---------- 1 file changed, 30 insertions(+), 13 deletions(-) diff --git a/algorithmic_efficiency/spec.py b/algorithmic_efficiency/spec.py index 25bd7b6d0..b8be5fcaa 100644 --- a/algorithmic_efficiency/spec.py +++ b/algorithmic_efficiency/spec.py @@ -406,19 +406,6 @@ def init_optimizer_state(workload: Workload, RandomState ], UpdateReturn] -PrepareForEvalFn = Callable[[ - Workload, - ParameterContainer, - ParameterTypeTree, - ModelAuxiliaryState, - Hyperparameters, - LossType, - OptimizerState, - List[Tuple[int, float]], - int, - RandomState -], - UpdateReturn] # Each call to this function is considered a "step". @@ -442,6 +429,36 @@ def update_params(workload: Workload, pass +PrepareForEvalFn = Callable[[ + Workload, + ParameterContainer, + ParameterTypeTree, + ModelAuxiliaryState, + Hyperparameters, + LossType, + OptimizerState, + List[Tuple[int, float]], + int, + RandomState +], + UpdateReturn] + + +# Prepare model and optimizer for evaluation. +def prepare_for_eval(workload: Workload, + current_param_container: ParameterContainer, + current_params_types: ParameterTypeTree, + model_state: ModelAuxiliaryState, + hyperparameters: Hyperparameters, + loss_type: LossType, + optimizer_state: OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: RandomState) -> UpdateReturn: + """Return (updated_optimizer_state, updated_params, updated_model_state).""" + pass + + DataSelectionFn = Callable[[ Workload, Iterator[Dict[str, Any]], From 9caedc5570550708aba7d2695e15b2480ca7cf0f Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Mon, 21 Oct 2024 11:48:35 +0200 Subject: [PATCH 007/105] make prepare_for_eval backward compatible --- submission_runner.py | 42 ++++++++++++++++++++++-------------------- 1 file changed, 22 insertions(+), 20 deletions(-) diff --git a/submission_runner.py b/submission_runner.py index 632cb450b..3ef30ffba 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -378,25 +378,27 @@ def train_once( workload.eval_period_time_sec or train_state['training_complete']): # Prepare for evaluation (timed). - with profiler.profile('Prepare for eval'): - del batch - prepare_for_eval_start_time = get_time() - optimizer_state, model_params, model_state = prepare_for_eval( - workload=workload, - current_param_container=model_params, - current_params_types=workload.model_params_types, - model_state=model_state, - hyperparameters=hyperparameters, - loss_type=workload.loss_type, - optimizer_state=optimizer_state, - eval_results=eval_results, - global_step=global_step, - rng=prep_eval_rng) - prepare_for_eval_end_time = get_time() - - # Update sumbission time. - train_state['accumulated_submission_time'] += ( - prepare_for_eval_end_time - prepare_for_eval_start_time) + if prepare_for_eval is not None: + + with profiler.profile('Prepare for eval'): + del batch + prepare_for_eval_start_time = get_time() + optimizer_state, model_params, model_state = prepare_for_eval( + workload=workload, + current_param_container=model_params, + current_params_types=workload.model_params_types, + model_state=model_state, + hyperparameters=hyperparameters, + loss_type=workload.loss_type, + optimizer_state=optimizer_state, + eval_results=eval_results, + global_step=global_step, + rng=prep_eval_rng) + prepare_for_eval_end_time = get_time() + + # Update sumbission time. + train_state['accumulated_submission_time'] += ( + prepare_for_eval_end_time - prepare_for_eval_start_time) # Check if time is remaining, # use 3x the runtime budget for the self-tuning ruleset. @@ -548,7 +550,7 @@ def score_submission_on_workload(workload: spec.Workload, init_optimizer_state = submission_module.init_optimizer_state update_params = submission_module.update_params data_selection = submission_module.data_selection - prepare_for_eval = submission_module.prepare_for_eval + prepare_for_eval = getattr(submission_module, 'prepare_for_eval', None) try: global_batch_size = submission_module.get_batch_size(workload_name) except ValueError: From 4d74d2ccee73ae6096a9fceff6a7b60c80f8f5a7 Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Mon, 21 Oct 2024 12:00:29 +0200 Subject: [PATCH 008/105] optional prepare_for_eval arg --- submission_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/submission_runner.py b/submission_runner.py index 3ef30ffba..c396cb027 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -200,7 +200,7 @@ def train_once( init_optimizer_state: spec.InitOptimizerFn, update_params: spec.UpdateParamsFn, data_selection: spec.DataSelectionFn, - prepare_for_eval: spec.PrepareForEvalFn, + prepare_for_eval: Optional[spec.PrepareForEvalFn], hyperparameters: Optional[spec.Hyperparameters], rng_seed: int, rng: spec.RandomState, From 8cc4f4a0278406fb3b2ad5a3f9f5d4b5fd329daf Mon Sep 17 00:00:00 2001 From: init-22 Date: Thu, 31 Oct 2024 22:26:53 +0530 Subject: [PATCH 009/105] default dropout rates for workloads are added --- DOCUMENTATION.md | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/DOCUMENTATION.md b/DOCUMENTATION.md index 607f47ead..851d85dbc 100644 --- a/DOCUMENTATION.md +++ b/DOCUMENTATION.md @@ -400,6 +400,22 @@ Submissions will be scored based on their performance on the [fixed workload](#f Furthermore, a less computationally expensive subset of the fixed workloads is collected with the [qualification set](#qualification-set). Submitters without enough compute resources to self-report on the full set of fixed and held-out workloads can instead self-report on this smaller qualification set. Well-performing submissions can thereby qualify for computational resources provided by sponsors of the benchmark to be scored on the full benchmark set. +#### Default Dropout Values for Different Workloads: + +| Workload | Dropout Values | +|------------------------|------------------------------------------------------------------------------------------------------| +| cifar | dropout not used | +| criteo 1tb | dropout_rate: 0.0 | +| fastmri | dropout_rate: 0.0 | +| imagenet_resnet | dropout not used | +| imagenet_vit | dropout_rate: 0.0 | +| librispeech_conformer | attention_dropout_rate: 0.0
attention_residual_dropout_rate: 0.1
conv_residual_dropout_rate: 0.0
feed_forward_dropout_rate: 0.0
feed_forward_residual_dropout_rate: 0.1
input_dropout_rate: 0.1 | +| librispeech_deepspeech | input_dropout_rate: 0.1
feed_forward_dropout_rate: 0.1
(Only for JAX - dropout_rate in CudnnLSTM class: 0.0) | +| mnist | dropout not used | +| ogbg | dropout_rate: 0.1 | +| wmt | dropout_rate: 0.1
attention_dropout_rate: 0.1 | + + NOTE: Submitters are no longer required to self-report results for AlgoPerf competition v0.5. #### Fixed workloads From a6fc879e119cc805bafd98ecd086b1243b9a42c7 Mon Sep 17 00:00:00 2001 From: init-22 Date: Thu, 31 Oct 2024 22:45:21 +0530 Subject: [PATCH 010/105] adding the dropout info in fixed workload section --- DOCUMENTATION.md | 33 +++++++++++++++++---------------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/DOCUMENTATION.md b/DOCUMENTATION.md index 851d85dbc..2decbcb46 100644 --- a/DOCUMENTATION.md +++ b/DOCUMENTATION.md @@ -400,22 +400,6 @@ Submissions will be scored based on their performance on the [fixed workload](#f Furthermore, a less computationally expensive subset of the fixed workloads is collected with the [qualification set](#qualification-set). Submitters without enough compute resources to self-report on the full set of fixed and held-out workloads can instead self-report on this smaller qualification set. Well-performing submissions can thereby qualify for computational resources provided by sponsors of the benchmark to be scored on the full benchmark set. -#### Default Dropout Values for Different Workloads: - -| Workload | Dropout Values | -|------------------------|------------------------------------------------------------------------------------------------------| -| cifar | dropout not used | -| criteo 1tb | dropout_rate: 0.0 | -| fastmri | dropout_rate: 0.0 | -| imagenet_resnet | dropout not used | -| imagenet_vit | dropout_rate: 0.0 | -| librispeech_conformer | attention_dropout_rate: 0.0
attention_residual_dropout_rate: 0.1
conv_residual_dropout_rate: 0.0
feed_forward_dropout_rate: 0.0
feed_forward_residual_dropout_rate: 0.1
input_dropout_rate: 0.1 | -| librispeech_deepspeech | input_dropout_rate: 0.1
feed_forward_dropout_rate: 0.1
(Only for JAX - dropout_rate in CudnnLSTM class: 0.0) | -| mnist | dropout not used | -| ogbg | dropout_rate: 0.1 | -| wmt | dropout_rate: 0.1
attention_dropout_rate: 0.1 | - - NOTE: Submitters are no longer required to self-report results for AlgoPerf competition v0.5. #### Fixed workloads @@ -433,6 +417,23 @@ The currently eight fixed workloads are: | **7** | Molecular property prediction | OGBG | GNN | CE | mAP | 0.28098 | 0.268729 | 18,477 | | **8** | Translation | WMT | Transformer | CE | BLEU | 30.8491 | 30.7219 | 48,151 | +#### Default Dropout Values for Different Workloads: + +| Workload | Dropout Values | +|------------------------|------------------------------------------------------------------------------------------------------| +| cifar | dropout not used | +| criteo 1tb | dropout_rate: 0.0 | +| fastmri | dropout_rate: 0.0 | +| imagenet_resnet | dropout not used | +| imagenet_vit | dropout_rate: 0.0 | +| librispeech_conformer | attention_dropout_rate: 0.0
attention_residual_dropout_rate: 0.1
conv_residual_dropout_rate: 0.0
feed_forward_dropout_rate: 0.0
feed_forward_residual_dropout_rate: 0.1
input_dropout_rate: 0.1 | +| librispeech_deepspeech | input_dropout_rate: 0.1
feed_forward_dropout_rate: 0.1
(Only for JAX - dropout_rate in CudnnLSTM class: 0.0) | +| mnist | dropout not used | +| ogbg | dropout_rate: 0.1 | +| wmt | dropout_rate: 0.1
attention_dropout_rate: 0.1 | + + + #### Randomized workloads In addition to the [fixed and known workloads](#fixed-workloads), there will also be randomized workloads in our benchmark. These randomized workloads will introduce minor modifications to a fixed workload (e.g. small model changes). The exact instances of these randomized workloads will only be created after the submission deadline and are thus unknown to both the submitters as well as the benchmark organizers. The instructions for creating them, i.e. providing a set or distribution of workloads to sample from, will be defined by this working group and made public with the call for submissions, to allow the members of this working group to submit as well as ensure that they do not possess any additional information compared to other submitters. We will refer to the unspecific workloads as *randomized workloads*, e.g. the set or distribution. The specific instance of such a randomized workload we call a *held-out workload*. That is, a held-out workload is a specific sample of a randomized workload that is used for one iteration of the benchmark. While we may reuse randomized workloads between iterations of the benchmark, new held-out workloads will be sampled for each new benchmark iteration. From 19838992f8edb766860f655670215c037ddcc834 Mon Sep 17 00:00:00 2001 From: init-22 Date: Thu, 31 Oct 2024 22:47:07 +0530 Subject: [PATCH 011/105] removing bold headings --- DOCUMENTATION.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/DOCUMENTATION.md b/DOCUMENTATION.md index 2decbcb46..0c9c429c6 100644 --- a/DOCUMENTATION.md +++ b/DOCUMENTATION.md @@ -417,7 +417,7 @@ The currently eight fixed workloads are: | **7** | Molecular property prediction | OGBG | GNN | CE | mAP | 0.28098 | 0.268729 | 18,477 | | **8** | Translation | WMT | Transformer | CE | BLEU | 30.8491 | 30.7219 | 48,151 | -#### Default Dropout Values for Different Workloads: +Default Dropout Values for Different Workloads: | Workload | Dropout Values | |------------------------|------------------------------------------------------------------------------------------------------| From e16ebe091aebcf126694c63588ccf15f8f1b3cf0 Mon Sep 17 00:00:00 2001 From: init-22 Date: Thu, 14 Nov 2024 20:33:56 +0530 Subject: [PATCH 012/105] fix: changing the dtype in random_utils to uint32 --- algorithmic_efficiency/random_utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/algorithmic_efficiency/random_utils.py b/algorithmic_efficiency/random_utils.py index cf1ea6c32..31317047e 100644 --- a/algorithmic_efficiency/random_utils.py +++ b/algorithmic_efficiency/random_utils.py @@ -18,8 +18,8 @@ # Annoyingly, RandomState(seed) requires seed to be in [0, 2 ** 32 - 1] (an # unsigned int), while RandomState.randint only accepts and returns signed ints. -MAX_INT32 = 2**31 -MIN_INT32 = -MAX_INT32 +MAX_UINT32 = 2**31 +MIN_UINT32 = 0 SeedType = Union[int, list, np.ndarray] @@ -35,13 +35,13 @@ def _signed_to_unsigned(seed: SeedType) -> SeedType: def _fold_in(seed: SeedType, data: Any) -> List[Union[SeedType, Any]]: rng = np.random.RandomState(seed=_signed_to_unsigned(seed)) - new_seed = rng.randint(MIN_INT32, MAX_INT32, dtype=np.int32) + new_seed = rng.randint(MIN_UINT32, MAX_UINT32, dtype=np.uint32) return [new_seed, data] def _split(seed: SeedType, num: int = 2) -> SeedType: rng = np.random.RandomState(seed=_signed_to_unsigned(seed)) - return rng.randint(MIN_INT32, MAX_INT32, dtype=np.int32, size=[num, 2]) + return rng.randint(MIN_UINT32, MAX_UINT32, dtype=np.uint32, size=[num, 2]) def _PRNGKey(seed: SeedType) -> SeedType: # pylint: disable=invalid-name From 42da4fd7bb8ee92d8fb47e4c07456ac3f2d45e3d Mon Sep 17 00:00:00 2001 From: init-22 Date: Thu, 14 Nov 2024 23:21:17 +0530 Subject: [PATCH 013/105] feat: package updates with python 3.11 --- docker/Dockerfile | 29 +++++++++++++++- setup.cfg | 86 +++++++++++++++++++++++------------------------ 2 files changed, 71 insertions(+), 44 deletions(-) diff --git a/docker/Dockerfile b/docker/Dockerfile index 9b72aea86..24d05b495 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -11,7 +11,34 @@ FROM nvidia/cuda:12.1.1-cudnn8-devel-ubuntu20.04 RUN echo "Setting up machine" RUN apt-get update RUN apt-get install -y curl tar -RUN DEBIAN_FRONTEND=noninteractive apt-get install -y git python3 pip wget ffmpeg +RUN DEBIAN_FRONTEND=noninteractive apt-get install -y git ffmpeg + +# Install prerequisites +RUN apt-get update && apt-get install -y \ + wget \ + build-essential \ + zlib1g-dev \ + libncurses5-dev \ + libssl-dev \ + libreadline-dev \ + libffi-dev \ + curl \ + libbz2-dev \ + liblzma-dev + +# Download and install Python 3.11 +RUN cd /tmp \ + && wget https://www.python.org/ftp/python/3.11.0/Python-3.11.0.tgz \ + && tar -xvzf Python-3.11.0.tgz \ + && cd Python-3.11.0 \ + && ./configure --enable-optimizations \ + && make -j$(nproc) \ + && make altinstall + +# Create symlinks for python and pip (use 'pip' instead of 'pip3') +RUN ln -s /usr/local/bin/python3.11 /usr/bin/python \ + && ln -s /usr/local/bin/pip3.11 /usr/bin/pip + RUN apt-get install libtcmalloc-minimal4 RUN apt-get install unzip RUN apt-get install pigz diff --git a/setup.cfg b/setup.cfg index 4afefd164..deeb1c6c4 100644 --- a/setup.cfg +++ b/setup.cfg @@ -24,6 +24,7 @@ classifiers = Programming Language :: Python :: 3.8 Programming Language :: Python :: 3.9 Programming Language :: Python :: 3.10 + Programming Language :: Python :: 3.11 Topic :: Scientific/Engineering :: Artificial Intelligence [options] @@ -34,22 +35,22 @@ setup_requires = setuptools_scm # Dependencies of the project: install_requires = - absl-py==1.4.0 + absl-py==2.1.1 # Pin to avoid unpinned install in dependencies that requires Python>=3.9. - networkx==3.1 - docker==7.0.0 - numpy>=1.23 - pandas>=2.0.1 - tensorflow==2.12.0 - tensorflow-datasets==4.9.2 - tensorflow-probability==0.20.0 - tensorflow-addons==0.20.0 + networkx==3.2.1 + docker==7.1.0 + numpy>=1.26.4 + pandas==2.2.3 + tensorflow==2.18.0 + tensorflow-datasets==4.9.7 + tensorflow-addons==0.23.0 gputil==1.4.0 - psutil==5.9.5 - clu==0.0.7 - matplotlib>=3.7.2 + psutil==6.1.0 + clu==0.0.12 + matplotlib>=3.9.2 tabulate==0.9.0 -python_requires = >=3.8 + wandb==0.18.7 +python_requires = >=3.11 ############################################################################### @@ -79,78 +80,77 @@ full_dev = # Dependencies for developing the package dev = - isort==5.12.0 - pylint==2.17.4 - pytest==7.3.1 - yapf==0.33.0 - pre-commit==3.3.1 + isort==5.13.2 + pylint==3.3.1 + pytest==8.3.3 + yapf==0.43.0 + pre-commit==4.0.1 # Workloads # criteo1tb = - scikit-learn==1.2.2 + scikit-learn==1.5.2 fastmri = - h5py==3.8.0 - scikit-image==0.20.0 + h5py==3.12.1 + scikit-image==0.24.0 ogbg = jraph==0.0.6.dev0 - scikit-learn==1.2.2 + scikit-learn==1.5.2 librispeech_conformer = - sentencepiece==0.1.99 - tensorflow-text==2.12.1 + sentencepiece==0.2.0 + tensorflow-text==2.18.0 pydub==0.25.1 wmt = - sentencepiece==0.1.99 - tensorflow-text==2.12.1 - sacrebleu==1.3.1 + sentencepiece==0.2.0 + tensorflow-text==2.18.0 + sacrebleu==2.4.3 # Frameworks # # JAX Core jax_core_deps = - flax==0.6.10 - optax==0.1.5 + flax==0.10.1 + optax==0.2.4 # Fix chex (optax dependency) version. # Not fixing it can raise dependency issues with our # jax version. # Todo(kasimbeg): verify if this is necessary after we # upgrade jax. - chex==0.1.7 - ml_dtypes==0.2.0 - protobuf==4.25.3 + chex==0.1.87 + ml_dtypes==0.4.1 + protobuf==4.25.5 # JAX CPU jax_cpu = - jax==0.4.10 - jaxlib==0.4.10 + jax==0.4.35 + jaxlib==0.4.35 %(jax_core_deps)s # JAX GPU # Note this installs both jax and jaxlib. jax_gpu = - jax==0.4.10 - jaxlib==0.4.10+cuda12.cudnn88 + jax==0.4.35 + jaxlib==0.4.35 + jax-cuda12-plugin[with_cuda]==0.4.35 + jax-cuda12-pjrt==0.4.35 %(jax_core_deps)s # PyTorch CPU pytorch_cpu = - torch==2.1.0 - torchvision==0.16.0 + torch==2.5.0 + torchvision==0.20.0 # PyTorch GPU # Note: omit the cuda suffix and installing from the appropriate # wheel will result in using locally installed CUDA. pytorch_gpu = - torch==2.1.0 - torchvision==0.16.0 + torch==2.5.0 + torchvision==0.20.0 -# wandb -wandb = - wandb==0.16.5 ############################################################################### # Linting Configurations # From 10057769ca798f52ab6c32de41e589b45d6a5b6b Mon Sep 17 00:00:00 2001 From: init-22 Date: Thu, 14 Nov 2024 23:26:54 +0530 Subject: [PATCH 014/105] fix: reverting the python311 changes --- docker/Dockerfile | 29 +--------------- setup.cfg | 86 +++++++++++++++++++++++------------------------ 2 files changed, 44 insertions(+), 71 deletions(-) diff --git a/docker/Dockerfile b/docker/Dockerfile index 24d05b495..9b72aea86 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -11,34 +11,7 @@ FROM nvidia/cuda:12.1.1-cudnn8-devel-ubuntu20.04 RUN echo "Setting up machine" RUN apt-get update RUN apt-get install -y curl tar -RUN DEBIAN_FRONTEND=noninteractive apt-get install -y git ffmpeg - -# Install prerequisites -RUN apt-get update && apt-get install -y \ - wget \ - build-essential \ - zlib1g-dev \ - libncurses5-dev \ - libssl-dev \ - libreadline-dev \ - libffi-dev \ - curl \ - libbz2-dev \ - liblzma-dev - -# Download and install Python 3.11 -RUN cd /tmp \ - && wget https://www.python.org/ftp/python/3.11.0/Python-3.11.0.tgz \ - && tar -xvzf Python-3.11.0.tgz \ - && cd Python-3.11.0 \ - && ./configure --enable-optimizations \ - && make -j$(nproc) \ - && make altinstall - -# Create symlinks for python and pip (use 'pip' instead of 'pip3') -RUN ln -s /usr/local/bin/python3.11 /usr/bin/python \ - && ln -s /usr/local/bin/pip3.11 /usr/bin/pip - +RUN DEBIAN_FRONTEND=noninteractive apt-get install -y git python3 pip wget ffmpeg RUN apt-get install libtcmalloc-minimal4 RUN apt-get install unzip RUN apt-get install pigz diff --git a/setup.cfg b/setup.cfg index deeb1c6c4..4afefd164 100644 --- a/setup.cfg +++ b/setup.cfg @@ -24,7 +24,6 @@ classifiers = Programming Language :: Python :: 3.8 Programming Language :: Python :: 3.9 Programming Language :: Python :: 3.10 - Programming Language :: Python :: 3.11 Topic :: Scientific/Engineering :: Artificial Intelligence [options] @@ -35,22 +34,22 @@ setup_requires = setuptools_scm # Dependencies of the project: install_requires = - absl-py==2.1.1 + absl-py==1.4.0 # Pin to avoid unpinned install in dependencies that requires Python>=3.9. - networkx==3.2.1 - docker==7.1.0 - numpy>=1.26.4 - pandas==2.2.3 - tensorflow==2.18.0 - tensorflow-datasets==4.9.7 - tensorflow-addons==0.23.0 + networkx==3.1 + docker==7.0.0 + numpy>=1.23 + pandas>=2.0.1 + tensorflow==2.12.0 + tensorflow-datasets==4.9.2 + tensorflow-probability==0.20.0 + tensorflow-addons==0.20.0 gputil==1.4.0 - psutil==6.1.0 - clu==0.0.12 - matplotlib>=3.9.2 + psutil==5.9.5 + clu==0.0.7 + matplotlib>=3.7.2 tabulate==0.9.0 - wandb==0.18.7 -python_requires = >=3.11 +python_requires = >=3.8 ############################################################################### @@ -80,77 +79,78 @@ full_dev = # Dependencies for developing the package dev = - isort==5.13.2 - pylint==3.3.1 - pytest==8.3.3 - yapf==0.43.0 - pre-commit==4.0.1 + isort==5.12.0 + pylint==2.17.4 + pytest==7.3.1 + yapf==0.33.0 + pre-commit==3.3.1 # Workloads # criteo1tb = - scikit-learn==1.5.2 + scikit-learn==1.2.2 fastmri = - h5py==3.12.1 - scikit-image==0.24.0 + h5py==3.8.0 + scikit-image==0.20.0 ogbg = jraph==0.0.6.dev0 - scikit-learn==1.5.2 + scikit-learn==1.2.2 librispeech_conformer = - sentencepiece==0.2.0 - tensorflow-text==2.18.0 + sentencepiece==0.1.99 + tensorflow-text==2.12.1 pydub==0.25.1 wmt = - sentencepiece==0.2.0 - tensorflow-text==2.18.0 - sacrebleu==2.4.3 + sentencepiece==0.1.99 + tensorflow-text==2.12.1 + sacrebleu==1.3.1 # Frameworks # # JAX Core jax_core_deps = - flax==0.10.1 - optax==0.2.4 + flax==0.6.10 + optax==0.1.5 # Fix chex (optax dependency) version. # Not fixing it can raise dependency issues with our # jax version. # Todo(kasimbeg): verify if this is necessary after we # upgrade jax. - chex==0.1.87 - ml_dtypes==0.4.1 - protobuf==4.25.5 + chex==0.1.7 + ml_dtypes==0.2.0 + protobuf==4.25.3 # JAX CPU jax_cpu = - jax==0.4.35 - jaxlib==0.4.35 + jax==0.4.10 + jaxlib==0.4.10 %(jax_core_deps)s # JAX GPU # Note this installs both jax and jaxlib. jax_gpu = - jax==0.4.35 - jaxlib==0.4.35 - jax-cuda12-plugin[with_cuda]==0.4.35 - jax-cuda12-pjrt==0.4.35 + jax==0.4.10 + jaxlib==0.4.10+cuda12.cudnn88 %(jax_core_deps)s # PyTorch CPU pytorch_cpu = - torch==2.5.0 - torchvision==0.20.0 + torch==2.1.0 + torchvision==0.16.0 # PyTorch GPU # Note: omit the cuda suffix and installing from the appropriate # wheel will result in using locally installed CUDA. pytorch_gpu = - torch==2.5.0 - torchvision==0.20.0 + torch==2.1.0 + torchvision==0.16.0 +# wandb +wandb = + wandb==0.16.5 ############################################################################### # Linting Configurations # From ce99901e2ea31528c19161da35798f09c424c8e5 Mon Sep 17 00:00:00 2001 From: init-22 Date: Thu, 14 Nov 2024 23:31:40 +0530 Subject: [PATCH 015/105] feat: package updates with python311 --- algorithmic_efficiency/random_utils.py | 8 +-- docker/Dockerfile | 29 ++++++++- setup.cfg | 86 +++++++++++++------------- 3 files changed, 75 insertions(+), 48 deletions(-) diff --git a/algorithmic_efficiency/random_utils.py b/algorithmic_efficiency/random_utils.py index cf1ea6c32..31317047e 100644 --- a/algorithmic_efficiency/random_utils.py +++ b/algorithmic_efficiency/random_utils.py @@ -18,8 +18,8 @@ # Annoyingly, RandomState(seed) requires seed to be in [0, 2 ** 32 - 1] (an # unsigned int), while RandomState.randint only accepts and returns signed ints. -MAX_INT32 = 2**31 -MIN_INT32 = -MAX_INT32 +MAX_UINT32 = 2**31 +MIN_UINT32 = 0 SeedType = Union[int, list, np.ndarray] @@ -35,13 +35,13 @@ def _signed_to_unsigned(seed: SeedType) -> SeedType: def _fold_in(seed: SeedType, data: Any) -> List[Union[SeedType, Any]]: rng = np.random.RandomState(seed=_signed_to_unsigned(seed)) - new_seed = rng.randint(MIN_INT32, MAX_INT32, dtype=np.int32) + new_seed = rng.randint(MIN_UINT32, MAX_UINT32, dtype=np.uint32) return [new_seed, data] def _split(seed: SeedType, num: int = 2) -> SeedType: rng = np.random.RandomState(seed=_signed_to_unsigned(seed)) - return rng.randint(MIN_INT32, MAX_INT32, dtype=np.int32, size=[num, 2]) + return rng.randint(MIN_UINT32, MAX_UINT32, dtype=np.uint32, size=[num, 2]) def _PRNGKey(seed: SeedType) -> SeedType: # pylint: disable=invalid-name diff --git a/docker/Dockerfile b/docker/Dockerfile index 9b72aea86..24d05b495 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -11,7 +11,34 @@ FROM nvidia/cuda:12.1.1-cudnn8-devel-ubuntu20.04 RUN echo "Setting up machine" RUN apt-get update RUN apt-get install -y curl tar -RUN DEBIAN_FRONTEND=noninteractive apt-get install -y git python3 pip wget ffmpeg +RUN DEBIAN_FRONTEND=noninteractive apt-get install -y git ffmpeg + +# Install prerequisites +RUN apt-get update && apt-get install -y \ + wget \ + build-essential \ + zlib1g-dev \ + libncurses5-dev \ + libssl-dev \ + libreadline-dev \ + libffi-dev \ + curl \ + libbz2-dev \ + liblzma-dev + +# Download and install Python 3.11 +RUN cd /tmp \ + && wget https://www.python.org/ftp/python/3.11.0/Python-3.11.0.tgz \ + && tar -xvzf Python-3.11.0.tgz \ + && cd Python-3.11.0 \ + && ./configure --enable-optimizations \ + && make -j$(nproc) \ + && make altinstall + +# Create symlinks for python and pip (use 'pip' instead of 'pip3') +RUN ln -s /usr/local/bin/python3.11 /usr/bin/python \ + && ln -s /usr/local/bin/pip3.11 /usr/bin/pip + RUN apt-get install libtcmalloc-minimal4 RUN apt-get install unzip RUN apt-get install pigz diff --git a/setup.cfg b/setup.cfg index 4afefd164..deeb1c6c4 100644 --- a/setup.cfg +++ b/setup.cfg @@ -24,6 +24,7 @@ classifiers = Programming Language :: Python :: 3.8 Programming Language :: Python :: 3.9 Programming Language :: Python :: 3.10 + Programming Language :: Python :: 3.11 Topic :: Scientific/Engineering :: Artificial Intelligence [options] @@ -34,22 +35,22 @@ setup_requires = setuptools_scm # Dependencies of the project: install_requires = - absl-py==1.4.0 + absl-py==2.1.1 # Pin to avoid unpinned install in dependencies that requires Python>=3.9. - networkx==3.1 - docker==7.0.0 - numpy>=1.23 - pandas>=2.0.1 - tensorflow==2.12.0 - tensorflow-datasets==4.9.2 - tensorflow-probability==0.20.0 - tensorflow-addons==0.20.0 + networkx==3.2.1 + docker==7.1.0 + numpy>=1.26.4 + pandas==2.2.3 + tensorflow==2.18.0 + tensorflow-datasets==4.9.7 + tensorflow-addons==0.23.0 gputil==1.4.0 - psutil==5.9.5 - clu==0.0.7 - matplotlib>=3.7.2 + psutil==6.1.0 + clu==0.0.12 + matplotlib>=3.9.2 tabulate==0.9.0 -python_requires = >=3.8 + wandb==0.18.7 +python_requires = >=3.11 ############################################################################### @@ -79,78 +80,77 @@ full_dev = # Dependencies for developing the package dev = - isort==5.12.0 - pylint==2.17.4 - pytest==7.3.1 - yapf==0.33.0 - pre-commit==3.3.1 + isort==5.13.2 + pylint==3.3.1 + pytest==8.3.3 + yapf==0.43.0 + pre-commit==4.0.1 # Workloads # criteo1tb = - scikit-learn==1.2.2 + scikit-learn==1.5.2 fastmri = - h5py==3.8.0 - scikit-image==0.20.0 + h5py==3.12.1 + scikit-image==0.24.0 ogbg = jraph==0.0.6.dev0 - scikit-learn==1.2.2 + scikit-learn==1.5.2 librispeech_conformer = - sentencepiece==0.1.99 - tensorflow-text==2.12.1 + sentencepiece==0.2.0 + tensorflow-text==2.18.0 pydub==0.25.1 wmt = - sentencepiece==0.1.99 - tensorflow-text==2.12.1 - sacrebleu==1.3.1 + sentencepiece==0.2.0 + tensorflow-text==2.18.0 + sacrebleu==2.4.3 # Frameworks # # JAX Core jax_core_deps = - flax==0.6.10 - optax==0.1.5 + flax==0.10.1 + optax==0.2.4 # Fix chex (optax dependency) version. # Not fixing it can raise dependency issues with our # jax version. # Todo(kasimbeg): verify if this is necessary after we # upgrade jax. - chex==0.1.7 - ml_dtypes==0.2.0 - protobuf==4.25.3 + chex==0.1.87 + ml_dtypes==0.4.1 + protobuf==4.25.5 # JAX CPU jax_cpu = - jax==0.4.10 - jaxlib==0.4.10 + jax==0.4.35 + jaxlib==0.4.35 %(jax_core_deps)s # JAX GPU # Note this installs both jax and jaxlib. jax_gpu = - jax==0.4.10 - jaxlib==0.4.10+cuda12.cudnn88 + jax==0.4.35 + jaxlib==0.4.35 + jax-cuda12-plugin[with_cuda]==0.4.35 + jax-cuda12-pjrt==0.4.35 %(jax_core_deps)s # PyTorch CPU pytorch_cpu = - torch==2.1.0 - torchvision==0.16.0 + torch==2.5.0 + torchvision==0.20.0 # PyTorch GPU # Note: omit the cuda suffix and installing from the appropriate # wheel will result in using locally installed CUDA. pytorch_gpu = - torch==2.1.0 - torchvision==0.16.0 + torch==2.5.0 + torchvision==0.20.0 -# wandb -wandb = - wandb==0.16.5 ############################################################################### # Linting Configurations # From 21fb3f902d5744c8331be89f896c2376977f7f12 Mon Sep 17 00:00:00 2001 From: init-22 Date: Thu, 14 Nov 2024 23:46:17 +0530 Subject: [PATCH 016/105] fix: absl package version change --- docker/Dockerfile | 12 +++++++----- setup.cfg | 2 +- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/docker/Dockerfile b/docker/Dockerfile index 24d05b495..497ffb2c1 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -28,9 +28,9 @@ RUN apt-get update && apt-get install -y \ # Download and install Python 3.11 RUN cd /tmp \ - && wget https://www.python.org/ftp/python/3.11.0/Python-3.11.0.tgz \ - && tar -xvzf Python-3.11.0.tgz \ - && cd Python-3.11.0 \ + && wget https://www.python.org/ftp/python/3.11.10/Python-3.11.10.tgz \ + && tar -xvzf Python-3.11.10.tgz \ + && cd Python-3.11.10 \ && ./configure --enable-optimizations \ && make -j$(nproc) \ && make altinstall @@ -55,11 +55,13 @@ RUN echo "Setting up directories for data and experiment_runs" RUN mkdir -p data/ RUN mkdir -p experiment_runs/ +RUN pip install --upgrade pip + # Install Algorithmic efficiency repo RUN echo "Setting up algorithmic_efficiency repo" -ARG branch="main" +ARG branch="python311" ARG framework="both" -ARG git_url=https://github.com/mlcommons/algorithmic-efficiency.git +ARG git_url=https://github.com/init-22/algorithmic-efficiency.git RUN git clone $git_url && cd /algorithmic-efficiency RUN cd /algorithmic-efficiency && git checkout $branch diff --git a/setup.cfg b/setup.cfg index deeb1c6c4..e952513df 100644 --- a/setup.cfg +++ b/setup.cfg @@ -35,7 +35,7 @@ setup_requires = setuptools_scm # Dependencies of the project: install_requires = - absl-py==2.1.1 + absl-py==2.1.0 # Pin to avoid unpinned install in dependencies that requires Python>=3.9. networkx==3.2.1 docker==7.1.0 From 67b9f15108486a1a29b348031e1b50a82fa55b40 Mon Sep 17 00:00:00 2001 From: init-22 Date: Fri, 15 Nov 2024 00:09:04 +0530 Subject: [PATCH 017/105] fix: pytorch version change --- setup.cfg | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/setup.cfg b/setup.cfg index e952513df..a74faa197 100644 --- a/setup.cfg +++ b/setup.cfg @@ -141,15 +141,15 @@ jax_gpu = # PyTorch CPU pytorch_cpu = - torch==2.5.0 - torchvision==0.20.0 + torch==2.5.1 + torchvision==0.20.1 # PyTorch GPU # Note: omit the cuda suffix and installing from the appropriate # wheel will result in using locally installed CUDA. pytorch_gpu = - torch==2.5.0 - torchvision==0.20.0 + torch==2.5.1 + torchvision==0.20.1 ############################################################################### From 78df36f2f0f173ad651b81527cda8d55f85028b0 Mon Sep 17 00:00:00 2001 From: init-22 Date: Fri, 15 Nov 2024 00:42:26 +0530 Subject: [PATCH 018/105] fix: tf version to use numpy < 2 --- docker/Dockerfile | 2 -- setup.cfg | 4 ++-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/docker/Dockerfile b/docker/Dockerfile index 497ffb2c1..88fc55243 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -87,8 +87,6 @@ RUN if [ "$framework" = "jax" ] ; then \ RUN cd /algorithmic-efficiency && pip install -e '.[full]' -RUN cd /algorithmic-efficiency && pip install -e '.[wandb]' - RUN cd /algorithmic-efficiency && git fetch origin RUN cd /algorithmic-efficiency && git pull diff --git a/setup.cfg b/setup.cfg index a74faa197..2a300469a 100644 --- a/setup.cfg +++ b/setup.cfg @@ -41,7 +41,7 @@ install_requires = docker==7.1.0 numpy>=1.26.4 pandas==2.2.3 - tensorflow==2.18.0 + tensorflow==2.17.0 tensorflow-datasets==4.9.7 tensorflow-addons==0.23.0 gputil==1.4.0 @@ -105,7 +105,7 @@ librispeech_conformer = wmt = sentencepiece==0.2.0 - tensorflow-text==2.18.0 + tensorflow-text==2.17.0 sacrebleu==2.4.3 # Frameworks # From 76b084b556af6bd58d1fbf40d5215cce510146b9 Mon Sep 17 00:00:00 2001 From: init-22 Date: Fri, 15 Nov 2024 00:55:52 +0530 Subject: [PATCH 019/105] fix: removed cifar10 and mnist --- DOCUMENTATION.md | 4 ---- 1 file changed, 4 deletions(-) diff --git a/DOCUMENTATION.md b/DOCUMENTATION.md index 0c9c429c6..990656d38 100644 --- a/DOCUMENTATION.md +++ b/DOCUMENTATION.md @@ -421,19 +421,15 @@ Default Dropout Values for Different Workloads: | Workload | Dropout Values | |------------------------|------------------------------------------------------------------------------------------------------| -| cifar | dropout not used | | criteo 1tb | dropout_rate: 0.0 | | fastmri | dropout_rate: 0.0 | | imagenet_resnet | dropout not used | | imagenet_vit | dropout_rate: 0.0 | | librispeech_conformer | attention_dropout_rate: 0.0
attention_residual_dropout_rate: 0.1
conv_residual_dropout_rate: 0.0
feed_forward_dropout_rate: 0.0
feed_forward_residual_dropout_rate: 0.1
input_dropout_rate: 0.1 | | librispeech_deepspeech | input_dropout_rate: 0.1
feed_forward_dropout_rate: 0.1
(Only for JAX - dropout_rate in CudnnLSTM class: 0.0) | -| mnist | dropout not used | | ogbg | dropout_rate: 0.1 | | wmt | dropout_rate: 0.1
attention_dropout_rate: 0.1 | - - #### Randomized workloads In addition to the [fixed and known workloads](#fixed-workloads), there will also be randomized workloads in our benchmark. These randomized workloads will introduce minor modifications to a fixed workload (e.g. small model changes). The exact instances of these randomized workloads will only be created after the submission deadline and are thus unknown to both the submitters as well as the benchmark organizers. The instructions for creating them, i.e. providing a set or distribution of workloads to sample from, will be defined by this working group and made public with the call for submissions, to allow the members of this working group to submit as well as ensure that they do not possess any additional information compared to other submitters. We will refer to the unspecific workloads as *randomized workloads*, e.g. the set or distribution. The specific instance of such a randomized workload we call a *held-out workload*. That is, a held-out workload is a specific sample of a randomized workload that is used for one iteration of the benchmark. While we may reuse randomized workloads between iterations of the benchmark, new held-out workloads will be sampled for each new benchmark iteration. From b5ad298c57659b742eb4061e9a80d152eb13abce Mon Sep 17 00:00:00 2001 From: init-22 Date: Fri, 15 Nov 2024 18:58:00 +0530 Subject: [PATCH 020/105] fix: changing PRNGkey in random_utils to key --- algorithmic_efficiency/random_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algorithmic_efficiency/random_utils.py b/algorithmic_efficiency/random_utils.py index 31317047e..91b415cd7 100644 --- a/algorithmic_efficiency/random_utils.py +++ b/algorithmic_efficiency/random_utils.py @@ -75,5 +75,5 @@ def split(seed: SeedType, num: int = 2) -> SeedType: def PRNGKey(seed: SeedType) -> SeedType: # pylint: disable=invalid-name if FLAGS.framework == 'jax': _check_jax_install() - return jax_rng.PRNGKey(seed) + return jax_rng.key(seed) return _PRNGKey(seed) From 2584416e8cc82bb61ef7a1d2a395a25da919f93f Mon Sep 17 00:00:00 2001 From: init-22 Date: Fri, 15 Nov 2024 19:09:40 +0530 Subject: [PATCH 021/105] fix: librispeech requirement of tf-text rolled back to v2.17 --- setup.cfg | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index 2a300469a..078b694b8 100644 --- a/setup.cfg +++ b/setup.cfg @@ -100,7 +100,7 @@ ogbg = librispeech_conformer = sentencepiece==0.2.0 - tensorflow-text==2.18.0 + tensorflow-text==2.17.0 pydub==0.25.1 wmt = From d603ce921b211918ce0e3d27742032f5e7ece674 Mon Sep 17 00:00:00 2001 From: init-22 Date: Fri, 15 Nov 2024 19:11:38 +0530 Subject: [PATCH 022/105] fix: using the main repo and branch for testing --- docker/Dockerfile | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docker/Dockerfile b/docker/Dockerfile index 88fc55243..ee9136cbf 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -59,9 +59,9 @@ RUN pip install --upgrade pip # Install Algorithmic efficiency repo RUN echo "Setting up algorithmic_efficiency repo" -ARG branch="python311" +ARG branch="main" ARG framework="both" -ARG git_url=https://github.com/init-22/algorithmic-efficiency.git +ARG git_url=https://github.com/mlcommons/algorithmic-efficiency.git RUN git clone $git_url && cd /algorithmic-efficiency RUN cd /algorithmic-efficiency && git checkout $branch From 98198682e2798f7f0a9ef6b374d6a9346ec8ec22 Mon Sep 17 00:00:00 2001 From: init-22 Date: Sat, 16 Nov 2024 12:30:47 +0530 Subject: [PATCH 023/105] fix: changing the range of MAX_UINT32 --- algorithmic_efficiency/random_utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/algorithmic_efficiency/random_utils.py b/algorithmic_efficiency/random_utils.py index 91b415cd7..93dc263bd 100644 --- a/algorithmic_efficiency/random_utils.py +++ b/algorithmic_efficiency/random_utils.py @@ -18,7 +18,7 @@ # Annoyingly, RandomState(seed) requires seed to be in [0, 2 ** 32 - 1] (an # unsigned int), while RandomState.randint only accepts and returns signed ints. -MAX_UINT32 = 2**31 +MAX_UINT32 = 2**32-1 MIN_UINT32 = 0 SeedType = Union[int, list, np.ndarray] @@ -26,11 +26,11 @@ def _signed_to_unsigned(seed: SeedType) -> SeedType: if isinstance(seed, int): - return seed % 2**32 + return seed % MAX_UINT32 if isinstance(seed, list): - return [s % 2**32 for s in seed] + return [s % MAX_UINT32 for s in seed] if isinstance(seed, np.ndarray): - return np.array([s % 2**32 for s in seed.tolist()]) + return np.array([s % MAX_UINT32 for s in seed.tolist()]) def _fold_in(seed: SeedType, data: Any) -> List[Union[SeedType, Any]]: From 4b2e64e34e672212e4fe947674c508a410de9ef0 Mon Sep 17 00:00:00 2001 From: init-22 Date: Sat, 16 Nov 2024 12:58:26 +0530 Subject: [PATCH 024/105] bringing back PRNGKey instead of key, till the python311 branch is merged --- algorithmic_efficiency/random_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algorithmic_efficiency/random_utils.py b/algorithmic_efficiency/random_utils.py index 93dc263bd..bcfc59c92 100644 --- a/algorithmic_efficiency/random_utils.py +++ b/algorithmic_efficiency/random_utils.py @@ -75,5 +75,5 @@ def split(seed: SeedType, num: int = 2) -> SeedType: def PRNGKey(seed: SeedType) -> SeedType: # pylint: disable=invalid-name if FLAGS.framework == 'jax': _check_jax_install() - return jax_rng.key(seed) + return jax_rng.PRNGKey(seed) return _PRNGKey(seed) From be68f8cbf4a528804c78eff886ffd7e36e04fca8 Mon Sep 17 00:00:00 2001 From: init-22 Date: Sat, 16 Nov 2024 13:57:11 +0530 Subject: [PATCH 025/105] fix: overflow error resolved and PRNGKey to key --- algorithmic_efficiency/checkpoint_utils.py | 2 +- algorithmic_efficiency/random_utils.py | 10 +++++----- setup.cfg | 8 ++++---- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/algorithmic_efficiency/checkpoint_utils.py b/algorithmic_efficiency/checkpoint_utils.py index 29c1a821e..04dad0eb7 100644 --- a/algorithmic_efficiency/checkpoint_utils.py +++ b/algorithmic_efficiency/checkpoint_utils.py @@ -231,7 +231,7 @@ def save_checkpoint(framework: str, target=checkpoint_state, step=global_step, overwrite=True, - keep=np.Inf if save_intermediate_checkpoints else 1) + keep=np.inf if save_intermediate_checkpoints else 1) else: if not save_intermediate_checkpoints: checkpoint_files = gfile.glob( diff --git a/algorithmic_efficiency/random_utils.py b/algorithmic_efficiency/random_utils.py index 31317047e..93dc263bd 100644 --- a/algorithmic_efficiency/random_utils.py +++ b/algorithmic_efficiency/random_utils.py @@ -18,7 +18,7 @@ # Annoyingly, RandomState(seed) requires seed to be in [0, 2 ** 32 - 1] (an # unsigned int), while RandomState.randint only accepts and returns signed ints. -MAX_UINT32 = 2**31 +MAX_UINT32 = 2**32-1 MIN_UINT32 = 0 SeedType = Union[int, list, np.ndarray] @@ -26,11 +26,11 @@ def _signed_to_unsigned(seed: SeedType) -> SeedType: if isinstance(seed, int): - return seed % 2**32 + return seed % MAX_UINT32 if isinstance(seed, list): - return [s % 2**32 for s in seed] + return [s % MAX_UINT32 for s in seed] if isinstance(seed, np.ndarray): - return np.array([s % 2**32 for s in seed.tolist()]) + return np.array([s % MAX_UINT32 for s in seed.tolist()]) def _fold_in(seed: SeedType, data: Any) -> List[Union[SeedType, Any]]: @@ -75,5 +75,5 @@ def split(seed: SeedType, num: int = 2) -> SeedType: def PRNGKey(seed: SeedType) -> SeedType: # pylint: disable=invalid-name if FLAGS.framework == 'jax': _check_jax_install() - return jax_rng.PRNGKey(seed) + return jax_rng.key(seed) return _PRNGKey(seed) diff --git a/setup.cfg b/setup.cfg index 078b694b8..6e6a1c957 100644 --- a/setup.cfg +++ b/setup.cfg @@ -39,9 +39,9 @@ install_requires = # Pin to avoid unpinned install in dependencies that requires Python>=3.9. networkx==3.2.1 docker==7.1.0 - numpy>=1.26.4 + numpy>=2.1.3 pandas==2.2.3 - tensorflow==2.17.0 + tensorflow==2.18.0 tensorflow-datasets==4.9.7 tensorflow-addons==0.23.0 gputil==1.4.0 @@ -100,12 +100,12 @@ ogbg = librispeech_conformer = sentencepiece==0.2.0 - tensorflow-text==2.17.0 + tensorflow-text==2.18.0 pydub==0.25.1 wmt = sentencepiece==0.2.0 - tensorflow-text==2.17.0 + tensorflow-text==2.18.0 sacrebleu==2.4.3 # Frameworks # From f72028f1236ab562a6d1d50f43ec099de9325bbb Mon Sep 17 00:00:00 2001 From: init-22 Date: Tue, 19 Nov 2024 21:07:46 +0530 Subject: [PATCH 026/105] fix: ran yapf for passing the checks --- scoring/score_submissions.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/scoring/score_submissions.py b/scoring/score_submissions.py index 59295b686..8cc06b15f 100644 --- a/scoring/score_submissions.py +++ b/scoring/score_submissions.py @@ -211,7 +211,8 @@ def main(_): verbosity=0, self_tuning_ruleset=FLAGS.self_tuning_ruleset, strict=FLAGS.strict, - output_dir=FLAGS.output_dir,) + output_dir=FLAGS.output_dir, + ) if not os.path.exists(FLAGS.output_dir): os.mkdir(FLAGS.output_dir) performance_profile.plot_performance_profiles( From 2b8b771d169bb97a257e735ba3c26d01d6000396 Mon Sep 17 00:00:00 2001 From: init-22 Date: Tue, 19 Nov 2024 21:11:19 +0530 Subject: [PATCH 027/105] fix: ran yapf for passing the checks --- algorithmic_efficiency/random_utils.py | 2 +- scoring/score_submissions.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/algorithmic_efficiency/random_utils.py b/algorithmic_efficiency/random_utils.py index bcfc59c92..f40a98003 100644 --- a/algorithmic_efficiency/random_utils.py +++ b/algorithmic_efficiency/random_utils.py @@ -18,7 +18,7 @@ # Annoyingly, RandomState(seed) requires seed to be in [0, 2 ** 32 - 1] (an # unsigned int), while RandomState.randint only accepts and returns signed ints. -MAX_UINT32 = 2**32-1 +MAX_UINT32 = 2**32 - 1 MIN_UINT32 = 0 SeedType = Union[int, list, np.ndarray] diff --git a/scoring/score_submissions.py b/scoring/score_submissions.py index 59295b686..8cc06b15f 100644 --- a/scoring/score_submissions.py +++ b/scoring/score_submissions.py @@ -211,7 +211,8 @@ def main(_): verbosity=0, self_tuning_ruleset=FLAGS.self_tuning_ruleset, strict=FLAGS.strict, - output_dir=FLAGS.output_dir,) + output_dir=FLAGS.output_dir, + ) if not os.path.exists(FLAGS.output_dir): os.mkdir(FLAGS.output_dir) performance_profile.plot_performance_profiles( From e890c893297a6e64cbfdc6d63f87ee7f7b4d385a Mon Sep 17 00:00:00 2001 From: init-22 Date: Wed, 20 Nov 2024 19:13:50 +0530 Subject: [PATCH 028/105] fix: minor changes in docs --- GETTING_STARTED.md | 2 +- algorithmic_efficiency/logger_utils.py | 2 +- setup.cfg | 3 --- 3 files changed, 2 insertions(+), 5 deletions(-) diff --git a/GETTING_STARTED.md b/GETTING_STARTED.md index 006b972ec..aa493bc9f 100644 --- a/GETTING_STARTED.md +++ b/GETTING_STARTED.md @@ -35,7 +35,7 @@ The specs on the benchmarking machines are: > **Prerequisites:** > -> - Python minimum requirement >= 3.8 +> - Python minimum requirement >= 3.11 > - CUDA 12.1 > - NVIDIA Driver version 535.104.05 diff --git a/algorithmic_efficiency/logger_utils.py b/algorithmic_efficiency/logger_utils.py index 609d996e6..155e55356 100644 --- a/algorithmic_efficiency/logger_utils.py +++ b/algorithmic_efficiency/logger_utils.py @@ -211,7 +211,7 @@ def _get_system_software_info() -> Dict: system_software_info['os_platform'] = \ platform.platform() # Ex. 'Linux-5.4.48-x86_64-with-glibc2.29' system_software_info['python_version'] = platform.python_version( - ) # Ex. '3.8.10' + ) # Ex. '3.11.10' system_software_info['python_compiler'] = platform.python_compiler( ) # Ex. 'GCC 9.3.0' # Note: do not store hostname as that may be sensitive diff --git a/setup.cfg b/setup.cfg index 6e6a1c957..5023f1ba6 100644 --- a/setup.cfg +++ b/setup.cfg @@ -21,9 +21,6 @@ classifiers = Intended Audience :: Science/Research License :: OSI Approved :: Apache Software License Operating System :: OS Independent - Programming Language :: Python :: 3.8 - Programming Language :: Python :: 3.9 - Programming Language :: Python :: 3.10 Programming Language :: Python :: 3.11 Topic :: Scientific/Engineering :: Artificial Intelligence From d8f07b73c7a6b0d049c513e4b846696ed7df1da8 Mon Sep 17 00:00:00 2001 From: init-22 Date: Fri, 22 Nov 2024 22:15:01 +0530 Subject: [PATCH 029/105] fix: triggering the checks again --- scoring/performance_profile.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scoring/performance_profile.py b/scoring/performance_profile.py index 0d5ca9770..f4f2d5679 100644 --- a/scoring/performance_profile.py +++ b/scoring/performance_profile.py @@ -321,7 +321,7 @@ def compute_performance_profiles(submissions, df = df[BASE_WORKLOADS + HELDOUT_WORKLOADS] # Sort workloads alphabetically (for better display) df = df.reindex(sorted(df.columns), axis=1) - + # Save time to target dataframe df.to_csv(os.path.join(output_dir, 'time_to_targets.csv')) # For each held-out workload set to inf if the base workload is inf or nan From ff176d7719d55238a898f29c3fe5e3812bd6a06c Mon Sep 17 00:00:00 2001 From: init-22 Date: Fri, 22 Nov 2024 22:28:00 +0530 Subject: [PATCH 030/105] fix: triggering the checks again --- scoring/performance_profile.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scoring/performance_profile.py b/scoring/performance_profile.py index 0d5ca9770..f4f2d5679 100644 --- a/scoring/performance_profile.py +++ b/scoring/performance_profile.py @@ -321,7 +321,7 @@ def compute_performance_profiles(submissions, df = df[BASE_WORKLOADS + HELDOUT_WORKLOADS] # Sort workloads alphabetically (for better display) df = df.reindex(sorted(df.columns), axis=1) - + # Save time to target dataframe df.to_csv(os.path.join(output_dir, 'time_to_targets.csv')) # For each held-out workload set to inf if the base workload is inf or nan From 579a4850045de8b7349686105564709ded589b64 Mon Sep 17 00:00:00 2001 From: EIFY Date: Tue, 26 Nov 2024 14:20:16 -0800 Subject: [PATCH 031/105] fix pytorch_default_init() torch.nn.init.trunc_normal_() defaults to truncation at (a, b), not (a * std, b * std). --- algorithmic_efficiency/init_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algorithmic_efficiency/init_utils.py b/algorithmic_efficiency/init_utils.py index 66ed041ce..185480cc7 100644 --- a/algorithmic_efficiency/init_utils.py +++ b/algorithmic_efficiency/init_utils.py @@ -13,6 +13,6 @@ def pytorch_default_init(module: nn.Module) -> None: # Perform lecun_normal initialization. fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight) std = math.sqrt(1. / fan_in) / .87962566103423978 - nn.init.trunc_normal_(module.weight, std=std) + nn.init.trunc_normal_(module.weight, std=std, a=-2 * std, b=2 * std) if module.bias is not None: nn.init.constant_(module.bias, 0.) From 1bc2a7b2d5de45309bbcab035bff587c9f19ef27 Mon Sep 17 00:00:00 2001 From: init-22 Date: Sat, 30 Nov 2024 13:07:46 +0530 Subject: [PATCH 032/105] fix: changing the python versions in workflow to pass the tests --- .github/workflows/CI.yml | 48 +++++++++++++------------- .github/workflows/linting.yml | 12 +++---- .github/workflows/traindiffs_tests.yml | 2 +- 3 files changed, 31 insertions(+), 31 deletions(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 05d94e896..fe2441bfe 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -7,10 +7,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - - name: Set up Python 3.9 + - name: Set up Python 3.11.10 uses: actions/setup-python@v4 with: - python-version: 3.9 + python-version: 3.11.10 cache: 'pip' # Cache pip dependencies\. cache-dependency-path: '**/setup.py' - name: Install Modules and Run @@ -25,10 +25,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - - name: Set up Python 3.9 + - name: Set up Python 3.11.10 uses: actions/setup-python@v4 with: - python-version: 3.9 + python-version: 3.11.10 cache: 'pip' # Cache pip dependencies\. cache-dependency-path: '**/setup.py' - name: Install Modules and Run @@ -42,10 +42,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - - name: Set up Python 3.9 + - name: Set up Python 3.11.10 uses: actions/setup-python@v4 with: - python-version: 3.9 + python-version: 3.11.10 cache: 'pip' # Cache pip dependencies\. cache-dependency-path: '**/setup.py' - name: Install Modules and Run @@ -59,10 +59,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - - name: Set up Python 3.9 + - name: Set up Python 3.11.10 uses: actions/setup-python@v4 with: - python-version: 3.9 + python-version: 3.11.10 cache: 'pip' # Cache pip dependencies\. cache-dependency-path: '**/setup.py' - name: Install Modules and Run @@ -77,10 +77,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - - name: Set up Python 3.9 + - name: Set up Python 3.11.10 uses: actions/setup-python@v4 with: - python-version: 3.9 + python-version: 3.11.10 cache: 'pip' # Cache pip dependencies\. cache-dependency-path: '**/setup.py' - name: Install Modules and Run @@ -96,10 +96,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - - name: Set up Python 3.9 + - name: Set up Python 3.11.10 uses: actions/setup-python@v4 with: - python-version: 3.9 + python-version: 3.11.10 cache: 'pip' # Cache pip dependencies\. cache-dependency-path: '**/setup.py' - name: Install Modules and Run @@ -113,10 +113,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - - name: Set up Python 3.9 + - name: Set up Python 3.11.10 uses: actions/setup-python@v4 with: - python-version: 3.9 + python-version: 3.11.10 cache: 'pip' # Cache pip dependencies\. cache-dependency-path: '**/setup.py' - name: Install Modules and Run @@ -130,10 +130,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - - name: Set up Python 3.9 + - name: Set up Python 3.11.10 uses: actions/setup-python@v4 with: - python-version: 3.9 + python-version: 3.11.10 cache: 'pip' # Cache pip dependencies\. cache-dependency-path: '**/setup.py' - name: Install Modules and Run @@ -148,10 +148,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - - name: Set up Python 3.9 + - name: Set up Python 3.11.10 uses: actions/setup-python@v4 with: - python-version: 3.9 + python-version: 3.11.10 cache: 'pip' # Cache pip dependencies\. cache-dependency-path: '**/setup.py' - name: Install Modules and Run @@ -166,10 +166,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - - name: Set up Python 3.9 + - name: Set up Python 3.11.10 uses: actions/setup-python@v4 with: - python-version: 3.9 + python-version: 3.11.10 cache: 'pip' # Cache pip dependencies\. cache-dependency-path: '**/setup.py' - name: Install Modules and Run @@ -184,10 +184,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - - name: Set up Python 3.9 + - name: Set up Python 3.11.10 uses: actions/setup-python@v4 with: - python-version: 3.9 + python-version: 3.11.10 cache: 'pip' # Cache pip dependencies\. cache-dependency-path: '**/setup.py' - name: Install pytest @@ -208,10 +208,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - - name: Set up Python 3.9 + - name: Set up Python 3.11.10 uses: actions/setup-python@v4 with: - python-version: 3.9 + python-version: 3.11.10 cache: 'pip' # Cache pip dependencies\. cache-dependency-path: '**/setup.py' - name: Install pytest diff --git a/.github/workflows/linting.yml b/.github/workflows/linting.yml index 89b5ef288..628fc012b 100644 --- a/.github/workflows/linting.yml +++ b/.github/workflows/linting.yml @@ -7,10 +7,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 - - name: Set up Python 3.9 + - name: Set up Python 3.11.10 uses: actions/setup-python@v2 with: - python-version: 3.9 + python-version: 3.11.10 - name: Install pylint run: | python -m pip install --upgrade pip @@ -27,10 +27,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 - - name: Set up Python 3.9 + - name: Set up Python 3.11.10 uses: actions/setup-python@v2 with: - python-version: 3.9 + python-version: 3.11.10 - name: Install isort run: | python -m pip install --upgrade pip @@ -43,10 +43,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 - - name: Set up Python 3.9 + - name: Set up Python 3.11.10 uses: actions/setup-python@v2 with: - python-version: 3.9 + python-version: 3.11.10 - name: Install yapf run: | python -m pip install --upgrade pip diff --git a/.github/workflows/traindiffs_tests.yml b/.github/workflows/traindiffs_tests.yml index 382f0dfe1..a2fdcb453 100644 --- a/.github/workflows/traindiffs_tests.yml +++ b/.github/workflows/traindiffs_tests.yml @@ -3,7 +3,7 @@ name: Containerized Training Differences Tests Jax vs PyTorch on: pull_request: branches: - - 'main' + - 'python311' jobs: build_and_push_docker_image: From 7a0fee3224e3d4e8602a2aca2819358bf97acf00 Mon Sep 17 00:00:00 2001 From: init-22 Date: Sat, 30 Nov 2024 22:59:52 +0530 Subject: [PATCH 033/105] fix: changing numpy compatible version --- setup.cfg | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index 5023f1ba6..0aa4dce49 100644 --- a/setup.cfg +++ b/setup.cfg @@ -36,7 +36,7 @@ install_requires = # Pin to avoid unpinned install in dependencies that requires Python>=3.9. networkx==3.2.1 docker==7.1.0 - numpy>=2.1.3 + numpy>=2.0.2 pandas==2.2.3 tensorflow==2.18.0 tensorflow-datasets==4.9.7 From 7cdea1638ceb2a3c0019e95c0a63f0c36605064a Mon Sep 17 00:00:00 2001 From: init-22 Date: Sat, 30 Nov 2024 23:07:52 +0530 Subject: [PATCH 034/105] adding key_data to check the CI tests --- submission_runner.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/submission_runner.py b/submission_runner.py index 551173bf5..0024c35d4 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -210,7 +210,7 @@ def train_once( ) -> Tuple[spec.Timing, Dict[str, Any]]: _reset_cuda_mem() data_rng, opt_init_rng, model_init_rng, rng = prng.split(rng, 4) - + data_rng = jax.random.key_data(data_rng) # Workload setup. logging.info('Initializing dataset.') if hasattr(workload, '_eval_num_workers'): @@ -336,7 +336,7 @@ def train_once( step_rng = prng.fold_in(rng, global_step) data_select_rng, update_rng, eval_rng = prng.split(step_rng, 3) - + eval_rng = jax.random.key_data(eval_rng) with profiler.profile('Data selection'): batch = data_selection(workload, input_queue, From 7264c3f80d0bd38a1c50f107d715765a7c76dcdc Mon Sep 17 00:00:00 2001 From: init-22 Date: Sun, 1 Dec 2024 14:41:05 +0530 Subject: [PATCH 035/105] fix: updated packge of sacrebleu changed the way it used to work, hence using the corpus_bleu from the main package --- algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py | 3 ++- setup.cfg | 1 - 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py b/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py index 0ba49c2f6..327ca34ad 100644 --- a/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py @@ -5,6 +5,7 @@ from absl import logging import jax +import sacrebleu import tensorflow as tf import torch import torch.distributed as dist @@ -162,7 +163,7 @@ def translate_and_calculate_bleu(self, predictions.append(self._decode_tokens(predicted[idx])) # Calculate BLEU score for translated eval corpus against reference. - bleu_score = bleu.corpus_bleu(predictions, [references]).score + bleu_score = sacrebleu.corpus_bleu(predictions, [references]).score return bleu_score def init_model_fn( diff --git a/setup.cfg b/setup.cfg index 0aa4dce49..23e86a13b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -104,7 +104,6 @@ wmt = sentencepiece==0.2.0 tensorflow-text==2.18.0 sacrebleu==2.4.3 - # Frameworks # # JAX Core From abbdc8262917fd8e38ba954f8cdaf478a5d8d1c7 Mon Sep 17 00:00:00 2001 From: init-22 Date: Sun, 1 Dec 2024 16:11:01 +0530 Subject: [PATCH 036/105] fix: temporarily commenting tfa --- .../workloads/imagenet_resnet/imagenet_jax/randaugment.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py index 5f92b1482..d0bbecb8f 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py +++ b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py @@ -8,7 +8,7 @@ import math import tensorflow as tf -from tensorflow_addons import image as contrib_image +#from tensorflow_addons import image as contrib_image # This signifies the max integer that the controller RNN could predict for the # augmentation scheme. From 86029a742094a653e5bf9a6f17f0d42c0990671d Mon Sep 17 00:00:00 2001 From: init-22 Date: Mon, 2 Dec 2024 22:10:24 +0530 Subject: [PATCH 037/105] fix: explicitly using mask kwarg to use MultiHeadDotProductAttention and also using sacrebleu --- .../workloads/imagenet_resnet/imagenet_jax/randaugment.py | 1 + algorithmic_efficiency/workloads/wmt/wmt_jax/models.py | 6 +++--- algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py | 3 ++- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py index d0bbecb8f..af1b763c1 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py +++ b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py @@ -8,6 +8,7 @@ import math import tensorflow as tf + #from tensorflow_addons import image as contrib_image # This signifies the max integer that the controller RNN could predict for the diff --git a/algorithmic_efficiency/workloads/wmt/wmt_jax/models.py b/algorithmic_efficiency/workloads/wmt/wmt_jax/models.py index e4b5cd014..7bbc0b168 100644 --- a/algorithmic_efficiency/workloads/wmt/wmt_jax/models.py +++ b/algorithmic_efficiency/workloads/wmt/wmt_jax/models.py @@ -224,7 +224,7 @@ def __call__(self, inputs, encoder_mask=None): dropout_rate=attention_dropout_rate, deterministic=cfg.deterministic)(cfg.attention_temp * x, x, - encoder_mask) + mask=encoder_mask) if cfg.dropout_rate is None: dropout_rate = 0.1 @@ -288,7 +288,7 @@ def __call__(self, broadcast_dropout=False, dropout_rate=attention_dropout_rate, deterministic=cfg.deterministic, - decode=cfg.decode)(cfg.attention_temp * x, x, decoder_mask) + decode=cfg.decode)(cfg.attention_temp * x, x, mask=decoder_mask) if cfg.dropout_rate is None: dropout_rate = 0.1 else: @@ -311,7 +311,7 @@ def __call__(self, dropout_rate=attention_dropout_rate, deterministic=cfg.deterministic)(cfg.attention_temp * y, encoded, - encoder_decoder_mask) + mask=encoder_decoder_mask) y = nn.Dropout(rate=dropout_rate)(y, deterministic=cfg.deterministic) y = y + x diff --git a/algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py b/algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py index 046d5e469..442c85899 100644 --- a/algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py +++ b/algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py @@ -12,6 +12,7 @@ import jax.numpy as jnp import numpy as np import optax +import sacrebleu from algorithmic_efficiency import param_utils from algorithmic_efficiency import spec @@ -203,7 +204,7 @@ def translate_and_calculate_bleu(self, predictions.append(self._decode_tokens(predicted[idx])) # Calculate BLEU score for translated eval corpus against reference. - bleu_score = bleu.corpus_bleu(predictions, [references]).score + bleu_score = sacrebleu.corpus_bleu(predictions, [references]).score return bleu_score def init_model_fn( From aca45a2b1e1df7e42a5108df8e30d49baf6ef6e2 Mon Sep 17 00:00:00 2001 From: init-22 Date: Mon, 2 Dec 2024 22:42:21 +0530 Subject: [PATCH 038/105] fix: using flax.core.pop instead of variables.pop, better way to update batch_stats --- .../workloads/imagenet_resnet/imagenet_jax/workload.py | 7 ++++--- .../workloads/imagenet_vit/imagenet_jax/workload.py | 3 ++- .../librispeech_conformer/librispeech_jax/workload.py | 7 ++++--- 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/workload.py b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/workload.py index d8de214f5..8ab4adbb9 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/workload.py +++ b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/workload.py @@ -11,6 +11,7 @@ from flax import jax_utils from flax import linen as nn +from flax.core import pop import jax from jax import lax import jax.numpy as jnp @@ -79,8 +80,8 @@ def sync_batch_stats( # In this case each device has its own version of the batch statistics and # we average them. avg_fn = jax.pmap(lambda x: lax.pmean(x, 'x'), 'x') - new_model_state = model_state.copy( - {'batch_stats': avg_fn(model_state['batch_stats'])}) + new_model_state = model_state.copy() # Create a shallow copy + new_model_state['batch_stats'] = avg_fn(model_state['batch_stats']) return new_model_state def init_model_fn( @@ -111,7 +112,7 @@ def init_model_fn( input_shape = (1, 224, 224, 3) variables = jax.jit(model.init)({'params': rng}, jnp.ones(input_shape, model.dtype)) - model_state, params = variables.pop('params') + model_state, params = pop(variables, "params") self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) model_state = jax_utils.replicate(model_state) diff --git a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/workload.py b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/workload.py index 2ad71ffd0..5f826d035 100644 --- a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/workload.py +++ b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/workload.py @@ -4,6 +4,7 @@ from flax import jax_utils from flax import linen as nn +from flax.core import pop import jax import jax.numpy as jnp @@ -28,7 +29,7 @@ def initialized(self, key: spec.RandomState, variables = jax.jit( model.init)({'params': params_rng, 'dropout': dropout_rng}, jnp.ones(input_shape)) - model_state, params = variables.pop('params') + model_state, params = pop(variables, "params") return params, model_state def init_model_fn( diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py index f4d1ab0f3..d805e8b17 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py @@ -3,6 +3,7 @@ from typing import Dict, Iterator, Optional, Tuple from flax import jax_utils +from flax.core import pop import flax.linen as nn import jax from jax import lax @@ -89,7 +90,7 @@ def init_model_fn( variables = model_init_fn({'params': params_rng, 'dropout': dropout_rng}, *fake_input_batch) - model_state, params = variables.pop('params') + model_state, params = pop(variables, "params") self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) @@ -374,8 +375,8 @@ def sync_batch_stats( # In this case each device has its own version of the batch statistics and # we average them. avg_fn = jax.pmap(lambda x: lax.pmean(x, 'x'), 'x') - new_model_state = model_state.copy( - {'batch_stats': avg_fn(model_state['batch_stats'])}) + new_model_state = model_state.copy() + new_model_state['batch_stats'] = avg_fn(model_state['batch_stats']) return new_model_state From 2618c5e6b1dcbdf48c2625f4cfbdca93fdc53993 Mon Sep 17 00:00:00 2001 From: init-22 Date: Mon, 2 Dec 2024 22:50:27 +0530 Subject: [PATCH 039/105] fix: changing the traindiffs_tests branch to main again --- .github/workflows/traindiffs_tests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/traindiffs_tests.yml b/.github/workflows/traindiffs_tests.yml index a2fdcb453..382f0dfe1 100644 --- a/.github/workflows/traindiffs_tests.yml +++ b/.github/workflows/traindiffs_tests.yml @@ -3,7 +3,7 @@ name: Containerized Training Differences Tests Jax vs PyTorch on: pull_request: branches: - - 'python311' + - 'main' jobs: build_and_push_docker_image: From 8c9062564c920e7fea8c3ee6abc8fce51d663c82 Mon Sep 17 00:00:00 2001 From: init-22 Date: Mon, 2 Dec 2024 23:23:09 +0530 Subject: [PATCH 040/105] fix: unfreeze() in test_param_shapes expect FrozenDict also added flax.core.pop instead of variables.pop --- .../workloads/cifar/cifar_jax/workload.py | 7 ++++--- tests/test_param_shapes.py | 6 +++++- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py b/algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py index b019d1cee..6ec90b99a 100644 --- a/algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py +++ b/algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py @@ -5,6 +5,7 @@ from flax import jax_utils from flax import linen as nn +from flax.core import pop import jax from jax import lax import jax.numpy as jnp @@ -75,8 +76,8 @@ def sync_batch_stats( # In this case each device has its own version of the batch statistics # and we average them. avg_fn = jax.pmap(lambda x: lax.pmean(x, 'x'), 'x') - new_model_state = model_state.copy( - {'batch_stats': avg_fn(model_state['batch_stats'])}) + new_model_state = model_state.copy() + new_model_state['batch_stats'] = avg_fn(model_state['batch_stats']) return new_model_state def init_model_fn( @@ -93,7 +94,7 @@ def init_model_fn( input_shape = (1, 32, 32, 3) variables = jax.jit(model.init)({'params': rng}, jnp.ones(input_shape, model.dtype)) - model_state, params = variables.pop('params') + model_state, params = pop(variables, 'params') self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) model_state = jax_utils.replicate(model_state) diff --git a/tests/test_param_shapes.py b/tests/test_param_shapes.py index b67625213..4ad56c873 100644 --- a/tests/test_param_shapes.py +++ b/tests/test_param_shapes.py @@ -3,6 +3,7 @@ import jax import numpy as np import pytest +from flax.core import FrozenDict # isort: skip_file # pylint:disable=line-too-long @@ -51,8 +52,11 @@ def test_param_shapes(workload): jax_workload, pytorch_workload = get_workload(workload) # Compare number of parameter tensors of both models. + jax_workload_param_shapes = jax_workload.param_shapes + if isinstance(jax_workload_param_shapes, dict): + jax_workload_param_shapes = FrozenDict(jax_workload_param_shapes) jax_param_shapes = jax.tree_util.tree_leaves( - jax_workload.param_shapes.unfreeze()) + jax_workload_param_shapes.unfreeze()) pytorch_param_shapes = jax.tree_util.tree_leaves( pytorch_workload.param_shapes) if workload == 'wmt': From 1b587b75890c39c3b3ebf5359b7f82b260e06bc6 Mon Sep 17 00:00:00 2001 From: init-22 Date: Tue, 3 Dec 2024 20:30:41 +0530 Subject: [PATCH 041/105] fix: formatting changes with yapf --- algorithmic_efficiency/profiler.py | 4 +-- algorithmic_efficiency/random_utils.py | 2 +- .../workloads/cifar/cifar_jax/workload.py | 2 +- .../fastmri/fastmri_pytorch/workload.py | 4 +-- .../imagenet_jax/randaugment.py | 8 ++---- .../imagenet_pytorch/workload.py | 4 +-- .../librispeech_jax/models.py | 10 +++---- .../librispeech_jax/spectrum_augmenter.py | 4 +-- .../librispeech_jax/workload.py | 2 +- .../librispeech_pytorch/workload.py | 9 +++--- .../librispeech_jax/models.py | 10 +++---- .../workloads/mnist/workload.py | 7 ++--- .../workloads/wmt/wmt_jax/models.py | 13 ++++----- .../workloads/wmt/wmt_pytorch/models.py | 4 +-- .../external_tuning/jax_nadamw_full_budget.py | 10 ++++--- .../jax_nadamw_target_setting.py | 10 ++++--- .../self_tuning/jax_nadamw_full_budget.py | 10 ++++--- .../self_tuning/jax_nadamw_target_setting.py | 10 ++++--- .../paper_baselines/nadamw/jax/submission.py | 10 ++++--- .../paper_baselines/sam/jax/submission.py | 8 +++--- .../shampoo/jax/distributed_shampoo.py | 28 +++++++------------ .../target_setting_algorithms/jax_nadamw.py | 10 ++++--- submission_runner.py | 4 +-- tests/modeldiffs/wmt/compare.py | 6 ++-- .../modeldiffs/wmt_attention_temp/compare.py | 6 ++-- tests/modeldiffs/wmt_glu_tanh/compare.py | 6 ++-- tests/modeldiffs/wmt_post_ln/compare.py | 6 ++-- 27 files changed, 98 insertions(+), 109 deletions(-) diff --git a/algorithmic_efficiency/profiler.py b/algorithmic_efficiency/profiler.py index fa2a1bee2..d73efd964 100644 --- a/algorithmic_efficiency/profiler.py +++ b/algorithmic_efficiency/profiler.py @@ -72,8 +72,8 @@ def _make_report( float(np.std(d)), len(d), float(np.sum(d)), - 100.0 * float(np.sum(d)) / total_duration) for a, - d in self.recorded_durations.items()] + 100.0 * float(np.sum(d)) / total_duration) + for a, d in self.recorded_durations.items()] report.sort(key=lambda x: x[5], reverse=True) total_calls = sum(x[3] for x in report) return report, total_calls, total_duration diff --git a/algorithmic_efficiency/random_utils.py b/algorithmic_efficiency/random_utils.py index 93dc263bd..b5b30ce22 100644 --- a/algorithmic_efficiency/random_utils.py +++ b/algorithmic_efficiency/random_utils.py @@ -18,7 +18,7 @@ # Annoyingly, RandomState(seed) requires seed to be in [0, 2 ** 32 - 1] (an # unsigned int), while RandomState.randint only accepts and returns signed ints. -MAX_UINT32 = 2**32-1 +MAX_UINT32 = 2**32 - 1 MIN_UINT32 = 0 SeedType = Union[int, list, np.ndarray] diff --git a/algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py b/algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py index 6ec90b99a..dd4643a60 100644 --- a/algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py +++ b/algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py @@ -76,7 +76,7 @@ def sync_batch_stats( # In this case each device has its own version of the batch statistics # and we average them. avg_fn = jax.pmap(lambda x: lax.pmean(x, 'x'), 'x') - new_model_state = model_state.copy() + new_model_state = model_state.copy() new_model_state['batch_stats'] = avg_fn(model_state['batch_stats']) return new_model_state diff --git a/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/workload.py b/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/workload.py index 74f6aa13d..a2f0828e3 100644 --- a/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/workload.py @@ -252,9 +252,7 @@ def _eval_model_on_split(self, for _ in range(num_batches): batch = next(self._eval_iters[split]) batch_metrics = self._eval_model(params, batch, model_rng) - total_metrics = { - k: v + batch_metrics[k] for k, v in total_metrics.items() - } + total_metrics = {k: v + batch_metrics[k] for k, v in total_metrics.items()} if USE_PYTORCH_DDP: for metric in total_metrics.values(): dist.all_reduce(metric) diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py index af1b763c1..94c66033a 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py +++ b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py @@ -313,8 +313,7 @@ def build_lut(histo, step): # If step is zero, return the original image. Otherwise, build # lut from the full histogram and step and then index from it. result = tf.cond( - tf.equal(step, 0), - lambda: im, + tf.equal(step, 0), lambda: im, lambda: tf.gather(build_lut(histo, step), im)) return tf.cast(result, tf.uint8) @@ -549,7 +548,6 @@ def distort_image_with_randaugment(image, num_layers, magnitude, key): translate_const=100) image = tf.cond( tf.equal(i, op_to_select), - lambda selected_func=func, - selected_args=args: selected_func(image, *selected_args), - lambda: image) + lambda selected_func=func, selected_args=args: selected_func( + image, *selected_args), lambda: image) return image diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_pytorch/workload.py b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_pytorch/workload.py index 3549911fa..0ed944191 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_pytorch/workload.py @@ -309,9 +309,7 @@ def _eval_model_on_split(self, update_batch_norm=False) weights = batch.get('weights') batch_metrics = self._compute_metrics(logits, batch['targets'], weights) - total_metrics = { - k: v + batch_metrics[k] for k, v in total_metrics.items() - } + total_metrics = {k: v + batch_metrics[k] for k, v in total_metrics.items()} if USE_PYTORCH_DDP: for metric in total_metrics.values(): dist.all_reduce(metric) diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py index ed05f4335..db8cbc70a 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py @@ -153,8 +153,8 @@ def setup(self): self.kernel = self.param('kernel', nn.initializers.xavier_uniform(), self.filter_shape) - self.bias = self.param( - 'bias', lambda rng, s: jnp.zeros(s, jnp.float32), self.output_channels) + self.bias = self.param('bias', lambda rng, s: jnp.zeros(s, jnp.float32), + self.output_channels) @nn.compact def __call__(self, inputs, paddings): @@ -442,12 +442,10 @@ def setup(self): dtype = self.config.dtype self.ra_mean = self.variable('batch_stats', - 'mean', - lambda s: jnp.zeros(s, dtype), + 'mean', lambda s: jnp.zeros(s, dtype), dim) self.ra_var = self.variable('batch_stats', - 'var', - lambda s: jnp.ones(s, dtype), + 'var', lambda s: jnp.ones(s, dtype), dim) self.gamma = self.param('scale', nn.initializers.zeros, dim, dtype) diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/spectrum_augmenter.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/spectrum_augmenter.py index 2a6f73d4d..c16740629 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/spectrum_augmenter.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/spectrum_augmenter.py @@ -81,8 +81,8 @@ def _get_mask(self, jnp.expand_dims(jnp.arange(multiplicity, dtype=jnp.int32), 0), [batch_size, 1]) multiplicity_tensor = masks_per_frame * choose_range - multiplicity_weights = (multiplicity_weights < - multiplicity_tensor).astype(jnp.int32) + multiplicity_weights = (multiplicity_weights + < multiplicity_tensor).astype(jnp.int32) pre_mask = jnp.einsum('bmt,bm->bt', pre_mask, multiplicity_weights) else: pre_mask = jnp.einsum('bmt->bt', pre_mask) diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py index d805e8b17..64e41989f 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py @@ -375,7 +375,7 @@ def sync_batch_stats( # In this case each device has its own version of the batch statistics and # we average them. avg_fn = jax.pmap(lambda x: lax.pmean(x, 'x'), 'x') - new_model_state = model_state.copy() + new_model_state = model_state.copy() new_model_state['batch_stats'] = avg_fn(model_state['batch_stats']) return new_model_state diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py index 155b30920..31d069e88 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py @@ -260,8 +260,9 @@ def greedy_decode( idxs = torch.arange( fin_result.numel(), device=result.device).view(*fin_result.shape) mask = torch.arange( - fin_result.shape[1], device=result.device).view( - 1, -1) < result.count_nonzero(dim=1).view(-1, 1) + fin_result.shape[1], + device=result.device).view(1, -1) < result.count_nonzero(dim=1).view( + -1, 1) fin_result.view(-1)[idxs[mask != 0]] = result[result != blank_id] padding = fin_result == 0 return fin_result, padding @@ -329,9 +330,7 @@ def _eval_model_on_split(self, 'word_errors': word_errors, 'num_words': num_words, } - total_metrics = { - k: v + batch_metrics[k] for k, v in total_metrics.items() - } + total_metrics = {k: v + batch_metrics[k] for k, v in total_metrics.items()} if USE_PYTORCH_DDP: for metric in total_metrics.values(): dist.all_reduce(metric) diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/models.py b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/models.py index f9eb732e9..c2fe540a6 100644 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/models.py +++ b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/models.py @@ -139,8 +139,8 @@ def setup(self): self.kernel = self.param('kernel', nn.initializers.xavier_uniform(), self.filter_shape) - self.bias = self.param( - 'bias', lambda rng, s: jnp.zeros(s, jnp.float32), self.output_channels) + self.bias = self.param('bias', lambda rng, s: jnp.zeros(s, jnp.float32), + self.output_channels) @nn.compact def __call__(self, inputs, paddings, train): @@ -273,12 +273,10 @@ def setup(self): dtype = self.dtype self.ra_mean = self.variable('batch_stats', - 'mean', - lambda s: jnp.zeros(s, dtype), + 'mean', lambda s: jnp.zeros(s, dtype), dim) self.ra_var = self.variable('batch_stats', - 'var', - lambda s: jnp.ones(s, dtype), + 'var', lambda s: jnp.ones(s, dtype), dim) self.gamma = self.param('scale', nn.initializers.zeros, dim, dtype) diff --git a/algorithmic_efficiency/workloads/mnist/workload.py b/algorithmic_efficiency/workloads/mnist/workload.py index dcc195170..ad950b869 100644 --- a/algorithmic_efficiency/workloads/mnist/workload.py +++ b/algorithmic_efficiency/workloads/mnist/workload.py @@ -46,8 +46,7 @@ def _build_mnist_dataset( ds = ds.map( lambda x: { 'inputs': _normalize(x['image'], train_mean, train_stddev), - 'targets': x['label'], - }) + 'targets': x['label'],}) is_train = split == 'train' if cache: @@ -214,8 +213,6 @@ def _eval_model_on_split(self, batch, model_state, per_device_model_rngs) - total_metrics = { - k: v + batch_metrics[k] for k, v in total_metrics.items() - } + total_metrics = {k: v + batch_metrics[k] for k, v in total_metrics.items()} return self._normalize_eval_metrics(num_examples, total_metrics) diff --git a/algorithmic_efficiency/workloads/wmt/wmt_jax/models.py b/algorithmic_efficiency/workloads/wmt/wmt_jax/models.py index 7bbc0b168..97fee032f 100644 --- a/algorithmic_efficiency/workloads/wmt/wmt_jax/models.py +++ b/algorithmic_efficiency/workloads/wmt/wmt_jax/models.py @@ -222,9 +222,8 @@ def __call__(self, inputs, encoder_mask=None): use_bias=False, broadcast_dropout=False, dropout_rate=attention_dropout_rate, - deterministic=cfg.deterministic)(cfg.attention_temp * x, - x, - mask=encoder_mask) + deterministic=cfg.deterministic)( + cfg.attention_temp * x, x, mask=encoder_mask) if cfg.dropout_rate is None: dropout_rate = 0.1 @@ -288,7 +287,8 @@ def __call__(self, broadcast_dropout=False, dropout_rate=attention_dropout_rate, deterministic=cfg.deterministic, - decode=cfg.decode)(cfg.attention_temp * x, x, mask=decoder_mask) + decode=cfg.decode)( + cfg.attention_temp * x, x, mask=decoder_mask) if cfg.dropout_rate is None: dropout_rate = 0.1 else: @@ -309,9 +309,8 @@ def __call__(self, use_bias=False, broadcast_dropout=False, dropout_rate=attention_dropout_rate, - deterministic=cfg.deterministic)(cfg.attention_temp * y, - encoded, - mask=encoder_decoder_mask) + deterministic=cfg.deterministic)( + cfg.attention_temp * y, encoded, mask=encoder_decoder_mask) y = nn.Dropout(rate=dropout_rate)(y, deterministic=cfg.deterministic) y = y + x diff --git a/algorithmic_efficiency/workloads/wmt/wmt_pytorch/models.py b/algorithmic_efficiency/workloads/wmt/wmt_pytorch/models.py index a1c7ce15e..089f1bfbb 100644 --- a/algorithmic_efficiency/workloads/wmt/wmt_pytorch/models.py +++ b/algorithmic_efficiency/workloads/wmt/wmt_pytorch/models.py @@ -942,8 +942,8 @@ def forward(self, # not the remaining zero elements. if attn_mask is not None: raise ValueError('Attention mask has to be None for decode == True.') - attn_mask = (torch.arange(max_len, device=k.device) >= - cache_index).reshape(1, max_len) + attn_mask = (torch.arange(max_len, device=k.device) + >= cache_index).reshape(1, max_len) # Update sequence length to account for complete sequence. seq_len = k.size(1) diff --git a/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py b/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py index 98193f01f..ad4d8e6f5 100644 --- a/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py +++ b/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py @@ -123,8 +123,9 @@ def update_fn(updates, state, params=None): mu_hat = _update_moment(updates, mu, b1, 1) mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) nu_hat = nu if not debias else _bias_correction(nu, b2, count) - updates = jax.tree_map( - lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) + updates = jax.tree_map(lambda m, v: m / (raise_power(v + eps_root) + eps), + mu_hat, + nu_hat) return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) return optax.GradientTransformation(init_fn, update_fn) @@ -139,8 +140,9 @@ class ScaleByAdamState(NamedTuple): def _update_moment(updates, moments, decay, order): """Compute the exponential moving average of the `order-th` moment.""" - return jax.tree_map( - lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) + return jax.tree_map(lambda g, t: (1 - decay) * (g**order) + decay * t, + updates, + moments) def _bias_correction(moment, decay, count): diff --git a/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py b/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py index 66fdc4ebb..bde851468 100644 --- a/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py +++ b/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py @@ -123,8 +123,9 @@ def update_fn(updates, state, params=None): mu_hat = _update_moment(updates, mu, b1, 1) mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) nu_hat = nu if not debias else _bias_correction(nu, b2, count) - updates = jax.tree_map( - lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) + updates = jax.tree_map(lambda m, v: m / (raise_power(v + eps_root) + eps), + mu_hat, + nu_hat) return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) return optax.GradientTransformation(init_fn, update_fn) @@ -139,8 +140,9 @@ class ScaleByAdamState(NamedTuple): def _update_moment(updates, moments, decay, order): """Compute the exponential moving average of the `order-th` moment.""" - return jax.tree_map( - lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) + return jax.tree_map(lambda g, t: (1 - decay) * (g**order) + decay * t, + updates, + moments) def _bias_correction(moment, decay, count): diff --git a/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py b/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py index 4f53afb56..4122be181 100644 --- a/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py +++ b/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py @@ -132,8 +132,9 @@ def update_fn(updates, state, params=None): mu_hat = _update_moment(updates, mu, b1, 1) mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) nu_hat = nu if not debias else _bias_correction(nu, b2, count) - updates = jax.tree_map( - lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) + updates = jax.tree_map(lambda m, v: m / (raise_power(v + eps_root) + eps), + mu_hat, + nu_hat) return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) return optax.GradientTransformation(init_fn, update_fn) @@ -148,8 +149,9 @@ class ScaleByAdamState(NamedTuple): def _update_moment(updates, moments, decay, order): """Compute the exponential moving average of the `order-th` moment.""" - return jax.tree_map( - lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) + return jax.tree_map(lambda g, t: (1 - decay) * (g**order) + decay * t, + updates, + moments) def _bias_correction(moment, decay, count): diff --git a/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py b/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py index 60a1f784d..6b5faa6b8 100644 --- a/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py +++ b/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py @@ -132,8 +132,9 @@ def update_fn(updates, state, params=None): mu_hat = _update_moment(updates, mu, b1, 1) mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) nu_hat = nu if not debias else _bias_correction(nu, b2, count) - updates = jax.tree_map( - lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) + updates = jax.tree_map(lambda m, v: m / (raise_power(v + eps_root) + eps), + mu_hat, + nu_hat) return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) return optax.GradientTransformation(init_fn, update_fn) @@ -148,8 +149,9 @@ class ScaleByAdamState(NamedTuple): def _update_moment(updates, moments, decay, order): """Compute the exponential moving average of the `order-th` moment.""" - return jax.tree_map( - lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) + return jax.tree_map(lambda g, t: (1 - decay) * (g**order) + decay * t, + updates, + moments) def _bias_correction(moment, decay, count): diff --git a/reference_algorithms/paper_baselines/nadamw/jax/submission.py b/reference_algorithms/paper_baselines/nadamw/jax/submission.py index 98193f01f..ad4d8e6f5 100644 --- a/reference_algorithms/paper_baselines/nadamw/jax/submission.py +++ b/reference_algorithms/paper_baselines/nadamw/jax/submission.py @@ -123,8 +123,9 @@ def update_fn(updates, state, params=None): mu_hat = _update_moment(updates, mu, b1, 1) mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) nu_hat = nu if not debias else _bias_correction(nu, b2, count) - updates = jax.tree_map( - lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) + updates = jax.tree_map(lambda m, v: m / (raise_power(v + eps_root) + eps), + mu_hat, + nu_hat) return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) return optax.GradientTransformation(init_fn, update_fn) @@ -139,8 +140,9 @@ class ScaleByAdamState(NamedTuple): def _update_moment(updates, moments, decay, order): """Compute the exponential moving average of the `order-th` moment.""" - return jax.tree_map( - lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) + return jax.tree_map(lambda g, t: (1 - decay) * (g**order) + decay * t, + updates, + moments) def _bias_correction(moment, decay, count): diff --git a/reference_algorithms/paper_baselines/sam/jax/submission.py b/reference_algorithms/paper_baselines/sam/jax/submission.py index 85b3d7441..d33daadb8 100644 --- a/reference_algorithms/paper_baselines/sam/jax/submission.py +++ b/reference_algorithms/paper_baselines/sam/jax/submission.py @@ -67,8 +67,9 @@ def update_fn(updates, state, grad_fn_params_tuple): # the noised parameters in the same order as on the original gradients and # with the same 1e-6 epsilon that is used when clipping the gradients. updates = dual_vector(updates) - noised_params = jax.tree_util.tree_map( - lambda p, u: p + rho * u, params, updates) + noised_params = jax.tree_util.tree_map(lambda p, u: p + rho * u, + params, + updates) (_, (n_valid_examples, _)), updates = grad_fn(noised_params) # Get correct global mean grad. (n_valid_examples, updates) = lax.psum((n_valid_examples, updates), @@ -80,8 +81,7 @@ def update_fn(updates, state, grad_fn_params_tuple): sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(updates))) scaled_updates = jax.tree_map( lambda x: x / (updates_norm + _GRAD_CLIP_EPS) * grad_clip, updates) - updates = jax.lax.cond(updates_norm > grad_clip, - lambda _: scaled_updates, + updates = jax.lax.cond(updates_norm > grad_clip, lambda _: scaled_updates, lambda _: updates, None) updates, state = base_opt_update_fn(updates, state, params) diff --git a/reference_algorithms/paper_baselines/shampoo/jax/distributed_shampoo.py b/reference_algorithms/paper_baselines/shampoo/jax/distributed_shampoo.py index 725529cae..722dab06b 100644 --- a/reference_algorithms/paper_baselines/shampoo/jax/distributed_shampoo.py +++ b/reference_algorithms/paper_baselines/shampoo/jax/distributed_shampoo.py @@ -595,8 +595,8 @@ def matrix_inverse_pth_root( if padding_start is not None: # Zero out padding in identity as well for convergence checks. - ix = (jnp.arange(matrix_size, dtype=jnp.int32) < padding_start).astype( - matrix.dtype) + ix = (jnp.arange(matrix_size, dtype=jnp.int32) + < padding_start).astype(matrix.dtype) matrix *= ix[jnp.newaxis, :] matrix *= ix[:, jnp.newaxis] identity *= ix @@ -815,8 +815,8 @@ def matrix_inverse_pth_root_eigh( alpha = jnp.asarray(-1.0 / p, _MAT_INV_PTH_ROOT_DTYPE) identity = jnp.eye(matrix_size, dtype=_MAT_INV_PTH_ROOT_DTYPE) if padding_start is not None: - ix = (jnp.arange(matrix_size, dtype=jnp.int32) < padding_start).astype( - matrix.dtype) + ix = (jnp.arange(matrix_size, dtype=jnp.int32) + < padding_start).astype(matrix.dtype) matrix *= ix[jnp.newaxis, :] matrix *= ix[:, jnp.newaxis] identity *= ix @@ -1809,17 +1809,13 @@ def sharded_update_fn(grads, state, params): )) new_stats_flat = jax.tree_map( - lambda g, - s, - p: _compute_stats(g, s, p, state.count), + lambda g, s, p: _compute_stats(g, s, p, state.count), grads_flat, stats_flat, params_flat) outputs = jax.tree_map( - lambda g, - s, - p: _transform_grad(g, s, p, state.count), + lambda g, s, p: _transform_grad(g, s, p, state.count), grads_flat, new_stats_flat, params_flat) @@ -1923,8 +1919,8 @@ def _internal_inverse_pth_root_all(): errors = metrics.inverse_pth_root_errors errors = errors.reshape((-1, 1, 1)) predicate = jnp.logical_or( - jnp.isnan(errors), - errors >= inverse_failure_threshold).astype(new_preconditioners.dtype) + jnp.isnan(errors), errors + >= inverse_failure_threshold).astype(new_preconditioners.dtype) # TODO(rohananil): Check for numerical instabilities. new_conditional_preconditioners = ( predicate * global_stats.preconditioners + @@ -2442,9 +2438,7 @@ def update_fn(grads, state, params): stats_grads = treedef.flatten_up_to(grads_custom) new_stats_flat = jax.tree_map( - lambda g, - s, - p: _compute_stats(g, s, p, state.count), + lambda g, s, p: _compute_stats(g, s, p, state.count), stats_grads, stats_flat, params_flat) @@ -2453,9 +2447,7 @@ def update_fn(grads, state, params): params_flat, state.count) outputs = jax.tree_map( - lambda g, - s, - p: _transform_grad(g, s, p, state.count), + lambda g, s, p: _transform_grad(g, s, p, state.count), grads_flat, new_stats_flat, params_flat) diff --git a/reference_algorithms/target_setting_algorithms/jax_nadamw.py b/reference_algorithms/target_setting_algorithms/jax_nadamw.py index 21f2a7b2b..fc866f80a 100644 --- a/reference_algorithms/target_setting_algorithms/jax_nadamw.py +++ b/reference_algorithms/target_setting_algorithms/jax_nadamw.py @@ -108,8 +108,9 @@ def update_fn(updates, state, params=None): mu_hat = _update_moment(updates, mu, b1, 1) mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) nu_hat = nu if not debias else _bias_correction(nu, b2, count) - updates = jax.tree_map( - lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) + updates = jax.tree_map(lambda m, v: m / (raise_power(v + eps_root) + eps), + mu_hat, + nu_hat) return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) return optax.GradientTransformation(init_fn, update_fn) @@ -124,8 +125,9 @@ class ScaleByAdamState(NamedTuple): def _update_moment(updates, moments, decay, order): """Compute the exponential moving average of the `order-th` moment.""" - return jax.tree_map( - lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) + return jax.tree_map(lambda g, t: (1 - decay) * (g**order) + decay * t, + updates, + moments) def _bias_correction(moment, decay, count): diff --git a/submission_runner.py b/submission_runner.py index 0024c35d4..a6bea1aa8 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -377,8 +377,8 @@ def train_once( train_state['is_time_remaining'] = ( train_state['accumulated_submission_time'] < max_allowed_runtime_sec) # Check if submission is eligible for an untimed eval. - if ((train_step_end_time - train_state['last_eval_time']) >= - workload.eval_period_time_sec or train_state['training_complete']): + if ((train_step_end_time - train_state['last_eval_time']) + >= workload.eval_period_time_sec or train_state['training_complete']): with profiler.profile('Evaluation'): del batch _reset_cuda_mem() diff --git a/tests/modeldiffs/wmt/compare.py b/tests/modeldiffs/wmt/compare.py index 41fc5ee17..8f9154f53 100644 --- a/tests/modeldiffs/wmt/compare.py +++ b/tests/modeldiffs/wmt/compare.py @@ -76,9 +76,9 @@ def sd_transform(sd): out = { tuple( k.replace('SelfAttention', 'MultiHeadDotProductAttention') - for k in key): value - for key, - value in out.items() + for k in key): + value + for key, value in out.items() } elif 'Dense' in k_str: new_key = (*k[:2], 'MlpBlock_0', *k[2:]) diff --git a/tests/modeldiffs/wmt_attention_temp/compare.py b/tests/modeldiffs/wmt_attention_temp/compare.py index 92ce4eb44..ff7103d43 100644 --- a/tests/modeldiffs/wmt_attention_temp/compare.py +++ b/tests/modeldiffs/wmt_attention_temp/compare.py @@ -76,9 +76,9 @@ def sd_transform(sd): out = { tuple( k.replace('SelfAttention', 'MultiHeadDotProductAttention') - for k in key): value - for key, - value in out.items() + for k in key): + value + for key, value in out.items() } elif 'Dense' in k_str: new_key = (*k[:2], 'MlpBlock_0', *k[2:]) diff --git a/tests/modeldiffs/wmt_glu_tanh/compare.py b/tests/modeldiffs/wmt_glu_tanh/compare.py index b8d860479..d24d818a2 100644 --- a/tests/modeldiffs/wmt_glu_tanh/compare.py +++ b/tests/modeldiffs/wmt_glu_tanh/compare.py @@ -76,9 +76,9 @@ def sd_transform(sd): out = { tuple( k.replace('SelfAttention', 'MultiHeadDotProductAttention') - for k in key): value - for key, - value in out.items() + for k in key): + value + for key, value in out.items() } elif 'Dense' in k_str: new_key = (*k[:2], 'MlpBlock_0', *k[2:]) diff --git a/tests/modeldiffs/wmt_post_ln/compare.py b/tests/modeldiffs/wmt_post_ln/compare.py index 3f5469d8d..7d0556345 100644 --- a/tests/modeldiffs/wmt_post_ln/compare.py +++ b/tests/modeldiffs/wmt_post_ln/compare.py @@ -76,9 +76,9 @@ def sd_transform(sd): out = { tuple( k.replace('SelfAttention', 'MultiHeadDotProductAttention') - for k in key): value - for key, - value in out.items() + for k in key): + value + for key, value in out.items() } elif 'Dense' in k_str: new_key = (*k[:2], 'MlpBlock_0', *k[2:]) From c65d93e5b4adfa6e493e6101048738afd8dc15d9 Mon Sep 17 00:00:00 2001 From: init-22 Date: Tue, 3 Dec 2024 20:37:32 +0530 Subject: [PATCH 042/105] fix: running yapf again with 0.32, earlier using 0.43 --- algorithmic_efficiency/profiler.py | 4 ++-- .../workloads/fastmri/fastmri_pytorch/workload.py | 4 +++- .../imagenet_resnet/imagenet_jax/randaugment.py | 8 +++++--- .../imagenet_resnet/imagenet_pytorch/workload.py | 4 +++- .../librispeech_conformer/librispeech_jax/models.py | 10 ++++++---- .../librispeech_jax/spectrum_augmenter.py | 4 ++-- .../librispeech_pytorch/workload.py | 9 +++++---- .../librispeech_deepspeech/librispeech_jax/models.py | 10 ++++++---- algorithmic_efficiency/workloads/mnist/workload.py | 7 +++++-- .../workloads/wmt/wmt_pytorch/models.py | 4 ++-- setup.cfg | 2 +- submission_runner.py | 4 ++-- tests/modeldiffs/wmt/compare.py | 6 +++--- tests/modeldiffs/wmt_attention_temp/compare.py | 6 +++--- tests/modeldiffs/wmt_glu_tanh/compare.py | 6 +++--- tests/modeldiffs/wmt_post_ln/compare.py | 6 +++--- 16 files changed, 54 insertions(+), 40 deletions(-) diff --git a/algorithmic_efficiency/profiler.py b/algorithmic_efficiency/profiler.py index d73efd964..fa2a1bee2 100644 --- a/algorithmic_efficiency/profiler.py +++ b/algorithmic_efficiency/profiler.py @@ -72,8 +72,8 @@ def _make_report( float(np.std(d)), len(d), float(np.sum(d)), - 100.0 * float(np.sum(d)) / total_duration) - for a, d in self.recorded_durations.items()] + 100.0 * float(np.sum(d)) / total_duration) for a, + d in self.recorded_durations.items()] report.sort(key=lambda x: x[5], reverse=True) total_calls = sum(x[3] for x in report) return report, total_calls, total_duration diff --git a/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/workload.py b/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/workload.py index a2f0828e3..74f6aa13d 100644 --- a/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/workload.py @@ -252,7 +252,9 @@ def _eval_model_on_split(self, for _ in range(num_batches): batch = next(self._eval_iters[split]) batch_metrics = self._eval_model(params, batch, model_rng) - total_metrics = {k: v + batch_metrics[k] for k, v in total_metrics.items()} + total_metrics = { + k: v + batch_metrics[k] for k, v in total_metrics.items() + } if USE_PYTORCH_DDP: for metric in total_metrics.values(): dist.all_reduce(metric) diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py index 94c66033a..af1b763c1 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py +++ b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py @@ -313,7 +313,8 @@ def build_lut(histo, step): # If step is zero, return the original image. Otherwise, build # lut from the full histogram and step and then index from it. result = tf.cond( - tf.equal(step, 0), lambda: im, + tf.equal(step, 0), + lambda: im, lambda: tf.gather(build_lut(histo, step), im)) return tf.cast(result, tf.uint8) @@ -548,6 +549,7 @@ def distort_image_with_randaugment(image, num_layers, magnitude, key): translate_const=100) image = tf.cond( tf.equal(i, op_to_select), - lambda selected_func=func, selected_args=args: selected_func( - image, *selected_args), lambda: image) + lambda selected_func=func, + selected_args=args: selected_func(image, *selected_args), + lambda: image) return image diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_pytorch/workload.py b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_pytorch/workload.py index 0ed944191..3549911fa 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_pytorch/workload.py @@ -309,7 +309,9 @@ def _eval_model_on_split(self, update_batch_norm=False) weights = batch.get('weights') batch_metrics = self._compute_metrics(logits, batch['targets'], weights) - total_metrics = {k: v + batch_metrics[k] for k, v in total_metrics.items()} + total_metrics = { + k: v + batch_metrics[k] for k, v in total_metrics.items() + } if USE_PYTORCH_DDP: for metric in total_metrics.values(): dist.all_reduce(metric) diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py index db8cbc70a..ed05f4335 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py @@ -153,8 +153,8 @@ def setup(self): self.kernel = self.param('kernel', nn.initializers.xavier_uniform(), self.filter_shape) - self.bias = self.param('bias', lambda rng, s: jnp.zeros(s, jnp.float32), - self.output_channels) + self.bias = self.param( + 'bias', lambda rng, s: jnp.zeros(s, jnp.float32), self.output_channels) @nn.compact def __call__(self, inputs, paddings): @@ -442,10 +442,12 @@ def setup(self): dtype = self.config.dtype self.ra_mean = self.variable('batch_stats', - 'mean', lambda s: jnp.zeros(s, dtype), + 'mean', + lambda s: jnp.zeros(s, dtype), dim) self.ra_var = self.variable('batch_stats', - 'var', lambda s: jnp.ones(s, dtype), + 'var', + lambda s: jnp.ones(s, dtype), dim) self.gamma = self.param('scale', nn.initializers.zeros, dim, dtype) diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/spectrum_augmenter.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/spectrum_augmenter.py index c16740629..2a6f73d4d 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/spectrum_augmenter.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/spectrum_augmenter.py @@ -81,8 +81,8 @@ def _get_mask(self, jnp.expand_dims(jnp.arange(multiplicity, dtype=jnp.int32), 0), [batch_size, 1]) multiplicity_tensor = masks_per_frame * choose_range - multiplicity_weights = (multiplicity_weights - < multiplicity_tensor).astype(jnp.int32) + multiplicity_weights = (multiplicity_weights < + multiplicity_tensor).astype(jnp.int32) pre_mask = jnp.einsum('bmt,bm->bt', pre_mask, multiplicity_weights) else: pre_mask = jnp.einsum('bmt->bt', pre_mask) diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py index 31d069e88..155b30920 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py @@ -260,9 +260,8 @@ def greedy_decode( idxs = torch.arange( fin_result.numel(), device=result.device).view(*fin_result.shape) mask = torch.arange( - fin_result.shape[1], - device=result.device).view(1, -1) < result.count_nonzero(dim=1).view( - -1, 1) + fin_result.shape[1], device=result.device).view( + 1, -1) < result.count_nonzero(dim=1).view(-1, 1) fin_result.view(-1)[idxs[mask != 0]] = result[result != blank_id] padding = fin_result == 0 return fin_result, padding @@ -330,7 +329,9 @@ def _eval_model_on_split(self, 'word_errors': word_errors, 'num_words': num_words, } - total_metrics = {k: v + batch_metrics[k] for k, v in total_metrics.items()} + total_metrics = { + k: v + batch_metrics[k] for k, v in total_metrics.items() + } if USE_PYTORCH_DDP: for metric in total_metrics.values(): dist.all_reduce(metric) diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/models.py b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/models.py index c2fe540a6..f9eb732e9 100644 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/models.py +++ b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/models.py @@ -139,8 +139,8 @@ def setup(self): self.kernel = self.param('kernel', nn.initializers.xavier_uniform(), self.filter_shape) - self.bias = self.param('bias', lambda rng, s: jnp.zeros(s, jnp.float32), - self.output_channels) + self.bias = self.param( + 'bias', lambda rng, s: jnp.zeros(s, jnp.float32), self.output_channels) @nn.compact def __call__(self, inputs, paddings, train): @@ -273,10 +273,12 @@ def setup(self): dtype = self.dtype self.ra_mean = self.variable('batch_stats', - 'mean', lambda s: jnp.zeros(s, dtype), + 'mean', + lambda s: jnp.zeros(s, dtype), dim) self.ra_var = self.variable('batch_stats', - 'var', lambda s: jnp.ones(s, dtype), + 'var', + lambda s: jnp.ones(s, dtype), dim) self.gamma = self.param('scale', nn.initializers.zeros, dim, dtype) diff --git a/algorithmic_efficiency/workloads/mnist/workload.py b/algorithmic_efficiency/workloads/mnist/workload.py index ad950b869..dcc195170 100644 --- a/algorithmic_efficiency/workloads/mnist/workload.py +++ b/algorithmic_efficiency/workloads/mnist/workload.py @@ -46,7 +46,8 @@ def _build_mnist_dataset( ds = ds.map( lambda x: { 'inputs': _normalize(x['image'], train_mean, train_stddev), - 'targets': x['label'],}) + 'targets': x['label'], + }) is_train = split == 'train' if cache: @@ -213,6 +214,8 @@ def _eval_model_on_split(self, batch, model_state, per_device_model_rngs) - total_metrics = {k: v + batch_metrics[k] for k, v in total_metrics.items()} + total_metrics = { + k: v + batch_metrics[k] for k, v in total_metrics.items() + } return self._normalize_eval_metrics(num_examples, total_metrics) diff --git a/algorithmic_efficiency/workloads/wmt/wmt_pytorch/models.py b/algorithmic_efficiency/workloads/wmt/wmt_pytorch/models.py index 089f1bfbb..a1c7ce15e 100644 --- a/algorithmic_efficiency/workloads/wmt/wmt_pytorch/models.py +++ b/algorithmic_efficiency/workloads/wmt/wmt_pytorch/models.py @@ -942,8 +942,8 @@ def forward(self, # not the remaining zero elements. if attn_mask is not None: raise ValueError('Attention mask has to be None for decode == True.') - attn_mask = (torch.arange(max_len, device=k.device) - >= cache_index).reshape(1, max_len) + attn_mask = (torch.arange(max_len, device=k.device) >= + cache_index).reshape(1, max_len) # Update sequence length to account for complete sequence. seq_len = k.size(1) diff --git a/setup.cfg b/setup.cfg index 23e86a13b..e8044fe02 100644 --- a/setup.cfg +++ b/setup.cfg @@ -80,7 +80,7 @@ dev = isort==5.13.2 pylint==3.3.1 pytest==8.3.3 - yapf==0.43.0 + yapf==0.32.0 pre-commit==4.0.1 # Workloads # diff --git a/submission_runner.py b/submission_runner.py index a6bea1aa8..0024c35d4 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -377,8 +377,8 @@ def train_once( train_state['is_time_remaining'] = ( train_state['accumulated_submission_time'] < max_allowed_runtime_sec) # Check if submission is eligible for an untimed eval. - if ((train_step_end_time - train_state['last_eval_time']) - >= workload.eval_period_time_sec or train_state['training_complete']): + if ((train_step_end_time - train_state['last_eval_time']) >= + workload.eval_period_time_sec or train_state['training_complete']): with profiler.profile('Evaluation'): del batch _reset_cuda_mem() diff --git a/tests/modeldiffs/wmt/compare.py b/tests/modeldiffs/wmt/compare.py index 8f9154f53..41fc5ee17 100644 --- a/tests/modeldiffs/wmt/compare.py +++ b/tests/modeldiffs/wmt/compare.py @@ -76,9 +76,9 @@ def sd_transform(sd): out = { tuple( k.replace('SelfAttention', 'MultiHeadDotProductAttention') - for k in key): - value - for key, value in out.items() + for k in key): value + for key, + value in out.items() } elif 'Dense' in k_str: new_key = (*k[:2], 'MlpBlock_0', *k[2:]) diff --git a/tests/modeldiffs/wmt_attention_temp/compare.py b/tests/modeldiffs/wmt_attention_temp/compare.py index ff7103d43..92ce4eb44 100644 --- a/tests/modeldiffs/wmt_attention_temp/compare.py +++ b/tests/modeldiffs/wmt_attention_temp/compare.py @@ -76,9 +76,9 @@ def sd_transform(sd): out = { tuple( k.replace('SelfAttention', 'MultiHeadDotProductAttention') - for k in key): - value - for key, value in out.items() + for k in key): value + for key, + value in out.items() } elif 'Dense' in k_str: new_key = (*k[:2], 'MlpBlock_0', *k[2:]) diff --git a/tests/modeldiffs/wmt_glu_tanh/compare.py b/tests/modeldiffs/wmt_glu_tanh/compare.py index d24d818a2..b8d860479 100644 --- a/tests/modeldiffs/wmt_glu_tanh/compare.py +++ b/tests/modeldiffs/wmt_glu_tanh/compare.py @@ -76,9 +76,9 @@ def sd_transform(sd): out = { tuple( k.replace('SelfAttention', 'MultiHeadDotProductAttention') - for k in key): - value - for key, value in out.items() + for k in key): value + for key, + value in out.items() } elif 'Dense' in k_str: new_key = (*k[:2], 'MlpBlock_0', *k[2:]) diff --git a/tests/modeldiffs/wmt_post_ln/compare.py b/tests/modeldiffs/wmt_post_ln/compare.py index 7d0556345..3f5469d8d 100644 --- a/tests/modeldiffs/wmt_post_ln/compare.py +++ b/tests/modeldiffs/wmt_post_ln/compare.py @@ -76,9 +76,9 @@ def sd_transform(sd): out = { tuple( k.replace('SelfAttention', 'MultiHeadDotProductAttention') - for k in key): - value - for key, value in out.items() + for k in key): value + for key, + value in out.items() } elif 'Dense' in k_str: new_key = (*k[:2], 'MlpBlock_0', *k[2:]) From 3afd1dff5e6bf0780c5ff77e2e7daedba74928cb Mon Sep 17 00:00:00 2001 From: init-22 Date: Tue, 3 Dec 2024 20:39:03 +0530 Subject: [PATCH 043/105] fix: running yapf again with 0.32, earlier using 0.43 --- .../external_tuning/jax_nadamw_full_budget.py | 10 +++---- .../jax_nadamw_target_setting.py | 10 +++---- .../self_tuning/jax_nadamw_full_budget.py | 10 +++---- .../self_tuning/jax_nadamw_target_setting.py | 10 +++---- .../paper_baselines/nadamw/jax/submission.py | 10 +++---- .../paper_baselines/sam/jax/submission.py | 8 +++--- .../shampoo/jax/distributed_shampoo.py | 28 ++++++++++++------- .../target_setting_algorithms/jax_nadamw.py | 10 +++---- 8 files changed, 46 insertions(+), 50 deletions(-) diff --git a/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py b/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py index ad4d8e6f5..98193f01f 100644 --- a/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py +++ b/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py @@ -123,9 +123,8 @@ def update_fn(updates, state, params=None): mu_hat = _update_moment(updates, mu, b1, 1) mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) nu_hat = nu if not debias else _bias_correction(nu, b2, count) - updates = jax.tree_map(lambda m, v: m / (raise_power(v + eps_root) + eps), - mu_hat, - nu_hat) + updates = jax.tree_map( + lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) return optax.GradientTransformation(init_fn, update_fn) @@ -140,9 +139,8 @@ class ScaleByAdamState(NamedTuple): def _update_moment(updates, moments, decay, order): """Compute the exponential moving average of the `order-th` moment.""" - return jax.tree_map(lambda g, t: (1 - decay) * (g**order) + decay * t, - updates, - moments) + return jax.tree_map( + lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) def _bias_correction(moment, decay, count): diff --git a/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py b/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py index bde851468..66fdc4ebb 100644 --- a/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py +++ b/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py @@ -123,9 +123,8 @@ def update_fn(updates, state, params=None): mu_hat = _update_moment(updates, mu, b1, 1) mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) nu_hat = nu if not debias else _bias_correction(nu, b2, count) - updates = jax.tree_map(lambda m, v: m / (raise_power(v + eps_root) + eps), - mu_hat, - nu_hat) + updates = jax.tree_map( + lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) return optax.GradientTransformation(init_fn, update_fn) @@ -140,9 +139,8 @@ class ScaleByAdamState(NamedTuple): def _update_moment(updates, moments, decay, order): """Compute the exponential moving average of the `order-th` moment.""" - return jax.tree_map(lambda g, t: (1 - decay) * (g**order) + decay * t, - updates, - moments) + return jax.tree_map( + lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) def _bias_correction(moment, decay, count): diff --git a/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py b/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py index 4122be181..4f53afb56 100644 --- a/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py +++ b/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py @@ -132,9 +132,8 @@ def update_fn(updates, state, params=None): mu_hat = _update_moment(updates, mu, b1, 1) mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) nu_hat = nu if not debias else _bias_correction(nu, b2, count) - updates = jax.tree_map(lambda m, v: m / (raise_power(v + eps_root) + eps), - mu_hat, - nu_hat) + updates = jax.tree_map( + lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) return optax.GradientTransformation(init_fn, update_fn) @@ -149,9 +148,8 @@ class ScaleByAdamState(NamedTuple): def _update_moment(updates, moments, decay, order): """Compute the exponential moving average of the `order-th` moment.""" - return jax.tree_map(lambda g, t: (1 - decay) * (g**order) + decay * t, - updates, - moments) + return jax.tree_map( + lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) def _bias_correction(moment, decay, count): diff --git a/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py b/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py index 6b5faa6b8..60a1f784d 100644 --- a/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py +++ b/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py @@ -132,9 +132,8 @@ def update_fn(updates, state, params=None): mu_hat = _update_moment(updates, mu, b1, 1) mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) nu_hat = nu if not debias else _bias_correction(nu, b2, count) - updates = jax.tree_map(lambda m, v: m / (raise_power(v + eps_root) + eps), - mu_hat, - nu_hat) + updates = jax.tree_map( + lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) return optax.GradientTransformation(init_fn, update_fn) @@ -149,9 +148,8 @@ class ScaleByAdamState(NamedTuple): def _update_moment(updates, moments, decay, order): """Compute the exponential moving average of the `order-th` moment.""" - return jax.tree_map(lambda g, t: (1 - decay) * (g**order) + decay * t, - updates, - moments) + return jax.tree_map( + lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) def _bias_correction(moment, decay, count): diff --git a/reference_algorithms/paper_baselines/nadamw/jax/submission.py b/reference_algorithms/paper_baselines/nadamw/jax/submission.py index ad4d8e6f5..98193f01f 100644 --- a/reference_algorithms/paper_baselines/nadamw/jax/submission.py +++ b/reference_algorithms/paper_baselines/nadamw/jax/submission.py @@ -123,9 +123,8 @@ def update_fn(updates, state, params=None): mu_hat = _update_moment(updates, mu, b1, 1) mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) nu_hat = nu if not debias else _bias_correction(nu, b2, count) - updates = jax.tree_map(lambda m, v: m / (raise_power(v + eps_root) + eps), - mu_hat, - nu_hat) + updates = jax.tree_map( + lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) return optax.GradientTransformation(init_fn, update_fn) @@ -140,9 +139,8 @@ class ScaleByAdamState(NamedTuple): def _update_moment(updates, moments, decay, order): """Compute the exponential moving average of the `order-th` moment.""" - return jax.tree_map(lambda g, t: (1 - decay) * (g**order) + decay * t, - updates, - moments) + return jax.tree_map( + lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) def _bias_correction(moment, decay, count): diff --git a/reference_algorithms/paper_baselines/sam/jax/submission.py b/reference_algorithms/paper_baselines/sam/jax/submission.py index d33daadb8..85b3d7441 100644 --- a/reference_algorithms/paper_baselines/sam/jax/submission.py +++ b/reference_algorithms/paper_baselines/sam/jax/submission.py @@ -67,9 +67,8 @@ def update_fn(updates, state, grad_fn_params_tuple): # the noised parameters in the same order as on the original gradients and # with the same 1e-6 epsilon that is used when clipping the gradients. updates = dual_vector(updates) - noised_params = jax.tree_util.tree_map(lambda p, u: p + rho * u, - params, - updates) + noised_params = jax.tree_util.tree_map( + lambda p, u: p + rho * u, params, updates) (_, (n_valid_examples, _)), updates = grad_fn(noised_params) # Get correct global mean grad. (n_valid_examples, updates) = lax.psum((n_valid_examples, updates), @@ -81,7 +80,8 @@ def update_fn(updates, state, grad_fn_params_tuple): sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(updates))) scaled_updates = jax.tree_map( lambda x: x / (updates_norm + _GRAD_CLIP_EPS) * grad_clip, updates) - updates = jax.lax.cond(updates_norm > grad_clip, lambda _: scaled_updates, + updates = jax.lax.cond(updates_norm > grad_clip, + lambda _: scaled_updates, lambda _: updates, None) updates, state = base_opt_update_fn(updates, state, params) diff --git a/reference_algorithms/paper_baselines/shampoo/jax/distributed_shampoo.py b/reference_algorithms/paper_baselines/shampoo/jax/distributed_shampoo.py index 722dab06b..725529cae 100644 --- a/reference_algorithms/paper_baselines/shampoo/jax/distributed_shampoo.py +++ b/reference_algorithms/paper_baselines/shampoo/jax/distributed_shampoo.py @@ -595,8 +595,8 @@ def matrix_inverse_pth_root( if padding_start is not None: # Zero out padding in identity as well for convergence checks. - ix = (jnp.arange(matrix_size, dtype=jnp.int32) - < padding_start).astype(matrix.dtype) + ix = (jnp.arange(matrix_size, dtype=jnp.int32) < padding_start).astype( + matrix.dtype) matrix *= ix[jnp.newaxis, :] matrix *= ix[:, jnp.newaxis] identity *= ix @@ -815,8 +815,8 @@ def matrix_inverse_pth_root_eigh( alpha = jnp.asarray(-1.0 / p, _MAT_INV_PTH_ROOT_DTYPE) identity = jnp.eye(matrix_size, dtype=_MAT_INV_PTH_ROOT_DTYPE) if padding_start is not None: - ix = (jnp.arange(matrix_size, dtype=jnp.int32) - < padding_start).astype(matrix.dtype) + ix = (jnp.arange(matrix_size, dtype=jnp.int32) < padding_start).astype( + matrix.dtype) matrix *= ix[jnp.newaxis, :] matrix *= ix[:, jnp.newaxis] identity *= ix @@ -1809,13 +1809,17 @@ def sharded_update_fn(grads, state, params): )) new_stats_flat = jax.tree_map( - lambda g, s, p: _compute_stats(g, s, p, state.count), + lambda g, + s, + p: _compute_stats(g, s, p, state.count), grads_flat, stats_flat, params_flat) outputs = jax.tree_map( - lambda g, s, p: _transform_grad(g, s, p, state.count), + lambda g, + s, + p: _transform_grad(g, s, p, state.count), grads_flat, new_stats_flat, params_flat) @@ -1919,8 +1923,8 @@ def _internal_inverse_pth_root_all(): errors = metrics.inverse_pth_root_errors errors = errors.reshape((-1, 1, 1)) predicate = jnp.logical_or( - jnp.isnan(errors), errors - >= inverse_failure_threshold).astype(new_preconditioners.dtype) + jnp.isnan(errors), + errors >= inverse_failure_threshold).astype(new_preconditioners.dtype) # TODO(rohananil): Check for numerical instabilities. new_conditional_preconditioners = ( predicate * global_stats.preconditioners + @@ -2438,7 +2442,9 @@ def update_fn(grads, state, params): stats_grads = treedef.flatten_up_to(grads_custom) new_stats_flat = jax.tree_map( - lambda g, s, p: _compute_stats(g, s, p, state.count), + lambda g, + s, + p: _compute_stats(g, s, p, state.count), stats_grads, stats_flat, params_flat) @@ -2447,7 +2453,9 @@ def update_fn(grads, state, params): params_flat, state.count) outputs = jax.tree_map( - lambda g, s, p: _transform_grad(g, s, p, state.count), + lambda g, + s, + p: _transform_grad(g, s, p, state.count), grads_flat, new_stats_flat, params_flat) diff --git a/reference_algorithms/target_setting_algorithms/jax_nadamw.py b/reference_algorithms/target_setting_algorithms/jax_nadamw.py index fc866f80a..21f2a7b2b 100644 --- a/reference_algorithms/target_setting_algorithms/jax_nadamw.py +++ b/reference_algorithms/target_setting_algorithms/jax_nadamw.py @@ -108,9 +108,8 @@ def update_fn(updates, state, params=None): mu_hat = _update_moment(updates, mu, b1, 1) mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) nu_hat = nu if not debias else _bias_correction(nu, b2, count) - updates = jax.tree_map(lambda m, v: m / (raise_power(v + eps_root) + eps), - mu_hat, - nu_hat) + updates = jax.tree_map( + lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) return optax.GradientTransformation(init_fn, update_fn) @@ -125,9 +124,8 @@ class ScaleByAdamState(NamedTuple): def _update_moment(updates, moments, decay, order): """Compute the exponential moving average of the `order-th` moment.""" - return jax.tree_map(lambda g, t: (1 - decay) * (g**order) + decay * t, - updates, - moments) + return jax.tree_map( + lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) def _bias_correction(moment, decay, count): From 6ff2010d884e9d14911beab6dbce1a546a0a6213 Mon Sep 17 00:00:00 2001 From: init-22 Date: Tue, 3 Dec 2024 21:21:03 +0530 Subject: [PATCH 044/105] fix: latest versions of typing dont support Text instead str is recommended --- algorithmic_efficiency/halton.py | 14 +++++++------- .../workloads/wmt/wmt_jax/workload.py | 2 +- .../workloads/wmt/wmt_pytorch/workload.py | 2 +- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/algorithmic_efficiency/halton.py b/algorithmic_efficiency/halton.py index 9eb30861d..d710e3fce 100644 --- a/algorithmic_efficiency/halton.py +++ b/algorithmic_efficiency/halton.py @@ -10,13 +10,13 @@ import functools import itertools import math -from typing import Any, Callable, Dict, List, Sequence, Text, Tuple, Union +from typing import Any, Callable, Dict, List, Sequence, Tuple, Union from absl import logging from numpy import random -_SweepSequence = List[Dict[Text, Any]] -_GeneratorFn = Callable[[float], Tuple[Text, float]] +_SweepSequence = List[Dict[str, Any]] +_GeneratorFn = Callable[[float], Tuple[str, float]] def generate_primes(n: int) -> List[int]: @@ -195,10 +195,10 @@ def generate_sequence(num_samples: int, return halton_sequence -def _generate_double_point(name: Text, +def _generate_double_point(name: str, min_val: float, max_val: float, - scaling: Text, + scaling: str, halton_point: float) -> Tuple[str, float]: """Generate a float hyperparameter value from a Halton sequence point.""" if scaling not in ['linear', 'log']: @@ -234,7 +234,7 @@ def interval(start: int, end: int) -> Tuple[int, int]: return start, end -def loguniform(name: Text, range_endpoints: Tuple[int, int]) -> _GeneratorFn: +def loguniform(name: str, range_endpoints: Tuple[int, int]) -> _GeneratorFn: min_val, max_val = range_endpoints return functools.partial(_generate_double_point, name, @@ -244,7 +244,7 @@ def loguniform(name: Text, range_endpoints: Tuple[int, int]) -> _GeneratorFn: def uniform( - name: Text, search_points: Union[_DiscretePoints, + name: str, search_points: Union[_DiscretePoints, Tuple[int, int]]) -> _GeneratorFn: if isinstance(search_points, _DiscretePoints): return functools.partial(_generate_discrete_point, diff --git a/algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py b/algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py index 442c85899..72108c9d9 100644 --- a/algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py +++ b/algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py @@ -16,7 +16,7 @@ from algorithmic_efficiency import param_utils from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.wmt import bleu +#from algorithmic_efficiency.workloads.wmt import bleu from algorithmic_efficiency.workloads.wmt.wmt_jax import decode from algorithmic_efficiency.workloads.wmt.wmt_jax import models from algorithmic_efficiency.workloads.wmt.workload import BaseWmtWorkload diff --git a/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py b/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py index 327ca34ad..b554b2ab3 100644 --- a/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py @@ -16,7 +16,7 @@ from algorithmic_efficiency import param_utils from algorithmic_efficiency import pytorch_utils from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.wmt import bleu +#from algorithmic_efficiency.workloads.wmt import bleu from algorithmic_efficiency.workloads.wmt.wmt_pytorch import decode from algorithmic_efficiency.workloads.wmt.wmt_pytorch.models import Transformer from algorithmic_efficiency.workloads.wmt.workload import BaseWmtWorkload From 55bacbd493c425fda147bc59aa97341f73b1ef17 Mon Sep 17 00:00:00 2001 From: init-22 Date: Tue, 3 Dec 2024 21:24:18 +0530 Subject: [PATCH 045/105] fix: minor yapf --- algorithmic_efficiency/halton.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algorithmic_efficiency/halton.py b/algorithmic_efficiency/halton.py index d710e3fce..1f36b07bf 100644 --- a/algorithmic_efficiency/halton.py +++ b/algorithmic_efficiency/halton.py @@ -245,7 +245,7 @@ def loguniform(name: str, range_endpoints: Tuple[int, int]) -> _GeneratorFn: def uniform( name: str, search_points: Union[_DiscretePoints, - Tuple[int, int]]) -> _GeneratorFn: + Tuple[int, int]]) -> _GeneratorFn: if isinstance(search_points, _DiscretePoints): return functools.partial(_generate_discrete_point, name, From 5eac985fcefc7fa0f93c2e4f28e0d71ca6db7d3d Mon Sep 17 00:00:00 2001 From: init-22 Date: Sat, 7 Dec 2024 21:07:21 +0530 Subject: [PATCH 046/105] fix: going back to sacrebleu v1.3.1 --- algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py | 5 ++--- algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py | 5 ++--- setup.cfg | 2 +- 3 files changed, 5 insertions(+), 7 deletions(-) diff --git a/algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py b/algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py index 72108c9d9..046d5e469 100644 --- a/algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py +++ b/algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py @@ -12,11 +12,10 @@ import jax.numpy as jnp import numpy as np import optax -import sacrebleu from algorithmic_efficiency import param_utils from algorithmic_efficiency import spec -#from algorithmic_efficiency.workloads.wmt import bleu +from algorithmic_efficiency.workloads.wmt import bleu from algorithmic_efficiency.workloads.wmt.wmt_jax import decode from algorithmic_efficiency.workloads.wmt.wmt_jax import models from algorithmic_efficiency.workloads.wmt.workload import BaseWmtWorkload @@ -204,7 +203,7 @@ def translate_and_calculate_bleu(self, predictions.append(self._decode_tokens(predicted[idx])) # Calculate BLEU score for translated eval corpus against reference. - bleu_score = sacrebleu.corpus_bleu(predictions, [references]).score + bleu_score = bleu.corpus_bleu(predictions, [references]).score return bleu_score def init_model_fn( diff --git a/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py b/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py index b554b2ab3..0ba49c2f6 100644 --- a/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py @@ -5,7 +5,6 @@ from absl import logging import jax -import sacrebleu import tensorflow as tf import torch import torch.distributed as dist @@ -16,7 +15,7 @@ from algorithmic_efficiency import param_utils from algorithmic_efficiency import pytorch_utils from algorithmic_efficiency import spec -#from algorithmic_efficiency.workloads.wmt import bleu +from algorithmic_efficiency.workloads.wmt import bleu from algorithmic_efficiency.workloads.wmt.wmt_pytorch import decode from algorithmic_efficiency.workloads.wmt.wmt_pytorch.models import Transformer from algorithmic_efficiency.workloads.wmt.workload import BaseWmtWorkload @@ -163,7 +162,7 @@ def translate_and_calculate_bleu(self, predictions.append(self._decode_tokens(predicted[idx])) # Calculate BLEU score for translated eval corpus against reference. - bleu_score = sacrebleu.corpus_bleu(predictions, [references]).score + bleu_score = bleu.corpus_bleu(predictions, [references]).score return bleu_score def init_model_fn( diff --git a/setup.cfg b/setup.cfg index e8044fe02..a7c224407 100644 --- a/setup.cfg +++ b/setup.cfg @@ -103,7 +103,7 @@ librispeech_conformer = wmt = sentencepiece==0.2.0 tensorflow-text==2.18.0 - sacrebleu==2.4.3 + sacrebleu==1.3.1 # Frameworks # # JAX Core From 786771169b0f9bafe241692ac9411d30fccce62d Mon Sep 17 00:00:00 2001 From: init-22 Date: Tue, 17 Dec 2024 21:13:16 +0530 Subject: [PATCH 047/105] feat: custom tf_addons support in TF2.18 --- .../imagenet_jax/custom_tf_addons.py | 433 ++++++++++++++++++ .../imagenet_jax/randaugment.py | 16 +- 2 files changed, 441 insertions(+), 8 deletions(-) create mode 100644 algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/custom_tf_addons.py diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/custom_tf_addons.py b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/custom_tf_addons.py new file mode 100644 index 000000000..eda67d226 --- /dev/null +++ b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/custom_tf_addons.py @@ -0,0 +1,433 @@ +""" +Note: +The following code is adapted from: +https://github.com/tensorflow/addons/tree/master/tensorflow_addons/image + + +""" + +import math +from typing import Callable, List, Optional, Union + +import numpy as np +import tensorflow as tf + +_IMAGE_DTYPES = { + tf.dtypes.uint8, + tf.dtypes.int32, + tf.dtypes.int64, + tf.dtypes.float16, + tf.dtypes.float32, + tf.dtypes.float64, +} + +Number = Union[float, + int, + np.float16, + np.float32, + np.float64, + np.int8, + np.int16, + np.int32, + np.int64, + np.uint8, + np.uint16, + np.uint32, + np.uint64,] + +TensorLike = Union[List[Union[Number, list]], + tuple, + Number, + np.ndarray, + tf.Tensor, + tf.SparseTensor, + tf.Variable,] + + +def get_ndims(image): + return image.get_shape().ndims or tf.rank(image) + + +def to_4D_image(image): + """Convert 2/3/4D image to 4D image. + + Args: + image: 2/3/4D `Tensor`. + + Returns: + 4D `Tensor` with the same type. + """ + with tf.control_dependencies([ + tf.debugging.assert_rank_in( + image, [2, 3, 4], message="`image` must be 2/3/4D tensor") + ]): + ndims = image.get_shape().ndims + if ndims is None: + return _dynamic_to_4D_image(image) + elif ndims == 2: + return image[None, :, :, None] + elif ndims == 3: + return image[None, :, :, :] + else: + return image + + +def _dynamic_to_4D_image(image): + shape = tf.shape(image) + original_rank = tf.rank(image) + # 4D image => [N, H, W, C] or [N, C, H, W] + # 3D image => [1, H, W, C] or [1, C, H, W] + # 2D image => [1, H, W, 1] + left_pad = tf.cast(tf.less_equal(original_rank, 3), dtype=tf.int32) + right_pad = tf.cast(tf.equal(original_rank, 2), dtype=tf.int32) + new_shape = tf.concat( + [ + tf.ones(shape=left_pad, dtype=tf.int32), + shape, + tf.ones(shape=right_pad, dtype=tf.int32), + ], + axis=0, + ) + return tf.reshape(image, new_shape) + + +def from_4D_image(image, ndims): + """Convert back to an image with `ndims` rank. + + Args: + image: 4D `Tensor`. + ndims: The original rank of the image. + + Returns: + `ndims`-D `Tensor` with the same type. + """ + with tf.control_dependencies( + [tf.debugging.assert_rank(image, 4, + message="`image` must be 4D tensor")]): + if isinstance(ndims, tf.Tensor): + return _dynamic_from_4D_image(image, ndims) + elif ndims == 2: + return tf.squeeze(image, [0, 3]) + elif ndims == 3: + return tf.squeeze(image, [0]) + else: + return image + + +def _dynamic_from_4D_image(image, original_rank): + shape = tf.shape(image) + # 4D image <= [N, H, W, C] or [N, C, H, W] + # 3D image <= [1, H, W, C] or [1, C, H, W] + # 2D image <= [1, H, W, 1] + begin = tf.cast(tf.less_equal(original_rank, 3), dtype=tf.int32) + end = 4 - tf.cast(tf.equal(original_rank, 2), dtype=tf.int32) + new_shape = shape[begin:end] + return tf.reshape(image, new_shape) + + +def transform( + images: TensorLike, + transforms: TensorLike, + interpolation: str = "nearest", + fill_mode: str = "constant", + output_shape: Optional[list] = None, + name: Optional[str] = None, + fill_value: TensorLike = 0.0, +) -> tf.Tensor: + """Applies the given transform(s) to the image(s). + + Args: + images: A tensor of shape (num_images, num_rows, num_columns, + num_channels) (NHWC), (num_rows, num_columns, num_channels) (HWC), or + (num_rows, num_columns) (HW). + transforms: Projective transform matrix/matrices. A vector of length 8 or + tensor of size N x 8. If one row of transforms is + [a0, a1, a2, b0, b1, b2, c0, c1], then it maps the *output* point + `(x, y)` to a transformed *input* point + `(x', y') = ((a0 x + a1 y + a2) / k, (b0 x + b1 y + b2) / k)`, + where `k = c0 x + c1 y + 1`. The transforms are *inverted* compared to + the transform mapping input points to output points. Note that + gradients are not backpropagated into transformation parameters. + interpolation: Interpolation mode. + Supported values: "nearest", "bilinear". + fill_mode: Points outside the boundaries of the input are filled according + to the given mode (one of `{'constant', 'reflect', 'wrap', 'nearest'}`). + - *reflect*: `(d c b a | a b c d | d c b a)` + The input is extended by reflecting about the edge of the last pixel. + - *constant*: `(k k k k | a b c d | k k k k)` + The input is extended by filling all values beyond the edge with the + same constant value k = 0. + - *wrap*: `(a b c d | a b c d | a b c d)` + The input is extended by wrapping around to the opposite edge. + - *nearest*: `(a a a a | a b c d | d d d d)` + The input is extended by the nearest pixel. + fill_value: a float represents the value to be filled outside the + boundaries when `fill_mode` is "constant". + output_shape: Output dimesion after the transform, [height, width]. + If None, output is the same size as input image. + + name: The name of the op. + + Returns: + Image(s) with the same type and shape as `images`, with the given + transform(s) applied. Transformed coordinates outside of the input image + will be filled with zeros. + + Raises: + TypeError: If `image` is an invalid type. + ValueError: If output shape is not 1-D int32 Tensor. + """ + with tf.name_scope(name or "transform"): + image_or_images = tf.convert_to_tensor(images, name="images") + transform_or_transforms = tf.convert_to_tensor( + transforms, name="transforms", dtype=tf.dtypes.float32) + if image_or_images.dtype.base_dtype not in _IMAGE_DTYPES: + raise TypeError("Invalid dtype %s." % image_or_images.dtype) + images = to_4D_image(image_or_images) + original_ndims = get_ndims(image_or_images) + + if output_shape is None: + output_shape = tf.shape(images)[1:3] + + output_shape = tf.convert_to_tensor( + output_shape, tf.dtypes.int32, name="output_shape") + + if not output_shape.get_shape().is_compatible_with([2]): + raise ValueError("output_shape must be a 1-D Tensor of 2 elements: " + "new_height, new_width") + + if len(transform_or_transforms.get_shape()) == 1: + transforms = transform_or_transforms[None] + elif transform_or_transforms.get_shape().ndims is None: + raise ValueError("transforms rank must be statically known") + elif len(transform_or_transforms.get_shape()) == 2: + transforms = transform_or_transforms + else: + transforms = transform_or_transforms + raise ValueError("transforms should have rank 1 or 2, but got rank %d" % + len(transforms.get_shape())) + + fill_value = tf.convert_to_tensor( + fill_value, dtype=tf.float32, name="fill_value") + output = tf.raw_ops.ImageProjectiveTransformV3( + images=images, + transforms=transforms, + output_shape=output_shape, + interpolation=interpolation.upper(), + fill_mode=fill_mode.upper(), + fill_value=fill_value, + ) + return from_4D_image(output, original_ndims) + + +def angles_to_projective_transforms( + angles: TensorLike, + image_height: TensorLike, + image_width: TensorLike, + name: Optional[str] = None, +) -> tf.Tensor: + """Returns projective transform(s) for the given angle(s). + + Args: + angles: A scalar angle to rotate all images by, or (for batches of + images) a vector with an angle to rotate each image in the batch. The + rank must be statically known (the shape is not `TensorShape(None)`. + image_height: Height of the image(s) to be transformed. + image_width: Width of the image(s) to be transformed. + + Returns: + A tensor of shape (num_images, 8). Projective transforms which can be + given to `transform` op. + """ + with tf.name_scope(name or "angles_to_projective_transforms"): + angle_or_angles = tf.convert_to_tensor( + angles, name="angles", dtype=tf.dtypes.float32) + if len(angle_or_angles.get_shape()) == 0: + angles = angle_or_angles[None] + elif len(angle_or_angles.get_shape()) == 1: + angles = angle_or_angles + else: + raise ValueError("angles should have rank 0 or 1.") + cos_angles = tf.math.cos(angles) + sin_angles = tf.math.sin(angles) + x_offset = ((image_width - 1) - + (cos_angles * (image_width - 1) - sin_angles * + (image_height - 1))) / 2.0 + y_offset = ((image_height - 1) - + (sin_angles * (image_width - 1) + cos_angles * + (image_height - 1))) / 2.0 + num_angles = tf.shape(angles)[0] + return tf.concat( + values=[ + cos_angles[:, None], + -sin_angles[:, None], + x_offset[:, None], + sin_angles[:, None], + cos_angles[:, None], + y_offset[:, None], + tf.zeros((num_angles, 2), tf.dtypes.float32), + ], + axis=1, + ) + + +def rotate( + images: TensorLike, + angles: TensorLike, + interpolation: str = "nearest", + fill_mode: str = "constant", + name: Optional[str] = None, + fill_value: TensorLike = 0.0, +) -> tf.Tensor: + """Rotate image(s) counterclockwise by the passed angle(s) in radians. + + Args: + images: A tensor of shape + `(num_images, num_rows, num_columns, num_channels)` + (NHWC), `(num_rows, num_columns, num_channels)` (HWC), or + `(num_rows, num_columns)` (HW). + angles: A scalar angle to rotate all images by, or (if `images` has rank 4) + a vector of length num_images, with an angle for each image in the + batch. + interpolation: Interpolation mode. Supported values: "nearest", + "bilinear". + fill_mode: Points outside the boundaries of the input are filled according + to the given mode (one of `{'constant', 'reflect', 'wrap', 'nearest'}`). + - *reflect*: `(d c b a | a b c d | d c b a)` + The input is extended by reflecting about the edge of the last pixel. + - *constant*: `(k k k k | a b c d | k k k k)` + The input is extended by filling all values beyond the edge with the + same constant value k = 0. + - *wrap*: `(a b c d | a b c d | a b c d)` + The input is extended by wrapping around to the opposite edge. + - *nearest*: `(a a a a | a b c d | d d d d)` + The input is extended by the nearest pixel. + fill_value: a float represents the value to be filled outside the + boundaries when `fill_mode` is "constant". + name: The name of the op. + + Returns: + Image(s) with the same type and shape as `images`, rotated by the given + angle(s). Empty space due to the rotation will be filled with zeros. + + Raises: + TypeError: If `images` is an invalid type. + """ + with tf.name_scope(name or "rotate"): + image_or_images = tf.convert_to_tensor(images) + if image_or_images.dtype.base_dtype not in _IMAGE_DTYPES: + raise TypeError("Invalid dtype %s." % image_or_images.dtype) + images = to_4D_image(image_or_images) + original_ndims = get_ndims(image_or_images) + + image_height = tf.cast(tf.shape(images)[1], tf.dtypes.float32)[None] + image_width = tf.cast(tf.shape(images)[2], tf.dtypes.float32)[None] + output = transform( + images, + angles_to_projective_transforms(angles, image_height, image_width), + interpolation=interpolation, + fill_mode=fill_mode, + fill_value=fill_value, + ) + return from_4D_image(output, original_ndims) + + +def translations_to_projective_transforms(translations: TensorLike, + name: Optional[str] = None + ) -> tf.Tensor: + """Returns projective transform(s) for the given translation(s). + + Args: + translations: A 2-element list representing `[dx, dy]` or a matrix of + 2-element lists representing `[dx, dy]` to translate for each image + (for a batch of images). The rank must be statically known + (the shape is not `TensorShape(None)`). + name: The name of the op. + Returns: + A tensor of shape `(num_images, 8)` projective transforms which can be + given to `tfa.image.transform`. + """ + with tf.name_scope(name or "translations_to_projective_transforms"): + translation_or_translations = tf.convert_to_tensor( + translations, name="translations", dtype=tf.dtypes.float32) + if translation_or_translations.get_shape().ndims is None: + raise TypeError( + "translation_or_translations rank must be statically known") + elif len(translation_or_translations.get_shape()) == 1: + translations = translation_or_translations[None] + elif len(translation_or_translations.get_shape()) == 2: + translations = translation_or_translations + else: + raise TypeError("Translations should have rank 1 or 2.") + num_translations = tf.shape(translations)[0] + # The translation matrix looks like: + # [[1 0 -dx] + # [0 1 -dy] + # [0 0 1]] + # where the last entry is implicit. + # Translation matrices are always float32. + return tf.concat( + values=[ + tf.ones((num_translations, 1), tf.dtypes.float32), + tf.zeros((num_translations, 1), tf.dtypes.float32), + -translations[:, 0, None], + tf.zeros((num_translations, 1), tf.dtypes.float32), + tf.ones((num_translations, 1), tf.dtypes.float32), + -translations[:, 1, None], + tf.zeros((num_translations, 2), tf.dtypes.float32), + ], + axis=1, + ) + + +@tf.function +def translate( + images: TensorLike, + translations: TensorLike, + interpolation: str = "nearest", + fill_mode: str = "constant", + name: Optional[str] = None, + fill_value: TensorLike = 0.0, +) -> tf.Tensor: + """Translate image(s) by the passed vectors(s). + + Args: + images: A tensor of shape + `(num_images, num_rows, num_columns, num_channels)` (NHWC), + `(num_rows, num_columns, num_channels)` (HWC), or + `(num_rows, num_columns)` (HW). The rank must be statically known (the + shape is not `TensorShape(None)`). + translations: A vector representing `[dx, dy]` or (if `images` has rank 4) + a matrix of length num_images, with a `[dx, dy]` vector for each image + in the batch. + interpolation: Interpolation mode. Supported values: "nearest", + "bilinear". + fill_mode: Points outside the boundaries of the input are filled according + to the given mode (one of `{'constant', 'reflect', 'wrap', 'nearest'}`). + - *reflect*: `(d c b a | a b c d | d c b a)` + The input is extended by reflecting about the edge of the last pixel. + - *constant*: `(k k k k | a b c d | k k k k)` + The input is extended by filling all values beyond the edge with the + same constant value k = 0. + - *wrap*: `(a b c d | a b c d | a b c d)` + The input is extended by wrapping around to the opposite edge. + - *nearest*: `(a a a a | a b c d | d d d d)` + The input is extended by the nearest pixel. + fill_value: a float represents the value to be filled outside the + boundaries when `fill_mode` is "constant". + name: The name of the op. + Returns: + Image(s) with the same type and shape as `images`, translated by the + given vector(s). Empty space due to the translation will be filled with + zeros. + Raises: + TypeError: If `images` is an invalid type. + """ + with tf.name_scope(name or "translate"): + return transform( + images, + translations_to_projective_transforms(translations), + interpolation=interpolation, + fill_mode=fill_mode, + fill_value=fill_value, + ) diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py index af1b763c1..f3a946245 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py +++ b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py @@ -9,7 +9,9 @@ import tensorflow as tf -#from tensorflow_addons import image as contrib_image +from .custom_tf_addons import rotate +from .custom_tf_addons import transform +from .custom_tf_addons import translate # This signifies the max integer that the controller RNN could predict for the # augmentation scheme. @@ -177,19 +179,19 @@ def rotate(image, degrees, replace): # In practice, we should randomize the rotation degrees by flipping # it negatively half the time, but that's done on 'degrees' outside # of the function. - image = contrib_image.rotate(wrap(image), radians) + image = rotate(wrap(image), radians) return unwrap(image, replace) def translate_x(image, pixels, replace): """Equivalent of PIL Translate in X dimension.""" - image = contrib_image.translate(wrap(image), [-pixels, 0]) + image = translate(wrap(image), [-pixels, 0]) return unwrap(image, replace) def translate_y(image, pixels, replace): """Equivalent of PIL Translate in Y dimension.""" - image = contrib_image.translate(wrap(image), [0, -pixels]) + image = translate(wrap(image), [0, -pixels]) return unwrap(image, replace) @@ -199,8 +201,7 @@ def shear_x(image, level, replace): # with a matrix form of: # [1 level # 0 1]. - image = contrib_image.transform( - wrap(image), [1., level, 0., 0., 1., 0., 0., 0.]) + image = transform(wrap(image), [1., level, 0., 0., 1., 0., 0., 0.]) return unwrap(image, replace) @@ -210,8 +211,7 @@ def shear_y(image, level, replace): # with a matrix form of: # [1 0 # level 1]. - image = contrib_image.transform( - wrap(image), [1., 0., 0., level, 1., 0., 0., 0.]) + image = transform(wrap(image), [1., 0., 0., level, 1., 0., 0., 0.]) return unwrap(image, replace) From d6dd2e8e16145e73f69664bc81690ac06857319b Mon Sep 17 00:00:00 2001 From: init-22 Date: Tue, 17 Dec 2024 21:50:11 +0530 Subject: [PATCH 048/105] fix: resolving pylint issues in custom_tf_addons --- .../imagenet_jax/custom_tf_addons.py | 27 +++++++++---------- .../imagenet_jax/randaugment.py | 4 +-- 2 files changed, 15 insertions(+), 16 deletions(-) diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/custom_tf_addons.py b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/custom_tf_addons.py index eda67d226..79aef6791 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/custom_tf_addons.py +++ b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/custom_tf_addons.py @@ -6,8 +6,7 @@ """ -import math -from typing import Callable, List, Optional, Union +from typing import List, Optional, Union import numpy as np import tensorflow as tf @@ -48,7 +47,7 @@ def get_ndims(image): return image.get_shape().ndims or tf.rank(image) -def to_4D_image(image): +def to_4d_image(image): """Convert 2/3/4D image to 4D image. Args: @@ -63,7 +62,7 @@ def to_4D_image(image): ]): ndims = image.get_shape().ndims if ndims is None: - return _dynamic_to_4D_image(image) + return _dynamic_to_4d_image(image) elif ndims == 2: return image[None, :, :, None] elif ndims == 3: @@ -72,7 +71,7 @@ def to_4D_image(image): return image -def _dynamic_to_4D_image(image): +def _dynamic_to_4d_image(image): shape = tf.shape(image) original_rank = tf.rank(image) # 4D image => [N, H, W, C] or [N, C, H, W] @@ -91,7 +90,7 @@ def _dynamic_to_4D_image(image): return tf.reshape(image, new_shape) -def from_4D_image(image, ndims): +def from_4d_image(image, ndims): """Convert back to an image with `ndims` rank. Args: @@ -105,7 +104,7 @@ def from_4D_image(image, ndims): [tf.debugging.assert_rank(image, 4, message="`image` must be 4D tensor")]): if isinstance(ndims, tf.Tensor): - return _dynamic_from_4D_image(image, ndims) + return _dynamic_from_4d_image(image, ndims) elif ndims == 2: return tf.squeeze(image, [0, 3]) elif ndims == 3: @@ -114,7 +113,7 @@ def from_4D_image(image, ndims): return image -def _dynamic_from_4D_image(image, original_rank): +def _dynamic_from_4d_image(image, original_rank): shape = tf.shape(image) # 4D image <= [N, H, W, C] or [N, C, H, W] # 3D image <= [1, H, W, C] or [1, C, H, W] @@ -183,7 +182,7 @@ def transform( transforms, name="transforms", dtype=tf.dtypes.float32) if image_or_images.dtype.base_dtype not in _IMAGE_DTYPES: raise TypeError("Invalid dtype %s." % image_or_images.dtype) - images = to_4D_image(image_or_images) + images = to_4d_image(image_or_images) original_ndims = get_ndims(image_or_images) if output_shape is None: @@ -217,7 +216,7 @@ def transform( fill_mode=fill_mode.upper(), fill_value=fill_value, ) - return from_4D_image(output, original_ndims) + return from_4d_image(output, original_ndims) def angles_to_projective_transforms( @@ -271,7 +270,7 @@ def angles_to_projective_transforms( ) -def rotate( +def rotate_img( images: TensorLike, angles: TensorLike, interpolation: str = "nearest", @@ -286,7 +285,7 @@ def rotate( `(num_images, num_rows, num_columns, num_channels)` (NHWC), `(num_rows, num_columns, num_channels)` (HWC), or `(num_rows, num_columns)` (HW). - angles: A scalar angle to rotate all images by, or (if `images` has rank 4) + angles: A scalar angle to rotate all images by (if `images` has rank 4) a vector of length num_images, with an angle for each image in the batch. interpolation: Interpolation mode. Supported values: "nearest", @@ -317,7 +316,7 @@ def rotate( image_or_images = tf.convert_to_tensor(images) if image_or_images.dtype.base_dtype not in _IMAGE_DTYPES: raise TypeError("Invalid dtype %s." % image_or_images.dtype) - images = to_4D_image(image_or_images) + images = to_4d_image(image_or_images) original_ndims = get_ndims(image_or_images) image_height = tf.cast(tf.shape(images)[1], tf.dtypes.float32)[None] @@ -329,7 +328,7 @@ def rotate( fill_mode=fill_mode, fill_value=fill_value, ) - return from_4D_image(output, original_ndims) + return from_4d_image(output, original_ndims) def translations_to_projective_transforms(translations: TensorLike, diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py index f3a946245..dd00146cd 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py +++ b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py @@ -9,7 +9,7 @@ import tensorflow as tf -from .custom_tf_addons import rotate +from .custom_tf_addons import rotate_img from .custom_tf_addons import transform from .custom_tf_addons import translate @@ -179,7 +179,7 @@ def rotate(image, degrees, replace): # In practice, we should randomize the rotation degrees by flipping # it negatively half the time, but that's done on 'degrees' outside # of the function. - image = rotate(wrap(image), radians) + image = rotate_img(wrap(image), radians) return unwrap(image, replace) From a0b587aed0ccecb794a46e2ba99713c56ed69f93 Mon Sep 17 00:00:00 2001 From: init-22 Date: Tue, 17 Dec 2024 22:04:59 +0530 Subject: [PATCH 049/105] resolved pyline and changed the pylint version to current version of main --- .../imagenet_jax/custom_tf_addons.py | 20 ++++++++++++------- setup.cfg | 2 +- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/custom_tf_addons.py b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/custom_tf_addons.py index 79aef6791..3d6939218 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/custom_tf_addons.py +++ b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/custom_tf_addons.py @@ -241,12 +241,15 @@ def angles_to_projective_transforms( with tf.name_scope(name or "angles_to_projective_transforms"): angle_or_angles = tf.convert_to_tensor( angles, name="angles", dtype=tf.dtypes.float32) + + if len(angle_or_angles.get_shape()) not in (0, 1): + raise ValueError("angles should have rank 0 or 1.") + if len(angle_or_angles.get_shape()) == 0: angles = angle_or_angles[None] - elif len(angle_or_angles.get_shape()) == 1: - angles = angle_or_angles else: - raise ValueError("angles should have rank 0 or 1.") + angles = angle_or_angles + cos_angles = tf.math.cos(angles) sin_angles = tf.math.sin(angles) x_offset = ((image_width - 1) - @@ -352,12 +355,15 @@ def translations_to_projective_transforms(translations: TensorLike, if translation_or_translations.get_shape().ndims is None: raise TypeError( "translation_or_translations rank must be statically known") - elif len(translation_or_translations.get_shape()) == 1: + + if len(translation_or_translations.get_shape()) not in (1, 2): + raise TypeError("Translations should have rank 1 or 2.") + + if len(translation_or_translations.get_shape()) == 1: translations = translation_or_translations[None] - elif len(translation_or_translations.get_shape()) == 2: - translations = translation_or_translations else: - raise TypeError("Translations should have rank 1 or 2.") + translations = translation_or_translations + num_translations = tf.shape(translations)[0] # The translation matrix looks like: # [[1 0 -dx] diff --git a/setup.cfg b/setup.cfg index a7c224407..7977267bd 100644 --- a/setup.cfg +++ b/setup.cfg @@ -78,7 +78,7 @@ full_dev = # Dependencies for developing the package dev = isort==5.13.2 - pylint==3.3.1 + pylint==2.16.1 pytest==8.3.3 yapf==0.32.0 pre-commit==4.0.1 From 9393145ba91b9432c1732f5bd9d8865c2cb232f8 Mon Sep 17 00:00:00 2001 From: init-22 Date: Wed, 18 Dec 2024 20:58:42 +0530 Subject: [PATCH 050/105] fix: removing tensorflow addons from setup cfg --- setup.cfg | 1 - 1 file changed, 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index 7977267bd..2d246b48b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -40,7 +40,6 @@ install_requires = pandas==2.2.3 tensorflow==2.18.0 tensorflow-datasets==4.9.7 - tensorflow-addons==0.23.0 gputil==1.4.0 psutil==6.1.0 clu==0.0.12 From 53eff1d469635408aff5d80a28f3248c4bd79464 Mon Sep 17 00:00:00 2001 From: init-22 Date: Fri, 20 Dec 2024 00:41:47 +0530 Subject: [PATCH 051/105] fix: adding absolute paths for custom_tf_addons in randaugment --- .../imagenet_resnet/imagenet_jax/randaugment.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py index dd00146cd..e920331bc 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py +++ b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py @@ -9,9 +9,12 @@ import tensorflow as tf -from .custom_tf_addons import rotate_img -from .custom_tf_addons import transform -from .custom_tf_addons import translate +from algorithmic_efficiency.workloads.imagenet_resnet.imagenet_jax.custom_tf_addons import \ + rotate_img +from algorithmic_efficiency.workloads.imagenet_resnet.imagenet_jax.custom_tf_addons import \ + transform +from algorithmic_efficiency.workloads.imagenet_resnet.imagenet_jax.custom_tf_addons import \ + translate # This signifies the max integer that the controller RNN could predict for the # augmentation scheme. From d21d8205d565c94d82b312709491deac0b31de31 Mon Sep 17 00:00:00 2001 From: init-22 Date: Sun, 22 Dec 2024 16:11:59 +0530 Subject: [PATCH 052/105] fix: changes jax.tree_map to jax.tree.map --- algorithmic_efficiency/data_utils.py | 2 +- algorithmic_efficiency/param_utils.py | 2 +- .../workloads/cifar/cifar_jax/workload.py | 2 +- .../imagenet_resnet/imagenet_jax/workload.py | 2 +- .../workloads/mnist/mnist_jax/workload.py | 2 +- .../workloads/ogbg/input_pipeline.py | 4 +-- .../workloads/ogbg/ogbg_pytorch/workload.py | 6 ++-- .../workloads/wmt/wmt_jax/decode.py | 8 ++--- .../workloads/wmt/wmt_jax/workload.py | 4 +-- .../workloads/wmt/wmt_pytorch/decode.py | 8 ++--- .../workloads/wmt/wmt_pytorch/workload.py | 2 +- .../external_tuning/jax_nadamw_full_budget.py | 16 +++++----- .../jax_nadamw_target_setting.py | 16 +++++----- .../self_tuning/jax_nadamw_full_budget.py | 16 +++++----- .../self_tuning/jax_nadamw_target_setting.py | 16 +++++----- .../cifar/cifar_jax/submission.py | 2 +- .../mnist/mnist_jax/submission.py | 2 +- .../adafactor/jax/sharded_adafactor.py | 16 +++++----- .../adafactor/jax/submission.py | 6 ++-- .../paper_baselines/adamw/jax/submission.py | 6 ++-- .../paper_baselines/lamb/jax/submission.py | 6 ++-- .../momentum/jax/submission.py | 6 ++-- .../paper_baselines/nadamw/jax/submission.py | 16 +++++----- .../nesterov/jax/submission.py | 6 ++-- .../paper_baselines/sam/jax/submission.py | 10 +++---- .../shampoo/jax/distributed_shampoo.py | 30 +++++++++---------- .../paper_baselines/shampoo/jax/submission.py | 6 ++-- .../target_setting_algorithms/jax_adamw.py | 2 +- .../target_setting_algorithms/jax_momentum.py | 2 +- .../target_setting_algorithms/jax_nadamw.py | 12 ++++---- .../target_setting_algorithms/jax_nesterov.py | 2 +- .../jax_submission_base.py | 4 +-- tests/modeldiffs/vanilla_sgd_jax.py | 2 +- tests/reference_algorithm_tests.py | 4 +-- .../imagenet_jax/workload_test.py | 2 +- 35 files changed, 124 insertions(+), 124 deletions(-) diff --git a/algorithmic_efficiency/data_utils.py b/algorithmic_efficiency/data_utils.py index 901f0b582..38a76381f 100644 --- a/algorithmic_efficiency/data_utils.py +++ b/algorithmic_efficiency/data_utils.py @@ -65,7 +65,7 @@ def _prepare(x): # Assumes that `global_batch_size % local_device_count == 0`. return x.reshape((local_device_count, -1, *x.shape[1:])) - return jax.tree_map(_prepare, batch) + return jax.tree.map(_prepare, batch) def pad(tensor: np.ndarray, diff --git a/algorithmic_efficiency/param_utils.py b/algorithmic_efficiency/param_utils.py index b430366b1..916eb8728 100644 --- a/algorithmic_efficiency/param_utils.py +++ b/algorithmic_efficiency/param_utils.py @@ -66,7 +66,7 @@ def pytorch_param_types( def jax_param_shapes( params: spec.ParameterContainer) -> spec.ParameterShapeTree: - return jax.tree_map(lambda x: spec.ShapeTuple(x.shape), params) + return jax.tree.map(lambda x: spec.ShapeTuple(x.shape), params) def jax_param_types(param_shapes: spec.ParameterShapeTree, diff --git a/algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py b/algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py index 6bbf9c64b..60f15c2f0 100644 --- a/algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py +++ b/algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py @@ -207,4 +207,4 @@ def _normalize_eval_metrics( self, num_examples: int, total_metrics: Dict[str, Any]) -> Dict[str, float]: """Normalize eval metrics.""" - return jax.tree_map(lambda x: float(x[0] / num_examples), total_metrics) + return jax.tree.map(lambda x: float(x[0] / num_examples), total_metrics) diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/workload.py b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/workload.py index 91cdec60a..4366fcf25 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/workload.py +++ b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/workload.py @@ -264,7 +264,7 @@ def _eval_model_on_split(self, eval_metrics[metric_name] = 0.0 eval_metrics[metric_name] += metric_value - eval_metrics = jax.tree_map(lambda x: float(x[0] / num_examples), + eval_metrics = jax.tree.map(lambda x: float(x[0] / num_examples), eval_metrics) return eval_metrics diff --git a/algorithmic_efficiency/workloads/mnist/mnist_jax/workload.py b/algorithmic_efficiency/workloads/mnist/mnist_jax/workload.py index efbd73e33..dcb0b6f36 100644 --- a/algorithmic_efficiency/workloads/mnist/mnist_jax/workload.py +++ b/algorithmic_efficiency/workloads/mnist/mnist_jax/workload.py @@ -132,4 +132,4 @@ def _normalize_eval_metrics( self, num_examples: int, total_metrics: Dict[str, Any]) -> Dict[str, float]: """Normalize eval metrics.""" - return jax.tree_map(lambda x: float(x[0] / num_examples), total_metrics) + return jax.tree.map(lambda x: float(x[0] / num_examples), total_metrics) diff --git a/algorithmic_efficiency/workloads/ogbg/input_pipeline.py b/algorithmic_efficiency/workloads/ogbg/input_pipeline.py index a301d677a..3cb6f51de 100644 --- a/algorithmic_efficiency/workloads/ogbg/input_pipeline.py +++ b/algorithmic_efficiency/workloads/ogbg/input_pipeline.py @@ -51,7 +51,7 @@ def _load_dataset(split, should_shuffle, data_rng, data_dir): def _to_jraph(example): """Converts an example graph to jraph.GraphsTuple.""" - example = jax.tree_map(lambda x: x._numpy(), example) # pylint: disable=protected-access + example = jax.tree.map(lambda x: x._numpy(), example) # pylint: disable=protected-access edge_feat = example['edge_feat'] node_feat = example['node_feat'] edge_index = example['edge_index'] @@ -150,7 +150,7 @@ def _get_batch_iterator(dataset_iter, global_batch_size, num_shards=None): if count == num_shards: def f(x): - return jax.tree_map(lambda *vals: np.stack(vals, axis=0), x[0], *x[1:]) + return jax.tree.map(lambda *vals: np.stack(vals, axis=0), x[0], *x[1:]) graphs_shards = f(graphs_shards) labels_shards = f(labels_shards) diff --git a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py index d4817226d..e66a7a151 100644 --- a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py @@ -20,8 +20,8 @@ def _pytorch_map(inputs: Any) -> Any: if USE_PYTORCH_DDP: - return jax.tree_map(lambda a: torch.as_tensor(a, device=DEVICE), inputs) - return jax.tree_map( + return jax.tree.map(lambda a: torch.as_tensor(a, device=DEVICE), inputs) + return jax.tree.map( lambda a: torch.as_tensor(a, device=DEVICE).view(-1, a.shape[-1]) if len(a.shape) == 3 else torch.as_tensor(a, device=DEVICE).view(-1), inputs) @@ -30,7 +30,7 @@ def _pytorch_map(inputs: Any) -> Any: def _shard(inputs: Any) -> Any: if not USE_PYTORCH_DDP: return inputs - return jax.tree_map(lambda tensor: tensor[RANK], inputs) + return jax.tree.map(lambda tensor: tensor[RANK], inputs) def _graph_map(function: Callable, graph: GraphsTuple) -> GraphsTuple: diff --git a/algorithmic_efficiency/workloads/wmt/wmt_jax/decode.py b/algorithmic_efficiency/workloads/wmt/wmt_jax/decode.py index 85d0eaac4..dfead5918 100644 --- a/algorithmic_efficiency/workloads/wmt/wmt_jax/decode.py +++ b/algorithmic_efficiency/workloads/wmt/wmt_jax/decode.py @@ -86,7 +86,7 @@ def gather_fn(x): return x return x[batch_indices, beam_indices] - return jax.tree_map(gather_fn, nested) + return jax.tree.map(gather_fn, nested) def gather_topk_beams(nested, score_or_log_prob, batch_size, new_beam_size): @@ -139,7 +139,7 @@ def beam_init(batch_size, beam_size, max_decode_len, cache): finished_seqs0 = jnp.zeros((batch_size, beam_size, max_decode_len), jnp.int32) finished_flags0 = jnp.zeros((batch_size, beam_size), jnp.bool_) # add beam dimension to attention cache pytree elements - beam_cache0 = jax.tree_map(lambda x: add_beam_dim(x, beam_size), cache) + beam_cache0 = jax.tree.map(lambda x: add_beam_dim(x, beam_size), cache) return BeamState( cur_index=cur_index0, live_logprobs=live_logprobs0, @@ -225,7 +225,7 @@ def beam_search_loop_body_fn(state): (batch_size, beam_size, 1))) # Flatten beam dimension into batch to be compatible with model. # {[batch, beam, ...], ...} --> {[batch * beam, ...], ...} - flat_cache = jax.tree_map(flatten_beam_dim, state.cache) + flat_cache = jax.tree.map(flatten_beam_dim, state.cache) # Call fast-decoder model on current tokens to get next-position logits. # --> [batch * beam, vocab] @@ -236,7 +236,7 @@ def beam_search_loop_body_fn(state): logits = unflatten_beam_dim(flat_logits, batch_size, beam_size) # Unflatten beam dimension in attention cache arrays # {[batch * beam, ...], ...} --> {[batch, beam, ...], ...} - new_cache = jax.tree_map( + new_cache = jax.tree.map( lambda x: unflatten_beam_dim(x, batch_size, beam_size), new_flat_cache) # Gather log probabilities from logits diff --git a/algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py b/algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py index 046d5e469..dd6728450 100644 --- a/algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py +++ b/algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py @@ -94,7 +94,7 @@ def eval_step(self, params: spec.ParameterContainer, batch: Dict[str, spec.Tensor]) -> Dict[str, spec.Tensor]: replicated_eval_metrics = self.eval_step_pmapped(params, batch) - return jax.tree_map(lambda x: jnp.sum(x, axis=0), replicated_eval_metrics) + return jax.tree.map(lambda x: jnp.sum(x, axis=0), replicated_eval_metrics) @functools.partial( jax.pmap, axis_name='batch', static_broadcasted_argnums=(0,)) @@ -291,7 +291,7 @@ def _normalize_eval_metrics( """Normalize eval metrics.""" del num_examples eval_denominator = total_metrics.pop('denominator') - return jax.tree_map(lambda x: float(x / eval_denominator), total_metrics) + return jax.tree.map(lambda x: float(x / eval_denominator), total_metrics) class WmtWorkloadPostLN(WmtWorkload): diff --git a/algorithmic_efficiency/workloads/wmt/wmt_pytorch/decode.py b/algorithmic_efficiency/workloads/wmt/wmt_pytorch/decode.py index 0488a144f..078560c36 100644 --- a/algorithmic_efficiency/workloads/wmt/wmt_pytorch/decode.py +++ b/algorithmic_efficiency/workloads/wmt/wmt_pytorch/decode.py @@ -98,7 +98,7 @@ def gather_fn(x): return x return x[batch_indices, beam_indices] - return jax.tree_map(gather_fn, nested) + return jax.tree.map(gather_fn, nested) def gather_topk_beams(nested: Dict[str, Any], @@ -164,7 +164,7 @@ def beam_init(batch_size: int, dtype=torch.bool, device=DEVICE) # add beam dimension to attention cache pytree elements - beam_cache0 = jax.tree_map(lambda x: add_beam_dim(x, beam_size), cache) + beam_cache0 = jax.tree.map(lambda x: add_beam_dim(x, beam_size), cache) return BeamState( cur_index=cur_index0, live_logprobs=live_logprobs0, @@ -251,7 +251,7 @@ def beam_search_loop_body_fn(state: BeamState) -> BeamState: state.live_seqs[:batch_size, :beam_size, cur_index:cur_index + 1]) # Flatten beam dimension into batch to be compatible with model. # {[batch, beam, ...], ...} --> {[batch * beam, ...], ...} - flat_cache = jax.tree_map(flatten_beam_dim, state.cache) + flat_cache = jax.tree.map(flatten_beam_dim, state.cache) # Call fast-decoder model on current tokens to get next-position logits. # --> [batch * beam, vocab] @@ -262,7 +262,7 @@ def beam_search_loop_body_fn(state: BeamState) -> BeamState: logits = unflatten_beam_dim(flat_logits, batch_size, beam_size) # Unflatten beam dimension in attention cache arrays # {[batch * beam, ...], ...} --> {[batch, beam, ...], ...} - new_cache = jax.tree_map( + new_cache = jax.tree.map( lambda x: unflatten_beam_dim(x, batch_size, beam_size), new_flat_cache) # Gather log probabilities from logits diff --git a/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py b/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py index 0ba49c2f6..9c1c21e93 100644 --- a/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py @@ -347,7 +347,7 @@ def _normalize_eval_metrics( dist.all_reduce(metric) total_metrics = {k: v.item() for k, v in total_metrics.items()} eval_denominator = total_metrics.pop('denominator') - return jax.tree_map(lambda x: float(x / eval_denominator), total_metrics) + return jax.tree.map(lambda x: float(x / eval_denominator), total_metrics) class WmtWorkloadPostLN(WmtWorkload): diff --git a/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py b/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py index 36e7e5607..30f9068d1 100644 --- a/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py +++ b/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py @@ -111,8 +111,8 @@ def scale_by_nadam(b1: float = 0.9, raise_power = jnp.sqrt if power == 0.5 else lambda x: jnp.power(x, power) def init_fn(params): - mu = jax.tree_map(jnp.zeros_like, params) # First moment - nu = jax.tree_map(jnp.zeros_like, params) # Second moment + mu = jax.tree.map(jnp.zeros_like, params) # First moment + nu = jax.tree.map(jnp.zeros_like, params) # Second moment return ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu) def update_fn(updates, state, params=None): @@ -123,7 +123,7 @@ def update_fn(updates, state, params=None): mu_hat = _update_moment(updates, mu, b1, 1) mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) nu_hat = nu if not debias else _bias_correction(nu, b2, count) - updates = jax.tree_map( + updates = jax.tree.map( lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) @@ -139,14 +139,14 @@ class ScaleByAdamState(NamedTuple): def _update_moment(updates, moments, decay, order): """Compute the exponential moving average of the `order-th` moment.""" - return jax.tree_map( + return jax.tree.map( lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) def _bias_correction(moment, decay, count): """Perform bias correction. This becomes a no-op as count goes to infinity.""" beta = 1 - decay**count - return jax.tree_map(lambda t: t / beta.astype(t.dtype), moment) + return jax.tree.map(lambda t: t / beta.astype(t.dtype), moment) def scale_by_learning_rate(learning_rate, flip_sign=True): @@ -188,7 +188,7 @@ def jax_cosine_warmup(step_hint: int, hyperparameters): b2=hyperparameters.beta2, eps=1e-8, weight_decay=hyperparameters.weight_decay) - params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), + params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes) optimizer_state = opt_init_fn(params_zeros_like) @@ -236,7 +236,7 @@ def _loss_fn(params): (summed_loss, n_valid_examples, grad) = lax.psum( (summed_loss, n_valid_examples, grad), axis_name='batch') loss = summed_loss / n_valid_examples - grad = jax.tree_map(lambda x: x / n_valid_examples, grad) + grad = jax.tree.map(lambda x: x / n_valid_examples, grad) grad_norm = jnp.sqrt( sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) @@ -244,7 +244,7 @@ def _loss_fn(params): if grad_clip is not None: grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) - grad = jax.tree_map(lambda x: x * grad_scaling_factor, grad) + grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad) updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, current_param_container) diff --git a/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py b/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py index 07281f540..71b1c5e1e 100644 --- a/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py +++ b/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py @@ -111,8 +111,8 @@ def scale_by_nadam(b1: float = 0.9, raise_power = jnp.sqrt if power == 0.5 else lambda x: jnp.power(x, power) def init_fn(params): - mu = jax.tree_map(jnp.zeros_like, params) # First moment - nu = jax.tree_map(jnp.zeros_like, params) # Second moment + mu = jax.tree.map(jnp.zeros_like, params) # First moment + nu = jax.tree.map(jnp.zeros_like, params) # Second moment return ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu) def update_fn(updates, state, params=None): @@ -123,7 +123,7 @@ def update_fn(updates, state, params=None): mu_hat = _update_moment(updates, mu, b1, 1) mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) nu_hat = nu if not debias else _bias_correction(nu, b2, count) - updates = jax.tree_map( + updates = jax.tree.map( lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) @@ -139,14 +139,14 @@ class ScaleByAdamState(NamedTuple): def _update_moment(updates, moments, decay, order): """Compute the exponential moving average of the `order-th` moment.""" - return jax.tree_map( + return jax.tree.map( lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) def _bias_correction(moment, decay, count): """Perform bias correction. This becomes a no-op as count goes to infinity.""" beta = 1 - decay**count - return jax.tree_map(lambda t: t / beta.astype(t.dtype), moment) + return jax.tree.map(lambda t: t / beta.astype(t.dtype), moment) def scale_by_learning_rate(learning_rate, flip_sign=True): @@ -188,7 +188,7 @@ def jax_cosine_warmup(step_hint: int, hyperparameters): b2=hyperparameters.beta2, eps=1e-8, weight_decay=hyperparameters.weight_decay) - params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), + params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes) optimizer_state = opt_init_fn(params_zeros_like) @@ -236,7 +236,7 @@ def _loss_fn(params): (summed_loss, n_valid_examples, grad) = lax.psum( (summed_loss, n_valid_examples, grad), axis_name='batch') loss = summed_loss / n_valid_examples - grad = jax.tree_map(lambda x: x / n_valid_examples, grad) + grad = jax.tree.map(lambda x: x / n_valid_examples, grad) grad_norm = jnp.sqrt( sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) @@ -244,7 +244,7 @@ def _loss_fn(params): if grad_clip is not None: grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) - grad = jax.tree_map(lambda x: x * grad_scaling_factor, grad) + grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad) updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, current_param_container) diff --git a/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py b/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py index 0d194ef7a..127e660d0 100644 --- a/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py +++ b/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py @@ -120,8 +120,8 @@ def scale_by_nadam(b1: float = 0.9, raise_power = jnp.sqrt if power == 0.5 else lambda x: jnp.power(x, power) def init_fn(params): - mu = jax.tree_map(jnp.zeros_like, params) # First moment - nu = jax.tree_map(jnp.zeros_like, params) # Second moment + mu = jax.tree.map(jnp.zeros_like, params) # First moment + nu = jax.tree.map(jnp.zeros_like, params) # Second moment return ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu) def update_fn(updates, state, params=None): @@ -132,7 +132,7 @@ def update_fn(updates, state, params=None): mu_hat = _update_moment(updates, mu, b1, 1) mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) nu_hat = nu if not debias else _bias_correction(nu, b2, count) - updates = jax.tree_map( + updates = jax.tree.map( lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) @@ -148,14 +148,14 @@ class ScaleByAdamState(NamedTuple): def _update_moment(updates, moments, decay, order): """Compute the exponential moving average of the `order-th` moment.""" - return jax.tree_map( + return jax.tree.map( lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) def _bias_correction(moment, decay, count): """Perform bias correction. This becomes a no-op as count goes to infinity.""" beta = 1 - decay**count - return jax.tree_map(lambda t: t / beta.astype(t.dtype), moment) + return jax.tree.map(lambda t: t / beta.astype(t.dtype), moment) def scale_by_learning_rate(learning_rate, flip_sign=True): @@ -200,7 +200,7 @@ def jax_cosine_warmup(step_hint: int, hyperparameters): b2=hyperparameters['beta2'], eps=1e-8, weight_decay=hyperparameters['weight_decay']) - params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), + params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes) optimizer_state = opt_init_fn(params_zeros_like) @@ -248,7 +248,7 @@ def _loss_fn(params): (summed_loss, n_valid_examples, grad) = lax.psum( (summed_loss, n_valid_examples, grad), axis_name='batch') loss = summed_loss / n_valid_examples - grad = jax.tree_map(lambda x: x / n_valid_examples, grad) + grad = jax.tree.map(lambda x: x / n_valid_examples, grad) grad_norm = jnp.sqrt( sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) @@ -256,7 +256,7 @@ def _loss_fn(params): if grad_clip is not None: grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) - grad = jax.tree_map(lambda x: x * grad_scaling_factor, grad) + grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad) updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, current_param_container) diff --git a/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py b/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py index 60fc25ec4..92c0f599c 100644 --- a/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py +++ b/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py @@ -120,8 +120,8 @@ def scale_by_nadam(b1: float = 0.9, raise_power = jnp.sqrt if power == 0.5 else lambda x: jnp.power(x, power) def init_fn(params): - mu = jax.tree_map(jnp.zeros_like, params) # First moment - nu = jax.tree_map(jnp.zeros_like, params) # Second moment + mu = jax.tree.map(jnp.zeros_like, params) # First moment + nu = jax.tree.map(jnp.zeros_like, params) # Second moment return ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu) def update_fn(updates, state, params=None): @@ -132,7 +132,7 @@ def update_fn(updates, state, params=None): mu_hat = _update_moment(updates, mu, b1, 1) mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) nu_hat = nu if not debias else _bias_correction(nu, b2, count) - updates = jax.tree_map( + updates = jax.tree.map( lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) @@ -148,14 +148,14 @@ class ScaleByAdamState(NamedTuple): def _update_moment(updates, moments, decay, order): """Compute the exponential moving average of the `order-th` moment.""" - return jax.tree_map( + return jax.tree.map( lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) def _bias_correction(moment, decay, count): """Perform bias correction. This becomes a no-op as count goes to infinity.""" beta = 1 - decay**count - return jax.tree_map(lambda t: t / beta.astype(t.dtype), moment) + return jax.tree.map(lambda t: t / beta.astype(t.dtype), moment) def scale_by_learning_rate(learning_rate, flip_sign=True): @@ -200,7 +200,7 @@ def jax_cosine_warmup(step_hint: int, hyperparameters): b2=hyperparameters['beta2'], eps=1e-8, weight_decay=hyperparameters['weight_decay']) - params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), + params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes) optimizer_state = opt_init_fn(params_zeros_like) @@ -248,7 +248,7 @@ def _loss_fn(params): (summed_loss, n_valid_examples, grad) = lax.psum( (summed_loss, n_valid_examples, grad), axis_name='batch') loss = summed_loss / n_valid_examples - grad = jax.tree_map(lambda x: x / n_valid_examples, grad) + grad = jax.tree.map(lambda x: x / n_valid_examples, grad) grad_norm = jnp.sqrt( sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) @@ -256,7 +256,7 @@ def _loss_fn(params): if grad_clip is not None: grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) - grad = jax.tree_map(lambda x: x * grad_scaling_factor, grad) + grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad) updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, current_param_container) diff --git a/reference_algorithms/development_algorithms/cifar/cifar_jax/submission.py b/reference_algorithms/development_algorithms/cifar/cifar_jax/submission.py index e8e0bf4ac..055de8569 100644 --- a/reference_algorithms/development_algorithms/cifar/cifar_jax/submission.py +++ b/reference_algorithms/development_algorithms/cifar/cifar_jax/submission.py @@ -60,7 +60,7 @@ def init_optimizer_state(workload: spec.Workload, del model_params del model_state del rng - params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), + params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes) opt_init_fn, opt_update_fn = optimizer(hyperparameters, workload.num_train_examples) diff --git a/reference_algorithms/development_algorithms/mnist/mnist_jax/submission.py b/reference_algorithms/development_algorithms/mnist/mnist_jax/submission.py index b33c0285b..b7c4dd2f2 100644 --- a/reference_algorithms/development_algorithms/mnist/mnist_jax/submission.py +++ b/reference_algorithms/development_algorithms/mnist/mnist_jax/submission.py @@ -26,7 +26,7 @@ def init_optimizer_state(workload: spec.Workload, del model_params del model_state del rng - params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), + params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes) opt_init_fn, opt_update_fn = optax.chain( optax.scale_by_adam( diff --git a/reference_algorithms/paper_baselines/adafactor/jax/sharded_adafactor.py b/reference_algorithms/paper_baselines/adafactor/jax/sharded_adafactor.py index 9f4da9132..ff98464ae 100644 --- a/reference_algorithms/paper_baselines/adafactor/jax/sharded_adafactor.py +++ b/reference_algorithms/paper_baselines/adafactor/jax/sharded_adafactor.py @@ -316,11 +316,11 @@ def to_state(self, count, result_tree): """Maps from a tree of (factored) values to separate trees of values.""" return ShardedAdafactorState( count=count, - m=jax.tree_map(lambda o: o.m, result_tree), - m_scale=jax.tree_map(lambda o: o.m_scale, result_tree), - vr=jax.tree_map(lambda o: o.vr, result_tree), - vc=jax.tree_map(lambda o: o.vc, result_tree), - v=jax.tree_map(lambda o: o.v, result_tree)) + m=jax.tree.map(lambda o: o.m, result_tree), + m_scale=jax.tree.map(lambda o: o.m_scale, result_tree), + vr=jax.tree.map(lambda o: o.vr, result_tree), + vc=jax.tree.map(lambda o: o.vc, result_tree), + v=jax.tree.map(lambda o: o.v, result_tree)) def init(self, param): """Initializes the optimizer state for a given param.""" @@ -667,7 +667,7 @@ def init_fn(params): """Initializes the optimizer's state.""" return sharded_adafactor_helper.to_state( jnp.zeros([], jnp.int32), - jax.tree_map(sharded_adafactor_helper.init, params)) + jax.tree.map(sharded_adafactor_helper.init, params)) def update_fn(updates, state, params=None): if params is None: @@ -677,7 +677,7 @@ def update_fn(updates, state, params=None): compute_var_and_slot_update_fn = functools.partial( sharded_adafactor_helper.compute_var_and_slot_update, state.count) - output = jax.tree_map(compute_var_and_slot_update_fn, + output = jax.tree.map(compute_var_and_slot_update_fn, updates, state.m, state.m_scale, @@ -685,7 +685,7 @@ def update_fn(updates, state, params=None): state.vc, state.v, params) - updates = jax.tree_map(lambda o: o.update, output) + updates = jax.tree.map(lambda o: o.update, output) count_plus_one = state.count + jnp.array(1, jnp.int32) updated_states = sharded_adafactor_helper.to_state(count_plus_one, output) return updates, updated_states diff --git a/reference_algorithms/paper_baselines/adafactor/jax/submission.py b/reference_algorithms/paper_baselines/adafactor/jax/submission.py index 0fcb9da0f..133468aea 100644 --- a/reference_algorithms/paper_baselines/adafactor/jax/submission.py +++ b/reference_algorithms/paper_baselines/adafactor/jax/submission.py @@ -46,7 +46,7 @@ def jax_cosine_warmup(step_hint: int, hyperparameters): learning_rate=lr_schedule_fn, beta1=1.0 - hyperparameters.one_minus_beta1, weight_decay=hyperparameters.weight_decay) - params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), + params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes) optimizer_state = opt_init_fn(params_zeros_like) @@ -94,7 +94,7 @@ def _loss_fn(params): (summed_loss, n_valid_examples, grad) = lax.psum( (summed_loss, n_valid_examples, grad), axis_name='batch') loss = summed_loss / n_valid_examples - grad = jax.tree_map(lambda x: x / n_valid_examples, grad) + grad = jax.tree.map(lambda x: x / n_valid_examples, grad) grad_norm = jnp.sqrt( sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) @@ -102,7 +102,7 @@ def _loss_fn(params): if grad_clip is not None: grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) - grad = jax.tree_map(lambda x: x * grad_scaling_factor, grad) + grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad) updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, current_param_container) diff --git a/reference_algorithms/paper_baselines/adamw/jax/submission.py b/reference_algorithms/paper_baselines/adamw/jax/submission.py index e80a29693..60a336250 100644 --- a/reference_algorithms/paper_baselines/adamw/jax/submission.py +++ b/reference_algorithms/paper_baselines/adamw/jax/submission.py @@ -46,7 +46,7 @@ def jax_cosine_warmup(step_hint: int, hyperparameters): b2=hyperparameters.beta2, eps=1e-8, weight_decay=hyperparameters.weight_decay) - params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), + params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes) optimizer_state = opt_init_fn(params_zeros_like) @@ -94,7 +94,7 @@ def _loss_fn(params): (summed_loss, n_valid_examples, grad) = lax.psum( (summed_loss, n_valid_examples, grad), axis_name='batch') loss = summed_loss / n_valid_examples - grad = jax.tree_map(lambda x: x / n_valid_examples, grad) + grad = jax.tree.map(lambda x: x / n_valid_examples, grad) grad_norm = jnp.sqrt( sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) @@ -102,7 +102,7 @@ def _loss_fn(params): if grad_clip is not None: grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) - grad = jax.tree_map(lambda x: x * grad_scaling_factor, grad) + grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad) updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, current_param_container) diff --git a/reference_algorithms/paper_baselines/lamb/jax/submission.py b/reference_algorithms/paper_baselines/lamb/jax/submission.py index ebcdc9914..7a3e1289c 100644 --- a/reference_algorithms/paper_baselines/lamb/jax/submission.py +++ b/reference_algorithms/paper_baselines/lamb/jax/submission.py @@ -53,7 +53,7 @@ def jax_cosine_warmup(step_hint: int, hyperparameters): b2=hyperparameters.beta2, eps=1e-8, weight_decay=hyperparameters.weight_decay) - params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), + params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes) optimizer_state = opt_init_fn(params_zeros_like) @@ -102,7 +102,7 @@ def _loss_fn(params): (summed_loss, n_valid_examples, grad) = lax.psum( (summed_loss, n_valid_examples, grad), axis_name='batch') loss = summed_loss / n_valid_examples - grad = jax.tree_map(lambda x: x / n_valid_examples, grad) + grad = jax.tree.map(lambda x: x / n_valid_examples, grad) grad_norm = jnp.sqrt( sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) @@ -110,7 +110,7 @@ def _loss_fn(params): if grad_clip is not None: grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) - grad = jax.tree_map(lambda x: x * grad_scaling_factor, grad) + grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad) updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, current_param_container) diff --git a/reference_algorithms/paper_baselines/momentum/jax/submission.py b/reference_algorithms/paper_baselines/momentum/jax/submission.py index 271ef860b..182fbe644 100644 --- a/reference_algorithms/paper_baselines/momentum/jax/submission.py +++ b/reference_algorithms/paper_baselines/momentum/jax/submission.py @@ -28,7 +28,7 @@ def init_optimizer_state(workload: spec.Workload, lr_schedule_fn = create_lr_schedule_fn(workload.step_hint, hyperparameters) # Create optimizer. - params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), + params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes) opt_init_fn, opt_update_fn = sgd( learning_rate=lr_schedule_fn, @@ -128,7 +128,7 @@ def _loss_fn(params): (summed_loss, n_valid_examples, grad) = lax.psum( (summed_loss, n_valid_examples, grad), axis_name='batch') loss = summed_loss / n_valid_examples - grad = jax.tree_map(lambda x: x / n_valid_examples, grad) + grad = jax.tree.map(lambda x: x / n_valid_examples, grad) grad_norm = jnp.sqrt( sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) @@ -136,7 +136,7 @@ def _loss_fn(params): if grad_clip is not None: grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) - grad = jax.tree_map(lambda x: x * grad_scaling_factor, grad) + grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad) updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, current_param_container) diff --git a/reference_algorithms/paper_baselines/nadamw/jax/submission.py b/reference_algorithms/paper_baselines/nadamw/jax/submission.py index 36e7e5607..30f9068d1 100644 --- a/reference_algorithms/paper_baselines/nadamw/jax/submission.py +++ b/reference_algorithms/paper_baselines/nadamw/jax/submission.py @@ -111,8 +111,8 @@ def scale_by_nadam(b1: float = 0.9, raise_power = jnp.sqrt if power == 0.5 else lambda x: jnp.power(x, power) def init_fn(params): - mu = jax.tree_map(jnp.zeros_like, params) # First moment - nu = jax.tree_map(jnp.zeros_like, params) # Second moment + mu = jax.tree.map(jnp.zeros_like, params) # First moment + nu = jax.tree.map(jnp.zeros_like, params) # Second moment return ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu) def update_fn(updates, state, params=None): @@ -123,7 +123,7 @@ def update_fn(updates, state, params=None): mu_hat = _update_moment(updates, mu, b1, 1) mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) nu_hat = nu if not debias else _bias_correction(nu, b2, count) - updates = jax.tree_map( + updates = jax.tree.map( lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) @@ -139,14 +139,14 @@ class ScaleByAdamState(NamedTuple): def _update_moment(updates, moments, decay, order): """Compute the exponential moving average of the `order-th` moment.""" - return jax.tree_map( + return jax.tree.map( lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) def _bias_correction(moment, decay, count): """Perform bias correction. This becomes a no-op as count goes to infinity.""" beta = 1 - decay**count - return jax.tree_map(lambda t: t / beta.astype(t.dtype), moment) + return jax.tree.map(lambda t: t / beta.astype(t.dtype), moment) def scale_by_learning_rate(learning_rate, flip_sign=True): @@ -188,7 +188,7 @@ def jax_cosine_warmup(step_hint: int, hyperparameters): b2=hyperparameters.beta2, eps=1e-8, weight_decay=hyperparameters.weight_decay) - params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), + params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes) optimizer_state = opt_init_fn(params_zeros_like) @@ -236,7 +236,7 @@ def _loss_fn(params): (summed_loss, n_valid_examples, grad) = lax.psum( (summed_loss, n_valid_examples, grad), axis_name='batch') loss = summed_loss / n_valid_examples - grad = jax.tree_map(lambda x: x / n_valid_examples, grad) + grad = jax.tree.map(lambda x: x / n_valid_examples, grad) grad_norm = jnp.sqrt( sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) @@ -244,7 +244,7 @@ def _loss_fn(params): if grad_clip is not None: grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) - grad = jax.tree_map(lambda x: x * grad_scaling_factor, grad) + grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad) updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, current_param_container) diff --git a/reference_algorithms/paper_baselines/nesterov/jax/submission.py b/reference_algorithms/paper_baselines/nesterov/jax/submission.py index a435643e4..e45d8a854 100644 --- a/reference_algorithms/paper_baselines/nesterov/jax/submission.py +++ b/reference_algorithms/paper_baselines/nesterov/jax/submission.py @@ -28,7 +28,7 @@ def init_optimizer_state(workload: spec.Workload, lr_schedule_fn = create_lr_schedule_fn(workload.step_hint, hyperparameters) # Create optimizer. - params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), + params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes) opt_init_fn, opt_update_fn = sgd( learning_rate=lr_schedule_fn, @@ -128,7 +128,7 @@ def _loss_fn(params): (summed_loss, n_valid_examples, grad) = lax.psum( (summed_loss, n_valid_examples, grad), axis_name='batch') loss = summed_loss / n_valid_examples - grad = jax.tree_map(lambda x: x / n_valid_examples, grad) + grad = jax.tree.map(lambda x: x / n_valid_examples, grad) grad_norm = jnp.sqrt( sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) @@ -136,7 +136,7 @@ def _loss_fn(params): if grad_clip is not None: grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) - grad = jax.tree_map(lambda x: x * grad_scaling_factor, grad) + grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad) updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, current_param_container) diff --git a/reference_algorithms/paper_baselines/sam/jax/submission.py b/reference_algorithms/paper_baselines/sam/jax/submission.py index 5f45901dd..3f029fbfd 100644 --- a/reference_algorithms/paper_baselines/sam/jax/submission.py +++ b/reference_algorithms/paper_baselines/sam/jax/submission.py @@ -24,7 +24,7 @@ def dual_vector(y: jnp.ndarray) -> jnp.ndarray: """ gradient_norm = jnp.sqrt( sum(jnp.sum(jnp.square(e)) for e in jax.tree_util.tree_leaves(y))) - normalized_gradient = jax.tree_map(lambda x: x / gradient_norm, y) + normalized_gradient = jax.tree.map(lambda x: x / gradient_norm, y) return normalized_gradient @@ -73,12 +73,12 @@ def update_fn(updates, state, grad_fn_params_tuple): # Get correct global mean grad. (n_valid_examples, updates) = lax.psum((n_valid_examples, updates), axis_name=batch_axis_name) - updates = jax.tree_map(lambda x: x / n_valid_examples, updates) + updates = jax.tree.map(lambda x: x / n_valid_examples, updates) if grad_clip: updates_norm = jnp.sqrt( sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(updates))) - scaled_updates = jax.tree_map( + scaled_updates = jax.tree.map( lambda x: x / (updates_norm + _GRAD_CLIP_EPS) * grad_clip, updates) updates = jax.lax.cond(updates_norm > grad_clip, lambda _: scaled_updates, @@ -136,7 +136,7 @@ def jax_cosine_warmup(step_hint: int, hyperparameters): base_opt_update_fn=opt_update_fn) # Initialize optimizer state. - params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), + params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes) optimizer_state = opt_init_fn(params_zeros_like) @@ -186,7 +186,7 @@ def _loss_fn(params, update_batch_norm=True): (summed_loss, n_valid_examples, grad) = lax.psum( (summed_loss, n_valid_examples, grad), axis_name='batch') loss = summed_loss / n_valid_examples - grad = jax.tree_map(lambda x: x / n_valid_examples, grad) + grad = jax.tree.map(lambda x: x / n_valid_examples, grad) grad_norm = jnp.sqrt( sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) diff --git a/reference_algorithms/paper_baselines/shampoo/jax/distributed_shampoo.py b/reference_algorithms/paper_baselines/shampoo/jax/distributed_shampoo.py index 725529cae..a5c2732ac 100644 --- a/reference_algorithms/paper_baselines/shampoo/jax/distributed_shampoo.py +++ b/reference_algorithms/paper_baselines/shampoo/jax/distributed_shampoo.py @@ -342,7 +342,7 @@ def init_training_metrics( """Initialize TrainingMetrics, masked if disabled.""" if not generate_training_metrics: return optax.MaskedNode() - return jax.tree_map( + return jax.tree.map( functools.partial(jnp.repeat, repeats=num_statistics), default_training_metrics()) @@ -356,14 +356,14 @@ def init_training_metrics_shapes( num_statistics, generate_training_metrics, ) - return jax.tree_map(lambda arr: [list(arr.shape), arr.dtype], seed) + return jax.tree.map(lambda arr: [list(arr.shape), arr.dtype], seed) def init_training_metrics_pspec(generate_training_metrics,): """Initialize training metrics partition specification.""" if not generate_training_metrics: return optax.MaskedNode() - return jax.tree_map(lambda _: jax.sharding.PartitionSpec(), + return jax.tree.map(lambda _: jax.sharding.PartitionSpec(), default_training_metrics()) @@ -1253,7 +1253,7 @@ def _add_metrics_into_local_stats(local_stats, metrics, keep_old): index_start = int(local_stat.index_start) index_end = int(len(local_stat.sizes)) + index_start # pylint:disable=cell-var-from-loop Used immediately. - per_stat_metrics = jax.tree_map(lambda x: x[index_start:index_end], metrics) + per_stat_metrics = jax.tree.map(lambda x: x[index_start:index_end], metrics) # We don't want to update the metrics if we didn't do a new inverse p-th # root calculation to find a new preconditioner, so that TensorBoard curves # look consistent (otherwise they'd oscillate between NaN and measured @@ -1808,7 +1808,7 @@ def sharded_update_fn(grads, state, params): local_stat, )) - new_stats_flat = jax.tree_map( + new_stats_flat = jax.tree.map( lambda g, s, p: _compute_stats(g, s, p, state.count), @@ -1816,7 +1816,7 @@ def sharded_update_fn(grads, state, params): stats_flat, params_flat) - outputs = jax.tree_map( + outputs = jax.tree.map( lambda g, s, p: _transform_grad(g, s, p, state.count), @@ -1981,7 +1981,7 @@ def _init(param): )) return ShampooState( - count=jnp.zeros([], jnp.int32), stats=jax.tree_map(_init, params)) + count=jnp.zeros([], jnp.int32), stats=jax.tree.map(_init, params)) def _skip_preconditioning(param): return len(param.shape) < skip_preconditioning_rank_lt or any( @@ -2140,7 +2140,7 @@ def _internal_inverse_pth_root_all(): preconditioners = jax.lax.all_gather(preconditioners, batch_axis_name) metrics = jax.lax.all_gather(metrics, batch_axis_name) preconditioners_flat = unbatch(preconditioners) - metrics_flat = jax.tree_map(unbatch, metrics) + metrics_flat = jax.tree.map(unbatch, metrics) else: preconditioners, metrics = _matrix_inverse_pth_root_vmap( all_statistics[0], @@ -2149,9 +2149,9 @@ def _internal_inverse_pth_root_all(): _maybe_ix(all_preconditioners, 0), ) preconditioners_flat = unbatch(jnp.stack([preconditioners])) - metrics = jax.tree_map( + metrics = jax.tree.map( functools.partial(jnp.expand_dims, axis=0), metrics) - metrics_flat = jax.tree_map(unbatch, metrics) + metrics_flat = jax.tree.map(unbatch, metrics) return preconditioners_flat, metrics_flat @@ -2166,7 +2166,7 @@ def _internal_inverse_pth_root_all(): s[:, :precond_dim(s.shape[0])] for s in packed_statistics ] n = len(packed_statistics) - metrics_init = jax.tree_map( + metrics_init = jax.tree.map( lambda x: [x] * n, default_training_metrics().replace( inverse_pth_root_errors=inverse_failure_threshold)) @@ -2215,12 +2215,12 @@ def _select_preconditioner(error, new_p, old_p): if generate_training_metrics: # pylint:disable=cell-var-from-loop Used immediately. - metrics_for_state = jax.tree_map( + metrics_for_state = jax.tree.map( lambda x: jnp.stack(x[idx:idx + num_statistics]), metrics_flat, is_leaf=lambda x: isinstance(x, list)) assert jax.tree_util.tree_all( - jax.tree_map(lambda x: len(state.statistics) == len(x), + jax.tree.map(lambda x: len(state.statistics) == len(x), metrics_for_state)) # If we skipped preconditioner computation, record old metrics. metrics_for_state = efficient_cond(perform_step, @@ -2441,7 +2441,7 @@ def update_fn(grads, state, params): if custom_preconditioner and grads_custom is not None: stats_grads = treedef.flatten_up_to(grads_custom) - new_stats_flat = jax.tree_map( + new_stats_flat = jax.tree.map( lambda g, s, p: _compute_stats(g, s, p, state.count), @@ -2452,7 +2452,7 @@ def update_fn(grads, state, params): new_stats_flat = _compute_preconditioners(new_stats_flat, params_flat, state.count) - outputs = jax.tree_map( + outputs = jax.tree.map( lambda g, s, p: _transform_grad(g, s, p, state.count), diff --git a/reference_algorithms/paper_baselines/shampoo/jax/submission.py b/reference_algorithms/paper_baselines/shampoo/jax/submission.py index 294ad2706..4a257d17b 100644 --- a/reference_algorithms/paper_baselines/shampoo/jax/submission.py +++ b/reference_algorithms/paper_baselines/shampoo/jax/submission.py @@ -49,7 +49,7 @@ def jax_cosine_warmup(step_hint: int, hyperparameters): weight_decay=hyperparameters.weight_decay, batch_axis_name='batch', eigh=False) - params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), + params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes) optimizer_state = opt_init_fn(params_zeros_like) @@ -97,7 +97,7 @@ def _loss_fn(params): (summed_loss, n_valid_examples, grad) = lax.psum( (summed_loss, n_valid_examples, grad), axis_name='batch') loss = summed_loss / n_valid_examples - grad = jax.tree_map(lambda x: x / n_valid_examples, grad) + grad = jax.tree.map(lambda x: x / n_valid_examples, grad) grad_norm = jnp.sqrt( sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) @@ -105,7 +105,7 @@ def _loss_fn(params): if grad_clip is not None: grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) - grad = jax.tree_map(lambda x: x * grad_scaling_factor, grad) + grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad) updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, current_param_container) diff --git a/reference_algorithms/target_setting_algorithms/jax_adamw.py b/reference_algorithms/target_setting_algorithms/jax_adamw.py index 6d2cfe245..bb85ecf05 100644 --- a/reference_algorithms/target_setting_algorithms/jax_adamw.py +++ b/reference_algorithms/target_setting_algorithms/jax_adamw.py @@ -29,7 +29,7 @@ def init_optimizer_state(workload: spec.Workload, hyperparameters) # Create optimizer. - params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), + params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes) epsilon = ( hyperparameters.epsilon if hasattr(hyperparameters, 'epsilon') else 1e-8) diff --git a/reference_algorithms/target_setting_algorithms/jax_momentum.py b/reference_algorithms/target_setting_algorithms/jax_momentum.py index 08a0f7e9d..c5fc2a0c6 100644 --- a/reference_algorithms/target_setting_algorithms/jax_momentum.py +++ b/reference_algorithms/target_setting_algorithms/jax_momentum.py @@ -32,7 +32,7 @@ def init_optimizer_state(workload: spec.Workload, hyperparameters) # Create optimizer. - params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), + params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes) opt_init_fn, opt_update_fn = sgd( learning_rate=lr_schedule_fn, diff --git a/reference_algorithms/target_setting_algorithms/jax_nadamw.py b/reference_algorithms/target_setting_algorithms/jax_nadamw.py index 21f2a7b2b..1e6b691fc 100644 --- a/reference_algorithms/target_setting_algorithms/jax_nadamw.py +++ b/reference_algorithms/target_setting_algorithms/jax_nadamw.py @@ -96,8 +96,8 @@ def scale_by_nadam(b1: float = 0.9, raise_power = jnp.sqrt if power == 0.5 else lambda x: jnp.power(x, power) def init_fn(params): - mu = jax.tree_map(jnp.zeros_like, params) # First moment - nu = jax.tree_map(jnp.zeros_like, params) # Second moment + mu = jax.tree.map(jnp.zeros_like, params) # First moment + nu = jax.tree.map(jnp.zeros_like, params) # Second moment return ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu) def update_fn(updates, state, params=None): @@ -108,7 +108,7 @@ def update_fn(updates, state, params=None): mu_hat = _update_moment(updates, mu, b1, 1) mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) nu_hat = nu if not debias else _bias_correction(nu, b2, count) - updates = jax.tree_map( + updates = jax.tree.map( lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) @@ -124,14 +124,14 @@ class ScaleByAdamState(NamedTuple): def _update_moment(updates, moments, decay, order): """Compute the exponential moving average of the `order-th` moment.""" - return jax.tree_map( + return jax.tree.map( lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) def _bias_correction(moment, decay, count): """Perform bias correction. This becomes a no-op as count goes to infinity.""" beta = 1 - decay**count - return jax.tree_map(lambda t: t / beta.astype(t.dtype), moment) + return jax.tree.map(lambda t: t / beta.astype(t.dtype), moment) def scale_by_learning_rate(learning_rate, flip_sign=True): @@ -156,7 +156,7 @@ def init_optimizer_state(workload: spec.Workload, hyperparameters) # Create optimizer. - params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), + params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes) epsilon = ( hyperparameters.epsilon if hasattr(hyperparameters, 'epsilon') else 1e-8) diff --git a/reference_algorithms/target_setting_algorithms/jax_nesterov.py b/reference_algorithms/target_setting_algorithms/jax_nesterov.py index 6b27e0e2a..e5abde50b 100644 --- a/reference_algorithms/target_setting_algorithms/jax_nesterov.py +++ b/reference_algorithms/target_setting_algorithms/jax_nesterov.py @@ -32,7 +32,7 @@ def init_optimizer_state(workload: spec.Workload, hyperparameters) # Create optimizer. - params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), + params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes) opt_init_fn, opt_update_fn = sgd( learning_rate=lr_schedule_fn, diff --git a/reference_algorithms/target_setting_algorithms/jax_submission_base.py b/reference_algorithms/target_setting_algorithms/jax_submission_base.py index 7a16c07cb..703310df4 100644 --- a/reference_algorithms/target_setting_algorithms/jax_submission_base.py +++ b/reference_algorithms/target_setting_algorithms/jax_submission_base.py @@ -53,7 +53,7 @@ def _loss_fn(params): (summed_loss, n_valid_examples, grad) = lax.psum( (summed_loss, n_valid_examples, grad), axis_name='batch') loss = summed_loss / n_valid_examples - grad = jax.tree_map(lambda x: x / n_valid_examples, grad) + grad = jax.tree.map(lambda x: x / n_valid_examples, grad) grad_norm = jnp.sqrt( sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) @@ -61,7 +61,7 @@ def _loss_fn(params): if grad_clip is not None: grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) - grad = jax.tree_map(lambda x: x * grad_scaling_factor, grad) + grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad) updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, current_param_container) diff --git a/tests/modeldiffs/vanilla_sgd_jax.py b/tests/modeldiffs/vanilla_sgd_jax.py index d45694bcb..18dce968a 100644 --- a/tests/modeldiffs/vanilla_sgd_jax.py +++ b/tests/modeldiffs/vanilla_sgd_jax.py @@ -21,7 +21,7 @@ def init_optimizer_state(workload: spec.Workload, del rng # Create optimizer. - params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), + params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes) opt_init_fn, opt_update_fn = optax.sgd(learning_rate=0.001) diff --git a/tests/reference_algorithm_tests.py b/tests/reference_algorithm_tests.py index f107be8d7..6afea8a8e 100644 --- a/tests/reference_algorithm_tests.py +++ b/tests/reference_algorithm_tests.py @@ -97,9 +97,9 @@ def _make_fake_image_batch(batch_shape, data_shape, num_classes): def _pytorch_map(inputs): if USE_PYTORCH_DDP: - return jax.tree_map( + return jax.tree.map( lambda a: torch.as_tensor(a[RANK], device=PYTORCH_DEVICE), inputs) - return jax.tree_map( + return jax.tree.map( lambda a: torch.as_tensor(a, device=PYTORCH_DEVICE).view(-1, a.shape[-1]) if len(a.shape) == 3 else torch.as_tensor(a, device=PYTORCH_DEVICE).view( -1), diff --git a/tests/workloads/imagenet_resnet/imagenet_jax/workload_test.py b/tests/workloads/imagenet_resnet/imagenet_jax/workload_test.py index 6a85c2196..49fd85fef 100644 --- a/tests/workloads/imagenet_resnet/imagenet_jax/workload_test.py +++ b/tests/workloads/imagenet_resnet/imagenet_jax/workload_test.py @@ -10,7 +10,7 @@ def _pytree_total_diff(pytree_a, pytree_b): - pytree_diff = jax.tree_map(lambda a, b: jnp.sum(a - b), pytree_a, pytree_b) + pytree_diff = jax.tree.map(lambda a, b: jnp.sum(a - b), pytree_a, pytree_b) pytree_diff = jax.tree_util.tree_leaves(pytree_diff) return jnp.sum(jnp.array(pytree_diff)) From 785d82bff29454a1053cd0bf3e0fdd0354851bd1 Mon Sep 17 00:00:00 2001 From: init-22 Date: Sun, 22 Dec 2024 17:06:16 +0530 Subject: [PATCH 053/105] fix: MultiHeadDotProductAttention and optax ctc_loss changes --- .../workloads/imagenet_vit/imagenet_jax/models.py | 4 ++-- .../librispeech_conformer/librispeech_jax/models.py | 6 +++--- .../librispeech_conformer/librispeech_jax/workload.py | 11 ++++++----- 3 files changed, 11 insertions(+), 10 deletions(-) diff --git a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/models.py b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/models.py index 639800b44..79ad54097 100644 --- a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/models.py +++ b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/models.py @@ -70,7 +70,7 @@ class Encoder1DBlock(nn.Module): def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor: if not self.use_post_layer_norm: y = nn.LayerNorm(name='LayerNorm_0')(x) - y = nn.SelfAttention( + y = nn.MultiHeadDotProductAttention( num_heads=self.num_heads, kernel_init=nn.initializers.xavier_uniform(), deterministic=train, @@ -89,7 +89,7 @@ def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor: x = x + y else: y = x - y = nn.SelfAttention( + y = nn.MultiHeadDotProductAttention( num_heads=self.num_heads, kernel_init=nn.initializers.xavier_uniform(), deterministic=train, diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py index cb6287c5e..85a8d1bb7 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py @@ -396,10 +396,9 @@ def __call__(self, inputs, paddings, train): mask_paddings > 0, mask_paddings > 0, dtype=jnp.float32) inputs = LayerNorm(dim=config.encoder_dim)(inputs) - attention_fn = functools.partial( dot_product_attention, temperature=config.attention_temperature) - result = nn.SelfAttention( + result = nn.MultiHeadDotProductAttention( num_heads=config.num_attention_heads, qkv_features=config.encoder_dim, decode=False, @@ -410,7 +409,8 @@ def __call__(self, inputs, paddings, train): broadcast_dropout=False, attention_fn=attention_fn, dropout_rate=config.attention_dropout_rate, - deterministic=not train)(inputs, attention_mask) + deterministic=not train)( + inputs_q=inputs, mask=attention_mask) if config.attention_residual_dropout_rate is None: attention_residual_dropout_rate = 0.1 diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py index 05faf1135..f546ef785 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py @@ -227,11 +227,12 @@ def ctc_loss(self, labels: spec.Tensor, label_paddings: spec.Tensor, blank_id: int = 0) -> spec.Tensor: - return optax.ctc_loss(logits, - logit_paddings, - labels, - label_paddings, - blank_id) + return optax.ctc_loss( + logits=logits, + logit_paddings=logit_paddings, + labels=labels, + label_paddings=label_paddings, + blank_id=blank_id) # Adapted from lingvo's greedy decoding logic here: # https://github.com/tensorflow/lingvo/blob/2ee26814c57b7dcead3f0382170f2f3da006f810/lingvo/jax/layers/ctc_objectives.py#L138. From d4aa90a8e8de930deb7981a931f6ff672ca1c9e1 Mon Sep 17 00:00:00 2001 From: init-22 Date: Sun, 22 Dec 2024 19:18:10 +0530 Subject: [PATCH 054/105] fix: removed the sacrebleu dependency --- algorithmic_efficiency/workloads/wmt/bleu.py | 366 ++++++++++++++++++- setup.cfg | 2 +- 2 files changed, 355 insertions(+), 13 deletions(-) diff --git a/algorithmic_efficiency/workloads/wmt/bleu.py b/algorithmic_efficiency/workloads/wmt/bleu.py index 1efc87381..dda6d102a 100644 --- a/algorithmic_efficiency/workloads/wmt/bleu.py +++ b/algorithmic_efficiency/workloads/wmt/bleu.py @@ -1,8 +1,20 @@ +""" +Removing the dependency on sacrebleu, we reimplement the BLEU score computation in this file. +Reference: +https://github.com/mjpost/sacrebleu/blob/v1.3.1/sacrebleu.py. +""" + +from collections import Counter +from collections import namedtuple from itertools import zip_longest -from typing import Sequence +import logging +import math +import re +import sys +from typing import List, Sequence +import unicodedata from absl import logging -import sacrebleu import torch import torch.distributed as dist @@ -10,10 +22,340 @@ USE_PYTORCH_DDP, _, DEVICE, N_GPUS = pytorch_setup() +NGRAM_ORDER = 4 +# The default floor value to use with `--smooth floor` +SMOOTH_VALUE_DEFAULT = 0.0 + + +def my_log(num): + """ + Floors the log function + + :param num: the number + :return: log(num) floored to a very low number + """ + + if num == 0.0: + return -9999999999 + return math.log(num) + + +def tokenize_13a(line): + """ + Tokenizes an input line using a relatively minimal tokenization that is however equivalent to mteval-v13a, used by WMT. + + :param line: a segment to tokenize + :return: the tokenized line + """ + + norm = line + + # language-independent part: + norm = norm.replace('', '') + norm = norm.replace('-\n', '') + norm = norm.replace('\n', ' ') + norm = norm.replace('"', '"') + norm = norm.replace('&', '&') + norm = norm.replace('<', '<') + norm = norm.replace('>', '>') + + # language-dependent part (assuming Western languages): + norm = " {} ".format(norm) + norm = re.sub(r'([\{-\~\[-\` -\&\(-\+\:-\@\/])', ' \\1 ', norm) + norm = re.sub(r'([^0-9])([\.,])', '\\1 \\2 ', + norm) # tokenize period and comma unless preceded by a digit + norm = re.sub(r'([\.,])([^0-9])', ' \\1 \\2', + norm) # tokenize period and comma unless followed by a digit + norm = re.sub(r'([0-9])(-)', '\\1 \\2 ', + norm) # tokenize dash when preceded by a digit + norm = re.sub(r'\s+', ' ', norm) # one space only between words + norm = re.sub(r'^\s+', '', norm) # no leading space + norm = re.sub(r'\s+$', '', norm) # no trailing space + + return norm + + +class UnicodeRegex: + """Ad-hoc hack to recognize all punctuation and symbols. + + without depending on https://pypi.python.org/pypi/regex/.""" + + def _property_chars(prefix): + return ''.join( + chr(x) + for x in range(sys.maxunicode) + if unicodedata.category(chr(x)).startswith(prefix)) + + punctuation = _property_chars('P') + nondigit_punct_re = re.compile(r'([^\d])([' + punctuation + r'])') + punct_nondigit_re = re.compile(r'([' + punctuation + r'])([^\d])') + symbol_re = re.compile('([' + _property_chars('S') + '])') + + +def tokenize_v14_international(string): + r"""Tokenize a string following the official BLEU implementation. + + See https://github.com/moses-smt/mosesdecoder/blob/master/scripts/generic/mteval-v14.pl#L954-L983 + In our case, the input string is expected to be just one line + and no HTML entities de-escaping is needed. + So we just tokenize on punctuation and symbols, + except when a punctuation is preceded and followed by a digit + (e.g. a comma/dot as a thousand/decimal separator). + + Note that a number (e.g., a year) followed by a dot at the end of sentence is NOT tokenized, + i.e. the dot stays with the number because `s/(\p{P})(\P{N})/ $1 $2/g` + does not match this case (unless we add a space after each sentence). + However, this error is already in the original mteval-v14.pl + and we want to be consistent with it. + The error is not present in the non-international version, + which uses `$norm_text = " $norm_text "` (or `norm = " {} ".format(norm)` in Python). + + :param string: the input string + :return: a list of tokens + """ + string = UnicodeRegex.nondigit_punct_re.sub(r'\1 \2 ', string) + string = UnicodeRegex.punct_nondigit_re.sub(r' \1 \2', string) + string = UnicodeRegex.symbol_re.sub(r' \1 ', string) + return string.strip() + + +def tokenize_zh(sentence): + """MIT License + Copyright (c) 2017 - Shujian Huang + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE. + + The tokenization of Chinese text in this script contains two steps: separate each Chinese + characters (by utf-8 encoding); tokenize the non Chinese part (following the mteval script). + Author: Shujian Huang huangsj@nju.edu.cn + + :param sentence: input sentence + :return: tokenized sentence + """ + + def is_chinese_char(uchar): + """ + :param uchar: input char in unicode + :return: whether the input char is a Chinese character. + """ + if uchar >= u'\u3400' and uchar <= u'\u4db5': # CJK Unified Ideographs Extension A, release 3.0 + return True + elif uchar >= u'\u4e00' and uchar <= u'\u9fa5': # CJK Unified Ideographs, release 1.1 + return True + elif uchar >= u'\u9fa6' and uchar <= u'\u9fbb': # CJK Unified Ideographs, release 4.1 + return True + elif uchar >= u'\uf900' and uchar <= u'\ufa2d': # CJK Compatibility Ideographs, release 1.1 + return True + elif uchar >= u'\ufa30' and uchar <= u'\ufa6a': # CJK Compatibility Ideographs, release 3.2 + return True + elif uchar >= u'\ufa70' and uchar <= u'\ufad9': # CJK Compatibility Ideographs, release 4.1 + return True + elif uchar >= u'\u20000' and uchar <= u'\u2a6d6': # CJK Unified Ideographs Extension B, release 3.1 + return True + elif uchar >= u'\u2f800' and uchar <= u'\u2fa1d': # CJK Compatibility Supplement, release 3.1 + return True + elif uchar >= u'\uff00' and uchar <= u'\uffef': # Full width ASCII, full width of English punctuation, half width Katakana, half wide half width kana, Korean alphabet + return True + elif uchar >= u'\u2e80' and uchar <= u'\u2eff': # CJK Radicals Supplement + return True + elif uchar >= u'\u3000' and uchar <= u'\u303f': # CJK punctuation mark + return True + elif uchar >= u'\u31c0' and uchar <= u'\u31ef': # CJK stroke + return True + elif uchar >= u'\u2f00' and uchar <= u'\u2fdf': # Kangxi Radicals + return True + elif uchar >= u'\u2ff0' and uchar <= u'\u2fff': # Chinese character structure + return True + elif uchar >= u'\u3100' and uchar <= u'\u312f': # Phonetic symbols + return True + elif uchar >= u'\u31a0' and uchar <= u'\u31bf': # Phonetic symbols (Taiwanese and Hakka expansion) + return True + elif uchar >= u'\ufe10' and uchar <= u'\ufe1f': + return True + elif uchar >= u'\ufe30' and uchar <= u'\ufe4f': + return True + elif uchar >= u'\u2600' and uchar <= u'\u26ff': + return True + elif uchar >= u'\u2700' and uchar <= u'\u27bf': + return True + elif uchar >= u'\u3200' and uchar <= u'\u32ff': + return True + elif uchar >= u'\u3300' and uchar <= u'\u33ff': + return True + + return False + + sentence = sentence.strip() + sentence_in_chars = "" + for char in sentence: + if is_chinese_char(char): + sentence_in_chars += " " + sentence_in_chars += char + sentence_in_chars += " " + else: + sentence_in_chars += char + sentence = sentence_in_chars + + # TODO: the code above could probably be replaced with the following line: + # import regex + # sentence = regex.sub(r'(\p{Han})', r' \1 ', sentence) + + # tokenize punctuation + sentence = re.sub(r'([\{-\~\[-\` -\&\(-\+\:-\@\/])', r' \1 ', sentence) + + # tokenize period and comma unless preceded by a digit + sentence = re.sub(r'([^0-9])([\.,])', r'\1 \2 ', sentence) + + # tokenize period and comma unless followed by a digit + sentence = re.sub(r'([\.,])([^0-9])', r' \1 \2', sentence) + + # tokenize dash when preceded by a digit + sentence = re.sub(r'([0-9])(-)', r'\1 \2 ', sentence) + + # one space only between words + sentence = re.sub(r'\s+', r' ', sentence) + + # no leading or trailing spaces + sentence = sentence.strip() + + return sentence + + +TOKENIZERS = { + '13a': tokenize_13a, + 'intl': tokenize_v14_international, + 'zh': tokenize_zh, + 'none': lambda x: x, +} +DEFAULT_TOKENIZER = '13a' + + +def extract_ngrams(line, min_order=1, max_order=NGRAM_ORDER) -> Counter: + """Extracts all the ngrams (1 <= n <= NGRAM_ORDER) from a sequence of tokens. + + :param line: a segment containing a sequence of words + :param max_order: collect n-grams from 1<=n<=max + :return: a dictionary containing ngrams and counts + """ + + ngrams = Counter() + tokens = line.split() + for n in range(min_order, max_order + 1): + for i in range(0, len(tokens) - n + 1): + ngram = ' '.join(tokens[i:i + n]) + ngrams[ngram] += 1 + + return ngrams + + +def ref_stats(output, refs): + ngrams = Counter() + closest_diff = None + closest_len = None + for ref in refs: + tokens = ref.split() + reflen = len(tokens) + diff = abs(len(output.split()) - reflen) + if closest_diff is None or diff < closest_diff: + closest_diff = diff + closest_len = reflen + elif diff == closest_diff: + if reflen < closest_len: + closest_len = reflen + + ngrams_ref = extract_ngrams(ref) + for ngram in ngrams_ref.keys(): + ngrams[ngram] = max(ngrams[ngram], ngrams_ref[ngram]) + + return ngrams, closest_diff, closest_len + + +BLEU = namedtuple('BLEU', + 'score, counts, totals, precisions, bp, sys_len, ref_len') + + +def compute_bleu(correct: List[int], + total: List[int], + sys_len: int, + ref_len: int, + smooth_method='none', + smooth_value=SMOOTH_VALUE_DEFAULT, + use_effective_order=False) -> BLEU: + """Computes BLEU score from its sufficient statistics. Adds smoothing. + + Smoothing methods (citing "A Systematic Comparison of Smoothing Techniques for Sentence-Level BLEU", + Boxing Chen and Colin Cherry, WMT 2014: http://aclweb.org/anthology/W14-3346) + + - exp: NIST smoothing method (Method 3) + - floor: Method 1 + - add-k: Method 2 (generalizing Lin and Och, 2004) + - none: do nothing. + + :param correct: List of counts of correct ngrams, 1 <= n <= NGRAM_ORDER + :param total: List of counts of total ngrams, 1 <= n <= NGRAM_ORDER + :param sys_len: The cumulative system length + :param ref_len: The cumulative reference length + :param smooth: The smoothing method to use + :param smooth_value: The smoothing value added, if smooth method 'floor' is used + :param use_effective_order: Use effective order. + :return: A BLEU object with the score (100-based) and other statistics. + """ + + precisions = [0 for x in range(NGRAM_ORDER)] + + smooth_mteval = 1. + effective_order = NGRAM_ORDER + for n in range(NGRAM_ORDER): + if smooth_method == 'add-k' and n > 1: + correct[n] += smooth_value + total[n] += smooth_value + if total[n] == 0: + break + + if use_effective_order: + effective_order = n + 1 + + if correct[n] == 0: + if smooth_method == 'exp': + smooth_mteval *= 2 + precisions[n] = 100. / (smooth_mteval * total[n]) + elif smooth_method == 'floor': + precisions[n] = 100. * smooth_value / total[n] + else: + precisions[n] = 100. * correct[n] / total[n] + + # If the system guesses no i-grams, 1 <= i <= NGRAM_ORDER, the BLEU score is 0 (technically undefined). + # This is a problem for sentence-level BLEU or a corpus of short sentences, where systems will get no credit + # if sentence lengths fall under the NGRAM_ORDER threshold. This fix scales NGRAM_ORDER to the observed + # maximum order. It is only available through the API and off by default + + brevity_penalty = 1.0 + if sys_len < ref_len: + brevity_penalty = math.exp(1 - ref_len / sys_len) if sys_len > 0 else 0.0 + + bleu = brevity_penalty * math.exp( + sum(map(my_log, precisions[:effective_order])) / effective_order) + + return BLEU._make( + [bleu, correct, total, precisions, brevity_penalty, sys_len, ref_len]) + -# Modified (added sync for PyTorch DDP) from -# https://github.com/mjpost/sacrebleu/blob/v1.3.1/sacrebleu.py. -# Assumes that sacrebleu==1.3.1 is installed. def corpus_bleu(sys_stream: Sequence[str], ref_streams: Sequence[str], smooth_method: str = 'exp', @@ -21,7 +363,7 @@ def corpus_bleu(sys_stream: Sequence[str], force: bool = False, lowercase: bool = False, tokenize: str = '13a', - use_effective_order: bool = False) -> sacrebleu.BLEU: + use_effective_order: bool = False) -> BLEU: """Produces BLEU scores along with its sufficient statistics from a source against one or more references. :param sys_stream: The system stream (a sequence of segments). @@ -44,8 +386,8 @@ def corpus_bleu(sys_stream: Sequence[str], sys_len = 0 ref_len = 0 - correct = [0 for _ in range(sacrebleu.NGRAM_ORDER)] - total = [0 for _ in range(sacrebleu.NGRAM_ORDER)] + correct = [0 for _ in range(NGRAM_ORDER)] + total = [0 for _ in range(NGRAM_ORDER)] # Look for already-tokenized sentences. tokenized_count = 0 @@ -70,14 +412,14 @@ def corpus_bleu(sys_stream: Sequence[str], 'or don\'t care, you can suppress this message with ' '\'--force\'.') - output, *refs = [sacrebleu.TOKENIZERS[tokenize](x.rstrip()) for x in lines] + output, *refs = [TOKENIZERS[tokenize](x.rstrip()) for x in lines] - ref_ngrams, _, closest_len = sacrebleu.ref_stats(output, refs) + ref_ngrams, _, closest_len = ref_stats(output, refs) sys_len += len(output.split()) ref_len += closest_len - sys_ngrams = sacrebleu.extract_ngrams(output) + sys_ngrams = extract_ngrams(output) for ngram, sys_ngram in sys_ngrams.items(): n = len(ngram.split()) correct[n - 1] += min(sys_ngram, ref_ngrams.get(ngram, 0)) @@ -100,7 +442,7 @@ def corpus_bleu(sys_stream: Sequence[str], dist.all_reduce(total) total = total.cpu().numpy().tolist() - return sacrebleu.compute_bleu( + return compute_bleu( correct, total, sys_len, diff --git a/setup.cfg b/setup.cfg index 2d246b48b..8e37acb7a 100644 --- a/setup.cfg +++ b/setup.cfg @@ -102,7 +102,7 @@ librispeech_conformer = wmt = sentencepiece==0.2.0 tensorflow-text==2.18.0 - sacrebleu==1.3.1 + # Frameworks # # JAX Core From 5e348e4234b061f1819bddcd8d6a3b70ef9804b2 Mon Sep 17 00:00:00 2001 From: init-22 Date: Mon, 23 Dec 2024 00:31:48 +0530 Subject: [PATCH 055/105] fix: resolving pylint errors --- algorithmic_efficiency/workloads/wmt/bleu.py | 132 ++++++++++--------- 1 file changed, 71 insertions(+), 61 deletions(-) diff --git a/algorithmic_efficiency/workloads/wmt/bleu.py b/algorithmic_efficiency/workloads/wmt/bleu.py index dda6d102a..22f6a57e0 100644 --- a/algorithmic_efficiency/workloads/wmt/bleu.py +++ b/algorithmic_efficiency/workloads/wmt/bleu.py @@ -1,5 +1,6 @@ """ -Removing the dependency on sacrebleu, we reimplement the BLEU score computation in this file. +Removing the dependency on sacrebleu, we reimplement the BLEU score computation +in this file. Reference: https://github.com/mjpost/sacrebleu/blob/v1.3.1/sacrebleu.py. """ @@ -42,7 +43,8 @@ def my_log(num): def tokenize_13a(line): """ - Tokenizes an input line using a relatively minimal tokenization that is however equivalent to mteval-v13a, used by WMT. + Tokenizes an input line using a relatively minimal tokenization that is + however equivalent to mteval-v13a, used by WMT. :param line: a segment to tokenize :return: the tokenized line @@ -80,6 +82,7 @@ class UnicodeRegex: without depending on https://pypi.python.org/pypi/regex/.""" + @staticmethod def _property_chars(prefix): return ''.join( chr(x) @@ -95,20 +98,23 @@ def _property_chars(prefix): def tokenize_v14_international(string): r"""Tokenize a string following the official BLEU implementation. - See https://github.com/moses-smt/mosesdecoder/blob/master/scripts/generic/mteval-v14.pl#L954-L983 + See + https://github.com/moses-smt/mosesdecoder/blob/master/scripts/generic/mteval-v14.pl#L954-L983 In our case, the input string is expected to be just one line and no HTML entities de-escaping is needed. So we just tokenize on punctuation and symbols, except when a punctuation is preceded and followed by a digit (e.g. a comma/dot as a thousand/decimal separator). - Note that a number (e.g., a year) followed by a dot at the end of sentence is NOT tokenized, + Note that a number (e.g., a year) followed by a dot at the end of sentence + is NOT tokenized, i.e. the dot stays with the number because `s/(\p{P})(\P{N})/ $1 $2/g` does not match this case (unless we add a space after each sentence). However, this error is already in the original mteval-v14.pl and we want to be consistent with it. The error is not present in the non-international version, - which uses `$norm_text = " $norm_text "` (or `norm = " {} ".format(norm)` in Python). + which uses, + `$norm_text = " $norm_text "` (or `norm = " {} ".format(norm)` in Python). :param string: the input string :return: a list of tokens @@ -123,26 +129,28 @@ def tokenize_zh(sentence): """MIT License Copyright (c) 2017 - Shujian Huang - Permission is hereby granted, free of charge, to any person obtaining a copy - of this software and associated documentation files (the "Software"), to deal - in the Software without restriction, including without limitation the rights - to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - copies of the Software, and to permit persons to whom the Software is - furnished to do so, subject to the following conditions: - - The above copyright notice and this permission notice shall be included in all - copies or substantial portions of the Software. - - THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - SOFTWARE. - - The tokenization of Chinese text in this script contains two steps: separate each Chinese - characters (by utf-8 encoding); tokenize the non Chinese part (following the mteval script). + Permission is hereby granted, free of charge, to any person obtaining + a copy of this software and associated documentation files + (the "Software"), to deal in the Software without restriction, including + without limitation the rights to use, copy, modify, merge, publish, + distribute, sublicense, and/or sell copies of the Software, and to + permit persons to whom the Software is furnished to do so, subject to the + following conditions: + + The above copyright notice and this permission notice shall be included + in all copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, + DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR + OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE + USE OR OTHER DEALINGS IN THE SOFTWARE. + + The tokenization of Chinese text in this script contains two steps: + separate each Chinese characters (by utf-8 encoding); + tokenize the non Chinese part (following the mteval script). Author: Shujian Huang huangsj@nju.edu.cn :param sentence: input sentence @@ -151,54 +159,53 @@ def tokenize_zh(sentence): def is_chinese_char(uchar): """ - :param uchar: input char in unicode - :return: whether the input char is a Chinese character. - """ - if uchar >= u'\u3400' and uchar <= u'\u4db5': # CJK Unified Ideographs Extension A, release 3.0 + :param uchar: input char in unicode + :return: whether the input char is a Chinese character. + """ + if "\u3400" <= uchar <= "\u4db5": return True - elif uchar >= u'\u4e00' and uchar <= u'\u9fa5': # CJK Unified Ideographs, release 1.1 + elif "\u4e00" <= uchar <= "\u9fa5": return True - elif uchar >= u'\u9fa6' and uchar <= u'\u9fbb': # CJK Unified Ideographs, release 4.1 + elif "\u9fa6" <= uchar <= "\u9fbb": return True - elif uchar >= u'\uf900' and uchar <= u'\ufa2d': # CJK Compatibility Ideographs, release 1.1 + elif "\uf900" <= uchar <= "\ufa2d": return True - elif uchar >= u'\ufa30' and uchar <= u'\ufa6a': # CJK Compatibility Ideographs, release 3.2 + elif "\ufa30" <= uchar <= "\ufa6a": return True - elif uchar >= u'\ufa70' and uchar <= u'\ufad9': # CJK Compatibility Ideographs, release 4.1 + elif "\ufa70" <= uchar <= "\ufad9": return True - elif uchar >= u'\u20000' and uchar <= u'\u2a6d6': # CJK Unified Ideographs Extension B, release 3.1 + elif "\u20000" <= uchar <= "\u2a6d6": return True - elif uchar >= u'\u2f800' and uchar <= u'\u2fa1d': # CJK Compatibility Supplement, release 3.1 + elif "\u2f800" <= uchar <= "\u2fa1d": return True - elif uchar >= u'\uff00' and uchar <= u'\uffef': # Full width ASCII, full width of English punctuation, half width Katakana, half wide half width kana, Korean alphabet + elif "\uff00" <= uchar <= "\uffef": return True - elif uchar >= u'\u2e80' and uchar <= u'\u2eff': # CJK Radicals Supplement + elif "\u2e80" <= uchar <= "\u2eff": return True - elif uchar >= u'\u3000' and uchar <= u'\u303f': # CJK punctuation mark + elif "\u3000" <= uchar <= "\u303f": return True - elif uchar >= u'\u31c0' and uchar <= u'\u31ef': # CJK stroke + elif "\u31c0" <= uchar <= "\u31ef": return True - elif uchar >= u'\u2f00' and uchar <= u'\u2fdf': # Kangxi Radicals + elif "\u2f00" <= uchar <= "\u2fdf": return True - elif uchar >= u'\u2ff0' and uchar <= u'\u2fff': # Chinese character structure + elif "\u2ff0" <= uchar <= "\u2fff": return True - elif uchar >= u'\u3100' and uchar <= u'\u312f': # Phonetic symbols + elif "\u3100" <= uchar <= "\u312f": return True - elif uchar >= u'\u31a0' and uchar <= u'\u31bf': # Phonetic symbols (Taiwanese and Hakka expansion) + elif "\u31a0" <= uchar <= "\u31bf": return True - elif uchar >= u'\ufe10' and uchar <= u'\ufe1f': + elif "\ufe10" <= uchar <= "\ufe1f": return True - elif uchar >= u'\ufe30' and uchar <= u'\ufe4f': + elif "\ufe30" <= uchar <= "\ufe4f": return True - elif uchar >= u'\u2600' and uchar <= u'\u26ff': + elif "\u2600" <= uchar <= "\u26ff": return True - elif uchar >= u'\u2700' and uchar <= u'\u27bf': + elif "\u2700" <= uchar <= "\u27bf": return True - elif uchar >= u'\u3200' and uchar <= u'\u32ff': + elif "\u3200" <= uchar <= "\u32ff": return True - elif uchar >= u'\u3300' and uchar <= u'\u33ff': + elif "\u3300" <= uchar <= "\u33ff": return True - return False sentence = sentence.strip() @@ -280,13 +287,13 @@ def ref_stats(output, refs): closest_len = reflen ngrams_ref = extract_ngrams(ref) - for ngram in ngrams_ref.keys(): + for ngram in ngrams_ref: ngrams[ngram] = max(ngrams[ngram], ngrams_ref[ngram]) return ngrams, closest_diff, closest_len -BLEU = namedtuple('BLEU', +BLEU = namedtuple('BLE', 'score, counts, totals, precisions, bp, sys_len, ref_len') @@ -299,8 +306,9 @@ def compute_bleu(correct: List[int], use_effective_order=False) -> BLEU: """Computes BLEU score from its sufficient statistics. Adds smoothing. - Smoothing methods (citing "A Systematic Comparison of Smoothing Techniques for Sentence-Level BLEU", - Boxing Chen and Colin Cherry, WMT 2014: http://aclweb.org/anthology/W14-3346) + Smoothing methods (citing "A Systematic Comparison of Smoothing Techniques + for Sentence-Level BLEU", Boxing Chen and Colin Cherry, + WMT 2014: http://aclweb.org/anthology/W14-3346) - exp: NIST smoothing method (Method 3) - floor: Method 1 @@ -312,7 +320,7 @@ def compute_bleu(correct: List[int], :param sys_len: The cumulative system length :param ref_len: The cumulative reference length :param smooth: The smoothing method to use - :param smooth_value: The smoothing value added, if smooth method 'floor' is used + :param smooth_value: The smoothing value added, if smooth is 'floor' :param use_effective_order: Use effective order. :return: A BLEU object with the score (100-based) and other statistics. """ @@ -340,10 +348,12 @@ def compute_bleu(correct: List[int], else: precisions[n] = 100. * correct[n] / total[n] - # If the system guesses no i-grams, 1 <= i <= NGRAM_ORDER, the BLEU score is 0 (technically undefined). - # This is a problem for sentence-level BLEU or a corpus of short sentences, where systems will get no credit - # if sentence lengths fall under the NGRAM_ORDER threshold. This fix scales NGRAM_ORDER to the observed - # maximum order. It is only available through the API and off by default + # If the system guesses no i-grams, 1 <= i <= NGRAM_ORDER, the BLEU + # score is 0 (technically undefined). This is a problem for sentence-level + # BLEU or a corpus of short sentences, where systems will get no credit + # if sentence lengths fall under the NGRAM_ORDER threshold. This fix scales + # NGRAM_ORDER to the observed maximum order. + # It is only available through the API and off by default brevity_penalty = 1.0 if sys_len < ref_len: @@ -374,7 +384,7 @@ def corpus_bleu(sys_stream: Sequence[str], :param force: Ignore data that looks already tokenized. :param lowercase: Lowercase the data. :param tokenize: The tokenizer to use. - :return: A BLEU object containing everything you'd want. + :return: A BLEU object containing everything yo'd want. """ # Add some robustness to the input arguments. From b769e6cc5b877e5ed40ef7e86aaebd0c53d9d5ab Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 9 Jan 2025 17:59:09 +0000 Subject: [PATCH 056/105] fix startup script for python version upgrade --- docker/scripts/startup.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docker/scripts/startup.sh b/docker/scripts/startup.sh index 527e8306a..1dbba9565 100644 --- a/docker/scripts/startup.sh +++ b/docker/scripts/startup.sh @@ -156,7 +156,7 @@ fi if [[ ${TEST} == "true" ]]; then cd algorithmic-efficiency - COMMAND="python3 tests/test_traindiffs.py" + COMMAND="python tests/test_traindiffs.py" echo $COMMAND eval $COMMAND exit @@ -209,7 +209,7 @@ TUNING_RULESET_FLAG="--tuning_ruleset=${TUNING_RULESET}" # Set run command prefix depending on framework if [[ "${FRAMEWORK}" == "jax" ]]; then - COMMAND_PREFIX="python3" + COMMAND_PREFIX="python" else COMMAND_PREFIX="torchrun --redirects 1:0,2:0,3:0,4:0,5:0,6:0,7:0 --standalone --nnodes=1 --nproc_per_node=8" fi From b65157eb4a8798029f3b21a38a807dd9a6067fa9 Mon Sep 17 00:00:00 2001 From: Isaac Date: Tue, 14 Jan 2025 15:52:44 +0000 Subject: [PATCH 057/105] fix: getargspec is not supported in python311, using getfullargspec instead --- .../workloads/imagenet_resnet/imagenet_jax/randaugment.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py index e920331bc..41002ff9b 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py +++ b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py @@ -482,13 +482,13 @@ def _parse_policy_info(name, # Check to see if prob is passed into function. This is used for operations # where we alter bboxes independently. - if 'prob' in inspect.getargspec(func)[0]: + if 'prob' in inspect.getfullargspec(func)[0]: args = tuple([prob] + list(args)) # Add in replace arg if it is required for the function that is being called. - if 'replace' in inspect.getargspec(func)[0]: + if 'replace' in inspect.getfullargspec(func)[0]: # Make sure replace is the final argument - assert 'replace' == inspect.getargspec(func)[0][-1] + assert 'replace' == inspect.getfullargspec(func)[0][-1] args = tuple(list(args) + [replace_value]) return (func, prob, args) From 8327283285f2127d951c9f8e9c20a30a47444ff4 Mon Sep 17 00:00:00 2001 From: Frank Schneider Date: Wed, 15 Jan 2025 13:07:55 +0100 Subject: [PATCH 058/105] Create equivalent pyproject toml --- pyproject.toml | 327 +++++++++++++++++++++++++++++++++++++++++++++++++ setup.cfg | 314 ----------------------------------------------- setup.py | 4 - 3 files changed, 327 insertions(+), 318 deletions(-) create mode 100644 pyproject.toml delete mode 100644 setup.cfg delete mode 100644 setup.py diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 000000000..10fd4b730 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,327 @@ +############################################################################### +# MLCommons Algorithmic Efficiency. # +############################################################################### + +[project] +name = "algorithmic_efficiency" +dynamic = ["version"] +description = "Codebase for the AlgoPerf: Training Algorithms benchmark" +authors = [ + { name = "MLCommons Algorithms Working Group", email = "algorithms@mlcommons.org" }, +] +license = { text = "Apache 2.0" } +readme = "README.md" +requires-python = ">=3.8" +keywords = [ + "algoperf", + "algorithmic-efficiency", + "machine-learning", + "deep-learning", + "optimization", + "benchmarking", + "training-methods", +] +classifiers = [ + "Development Status :: 3 - Alpha", + "Intended Audience :: Developers", + "Intended Audience :: Science/Research", + "License :: OSI Approved :: Apache Software License", + "Operating System :: OS Independent", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Topic :: Scientific/Engineering :: Artificial Intelligence", +] +dependencies = [ + "absl-py==1.4.0", + "networkx==3.1", + "docker==7.0.0", + "numpy>=1.23", + "pandas>=2.0.1", + "tensorflow==2.12.0", + "tensorflow-datasets==4.9.2", + "tensorflow-probability==0.20.0", + "tensorflow-addons==0.20.0", + "gputil==1.4.0", + "psutil==5.9.5", + "clu==0.0.7", + "matplotlib>=3.7.2", + "tabulate==0.9.0", +] + +[build-system] +requires = ["setuptools>=45", "setuptools_scm[toml]>=6.2"] +build-backend = "setuptools.build_meta" + +[tool.setuptools] +packages = ["algorithmic_efficiency"] +py-modules = ["submission_runner"] +include-package-data = true +zip-safe = false + +[tool.setuptools.dynamic] +version = { attr = "algorithmic_efficiency.__version__" } + +############################################################################### +# (Optional) Dependencies # +############################################################################### +[project.optional-dependencies] +# All workloads +full = [ + "algorithmic_efficiency[criteo1tb,fastmri,ogbg,librispeech_conformer,wmt]", +] +# All workloads plus development dependencies +full_dev = ["algorithmic_efficiency[full,dev]"] +# Dependencies for developing the package +dev = [ + "isort==5.12.0", + "pylint==2.17.4", + "pytest==7.3.1", + "yapf==0.33.0", + "pre-commit==3.3.1", +] + +# Workloads +criteo1tb = ["scikit-learn==1.2.2"] +fastmri = ["h5py==3.8.0", "scikit-image==0.20.0"] +ogbg = ["jraph==0.0.6.dev0", "scikit-learn==1.2.2"] +librispeech_conformer = [ + "sentencepiece==0.1.99", + "tensorflow-text==2.12.1", + "pydub==0.25.1", +] +wmt = ["sentencepiece==0.1.99", "tensorflow-text==2.12.1", "sacrebleu==1.3.1"] + +# Frameworks +jax_core_deps = [ + "flax==0.6.10", + "optax==0.1.5", + # Todo(kasimbeg): verify if this is necessary after we upgrade jax. + "chex==0.1.7", + "ml_dtypes==0.2.0", + "protobuf==4.25.3", +] +jax_cpu = [ + "jax==0.4.10", + "jaxlib==0.4.10", + "algorithmic_efficiency[jax_core_deps]", +] +jax_gpu = [ + "jax==0.4.10", + "jaxlib==0.4.10+cuda12.cudnn88", + "algorithmic_efficiency[jax_core_deps]", +] +pytorch_cpu = ["torch==2.1.0", "torchvision==0.16.0"] +pytorch_gpu = [ + "torch==2.1.0", + "torchvision==0.16.0", +] # Note: omit the cuda suffix and installing from the appropriate wheel will result in using locally installed CUDA. +wandb = ["wandb==0.16.5"] + +############################################################################### +# Linting Configurations # +############################################################################### + +# yapf configuration +[tool.yapf] +based_on_style = "yapf" +each_dict_entry_on_separate_line = false +split_all_top_level_comma_separated_values = true + +# isort configuration +[tool.isort] +profile = "google" + +# pylint configuration +[tool.pylint.MASTER] +persistent = false +ignore = "get_references_web.py,get_references_web_single_group.py" + +[tool.pylint.REPORTS] +reports = false +msg-template = "{msg_id}:{line:3} {obj}: {msg} [{symbol}]" + +[tool.pylint.MESSAGES_CONTROL] +enable = "indexing-exception,old-raise-syntax" + +[tool.pylint.BASIC] +# Required attributes for module, separated by a comma +#required-attributes= +# Regular expression which should only match the name +# of functions or classes which do not require a docstring. +no-docstring-rgx = "(__.*__|main)" +# Min length in lines of a function that requires a docstring. +docstring-min-length = 10 +# Regular expression which should only match correct module names. The +# leading underscore is sanctioned for private modules by Google's style +# guide. +# +# There are exceptions to the basic rule (_?[a-z][a-z0-9_]*) to cover +# requirements of Python's module system. +module-rgx = "^(_?[a-z][a-z0-9_]*)|__init__$" +# Regular expression which should only match correct module level names +const-rgx = "^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$" +# Regular expression which should only match correct class attribute +class-attribute-rgx = "^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$" +# Regular expression which should only match correct class names +class-rgx = "^_?[A-Z][a-zA-Z0-9]*$" +# Regular expression which should only match correct function names. +# 'camel_case' and 'snake_case' group names are used for consistency of naming +# styles across functions and methods. +function-rgx = "^(?:(?PsetUp|tearDown|setUpModule|tearDownModule)|(?P_?[A-Z][a-zA-Z0-9]*)|(?P_?[a-z][a-z0-9_]*))$" +# Regular expression which should only match correct method names. +# 'camel_case' and 'snake_case' group names are used for consistency of naming +# styles across functions and methods. 'exempt' indicates a name which is +# consistent with all naming styles. +method-rgx = "(?x)^(?:(?P_[a-z0-9_]+__|runTest|setUp|tearDown|setUpTestCase|tearDownTestCase|setupSelf|tearDownClass|_testDatasetSize|setUpClass|test|assert)_*[A-Z0-9][a-zA-Z0-9_]*|(?:test|assert)_*[A-Z0-9][a-zA-Z0-9_]*|next|(?P_{0,2}[A-Z][a-zA-Z0-9_]*)|(?P_{0,2}[a-z][a-z0-9_]*))$" +# Regular expression which should only match correct instance attribute names +attr-rgx = "^_{0,2}[a-z][a-z0-9_]*$" +# Regular expression which should only match correct argument names +argument-rgx = "^[a-z][a-z0-9_]*$" +# Regular expression which should only match correct variable names +variable-rgx = "^[a-z][a-z0-9_]*$" +# Regular expression which should only match correct list comprehension / +# generator expression variable names +inlinevar-rgx = "^[a-z][a-z0-9_]*$" +# Good variable names which should always be accepted, separated by a comma +good-names = "main,_" +# Bad variable names which should always be refused, separated by a comma +bad-names = "" +# List of builtins function names that should not be used, separated by a comma +#bad-functions=input,apply,reduce +# List of decorators that define properties, such as abc.abstractproperty. +property-classes = "abc.abstractproperty" + +[tool.pylint.typecheck] +# Tells whether missing members accessed in mixin class should be ignored. A +# mixin class is detected if its name ends with "mixin" (case insensitive). +ignore-mixin-members = true + +# List of decorators that create context managers from functions, such as +# contextlib.contextmanager. +contextmanager-decorators = [ + "contextlib.contextmanager", + "contextlib2.contextmanager", +] + +[tool.pylint.VARIABLES] +# Tells whether we should check for unused import in __init__ files. +init-import = false + +# A regular expression matching names used for dummy variables (i.e. not used). +dummy-variables-rgx = "^\\*{0,2}(_$|unused_|dummy_)" + +# List of additional names supposed to be defined in builtins. +additional-builtins = [] + +[tool.pylint.CLASSES] +# List of method names used to declare (i.e. assign) instance attributes. +defining-attr-methods = ["__init__", "__new__", "setUp"] + +# Valid names for the first argument to a class method. +valid-classmethod-first-arg = ["cls", "class_"] + +[tool.pylint.EXCEPTIONS] +overgeneral-exceptions = [ + "builtins.StandardError", + "builtins.Exception", + "builtins.BaseException", +] + +[tool.pylint.IMPORTS] +# Deprecated modules which should not be used, separated by a comma +deprecated-modules = ["regsub", "TERMIOS", "Bastion", "rexec", "sets"] + +[tool.pylint.FORMAT] +# List of checkers and warnings to disable. +disable = [ + "abstract-method", + "access-member-before-definition", + "arguments-differ", + "assignment-from-no-return", + "attribute-defined-outside-init", + "bad-mcs-classmethod-argument", + "bad-option-value", + "c-extension-no-member", + "consider-merging-isinstance", + "consider-using-dict-comprehension", + "consider-using-enumerate", + "consider-using-in", + "consider-using-set-comprehension", + "consider-using-ternary", + "deprecated-method", + "design", + "file-ignored", + "fixme", + "global-statement", + "import-error", + "inconsistent-return-statements", + "invalid-unary-operand-type", + "len-as-condition", + "locally-disabled", + "locally-enabled", + "misplaced-comparison-constant", + "missing-docstring", + "multiple-imports", + "no-else-return", + "no-member", + "no-name-in-module", + "no-self-use", + "no-value-for-parameter", + "not-an-iterable", + "not-context-manager", + "pointless-except", + "protected-access", + "redefined-argument-from-local", + "signature-differs", + "similarities", + "simplifiable-if-expression", + "star-args", + "super-init-not-called", + "suppressed-message", + "too-many-function-args", + "trailing-comma-tuple", + "trailing-newlines", + "ungrouped-imports", + "unnecessary-pass", + "unsubscriptable-object", + "unused-argument", + "useless-object-inheritance", + "useless-return", + "useless-suppression", + "wrong-import-order", + "wrong-import-position", + "unneeded-not", + "unexpected-keyword-arg", + "redundant-keyword-arg", + "unspecified-encoding", + "logging-fstring-interpolation", + "consider-using-f-string", + "use-dict-literal", +] +# Maximum number of characters on a single line. +max-line-length = 80 +ignore-long-lines = "(?x)(^\\s*(import|from)\\s|^\\s*(\\#\\ )??$|^[a-zA-Z_][a-zA-Z0-9_]*\\s*=\\s*('[^']\\S+'|\"[^\"]\\S+\"))" +# Maximum number of lines in a module +max-module-lines = 99999 +# String used as indentation unit. We differ from PEP8's normal 4 spaces. +indent-string = ' ' +single-line-if-stmt = true +# Do not warn about multiple statements on a single line for constructs like +# if test: stmt +[tool.pylint.LOGGING] +logging-modules = "logging,absl.logging" +# Add logging modules. +[tool.pylint.MISCELLANEOUS] +# Maximum line length for lambdas +#short-func-length=1 +# List of module members that should be marked as deprecated. +# All of the string functions are listed in 4.1.4 Deprecated string functions +# in the Python 2.4 docs. +#deprecated-members=string.atof,string.atoi,string.atol,string.capitalize,string.expandtabs,string.find,string.rfind,string.index,string.rindex,string.count,string.lower,string.split,string.rsplit,string.splitfields,string.join,string.joinfields,string.lstrip,string.rstrip,string.strip,string.swapcase,string.translate,string.upper,string.ljust,string.rjust,string.center,string.zfill,string.replace,sys.exitfunc,sys.maxint +# List of exceptions that do not need to be mentioned in the Raises section of +# a docstring. +#ignore-exceptions=AssertionError,NotImplementedError,StopIteration,TypeError +# Number of spaces of indent required when the last token on the preceding line +# is an open (, [, or {. +indent-after-paren = 4 diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index 4afefd164..000000000 --- a/setup.cfg +++ /dev/null @@ -1,314 +0,0 @@ -############################################################################### -# MLCommons Algorithmic Efficiency. # -############################################################################### - -[metadata] -name = algorithmic_efficiency -version = attr: algorithmic_efficiency.__version__ -description = MLCommons Algorithmic Efficiency -url = https://github.com/mlcommons/algorithmic-efficiency -author = MLCommons Algorithmic Efficiency -author_email = algorithms@mlcommons.org -license = Apache 2.0 -long_description = file: README.md -long_description_content_type = text/markdown -keywords = algorithmic-efficiency, machine-learning, deep-learning, - optimization, benchmarking, training-methods -platforms = any -classifiers = - Development Status :: 3 - Alpha - Intended Audience :: Developers - Intended Audience :: Science/Research - License :: OSI Approved :: Apache Software License - Operating System :: OS Independent - Programming Language :: Python :: 3.8 - Programming Language :: Python :: 3.9 - Programming Language :: Python :: 3.10 - Topic :: Scientific/Engineering :: Artificial Intelligence - -[options] -zip_safe = False -packages = find: -include_package_data = True -setup_requires = - setuptools_scm -# Dependencies of the project: -install_requires = - absl-py==1.4.0 - # Pin to avoid unpinned install in dependencies that requires Python>=3.9. - networkx==3.1 - docker==7.0.0 - numpy>=1.23 - pandas>=2.0.1 - tensorflow==2.12.0 - tensorflow-datasets==4.9.2 - tensorflow-probability==0.20.0 - tensorflow-addons==0.20.0 - gputil==1.4.0 - psutil==5.9.5 - clu==0.0.7 - matplotlib>=3.7.2 - tabulate==0.9.0 -python_requires = >=3.8 - - -############################################################################### -# Additional Dependencies # -############################################################################### - -[options.extras_require] -# Add extra dependencies, e.g. to run tests or for the different frameworks. -# Use as `pip install -e '.[jax_gpu]' -f https://storage.googleapis.com/jax-releases/jax_releases.html` -# or `pip install -e '.[dev]'` - -# Bundled installs # - -# All workloads -full = - %(criteo1tb)s - %(fastmri)s - %(ogbg)s - %(librispeech_conformer)s - %(wmt)s - -# All workloads plus development dependencies -full_dev = - %(full)s - %(dev)s - - -# Dependencies for developing the package -dev = - isort==5.12.0 - pylint==2.17.4 - pytest==7.3.1 - yapf==0.33.0 - pre-commit==3.3.1 - -# Workloads # -criteo1tb = - scikit-learn==1.2.2 - -fastmri = - h5py==3.8.0 - scikit-image==0.20.0 - -ogbg = - jraph==0.0.6.dev0 - scikit-learn==1.2.2 - -librispeech_conformer = - sentencepiece==0.1.99 - tensorflow-text==2.12.1 - pydub==0.25.1 - -wmt = - sentencepiece==0.1.99 - tensorflow-text==2.12.1 - sacrebleu==1.3.1 - -# Frameworks # - -# JAX Core -jax_core_deps = - flax==0.6.10 - optax==0.1.5 - # Fix chex (optax dependency) version. - # Not fixing it can raise dependency issues with our - # jax version. - # Todo(kasimbeg): verify if this is necessary after we - # upgrade jax. - chex==0.1.7 - ml_dtypes==0.2.0 - protobuf==4.25.3 - - -# JAX CPU -jax_cpu = - jax==0.4.10 - jaxlib==0.4.10 - %(jax_core_deps)s - -# JAX GPU -# Note this installs both jax and jaxlib. -jax_gpu = - jax==0.4.10 - jaxlib==0.4.10+cuda12.cudnn88 - %(jax_core_deps)s - -# PyTorch CPU -pytorch_cpu = - torch==2.1.0 - torchvision==0.16.0 - -# PyTorch GPU -# Note: omit the cuda suffix and installing from the appropriate -# wheel will result in using locally installed CUDA. -pytorch_gpu = - torch==2.1.0 - torchvision==0.16.0 - -# wandb -wandb = - wandb==0.16.5 - -############################################################################### -# Linting Configurations # -############################################################################### - -# yapf configuration -[yapf] -based_on_style = yapf -each_dict_entry_on_separate_line = false -split_all_top_level_comma_separated_values = true - - -# isort configuration -[isort] -profile=google - - -# pylint configuration -[pylint.MASTER] -persistent=no # Pickle collected data for later comparisons. -#cache-size=500 # Set the cache size for astng objects. -# Ignore Py3 files -ignore=get_references_web.py,get_references_web_single_group.py -[pylint.REPORTS] -# Set the output format. -# output-format=sorted-text -# Put messages in a separate file for each module / package specified on the -# command line instead of printing them on stdout. Reports (if any) will be -# written in a file name "pylint_global.[txt|html]". -#files-output=no -# Tells whether to display a full report or only the messages. -reports=no -# Disable the report(s) with the given id(s). -#disable-report=R0001,R0002,R0003,R0004,R0101,R0102,R0201,R0202,R0220,R0401,R0402,R0701,R0801,R0901,R0902,R0903,R0904,R0911,R0912,R0913,R0914,R0915,R0921,R0922,R0923 -# Error message template (continued on second line) -msg-template={msg_id}:{line:3} {obj}: {msg} [{symbol}] -[pylint.'MESSAGES CONTROL'] -# List of checkers and warnings to enable. -enable=indexing-exception,old-raise-syntax - - -[pylint.BASIC] -# Required attributes for module, separated by a comma -#required-attributes= -# Regular expression which should only match the name -# of functions or classes which do not require a docstring. -no-docstring-rgx=(__.*__|main) -# Min length in lines of a function that requires a docstring. -docstring-min-length=10 -# Regular expression which should only match correct module names. The -# leading underscore is sanctioned for private modules by Google's style -# guide. -# -# There are exceptions to the basic rule (_?[a-z][a-z0-9_]*) to cover -# requirements of Python's module system. -module-rgx=^(_?[a-z][a-z0-9_]*)|__init__$ -# Regular expression which should only match correct module level names -const-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$ -# Regular expression which should only match correct class attribute -class-attribute-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$ -# Regular expression which should only match correct class names -class-rgx=^_?[A-Z][a-zA-Z0-9]*$ -# Regular expression which should only match correct function names. -# 'camel_case' and 'snake_case' group names are used for consistency of naming -# styles across functions and methods. -function-rgx=^(?:(?PsetUp|tearDown|setUpModule|tearDownModule)|(?P_?[A-Z][a-zA-Z0-9]*)|(?P_?[a-z][a-z0-9_]*))$ -# Regular expression which should only match correct method names. -# 'camel_case' and 'snake_case' group names are used for consistency of naming -# styles across functions and methods. 'exempt' indicates a name which is -# consistent with all naming styles. -method-rgx=(?x) - ^(?:(?P_[a-z0-9_]+__|runTest|setUp|tearDown|setUpTestCase - |tearDownTestCase|setupSelf|tearDownClass|_testDatasetSize|setUpClass - |(test|assert)_*[A-Z0-9][a-zA-Z0-9_]*|next) - |(?P_{0,2}[A-Z][a-zA-Z0-9_]*) - |(?P_{0,2}[a-z][a-z0-9_]*))$ -# Regular expression which should only match correct instance attribute names -attr-rgx=^_{0,2}[a-z][a-z0-9_]*$ -# Regular expression which should only match correct argument names -argument-rgx=^[a-z][a-z0-9_]*$ -# Regular expression which should only match correct variable names -variable-rgx=^[a-z][a-z0-9_]*$ -# Regular expression which should only match correct list comprehension / -# generator expression variable names -inlinevar-rgx=^[a-z][a-z0-9_]*$ -# Good variable names which should always be accepted, separated by a comma -good-names=main,_ -# Bad variable names which should always be refused, separated by a comma -bad-names= -# List of builtins function names that should not be used, separated by a comma -#bad-functions=input,apply,reduce -# List of decorators that define properties, such as abc.abstractproperty. -property-classes=abc.abstractproperty -[pylint.TYPECHECK] -# Tells whether missing members accessed in mixin class should be ignored. A -# mixin class is detected if its name ends with "mixin" (case insensitive). -ignore-mixin-members=yes -# List of decorators that create context managers from functions, such as -# contextlib.contextmanager. -contextmanager-decorators=contextlib.contextmanager,contextlib2.contextmanager -[pylint.VARIABLES] -# Tells whether we should check for unused import in __init__ files. -init-import=no -# A regular expression matching names used for dummy variables (i.e. not used). -dummy-variables-rgx=^\*{0,2}(_$|unused_|dummy_) -# List of additional names supposed to be defined in builtins. Remember that -# you should avoid to define new builtins when possible. -additional-builtins= -[pylint.CLASSES] -# List of method names used to declare (i.e. assign) instance attributes. -defining-attr-methods=__init__,__new__,setUp -# "class_" is also a valid for the first argument to a class method. -valid-classmethod-first-arg=cls,class_ -[pylint.EXCEPTIONS] -overgeneral-exceptions=builtins.StandardError,builtins.Exception,builtins.BaseException -[pylint.IMPORTS] -# Deprecated modules which should not be used, separated by a comma -deprecated-modules=regsub,TERMIOS,Bastion,rexec,sets -[pylint.FORMAT] -# List of checkers and warnings to disable. -disable=abstract-method,access-member-before-definition,arguments-differ,assignment-from-no-return,attribute-defined-outside-init,bad-mcs-classmethod-argument,bad-option-value,c-extension-no-member,consider-merging-isinstance,consider-using-dict-comprehension,consider-using-enumerate,consider-using-in,consider-using-set-comprehension,consider-using-ternary,deprecated-method,design,file-ignored,fixme,global-statement,import-error,inconsistent-return-statements,invalid-unary-operand-type,len-as-condition,locally-disabled,locally-enabled,misplaced-comparison-constant,missing-docstring,multiple-imports,no-else-return,no-member,no-name-in-module,no-self-use,no-value-for-parameter,not-an-iterable,not-context-manager,pointless-except,protected-access,redefined-argument-from-local,signature-differs,similarities,simplifiable-if-expression,star-args,super-init-not-called,suppressed-message,too-many-function-args,trailing-comma-tuple,trailing-newlines,ungrouped-imports,unnecessary-pass,unsubscriptable-object,unused-argument,useless-object-inheritance,useless-return,useless-suppression,wrong-import-order,wrong-import-position,unneeded-not,unexpected-keyword-arg,redundant-keyword-arg,unspecified-encoding,logging-fstring-interpolation,consider-using-f-string,use-dict-literal - -# Maximum number of characters on a single line. -max-line-length=80 -# Regexp for a line that is allowed to be longer than the limit. -# This "ignore" regex is today composed of several independent parts: -# (1) Long import lines -# (2) URLs in comments or pydocs. Detecting URLs by regex is a hard problem and -# no amount of tweaking will make a perfect regex AFAICT. This one is a good -# compromise. -# (3) Constant string literals at the start of files don't need to be broken -# across lines. Allowing long paths and urls to be on a single -# line. Also requires that the string not be a triplequoted string. -ignore-long-lines=(?x) - (^\s*(import|from)\s - |^\s*(\#\ )??$ - |^[a-zA-Z_][a-zA-Z0-9_]*\s*=\s*("[^"]\S+"|'[^']\S+') - ) -# Maximum number of lines in a module -max-module-lines=99999 -# String used as indentation unit. We differ from PEP8's normal 4 spaces. -indent-string=' ' -# Do not warn about multiple statements on a single line for constructs like -# if test: stmt -single-line-if-stmt=y -[pylint.LOGGING] -# Add logging modules. -logging-modules=logging,absl.logging -[pylint.MISCELLANEOUS] -# Maximum line length for lambdas -#short-func-length=1 -# List of module members that should be marked as deprecated. -# All of the string functions are listed in 4.1.4 Deprecated string functions -# in the Python 2.4 docs. -#deprecated-members=string.atof,string.atoi,string.atol,string.capitalize,string.expandtabs,string.find,string.rfind,string.index,string.rindex,string.count,string.lower,string.split,string.rsplit,string.splitfields,string.join,string.joinfields,string.lstrip,string.rstrip,string.strip,string.swapcase,string.translate,string.upper,string.ljust,string.rjust,string.center,string.zfill,string.replace,sys.exitfunc,sys.maxint -# List of exceptions that do not need to be mentioned in the Raises section of -# a docstring. -#ignore-exceptions=AssertionError,NotImplementedError,StopIteration,TypeError -# Number of spaces of indent required when the last token on the preceding line -# is an open (, [, or {. -indent-after-paren=4 diff --git a/setup.py b/setup.py deleted file mode 100644 index a4ead8f48..000000000 --- a/setup.py +++ /dev/null @@ -1,4 +0,0 @@ -from setuptools import setup - -if __name__ == "__main__": - setup() From c8dc704ec42a77cf6ac84c83fbcada13c9b41a1c Mon Sep 17 00:00:00 2001 From: Frank Schneider Date: Wed, 15 Jan 2025 13:21:11 +0100 Subject: [PATCH 059/105] yapf requires toml --- .github/workflows/linting.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/linting.yml b/.github/workflows/linting.yml index 89b5ef288..e49686358 100644 --- a/.github/workflows/linting.yml +++ b/.github/workflows/linting.yml @@ -50,7 +50,7 @@ jobs: - name: Install yapf run: | python -m pip install --upgrade pip - pip install yapf==0.32 + pip install yapf==0.32 toml - name: Run yapf run: | yapf . --diff --recursive From 50658bc1e23cb5f4076f6565c3629d0dc1a8e1fa Mon Sep 17 00:00:00 2001 From: Frank Schneider Date: Wed, 15 Jan 2025 13:55:34 +0100 Subject: [PATCH 060/105] Revert to auto-finding packages (includes `tests/`) --- pyproject.toml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 10fd4b730..3ff79bece 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,11 +54,13 @@ requires = ["setuptools>=45", "setuptools_scm[toml]>=6.2"] build-backend = "setuptools.build_meta" [tool.setuptools] -packages = ["algorithmic_efficiency"] py-modules = ["submission_runner"] include-package-data = true zip-safe = false +[tool.setuptools.packages] +find = {} # Scanning implicit namespaces is active by default + [tool.setuptools.dynamic] version = { attr = "algorithmic_efficiency.__version__" } From 616a0f499967b95db1ff2f47b36effcee975fc5c Mon Sep 17 00:00:00 2001 From: Frank Schneider Date: Wed, 15 Jan 2025 14:36:33 +0100 Subject: [PATCH 061/105] Match version to GH --- CONTRIBUTING.md | 6 ++++++ algorithmic_efficiency/__init__.py | 2 +- pyproject.toml | 2 +- 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 364bbee62..bc5d004e9 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -22,6 +22,7 @@ - [Style Testing](#style-testing) - [Unit and Integration Tests](#unit-and-integration-tests) - [Regression Tests](#regression-tests) + - [Versioning](#versioning) ## Contributing to MLCommons @@ -276,3 +277,8 @@ To run a regression test: 2. Turn on the self-hosted runner. 3. Run the self-hosted runner application for the runner to accept jobs. 4. Open a pull request into mian to trigger the workflow. + +### Versioning + +The package version is centrally defined in `algorithmic_efficiency/__init__.py`. +When releasing a new version, update the version number in `algorithmic_efficiency/__init__.py` and create a new release in the GitHub UI. diff --git a/algorithmic_efficiency/__init__.py b/algorithmic_efficiency/__init__.py index a0e473e1d..05485dcaa 100644 --- a/algorithmic_efficiency/__init__.py +++ b/algorithmic_efficiency/__init__.py @@ -1,3 +1,3 @@ """Algorithmic Efficiency.""" -__version__ = '0.1.0' +__version__ = "0.1.5" diff --git a/pyproject.toml b/pyproject.toml index 3ff79bece..eb3271ee3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,7 +50,7 @@ dependencies = [ ] [build-system] -requires = ["setuptools>=45", "setuptools_scm[toml]>=6.2"] +requires = ["setuptools>=45"] build-backend = "setuptools.build_meta" [tool.setuptools] From bad76f55311715144476542454d1fcde60f509e5 Mon Sep 17 00:00:00 2001 From: Frank Schneider Date: Wed, 15 Jan 2025 15:13:08 +0100 Subject: [PATCH 062/105] Let `setuptools_scm` handle versioning. --- .gitignore | 4 +++- CONTRIBUTING.md | 11 +++++++++-- algorithmic_efficiency/__init__.py | 4 +++- pyproject.toml | 6 +++--- 4 files changed, 18 insertions(+), 7 deletions(-) diff --git a/.gitignore b/.gitignore index d2e212366..85063bcf4 100644 --- a/.gitignore +++ b/.gitignore @@ -23,4 +23,6 @@ wandb/ scoring/plots/ !scoring/test_data/experiment_dir/study_0/mnist_jax/trial_0/eval_measurements.csv -!scoring/test_data/experiment_dir/study_0/mnist_jax/trial_1/eval_measurements.csv \ No newline at end of file +!scoring/test_data/experiment_dir/study_0/mnist_jax/trial_1/eval_measurements.csv + +algorithmic_efficiency/_version.py \ No newline at end of file diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index bc5d004e9..a93289852 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -280,5 +280,12 @@ To run a regression test: ### Versioning -The package version is centrally defined in `algorithmic_efficiency/__init__.py`. -When releasing a new version, update the version number in `algorithmic_efficiency/__init__.py` and create a new release in the GitHub UI. +The package version is automatically determined by the `setuptools_scm` package based on the last git tag. +It follows the structure `major.minor.patch` + `devN` where `N` is the number of commits since the last tag. +It automatically increments the patch version (i.e. it guesses the next version) if there are commits after the last tag. +Additionally, if there are uncommitted changes, the version will include a suffix separated by a `+` character and includes the last commit hash plus the date on dirt workdir (see [setuptools_scm's documentation](https://setuptools-scm.readthedocs.io/en/latest/extending/#setuptools_scmlocal_scheme) with the default version and local scheme). +You can check what version `setuptools_scm` is creating by running `python -m setuptools_scm`. + +To create a new version, create a new release (and tag) in the GitHub UI. +The package version is automatically updated to the new version. +Once the package is installed, the version can be accessed as the package attribute `algorithmic_efficiency.__version__`, i.e. via `python -c "import algorithmic_efficiency; print(algorithmic_efficiency.__version__)"`. diff --git a/algorithmic_efficiency/__init__.py b/algorithmic_efficiency/__init__.py index 05485dcaa..7d54f8290 100644 --- a/algorithmic_efficiency/__init__.py +++ b/algorithmic_efficiency/__init__.py @@ -1,3 +1,5 @@ """Algorithmic Efficiency.""" -__version__ = "0.1.5" +from ._version import version as __version__ + +__all__ = ["__version__"] diff --git a/pyproject.toml b/pyproject.toml index eb3271ee3..2c6d28458 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,7 +50,7 @@ dependencies = [ ] [build-system] -requires = ["setuptools>=45"] +requires = ["setuptools>=45", "setuptools_scm[toml]>=6.2"] build-backend = "setuptools.build_meta" [tool.setuptools] @@ -61,8 +61,8 @@ zip-safe = false [tool.setuptools.packages] find = {} # Scanning implicit namespaces is active by default -[tool.setuptools.dynamic] -version = { attr = "algorithmic_efficiency.__version__" } +[tool.setuptools_scm] +version_file = "algorithmic_efficiency/_version.py" ############################################################################### # (Optional) Dependencies # From ff4a457ea6eea6e2603887066e3e2735a8867d2a Mon Sep 17 00:00:00 2001 From: Frank Schneider Date: Wed, 15 Jan 2025 15:23:37 +0100 Subject: [PATCH 063/105] Fix version test --- tests/version_test.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/tests/version_test.py b/tests/version_test.py index 9f7006aab..7e3001324 100644 --- a/tests/version_test.py +++ b/tests/version_test.py @@ -4,10 +4,13 @@ def test_version_attribute(): - """Check whether __version__ exists and is a valid string.""" + """Check whether __version__ exists and is a valid string.""" - assert hasattr(algorithmic_efficiency, "__version__") - version = algorithmic_efficiency.__version__ - assert isinstance(version, str) - version_elements = version.split(".") - assert all(el.isnumeric() for el in version_elements) + assert hasattr(algorithmic_efficiency, "__version__") + version = algorithmic_efficiency.__version__ + assert isinstance(version, str) + version_elements = version.split(".") + print(version_elements) + # Only check the first three elements, i.e. major, minor, patch. + # The remaining elements contain commit hash and dirty status. + assert all(el.isnumeric() for el in version_elements[0:3]) From f97c880bb8a425430666257f3cdac2ef5a6a8187 Mon Sep 17 00:00:00 2001 From: Frank Schneider Date: Wed, 15 Jan 2025 15:25:06 +0100 Subject: [PATCH 064/105] Match file name of version test to the other tests --- .github/workflows/CI.yml | 2 +- tests/{version_test.py => test_version.py} | 0 2 files changed, 1 insertion(+), 1 deletion(-) rename tests/{version_test.py => test_version.py} (100%) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 05d94e896..e05b74eef 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -199,7 +199,7 @@ jobs: pip install .[pytorch_cpu] - name: Run pytest tests run: | - pytest -vx tests/version_test.py + pytest -vx tests/test_version.py pytest -vx tests/test_num_params.py pytest -vx tests/test_param_shapes.py pytest -vx tests/test_param_types.py diff --git a/tests/version_test.py b/tests/test_version.py similarity index 100% rename from tests/version_test.py rename to tests/test_version.py From f98b55480041a31d5dc07f0af26937eee8750a49 Mon Sep 17 00:00:00 2001 From: Frank Schneider Date: Wed, 15 Jan 2025 15:38:17 +0100 Subject: [PATCH 065/105] Fix linting --- pyproject.toml | 4 +++- tests/test_version.py | 18 +++++++++--------- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 2c6d28458..0788d48a5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -129,6 +129,8 @@ wandb = ["wandb==0.16.5"] based_on_style = "yapf" each_dict_entry_on_separate_line = false split_all_top_level_comma_separated_values = true +[tool.yapfignore] +ignore_patterns = ["algorithmic_efficiency/_version.py"] # isort configuration [tool.isort] @@ -137,7 +139,7 @@ profile = "google" # pylint configuration [tool.pylint.MASTER] persistent = false -ignore = "get_references_web.py,get_references_web_single_group.py" +ignore = "get_references_web.py,get_references_web_single_group.py,_version.py" [tool.pylint.REPORTS] reports = false diff --git a/tests/test_version.py b/tests/test_version.py index 7e3001324..37aa26ea9 100644 --- a/tests/test_version.py +++ b/tests/test_version.py @@ -4,13 +4,13 @@ def test_version_attribute(): - """Check whether __version__ exists and is a valid string.""" + """Check whether __version__ exists and is a valid string.""" - assert hasattr(algorithmic_efficiency, "__version__") - version = algorithmic_efficiency.__version__ - assert isinstance(version, str) - version_elements = version.split(".") - print(version_elements) - # Only check the first three elements, i.e. major, minor, patch. - # The remaining elements contain commit hash and dirty status. - assert all(el.isnumeric() for el in version_elements[0:3]) + assert hasattr(algorithmic_efficiency, "__version__") + version = algorithmic_efficiency.__version__ + assert isinstance(version, str) + version_elements = version.split(".") + print(version_elements) + # Only check the first three elements, i.e. major, minor, patch. + # The remaining elements contain commit hash and dirty status. + assert all(el.isnumeric() for el in version_elements[0:3]) From 8171a32e12047c448003f75041b0886a6f09365c Mon Sep 17 00:00:00 2001 From: Frank Schneider Date: Wed, 15 Jan 2025 16:00:28 +0100 Subject: [PATCH 066/105] Update version test to only check major and minor elements, excluding patch version. --- tests/test_version.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/test_version.py b/tests/test_version.py index 37aa26ea9..ef01d4f32 100644 --- a/tests/test_version.py +++ b/tests/test_version.py @@ -11,6 +11,7 @@ def test_version_attribute(): assert isinstance(version, str) version_elements = version.split(".") print(version_elements) - # Only check the first three elements, i.e. major, minor, patch. + # Only check the first two elements, i.e. major, minor + # (patch is not checked as it is not required). # The remaining elements contain commit hash and dirty status. - assert all(el.isnumeric() for el in version_elements[0:3]) + assert all(el.isnumeric() for el in version_elements[0:2]) From 96cc471df4b85d22595907a0b18d268a10186141 Mon Sep 17 00:00:00 2001 From: Frank Schneider Date: Wed, 15 Jan 2025 16:04:03 +0100 Subject: [PATCH 067/105] Rename job in regression tests workflow from `criteo_resnet_pytorch` to `criteo_embed_init_pytorch` due to likely typo. --- .github/workflows/regression_tests_variants.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/regression_tests_variants.yml b/.github/workflows/regression_tests_variants.yml index ef1585d0d..b234575b7 100644 --- a/.github/workflows/regression_tests_variants.yml +++ b/.github/workflows/regression_tests_variants.yml @@ -72,7 +72,7 @@ jobs: run: | docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} -d criteo1tb -f pytorch -s reference_algorithms/paper_baselines/adamw/pytorch/submission.py -w criteo1tb_resnet -t reference_algorithms/paper_baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false - criteo_resnet_pytorch: + criteo_embed_init_pytorch: runs-on: self-hosted needs: build_and_push_pytorch_docker_image steps: From 230bf8471e9d660d60ca7c191800a5769483de72 Mon Sep 17 00:00:00 2001 From: Frank Schneider Date: Wed, 15 Jan 2025 16:05:34 +0100 Subject: [PATCH 068/105] Fix some markdown linting issues. --- DOCUMENTATION.md | 3 +-- GETTING_STARTED.md | 15 +++++++++------ 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/DOCUMENTATION.md b/DOCUMENTATION.md index 05d6515bd..d04b247f3 100644 --- a/DOCUMENTATION.md +++ b/DOCUMENTATION.md @@ -222,7 +222,6 @@ def update_params( - Cannot replace the model parameters with pre-trained ones. - Batch norm should work here because the `model_fn` will return updated batch norm moving averages when it is told to with `update_batch_norm`. - ###### Prepare for evaluation function ```python @@ -278,7 +277,7 @@ def data_selection( In general, with noisy, non-deterministic training, evaluation frequency can affect training time measurements as more "bites of the apple" potentially allows the training code to exploit instability. We also want to discourage submissions from complicated and unrealistic logic that attempts to guess when training is close to complete and increases the evaluation rate, while not producing a well-sampled training curve at the start of training. Simply allowing submissions complete freedom over evaluation frequency encourages competitors to work to minimize the number of evaluations, which distracts from the primary goal of finding better training algorithms. -Submissions are eligible for an untimed eval every `eval_period` seconds. Before proceeding to evaluation, the submission can prepare the model through a call to `prepare_for_eval`, effectively modifying the model parameters and state as well as the the optimizer state. Any additional evaluations performed by the submission code count against the runtime for scoring. +Submissions are eligible for an untimed eval every `eval_period` seconds. Before proceeding to evaluation, the submission can prepare the model through a call to `prepare_for_eval`, effectively modifying the model parameters and state as well as the the optimizer state. Any additional evaluations performed by the submission code count against the runtime for scoring. The harness that runs the submission code will attempt to eval every `eval_period` seconds by checking between each submission step (call of `update_params`) whether it has been at least `eval_period` seconds since that last eval, if so, the submission is given the possibility to prepare for evaluation (through a timed call to `prepare_for_eval`). If the accumulated runtime does not exceed the maximum allowed runtime after the preparation step, the clock is paused, and the submission is evaluated. This means that if calls to `update_params` typically take a lot more than `eval_period` seconds, such submissions will not receive as many untimed evals as a submission that had an `update_params` function that took less time. However, for appropriate settings of `eval_period`, we expect this to be quite rare. Submissions are always free to restructure their `update_params` code to split work into two subsequent steps to regain the potential benefits of these untimed model evaluations. For each workload, the `eval_period` will be set such that the total evaluation time is roughly between 10% and 20% of the total training time for the target-setting runs. #### Valid submissions diff --git a/GETTING_STARTED.md b/GETTING_STARTED.md index 006b972ec..7d53c35e2 100644 --- a/GETTING_STARTED.md +++ b/GETTING_STARTED.md @@ -18,6 +18,8 @@ - [Docker Tips](#docker-tips) - [Score your Submission](#score-your-submission) - [Running workloads](#running-workloads) +- [Package your Submission code](#package-your-submission-code) +- [Package Logs for Self-Reporting Submissions](#package-logs-for-self-reporting-submissions) ## Set Up and Installation @@ -80,7 +82,6 @@ To set up a virtual enviornment and install this repository pip3 install -e '.[full]' ``` -
Per workload installations @@ -414,22 +415,24 @@ submission_folder/ ``` Specifically we require that: + 1. There exist subdirectories in the the submission folder named after the ruleset: `external_tuning` or `self_tuning`. -2. The ruleset subdirectories contain directories named according to -some identifier of the algorithm. -3. Each algorithm subdirectory contains a `submission.py` module. Additional helper modules are allowed if prefer to you organize your code into multiple files. If there are additional python packages that have to be installed for the algorithm also include a `requirements.txt` with package names and versions in the algorithm subdirectory. +2. The ruleset subdirectories contain directories named according to +some identifier of the algorithm. +3. Each algorithm subdirectory contains a `submission.py` module. Additional helper modules are allowed if prefer to you organize your code into multiple files. If there are additional python packages that have to be installed for the algorithm also include a `requirements.txt` with package names and versions in the algorithm subdirectory. 4. For `external_tuning` algorithms the algorithm subdirectory should contain a `tuning_search_space.json`. To check that your submission folder meets the above requirements you can run the `submissions/repo_checker.py` script. ## Package Logs for Self-Reporting Submissions + To prepare your submission for self reporting run: -``` +```bash python3 package_logs.py --experiment_dir --destination_dir ``` -The destination directiory will contain the logs packed in studies and trials required for self-reporting. +The destination directiory will contain the logs packed in studies and trials required for self-reporting. **Good Luck!** From ce44582c5e4e81e71a41798076315534259c97a0 Mon Sep 17 00:00:00 2001 From: Frank Schneider Date: Wed, 15 Jan 2025 16:14:51 +0100 Subject: [PATCH 069/105] Add trailing new line --- .gitignore | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 85063bcf4..403b08c2b 100644 --- a/.gitignore +++ b/.gitignore @@ -25,4 +25,4 @@ scoring/plots/ !scoring/test_data/experiment_dir/study_0/mnist_jax/trial_0/eval_measurements.csv !scoring/test_data/experiment_dir/study_0/mnist_jax/trial_1/eval_measurements.csv -algorithmic_efficiency/_version.py \ No newline at end of file +algorithmic_efficiency/_version.py From 37f556d7e5a6e22c87b59ddde663c3ca5b263280 Mon Sep 17 00:00:00 2001 From: Frank Schneider Date: Wed, 15 Jan 2025 16:34:41 +0100 Subject: [PATCH 070/105] Rename package from `algorithmic-efficiency` to `algoperf`. --- .github/workflows/linting.yml | 2 +- .gitignore | 6 +-- CHANGELOG.md | 20 +++++---- CONTRIBUTING.md | 8 ++-- DOCUMENTATION.md | 2 +- GETTING_STARTED.md | 2 +- .../__init__.py | 0 .../checkpoint_utils.py | 4 +- .../data_utils.py | 2 +- .../halton.py | 0 .../init_utils.py | 0 .../interop_utils.py | 2 +- .../logger_utils.py | 4 +- .../param_utils.py | 2 +- .../profiler.py | 0 .../pytorch_utils.py | 8 ++-- .../random_utils.py | 0 {algorithmic_efficiency => algoperf}/spec.py | 0 .../workloads/__init__.py | 0 .../workloads/cifar/__init__.py | 0 .../workloads/cifar/cifar_jax/__init__.py | 0 .../cifar/cifar_jax/input_pipeline.py | 4 +- .../workloads/cifar/cifar_jax/models.py | 4 +- .../workloads/cifar/cifar_jax/workload.py | 10 ++--- .../workloads/cifar/cifar_pytorch/__init__.py | 0 .../workloads/cifar/cifar_pytorch/models.py | 10 ++--- .../workloads/cifar/cifar_pytorch/workload.py | 12 +++--- .../workloads/cifar/workload.py | 6 +-- .../workloads/criteo1tb/__init__.py | 0 .../criteo1tb/criteo1tb_jax/__init__.py | 0 .../criteo1tb/criteo1tb_jax/models.py | 0 .../criteo1tb/criteo1tb_jax/workload.py | 8 ++-- .../criteo1tb/criteo1tb_pytorch/__init__.py | 0 .../criteo1tb/criteo1tb_pytorch/models.py | 0 .../criteo1tb/criteo1tb_pytorch/workload.py | 10 ++--- .../workloads/criteo1tb/input_pipeline.py | 2 +- .../workloads/criteo1tb/workload.py | 4 +- .../workloads/fastmri/__init__.py | 0 .../workloads/fastmri/fastmri_jax/__init__.py | 0 .../workloads/fastmri/fastmri_jax/models.py | 0 .../workloads/fastmri/fastmri_jax/ssim.py | 0 .../workloads/fastmri/fastmri_jax/workload.py | 12 +++--- .../fastmri/fastmri_pytorch/__init__.py | 0 .../fastmri/fastmri_pytorch/models.py | 2 +- .../workloads/fastmri/fastmri_pytorch/ssim.py | 2 +- .../fastmri/fastmri_pytorch/workload.py | 14 +++---- .../workloads/fastmri/input_pipeline.py | 2 +- .../workloads/fastmri/workload.py | 4 +- .../workloads/imagenet_resnet/__init__.py | 0 .../imagenet_resnet/imagenet_jax/__init__.py | 0 .../imagenet_jax/input_pipeline.py | 6 +-- .../imagenet_resnet/imagenet_jax/models.py | 2 +- .../imagenet_jax/randaugment.py | 0 .../imagenet_resnet/imagenet_jax/workload.py | 14 +++---- .../imagenet_pytorch/__init__.py | 0 .../imagenet_pytorch/models.py | 4 +- .../imagenet_pytorch/randaugment.py | 2 +- .../imagenet_pytorch/workload.py | 18 ++++---- .../workloads/imagenet_resnet/imagenet_v2.py | 6 +-- .../workloads/imagenet_resnet/workload.py | 2 +- .../workloads/imagenet_vit/__init__.py | 0 .../imagenet_vit/imagenet_jax/__init__.py | 0 .../imagenet_vit/imagenet_jax/models.py | 2 +- .../imagenet_vit/imagenet_jax/workload.py | 12 +++--- .../imagenet_vit/imagenet_pytorch/__init__.py | 0 .../imagenet_vit/imagenet_pytorch/models.py | 6 +-- .../imagenet_vit/imagenet_pytorch/workload.py | 14 +++---- .../workloads/imagenet_vit/workload.py | 4 +- .../librispeech_conformer/__init__.py | 0 .../librispeech_conformer/input_pipeline.py | 0 .../librispeech_jax/__init__.py | 0 .../librispeech_preprocessor.py | 0 .../librispeech_jax/models.py | 4 +- .../librispeech_jax/spectrum_augmenter.py | 0 .../librispeech_jax/workload.py | 14 +++---- .../librispeech_pytorch/__init__.py | 0 .../librispeech_pytorch/models.py | 4 +- .../librispeech_pytorch/preprocessor.py | 0 .../librispeech_pytorch/spectrum_augmenter.py | 0 .../librispeech_pytorch/workload.py | 18 ++++---- .../librispeech_conformer/metrics.py | 0 .../librispeech_conformer/workload.py | 2 +- .../librispeech_deepspeech/__init__.py | 0 .../librispeech_jax/__init__.py | 0 .../librispeech_jax/models.py | 4 +- .../librispeech_jax/workload.py | 8 ++-- .../librispeech_pytorch/__init__.py | 0 .../librispeech_pytorch/models.py | 4 +- .../librispeech_pytorch/workload.py | 14 +++---- .../workloads/mnist/__init__.py | 0 .../workloads/mnist/mnist_jax/__init__.py | 0 .../workloads/mnist/mnist_jax/workload.py | 6 +-- .../workloads/mnist/mnist_pytorch/__init__.py | 0 .../workloads/mnist/mnist_pytorch/workload.py | 10 ++--- .../workloads/mnist/workload.py | 8 ++-- .../workloads/ogbg/__init__.py | 0 .../workloads/ogbg/input_pipeline.py | 0 .../workloads/ogbg/metrics.py | 2 +- .../workloads/ogbg/ogbg_jax/__init__.py | 0 .../workloads/ogbg/ogbg_jax/models.py | 0 .../workloads/ogbg/ogbg_jax/workload.py | 10 ++--- .../workloads/ogbg/ogbg_pytorch/__init__.py | 0 .../workloads/ogbg/ogbg_pytorch/models.py | 2 +- .../workloads/ogbg/ogbg_pytorch/workload.py | 12 +++--- .../workloads/ogbg/workload.py | 8 ++-- .../workloads/utils.py | 0 .../workloads/wmt/__init__.py | 0 .../workloads/wmt/bleu.py | 2 +- .../workloads/wmt/input_pipeline.py | 6 +-- .../workloads/wmt/tokenizer.py | 0 .../workloads/wmt/wmt_jax/__init__.py | 0 .../workloads/wmt/wmt_jax/decode.py | 0 .../workloads/wmt/wmt_jax/models.py | 0 .../workloads/wmt/wmt_jax/workload.py | 12 +++--- .../workloads/wmt/wmt_pytorch/__init__.py | 0 .../workloads/wmt/wmt_pytorch/decode.py | 2 +- .../workloads/wmt/wmt_pytorch/models.py | 0 .../workloads/wmt/wmt_pytorch/workload.py | 14 +++---- .../workloads/wmt/workload.py | 6 +-- .../workloads/workloads.py | 4 +- datasets/dataset_setup.py | 4 +- .../external_tuning/jax_nadamw_full_budget.py | 2 +- .../jax_nadamw_target_setting.py | 2 +- .../pytorch_nadamw_full_budget.py | 4 +- .../pytorch_nadamw_target_setting.py | 4 +- .../self_tuning/jax_nadamw_full_budget.py | 2 +- .../self_tuning/jax_nadamw_target_setting.py | 2 +- .../self_tuning/pytorch_nadamw_full_budget.py | 4 +- .../pytorch_nadamw_target_setting.py | 4 +- pyproject.toml | 14 +++---- .../cifar/cifar_jax/submission.py | 2 +- .../cifar/cifar_pytorch/submission.py | 2 +- .../mnist/mnist_jax/submission.py | 2 +- .../mnist/mnist_pytorch/submission.py | 2 +- .../adafactor/jax/submission.py | 2 +- .../adafactor/pytorch/submission.py | 4 +- .../paper_baselines/adamw/jax/submission.py | 2 +- .../adamw/pytorch/submission.py | 4 +- .../paper_baselines/lamb/jax/submission.py | 2 +- .../lamb/pytorch/submission.py | 2 +- .../momentum/jax/submission.py | 2 +- .../momentum/pytorch/submission.py | 4 +- .../paper_baselines/nadamw/jax/submission.py | 2 +- .../nadamw/pytorch/submission.py | 4 +- .../nesterov/jax/submission.py | 2 +- .../nesterov/pytorch/submission.py | 4 +- .../paper_baselines/sam/jax/submission.py | 2 +- .../paper_baselines/sam/pytorch/submission.py | 4 +- .../paper_baselines/shampoo/jax/submission.py | 2 +- .../data_selection.py | 2 +- .../target_setting_algorithms/jax_adamw.py | 2 +- .../target_setting_algorithms/jax_momentum.py | 2 +- .../target_setting_algorithms/jax_nadamw.py | 2 +- .../target_setting_algorithms/jax_nesterov.py | 2 +- .../jax_submission_base.py | 2 +- .../pytorch_adamw.py | 2 +- .../pytorch_momentum.py | 2 +- .../pytorch_nadamw.py | 2 +- .../pytorch_nesterov.py | 2 +- .../pytorch_submission_base.py | 4 +- scoring/performance_profile.py | 6 +-- scoring/run_workloads.py | 4 +- scoring/scoring_utils.py | 4 +- submission_runner.py | 22 +++++----- submissions/template/submission.py | 2 +- tests/modeldiffs/criteo1tb/compare.py | 6 +-- .../criteo1tb_embed_init/compare.py | 6 +-- .../modeldiffs/criteo1tb_layernorm/compare.py | 6 +-- tests/modeldiffs/criteo1tb_resnet/compare.py | 6 +-- tests/modeldiffs/fastmri/compare.py | 6 +-- tests/modeldiffs/fastmri_layernorm/compare.py | 6 +-- .../modeldiffs/fastmri_model_size/compare.py | 6 +-- tests/modeldiffs/fastmri_tanh/compare.py | 6 +-- tests/modeldiffs/imagenet_resnet/compare.py | 6 +-- .../imagenet_resnet/gelu_compare.py | 6 +-- .../imagenet_resnet/silu_compare.py | 6 +-- tests/modeldiffs/imagenet_vit/compare.py | 6 +-- tests/modeldiffs/imagenet_vit_glu/compare.py | 6 +-- tests/modeldiffs/imagenet_vit_map/compare.py | 6 +-- .../modeldiffs/imagenet_vit_postln/compare.py | 6 +-- .../librispeech_conformer/compare.py | 6 +-- .../compare.py | 6 +-- .../librispeech_conformer_gelu/compare.py | 6 +-- .../compare.py | 6 +-- .../librispeech_deepspeech/compare.py | 6 +-- .../compare.py | 6 +-- .../librispeech_deepspeech_normaug/compare.py | 6 +-- .../librispeech_deepspeech_tanh/compare.py | 6 +-- tests/modeldiffs/ogbg/compare.py | 6 +-- tests/modeldiffs/ogbg_gelu/compare.py | 6 +-- tests/modeldiffs/ogbg_model_size/compare.py | 6 +-- tests/modeldiffs/ogbg_silu/compare.py | 6 +-- tests/modeldiffs/vanilla_sgd_jax.py | 2 +- tests/modeldiffs/vanilla_sgd_pytorch.py | 2 +- tests/modeldiffs/wmt/compare.py | 6 +-- .../modeldiffs/wmt_attention_temp/compare.py | 6 +-- tests/modeldiffs/wmt_glu_tanh/compare.py | 6 +-- tests/modeldiffs/wmt_post_ln/compare.py | 6 +-- tests/reference_algorithm_tests.py | 14 +++---- tests/submission_runner_test.py | 2 +- tests/test_baselines.py | 4 +- tests/test_num_params.py | 38 ++++++++--------- tests/test_param_shapes.py | 40 +++++++++--------- tests/test_param_types.py | 42 +++++++++---------- tests/test_ssim.py | 10 ++--- tests/test_version.py | 6 +-- tests/version_test.py | 16 +++++++ .../imagenet_jax/workload_test.py | 4 +- 208 files changed, 489 insertions(+), 467 deletions(-) rename {algorithmic_efficiency => algoperf}/__init__.py (100%) rename {algorithmic_efficiency => algoperf}/checkpoint_utils.py (98%) rename {algorithmic_efficiency => algoperf}/data_utils.py (99%) rename {algorithmic_efficiency => algoperf}/halton.py (100%) rename {algorithmic_efficiency => algoperf}/init_utils.py (100%) rename {algorithmic_efficiency => algoperf}/interop_utils.py (90%) rename {algorithmic_efficiency => algoperf}/logger_utils.py (99%) rename {algorithmic_efficiency => algoperf}/param_utils.py (99%) rename {algorithmic_efficiency => algoperf}/profiler.py (100%) rename {algorithmic_efficiency => algoperf}/pytorch_utils.py (89%) rename {algorithmic_efficiency => algoperf}/random_utils.py (100%) rename {algorithmic_efficiency => algoperf}/spec.py (100%) rename {algorithmic_efficiency => algoperf}/workloads/__init__.py (100%) rename {algorithmic_efficiency => algoperf}/workloads/cifar/__init__.py (100%) rename {algorithmic_efficiency => algoperf}/workloads/cifar/cifar_jax/__init__.py (100%) rename {algorithmic_efficiency => algoperf}/workloads/cifar/cifar_jax/input_pipeline.py (98%) rename {algorithmic_efficiency => algoperf}/workloads/cifar/cifar_jax/models.py (93%) rename {algorithmic_efficiency => algoperf}/workloads/cifar/cifar_jax/workload.py (95%) rename {algorithmic_efficiency => algoperf}/workloads/cifar/cifar_pytorch/__init__.py (100%) rename {algorithmic_efficiency => algoperf}/workloads/cifar/cifar_pytorch/models.py (92%) rename {algorithmic_efficiency => algoperf}/workloads/cifar/cifar_pytorch/workload.py (96%) rename {algorithmic_efficiency => algoperf}/workloads/cifar/workload.py (97%) rename {algorithmic_efficiency => algoperf}/workloads/criteo1tb/__init__.py (100%) rename {algorithmic_efficiency => algoperf}/workloads/criteo1tb/criteo1tb_jax/__init__.py (100%) rename {algorithmic_efficiency => algoperf}/workloads/criteo1tb/criteo1tb_jax/models.py (100%) rename {algorithmic_efficiency => algoperf}/workloads/criteo1tb/criteo1tb_jax/workload.py (96%) rename {algorithmic_efficiency => algoperf}/workloads/criteo1tb/criteo1tb_pytorch/__init__.py (100%) rename {algorithmic_efficiency => algoperf}/workloads/criteo1tb/criteo1tb_pytorch/models.py (100%) rename {algorithmic_efficiency => algoperf}/workloads/criteo1tb/criteo1tb_pytorch/workload.py (97%) rename {algorithmic_efficiency => algoperf}/workloads/criteo1tb/input_pipeline.py (98%) rename {algorithmic_efficiency => algoperf}/workloads/criteo1tb/workload.py (97%) rename {algorithmic_efficiency => algoperf}/workloads/fastmri/__init__.py (100%) rename {algorithmic_efficiency => algoperf}/workloads/fastmri/fastmri_jax/__init__.py (100%) rename {algorithmic_efficiency => algoperf}/workloads/fastmri/fastmri_jax/models.py (100%) rename {algorithmic_efficiency => algoperf}/workloads/fastmri/fastmri_jax/ssim.py (100%) rename {algorithmic_efficiency => algoperf}/workloads/fastmri/fastmri_jax/workload.py (95%) rename {algorithmic_efficiency => algoperf}/workloads/fastmri/fastmri_pytorch/__init__.py (100%) rename {algorithmic_efficiency => algoperf}/workloads/fastmri/fastmri_pytorch/models.py (99%) rename {algorithmic_efficiency => algoperf}/workloads/fastmri/fastmri_pytorch/ssim.py (98%) rename {algorithmic_efficiency => algoperf}/workloads/fastmri/fastmri_pytorch/workload.py (96%) rename {algorithmic_efficiency => algoperf}/workloads/fastmri/input_pipeline.py (99%) rename {algorithmic_efficiency => algoperf}/workloads/fastmri/workload.py (96%) rename {algorithmic_efficiency => algoperf}/workloads/imagenet_resnet/__init__.py (100%) rename {algorithmic_efficiency => algoperf}/workloads/imagenet_resnet/imagenet_jax/__init__.py (100%) rename {algorithmic_efficiency => algoperf}/workloads/imagenet_resnet/imagenet_jax/input_pipeline.py (98%) rename {algorithmic_efficiency => algoperf}/workloads/imagenet_resnet/imagenet_jax/models.py (98%) rename {algorithmic_efficiency => algoperf}/workloads/imagenet_resnet/imagenet_jax/randaugment.py (100%) rename {algorithmic_efficiency => algoperf}/workloads/imagenet_resnet/imagenet_jax/workload.py (95%) rename {algorithmic_efficiency => algoperf}/workloads/imagenet_resnet/imagenet_pytorch/__init__.py (100%) rename {algorithmic_efficiency => algoperf}/workloads/imagenet_resnet/imagenet_pytorch/models.py (98%) rename {algorithmic_efficiency => algoperf}/workloads/imagenet_resnet/imagenet_pytorch/randaugment.py (99%) rename {algorithmic_efficiency => algoperf}/workloads/imagenet_resnet/imagenet_pytorch/workload.py (95%) rename {algorithmic_efficiency => algoperf}/workloads/imagenet_resnet/imagenet_v2.py (90%) rename {algorithmic_efficiency => algoperf}/workloads/imagenet_resnet/workload.py (99%) rename {algorithmic_efficiency => algoperf}/workloads/imagenet_vit/__init__.py (100%) rename {algorithmic_efficiency => algoperf}/workloads/imagenet_vit/imagenet_jax/__init__.py (100%) rename {algorithmic_efficiency => algoperf}/workloads/imagenet_vit/imagenet_jax/models.py (99%) rename {algorithmic_efficiency => algoperf}/workloads/imagenet_vit/imagenet_jax/workload.py (91%) rename {algorithmic_efficiency => algoperf}/workloads/imagenet_vit/imagenet_pytorch/__init__.py (100%) rename {algorithmic_efficiency => algoperf}/workloads/imagenet_vit/imagenet_pytorch/models.py (98%) rename {algorithmic_efficiency => algoperf}/workloads/imagenet_vit/imagenet_pytorch/workload.py (87%) rename {algorithmic_efficiency => algoperf}/workloads/imagenet_vit/workload.py (96%) rename {algorithmic_efficiency => algoperf}/workloads/librispeech_conformer/__init__.py (100%) rename {algorithmic_efficiency => algoperf}/workloads/librispeech_conformer/input_pipeline.py (100%) rename {algorithmic_efficiency => algoperf}/workloads/librispeech_conformer/librispeech_jax/__init__.py (100%) rename {algorithmic_efficiency => algoperf}/workloads/librispeech_conformer/librispeech_jax/librispeech_preprocessor.py (100%) rename {algorithmic_efficiency => algoperf}/workloads/librispeech_conformer/librispeech_jax/models.py (99%) rename {algorithmic_efficiency => algoperf}/workloads/librispeech_conformer/librispeech_jax/spectrum_augmenter.py (100%) rename {algorithmic_efficiency => algoperf}/workloads/librispeech_conformer/librispeech_jax/workload.py (96%) rename {algorithmic_efficiency => algoperf}/workloads/librispeech_conformer/librispeech_pytorch/__init__.py (100%) rename {algorithmic_efficiency => algoperf}/workloads/librispeech_conformer/librispeech_pytorch/models.py (98%) rename {algorithmic_efficiency => algoperf}/workloads/librispeech_conformer/librispeech_pytorch/preprocessor.py (100%) rename {algorithmic_efficiency => algoperf}/workloads/librispeech_conformer/librispeech_pytorch/spectrum_augmenter.py (100%) rename {algorithmic_efficiency => algoperf}/workloads/librispeech_conformer/librispeech_pytorch/workload.py (95%) rename {algorithmic_efficiency => algoperf}/workloads/librispeech_conformer/metrics.py (100%) rename {algorithmic_efficiency => algoperf}/workloads/librispeech_conformer/workload.py (98%) rename {algorithmic_efficiency => algoperf}/workloads/librispeech_deepspeech/__init__.py (100%) rename {algorithmic_efficiency => algoperf}/workloads/librispeech_deepspeech/librispeech_jax/__init__.py (100%) rename {algorithmic_efficiency => algoperf}/workloads/librispeech_deepspeech/librispeech_jax/models.py (99%) rename {algorithmic_efficiency => algoperf}/workloads/librispeech_deepspeech/librispeech_jax/workload.py (95%) rename {algorithmic_efficiency => algoperf}/workloads/librispeech_deepspeech/librispeech_pytorch/__init__.py (100%) rename {algorithmic_efficiency => algoperf}/workloads/librispeech_deepspeech/librispeech_pytorch/models.py (98%) rename {algorithmic_efficiency => algoperf}/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py (88%) rename {algorithmic_efficiency => algoperf}/workloads/mnist/__init__.py (100%) rename {algorithmic_efficiency => algoperf}/workloads/mnist/mnist_jax/__init__.py (100%) rename {algorithmic_efficiency => algoperf}/workloads/mnist/mnist_jax/workload.py (96%) rename {algorithmic_efficiency => algoperf}/workloads/mnist/mnist_pytorch/__init__.py (100%) rename {algorithmic_efficiency => algoperf}/workloads/mnist/mnist_pytorch/workload.py (96%) rename {algorithmic_efficiency => algoperf}/workloads/mnist/workload.py (97%) rename {algorithmic_efficiency => algoperf}/workloads/ogbg/__init__.py (100%) rename {algorithmic_efficiency => algoperf}/workloads/ogbg/input_pipeline.py (100%) rename {algorithmic_efficiency => algoperf}/workloads/ogbg/metrics.py (98%) rename {algorithmic_efficiency => algoperf}/workloads/ogbg/ogbg_jax/__init__.py (100%) rename {algorithmic_efficiency => algoperf}/workloads/ogbg/ogbg_jax/models.py (100%) rename {algorithmic_efficiency => algoperf}/workloads/ogbg/ogbg_jax/workload.py (95%) rename {algorithmic_efficiency => algoperf}/workloads/ogbg/ogbg_pytorch/__init__.py (100%) rename {algorithmic_efficiency => algoperf}/workloads/ogbg/ogbg_pytorch/models.py (99%) rename {algorithmic_efficiency => algoperf}/workloads/ogbg/ogbg_pytorch/workload.py (96%) rename {algorithmic_efficiency => algoperf}/workloads/ogbg/workload.py (96%) rename {algorithmic_efficiency => algoperf}/workloads/utils.py (100%) rename {algorithmic_efficiency => algoperf}/workloads/wmt/__init__.py (100%) rename {algorithmic_efficiency => algoperf}/workloads/wmt/bleu.py (98%) rename {algorithmic_efficiency => algoperf}/workloads/wmt/input_pipeline.py (98%) rename {algorithmic_efficiency => algoperf}/workloads/wmt/tokenizer.py (100%) rename {algorithmic_efficiency => algoperf}/workloads/wmt/wmt_jax/__init__.py (100%) rename {algorithmic_efficiency => algoperf}/workloads/wmt/wmt_jax/decode.py (100%) rename {algorithmic_efficiency => algoperf}/workloads/wmt/wmt_jax/models.py (100%) rename {algorithmic_efficiency => algoperf}/workloads/wmt/wmt_jax/workload.py (97%) rename {algorithmic_efficiency => algoperf}/workloads/wmt/wmt_pytorch/__init__.py (100%) rename {algorithmic_efficiency => algoperf}/workloads/wmt/wmt_pytorch/decode.py (99%) rename {algorithmic_efficiency => algoperf}/workloads/wmt/wmt_pytorch/models.py (100%) rename {algorithmic_efficiency => algoperf}/workloads/wmt/wmt_pytorch/workload.py (97%) rename {algorithmic_efficiency => algoperf}/workloads/wmt/workload.py (97%) rename {algorithmic_efficiency => algoperf}/workloads/workloads.py (98%) create mode 100644 tests/version_test.py diff --git a/.github/workflows/linting.yml b/.github/workflows/linting.yml index e49686358..0efa7b236 100644 --- a/.github/workflows/linting.yml +++ b/.github/workflows/linting.yml @@ -17,7 +17,7 @@ jobs: pip install pylint==2.16.1 - name: Run pylint run: | - pylint algorithmic_efficiency + pylint algoperf pylint reference_algorithms pylint prize_qualification_baselines pylint submission_runner.py diff --git a/.gitignore b/.gitignore index 403b08c2b..7d35f0ccc 100644 --- a/.gitignore +++ b/.gitignore @@ -12,8 +12,8 @@ makefile *.swp */data/ *events.out.tfevents* -algorithmic_efficiency/workloads/librispeech_conformer/data_dir -algorithmic_efficiency/workloads/librispeech_conformer/work_dir +algoperf/workloads/librispeech_conformer/data_dir +algoperf/workloads/librispeech_conformer/work_dir *.flac *.npy *.csv @@ -25,4 +25,4 @@ scoring/plots/ !scoring/test_data/experiment_dir/study_0/mnist_jax/trial_0/eval_measurements.csv !scoring/test_data/experiment_dir/study_0/mnist_jax/trial_1/eval_measurements.csv -algorithmic_efficiency/_version.py +algoperf/_version.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 95cd40775..685926506 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,34 +4,39 @@ - Finalized variant workload targets. - Fix in random_utils helper function. -- For conformer PyTorch Dropout layers set `inplace=True`. +- For conformer PyTorch Dropout layers set `inplace=True`. - Clear CUDA cache at begining of each trial for PyTorch. ## algoperf-benchmark-0.1.4 (2024-03-26) Upgrade CUDA version to CUDA 12.1: + - Upgrade CUDA version in Dockerfiles that will be used for scoring. - Update Jax and PyTorch package version tags to use local CUDA installation. -Add flag for completely disabling checkpointing. +Add flag for completely disabling checkpointing. + - Note that we will run with checkpointing off at scoring time. -Update Deepspeech and Conformer variant target setting configurations. -- Note that variant targets are not final. +Update Deepspeech and Conformer variant target setting configurations. + +- Note that variant targets are not final. Fixed bug in scoring code to take best trial in a study for external-tuning ruleset. -Added instructions for submission. +Added instructions for submission. -Changed default number of workers for PyTorch data loaders to 0. Running with >0 may lead to incorrect eval results see https://github.com/mlcommons/algorithmic-efficiency/issues/732. +Changed default number of workers for PyTorch data loaders to 0. Running with >0 may lead to incorrect eval results see . ## algoperf-benchmark-0.1.2 (2024-03-04) + Workload variant additions and fixes: + - Add Deepspeech workload variant - Fix bugs in Imagenet ResNet, WMT and Criteo1tb variants Add prize qualification logs for external tuning ruleset. -Note: FastMRI trials with dropout are not yet added due to https://github.com/mlcommons/algorithmic-efficiency/issues/664. +Note: FastMRI trials with dropout are not yet added due to . Add missing funcitonality to Docker startup script for self_tuning ruleset. Add self_tuning ruleset option to script that runs all workloads for scoring. @@ -41,6 +46,7 @@ Datasetup fixes. Fix tests that check training differences in PyTorch and JAX on GPU. ## algoperf-benchmark-0.1.1 (2024-01-19) + Bug fixes to FastMRI metric calculation and targets. Added workload variants and targets for ogbg, fastmri, librispeech_conformer, imagenet_resnet, imagenet_vit, criteo1tb to be used as held-out workloads. diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index a93289852..c98a5009e 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -205,7 +205,7 @@ docker run -t -d \ -v $HOME/data/:/data/ \ -v $HOME/experiment_runs/:/experiment_runs \ -v $HOME/experiment_runs/logs:/logs \ --v $HOME/algorithmic-efficiency:/algorithmic-efficiency \ +-v $HOME/algorithmic-efficiency:/algoperf \ --gpus all \ --ipc=host \ \ @@ -229,7 +229,7 @@ To run the below commands, use the versions installed via `pip install -e '.[dev To automatically fix formatting errors, run the following (*WARNING:* this will edit your code, so it is suggested to make a git commit first!): ```bash -yapf -i -r -vv -p algorithmic_efficiency datasets prize_qualification_baselines reference_algorithms tests *.py +yapf -i -r -vv -p algoperf datasets prize_qualification_baselines reference_algorithms tests *.py ``` To sort all import orderings, run the following: @@ -247,7 +247,7 @@ isort . --check --diff To print out all offending pylint issues, run the following: ```bash -pylint algorithmic_efficiency +pylint algoperf pylint datasets pylint prize_qualification_baselines pylint reference_algorithms @@ -288,4 +288,4 @@ You can check what version `setuptools_scm` is creating by running `python -m se To create a new version, create a new release (and tag) in the GitHub UI. The package version is automatically updated to the new version. -Once the package is installed, the version can be accessed as the package attribute `algorithmic_efficiency.__version__`, i.e. via `python -c "import algorithmic_efficiency; print(algorithmic_efficiency.__version__)"`. +Once the package is installed, the version can be accessed as the package attribute `algoperf.__version__`, i.e. via `python -c "import algoperf; print(algoperf.__version__)"`. diff --git a/DOCUMENTATION.md b/DOCUMENTATION.md index d04b247f3..63439cb09 100644 --- a/DOCUMENTATION.md +++ b/DOCUMENTATION.md @@ -641,4 +641,4 @@ That said, while submitting Adam with some novel heuristic to set various hyperp The JAX and PyTorch versions of the Criteo, FastMRI, Librispeech, OGBG, and WMT workloads use the same TensorFlow input pipelines. Due to differences in how JAX and PyTorch distribute computations across devices, the PyTorch workloads have an additional overhead for these workloads. Since we use PyTorch's [`DistributedDataParallel`](https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html#torch.nn.parallel.DistributedDataParallel) implementation, there is one Python process for each device. Depending on the hardware and the settings of the cluster, running a TensorFlow input pipeline in each Python process can lead to errors, since too many threads are created in each process. See [this PR thread](https://github.com/mlcommons/algorithmic-efficiency/pull/85) for more details. -While this issue might not affect all setups, we currently implement a different strategy: we only run the TensorFlow input pipeline in one Python process (with `rank == 0`), and [broadcast](https://pytorch.org/docs/stable/distributed.html#torch.distributed.broadcast) the batches to all other devices. This introduces an additional communication overhead for each batch. See the [implementation for the WMT workload](https://github.com/mlcommons/algorithmic-efficiency/blob/main/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py#L215-L288) as an example. +While this issue might not affect all setups, we currently implement a different strategy: we only run the TensorFlow input pipeline in one Python process (with `rank == 0`), and [broadcast](https://pytorch.org/docs/stable/distributed.html#torch.distributed.broadcast) the batches to all other devices. This introduces an additional communication overhead for each batch. See the [implementation for the WMT workload](https://github.com/mlcommons/algorithmic-efficiency/blob/main/algoperf/workloads/wmt/wmt_pytorch/workload.py#L215-L288) as an example. diff --git a/GETTING_STARTED.md b/GETTING_STARTED.md index 7d53c35e2..a4b4460a6 100644 --- a/GETTING_STARTED.md +++ b/GETTING_STARTED.md @@ -58,7 +58,7 @@ To set up a virtual enviornment and install this repository cd algorithmic-efficiency ``` -3. Run the following pip3 install commands based on your chosen framework to install `algorithmic_efficiency` and its dependencies. +3. Run the following pip3 install commands based on your chosen framework to install `algoperf` and its dependencies. For **JAX**: diff --git a/algorithmic_efficiency/__init__.py b/algoperf/__init__.py similarity index 100% rename from algorithmic_efficiency/__init__.py rename to algoperf/__init__.py diff --git a/algorithmic_efficiency/checkpoint_utils.py b/algoperf/checkpoint_utils.py similarity index 98% rename from algorithmic_efficiency/checkpoint_utils.py rename to algoperf/checkpoint_utils.py index 29c1a821e..8d3fc5102 100644 --- a/algorithmic_efficiency/checkpoint_utils.py +++ b/algoperf/checkpoint_utils.py @@ -16,8 +16,8 @@ from tensorflow.io import gfile # pytype: disable=import-error import torch -from algorithmic_efficiency import spec -from algorithmic_efficiency.pytorch_utils import pytorch_setup +from algoperf import spec +from algoperf.pytorch_utils import pytorch_setup _, _, DEVICE, _ = pytorch_setup() CheckpointReturn = Tuple[spec.OptimizerState, diff --git a/algorithmic_efficiency/data_utils.py b/algoperf/data_utils.py similarity index 99% rename from algorithmic_efficiency/data_utils.py rename to algoperf/data_utils.py index 901f0b582..b09731fbe 100644 --- a/algorithmic_efficiency/data_utils.py +++ b/algoperf/data_utils.py @@ -11,7 +11,7 @@ from torch.utils.data import DistributedSampler from torch.utils.data import Sampler -from algorithmic_efficiency import spec +from algoperf import spec def shard_and_maybe_pad_np( diff --git a/algorithmic_efficiency/halton.py b/algoperf/halton.py similarity index 100% rename from algorithmic_efficiency/halton.py rename to algoperf/halton.py diff --git a/algorithmic_efficiency/init_utils.py b/algoperf/init_utils.py similarity index 100% rename from algorithmic_efficiency/init_utils.py rename to algoperf/init_utils.py diff --git a/algorithmic_efficiency/interop_utils.py b/algoperf/interop_utils.py similarity index 90% rename from algorithmic_efficiency/interop_utils.py rename to algoperf/interop_utils.py index e307042a9..0c6535d7a 100644 --- a/algorithmic_efficiency/interop_utils.py +++ b/algoperf/interop_utils.py @@ -1,7 +1,7 @@ import jax.dlpack import torch -from algorithmic_efficiency import spec +from algoperf import spec def jax_to_pytorch(x: spec.Tensor, take_ownership: bool = False) -> spec.Tensor: diff --git a/algorithmic_efficiency/logger_utils.py b/algoperf/logger_utils.py similarity index 99% rename from algorithmic_efficiency/logger_utils.py rename to algoperf/logger_utils.py index 609d996e6..37a8ab246 100644 --- a/algorithmic_efficiency/logger_utils.py +++ b/algoperf/logger_utils.py @@ -18,8 +18,8 @@ import psutil import torch.distributed as dist -from algorithmic_efficiency import spec -from algorithmic_efficiency.pytorch_utils import pytorch_setup +from algoperf import spec +from algoperf.pytorch_utils import pytorch_setup USE_PYTORCH_DDP, RANK, DEVICE, _ = pytorch_setup() diff --git a/algorithmic_efficiency/param_utils.py b/algoperf/param_utils.py similarity index 99% rename from algorithmic_efficiency/param_utils.py rename to algoperf/param_utils.py index b430366b1..00fde1cce 100644 --- a/algorithmic_efficiency/param_utils.py +++ b/algoperf/param_utils.py @@ -6,7 +6,7 @@ import jax from torch import nn -from algorithmic_efficiency import spec +from algoperf import spec def pytorch_param_shapes(model: nn.Module) -> Dict[str, spec.ShapeTuple]: diff --git a/algorithmic_efficiency/profiler.py b/algoperf/profiler.py similarity index 100% rename from algorithmic_efficiency/profiler.py rename to algoperf/profiler.py diff --git a/algorithmic_efficiency/pytorch_utils.py b/algoperf/pytorch_utils.py similarity index 89% rename from algorithmic_efficiency/pytorch_utils.py rename to algoperf/pytorch_utils.py index 590f500fa..4a674985d 100644 --- a/algorithmic_efficiency/pytorch_utils.py +++ b/algoperf/pytorch_utils.py @@ -7,11 +7,11 @@ import torch import torch.distributed as dist -from algorithmic_efficiency import spec -from algorithmic_efficiency.profiler import Profiler -from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_pytorch.models import \ +from algoperf import spec +from algoperf.profiler import Profiler +from algoperf.workloads.librispeech_conformer.librispeech_pytorch.models import \ BatchNorm as ConformerBatchNorm -from algorithmic_efficiency.workloads.librispeech_deepspeech.librispeech_pytorch.models import \ +from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.models import \ BatchNorm as DeepspeechBatchNorm diff --git a/algorithmic_efficiency/random_utils.py b/algoperf/random_utils.py similarity index 100% rename from algorithmic_efficiency/random_utils.py rename to algoperf/random_utils.py diff --git a/algorithmic_efficiency/spec.py b/algoperf/spec.py similarity index 100% rename from algorithmic_efficiency/spec.py rename to algoperf/spec.py diff --git a/algorithmic_efficiency/workloads/__init__.py b/algoperf/workloads/__init__.py similarity index 100% rename from algorithmic_efficiency/workloads/__init__.py rename to algoperf/workloads/__init__.py diff --git a/algorithmic_efficiency/workloads/cifar/__init__.py b/algoperf/workloads/cifar/__init__.py similarity index 100% rename from algorithmic_efficiency/workloads/cifar/__init__.py rename to algoperf/workloads/cifar/__init__.py diff --git a/algorithmic_efficiency/workloads/cifar/cifar_jax/__init__.py b/algoperf/workloads/cifar/cifar_jax/__init__.py similarity index 100% rename from algorithmic_efficiency/workloads/cifar/cifar_jax/__init__.py rename to algoperf/workloads/cifar/cifar_jax/__init__.py diff --git a/algorithmic_efficiency/workloads/cifar/cifar_jax/input_pipeline.py b/algoperf/workloads/cifar/cifar_jax/input_pipeline.py similarity index 98% rename from algorithmic_efficiency/workloads/cifar/cifar_jax/input_pipeline.py rename to algoperf/workloads/cifar/cifar_jax/input_pipeline.py index 3e6a68844..728d05f29 100644 --- a/algorithmic_efficiency/workloads/cifar/cifar_jax/input_pipeline.py +++ b/algoperf/workloads/cifar/cifar_jax/input_pipeline.py @@ -13,8 +13,8 @@ import tensorflow as tf import tensorflow_datasets as tfds -from algorithmic_efficiency import spec -from algorithmic_efficiency.data_utils import shard_and_maybe_pad_np +from algoperf import spec +from algoperf.data_utils import shard_and_maybe_pad_np def preprocess_for_train(image: spec.Tensor, diff --git a/algorithmic_efficiency/workloads/cifar/cifar_jax/models.py b/algoperf/workloads/cifar/cifar_jax/models.py similarity index 93% rename from algorithmic_efficiency/workloads/cifar/cifar_jax/models.py rename to algoperf/workloads/cifar/cifar_jax/models.py index 059352fb6..4d5df766e 100644 --- a/algorithmic_efficiency/workloads/cifar/cifar_jax/models.py +++ b/algoperf/workloads/cifar/cifar_jax/models.py @@ -10,8 +10,8 @@ from flax import linen as nn import jax.numpy as jnp -from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.imagenet_resnet.imagenet_jax.models import \ +from algoperf import spec +from algoperf.workloads.imagenet_resnet.imagenet_jax.models import \ ResNetBlock ModuleDef = nn.Module diff --git a/algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py b/algoperf/workloads/cifar/cifar_jax/workload.py similarity index 95% rename from algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py rename to algoperf/workloads/cifar/cifar_jax/workload.py index 8268c6ca3..f4bcffbc3 100644 --- a/algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py +++ b/algoperf/workloads/cifar/cifar_jax/workload.py @@ -11,12 +11,12 @@ import optax import tensorflow_datasets as tfds -from algorithmic_efficiency import param_utils -from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.cifar.cifar_jax import models -from algorithmic_efficiency.workloads.cifar.cifar_jax.input_pipeline import \ +from algoperf import param_utils +from algoperf import spec +from algoperf.workloads.cifar.cifar_jax import models +from algoperf.workloads.cifar.cifar_jax.input_pipeline import \ create_input_iter -from algorithmic_efficiency.workloads.cifar.workload import BaseCifarWorkload +from algoperf.workloads.cifar.workload import BaseCifarWorkload class CifarWorkload(BaseCifarWorkload): diff --git a/algorithmic_efficiency/workloads/cifar/cifar_pytorch/__init__.py b/algoperf/workloads/cifar/cifar_pytorch/__init__.py similarity index 100% rename from algorithmic_efficiency/workloads/cifar/cifar_pytorch/__init__.py rename to algoperf/workloads/cifar/cifar_pytorch/__init__.py diff --git a/algorithmic_efficiency/workloads/cifar/cifar_pytorch/models.py b/algoperf/workloads/cifar/cifar_pytorch/models.py similarity index 92% rename from algorithmic_efficiency/workloads/cifar/cifar_pytorch/models.py rename to algoperf/workloads/cifar/cifar_pytorch/models.py index b592e10ab..393d568b9 100644 --- a/algorithmic_efficiency/workloads/cifar/cifar_pytorch/models.py +++ b/algoperf/workloads/cifar/cifar_pytorch/models.py @@ -10,13 +10,13 @@ import torch from torch import nn -from algorithmic_efficiency import spec -from algorithmic_efficiency.init_utils import pytorch_default_init -from algorithmic_efficiency.workloads.imagenet_resnet.imagenet_pytorch.models import \ +from algoperf import spec +from algoperf.init_utils import pytorch_default_init +from algoperf.workloads.imagenet_resnet.imagenet_pytorch.models import \ BasicBlock -from algorithmic_efficiency.workloads.imagenet_resnet.imagenet_pytorch.models import \ +from algoperf.workloads.imagenet_resnet.imagenet_pytorch.models import \ Bottleneck -from algorithmic_efficiency.workloads.imagenet_resnet.imagenet_pytorch.models import \ +from algoperf.workloads.imagenet_resnet.imagenet_pytorch.models import \ conv1x1 diff --git a/algorithmic_efficiency/workloads/cifar/cifar_pytorch/workload.py b/algoperf/workloads/cifar/cifar_pytorch/workload.py similarity index 96% rename from algorithmic_efficiency/workloads/cifar/cifar_pytorch/workload.py rename to algoperf/workloads/cifar/cifar_pytorch/workload.py index 7abcf4d6c..2ba92f0b9 100644 --- a/algorithmic_efficiency/workloads/cifar/cifar_pytorch/workload.py +++ b/algoperf/workloads/cifar/cifar_pytorch/workload.py @@ -12,13 +12,13 @@ from torchvision import transforms from torchvision.datasets import CIFAR10 -from algorithmic_efficiency import data_utils -from algorithmic_efficiency import param_utils -from algorithmic_efficiency import pytorch_utils -from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.cifar.cifar_pytorch.models import \ +from algoperf import data_utils +from algoperf import param_utils +from algoperf import pytorch_utils +from algoperf import spec +from algoperf.workloads.cifar.cifar_pytorch.models import \ resnet18 -from algorithmic_efficiency.workloads.cifar.workload import BaseCifarWorkload +from algoperf.workloads.cifar.workload import BaseCifarWorkload USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_utils.pytorch_setup() diff --git a/algorithmic_efficiency/workloads/cifar/workload.py b/algoperf/workloads/cifar/workload.py similarity index 97% rename from algorithmic_efficiency/workloads/cifar/workload.py rename to algoperf/workloads/cifar/workload.py index 9e36cb291..c0d565108 100644 --- a/algorithmic_efficiency/workloads/cifar/workload.py +++ b/algoperf/workloads/cifar/workload.py @@ -7,9 +7,9 @@ import jax import torch -from algorithmic_efficiency import spec -from algorithmic_efficiency.pytorch_utils import pytorch_setup -import algorithmic_efficiency.random_utils as prng +from algoperf import spec +from algoperf.pytorch_utils import pytorch_setup +import algoperf.random_utils as prng USE_PYTORCH_DDP, _, _, _ = pytorch_setup() diff --git a/algorithmic_efficiency/workloads/criteo1tb/__init__.py b/algoperf/workloads/criteo1tb/__init__.py similarity index 100% rename from algorithmic_efficiency/workloads/criteo1tb/__init__.py rename to algoperf/workloads/criteo1tb/__init__.py diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/__init__.py b/algoperf/workloads/criteo1tb/criteo1tb_jax/__init__.py similarity index 100% rename from algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/__init__.py rename to algoperf/workloads/criteo1tb/criteo1tb_jax/__init__.py diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py b/algoperf/workloads/criteo1tb/criteo1tb_jax/models.py similarity index 100% rename from algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py rename to algoperf/workloads/criteo1tb/criteo1tb_jax/models.py diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py b/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py similarity index 96% rename from algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py rename to algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py index 3743dc1ff..91761e458 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py +++ b/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py @@ -8,10 +8,10 @@ import jax.numpy as jnp import numpy as np -from algorithmic_efficiency import param_utils -from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.criteo1tb.criteo1tb_jax import models -from algorithmic_efficiency.workloads.criteo1tb.workload import \ +from algoperf import param_utils +from algoperf import spec +from algoperf.workloads.criteo1tb.criteo1tb_jax import models +from algoperf.workloads.criteo1tb.workload import \ BaseCriteo1TbDlrmSmallWorkload diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/__init__.py b/algoperf/workloads/criteo1tb/criteo1tb_pytorch/__init__.py similarity index 100% rename from algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/__init__.py rename to algoperf/workloads/criteo1tb/criteo1tb_pytorch/__init__.py diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py b/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models.py similarity index 100% rename from algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py rename to algoperf/workloads/criteo1tb/criteo1tb_pytorch/models.py diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py b/algoperf/workloads/criteo1tb/criteo1tb_pytorch/workload.py similarity index 97% rename from algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py rename to algoperf/workloads/criteo1tb/criteo1tb_pytorch/workload.py index 446267440..726aa8705 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py +++ b/algoperf/workloads/criteo1tb/criteo1tb_pytorch/workload.py @@ -7,11 +7,11 @@ import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP -from algorithmic_efficiency import param_utils -from algorithmic_efficiency import spec -from algorithmic_efficiency.pytorch_utils import pytorch_setup -from algorithmic_efficiency.workloads.criteo1tb.criteo1tb_pytorch import models -from algorithmic_efficiency.workloads.criteo1tb.workload import \ +from algoperf import param_utils +from algoperf import spec +from algoperf.pytorch_utils import pytorch_setup +from algoperf.workloads.criteo1tb.criteo1tb_pytorch import models +from algoperf.workloads.criteo1tb.workload import \ BaseCriteo1TbDlrmSmallWorkload USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_setup() diff --git a/algorithmic_efficiency/workloads/criteo1tb/input_pipeline.py b/algoperf/workloads/criteo1tb/input_pipeline.py similarity index 98% rename from algorithmic_efficiency/workloads/criteo1tb/input_pipeline.py rename to algoperf/workloads/criteo1tb/input_pipeline.py index cb091b3a5..7e254336a 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/input_pipeline.py +++ b/algoperf/workloads/criteo1tb/input_pipeline.py @@ -12,7 +12,7 @@ import tensorflow as tf -from algorithmic_efficiency import data_utils +from algoperf import data_utils _NUM_DAY_23_FILES = 36 diff --git a/algorithmic_efficiency/workloads/criteo1tb/workload.py b/algoperf/workloads/criteo1tb/workload.py similarity index 97% rename from algorithmic_efficiency/workloads/criteo1tb/workload.py rename to algoperf/workloads/criteo1tb/workload.py index f18f2656f..80ec9d67a 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/workload.py +++ b/algoperf/workloads/criteo1tb/workload.py @@ -7,8 +7,8 @@ from absl import flags import torch.distributed as dist -from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.criteo1tb import input_pipeline +from algoperf import spec +from algoperf.workloads.criteo1tb import input_pipeline FLAGS = flags.FLAGS diff --git a/algorithmic_efficiency/workloads/fastmri/__init__.py b/algoperf/workloads/fastmri/__init__.py similarity index 100% rename from algorithmic_efficiency/workloads/fastmri/__init__.py rename to algoperf/workloads/fastmri/__init__.py diff --git a/algorithmic_efficiency/workloads/fastmri/fastmri_jax/__init__.py b/algoperf/workloads/fastmri/fastmri_jax/__init__.py similarity index 100% rename from algorithmic_efficiency/workloads/fastmri/fastmri_jax/__init__.py rename to algoperf/workloads/fastmri/fastmri_jax/__init__.py diff --git a/algorithmic_efficiency/workloads/fastmri/fastmri_jax/models.py b/algoperf/workloads/fastmri/fastmri_jax/models.py similarity index 100% rename from algorithmic_efficiency/workloads/fastmri/fastmri_jax/models.py rename to algoperf/workloads/fastmri/fastmri_jax/models.py diff --git a/algorithmic_efficiency/workloads/fastmri/fastmri_jax/ssim.py b/algoperf/workloads/fastmri/fastmri_jax/ssim.py similarity index 100% rename from algorithmic_efficiency/workloads/fastmri/fastmri_jax/ssim.py rename to algoperf/workloads/fastmri/fastmri_jax/ssim.py diff --git a/algorithmic_efficiency/workloads/fastmri/fastmri_jax/workload.py b/algoperf/workloads/fastmri/fastmri_jax/workload.py similarity index 95% rename from algorithmic_efficiency/workloads/fastmri/fastmri_jax/workload.py rename to algoperf/workloads/fastmri/fastmri_jax/workload.py index a5dfe8c22..393aa19d7 100644 --- a/algorithmic_efficiency/workloads/fastmri/fastmri_jax/workload.py +++ b/algoperf/workloads/fastmri/fastmri_jax/workload.py @@ -8,12 +8,12 @@ import jax import jax.numpy as jnp -from algorithmic_efficiency import param_utils -from algorithmic_efficiency import spec -import algorithmic_efficiency.random_utils as prng -from algorithmic_efficiency.workloads.fastmri.fastmri_jax.models import UNet -from algorithmic_efficiency.workloads.fastmri.fastmri_jax.ssim import ssim -from algorithmic_efficiency.workloads.fastmri.workload import \ +from algoperf import param_utils +from algoperf import spec +import algoperf.random_utils as prng +from algoperf.workloads.fastmri.fastmri_jax.models import UNet +from algoperf.workloads.fastmri.fastmri_jax.ssim import ssim +from algoperf.workloads.fastmri.workload import \ BaseFastMRIWorkload diff --git a/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/__init__.py b/algoperf/workloads/fastmri/fastmri_pytorch/__init__.py similarity index 100% rename from algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/__init__.py rename to algoperf/workloads/fastmri/fastmri_pytorch/__init__.py diff --git a/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/models.py b/algoperf/workloads/fastmri/fastmri_pytorch/models.py similarity index 99% rename from algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/models.py rename to algoperf/workloads/fastmri/fastmri_pytorch/models.py index 6c0ab19e2..28f20bf20 100644 --- a/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/models.py +++ b/algoperf/workloads/fastmri/fastmri_pytorch/models.py @@ -12,7 +12,7 @@ from torch import Tensor from torch.nn import functional as F -from algorithmic_efficiency import init_utils +from algoperf import init_utils class UNet(nn.Module): diff --git a/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/ssim.py b/algoperf/workloads/fastmri/fastmri_pytorch/ssim.py similarity index 98% rename from algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/ssim.py rename to algoperf/workloads/fastmri/fastmri_pytorch/ssim.py index eff6fb62f..45b61bea4 100644 --- a/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/ssim.py +++ b/algoperf/workloads/fastmri/fastmri_pytorch/ssim.py @@ -6,7 +6,7 @@ import torch.nn.functional as F from torchvision.transforms.functional import pad as pad_fn -from algorithmic_efficiency.pytorch_utils import pytorch_setup +from algoperf.pytorch_utils import pytorch_setup DEVICE = pytorch_setup()[2] diff --git a/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/workload.py b/algoperf/workloads/fastmri/fastmri_pytorch/workload.py similarity index 96% rename from algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/workload.py rename to algoperf/workloads/fastmri/fastmri_pytorch/workload.py index 74f6aa13d..f40654678 100644 --- a/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/workload.py +++ b/algoperf/workloads/fastmri/fastmri_pytorch/workload.py @@ -9,14 +9,14 @@ import torch.nn.functional as F from torch.nn.parallel import DistributedDataParallel as DDP -from algorithmic_efficiency import param_utils -from algorithmic_efficiency import pytorch_utils -from algorithmic_efficiency import spec -import algorithmic_efficiency.random_utils as prng -from algorithmic_efficiency.workloads.fastmri.fastmri_pytorch.models import \ +from algoperf import param_utils +from algoperf import pytorch_utils +from algoperf import spec +import algoperf.random_utils as prng +from algoperf.workloads.fastmri.fastmri_pytorch.models import \ UNet -from algorithmic_efficiency.workloads.fastmri.fastmri_pytorch.ssim import ssim -from algorithmic_efficiency.workloads.fastmri.workload import \ +from algoperf.workloads.fastmri.fastmri_pytorch.ssim import ssim +from algoperf.workloads.fastmri.workload import \ BaseFastMRIWorkload USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_utils.pytorch_setup() diff --git a/algorithmic_efficiency/workloads/fastmri/input_pipeline.py b/algoperf/workloads/fastmri/input_pipeline.py similarity index 99% rename from algorithmic_efficiency/workloads/fastmri/input_pipeline.py rename to algoperf/workloads/fastmri/input_pipeline.py index 8f6ddafd1..f20611f43 100644 --- a/algorithmic_efficiency/workloads/fastmri/input_pipeline.py +++ b/algoperf/workloads/fastmri/input_pipeline.py @@ -9,7 +9,7 @@ import jax import tensorflow as tf -from algorithmic_efficiency import data_utils +from algoperf import data_utils _TRAIN_DIR = 'knee_singlecoil_train' _VAL_DIR = 'knee_singlecoil_val' diff --git a/algorithmic_efficiency/workloads/fastmri/workload.py b/algoperf/workloads/fastmri/workload.py similarity index 96% rename from algorithmic_efficiency/workloads/fastmri/workload.py rename to algoperf/workloads/fastmri/workload.py index a8fd1abbb..e9a2a313a 100644 --- a/algorithmic_efficiency/workloads/fastmri/workload.py +++ b/algoperf/workloads/fastmri/workload.py @@ -3,8 +3,8 @@ import math from typing import Optional -from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.fastmri import input_pipeline +from algoperf import spec +from algoperf.workloads.fastmri import input_pipeline class BaseFastMRIWorkload(spec.Workload): diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/__init__.py b/algoperf/workloads/imagenet_resnet/__init__.py similarity index 100% rename from algorithmic_efficiency/workloads/imagenet_resnet/__init__.py rename to algoperf/workloads/imagenet_resnet/__init__.py diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/__init__.py b/algoperf/workloads/imagenet_resnet/imagenet_jax/__init__.py similarity index 100% rename from algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/__init__.py rename to algoperf/workloads/imagenet_resnet/imagenet_jax/__init__.py diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/input_pipeline.py b/algoperf/workloads/imagenet_resnet/imagenet_jax/input_pipeline.py similarity index 98% rename from algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/input_pipeline.py rename to algoperf/workloads/imagenet_resnet/imagenet_jax/input_pipeline.py index 422eb9f7a..709a318c2 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/input_pipeline.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_jax/input_pipeline.py @@ -12,9 +12,9 @@ import tensorflow as tf import tensorflow_datasets as tfds -from algorithmic_efficiency import data_utils -from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.imagenet_resnet.imagenet_jax import \ +from algoperf import data_utils +from algoperf import spec +from algoperf.workloads.imagenet_resnet.imagenet_jax import \ randaugment TFDS_SPLIT_NAME = { diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/models.py b/algoperf/workloads/imagenet_resnet/imagenet_jax/models.py similarity index 98% rename from algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/models.py rename to algoperf/workloads/imagenet_resnet/imagenet_jax/models.py index 34cd17440..ffa60b260 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/models.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_jax/models.py @@ -10,7 +10,7 @@ from flax import linen as nn import jax.numpy as jnp -from algorithmic_efficiency import spec +from algoperf import spec ModuleDef = nn.Module diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py b/algoperf/workloads/imagenet_resnet/imagenet_jax/randaugment.py similarity index 100% rename from algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py rename to algoperf/workloads/imagenet_resnet/imagenet_jax/randaugment.py diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/workload.py b/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py similarity index 95% rename from algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/workload.py rename to algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py index 2747fc2db..b445e9f00 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/workload.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py @@ -17,15 +17,15 @@ import optax import tensorflow_datasets as tfds -from algorithmic_efficiency import param_utils -from algorithmic_efficiency import random_utils as prng -from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.imagenet_resnet import imagenet_v2 -from algorithmic_efficiency.workloads.imagenet_resnet.imagenet_jax import \ +from algoperf import param_utils +from algoperf import random_utils as prng +from algoperf import spec +from algoperf.workloads.imagenet_resnet import imagenet_v2 +from algoperf.workloads.imagenet_resnet.imagenet_jax import \ input_pipeline -from algorithmic_efficiency.workloads.imagenet_resnet.imagenet_jax import \ +from algoperf.workloads.imagenet_resnet.imagenet_jax import \ models -from algorithmic_efficiency.workloads.imagenet_resnet.workload import \ +from algoperf.workloads.imagenet_resnet.workload import \ BaseImagenetResNetWorkload diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_pytorch/__init__.py b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/__init__.py similarity index 100% rename from algorithmic_efficiency/workloads/imagenet_resnet/imagenet_pytorch/__init__.py rename to algoperf/workloads/imagenet_resnet/imagenet_pytorch/__init__.py diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_pytorch/models.py b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/models.py similarity index 98% rename from algorithmic_efficiency/workloads/imagenet_resnet/imagenet_pytorch/models.py rename to algoperf/workloads/imagenet_resnet/imagenet_pytorch/models.py index 2b9093940..aba9e671f 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_pytorch/models.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/models.py @@ -11,8 +11,8 @@ from torch import nn from torch import Tensor -from algorithmic_efficiency import spec -from algorithmic_efficiency.init_utils import pytorch_default_init +from algoperf import spec +from algoperf.init_utils import pytorch_default_init def conv3x3(in_planes: int, diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_pytorch/randaugment.py b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/randaugment.py similarity index 99% rename from algorithmic_efficiency/workloads/imagenet_resnet/imagenet_pytorch/randaugment.py rename to algoperf/workloads/imagenet_resnet/imagenet_pytorch/randaugment.py index 829d82d74..c7a98e77a 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_pytorch/randaugment.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/randaugment.py @@ -14,7 +14,7 @@ from torchvision.transforms import functional as F from torchvision.transforms import InterpolationMode -from algorithmic_efficiency import spec +from algoperf import spec def cutout(img: spec.Tensor, pad_size: int) -> spec.Tensor: diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_pytorch/workload.py b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py similarity index 95% rename from algorithmic_efficiency/workloads/imagenet_resnet/imagenet_pytorch/workload.py rename to algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py index 3549911fa..7a08f325e 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_pytorch/workload.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py @@ -16,17 +16,17 @@ from torchvision import transforms from torchvision.datasets.folder import ImageFolder -from algorithmic_efficiency import data_utils -from algorithmic_efficiency import param_utils -from algorithmic_efficiency import pytorch_utils -from algorithmic_efficiency import spec -import algorithmic_efficiency.random_utils as prng -from algorithmic_efficiency.workloads.imagenet_resnet import imagenet_v2 -from algorithmic_efficiency.workloads.imagenet_resnet.imagenet_pytorch import \ +from algoperf import data_utils +from algoperf import param_utils +from algoperf import pytorch_utils +from algoperf import spec +import algoperf.random_utils as prng +from algoperf.workloads.imagenet_resnet import imagenet_v2 +from algoperf.workloads.imagenet_resnet.imagenet_pytorch import \ randaugment -from algorithmic_efficiency.workloads.imagenet_resnet.imagenet_pytorch.models import \ +from algoperf.workloads.imagenet_resnet.imagenet_pytorch.models import \ resnet50 -from algorithmic_efficiency.workloads.imagenet_resnet.workload import \ +from algoperf.workloads.imagenet_resnet.workload import \ BaseImagenetResNetWorkload USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_utils.pytorch_setup() diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_v2.py b/algoperf/workloads/imagenet_resnet/imagenet_v2.py similarity index 90% rename from algorithmic_efficiency/workloads/imagenet_resnet/imagenet_v2.py rename to algoperf/workloads/imagenet_resnet/imagenet_v2.py index 05ab12eb1..f63ddbc34 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_v2.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_v2.py @@ -8,9 +8,9 @@ import tensorflow_datasets as tfds -from algorithmic_efficiency import data_utils -from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.imagenet_resnet.imagenet_jax import \ +from algoperf import data_utils +from algoperf import spec +from algoperf.workloads.imagenet_resnet.imagenet_jax import \ input_pipeline diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/workload.py b/algoperf/workloads/imagenet_resnet/workload.py similarity index 99% rename from algorithmic_efficiency/workloads/imagenet_resnet/workload.py rename to algoperf/workloads/imagenet_resnet/workload.py index 2e06805f7..8b3393ded 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/workload.py +++ b/algoperf/workloads/imagenet_resnet/workload.py @@ -3,7 +3,7 @@ import math from typing import Dict, Iterator, Optional, Tuple -from algorithmic_efficiency import spec +from algoperf import spec class BaseImagenetResNetWorkload(spec.Workload): diff --git a/algorithmic_efficiency/workloads/imagenet_vit/__init__.py b/algoperf/workloads/imagenet_vit/__init__.py similarity index 100% rename from algorithmic_efficiency/workloads/imagenet_vit/__init__.py rename to algoperf/workloads/imagenet_vit/__init__.py diff --git a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/__init__.py b/algoperf/workloads/imagenet_vit/imagenet_jax/__init__.py similarity index 100% rename from algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/__init__.py rename to algoperf/workloads/imagenet_vit/imagenet_jax/__init__.py diff --git a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/models.py b/algoperf/workloads/imagenet_vit/imagenet_jax/models.py similarity index 99% rename from algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/models.py rename to algoperf/workloads/imagenet_vit/imagenet_jax/models.py index 639800b44..cfa104b53 100644 --- a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/models.py +++ b/algoperf/workloads/imagenet_vit/imagenet_jax/models.py @@ -10,7 +10,7 @@ from flax import linen as nn import jax.numpy as jnp -from algorithmic_efficiency import spec +from algoperf import spec def posemb_sincos_2d(h: int, diff --git a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/workload.py b/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py similarity index 91% rename from algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/workload.py rename to algoperf/workloads/imagenet_vit/imagenet_jax/workload.py index 2ad71ffd0..2261aac6d 100644 --- a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/workload.py +++ b/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py @@ -7,14 +7,14 @@ import jax import jax.numpy as jnp -from algorithmic_efficiency import param_utils -from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.imagenet_resnet.imagenet_jax.workload import \ +from algoperf import param_utils +from algoperf import spec +from algoperf.workloads.imagenet_resnet.imagenet_jax.workload import \ ImagenetResNetWorkload -from algorithmic_efficiency.workloads.imagenet_vit.imagenet_jax import models -from algorithmic_efficiency.workloads.imagenet_vit.workload import \ +from algoperf.workloads.imagenet_vit.imagenet_jax import models +from algoperf.workloads.imagenet_vit.workload import \ BaseImagenetVitWorkload -from algorithmic_efficiency.workloads.imagenet_vit.workload import \ +from algoperf.workloads.imagenet_vit.workload import \ decode_variant diff --git a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/__init__.py b/algoperf/workloads/imagenet_vit/imagenet_pytorch/__init__.py similarity index 100% rename from algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/__init__.py rename to algoperf/workloads/imagenet_vit/imagenet_pytorch/__init__.py diff --git a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/models.py b/algoperf/workloads/imagenet_vit/imagenet_pytorch/models.py similarity index 98% rename from algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/models.py rename to algoperf/workloads/imagenet_vit/imagenet_pytorch/models.py index 02d708da8..4fac8bd35 100644 --- a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/models.py +++ b/algoperf/workloads/imagenet_vit/imagenet_pytorch/models.py @@ -12,9 +12,9 @@ from torch import nn import torch.nn.functional as F -from algorithmic_efficiency import init_utils -from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.wmt.wmt_pytorch.models import \ +from algoperf import init_utils +from algoperf import spec +from algoperf.workloads.wmt.wmt_pytorch.models import \ MultiheadAttention diff --git a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/workload.py b/algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py similarity index 87% rename from algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/workload.py rename to algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py index 703d40b07..20b294b47 100644 --- a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_pytorch/workload.py +++ b/algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py @@ -6,16 +6,16 @@ import torch from torch.nn.parallel import DistributedDataParallel as DDP -from algorithmic_efficiency import param_utils -from algorithmic_efficiency import pytorch_utils -from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.imagenet_resnet.imagenet_pytorch.workload import \ +from algoperf import param_utils +from algoperf import pytorch_utils +from algoperf import spec +from algoperf.workloads.imagenet_resnet.imagenet_pytorch.workload import \ ImagenetResNetWorkload -from algorithmic_efficiency.workloads.imagenet_vit.imagenet_pytorch import \ +from algoperf.workloads.imagenet_vit.imagenet_pytorch import \ models -from algorithmic_efficiency.workloads.imagenet_vit.workload import \ +from algoperf.workloads.imagenet_vit.workload import \ BaseImagenetVitWorkload -from algorithmic_efficiency.workloads.imagenet_vit.workload import \ +from algoperf.workloads.imagenet_vit.workload import \ decode_variant USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_utils.pytorch_setup() diff --git a/algorithmic_efficiency/workloads/imagenet_vit/workload.py b/algoperf/workloads/imagenet_vit/workload.py similarity index 96% rename from algorithmic_efficiency/workloads/imagenet_vit/workload.py rename to algoperf/workloads/imagenet_vit/workload.py index ed0118ca0..7f06715a3 100644 --- a/algorithmic_efficiency/workloads/imagenet_vit/workload.py +++ b/algoperf/workloads/imagenet_vit/workload.py @@ -2,8 +2,8 @@ from typing import Dict, Iterator, Optional -from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.imagenet_resnet.workload import \ +from algoperf import spec +from algoperf.workloads.imagenet_resnet.workload import \ BaseImagenetResNetWorkload diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/__init__.py b/algoperf/workloads/librispeech_conformer/__init__.py similarity index 100% rename from algorithmic_efficiency/workloads/librispeech_conformer/__init__.py rename to algoperf/workloads/librispeech_conformer/__init__.py diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/input_pipeline.py b/algoperf/workloads/librispeech_conformer/input_pipeline.py similarity index 100% rename from algorithmic_efficiency/workloads/librispeech_conformer/input_pipeline.py rename to algoperf/workloads/librispeech_conformer/input_pipeline.py diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/__init__.py b/algoperf/workloads/librispeech_conformer/librispeech_jax/__init__.py similarity index 100% rename from algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/__init__.py rename to algoperf/workloads/librispeech_conformer/librispeech_jax/__init__.py diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/librispeech_preprocessor.py b/algoperf/workloads/librispeech_conformer/librispeech_jax/librispeech_preprocessor.py similarity index 100% rename from algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/librispeech_preprocessor.py rename to algoperf/workloads/librispeech_conformer/librispeech_jax/librispeech_preprocessor.py diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py b/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py similarity index 99% rename from algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py rename to algoperf/workloads/librispeech_conformer/librispeech_jax/models.py index cb6287c5e..adb5e803c 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py @@ -22,9 +22,9 @@ import jax.numpy as jnp import numpy as np -from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_jax import \ +from algoperf.workloads.librispeech_conformer.librispeech_jax import \ librispeech_preprocessor as preprocessor -from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_jax import \ +from algoperf.workloads.librispeech_conformer.librispeech_jax import \ spectrum_augmenter diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/spectrum_augmenter.py b/algoperf/workloads/librispeech_conformer/librispeech_jax/spectrum_augmenter.py similarity index 100% rename from algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/spectrum_augmenter.py rename to algoperf/workloads/librispeech_conformer/librispeech_jax/spectrum_augmenter.py diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py b/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py similarity index 96% rename from algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py rename to algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py index e362f973b..b4fdb0811 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py @@ -11,14 +11,14 @@ import optax import torch -from algorithmic_efficiency import data_utils -from algorithmic_efficiency import param_utils -from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.librispeech_conformer import metrics -from algorithmic_efficiency.workloads.librispeech_conformer import workload -from algorithmic_efficiency.workloads.librispeech_conformer.input_pipeline import \ +from algoperf import data_utils +from algoperf import param_utils +from algoperf import spec +from algoperf.workloads.librispeech_conformer import metrics +from algoperf.workloads.librispeech_conformer import workload +from algoperf.workloads.librispeech_conformer.input_pipeline import \ LibriSpeechDataset -from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_jax import \ +from algoperf.workloads.librispeech_conformer.librispeech_jax import \ models diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/__init__.py b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/__init__.py similarity index 100% rename from algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/__init__.py rename to algoperf/workloads/librispeech_conformer/librispeech_pytorch/__init__.py diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/models.py b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/models.py similarity index 98% rename from algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/models.py rename to algoperf/workloads/librispeech_conformer/librispeech_pytorch/models.py index 61400806a..db1e24521 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/models.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/models.py @@ -12,9 +12,9 @@ from torch.nn import init import torch.nn.functional as F -from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_pytorch import \ +from algoperf.workloads.librispeech_conformer.librispeech_pytorch import \ preprocessor -from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_pytorch.spectrum_augmenter import \ +from algoperf.workloads.librispeech_conformer.librispeech_pytorch.spectrum_augmenter import \ SpecAug diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/preprocessor.py b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/preprocessor.py similarity index 100% rename from algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/preprocessor.py rename to algoperf/workloads/librispeech_conformer/librispeech_pytorch/preprocessor.py diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/spectrum_augmenter.py b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/spectrum_augmenter.py similarity index 100% rename from algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/spectrum_augmenter.py rename to algoperf/workloads/librispeech_conformer/librispeech_pytorch/spectrum_augmenter.py diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/workload.py similarity index 95% rename from algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py rename to algoperf/workloads/librispeech_conformer/librispeech_pytorch/workload.py index 155b30920..592e63989 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/workload.py @@ -10,16 +10,16 @@ import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP -from algorithmic_efficiency import data_utils -from algorithmic_efficiency import param_utils -from algorithmic_efficiency import pytorch_utils -from algorithmic_efficiency import spec -import algorithmic_efficiency.random_utils as prng -from algorithmic_efficiency.workloads.librispeech_conformer import metrics -from algorithmic_efficiency.workloads.librispeech_conformer import workload -from algorithmic_efficiency.workloads.librispeech_conformer.input_pipeline import \ +from algoperf import data_utils +from algoperf import param_utils +from algoperf import pytorch_utils +from algoperf import spec +import algoperf.random_utils as prng +from algoperf.workloads.librispeech_conformer import metrics +from algoperf.workloads.librispeech_conformer import workload +from algoperf.workloads.librispeech_conformer.input_pipeline import \ LibriSpeechDataset -from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_pytorch import \ +from algoperf.workloads.librispeech_conformer.librispeech_pytorch import \ models USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_utils.pytorch_setup() diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/metrics.py b/algoperf/workloads/librispeech_conformer/metrics.py similarity index 100% rename from algorithmic_efficiency/workloads/librispeech_conformer/metrics.py rename to algoperf/workloads/librispeech_conformer/metrics.py diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/workload.py b/algoperf/workloads/librispeech_conformer/workload.py similarity index 98% rename from algorithmic_efficiency/workloads/librispeech_conformer/workload.py rename to algoperf/workloads/librispeech_conformer/workload.py index c2413c076..c9f5a3c59 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/workload.py +++ b/algoperf/workloads/librispeech_conformer/workload.py @@ -1,7 +1,7 @@ import math from typing import Dict -from algorithmic_efficiency import spec +from algoperf import spec class BaseLibrispeechWorkload(spec.Workload): diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/__init__.py b/algoperf/workloads/librispeech_deepspeech/__init__.py similarity index 100% rename from algorithmic_efficiency/workloads/librispeech_deepspeech/__init__.py rename to algoperf/workloads/librispeech_deepspeech/__init__.py diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/__init__.py b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/__init__.py similarity index 100% rename from algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/__init__.py rename to algoperf/workloads/librispeech_deepspeech/librispeech_jax/__init__.py diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/models.py b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py similarity index 99% rename from algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/models.py rename to algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py index f9eb732e9..b116f44cd 100644 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/models.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py @@ -16,9 +16,9 @@ from jax.experimental import rnn import jax.numpy as jnp -from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_jax import \ +from algoperf.workloads.librispeech_conformer.librispeech_jax import \ librispeech_preprocessor as preprocessor -from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_jax import \ +from algoperf.workloads.librispeech_conformer.librispeech_jax import \ spectrum_augmenter Array = jnp.ndarray diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py similarity index 95% rename from algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py rename to algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py index a0db6d607..3e0781deb 100644 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py @@ -6,11 +6,11 @@ import jax.numpy as jnp import numpy as np -from algorithmic_efficiency import param_utils -from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_jax.workload import \ +from algoperf import param_utils +from algoperf import spec +from algoperf.workloads.librispeech_conformer.librispeech_jax.workload import \ LibriSpeechConformerWorkload -from algorithmic_efficiency.workloads.librispeech_deepspeech.librispeech_jax import \ +from algoperf.workloads.librispeech_deepspeech.librispeech_jax import \ models diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/__init__.py b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/__init__.py similarity index 100% rename from algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/__init__.py rename to algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/__init__.py diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/models.py b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models.py similarity index 98% rename from algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/models.py rename to algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models.py index bdf556f1c..84d317326 100644 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/models.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models.py @@ -11,9 +11,9 @@ import torch.distributed.nn as dist_nn import torch.nn.functional as F -from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_pytorch import \ +from algoperf.workloads.librispeech_conformer.librispeech_pytorch import \ preprocessor -from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_pytorch.spectrum_augmenter import \ +from algoperf.workloads.librispeech_conformer.librispeech_pytorch.spectrum_augmenter import \ SpecAug USE_PYTORCH_DDP = 'LOCAL_RANK' in os.environ diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py similarity index 88% rename from algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py rename to algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py index 626bac278..4f8ad1974 100644 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py @@ -3,16 +3,16 @@ import torch from torch.nn.parallel import DistributedDataParallel as DDP -from algorithmic_efficiency import param_utils -from algorithmic_efficiency import spec -from algorithmic_efficiency.pytorch_utils import pytorch_setup -from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_pytorch.models import \ +from algoperf import param_utils +from algoperf import spec +from algoperf.pytorch_utils import pytorch_setup +from algoperf.workloads.librispeech_conformer.librispeech_pytorch.models import \ initialize -from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_pytorch.workload import \ +from algoperf.workloads.librispeech_conformer.librispeech_pytorch.workload import \ LibriSpeechConformerWorkload -from algorithmic_efficiency.workloads.librispeech_deepspeech.librispeech_pytorch.models import \ +from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.models import \ DeepspeechConfig -from algorithmic_efficiency.workloads.librispeech_deepspeech.librispeech_pytorch.models import \ +from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.models import \ DeepspeechEncoderDecoder USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_setup() diff --git a/algorithmic_efficiency/workloads/mnist/__init__.py b/algoperf/workloads/mnist/__init__.py similarity index 100% rename from algorithmic_efficiency/workloads/mnist/__init__.py rename to algoperf/workloads/mnist/__init__.py diff --git a/algorithmic_efficiency/workloads/mnist/mnist_jax/__init__.py b/algoperf/workloads/mnist/mnist_jax/__init__.py similarity index 100% rename from algorithmic_efficiency/workloads/mnist/mnist_jax/__init__.py rename to algoperf/workloads/mnist/mnist_jax/__init__.py diff --git a/algorithmic_efficiency/workloads/mnist/mnist_jax/workload.py b/algoperf/workloads/mnist/mnist_jax/workload.py similarity index 96% rename from algorithmic_efficiency/workloads/mnist/mnist_jax/workload.py rename to algoperf/workloads/mnist/mnist_jax/workload.py index efbd73e33..8154026d1 100644 --- a/algorithmic_efficiency/workloads/mnist/mnist_jax/workload.py +++ b/algoperf/workloads/mnist/mnist_jax/workload.py @@ -10,9 +10,9 @@ import jax.numpy as jnp import optax -from algorithmic_efficiency import param_utils -from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.mnist.workload import BaseMnistWorkload +from algoperf import param_utils +from algoperf import spec +from algoperf.workloads.mnist.workload import BaseMnistWorkload class _Model(nn.Module): diff --git a/algorithmic_efficiency/workloads/mnist/mnist_pytorch/__init__.py b/algoperf/workloads/mnist/mnist_pytorch/__init__.py similarity index 100% rename from algorithmic_efficiency/workloads/mnist/mnist_pytorch/__init__.py rename to algoperf/workloads/mnist/mnist_pytorch/__init__.py diff --git a/algorithmic_efficiency/workloads/mnist/mnist_pytorch/workload.py b/algoperf/workloads/mnist/mnist_pytorch/workload.py similarity index 96% rename from algorithmic_efficiency/workloads/mnist/mnist_pytorch/workload.py rename to algoperf/workloads/mnist/mnist_pytorch/workload.py index e638df078..780e1bca0 100644 --- a/algorithmic_efficiency/workloads/mnist/mnist_pytorch/workload.py +++ b/algoperf/workloads/mnist/mnist_pytorch/workload.py @@ -10,11 +10,11 @@ import torch.nn.functional as F from torch.nn.parallel import DistributedDataParallel as DDP -from algorithmic_efficiency import init_utils -from algorithmic_efficiency import param_utils -from algorithmic_efficiency import spec -from algorithmic_efficiency.pytorch_utils import pytorch_setup -from algorithmic_efficiency.workloads.mnist.workload import BaseMnistWorkload +from algoperf import init_utils +from algoperf import param_utils +from algoperf import spec +from algoperf.pytorch_utils import pytorch_setup +from algoperf.workloads.mnist.workload import BaseMnistWorkload USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_setup() diff --git a/algorithmic_efficiency/workloads/mnist/workload.py b/algoperf/workloads/mnist/workload.py similarity index 97% rename from algorithmic_efficiency/workloads/mnist/workload.py rename to algoperf/workloads/mnist/workload.py index dcc195170..f53aadd0b 100644 --- a/algorithmic_efficiency/workloads/mnist/workload.py +++ b/algoperf/workloads/mnist/workload.py @@ -10,10 +10,10 @@ import tensorflow_datasets as tfds import torch -from algorithmic_efficiency import data_utils -from algorithmic_efficiency import spec -from algorithmic_efficiency.pytorch_utils import pytorch_setup -import algorithmic_efficiency.random_utils as prng +from algoperf import data_utils +from algoperf import spec +from algoperf.pytorch_utils import pytorch_setup +import algoperf.random_utils as prng USE_PYTORCH_DDP, _, _, _ = pytorch_setup() diff --git a/algorithmic_efficiency/workloads/ogbg/__init__.py b/algoperf/workloads/ogbg/__init__.py similarity index 100% rename from algorithmic_efficiency/workloads/ogbg/__init__.py rename to algoperf/workloads/ogbg/__init__.py diff --git a/algorithmic_efficiency/workloads/ogbg/input_pipeline.py b/algoperf/workloads/ogbg/input_pipeline.py similarity index 100% rename from algorithmic_efficiency/workloads/ogbg/input_pipeline.py rename to algoperf/workloads/ogbg/input_pipeline.py diff --git a/algorithmic_efficiency/workloads/ogbg/metrics.py b/algoperf/workloads/ogbg/metrics.py similarity index 98% rename from algorithmic_efficiency/workloads/ogbg/metrics.py rename to algoperf/workloads/ogbg/metrics.py index a654eb2ae..55f83d905 100644 --- a/algorithmic_efficiency/workloads/ogbg/metrics.py +++ b/algoperf/workloads/ogbg/metrics.py @@ -11,7 +11,7 @@ import torch import torch.distributed as dist -from algorithmic_efficiency.pytorch_utils import pytorch_setup +from algoperf.pytorch_utils import pytorch_setup USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_setup() diff --git a/algorithmic_efficiency/workloads/ogbg/ogbg_jax/__init__.py b/algoperf/workloads/ogbg/ogbg_jax/__init__.py similarity index 100% rename from algorithmic_efficiency/workloads/ogbg/ogbg_jax/__init__.py rename to algoperf/workloads/ogbg/ogbg_jax/__init__.py diff --git a/algorithmic_efficiency/workloads/ogbg/ogbg_jax/models.py b/algoperf/workloads/ogbg/ogbg_jax/models.py similarity index 100% rename from algorithmic_efficiency/workloads/ogbg/ogbg_jax/models.py rename to algoperf/workloads/ogbg/ogbg_jax/models.py diff --git a/algorithmic_efficiency/workloads/ogbg/ogbg_jax/workload.py b/algoperf/workloads/ogbg/ogbg_jax/workload.py similarity index 95% rename from algorithmic_efficiency/workloads/ogbg/ogbg_jax/workload.py rename to algoperf/workloads/ogbg/ogbg_jax/workload.py index ec0c0658d..e895d15a7 100644 --- a/algorithmic_efficiency/workloads/ogbg/ogbg_jax/workload.py +++ b/algoperf/workloads/ogbg/ogbg_jax/workload.py @@ -8,11 +8,11 @@ import jraph import optax -from algorithmic_efficiency import param_utils -from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.ogbg import metrics -from algorithmic_efficiency.workloads.ogbg.ogbg_jax import models -from algorithmic_efficiency.workloads.ogbg.workload import BaseOgbgWorkload +from algoperf import param_utils +from algoperf import spec +from algoperf.workloads.ogbg import metrics +from algoperf.workloads.ogbg.ogbg_jax import models +from algoperf.workloads.ogbg.workload import BaseOgbgWorkload class OgbgWorkload(BaseOgbgWorkload): diff --git a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/__init__.py b/algoperf/workloads/ogbg/ogbg_pytorch/__init__.py similarity index 100% rename from algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/__init__.py rename to algoperf/workloads/ogbg/ogbg_pytorch/__init__.py diff --git a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py b/algoperf/workloads/ogbg/ogbg_pytorch/models.py similarity index 99% rename from algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py rename to algoperf/workloads/ogbg/ogbg_pytorch/models.py index d93013b87..fe9b29bc1 100644 --- a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/models.py +++ b/algoperf/workloads/ogbg/ogbg_pytorch/models.py @@ -8,7 +8,7 @@ import torch from torch import nn -from algorithmic_efficiency import init_utils +from algoperf import init_utils def _make_mlp(in_dim, hidden_dims, dropout_rate, activation_fn): diff --git a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py b/algoperf/workloads/ogbg/ogbg_pytorch/workload.py similarity index 96% rename from algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py rename to algoperf/workloads/ogbg/ogbg_pytorch/workload.py index d4817226d..75252a6b9 100644 --- a/algorithmic_efficiency/workloads/ogbg/ogbg_pytorch/workload.py +++ b/algoperf/workloads/ogbg/ogbg_pytorch/workload.py @@ -8,12 +8,12 @@ import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP -from algorithmic_efficiency import param_utils -from algorithmic_efficiency import pytorch_utils -from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.ogbg import metrics -from algorithmic_efficiency.workloads.ogbg.ogbg_pytorch.models import GNN -from algorithmic_efficiency.workloads.ogbg.workload import BaseOgbgWorkload +from algoperf import param_utils +from algoperf import pytorch_utils +from algoperf import spec +from algoperf.workloads.ogbg import metrics +from algoperf.workloads.ogbg.ogbg_pytorch.models import GNN +from algoperf.workloads.ogbg.workload import BaseOgbgWorkload USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_utils.pytorch_setup() diff --git a/algorithmic_efficiency/workloads/ogbg/workload.py b/algoperf/workloads/ogbg/workload.py similarity index 96% rename from algorithmic_efficiency/workloads/ogbg/workload.py rename to algoperf/workloads/ogbg/workload.py index a32f385cb..c6a2162d7 100644 --- a/algorithmic_efficiency/workloads/ogbg/workload.py +++ b/algoperf/workloads/ogbg/workload.py @@ -7,10 +7,10 @@ import jax -from algorithmic_efficiency import random_utils as prng -from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.ogbg import input_pipeline -from algorithmic_efficiency.workloads.ogbg import metrics +from algoperf import random_utils as prng +from algoperf import spec +from algoperf.workloads.ogbg import input_pipeline +from algoperf.workloads.ogbg import metrics class BaseOgbgWorkload(spec.Workload): diff --git a/algorithmic_efficiency/workloads/utils.py b/algoperf/workloads/utils.py similarity index 100% rename from algorithmic_efficiency/workloads/utils.py rename to algoperf/workloads/utils.py diff --git a/algorithmic_efficiency/workloads/wmt/__init__.py b/algoperf/workloads/wmt/__init__.py similarity index 100% rename from algorithmic_efficiency/workloads/wmt/__init__.py rename to algoperf/workloads/wmt/__init__.py diff --git a/algorithmic_efficiency/workloads/wmt/bleu.py b/algoperf/workloads/wmt/bleu.py similarity index 98% rename from algorithmic_efficiency/workloads/wmt/bleu.py rename to algoperf/workloads/wmt/bleu.py index 1efc87381..10719819c 100644 --- a/algorithmic_efficiency/workloads/wmt/bleu.py +++ b/algoperf/workloads/wmt/bleu.py @@ -6,7 +6,7 @@ import torch import torch.distributed as dist -from algorithmic_efficiency.pytorch_utils import pytorch_setup +from algoperf.pytorch_utils import pytorch_setup USE_PYTORCH_DDP, _, DEVICE, N_GPUS = pytorch_setup() diff --git a/algorithmic_efficiency/workloads/wmt/input_pipeline.py b/algoperf/workloads/wmt/input_pipeline.py similarity index 98% rename from algorithmic_efficiency/workloads/wmt/input_pipeline.py rename to algoperf/workloads/wmt/input_pipeline.py index af1c54994..d743b43b0 100644 --- a/algorithmic_efficiency/workloads/wmt/input_pipeline.py +++ b/algoperf/workloads/wmt/input_pipeline.py @@ -6,9 +6,9 @@ import tensorflow as tf import tensorflow_datasets as tfds -from algorithmic_efficiency import data_utils -from algorithmic_efficiency.pytorch_utils import pytorch_setup -from algorithmic_efficiency.workloads.wmt import tokenizer +from algoperf import data_utils +from algoperf.pytorch_utils import pytorch_setup +from algoperf.workloads.wmt import tokenizer RANK = pytorch_setup()[1] # Avoid multithreading in all processes but the first (rank 0). diff --git a/algorithmic_efficiency/workloads/wmt/tokenizer.py b/algoperf/workloads/wmt/tokenizer.py similarity index 100% rename from algorithmic_efficiency/workloads/wmt/tokenizer.py rename to algoperf/workloads/wmt/tokenizer.py diff --git a/algorithmic_efficiency/workloads/wmt/wmt_jax/__init__.py b/algoperf/workloads/wmt/wmt_jax/__init__.py similarity index 100% rename from algorithmic_efficiency/workloads/wmt/wmt_jax/__init__.py rename to algoperf/workloads/wmt/wmt_jax/__init__.py diff --git a/algorithmic_efficiency/workloads/wmt/wmt_jax/decode.py b/algoperf/workloads/wmt/wmt_jax/decode.py similarity index 100% rename from algorithmic_efficiency/workloads/wmt/wmt_jax/decode.py rename to algoperf/workloads/wmt/wmt_jax/decode.py diff --git a/algorithmic_efficiency/workloads/wmt/wmt_jax/models.py b/algoperf/workloads/wmt/wmt_jax/models.py similarity index 100% rename from algorithmic_efficiency/workloads/wmt/wmt_jax/models.py rename to algoperf/workloads/wmt/wmt_jax/models.py diff --git a/algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py b/algoperf/workloads/wmt/wmt_jax/workload.py similarity index 97% rename from algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py rename to algoperf/workloads/wmt/wmt_jax/workload.py index 046d5e469..9f919e7cb 100644 --- a/algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py +++ b/algoperf/workloads/wmt/wmt_jax/workload.py @@ -13,12 +13,12 @@ import numpy as np import optax -from algorithmic_efficiency import param_utils -from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.wmt import bleu -from algorithmic_efficiency.workloads.wmt.wmt_jax import decode -from algorithmic_efficiency.workloads.wmt.wmt_jax import models -from algorithmic_efficiency.workloads.wmt.workload import BaseWmtWorkload +from algoperf import param_utils +from algoperf import spec +from algoperf.workloads.wmt import bleu +from algoperf.workloads.wmt.wmt_jax import decode +from algoperf.workloads.wmt.wmt_jax import models +from algoperf.workloads.wmt.workload import BaseWmtWorkload def _to_host(x: spec.Tensor) -> spec.Tensor: diff --git a/algorithmic_efficiency/workloads/wmt/wmt_pytorch/__init__.py b/algoperf/workloads/wmt/wmt_pytorch/__init__.py similarity index 100% rename from algorithmic_efficiency/workloads/wmt/wmt_pytorch/__init__.py rename to algoperf/workloads/wmt/wmt_pytorch/__init__.py diff --git a/algorithmic_efficiency/workloads/wmt/wmt_pytorch/decode.py b/algoperf/workloads/wmt/wmt_pytorch/decode.py similarity index 99% rename from algorithmic_efficiency/workloads/wmt/wmt_pytorch/decode.py rename to algoperf/workloads/wmt/wmt_pytorch/decode.py index 0488a144f..ebfc64c50 100644 --- a/algorithmic_efficiency/workloads/wmt/wmt_pytorch/decode.py +++ b/algoperf/workloads/wmt/wmt_pytorch/decode.py @@ -10,7 +10,7 @@ import torch import torch.nn.functional as F -from algorithmic_efficiency.pytorch_utils import pytorch_setup +from algoperf.pytorch_utils import pytorch_setup DEVICE = pytorch_setup()[2] diff --git a/algorithmic_efficiency/workloads/wmt/wmt_pytorch/models.py b/algoperf/workloads/wmt/wmt_pytorch/models.py similarity index 100% rename from algorithmic_efficiency/workloads/wmt/wmt_pytorch/models.py rename to algoperf/workloads/wmt/wmt_pytorch/models.py diff --git a/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py b/algoperf/workloads/wmt/wmt_pytorch/workload.py similarity index 97% rename from algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py rename to algoperf/workloads/wmt/wmt_pytorch/workload.py index 0ba49c2f6..9d1248efd 100644 --- a/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py +++ b/algoperf/workloads/wmt/wmt_pytorch/workload.py @@ -12,13 +12,13 @@ import torch.nn.functional as F from torch.nn.parallel import DistributedDataParallel as DDP -from algorithmic_efficiency import param_utils -from algorithmic_efficiency import pytorch_utils -from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.wmt import bleu -from algorithmic_efficiency.workloads.wmt.wmt_pytorch import decode -from algorithmic_efficiency.workloads.wmt.wmt_pytorch.models import Transformer -from algorithmic_efficiency.workloads.wmt.workload import BaseWmtWorkload +from algoperf import param_utils +from algoperf import pytorch_utils +from algoperf import spec +from algoperf.workloads.wmt import bleu +from algoperf.workloads.wmt.wmt_pytorch import decode +from algoperf.workloads.wmt.wmt_pytorch.models import Transformer +from algoperf.workloads.wmt.workload import BaseWmtWorkload USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_utils.pytorch_setup() diff --git a/algorithmic_efficiency/workloads/wmt/workload.py b/algoperf/workloads/wmt/workload.py similarity index 97% rename from algorithmic_efficiency/workloads/wmt/workload.py rename to algoperf/workloads/wmt/workload.py index 68ebdc94b..e9a07d2b3 100644 --- a/algorithmic_efficiency/workloads/wmt/workload.py +++ b/algoperf/workloads/wmt/workload.py @@ -9,9 +9,9 @@ import numpy as np import torch -from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.wmt import input_pipeline -from algorithmic_efficiency.workloads.wmt.wmt_jax import decode +from algoperf import spec +from algoperf.workloads.wmt import input_pipeline +from algoperf.workloads.wmt.wmt_jax import decode VOCAB_PATH = './wmt_256/sentencepiece_model' WORKDIR = './wmt_256' diff --git a/algorithmic_efficiency/workloads/workloads.py b/algoperf/workloads/workloads.py similarity index 98% rename from algorithmic_efficiency/workloads/workloads.py rename to algoperf/workloads/workloads.py index bb57f598e..4712f4e25 100644 --- a/algorithmic_efficiency/workloads/workloads.py +++ b/algoperf/workloads/workloads.py @@ -4,9 +4,9 @@ import inspect import os -from algorithmic_efficiency import spec +from algoperf import spec -BASE_WORKLOADS_DIR = 'algorithmic_efficiency/workloads/' +BASE_WORKLOADS_DIR = 'algoperf/workloads/' WORKLOADS = { 'cifar': { diff --git a/datasets/dataset_setup.py b/datasets/dataset_setup.py index 5b43a3f87..efe923dbe 100644 --- a/datasets/dataset_setup.py +++ b/datasets/dataset_setup.py @@ -71,8 +71,8 @@ import tensorflow_datasets as tfds from torchvision.datasets import CIFAR10 -from algorithmic_efficiency.workloads.wmt import tokenizer -from algorithmic_efficiency.workloads.wmt.input_pipeline import \ +from algoperf.workloads.wmt import tokenizer +from algoperf.workloads.wmt.input_pipeline import \ normalize_feature_names from datasets import librispeech_preprocess from datasets import librispeech_tokenizer diff --git a/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py b/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py index 36e7e5607..445074b69 100644 --- a/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py +++ b/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py @@ -22,7 +22,7 @@ import jax.numpy as jnp import optax -from algorithmic_efficiency import spec +from algoperf import spec _GRAD_CLIP_EPS = 1e-6 diff --git a/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py b/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py index 07281f540..ac21f1327 100644 --- a/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py +++ b/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py @@ -22,7 +22,7 @@ import jax.numpy as jnp import optax -from algorithmic_efficiency import spec +from algoperf import spec _GRAD_CLIP_EPS = 1e-6 diff --git a/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py b/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py index a12523bde..a2f9fb4c5 100644 --- a/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py +++ b/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py @@ -11,8 +11,8 @@ from torch.optim.lr_scheduler import LinearLR from torch.optim.lr_scheduler import SequentialLR -from algorithmic_efficiency import spec -from algorithmic_efficiency.pytorch_utils import pytorch_setup +from algoperf import spec +from algoperf.pytorch_utils import pytorch_setup USE_PYTORCH_DDP = pytorch_setup()[0] diff --git a/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py b/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py index 93b41987e..a37b0d341 100644 --- a/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py +++ b/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py @@ -11,8 +11,8 @@ from torch.optim.lr_scheduler import LinearLR from torch.optim.lr_scheduler import SequentialLR -from algorithmic_efficiency import spec -from algorithmic_efficiency.pytorch_utils import pytorch_setup +from algoperf import spec +from algoperf.pytorch_utils import pytorch_setup USE_PYTORCH_DDP = pytorch_setup()[0] diff --git a/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py b/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py index 0d194ef7a..3e24e2e89 100644 --- a/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py +++ b/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py @@ -22,7 +22,7 @@ import jax.numpy as jnp import optax -from algorithmic_efficiency import spec +from algoperf import spec _GRAD_CLIP_EPS = 1e-6 diff --git a/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py b/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py index 60fc25ec4..eb6b3ebb3 100644 --- a/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py +++ b/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py @@ -22,7 +22,7 @@ import jax.numpy as jnp import optax -from algorithmic_efficiency import spec +from algoperf import spec _GRAD_CLIP_EPS = 1e-6 diff --git a/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py b/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py index 2dc29acad..3ef286877 100644 --- a/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py +++ b/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py @@ -11,8 +11,8 @@ from torch.optim.lr_scheduler import LinearLR from torch.optim.lr_scheduler import SequentialLR -from algorithmic_efficiency import spec -from algorithmic_efficiency.pytorch_utils import pytorch_setup +from algoperf import spec +from algoperf.pytorch_utils import pytorch_setup USE_PYTORCH_DDP = pytorch_setup()[0] diff --git a/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py b/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py index 6cc44cb12..e9f8810a6 100644 --- a/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py +++ b/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py @@ -11,8 +11,8 @@ from torch.optim.lr_scheduler import LinearLR from torch.optim.lr_scheduler import SequentialLR -from algorithmic_efficiency import spec -from algorithmic_efficiency.pytorch_utils import pytorch_setup +from algoperf import spec +from algoperf.pytorch_utils import pytorch_setup USE_PYTORCH_DDP = pytorch_setup()[0] diff --git a/pyproject.toml b/pyproject.toml index 0788d48a5..c21191adc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,7 +3,7 @@ ############################################################################### [project] -name = "algorithmic_efficiency" +name = "algoperf" dynamic = ["version"] description = "Codebase for the AlgoPerf: Training Algorithms benchmark" authors = [ @@ -62,7 +62,7 @@ zip-safe = false find = {} # Scanning implicit namespaces is active by default [tool.setuptools_scm] -version_file = "algorithmic_efficiency/_version.py" +version_file = "algoperf/_version.py" ############################################################################### # (Optional) Dependencies # @@ -70,10 +70,10 @@ version_file = "algorithmic_efficiency/_version.py" [project.optional-dependencies] # All workloads full = [ - "algorithmic_efficiency[criteo1tb,fastmri,ogbg,librispeech_conformer,wmt]", + "algoperf[criteo1tb,fastmri,ogbg,librispeech_conformer,wmt]", ] # All workloads plus development dependencies -full_dev = ["algorithmic_efficiency[full,dev]"] +full_dev = ["algoperf[full,dev]"] # Dependencies for developing the package dev = [ "isort==5.12.0", @@ -106,12 +106,12 @@ jax_core_deps = [ jax_cpu = [ "jax==0.4.10", "jaxlib==0.4.10", - "algorithmic_efficiency[jax_core_deps]", + "algoperf[jax_core_deps]", ] jax_gpu = [ "jax==0.4.10", "jaxlib==0.4.10+cuda12.cudnn88", - "algorithmic_efficiency[jax_core_deps]", + "algoperf[jax_core_deps]", ] pytorch_cpu = ["torch==2.1.0", "torchvision==0.16.0"] pytorch_gpu = [ @@ -130,7 +130,7 @@ based_on_style = "yapf" each_dict_entry_on_separate_line = false split_all_top_level_comma_separated_values = true [tool.yapfignore] -ignore_patterns = ["algorithmic_efficiency/_version.py"] +ignore_patterns = ["algoperf/_version.py"] # isort configuration [tool.isort] diff --git a/reference_algorithms/development_algorithms/cifar/cifar_jax/submission.py b/reference_algorithms/development_algorithms/cifar/cifar_jax/submission.py index e8e0bf4ac..614d66107 100644 --- a/reference_algorithms/development_algorithms/cifar/cifar_jax/submission.py +++ b/reference_algorithms/development_algorithms/cifar/cifar_jax/submission.py @@ -9,7 +9,7 @@ import jax.numpy as jnp import optax -from algorithmic_efficiency import spec +from algoperf import spec def get_batch_size(workload_name): diff --git a/reference_algorithms/development_algorithms/cifar/cifar_pytorch/submission.py b/reference_algorithms/development_algorithms/cifar/cifar_pytorch/submission.py index c3e7a546b..d8b91f83a 100644 --- a/reference_algorithms/development_algorithms/cifar/cifar_pytorch/submission.py +++ b/reference_algorithms/development_algorithms/cifar/cifar_pytorch/submission.py @@ -7,7 +7,7 @@ from torch.optim.lr_scheduler import LinearLR from torch.optim.lr_scheduler import SequentialLR -from algorithmic_efficiency import spec +from algoperf import spec def get_batch_size(workload_name): diff --git a/reference_algorithms/development_algorithms/mnist/mnist_jax/submission.py b/reference_algorithms/development_algorithms/mnist/mnist_jax/submission.py index b33c0285b..4148148a0 100644 --- a/reference_algorithms/development_algorithms/mnist/mnist_jax/submission.py +++ b/reference_algorithms/development_algorithms/mnist/mnist_jax/submission.py @@ -9,7 +9,7 @@ import jax.numpy as jnp import optax -from algorithmic_efficiency import spec +from algoperf import spec def get_batch_size(workload_name): diff --git a/reference_algorithms/development_algorithms/mnist/mnist_pytorch/submission.py b/reference_algorithms/development_algorithms/mnist/mnist_pytorch/submission.py index b868bc787..dedd96793 100644 --- a/reference_algorithms/development_algorithms/mnist/mnist_pytorch/submission.py +++ b/reference_algorithms/development_algorithms/mnist/mnist_pytorch/submission.py @@ -4,7 +4,7 @@ import torch -from algorithmic_efficiency import spec +from algoperf import spec def get_batch_size(workload_name): diff --git a/reference_algorithms/paper_baselines/adafactor/jax/submission.py b/reference_algorithms/paper_baselines/adafactor/jax/submission.py index 0fcb9da0f..abaf36ea5 100644 --- a/reference_algorithms/paper_baselines/adafactor/jax/submission.py +++ b/reference_algorithms/paper_baselines/adafactor/jax/submission.py @@ -9,7 +9,7 @@ import jax.numpy as jnp import optax -from algorithmic_efficiency import spec +from algoperf import spec from reference_algorithms.paper_baselines.adafactor.jax.sharded_adafactor import \ sharded_adafactor diff --git a/reference_algorithms/paper_baselines/adafactor/pytorch/submission.py b/reference_algorithms/paper_baselines/adafactor/pytorch/submission.py index c0eed45ef..7aa457a25 100644 --- a/reference_algorithms/paper_baselines/adafactor/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/adafactor/pytorch/submission.py @@ -10,8 +10,8 @@ from torch.optim.lr_scheduler import LinearLR from torch.optim.lr_scheduler import SequentialLR -from algorithmic_efficiency import spec -from algorithmic_efficiency.pytorch_utils import pytorch_setup +from algoperf import spec +from algoperf.pytorch_utils import pytorch_setup USE_PYTORCH_DDP = pytorch_setup()[0] diff --git a/reference_algorithms/paper_baselines/adamw/jax/submission.py b/reference_algorithms/paper_baselines/adamw/jax/submission.py index e80a29693..da0ccdc12 100644 --- a/reference_algorithms/paper_baselines/adamw/jax/submission.py +++ b/reference_algorithms/paper_baselines/adamw/jax/submission.py @@ -9,7 +9,7 @@ import jax.numpy as jnp import optax -from algorithmic_efficiency import spec +from algoperf import spec _GRAD_CLIP_EPS = 1e-6 diff --git a/reference_algorithms/paper_baselines/adamw/pytorch/submission.py b/reference_algorithms/paper_baselines/adamw/pytorch/submission.py index 8da4e1671..21d9b6b57 100644 --- a/reference_algorithms/paper_baselines/adamw/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/adamw/pytorch/submission.py @@ -9,8 +9,8 @@ from torch.optim.lr_scheduler import LinearLR from torch.optim.lr_scheduler import SequentialLR -from algorithmic_efficiency import spec -from algorithmic_efficiency.pytorch_utils import pytorch_setup +from algoperf import spec +from algoperf.pytorch_utils import pytorch_setup USE_PYTORCH_DDP = pytorch_setup()[0] diff --git a/reference_algorithms/paper_baselines/lamb/jax/submission.py b/reference_algorithms/paper_baselines/lamb/jax/submission.py index ebcdc9914..9623e912a 100644 --- a/reference_algorithms/paper_baselines/lamb/jax/submission.py +++ b/reference_algorithms/paper_baselines/lamb/jax/submission.py @@ -9,7 +9,7 @@ import jax.numpy as jnp import optax -from algorithmic_efficiency import spec +from algoperf import spec _GRAD_CLIP_EPS = 1e-6 diff --git a/reference_algorithms/paper_baselines/lamb/pytorch/submission.py b/reference_algorithms/paper_baselines/lamb/pytorch/submission.py index c0ecee69e..c1c6cec0a 100644 --- a/reference_algorithms/paper_baselines/lamb/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/lamb/pytorch/submission.py @@ -10,7 +10,7 @@ from torch.optim.lr_scheduler import LinearLR from torch.optim.lr_scheduler import SequentialLR -from algorithmic_efficiency import spec +from algoperf import spec # Modified from github.com/pytorch/pytorch/blob/v1.12.1/torch/optim/adamw.py diff --git a/reference_algorithms/paper_baselines/momentum/jax/submission.py b/reference_algorithms/paper_baselines/momentum/jax/submission.py index 271ef860b..7af999be8 100644 --- a/reference_algorithms/paper_baselines/momentum/jax/submission.py +++ b/reference_algorithms/paper_baselines/momentum/jax/submission.py @@ -9,7 +9,7 @@ import jax.numpy as jnp import optax -from algorithmic_efficiency import spec +from algoperf import spec _GRAD_CLIP_EPS = 1e-6 diff --git a/reference_algorithms/paper_baselines/momentum/pytorch/submission.py b/reference_algorithms/paper_baselines/momentum/pytorch/submission.py index 272a79b4c..c3760d20e 100644 --- a/reference_algorithms/paper_baselines/momentum/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/momentum/pytorch/submission.py @@ -8,8 +8,8 @@ import torch.distributed.nn as dist_nn from torch.optim.lr_scheduler import LambdaLR -from algorithmic_efficiency import spec -from algorithmic_efficiency.pytorch_utils import pytorch_setup +from algoperf import spec +from algoperf.pytorch_utils import pytorch_setup USE_PYTORCH_DDP = pytorch_setup()[0] diff --git a/reference_algorithms/paper_baselines/nadamw/jax/submission.py b/reference_algorithms/paper_baselines/nadamw/jax/submission.py index 36e7e5607..445074b69 100644 --- a/reference_algorithms/paper_baselines/nadamw/jax/submission.py +++ b/reference_algorithms/paper_baselines/nadamw/jax/submission.py @@ -22,7 +22,7 @@ import jax.numpy as jnp import optax -from algorithmic_efficiency import spec +from algoperf import spec _GRAD_CLIP_EPS = 1e-6 diff --git a/reference_algorithms/paper_baselines/nadamw/pytorch/submission.py b/reference_algorithms/paper_baselines/nadamw/pytorch/submission.py index a12523bde..a2f9fb4c5 100644 --- a/reference_algorithms/paper_baselines/nadamw/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/nadamw/pytorch/submission.py @@ -11,8 +11,8 @@ from torch.optim.lr_scheduler import LinearLR from torch.optim.lr_scheduler import SequentialLR -from algorithmic_efficiency import spec -from algorithmic_efficiency.pytorch_utils import pytorch_setup +from algoperf import spec +from algoperf.pytorch_utils import pytorch_setup USE_PYTORCH_DDP = pytorch_setup()[0] diff --git a/reference_algorithms/paper_baselines/nesterov/jax/submission.py b/reference_algorithms/paper_baselines/nesterov/jax/submission.py index a435643e4..0c9fe48c4 100644 --- a/reference_algorithms/paper_baselines/nesterov/jax/submission.py +++ b/reference_algorithms/paper_baselines/nesterov/jax/submission.py @@ -9,7 +9,7 @@ import jax.numpy as jnp import optax -from algorithmic_efficiency import spec +from algoperf import spec _GRAD_CLIP_EPS = 1e-6 diff --git a/reference_algorithms/paper_baselines/nesterov/pytorch/submission.py b/reference_algorithms/paper_baselines/nesterov/pytorch/submission.py index aac4146a4..b4432fbff 100644 --- a/reference_algorithms/paper_baselines/nesterov/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/nesterov/pytorch/submission.py @@ -8,8 +8,8 @@ import torch.distributed.nn as dist_nn from torch.optim.lr_scheduler import LambdaLR -from algorithmic_efficiency import spec -from algorithmic_efficiency.pytorch_utils import pytorch_setup +from algoperf import spec +from algoperf.pytorch_utils import pytorch_setup USE_PYTORCH_DDP = pytorch_setup()[0] diff --git a/reference_algorithms/paper_baselines/sam/jax/submission.py b/reference_algorithms/paper_baselines/sam/jax/submission.py index 5f45901dd..09995d0ef 100644 --- a/reference_algorithms/paper_baselines/sam/jax/submission.py +++ b/reference_algorithms/paper_baselines/sam/jax/submission.py @@ -9,7 +9,7 @@ import jax.numpy as jnp import optax -from algorithmic_efficiency import spec +from algoperf import spec _GRAD_CLIP_EPS = 1e-6 diff --git a/reference_algorithms/paper_baselines/sam/pytorch/submission.py b/reference_algorithms/paper_baselines/sam/pytorch/submission.py index 243174d34..92603f036 100644 --- a/reference_algorithms/paper_baselines/sam/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/sam/pytorch/submission.py @@ -9,8 +9,8 @@ from torch.optim.lr_scheduler import LinearLR from torch.optim.lr_scheduler import SequentialLR -from algorithmic_efficiency import spec -from algorithmic_efficiency.pytorch_utils import pytorch_setup +from algoperf import spec +from algoperf.pytorch_utils import pytorch_setup USE_PYTORCH_DDP = pytorch_setup()[0] diff --git a/reference_algorithms/paper_baselines/shampoo/jax/submission.py b/reference_algorithms/paper_baselines/shampoo/jax/submission.py index 294ad2706..41e223c9e 100644 --- a/reference_algorithms/paper_baselines/shampoo/jax/submission.py +++ b/reference_algorithms/paper_baselines/shampoo/jax/submission.py @@ -9,7 +9,7 @@ import jax.numpy as jnp import optax -from algorithmic_efficiency import spec +from algoperf import spec from reference_algorithms.paper_baselines.shampoo.jax.distributed_shampoo import \ distributed_shampoo diff --git a/reference_algorithms/target_setting_algorithms/data_selection.py b/reference_algorithms/target_setting_algorithms/data_selection.py index ce24482fc..5e70f9f8b 100644 --- a/reference_algorithms/target_setting_algorithms/data_selection.py +++ b/reference_algorithms/target_setting_algorithms/data_selection.py @@ -1,6 +1,6 @@ from typing import Dict, Iterator, Tuple -from algorithmic_efficiency import spec +from algoperf import spec def data_selection( diff --git a/reference_algorithms/target_setting_algorithms/jax_adamw.py b/reference_algorithms/target_setting_algorithms/jax_adamw.py index 6d2cfe245..edf9bae7a 100644 --- a/reference_algorithms/target_setting_algorithms/jax_adamw.py +++ b/reference_algorithms/target_setting_algorithms/jax_adamw.py @@ -4,7 +4,7 @@ import jax.numpy as jnp import optax -from algorithmic_efficiency import spec +from algoperf import spec from reference_algorithms.target_setting_algorithms import cosine_warmup from reference_algorithms.target_setting_algorithms.data_selection import \ data_selection # pylint: disable=unused-import diff --git a/reference_algorithms/target_setting_algorithms/jax_momentum.py b/reference_algorithms/target_setting_algorithms/jax_momentum.py index 08a0f7e9d..6cdd9a8d6 100644 --- a/reference_algorithms/target_setting_algorithms/jax_momentum.py +++ b/reference_algorithms/target_setting_algorithms/jax_momentum.py @@ -7,7 +7,7 @@ import jax.numpy as jnp import optax -from algorithmic_efficiency import spec +from algoperf import spec from reference_algorithms.target_setting_algorithms.data_selection import \ data_selection # pylint: disable=unused-import from reference_algorithms.target_setting_algorithms.get_batch_size import \ diff --git a/reference_algorithms/target_setting_algorithms/jax_nadamw.py b/reference_algorithms/target_setting_algorithms/jax_nadamw.py index 21f2a7b2b..9e23cf86f 100644 --- a/reference_algorithms/target_setting_algorithms/jax_nadamw.py +++ b/reference_algorithms/target_setting_algorithms/jax_nadamw.py @@ -8,7 +8,7 @@ import jax.numpy as jnp import optax -from algorithmic_efficiency import spec +from algoperf import spec from reference_algorithms.target_setting_algorithms import cosine_warmup from reference_algorithms.target_setting_algorithms.data_selection import \ data_selection # pylint: disable=unused-import diff --git a/reference_algorithms/target_setting_algorithms/jax_nesterov.py b/reference_algorithms/target_setting_algorithms/jax_nesterov.py index 6b27e0e2a..9ef43fafb 100644 --- a/reference_algorithms/target_setting_algorithms/jax_nesterov.py +++ b/reference_algorithms/target_setting_algorithms/jax_nesterov.py @@ -7,7 +7,7 @@ import jax.numpy as jnp import optax -from algorithmic_efficiency import spec +from algoperf import spec from reference_algorithms.target_setting_algorithms.data_selection import \ data_selection # pylint: disable=unused-import from reference_algorithms.target_setting_algorithms.get_batch_size import \ diff --git a/reference_algorithms/target_setting_algorithms/jax_submission_base.py b/reference_algorithms/target_setting_algorithms/jax_submission_base.py index 7a16c07cb..221cdf411 100644 --- a/reference_algorithms/target_setting_algorithms/jax_submission_base.py +++ b/reference_algorithms/target_setting_algorithms/jax_submission_base.py @@ -7,7 +7,7 @@ import jax.numpy as jnp import optax -from algorithmic_efficiency import spec +from algoperf import spec _GRAD_CLIP_EPS = 1e-6 diff --git a/reference_algorithms/target_setting_algorithms/pytorch_adamw.py b/reference_algorithms/target_setting_algorithms/pytorch_adamw.py index 0dcb5ab14..c87bdfb7d 100644 --- a/reference_algorithms/target_setting_algorithms/pytorch_adamw.py +++ b/reference_algorithms/target_setting_algorithms/pytorch_adamw.py @@ -2,7 +2,7 @@ import torch -from algorithmic_efficiency import spec +from algoperf import spec from reference_algorithms.target_setting_algorithms import cosine_warmup from reference_algorithms.target_setting_algorithms.data_selection import \ data_selection # pylint: disable=unused-import diff --git a/reference_algorithms/target_setting_algorithms/pytorch_momentum.py b/reference_algorithms/target_setting_algorithms/pytorch_momentum.py index 1a2df449a..584caff39 100644 --- a/reference_algorithms/target_setting_algorithms/pytorch_momentum.py +++ b/reference_algorithms/target_setting_algorithms/pytorch_momentum.py @@ -3,7 +3,7 @@ import torch from torch.optim.lr_scheduler import LambdaLR -from algorithmic_efficiency import spec +from algoperf import spec from reference_algorithms.target_setting_algorithms.data_selection import \ data_selection # pylint: disable=unused-import from reference_algorithms.target_setting_algorithms.get_batch_size import \ diff --git a/reference_algorithms/target_setting_algorithms/pytorch_nadamw.py b/reference_algorithms/target_setting_algorithms/pytorch_nadamw.py index 71b819e66..a9dee1d79 100644 --- a/reference_algorithms/target_setting_algorithms/pytorch_nadamw.py +++ b/reference_algorithms/target_setting_algorithms/pytorch_nadamw.py @@ -6,7 +6,7 @@ import torch from torch import Tensor -from algorithmic_efficiency import spec +from algoperf import spec from reference_algorithms.target_setting_algorithms import cosine_warmup from reference_algorithms.target_setting_algorithms.data_selection import \ data_selection # pylint: disable=unused-import diff --git a/reference_algorithms/target_setting_algorithms/pytorch_nesterov.py b/reference_algorithms/target_setting_algorithms/pytorch_nesterov.py index 830e5eac9..8e10db4ef 100644 --- a/reference_algorithms/target_setting_algorithms/pytorch_nesterov.py +++ b/reference_algorithms/target_setting_algorithms/pytorch_nesterov.py @@ -3,7 +3,7 @@ import torch from torch.optim.lr_scheduler import LambdaLR -from algorithmic_efficiency import spec +from algoperf import spec from reference_algorithms.target_setting_algorithms.data_selection import \ data_selection # pylint: disable=unused-import from reference_algorithms.target_setting_algorithms.get_batch_size import \ diff --git a/reference_algorithms/target_setting_algorithms/pytorch_submission_base.py b/reference_algorithms/target_setting_algorithms/pytorch_submission_base.py index 2e2876555..bbfd8b0f2 100644 --- a/reference_algorithms/target_setting_algorithms/pytorch_submission_base.py +++ b/reference_algorithms/target_setting_algorithms/pytorch_submission_base.py @@ -6,8 +6,8 @@ import torch import torch.distributed.nn as dist_nn -from algorithmic_efficiency import spec -from algorithmic_efficiency.pytorch_utils import pytorch_setup +from algoperf import spec +from algoperf.pytorch_utils import pytorch_setup USE_PYTORCH_DDP = pytorch_setup()[0] diff --git a/scoring/performance_profile.py b/scoring/performance_profile.py index f4f2d5679..615ac6fe1 100644 --- a/scoring/performance_profile.py +++ b/scoring/performance_profile.py @@ -38,14 +38,14 @@ import pandas as pd from tabulate import tabulate -from algorithmic_efficiency.workloads.workloads import get_base_workload_name -import algorithmic_efficiency.workloads.workloads as workloads_registry +from algoperf.workloads.workloads import get_base_workload_name +import algoperf.workloads.workloads as workloads_registry from scoring import scoring_utils WORKLOADS = workloads_registry.WORKLOADS BASE_WORKLOADS = workloads_registry.BASE_WORKLOADS WORKLOAD_NAME_PATTERN = '(.*)(_jax|_pytorch)' -BASE_WORKLOADS_DIR = 'algorithmic_efficiency/workloads/' +BASE_WORKLOADS_DIR = 'algoperf/workloads/' # Open json file to read heldout workloads # TODO: This probably shouldn't be hardcoded but passed as an argument. with open("held_out_workloads_algoperf_v05.json", "r") as f: diff --git a/scoring/run_workloads.py b/scoring/run_workloads.py index e474b6910..99c6e810e 100644 --- a/scoring/run_workloads.py +++ b/scoring/run_workloads.py @@ -20,8 +20,8 @@ from absl import flags from absl import logging -from algorithmic_efficiency import random_utils as prng -from algorithmic_efficiency.workloads.workloads import get_base_workload_name +from algoperf import random_utils as prng +from algoperf.workloads.workloads import get_base_workload_name import docker flags.DEFINE_string( diff --git a/scoring/scoring_utils.py b/scoring/scoring_utils.py index 0dd997ab9..ac513816e 100644 --- a/scoring/scoring_utils.py +++ b/scoring/scoring_utils.py @@ -7,7 +7,7 @@ from absl import logging import pandas as pd -import algorithmic_efficiency.workloads.workloads as workloads_registry +import algoperf.workloads.workloads as workloads_registry TRIAL_LINE_REGEX = '(.*) --- Tuning run (\d+)/(\d+) ---' METRICS_LINE_REGEX = '(.*) Metrics: ({.*})' @@ -17,7 +17,7 @@ WORKLOADS = workloads_registry.WORKLOADS WORKLOAD_NAME_PATTERN = '(.*)(_jax|_pytorch)' -BASE_WORKLOADS_DIR = 'algorithmic_efficiency/workloads/' +BASE_WORKLOADS_DIR = 'algoperf/workloads/' #### File IO helper functions ### diff --git a/submission_runner.py b/submission_runner.py index 9f9b8ff42..1be56aeab 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -40,17 +40,17 @@ # it unavailable to JAX. tf.config.set_visible_devices([], 'GPU') -from algorithmic_efficiency import checkpoint_utils -from algorithmic_efficiency import halton -from algorithmic_efficiency import logger_utils -from algorithmic_efficiency import random_utils as prng -from algorithmic_efficiency import spec -from algorithmic_efficiency.profiler import PassThroughProfiler -from algorithmic_efficiency.profiler import Profiler -from algorithmic_efficiency.pytorch_utils import pytorch_init -from algorithmic_efficiency.pytorch_utils import pytorch_setup -from algorithmic_efficiency.pytorch_utils import sync_ddp_time -from algorithmic_efficiency.workloads import workloads +from algoperf import checkpoint_utils +from algoperf import halton +from algoperf import logger_utils +from algoperf import random_utils as prng +from algoperf import spec +from algoperf.profiler import PassThroughProfiler +from algoperf.profiler import Profiler +from algoperf.pytorch_utils import pytorch_init +from algoperf.pytorch_utils import pytorch_setup +from algoperf.pytorch_utils import sync_ddp_time +from algoperf.workloads import workloads # disable only for deepspeech if it works fine for other workloads. os.environ['XLA_FLAGS'] = '--xla_gpu_enable_triton_gemm=false' diff --git a/submissions/template/submission.py b/submissions/template/submission.py index 20991ab66..a4fdc62b4 100644 --- a/submissions/template/submission.py +++ b/submissions/template/submission.py @@ -6,7 +6,7 @@ """ from typing import Any, Dict, Iterator, List, Optional, Tuple -from algorithmic_efficiency import spec +from algoperf import spec def init_optimizer_state(workload: spec.Workload, diff --git a/tests/modeldiffs/criteo1tb/compare.py b/tests/modeldiffs/criteo1tb/compare.py index adbade983..d280803af 100644 --- a/tests/modeldiffs/criteo1tb/compare.py +++ b/tests/modeldiffs/criteo1tb/compare.py @@ -7,10 +7,10 @@ import numpy as np import torch -from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.criteo1tb.criteo1tb_jax.workload import \ +from algoperf import spec +from algoperf.workloads.criteo1tb.criteo1tb_jax.workload import \ Criteo1TbDlrmSmallWorkload as JaxWorkload -from algorithmic_efficiency.workloads.criteo1tb.criteo1tb_pytorch.workload import \ +from algoperf.workloads.criteo1tb.criteo1tb_pytorch.workload import \ Criteo1TbDlrmSmallWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff diff --git a/tests/modeldiffs/criteo1tb_embed_init/compare.py b/tests/modeldiffs/criteo1tb_embed_init/compare.py index 0748e2d71..73744c667 100644 --- a/tests/modeldiffs/criteo1tb_embed_init/compare.py +++ b/tests/modeldiffs/criteo1tb_embed_init/compare.py @@ -7,10 +7,10 @@ import numpy as np import torch -from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.criteo1tb.criteo1tb_jax.workload import \ +from algoperf import spec +from algoperf.workloads.criteo1tb.criteo1tb_jax.workload import \ Criteo1TbDlrmSmallEmbedInitWorkload as JaxWorkload -from algorithmic_efficiency.workloads.criteo1tb.criteo1tb_pytorch.workload import \ +from algoperf.workloads.criteo1tb.criteo1tb_pytorch.workload import \ Criteo1TbDlrmSmallEmbedInitWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff diff --git a/tests/modeldiffs/criteo1tb_layernorm/compare.py b/tests/modeldiffs/criteo1tb_layernorm/compare.py index 0a6e5c5ac..96e3cc5cc 100644 --- a/tests/modeldiffs/criteo1tb_layernorm/compare.py +++ b/tests/modeldiffs/criteo1tb_layernorm/compare.py @@ -7,10 +7,10 @@ import numpy as np import torch -from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.criteo1tb.criteo1tb_jax.workload import \ +from algoperf import spec +from algoperf.workloads.criteo1tb.criteo1tb_jax.workload import \ Criteo1TbDlrmSmallLayerNormWorkload as JaxWorkload -from algorithmic_efficiency.workloads.criteo1tb.criteo1tb_pytorch.workload import \ +from algoperf.workloads.criteo1tb.criteo1tb_pytorch.workload import \ Criteo1TbDlrmSmallLayerNormWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff diff --git a/tests/modeldiffs/criteo1tb_resnet/compare.py b/tests/modeldiffs/criteo1tb_resnet/compare.py index 288442594..188e4cac3 100644 --- a/tests/modeldiffs/criteo1tb_resnet/compare.py +++ b/tests/modeldiffs/criteo1tb_resnet/compare.py @@ -8,10 +8,10 @@ import numpy as np import torch -from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.criteo1tb.criteo1tb_jax.workload import \ +from algoperf import spec +from algoperf.workloads.criteo1tb.criteo1tb_jax.workload import \ Criteo1TbDlrmSmallResNetWorkload as JaxWorkload -from algorithmic_efficiency.workloads.criteo1tb.criteo1tb_pytorch.workload import \ +from algoperf.workloads.criteo1tb.criteo1tb_pytorch.workload import \ Criteo1TbDlrmSmallResNetWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff diff --git a/tests/modeldiffs/fastmri/compare.py b/tests/modeldiffs/fastmri/compare.py index 56b74b32d..da5f0ba0a 100644 --- a/tests/modeldiffs/fastmri/compare.py +++ b/tests/modeldiffs/fastmri/compare.py @@ -6,10 +6,10 @@ import jax import torch -from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.fastmri.fastmri_jax.workload import \ +from algoperf import spec +from algoperf.workloads.fastmri.fastmri_jax.workload import \ FastMRIWorkload as JaxWorkload -from algorithmic_efficiency.workloads.fastmri.fastmri_pytorch.workload import \ +from algoperf.workloads.fastmri.fastmri_pytorch.workload import \ FastMRIWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff diff --git a/tests/modeldiffs/fastmri_layernorm/compare.py b/tests/modeldiffs/fastmri_layernorm/compare.py index 23ccf26d7..5f1eb1842 100644 --- a/tests/modeldiffs/fastmri_layernorm/compare.py +++ b/tests/modeldiffs/fastmri_layernorm/compare.py @@ -6,10 +6,10 @@ import jax import torch -from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.fastmri.fastmri_jax.workload import \ +from algoperf import spec +from algoperf.workloads.fastmri.fastmri_jax.workload import \ FastMRILayerNormWorkload as JaxWorkload -from algorithmic_efficiency.workloads.fastmri.fastmri_pytorch.workload import \ +from algoperf.workloads.fastmri.fastmri_pytorch.workload import \ FastMRILayerNormWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff diff --git a/tests/modeldiffs/fastmri_model_size/compare.py b/tests/modeldiffs/fastmri_model_size/compare.py index b61516c29..ebb8669f8 100644 --- a/tests/modeldiffs/fastmri_model_size/compare.py +++ b/tests/modeldiffs/fastmri_model_size/compare.py @@ -6,10 +6,10 @@ import jax import torch -from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.fastmri.fastmri_jax.workload import \ +from algoperf import spec +from algoperf.workloads.fastmri.fastmri_jax.workload import \ FastMRIModelSizeWorkload as JaxWorkload -from algorithmic_efficiency.workloads.fastmri.fastmri_pytorch.workload import \ +from algoperf.workloads.fastmri.fastmri_pytorch.workload import \ FastMRIModelSizeWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff diff --git a/tests/modeldiffs/fastmri_tanh/compare.py b/tests/modeldiffs/fastmri_tanh/compare.py index 0f455387c..558bc2ba1 100644 --- a/tests/modeldiffs/fastmri_tanh/compare.py +++ b/tests/modeldiffs/fastmri_tanh/compare.py @@ -6,10 +6,10 @@ import jax import torch -from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.fastmri.fastmri_jax.workload import \ +from algoperf import spec +from algoperf.workloads.fastmri.fastmri_jax.workload import \ FastMRITanhWorkload as JaxWorkload -from algorithmic_efficiency.workloads.fastmri.fastmri_pytorch.workload import \ +from algoperf.workloads.fastmri.fastmri_pytorch.workload import \ FastMRITanhWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff diff --git a/tests/modeldiffs/imagenet_resnet/compare.py b/tests/modeldiffs/imagenet_resnet/compare.py index fb730f1bf..0a6a1b7c5 100644 --- a/tests/modeldiffs/imagenet_resnet/compare.py +++ b/tests/modeldiffs/imagenet_resnet/compare.py @@ -6,10 +6,10 @@ import jax import torch -from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.imagenet_resnet.imagenet_jax.workload import \ +from algoperf import spec +from algoperf.workloads.imagenet_resnet.imagenet_jax.workload import \ ImagenetResNetWorkload as JaxWorkload -from algorithmic_efficiency.workloads.imagenet_resnet.imagenet_pytorch.workload import \ +from algoperf.workloads.imagenet_resnet.imagenet_pytorch.workload import \ ImagenetResNetWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff diff --git a/tests/modeldiffs/imagenet_resnet/gelu_compare.py b/tests/modeldiffs/imagenet_resnet/gelu_compare.py index 6c8adbec2..4f20873b7 100644 --- a/tests/modeldiffs/imagenet_resnet/gelu_compare.py +++ b/tests/modeldiffs/imagenet_resnet/gelu_compare.py @@ -6,10 +6,10 @@ import jax import torch -from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.imagenet_resnet.imagenet_jax.workload import \ +from algoperf import spec +from algoperf.workloads.imagenet_resnet.imagenet_jax.workload import \ ImagenetResNetGELUWorkload as JaxWorkload -from algorithmic_efficiency.workloads.imagenet_resnet.imagenet_pytorch.workload import \ +from algoperf.workloads.imagenet_resnet.imagenet_pytorch.workload import \ ImagenetResNetGELUWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff from tests.modeldiffs.imagenet_resnet.compare import key_transform diff --git a/tests/modeldiffs/imagenet_resnet/silu_compare.py b/tests/modeldiffs/imagenet_resnet/silu_compare.py index 7668cdbd9..e94fdcd4c 100644 --- a/tests/modeldiffs/imagenet_resnet/silu_compare.py +++ b/tests/modeldiffs/imagenet_resnet/silu_compare.py @@ -6,10 +6,10 @@ import jax import torch -from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.imagenet_resnet.imagenet_jax.workload import \ +from algoperf import spec +from algoperf.workloads.imagenet_resnet.imagenet_jax.workload import \ ImagenetResNetSiLUWorkload as JaxWorkload -from algorithmic_efficiency.workloads.imagenet_resnet.imagenet_pytorch.workload import \ +from algoperf.workloads.imagenet_resnet.imagenet_pytorch.workload import \ ImagenetResNetSiLUWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff from tests.modeldiffs.imagenet_resnet.compare import key_transform diff --git a/tests/modeldiffs/imagenet_vit/compare.py b/tests/modeldiffs/imagenet_vit/compare.py index ba21e63da..b7b9af794 100644 --- a/tests/modeldiffs/imagenet_vit/compare.py +++ b/tests/modeldiffs/imagenet_vit/compare.py @@ -6,10 +6,10 @@ import jax import torch -from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.imagenet_vit.imagenet_jax.workload import \ +from algoperf import spec +from algoperf.workloads.imagenet_vit.imagenet_jax.workload import \ ImagenetVitWorkload as JaxWorkload -from algorithmic_efficiency.workloads.imagenet_vit.imagenet_pytorch.workload import \ +from algoperf.workloads.imagenet_vit.imagenet_pytorch.workload import \ ImagenetVitWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff diff --git a/tests/modeldiffs/imagenet_vit_glu/compare.py b/tests/modeldiffs/imagenet_vit_glu/compare.py index 2c0aa546d..11edcd84e 100644 --- a/tests/modeldiffs/imagenet_vit_glu/compare.py +++ b/tests/modeldiffs/imagenet_vit_glu/compare.py @@ -9,10 +9,10 @@ import jax import torch -from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.imagenet_vit.imagenet_jax.workload import \ +from algoperf import spec +from algoperf.workloads.imagenet_vit.imagenet_jax.workload import \ ImagenetVitGluWorkload as JaxWorkload -from algorithmic_efficiency.workloads.imagenet_vit.imagenet_pytorch.workload import \ +from algoperf.workloads.imagenet_vit.imagenet_pytorch.workload import \ ImagenetVitGluWorkload as PyTorchWorkload sd_transform = None diff --git a/tests/modeldiffs/imagenet_vit_map/compare.py b/tests/modeldiffs/imagenet_vit_map/compare.py index e7c4c2ee8..70bcd2e04 100644 --- a/tests/modeldiffs/imagenet_vit_map/compare.py +++ b/tests/modeldiffs/imagenet_vit_map/compare.py @@ -9,10 +9,10 @@ import jax import torch -from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.imagenet_vit.imagenet_jax.workload import \ +from algoperf import spec +from algoperf.workloads.imagenet_vit.imagenet_jax.workload import \ ImagenetVitMapWorkload as JaxWorkload -from algorithmic_efficiency.workloads.imagenet_vit.imagenet_pytorch.workload import \ +from algoperf.workloads.imagenet_vit.imagenet_pytorch.workload import \ ImagenetVitMapWorkload as PytWorkload diff --git a/tests/modeldiffs/imagenet_vit_postln/compare.py b/tests/modeldiffs/imagenet_vit_postln/compare.py index 8a9063cac..113a65a2a 100644 --- a/tests/modeldiffs/imagenet_vit_postln/compare.py +++ b/tests/modeldiffs/imagenet_vit_postln/compare.py @@ -9,10 +9,10 @@ import jax import torch -from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.imagenet_vit.imagenet_jax.workload import \ +from algoperf import spec +from algoperf.workloads.imagenet_vit.imagenet_jax.workload import \ ImagenetVitPostLNWorkload as JaxWorkload -from algorithmic_efficiency.workloads.imagenet_vit.imagenet_pytorch.workload import \ +from algoperf.workloads.imagenet_vit.imagenet_pytorch.workload import \ ImagenetViTPostLNWorkload as PyTorchWorkload sd_transform = None diff --git a/tests/modeldiffs/librispeech_conformer/compare.py b/tests/modeldiffs/librispeech_conformer/compare.py index cfe6c7381..5bfbf915a 100644 --- a/tests/modeldiffs/librispeech_conformer/compare.py +++ b/tests/modeldiffs/librispeech_conformer/compare.py @@ -6,10 +6,10 @@ import jax import torch -from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_jax.workload import \ +from algoperf import spec +from algoperf.workloads.librispeech_conformer.librispeech_jax.workload import \ LibriSpeechConformerWorkload as JaxWorkload -from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_pytorch.workload import \ +from algoperf.workloads.librispeech_conformer.librispeech_pytorch.workload import \ LibriSpeechConformerWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff diff --git a/tests/modeldiffs/librispeech_conformer_attention_temperature/compare.py b/tests/modeldiffs/librispeech_conformer_attention_temperature/compare.py index 8480fca02..bb9a8fae1 100644 --- a/tests/modeldiffs/librispeech_conformer_attention_temperature/compare.py +++ b/tests/modeldiffs/librispeech_conformer_attention_temperature/compare.py @@ -6,10 +6,10 @@ import jax import torch -from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_jax.workload import \ +from algoperf import spec +from algoperf.workloads.librispeech_conformer.librispeech_jax.workload import \ LibriSpeechConformerAttentionTemperatureWorkload as JaxWorkload -from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_pytorch.workload import \ +from algoperf.workloads.librispeech_conformer.librispeech_pytorch.workload import \ LibriSpeechConformerAttentionTemperatureWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff diff --git a/tests/modeldiffs/librispeech_conformer_gelu/compare.py b/tests/modeldiffs/librispeech_conformer_gelu/compare.py index caa9b09b9..629418488 100644 --- a/tests/modeldiffs/librispeech_conformer_gelu/compare.py +++ b/tests/modeldiffs/librispeech_conformer_gelu/compare.py @@ -6,10 +6,10 @@ import jax import torch -from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_jax.workload import \ +from algoperf import spec +from algoperf.workloads.librispeech_conformer.librispeech_jax.workload import \ LibriSpeechConformerGeluWorkload as JaxWorkload -from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_pytorch.workload import \ +from algoperf.workloads.librispeech_conformer.librispeech_pytorch.workload import \ LibriSpeechConformerGeluWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff diff --git a/tests/modeldiffs/librispeech_conformer_layernorm/compare.py b/tests/modeldiffs/librispeech_conformer_layernorm/compare.py index 1a94d3c77..48fe991f7 100644 --- a/tests/modeldiffs/librispeech_conformer_layernorm/compare.py +++ b/tests/modeldiffs/librispeech_conformer_layernorm/compare.py @@ -6,10 +6,10 @@ import jax import torch -from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_jax.workload import \ +from algoperf import spec +from algoperf.workloads.librispeech_conformer.librispeech_jax.workload import \ LibriSpeechConformerLayerNormWorkload as JaxWorkload -from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_pytorch.workload import \ +from algoperf.workloads.librispeech_conformer.librispeech_pytorch.workload import \ LibriSpeechConformerLayerNormWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff diff --git a/tests/modeldiffs/librispeech_deepspeech/compare.py b/tests/modeldiffs/librispeech_deepspeech/compare.py index edcc3ba87..81e12b15d 100644 --- a/tests/modeldiffs/librispeech_deepspeech/compare.py +++ b/tests/modeldiffs/librispeech_deepspeech/compare.py @@ -6,10 +6,10 @@ import jax import torch -from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.librispeech_deepspeech.librispeech_jax.workload import \ +from algoperf import spec +from algoperf.workloads.librispeech_deepspeech.librispeech_jax.workload import \ LibriSpeechDeepSpeechWorkload as JaxWorkload -from algorithmic_efficiency.workloads.librispeech_deepspeech.librispeech_pytorch.workload import \ +from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.workload import \ LibriSpeechDeepSpeechWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff diff --git a/tests/modeldiffs/librispeech_deepspeech_noresnet/compare.py b/tests/modeldiffs/librispeech_deepspeech_noresnet/compare.py index 6c00bdf69..ea106ebe4 100644 --- a/tests/modeldiffs/librispeech_deepspeech_noresnet/compare.py +++ b/tests/modeldiffs/librispeech_deepspeech_noresnet/compare.py @@ -6,10 +6,10 @@ import jax import torch -from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.librispeech_deepspeech.librispeech_jax.workload import \ +from algoperf import spec +from algoperf.workloads.librispeech_deepspeech.librispeech_jax.workload import \ LibriSpeechDeepSpeechTanhWorkload as JaxWorkload -from algorithmic_efficiency.workloads.librispeech_deepspeech.librispeech_pytorch.workload import \ +from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.workload import \ LibriSpeechDeepSpeechTanhWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff from tests.modeldiffs.librispeech_deepspeech.compare import key_transform diff --git a/tests/modeldiffs/librispeech_deepspeech_normaug/compare.py b/tests/modeldiffs/librispeech_deepspeech_normaug/compare.py index c68d6adf9..ecb6d28af 100644 --- a/tests/modeldiffs/librispeech_deepspeech_normaug/compare.py +++ b/tests/modeldiffs/librispeech_deepspeech_normaug/compare.py @@ -6,10 +6,10 @@ import jax import torch -from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.librispeech_deepspeech.librispeech_jax.workload import \ +from algoperf import spec +from algoperf.workloads.librispeech_deepspeech.librispeech_jax.workload import \ LibriSpeechDeepSpeechNormAndSpecAugWorkload as JaxWorkload -from algorithmic_efficiency.workloads.librispeech_deepspeech.librispeech_pytorch.workload import \ +from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.workload import \ LibriSpeechDeepSpeechNormAndSpecAugWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff from tests.modeldiffs.librispeech_deepspeech.compare import key_transform diff --git a/tests/modeldiffs/librispeech_deepspeech_tanh/compare.py b/tests/modeldiffs/librispeech_deepspeech_tanh/compare.py index 4cfdf4f21..31d9029b4 100644 --- a/tests/modeldiffs/librispeech_deepspeech_tanh/compare.py +++ b/tests/modeldiffs/librispeech_deepspeech_tanh/compare.py @@ -6,10 +6,10 @@ import jax import torch -from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.librispeech_deepspeech.librispeech_jax.workload import \ +from algoperf import spec +from algoperf.workloads.librispeech_deepspeech.librispeech_jax.workload import \ LibriSpeechDeepSpeechNoResNetWorkload as JaxWorkload -from algorithmic_efficiency.workloads.librispeech_deepspeech.librispeech_pytorch.workload import \ +from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.workload import \ LibriSpeechDeepSpeechNoResNetWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff from tests.modeldiffs.librispeech_deepspeech.compare import key_transform diff --git a/tests/modeldiffs/ogbg/compare.py b/tests/modeldiffs/ogbg/compare.py index 56316ba12..43ca48764 100644 --- a/tests/modeldiffs/ogbg/compare.py +++ b/tests/modeldiffs/ogbg/compare.py @@ -8,10 +8,10 @@ import numpy as np import torch -from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.ogbg.ogbg_jax.workload import \ +from algoperf import spec +from algoperf.workloads.ogbg.ogbg_jax.workload import \ OgbgWorkload as JaxWorkload -from algorithmic_efficiency.workloads.ogbg.ogbg_pytorch.workload import \ +from algoperf.workloads.ogbg.ogbg_pytorch.workload import \ OgbgWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff diff --git a/tests/modeldiffs/ogbg_gelu/compare.py b/tests/modeldiffs/ogbg_gelu/compare.py index b58bcd461..062588fe2 100644 --- a/tests/modeldiffs/ogbg_gelu/compare.py +++ b/tests/modeldiffs/ogbg_gelu/compare.py @@ -8,10 +8,10 @@ import numpy as np import torch -from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.ogbg.ogbg_jax.workload import \ +from algoperf import spec +from algoperf.workloads.ogbg.ogbg_jax.workload import \ OgbgGeluWorkload as JaxWorkload -from algorithmic_efficiency.workloads.ogbg.ogbg_pytorch.workload import \ +from algoperf.workloads.ogbg.ogbg_pytorch.workload import \ OgbgGeluWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff diff --git a/tests/modeldiffs/ogbg_model_size/compare.py b/tests/modeldiffs/ogbg_model_size/compare.py index 62443bbb5..2eb70d097 100644 --- a/tests/modeldiffs/ogbg_model_size/compare.py +++ b/tests/modeldiffs/ogbg_model_size/compare.py @@ -8,10 +8,10 @@ import numpy as np import torch -from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.ogbg.ogbg_jax.workload import \ +from algoperf import spec +from algoperf.workloads.ogbg.ogbg_jax.workload import \ OgbgModelSizeWorkload as JaxWorkload -from algorithmic_efficiency.workloads.ogbg.ogbg_pytorch.workload import \ +from algoperf.workloads.ogbg.ogbg_pytorch.workload import \ OgbgModelSizeWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff diff --git a/tests/modeldiffs/ogbg_silu/compare.py b/tests/modeldiffs/ogbg_silu/compare.py index 2922b7046..19e446030 100644 --- a/tests/modeldiffs/ogbg_silu/compare.py +++ b/tests/modeldiffs/ogbg_silu/compare.py @@ -8,10 +8,10 @@ import numpy as np import torch -from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.ogbg.ogbg_jax.workload import \ +from algoperf import spec +from algoperf.workloads.ogbg.ogbg_jax.workload import \ OgbgSiluWorkload as JaxWorkload -from algorithmic_efficiency.workloads.ogbg.ogbg_pytorch.workload import \ +from algoperf.workloads.ogbg.ogbg_pytorch.workload import \ OgbgSiluWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff diff --git a/tests/modeldiffs/vanilla_sgd_jax.py b/tests/modeldiffs/vanilla_sgd_jax.py index d45694bcb..62b98bd17 100644 --- a/tests/modeldiffs/vanilla_sgd_jax.py +++ b/tests/modeldiffs/vanilla_sgd_jax.py @@ -3,7 +3,7 @@ import jax.numpy as jnp import optax -from algorithmic_efficiency import spec +from algoperf import spec from reference_algorithms.target_setting_algorithms.data_selection import \ data_selection # pylint: disable=unused-import from reference_algorithms.target_setting_algorithms.jax_submission_base import \ diff --git a/tests/modeldiffs/vanilla_sgd_pytorch.py b/tests/modeldiffs/vanilla_sgd_pytorch.py index 254ef6018..a6a0c5fa6 100644 --- a/tests/modeldiffs/vanilla_sgd_pytorch.py +++ b/tests/modeldiffs/vanilla_sgd_pytorch.py @@ -1,6 +1,6 @@ import torch -from algorithmic_efficiency import spec +from algoperf import spec from reference_algorithms.target_setting_algorithms.data_selection import \ data_selection # pylint: disable=unused-import from reference_algorithms.target_setting_algorithms.pytorch_submission_base import \ diff --git a/tests/modeldiffs/wmt/compare.py b/tests/modeldiffs/wmt/compare.py index 41fc5ee17..73bc03f78 100644 --- a/tests/modeldiffs/wmt/compare.py +++ b/tests/modeldiffs/wmt/compare.py @@ -6,10 +6,10 @@ import jax import torch -from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.wmt.wmt_jax.workload import \ +from algoperf import spec +from algoperf.workloads.wmt.wmt_jax.workload import \ WmtWorkload as JaxWorkload -from algorithmic_efficiency.workloads.wmt.wmt_pytorch.workload import \ +from algoperf.workloads.wmt.wmt_pytorch.workload import \ WmtWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff diff --git a/tests/modeldiffs/wmt_attention_temp/compare.py b/tests/modeldiffs/wmt_attention_temp/compare.py index 92ce4eb44..01dc2895c 100644 --- a/tests/modeldiffs/wmt_attention_temp/compare.py +++ b/tests/modeldiffs/wmt_attention_temp/compare.py @@ -6,10 +6,10 @@ import jax import torch -from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.wmt.wmt_jax.workload import \ +from algoperf import spec +from algoperf.workloads.wmt.wmt_jax.workload import \ WmtWorkloadAttentionTemp as JaxWorkload -from algorithmic_efficiency.workloads.wmt.wmt_pytorch.workload import \ +from algoperf.workloads.wmt.wmt_pytorch.workload import \ WmtWorkloadAttentionTemp as PyTorchWorkload from tests.modeldiffs.diff import out_diff diff --git a/tests/modeldiffs/wmt_glu_tanh/compare.py b/tests/modeldiffs/wmt_glu_tanh/compare.py index b8d860479..77e71c826 100644 --- a/tests/modeldiffs/wmt_glu_tanh/compare.py +++ b/tests/modeldiffs/wmt_glu_tanh/compare.py @@ -6,10 +6,10 @@ import jax import torch -from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.wmt.wmt_jax.workload import \ +from algoperf import spec +from algoperf.workloads.wmt.wmt_jax.workload import \ WmtWorkloadGLUTanH as JaxWorkload -from algorithmic_efficiency.workloads.wmt.wmt_pytorch.workload import \ +from algoperf.workloads.wmt.wmt_pytorch.workload import \ WmtWorkloadGLUTanH as PyTorchWorkload from tests.modeldiffs.diff import out_diff diff --git a/tests/modeldiffs/wmt_post_ln/compare.py b/tests/modeldiffs/wmt_post_ln/compare.py index 3f5469d8d..909fcd672 100644 --- a/tests/modeldiffs/wmt_post_ln/compare.py +++ b/tests/modeldiffs/wmt_post_ln/compare.py @@ -6,10 +6,10 @@ import jax import torch -from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.wmt.wmt_jax.workload import \ +from algoperf import spec +from algoperf.workloads.wmt.wmt_jax.workload import \ WmtWorkloadPostLN as JaxWorkload -from algorithmic_efficiency.workloads.wmt.wmt_pytorch.workload import \ +from algoperf.workloads.wmt.wmt_pytorch.workload import \ WmtWorkloadPostLN as PyTorchWorkload from tests.modeldiffs.diff import out_diff diff --git a/tests/reference_algorithm_tests.py b/tests/reference_algorithm_tests.py index f107be8d7..3f279a605 100644 --- a/tests/reference_algorithm_tests.py +++ b/tests/reference_algorithm_tests.py @@ -40,14 +40,14 @@ import torch import torch.distributed as dist -from algorithmic_efficiency import halton -from algorithmic_efficiency import pytorch_utils -from algorithmic_efficiency import random_utils as prng -from algorithmic_efficiency.profiler import PassThroughProfiler -from algorithmic_efficiency.workloads import workloads -from algorithmic_efficiency.workloads.ogbg import \ +from algoperf import halton +from algoperf import pytorch_utils +from algoperf import random_utils as prng +from algoperf.profiler import PassThroughProfiler +from algoperf.workloads import workloads +from algoperf.workloads.ogbg import \ input_pipeline as ogbg_input_pipeline -from algorithmic_efficiency.workloads.ogbg.ogbg_pytorch.workload import \ +from algoperf.workloads.ogbg.ogbg_pytorch.workload import \ _graph_map import submission_runner from tests.modeldiffs import diff as diff_utils diff --git a/tests/submission_runner_test.py b/tests/submission_runner_test.py index cc98e603e..ff724b201 100644 --- a/tests/submission_runner_test.py +++ b/tests/submission_runner_test.py @@ -13,7 +13,7 @@ from absl.testing import absltest from absl.testing import parameterized -from algorithmic_efficiency.profiler import PassThroughProfiler +from algoperf.profiler import PassThroughProfiler import submission_runner FLAGS = flags.FLAGS diff --git a/tests/test_baselines.py b/tests/test_baselines.py index f79e629e7..b2be8aa11 100644 --- a/tests/test_baselines.py +++ b/tests/test_baselines.py @@ -12,8 +12,8 @@ from absl.testing import absltest from absl.testing import parameterized -from algorithmic_efficiency.profiler import PassThroughProfiler -from algorithmic_efficiency.workloads import workloads +from algoperf.profiler import PassThroughProfiler +from algoperf.workloads import workloads import submission_runner FLAGS = flags.FLAGS diff --git a/tests/test_num_params.py b/tests/test_num_params.py index 574fd0aa5..83a23c9a4 100644 --- a/tests/test_num_params.py +++ b/tests/test_num_params.py @@ -5,42 +5,42 @@ import pytest import torch -from algorithmic_efficiency.workloads.criteo1tb.criteo1tb_jax.models import \ +from algoperf.workloads.criteo1tb.criteo1tb_jax.models import \ DlrmSmall as JaxDlrmSmall -from algorithmic_efficiency.workloads.criteo1tb.criteo1tb_pytorch.models import \ +from algoperf.workloads.criteo1tb.criteo1tb_pytorch.models import \ DlrmSmall as PyTorchDlrmSmall -from algorithmic_efficiency.workloads.imagenet_resnet.imagenet_jax.models import \ +from algoperf.workloads.imagenet_resnet.imagenet_jax.models import \ ResNet18 as JaxResNet_c10 -from algorithmic_efficiency.workloads.imagenet_resnet.imagenet_jax.models import \ +from algoperf.workloads.imagenet_resnet.imagenet_jax.models import \ ResNet50 as JaxResNet -from algorithmic_efficiency.workloads.imagenet_resnet.imagenet_pytorch.models import \ +from algoperf.workloads.imagenet_resnet.imagenet_pytorch.models import \ resnet18 as PyTorchResNet_c10 -from algorithmic_efficiency.workloads.imagenet_resnet.imagenet_pytorch.models import \ +from algoperf.workloads.imagenet_resnet.imagenet_pytorch.models import \ resnet50 as PyTorchResNet -from algorithmic_efficiency.workloads.imagenet_vit.imagenet_jax.models import \ +from algoperf.workloads.imagenet_vit.imagenet_jax.models import \ ViT as JaxViT -from algorithmic_efficiency.workloads.imagenet_vit.imagenet_pytorch.models import \ +from algoperf.workloads.imagenet_vit.imagenet_pytorch.models import \ ViT as PyTorchViT -from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_jax.models import \ +from algoperf.workloads.librispeech_conformer.librispeech_jax.models import \ Conformer as JaxConformer -from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_jax.models import \ +from algoperf.workloads.librispeech_conformer.librispeech_jax.models import \ ConformerConfig as JaxConformerConfig -from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_pytorch.models import \ +from algoperf.workloads.librispeech_conformer.librispeech_pytorch.models import \ ConformerConfig as PytorchConformerConfig -from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_pytorch.models import \ +from algoperf.workloads.librispeech_conformer.librispeech_pytorch.models import \ ConformerEncoderDecoder as PytorchConformer -from algorithmic_efficiency.workloads.mnist.mnist_jax.workload import \ +from algoperf.workloads.mnist.mnist_jax.workload import \ _Model as JaxMLP -from algorithmic_efficiency.workloads.mnist.mnist_pytorch.workload import \ +from algoperf.workloads.mnist.mnist_pytorch.workload import \ _Model as PyTorchMLP -from algorithmic_efficiency.workloads.ogbg.ogbg_jax.models import GNN as JaxGNN -from algorithmic_efficiency.workloads.ogbg.ogbg_pytorch.models import \ +from algoperf.workloads.ogbg.ogbg_jax.models import GNN as JaxGNN +from algoperf.workloads.ogbg.ogbg_pytorch.models import \ GNN as PyTorchGNN -from algorithmic_efficiency.workloads.wmt.wmt_jax.models import \ +from algoperf.workloads.wmt.wmt_jax.models import \ Transformer as JaxTransformer -from algorithmic_efficiency.workloads.wmt.wmt_jax.models import \ +from algoperf.workloads.wmt.wmt_jax.models import \ TransformerConfig -from algorithmic_efficiency.workloads.wmt.wmt_pytorch.models import \ +from algoperf.workloads.wmt.wmt_pytorch.models import \ Transformer as PyTorchTransformer WORKLOADS = [ diff --git a/tests/test_param_shapes.py b/tests/test_param_shapes.py index b67625213..96a7bace5 100644 --- a/tests/test_param_shapes.py +++ b/tests/test_param_shapes.py @@ -6,26 +6,26 @@ # isort: skip_file # pylint:disable=line-too-long -from algorithmic_efficiency.workloads.cifar.cifar_jax.workload import CifarWorkload as JaxCifarWorkload -from algorithmic_efficiency.workloads.cifar.cifar_pytorch.workload import CifarWorkload as PyTorchCifarWorkload -from algorithmic_efficiency.workloads.criteo1tb.criteo1tb_jax.workload import Criteo1TbDlrmSmallWorkload as JaxCriteoWorkload -from algorithmic_efficiency.workloads.criteo1tb.criteo1tb_pytorch.workload import Criteo1TbDlrmSmallWorkload as PyTorchCriteoWorkload -from algorithmic_efficiency.workloads.fastmri.fastmri_jax.workload import FastMRIWorkload as JaxFastMRIWorkload -from algorithmic_efficiency.workloads.fastmri.fastmri_pytorch.workload import FastMRIWorkload as PyTorchFastMRIWorkload -from algorithmic_efficiency.workloads.imagenet_resnet.imagenet_jax.workload import ImagenetResNetWorkload as JaxImagenetResNetWorkload -from algorithmic_efficiency.workloads.imagenet_resnet.imagenet_pytorch.workload import ImagenetResNetWorkload as PyTorchImagenetResNetWorkload -from algorithmic_efficiency.workloads.imagenet_vit.imagenet_jax.workload import ImagenetVitWorkload as JaxImagenetViTWorkload -from algorithmic_efficiency.workloads.imagenet_vit.imagenet_pytorch.workload import ImagenetVitWorkload as PyTorchImagenetViTWorkload -from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_jax.workload import LibriSpeechConformerWorkload as JaxLibriSpeechConformerWorkload -from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_pytorch.workload import LibriSpeechConformerWorkload as PytorchLibriSpeechConformerWorkload -from algorithmic_efficiency.workloads.librispeech_deepspeech.librispeech_jax.workload import LibriSpeechDeepSpeechWorkload as JaxLibriSpeechDeepSpeechWorkload -from algorithmic_efficiency.workloads.librispeech_deepspeech.librispeech_pytorch.workload import LibriSpeechDeepSpeechWorkload as PytorchLibriSpeechDeepSpeechWorkload -from algorithmic_efficiency.workloads.mnist.mnist_jax.workload import MnistWorkload as JaxMnistWorkload -from algorithmic_efficiency.workloads.mnist.mnist_pytorch.workload import MnistWorkload as PyTorchMnistWorkload -from algorithmic_efficiency.workloads.ogbg.ogbg_jax.workload import OgbgWorkload as JaxOgbgWorkload -from algorithmic_efficiency.workloads.ogbg.ogbg_pytorch.workload import OgbgWorkload as PyTorchOgbgWorkload -from algorithmic_efficiency.workloads.wmt.wmt_jax.workload import WmtWorkload as JaxWmtWorkload -from algorithmic_efficiency.workloads.wmt.wmt_pytorch.workload import WmtWorkload as PyTorchWmtWorkload +from algoperf.workloads.cifar.cifar_jax.workload import CifarWorkload as JaxCifarWorkload +from algoperf.workloads.cifar.cifar_pytorch.workload import CifarWorkload as PyTorchCifarWorkload +from algoperf.workloads.criteo1tb.criteo1tb_jax.workload import Criteo1TbDlrmSmallWorkload as JaxCriteoWorkload +from algoperf.workloads.criteo1tb.criteo1tb_pytorch.workload import Criteo1TbDlrmSmallWorkload as PyTorchCriteoWorkload +from algoperf.workloads.fastmri.fastmri_jax.workload import FastMRIWorkload as JaxFastMRIWorkload +from algoperf.workloads.fastmri.fastmri_pytorch.workload import FastMRIWorkload as PyTorchFastMRIWorkload +from algoperf.workloads.imagenet_resnet.imagenet_jax.workload import ImagenetResNetWorkload as JaxImagenetResNetWorkload +from algoperf.workloads.imagenet_resnet.imagenet_pytorch.workload import ImagenetResNetWorkload as PyTorchImagenetResNetWorkload +from algoperf.workloads.imagenet_vit.imagenet_jax.workload import ImagenetVitWorkload as JaxImagenetViTWorkload +from algoperf.workloads.imagenet_vit.imagenet_pytorch.workload import ImagenetVitWorkload as PyTorchImagenetViTWorkload +from algoperf.workloads.librispeech_conformer.librispeech_jax.workload import LibriSpeechConformerWorkload as JaxLibriSpeechConformerWorkload +from algoperf.workloads.librispeech_conformer.librispeech_pytorch.workload import LibriSpeechConformerWorkload as PytorchLibriSpeechConformerWorkload +from algoperf.workloads.librispeech_deepspeech.librispeech_jax.workload import LibriSpeechDeepSpeechWorkload as JaxLibriSpeechDeepSpeechWorkload +from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.workload import LibriSpeechDeepSpeechWorkload as PytorchLibriSpeechDeepSpeechWorkload +from algoperf.workloads.mnist.mnist_jax.workload import MnistWorkload as JaxMnistWorkload +from algoperf.workloads.mnist.mnist_pytorch.workload import MnistWorkload as PyTorchMnistWorkload +from algoperf.workloads.ogbg.ogbg_jax.workload import OgbgWorkload as JaxOgbgWorkload +from algoperf.workloads.ogbg.ogbg_pytorch.workload import OgbgWorkload as PyTorchOgbgWorkload +from algoperf.workloads.wmt.wmt_jax.workload import WmtWorkload as JaxWmtWorkload +from algoperf.workloads.wmt.wmt_pytorch.workload import WmtWorkload as PyTorchWmtWorkload # pylint:enable=line-too-long WORKLOADS = [ diff --git a/tests/test_param_types.py b/tests/test_param_types.py index 7cf8f63c3..d3722ae86 100644 --- a/tests/test_param_types.py +++ b/tests/test_param_types.py @@ -2,30 +2,30 @@ import pytest from absl import logging -from algorithmic_efficiency import spec +from algoperf import spec # isort: skip_file # pylint:disable=line-too-long -from algorithmic_efficiency.workloads.cifar.cifar_jax.workload import CifarWorkload as JaxCifarWorkload -from algorithmic_efficiency.workloads.cifar.cifar_pytorch.workload import CifarWorkload as PyTorchCifarWorkload -from algorithmic_efficiency.workloads.criteo1tb.criteo1tb_jax.workload import Criteo1TbDlrmSmallWorkload as JaxCriteoWorkload -from algorithmic_efficiency.workloads.criteo1tb.criteo1tb_pytorch.workload import Criteo1TbDlrmSmallWorkload as PyTorchCriteoWorkload -from algorithmic_efficiency.workloads.fastmri.fastmri_jax.workload import FastMRIWorkload as JaxFastMRIWorkload -from algorithmic_efficiency.workloads.fastmri.fastmri_pytorch.workload import FastMRIWorkload as PyTorchFastMRIWorkload -from algorithmic_efficiency.workloads.imagenet_resnet.imagenet_jax.workload import ImagenetResNetWorkload as JaxImagenetResNetWorkload -from algorithmic_efficiency.workloads.imagenet_resnet.imagenet_pytorch.workload import ImagenetResNetWorkload as PyTorchImagenetResNetWorkload -from algorithmic_efficiency.workloads.imagenet_vit.imagenet_jax.workload import ImagenetVitWorkload as JaxImagenetViTWorkload -from algorithmic_efficiency.workloads.imagenet_vit.imagenet_pytorch.workload import ImagenetVitWorkload as PyTorchImagenetViTWorkload -from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_jax.workload import LibriSpeechConformerWorkload as JaxLibriSpeechConformerWorkload -from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_pytorch.workload import LibriSpeechConformerWorkload as PytorchLibriSpeechConformerWorkload -from algorithmic_efficiency.workloads.librispeech_deepspeech.librispeech_jax.workload import LibriSpeechDeepSpeechWorkload as JaxLibriSpeechDeepSpeechWorkload -from algorithmic_efficiency.workloads.librispeech_deepspeech.librispeech_pytorch.workload import LibriSpeechDeepSpeechWorkload as PytorchLibriSpeechDeepSpeechWorkload -from algorithmic_efficiency.workloads.mnist.mnist_jax.workload import MnistWorkload as JaxMnistWorkload -from algorithmic_efficiency.workloads.mnist.mnist_pytorch.workload import MnistWorkload as PyTorchMnistWorkload -from algorithmic_efficiency.workloads.ogbg.ogbg_jax.workload import OgbgWorkload as JaxOgbgWorkload -from algorithmic_efficiency.workloads.ogbg.ogbg_pytorch.workload import OgbgWorkload as PyTorchOgbgWorkload -from algorithmic_efficiency.workloads.wmt.wmt_jax.workload import WmtWorkload as JaxWmtWorkload -from algorithmic_efficiency.workloads.wmt.wmt_pytorch.workload import WmtWorkload as PyTorchWmtWorkload +from algoperf.workloads.cifar.cifar_jax.workload import CifarWorkload as JaxCifarWorkload +from algoperf.workloads.cifar.cifar_pytorch.workload import CifarWorkload as PyTorchCifarWorkload +from algoperf.workloads.criteo1tb.criteo1tb_jax.workload import Criteo1TbDlrmSmallWorkload as JaxCriteoWorkload +from algoperf.workloads.criteo1tb.criteo1tb_pytorch.workload import Criteo1TbDlrmSmallWorkload as PyTorchCriteoWorkload +from algoperf.workloads.fastmri.fastmri_jax.workload import FastMRIWorkload as JaxFastMRIWorkload +from algoperf.workloads.fastmri.fastmri_pytorch.workload import FastMRIWorkload as PyTorchFastMRIWorkload +from algoperf.workloads.imagenet_resnet.imagenet_jax.workload import ImagenetResNetWorkload as JaxImagenetResNetWorkload +from algoperf.workloads.imagenet_resnet.imagenet_pytorch.workload import ImagenetResNetWorkload as PyTorchImagenetResNetWorkload +from algoperf.workloads.imagenet_vit.imagenet_jax.workload import ImagenetVitWorkload as JaxImagenetViTWorkload +from algoperf.workloads.imagenet_vit.imagenet_pytorch.workload import ImagenetVitWorkload as PyTorchImagenetViTWorkload +from algoperf.workloads.librispeech_conformer.librispeech_jax.workload import LibriSpeechConformerWorkload as JaxLibriSpeechConformerWorkload +from algoperf.workloads.librispeech_conformer.librispeech_pytorch.workload import LibriSpeechConformerWorkload as PytorchLibriSpeechConformerWorkload +from algoperf.workloads.librispeech_deepspeech.librispeech_jax.workload import LibriSpeechDeepSpeechWorkload as JaxLibriSpeechDeepSpeechWorkload +from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.workload import LibriSpeechDeepSpeechWorkload as PytorchLibriSpeechDeepSpeechWorkload +from algoperf.workloads.mnist.mnist_jax.workload import MnistWorkload as JaxMnistWorkload +from algoperf.workloads.mnist.mnist_pytorch.workload import MnistWorkload as PyTorchMnistWorkload +from algoperf.workloads.ogbg.ogbg_jax.workload import OgbgWorkload as JaxOgbgWorkload +from algoperf.workloads.ogbg.ogbg_pytorch.workload import OgbgWorkload as PyTorchOgbgWorkload +from algoperf.workloads.wmt.wmt_jax.workload import WmtWorkload as JaxWmtWorkload +from algoperf.workloads.wmt.wmt_pytorch.workload import WmtWorkload as PyTorchWmtWorkload # pylint:enable=line-too-long WORKLOADS = [ diff --git a/tests/test_ssim.py b/tests/test_ssim.py index fadf41f64..ba0b2ca7f 100644 --- a/tests/test_ssim.py +++ b/tests/test_ssim.py @@ -9,14 +9,14 @@ import numpy as np import torch -from algorithmic_efficiency.pytorch_utils import pytorch_setup -from algorithmic_efficiency.workloads.fastmri.fastmri_jax.ssim import \ +from algoperf.pytorch_utils import pytorch_setup +from algoperf.workloads.fastmri.fastmri_jax.ssim import \ _uniform_filter as _jax_uniform_filter -from algorithmic_efficiency.workloads.fastmri.fastmri_jax.ssim import \ +from algoperf.workloads.fastmri.fastmri_jax.ssim import \ ssim as jax_ssim -from algorithmic_efficiency.workloads.fastmri.fastmri_pytorch.ssim import \ +from algoperf.workloads.fastmri.fastmri_pytorch.ssim import \ _uniform_filter as _pytorch_uniform_filter -from algorithmic_efficiency.workloads.fastmri.fastmri_pytorch.ssim import \ +from algoperf.workloads.fastmri.fastmri_pytorch.ssim import \ ssim as pytorch_ssim # Make sure no GPU memory is preallocated to Jax. diff --git a/tests/test_version.py b/tests/test_version.py index ef01d4f32..d1bfbd18f 100644 --- a/tests/test_version.py +++ b/tests/test_version.py @@ -1,13 +1,13 @@ """Check whether the __version__ attribute is set correctly.""" -import algorithmic_efficiency +import algoperf def test_version_attribute(): """Check whether __version__ exists and is a valid string.""" - assert hasattr(algorithmic_efficiency, "__version__") - version = algorithmic_efficiency.__version__ + assert hasattr(algoperf, "__version__") + version = algoperf.__version__ assert isinstance(version, str) version_elements = version.split(".") print(version_elements) diff --git a/tests/version_test.py b/tests/version_test.py new file mode 100644 index 000000000..2205b305f --- /dev/null +++ b/tests/version_test.py @@ -0,0 +1,16 @@ +"""Check whether the __version__ attribute is set correctly.""" + +import algoperf + + +def test_version_attribute(): + """Check whether __version__ exists and is a valid string.""" + + assert hasattr(algoperf, "__version__") + version = algoperf.__version__ + assert isinstance(version, str) + version_elements = version.split(".") + print(version_elements) + # Only check the first three elements, i.e. major, minor, patch. + # The remaining elements contain commit hash and dirty status. + assert all(el.isnumeric() for el in version_elements[0:3]) diff --git a/tests/workloads/imagenet_resnet/imagenet_jax/workload_test.py b/tests/workloads/imagenet_resnet/imagenet_jax/workload_test.py index 6a85c2196..66b1dbc6a 100644 --- a/tests/workloads/imagenet_resnet/imagenet_jax/workload_test.py +++ b/tests/workloads/imagenet_resnet/imagenet_jax/workload_test.py @@ -4,8 +4,8 @@ import jax import jax.numpy as jnp -from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.imagenet_resnet.imagenet_jax.workload import \ +from algoperf import spec +from algoperf.workloads.imagenet_resnet.imagenet_jax.workload import \ ImagenetResNetWorkload From bc666a7420f01a15bd8f96ab66249e6afa6ced9e Mon Sep 17 00:00:00 2001 From: Frank Schneider Date: Wed, 15 Jan 2025 16:47:18 +0100 Subject: [PATCH 071/105] Fix linting (due to shorter package name in imports) --- algoperf/workloads/cifar/cifar_jax/models.py | 3 +-- algoperf/workloads/cifar/cifar_jax/workload.py | 3 +-- algoperf/workloads/cifar/cifar_pytorch/models.py | 3 +-- .../workloads/cifar/cifar_pytorch/workload.py | 3 +-- .../workloads/fastmri/fastmri_jax/workload.py | 3 +-- .../fastmri/fastmri_pytorch/workload.py | 6 ++---- .../imagenet_jax/input_pipeline.py | 3 +-- .../imagenet_resnet/imagenet_jax/workload.py | 6 ++---- .../imagenet_resnet/imagenet_pytorch/workload.py | 6 ++---- .../workloads/imagenet_resnet/imagenet_v2.py | 3 +-- .../imagenet_vit/imagenet_jax/workload.py | 6 ++---- .../imagenet_vit/imagenet_pytorch/models.py | 3 +-- .../imagenet_vit/imagenet_pytorch/workload.py | 9 +++------ .../librispeech_jax/workload.py | 3 +-- .../librispeech_pytorch/workload.py | 3 +-- .../librispeech_jax/workload.py | 3 +-- tests/modeldiffs/wmt/compare.py | 3 +-- tests/reference_algorithm_tests.py | 6 ++---- tests/test_num_params.py | 15 +++++---------- tests/test_ssim.py | 3 +-- tests/version_test.py | 16 ---------------- 21 files changed, 31 insertions(+), 78 deletions(-) delete mode 100644 tests/version_test.py diff --git a/algoperf/workloads/cifar/cifar_jax/models.py b/algoperf/workloads/cifar/cifar_jax/models.py index 4d5df766e..957079272 100644 --- a/algoperf/workloads/cifar/cifar_jax/models.py +++ b/algoperf/workloads/cifar/cifar_jax/models.py @@ -11,8 +11,7 @@ import jax.numpy as jnp from algoperf import spec -from algoperf.workloads.imagenet_resnet.imagenet_jax.models import \ - ResNetBlock +from algoperf.workloads.imagenet_resnet.imagenet_jax.models import ResNetBlock ModuleDef = nn.Module diff --git a/algoperf/workloads/cifar/cifar_jax/workload.py b/algoperf/workloads/cifar/cifar_jax/workload.py index f4bcffbc3..952bb977d 100644 --- a/algoperf/workloads/cifar/cifar_jax/workload.py +++ b/algoperf/workloads/cifar/cifar_jax/workload.py @@ -14,8 +14,7 @@ from algoperf import param_utils from algoperf import spec from algoperf.workloads.cifar.cifar_jax import models -from algoperf.workloads.cifar.cifar_jax.input_pipeline import \ - create_input_iter +from algoperf.workloads.cifar.cifar_jax.input_pipeline import create_input_iter from algoperf.workloads.cifar.workload import BaseCifarWorkload diff --git a/algoperf/workloads/cifar/cifar_pytorch/models.py b/algoperf/workloads/cifar/cifar_pytorch/models.py index 393d568b9..e6a7a8a81 100644 --- a/algoperf/workloads/cifar/cifar_pytorch/models.py +++ b/algoperf/workloads/cifar/cifar_pytorch/models.py @@ -16,8 +16,7 @@ BasicBlock from algoperf.workloads.imagenet_resnet.imagenet_pytorch.models import \ Bottleneck -from algoperf.workloads.imagenet_resnet.imagenet_pytorch.models import \ - conv1x1 +from algoperf.workloads.imagenet_resnet.imagenet_pytorch.models import conv1x1 class ResNet(nn.Module): diff --git a/algoperf/workloads/cifar/cifar_pytorch/workload.py b/algoperf/workloads/cifar/cifar_pytorch/workload.py index 2ba92f0b9..b16d62204 100644 --- a/algoperf/workloads/cifar/cifar_pytorch/workload.py +++ b/algoperf/workloads/cifar/cifar_pytorch/workload.py @@ -16,8 +16,7 @@ from algoperf import param_utils from algoperf import pytorch_utils from algoperf import spec -from algoperf.workloads.cifar.cifar_pytorch.models import \ - resnet18 +from algoperf.workloads.cifar.cifar_pytorch.models import resnet18 from algoperf.workloads.cifar.workload import BaseCifarWorkload USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_utils.pytorch_setup() diff --git a/algoperf/workloads/fastmri/fastmri_jax/workload.py b/algoperf/workloads/fastmri/fastmri_jax/workload.py index 393aa19d7..1156cf30a 100644 --- a/algoperf/workloads/fastmri/fastmri_jax/workload.py +++ b/algoperf/workloads/fastmri/fastmri_jax/workload.py @@ -13,8 +13,7 @@ import algoperf.random_utils as prng from algoperf.workloads.fastmri.fastmri_jax.models import UNet from algoperf.workloads.fastmri.fastmri_jax.ssim import ssim -from algoperf.workloads.fastmri.workload import \ - BaseFastMRIWorkload +from algoperf.workloads.fastmri.workload import BaseFastMRIWorkload class FastMRIWorkload(BaseFastMRIWorkload): diff --git a/algoperf/workloads/fastmri/fastmri_pytorch/workload.py b/algoperf/workloads/fastmri/fastmri_pytorch/workload.py index f40654678..58943de2f 100644 --- a/algoperf/workloads/fastmri/fastmri_pytorch/workload.py +++ b/algoperf/workloads/fastmri/fastmri_pytorch/workload.py @@ -13,11 +13,9 @@ from algoperf import pytorch_utils from algoperf import spec import algoperf.random_utils as prng -from algoperf.workloads.fastmri.fastmri_pytorch.models import \ - UNet +from algoperf.workloads.fastmri.fastmri_pytorch.models import UNet from algoperf.workloads.fastmri.fastmri_pytorch.ssim import ssim -from algoperf.workloads.fastmri.workload import \ - BaseFastMRIWorkload +from algoperf.workloads.fastmri.workload import BaseFastMRIWorkload USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_utils.pytorch_setup() diff --git a/algoperf/workloads/imagenet_resnet/imagenet_jax/input_pipeline.py b/algoperf/workloads/imagenet_resnet/imagenet_jax/input_pipeline.py index 709a318c2..66105335b 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_jax/input_pipeline.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_jax/input_pipeline.py @@ -14,8 +14,7 @@ from algoperf import data_utils from algoperf import spec -from algoperf.workloads.imagenet_resnet.imagenet_jax import \ - randaugment +from algoperf.workloads.imagenet_resnet.imagenet_jax import randaugment TFDS_SPLIT_NAME = { 'train': 'train', 'eval_train': 'train', 'validation': 'validation' diff --git a/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py b/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py index b445e9f00..9494fd63c 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py @@ -21,10 +21,8 @@ from algoperf import random_utils as prng from algoperf import spec from algoperf.workloads.imagenet_resnet import imagenet_v2 -from algoperf.workloads.imagenet_resnet.imagenet_jax import \ - input_pipeline -from algoperf.workloads.imagenet_resnet.imagenet_jax import \ - models +from algoperf.workloads.imagenet_resnet.imagenet_jax import input_pipeline +from algoperf.workloads.imagenet_resnet.imagenet_jax import models from algoperf.workloads.imagenet_resnet.workload import \ BaseImagenetResNetWorkload diff --git a/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py index 7a08f325e..92b651ba2 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py @@ -22,10 +22,8 @@ from algoperf import spec import algoperf.random_utils as prng from algoperf.workloads.imagenet_resnet import imagenet_v2 -from algoperf.workloads.imagenet_resnet.imagenet_pytorch import \ - randaugment -from algoperf.workloads.imagenet_resnet.imagenet_pytorch.models import \ - resnet50 +from algoperf.workloads.imagenet_resnet.imagenet_pytorch import randaugment +from algoperf.workloads.imagenet_resnet.imagenet_pytorch.models import resnet50 from algoperf.workloads.imagenet_resnet.workload import \ BaseImagenetResNetWorkload diff --git a/algoperf/workloads/imagenet_resnet/imagenet_v2.py b/algoperf/workloads/imagenet_resnet/imagenet_v2.py index f63ddbc34..84d364586 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_v2.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_v2.py @@ -10,8 +10,7 @@ from algoperf import data_utils from algoperf import spec -from algoperf.workloads.imagenet_resnet.imagenet_jax import \ - input_pipeline +from algoperf.workloads.imagenet_resnet.imagenet_jax import input_pipeline def get_imagenet_v2_iter(data_dir: str, diff --git a/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py b/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py index 2261aac6d..9a6190f5e 100644 --- a/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py +++ b/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py @@ -12,10 +12,8 @@ from algoperf.workloads.imagenet_resnet.imagenet_jax.workload import \ ImagenetResNetWorkload from algoperf.workloads.imagenet_vit.imagenet_jax import models -from algoperf.workloads.imagenet_vit.workload import \ - BaseImagenetVitWorkload -from algoperf.workloads.imagenet_vit.workload import \ - decode_variant +from algoperf.workloads.imagenet_vit.workload import BaseImagenetVitWorkload +from algoperf.workloads.imagenet_vit.workload import decode_variant # Make sure we inherit from the ViT base workload first. diff --git a/algoperf/workloads/imagenet_vit/imagenet_pytorch/models.py b/algoperf/workloads/imagenet_vit/imagenet_pytorch/models.py index 4fac8bd35..fcf0992d3 100644 --- a/algoperf/workloads/imagenet_vit/imagenet_pytorch/models.py +++ b/algoperf/workloads/imagenet_vit/imagenet_pytorch/models.py @@ -14,8 +14,7 @@ from algoperf import init_utils from algoperf import spec -from algoperf.workloads.wmt.wmt_pytorch.models import \ - MultiheadAttention +from algoperf.workloads.wmt.wmt_pytorch.models import MultiheadAttention def posemb_sincos_2d(patches: spec.Tensor, temperature=10_000.) -> spec.Tensor: diff --git a/algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py b/algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py index 20b294b47..97bb38515 100644 --- a/algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py +++ b/algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py @@ -11,12 +11,9 @@ from algoperf import spec from algoperf.workloads.imagenet_resnet.imagenet_pytorch.workload import \ ImagenetResNetWorkload -from algoperf.workloads.imagenet_vit.imagenet_pytorch import \ - models -from algoperf.workloads.imagenet_vit.workload import \ - BaseImagenetVitWorkload -from algoperf.workloads.imagenet_vit.workload import \ - decode_variant +from algoperf.workloads.imagenet_vit.imagenet_pytorch import models +from algoperf.workloads.imagenet_vit.workload import BaseImagenetVitWorkload +from algoperf.workloads.imagenet_vit.workload import decode_variant USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_utils.pytorch_setup() diff --git a/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py b/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py index b4fdb0811..8d9872461 100644 --- a/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py @@ -18,8 +18,7 @@ from algoperf.workloads.librispeech_conformer import workload from algoperf.workloads.librispeech_conformer.input_pipeline import \ LibriSpeechDataset -from algoperf.workloads.librispeech_conformer.librispeech_jax import \ - models +from algoperf.workloads.librispeech_conformer.librispeech_jax import models class LibriSpeechConformerWorkload(workload.BaseLibrispeechWorkload): diff --git a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/workload.py b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/workload.py index 592e63989..974b3bb19 100644 --- a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/workload.py +++ b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/workload.py @@ -19,8 +19,7 @@ from algoperf.workloads.librispeech_conformer import workload from algoperf.workloads.librispeech_conformer.input_pipeline import \ LibriSpeechDataset -from algoperf.workloads.librispeech_conformer.librispeech_pytorch import \ - models +from algoperf.workloads.librispeech_conformer.librispeech_pytorch import models USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_utils.pytorch_setup() diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py index 3e0781deb..9fd0898b4 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py @@ -10,8 +10,7 @@ from algoperf import spec from algoperf.workloads.librispeech_conformer.librispeech_jax.workload import \ LibriSpeechConformerWorkload -from algoperf.workloads.librispeech_deepspeech.librispeech_jax import \ - models +from algoperf.workloads.librispeech_deepspeech.librispeech_jax import models class LibriSpeechDeepSpeechWorkload(LibriSpeechConformerWorkload): diff --git a/tests/modeldiffs/wmt/compare.py b/tests/modeldiffs/wmt/compare.py index 73bc03f78..64401ef7f 100644 --- a/tests/modeldiffs/wmt/compare.py +++ b/tests/modeldiffs/wmt/compare.py @@ -7,8 +7,7 @@ import torch from algoperf import spec -from algoperf.workloads.wmt.wmt_jax.workload import \ - WmtWorkload as JaxWorkload +from algoperf.workloads.wmt.wmt_jax.workload import WmtWorkload as JaxWorkload from algoperf.workloads.wmt.wmt_pytorch.workload import \ WmtWorkload as PyTorchWorkload from tests.modeldiffs.diff import out_diff diff --git a/tests/reference_algorithm_tests.py b/tests/reference_algorithm_tests.py index 3f279a605..0a17e470c 100644 --- a/tests/reference_algorithm_tests.py +++ b/tests/reference_algorithm_tests.py @@ -45,10 +45,8 @@ from algoperf import random_utils as prng from algoperf.profiler import PassThroughProfiler from algoperf.workloads import workloads -from algoperf.workloads.ogbg import \ - input_pipeline as ogbg_input_pipeline -from algoperf.workloads.ogbg.ogbg_pytorch.workload import \ - _graph_map +from algoperf.workloads.ogbg import input_pipeline as ogbg_input_pipeline +from algoperf.workloads.ogbg.ogbg_pytorch.workload import _graph_map import submission_runner from tests.modeldiffs import diff as diff_utils diff --git a/tests/test_num_params.py b/tests/test_num_params.py index 83a23c9a4..b0633025e 100644 --- a/tests/test_num_params.py +++ b/tests/test_num_params.py @@ -17,8 +17,7 @@ resnet18 as PyTorchResNet_c10 from algoperf.workloads.imagenet_resnet.imagenet_pytorch.models import \ resnet50 as PyTorchResNet -from algoperf.workloads.imagenet_vit.imagenet_jax.models import \ - ViT as JaxViT +from algoperf.workloads.imagenet_vit.imagenet_jax.models import ViT as JaxViT from algoperf.workloads.imagenet_vit.imagenet_pytorch.models import \ ViT as PyTorchViT from algoperf.workloads.librispeech_conformer.librispeech_jax.models import \ @@ -29,17 +28,13 @@ ConformerConfig as PytorchConformerConfig from algoperf.workloads.librispeech_conformer.librispeech_pytorch.models import \ ConformerEncoderDecoder as PytorchConformer -from algoperf.workloads.mnist.mnist_jax.workload import \ - _Model as JaxMLP +from algoperf.workloads.mnist.mnist_jax.workload import _Model as JaxMLP from algoperf.workloads.mnist.mnist_pytorch.workload import \ _Model as PyTorchMLP from algoperf.workloads.ogbg.ogbg_jax.models import GNN as JaxGNN -from algoperf.workloads.ogbg.ogbg_pytorch.models import \ - GNN as PyTorchGNN -from algoperf.workloads.wmt.wmt_jax.models import \ - Transformer as JaxTransformer -from algoperf.workloads.wmt.wmt_jax.models import \ - TransformerConfig +from algoperf.workloads.ogbg.ogbg_pytorch.models import GNN as PyTorchGNN +from algoperf.workloads.wmt.wmt_jax.models import Transformer as JaxTransformer +from algoperf.workloads.wmt.wmt_jax.models import TransformerConfig from algoperf.workloads.wmt.wmt_pytorch.models import \ Transformer as PyTorchTransformer diff --git a/tests/test_ssim.py b/tests/test_ssim.py index ba0b2ca7f..920556964 100644 --- a/tests/test_ssim.py +++ b/tests/test_ssim.py @@ -12,8 +12,7 @@ from algoperf.pytorch_utils import pytorch_setup from algoperf.workloads.fastmri.fastmri_jax.ssim import \ _uniform_filter as _jax_uniform_filter -from algoperf.workloads.fastmri.fastmri_jax.ssim import \ - ssim as jax_ssim +from algoperf.workloads.fastmri.fastmri_jax.ssim import ssim as jax_ssim from algoperf.workloads.fastmri.fastmri_pytorch.ssim import \ _uniform_filter as _pytorch_uniform_filter from algoperf.workloads.fastmri.fastmri_pytorch.ssim import \ diff --git a/tests/version_test.py b/tests/version_test.py deleted file mode 100644 index 2205b305f..000000000 --- a/tests/version_test.py +++ /dev/null @@ -1,16 +0,0 @@ -"""Check whether the __version__ attribute is set correctly.""" - -import algoperf - - -def test_version_attribute(): - """Check whether __version__ exists and is a valid string.""" - - assert hasattr(algoperf, "__version__") - version = algoperf.__version__ - assert isinstance(version, str) - version_elements = version.split(".") - print(version_elements) - # Only check the first three elements, i.e. major, minor, patch. - # The remaining elements contain commit hash and dirty status. - assert all(el.isnumeric() for el in version_elements[0:3]) From d9f13ab9b6fd9c5ee0bd99d72ce7eb04851aa1c9 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 16 Jan 2025 22:30:09 +0000 Subject: [PATCH 072/105] upgrade_jax --- setup.cfg | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/setup.cfg b/setup.cfg index 8e37acb7a..43458cb07 100644 --- a/setup.cfg +++ b/setup.cfg @@ -121,17 +121,17 @@ jax_core_deps = # JAX CPU jax_cpu = - jax==0.4.35 - jaxlib==0.4.35 + jax==0.4.38 + jaxlib==0.4.38 %(jax_core_deps)s # JAX GPU # Note this installs both jax and jaxlib. jax_gpu = - jax==0.4.35 - jaxlib==0.4.35 - jax-cuda12-plugin[with_cuda]==0.4.35 - jax-cuda12-pjrt==0.4.35 + jax==0.4.38 + jaxlib==0.4.38 + jax-cuda12-plugin[with_cuda]==0.4.38 + jax-cuda12-pjrt==0.4.38 %(jax_core_deps)s # PyTorch CPU From 01eb8819dbe36d8b54987758706247c63b3f73df Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 16 Jan 2025 23:16:02 +0000 Subject: [PATCH 073/105] change jax version --- setup.cfg | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/setup.cfg b/setup.cfg index 43458cb07..040d1e26a 100644 --- a/setup.cfg +++ b/setup.cfg @@ -128,10 +128,10 @@ jax_cpu = # JAX GPU # Note this installs both jax and jaxlib. jax_gpu = - jax==0.4.38 - jaxlib==0.4.38 - jax-cuda12-plugin[with_cuda]==0.4.38 - jax-cuda12-pjrt==0.4.38 + jax==0.4.36 + jaxlib==0.4.36 + jax-cuda12-plugin[with_cuda]==0.4.36 + jax-cuda12-pjrt==0.4.36 %(jax_core_deps)s # PyTorch CPU From c9b641158e5ff41918f7e94f5d492c0b30e4ed80 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 16 Jan 2025 23:28:20 +0000 Subject: [PATCH 074/105] change jax python version --- setup.cfg | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.cfg b/setup.cfg index 040d1e26a..8c512d32e 100644 --- a/setup.cfg +++ b/setup.cfg @@ -121,8 +121,8 @@ jax_core_deps = # JAX CPU jax_cpu = - jax==0.4.38 - jaxlib==0.4.38 + jax==0.4.36 + jaxlib==0.4.36 %(jax_core_deps)s # JAX GPU From fc526a433395e27b1656f3986073c367f4d6654e Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Sat, 18 Jan 2025 07:54:54 +0000 Subject: [PATCH 075/105] modify limits of ints --- algorithmic_efficiency/random_utils.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/algorithmic_efficiency/random_utils.py b/algorithmic_efficiency/random_utils.py index f40a98003..a579976ad 100644 --- a/algorithmic_efficiency/random_utils.py +++ b/algorithmic_efficiency/random_utils.py @@ -16,32 +16,32 @@ FLAGS = flags.FLAGS -# Annoyingly, RandomState(seed) requires seed to be in [0, 2 ** 32 - 1] (an +# Annoyingly, RandomState(seed) requires seed to be in [0, 2 ** 31 - 1] (an # unsigned int), while RandomState.randint only accepts and returns signed ints. -MAX_UINT32 = 2**32 - 1 -MIN_UINT32 = 0 +MAX_INT32 = 2**31 - 1 +MIN_INT32 = 0 SeedType = Union[int, list, np.ndarray] def _signed_to_unsigned(seed: SeedType) -> SeedType: if isinstance(seed, int): - return seed % MAX_UINT32 + return seed % MAX_INT32 if isinstance(seed, list): - return [s % MAX_UINT32 for s in seed] + return [s % MAX_INT32 for s in seed] if isinstance(seed, np.ndarray): - return np.array([s % MAX_UINT32 for s in seed.tolist()]) + return np.array([s % MAX_INT32 for s in seed.tolist()]) def _fold_in(seed: SeedType, data: Any) -> List[Union[SeedType, Any]]: rng = np.random.RandomState(seed=_signed_to_unsigned(seed)) - new_seed = rng.randint(MIN_UINT32, MAX_UINT32, dtype=np.uint32) + new_seed = rng.randint(MIN_INT32, MAX_INT32, dtype=np.int32) return [new_seed, data] def _split(seed: SeedType, num: int = 2) -> SeedType: rng = np.random.RandomState(seed=_signed_to_unsigned(seed)) - return rng.randint(MIN_UINT32, MAX_UINT32, dtype=np.uint32, size=[num, 2]) + return rng.randint(MIN_INT32, MAX_INT32, dtype=np.int32, size=[num, 2]) def _PRNGKey(seed: SeedType) -> SeedType: # pylint: disable=invalid-name From 7d580f1eb3955d42de328d90f423a09cc0ed25b5 Mon Sep 17 00:00:00 2001 From: init-22 Date: Sat, 18 Jan 2025 15:27:26 +0530 Subject: [PATCH 076/105] fix: using jax.random.key_data only when the workload is jax --- submission_runner.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/submission_runner.py b/submission_runner.py index 4d494f607..b371489bd 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -213,7 +213,9 @@ def train_once( ) -> Tuple[spec.Timing, Dict[str, Any]]: _reset_cuda_mem() data_rng, opt_init_rng, model_init_rng, rng = prng.split(rng, 4) - data_rng = jax.random.key_data(data_rng) + + if FLAGS.framework == 'jax': + data_rng = jax.random.key_data(data_rng) # Workload setup. logging.info('Initializing dataset.') if hasattr(workload, '_eval_num_workers'): @@ -345,7 +347,9 @@ def train_once( data_select_rng, update_rng, prep_eval_rng, eval_rng = \ prng.split(step_rng, 4) - eval_rng = jax.random.key_data(eval_rng) + + if FLAGS.framework == 'jax': + eval_rng = jax.random.key_data(eval_rng) with profiler.profile('Data selection'): batch = data_selection(workload, From 57156188502a70a50e29164ad29822e093c2b8db Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Wed, 22 Jan 2025 21:43:38 +0000 Subject: [PATCH 077/105] revert to use PRNGKey --- algorithmic_efficiency/random_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algorithmic_efficiency/random_utils.py b/algorithmic_efficiency/random_utils.py index 603a644d1..a579976ad 100644 --- a/algorithmic_efficiency/random_utils.py +++ b/algorithmic_efficiency/random_utils.py @@ -75,5 +75,5 @@ def split(seed: SeedType, num: int = 2) -> SeedType: def PRNGKey(seed: SeedType) -> SeedType: # pylint: disable=invalid-name if FLAGS.framework == 'jax': _check_jax_install() - return jax_rng.key(seed) + return jax_rng.PRNGKey(seed) return _PRNGKey(seed) From 3fb722d5e1ba47ab8d5e00caa29a534ae6b70e89 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Wed, 22 Jan 2025 21:46:35 +0000 Subject: [PATCH 078/105] revert changes to submission runner for prng key --- submission_runner.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/submission_runner.py b/submission_runner.py index b371489bd..228cbc4d7 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -214,8 +214,7 @@ def train_once( _reset_cuda_mem() data_rng, opt_init_rng, model_init_rng, rng = prng.split(rng, 4) - if FLAGS.framework == 'jax': - data_rng = jax.random.key_data(data_rng) + data_rng = jax.random.key_data(data_rng) # Workload setup. logging.info('Initializing dataset.') if hasattr(workload, '_eval_num_workers'): @@ -348,8 +347,7 @@ def train_once( data_select_rng, update_rng, prep_eval_rng, eval_rng = \ prng.split(step_rng, 4) - if FLAGS.framework == 'jax': - eval_rng = jax.random.key_data(eval_rng) + eval_rng = jax.random.key_data(eval_rng) with profiler.profile('Data selection'): batch = data_selection(workload, From 9b7cee40010d45561f38acad0a701b0188e72181 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Wed, 22 Jan 2025 23:42:38 +0000 Subject: [PATCH 079/105] remove extracting key_data --- submission_runner.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/submission_runner.py b/submission_runner.py index 228cbc4d7..06963fc9d 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -214,7 +214,6 @@ def train_once( _reset_cuda_mem() data_rng, opt_init_rng, model_init_rng, rng = prng.split(rng, 4) - data_rng = jax.random.key_data(data_rng) # Workload setup. logging.info('Initializing dataset.') if hasattr(workload, '_eval_num_workers'): @@ -347,8 +346,6 @@ def train_once( data_select_rng, update_rng, prep_eval_rng, eval_rng = \ prng.split(step_rng, 4) - eval_rng = jax.random.key_data(eval_rng) - with profiler.profile('Data selection'): batch = data_selection(workload, input_queue, From 5775ed166b241abf12088dd2569ef9894eb9e6de Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 23 Jan 2025 00:58:24 +0000 Subject: [PATCH 080/105] cast np.int32 as int for random.Random arg --- .../workloads/cifar/cifar_pytorch/workload.py | 2 +- .../workloads/imagenet_resnet/imagenet_pytorch/workload.py | 2 +- .../librispeech_conformer/librispeech_pytorch/workload.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/algorithmic_efficiency/workloads/cifar/cifar_pytorch/workload.py b/algorithmic_efficiency/workloads/cifar/cifar_pytorch/workload.py index 7abcf4d6c..119c6378c 100644 --- a/algorithmic_efficiency/workloads/cifar/cifar_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/cifar/cifar_pytorch/workload.py @@ -82,7 +82,7 @@ def _build_dataset( } if split == 'eval_train': train_indices = indices_split['train'] - random.Random(data_rng[0]).shuffle(train_indices) + random.Random(int(data_rng[0])).shuffle(train_indices) indices_split['eval_train'] = train_indices[:self.num_eval_train_examples] if split in indices_split: dataset = torch.utils.data.Subset(dataset, indices_split[split]) diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_pytorch/workload.py b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_pytorch/workload.py index 3549911fa..6387a40c0 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_pytorch/workload.py @@ -120,7 +120,7 @@ def _build_dataset( if split == 'eval_train': indices = list(range(self.num_train_examples)) - random.Random(data_rng[0]).shuffle(indices) + random.Random(int(data_rng[0])).shuffle(indices) dataset = torch.utils.data.Subset(dataset, indices[:self.num_eval_train_examples]) diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py index 155b30920..83f0a2de7 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py @@ -166,7 +166,7 @@ def _build_input_queue( ds = LibriSpeechDataset(split=ds_split, data_dir=data_dir) if split == 'eval_train': indices = list(range(len(ds))) - random.Random(data_rng[0]).shuffle(indices) + random.Random(int(data_rng[0])).shuffle(indices) ds = torch.utils.data.Subset(ds, indices[:self.num_eval_train_examples]) sampler = None From 1352e70ad71fef6b0508bcc95d3af317fe62fa90 Mon Sep 17 00:00:00 2001 From: init-22 Date: Sun, 26 Jan 2025 11:36:58 +0530 Subject: [PATCH 081/105] fix: vim installation --- docker/Dockerfile | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/docker/Dockerfile b/docker/Dockerfile index 77dac5313..07375dd92 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -24,7 +24,8 @@ RUN apt-get update && apt-get install -y \ libffi-dev \ curl \ libbz2-dev \ - liblzma-dev + liblzma-dev \ + vim # Download and install Python 3.11 RUN cd /tmp \ @@ -91,6 +92,7 @@ RUN cd /algorithmic-efficiency && pip install -e '.[full]' RUN cd /algorithmic-efficiency && git fetch origin RUN cd /algorithmic-efficiency && git pull +RUN pip install wandb # Todo: remove this, this is temporary for developing COPY scripts/startup.sh /algorithmic-efficiency/docker/scripts/startup.sh From d7eebf88bf2fe61fc523991d62e5a0095af9fd64 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 30 Jan 2025 01:49:23 +0000 Subject: [PATCH 082/105] use inductor backend to compile deepspeech instead of eager --- submission_runner.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/submission_runner.py b/submission_runner.py index 06963fc9d..d2dcb03ac 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -242,8 +242,9 @@ def train_once( 'ogbg', 'criteo1tb', 'imagenet_vit', + 'librispeech_deepspeech' ] - eager_backend_workloads = ['librispeech_deepspeech'] + eager_backend_workloads = [] aot_eager_backend_workloads = [] loss_compilation_workloads = [ 'fastmri', 'librispeech_deepspeech', 'ogbg', 'wmt' From 58159c5fd7c80c7dc43a62f9df714baf7d82eadb Mon Sep 17 00:00:00 2001 From: init-22 Date: Sun, 2 Feb 2025 23:17:14 +0530 Subject: [PATCH 083/105] adding mem_fraction 0.80 for jax workfloads to resolve OOM of certain worklods --- docker/Dockerfile | 1 - submission_runner.py | 13 +++++++++++-- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/docker/Dockerfile b/docker/Dockerfile index 07375dd92..76bc5cfe0 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -92,7 +92,6 @@ RUN cd /algorithmic-efficiency && pip install -e '.[full]' RUN cd /algorithmic-efficiency && git fetch origin RUN cd /algorithmic-efficiency && git pull -RUN pip install wandb # Todo: remove this, this is temporary for developing COPY scripts/startup.sh /algorithmic-efficiency/docker/scripts/startup.sh diff --git a/submission_runner.py b/submission_runner.py index d2dcb03ac..2acc9d33c 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -693,12 +693,21 @@ def main(_): # Prevent OOM on librispeech conformer. base_workload = workloads.get_base_workload_name(FLAGS.workload) - if base_workload == 'librispeech_conformer': - os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.85' + + if base_workload == [ + 'librispeech_conformer', + 'librispeech_deepspeech', + 'imagenet_vit', + 'criteo1tb' + ]: + os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.80' if FLAGS.set_pytorch_max_split_size: os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:256' + if FLAGS.framework == 'pytorch' and base_workload == 'librispeech_conformer': + os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True' + # Extend path according to framework. workload_metadata['workload_path'] = os.path.join( BASE_WORKLOADS_DIR, From 81bc93d2394762d883058d922ebc524ad69706f5 Mon Sep 17 00:00:00 2001 From: init-22 Date: Mon, 3 Feb 2025 15:26:07 +0530 Subject: [PATCH 084/105] mem fraction typo --- submission_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/submission_runner.py b/submission_runner.py index 2acc9d33c..6024ba1a2 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -694,7 +694,7 @@ def main(_): # Prevent OOM on librispeech conformer. base_workload = workloads.get_base_workload_name(FLAGS.workload) - if base_workload == [ + if base_workload in [ 'librispeech_conformer', 'librispeech_deepspeech', 'imagenet_vit', From f6ca2bce0593a622cf53f90c1750bf27848eb892 Mon Sep 17 00:00:00 2001 From: init-22 Date: Mon, 3 Feb 2025 16:02:46 +0530 Subject: [PATCH 085/105] env variable for conformer set at the top --- submission_runner.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/submission_runner.py b/submission_runner.py index 6024ba1a2..da4e8371c 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -21,6 +21,12 @@ import itertools import json import os + +os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True' +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" # Disables tensorRT, cuda warnings. +# disable only for deepspeech if it works fine for other workloads. +os.environ['XLA_FLAGS'] = '--xla_gpu_enable_triton_gemm=false' + import struct import time from types import MappingProxyType @@ -30,12 +36,10 @@ from absl import flags from absl import logging import jax +import tensorflow as tf import torch import torch.distributed as dist -os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" # Disables tensorRT, cuda warnings. -import tensorflow as tf - # Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make # it unavailable to JAX. tf.config.set_visible_devices([], 'GPU') @@ -52,9 +56,6 @@ from algorithmic_efficiency.pytorch_utils import sync_ddp_time from algorithmic_efficiency.workloads import workloads -# disable only for deepspeech if it works fine for other workloads. -os.environ['XLA_FLAGS'] = '--xla_gpu_enable_triton_gemm=false' - # TODO(znado): make a nicer registry of workloads that lookup in. BASE_WORKLOADS_DIR = workloads.BASE_WORKLOADS_DIR @@ -702,12 +703,13 @@ def main(_): ]: os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.80' + if base_workload != 'librispeech_conformer': + # Remove the environment variable (only for workloads other than librispeech conformer). + del os.environ['PYTORCH_CUDA_ALLOC_CONF'] + if FLAGS.set_pytorch_max_split_size: os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:256' - if FLAGS.framework == 'pytorch' and base_workload == 'librispeech_conformer': - os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True' - # Extend path according to framework. workload_metadata['workload_path'] = os.path.join( BASE_WORKLOADS_DIR, From 59126ae04c70a46adf36d4873176b1c13a1ba9f7 Mon Sep 17 00:00:00 2001 From: Frank Schneider Date: Mon, 3 Feb 2025 12:00:06 +0100 Subject: [PATCH 086/105] Update documentation with new targets --- DOCUMENTATION.md | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/DOCUMENTATION.md b/DOCUMENTATION.md index 63439cb09..255ccdeea 100644 --- a/DOCUMENTATION.md +++ b/DOCUMENTATION.md @@ -418,7 +418,7 @@ In each trial, the tuning trial with the fastest training time to achieve the *v Submissions to this ruleset are not allowed to have user-defined hyperparameters. This ruleset allows both submissions that use the same hyperparameters for all workloads, including the randomized ones (e.g. Adam with default parameters), as well as submissions that perform inner-loop tuning during their training run (e.g. SGD with line searches). -Submissions will run on one instance of the [benchmarking hardware](#benchmarking-hardware). As always, submissions are allowed to perform inner-loop tuning (e.g. for their learning rate) but the tuning efforts will be part of their score. A submission will run *S=5* times and its score will be the median time to reach the target evaluation metric value on the validation set. To account for the lack of external tuning, submissions have a longer time budget to reach the target performance. Compared to the [external tuning ruleset](#external-tuning-ruleset), the `max_runtime` is tripled. Runs that do not reach the target performance of the evaluation metric within this allotted time budget have an infinite time. +Submissions will run on one instance of the [benchmarking hardware](#benchmarking-hardware). As always, submissions are allowed to perform inner-loop tuning (e.g. for their learning rate) but the tuning efforts will be part of their score. A submission will run *S=5* times and its score will be the median time to reach the target evaluation metric value on the validation set. To account for the lack of external tuning, submissions have a longer time budget to reach the target performance. Compared to the [external tuning ruleset](#external-tuning-ruleset), the `max_runtime` is $1.5$ times longer. Runs that do not reach the target performance of the evaluation metric within this allotted time budget have an infinite time. ### Workloads @@ -439,11 +439,11 @@ The currently eight fixed workloads are: | | **Task** | **Dataset** | **Model** | **Loss** | **Metric** | Validation
**Target** | Test
**Target** | Maximum
**Runtime**
(in secs) | |------------|-------------------------------|-------------|-------------------------|----------|------------|--------------------------|----------------------|------------------------| | **1** | Clickthrough rate prediction | Criteo 1TB | DLRMsmall | CE | CE | 0.123735 | 0.126041 | 7,703 | -| **2** | MRI reconstruction | fastMRI | U-Net | L1 | SSIM | 0.723653 | 0.740633 | 8,859 | -| **3
4** | Image classification | ImageNet | ResNet-50
ViT | CE | ER | 0.22569
0.22691 | 0.3440
0.3481 | 63,008
77,520 | -| **5
6** | Speech recognition | LibriSpeech | Conformer
DeepSpeech | CTC | WER | 0.085884
0.119936 | 0.052981
0.074143 | 61,068
55,506 | -| **7** | Molecular property prediction | OGBG | GNN | CE | mAP | 0.28098 | 0.268729 | 18,477 | -| **8** | Translation | WMT | Transformer | CE | BLEU | 30.8491 | 30.7219 | 48,151 | +| **2** | MRI reconstruction | fastMRI | U-Net | L1 | SSIM | 0.723653 | 0.740633 | 4,430 | +| **3
4** | Image classification | ImageNet | ResNet-50
ViT | CE | ER | 0.22569
0.22691 | 0.3440
0.3481 | 66,159
69,768 | +| **5
6** | Speech recognition | LibriSpeech | Conformer
DeepSpeech | CTC | WER | 0.085884
0.119936 | 0.052981
0.074143 | 58,015
44,405 | +| **7** | Molecular property prediction | OGBG | GNN | CE | mAP | 0.28098 | 0.268729 | 12,011 | +| **8** | Translation | WMT | Transformer | CE | BLEU | 30.8491 | 30.7219 | 43,336 | Default Dropout Values for Different Workloads: @@ -503,7 +503,7 @@ For self-reported results, it is acceptable to perform the tuning trials on hard Target performances on the validation and test sets will be defined for each [workload](#workloads) separately. For the [fixed workloads](#fixed-workloads), we take the best performance achievable by one of four standard algorithms (AdamW, NadamW, Nesterov Momentum, and Heavy Ball Momentum). These target-setting algorithms will follow the general process of the external tuning ruleset, with a significantly larger tuning budget of $200$ trials to guarantee competitive performance. Once the best algorithm and its hyperparameters are determined, training is repeated $20$ times. The median of the best achieved validation errors across seeds is used as the *validation* target. Out of the $10$ repeated runs that achieved this validation target, we took the worst achieved test error across seeds as our *test* target. Taking the median validation performance after rerunning the best hyperparameter point prevents our procedure from selecting a lucky outlier. To save computational resources, we only tuned two training algorithms instead of four, for the [randomized workloads](#randomized-workloads). For each workload variant, we used NadamW and the other best-performing training algorithm on the corresponding fixed workload the randomized workload is based on. -Both [tuning rulesets](#tuning) will use the same target performances. The runtime of the target-setting algorithms on each workload will be chosen to match published results and is constrained by the overall time budget of roughly a single week for all fixed workloads. The `max_runtime` for submissions on each workload is $\frac{1}{3}$ longer than the runtime of the target-setting algorithms (this `max_runtime` will be three times as much for the self-tuning ruleset, see the [Self-tuning ruleset](#self-tuning-ruleset) section). +Both [tuning rulesets](#tuning) will use the same target performances. The runtime of the target-setting algorithms on each workload will be chosen to match published results and is constrained by the overall time budget of roughly a single week for all fixed workloads. The initial `max_runtime` for submissions on each workload was $\frac{1}{3}$ longer than the runtime of the target-setting algorithms (this `max_runtime` will be $1.5$ times as much for the self-tuning ruleset, see the [Self-tuning ruleset](#self-tuning-ruleset) section). After the initial round of submissions, we have adapated the `max_runtime` based on the performance of the submissions (see [this issue](https://github.com/mlcommons/algorithmic-efficiency/issues/836)). #### Benchmark score using performance profiles From ff0086c81d6602d60da1be3cd94691fef2a3510e Mon Sep 17 00:00:00 2001 From: Frank Schneider Date: Mon, 3 Feb 2025 12:03:45 +0100 Subject: [PATCH 087/105] Use 1.5 instead of 3x the budget for self-tuning --- submission_runner.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/submission_runner.py b/submission_runner.py index 1be56aeab..a9d13e7cb 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -409,10 +409,10 @@ def train_once( prepare_for_eval_end_time - prepare_for_eval_start_time) # Check if time is remaining, - # use 3x the runtime budget for the self-tuning ruleset. + # use 1.5x the runtime budget for the self-tuning ruleset. max_allowed_runtime_sec = ( workload.max_allowed_runtime_sec if FLAGS.tuning_ruleset == 'external' - else 3 * workload.max_allowed_runtime_sec) + else 1.5 * workload.max_allowed_runtime_sec) train_state['is_time_remaining'] = ( train_state['accumulated_submission_time'] < max_allowed_runtime_sec) From 03bc79e4b905d2ab423e5bb7d8bbf346f42ccefa Mon Sep 17 00:00:00 2001 From: Frank Schneider Date: Mon, 3 Feb 2025 12:04:02 +0100 Subject: [PATCH 088/105] Update max allowed runtimes for each workload --- algoperf/workloads/criteo1tb/workload.py | 2 +- algoperf/workloads/fastmri/workload.py | 2 +- .../workloads/imagenet_resnet/workload.py | 2 +- algoperf/workloads/imagenet_vit/workload.py | 5 ++-- .../librispeech_conformer/workload.py | 2 +- .../librispeech_jax/workload.py | 12 +++++----- .../librispeech_pytorch/workload.py | 23 ++++++++++--------- algoperf/workloads/ogbg/workload.py | 5 ++-- algoperf/workloads/wmt/workload.py | 2 +- 9 files changed, 27 insertions(+), 28 deletions(-) diff --git a/algoperf/workloads/criteo1tb/workload.py b/algoperf/workloads/criteo1tb/workload.py index 80ec9d67a..db87771e8 100644 --- a/algoperf/workloads/criteo1tb/workload.py +++ b/algoperf/workloads/criteo1tb/workload.py @@ -93,7 +93,7 @@ def train_stddev(self): @property def max_allowed_runtime_sec(self) -> int: - return 7703 # ~2 hours. + return 7_703 # ~2.1 hours. @property def eval_period_time_sec(self) -> int: diff --git a/algoperf/workloads/fastmri/workload.py b/algoperf/workloads/fastmri/workload.py index e9a2a313a..5b4ce7b3e 100644 --- a/algoperf/workloads/fastmri/workload.py +++ b/algoperf/workloads/fastmri/workload.py @@ -95,7 +95,7 @@ def accelerations(self): @property def max_allowed_runtime_sec(self) -> int: - return 8859 # ~2.5 hours + return 4_430 # ~1.2 hours @property def eval_period_time_sec(self) -> int: diff --git a/algoperf/workloads/imagenet_resnet/workload.py b/algoperf/workloads/imagenet_resnet/workload.py index 8b3393ded..9a18f3681 100644 --- a/algoperf/workloads/imagenet_resnet/workload.py +++ b/algoperf/workloads/imagenet_resnet/workload.py @@ -102,7 +102,7 @@ def resize_size(self) -> int: @property def max_allowed_runtime_sec(self) -> int: - return 63_008 # ~17.5 hours + return 66_159 # ~18.4 hours @property def eval_period_time_sec(self) -> int: diff --git a/algoperf/workloads/imagenet_vit/workload.py b/algoperf/workloads/imagenet_vit/workload.py index 7f06715a3..30b774eec 100644 --- a/algoperf/workloads/imagenet_vit/workload.py +++ b/algoperf/workloads/imagenet_vit/workload.py @@ -3,8 +3,7 @@ from typing import Dict, Iterator, Optional from algoperf import spec -from algoperf.workloads.imagenet_resnet.workload import \ - BaseImagenetResNetWorkload +from algoperf.workloads.imagenet_resnet.workload import BaseImagenetResNetWorkload def decode_variant(variant: str) -> Dict[str, int]: @@ -81,7 +80,7 @@ def eval_batch_size(self) -> int: @property def max_allowed_runtime_sec(self) -> int: - return 77_520 # ~22 hours + return 69_768 # ~19.4 hours @property def eval_period_time_sec(self) -> int: diff --git a/algoperf/workloads/librispeech_conformer/workload.py b/algoperf/workloads/librispeech_conformer/workload.py index c9f5a3c59..92cc7f61f 100644 --- a/algoperf/workloads/librispeech_conformer/workload.py +++ b/algoperf/workloads/librispeech_conformer/workload.py @@ -79,7 +79,7 @@ def train_stddev(self): @property def max_allowed_runtime_sec(self) -> int: - return 61_068 # ~17 hours + return 58_015 # ~16.1 hours @property def eval_period_time_sec(self) -> int: diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py index 9fd0898b4..43ff4f066 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py @@ -1,15 +1,15 @@ import functools from typing import Dict, Optional, Tuple -from flax import jax_utils import jax import jax.numpy as jnp import numpy as np +from flax import jax_utils -from algoperf import param_utils -from algoperf import spec -from algoperf.workloads.librispeech_conformer.librispeech_jax.workload import \ - LibriSpeechConformerWorkload +from algoperf import param_utils, spec +from algoperf.workloads.librispeech_conformer.librispeech_jax.workload import ( + LibriSpeechConformerWorkload, +) from algoperf.workloads.librispeech_deepspeech.librispeech_jax import models @@ -104,7 +104,7 @@ def step_hint(self) -> int: @property def max_allowed_runtime_sec(self) -> int: - return 55_506 # ~15.4 hours + return 44_405 # ~12.3 hours @property def use_tanh(self) -> bool: diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py index 4f8ad1974..d368d0bc5 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py @@ -3,17 +3,18 @@ import torch from torch.nn.parallel import DistributedDataParallel as DDP -from algoperf import param_utils -from algoperf import spec +from algoperf import param_utils, spec from algoperf.pytorch_utils import pytorch_setup -from algoperf.workloads.librispeech_conformer.librispeech_pytorch.models import \ - initialize -from algoperf.workloads.librispeech_conformer.librispeech_pytorch.workload import \ - LibriSpeechConformerWorkload -from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.models import \ - DeepspeechConfig -from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.models import \ - DeepspeechEncoderDecoder +from algoperf.workloads.librispeech_conformer.librispeech_pytorch.models import ( + initialize, +) +from algoperf.workloads.librispeech_conformer.librispeech_pytorch.workload import ( + LibriSpeechConformerWorkload, +) +from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.models import ( + DeepspeechConfig, + DeepspeechEncoderDecoder, +) USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_setup() @@ -81,7 +82,7 @@ def step_hint(self) -> int: @property def max_allowed_runtime_sec(self) -> int: - return 55_506 # ~15.4 hours + return 44_405 # ~12.3 hours @property def use_tanh(self) -> bool: diff --git a/algoperf/workloads/ogbg/workload.py b/algoperf/workloads/ogbg/workload.py index c6a2162d7..4a4e94990 100644 --- a/algoperf/workloads/ogbg/workload.py +++ b/algoperf/workloads/ogbg/workload.py @@ -9,8 +9,7 @@ from algoperf import random_utils as prng from algoperf import spec -from algoperf.workloads.ogbg import input_pipeline -from algoperf.workloads.ogbg import metrics +from algoperf.workloads.ogbg import input_pipeline, metrics class BaseOgbgWorkload(spec.Workload): @@ -88,7 +87,7 @@ def train_stddev(self): @property def max_allowed_runtime_sec(self) -> int: - return 18_477 # ~5 hours + return 12_011 # ~3.3 hours @property def eval_period_time_sec(self) -> int: diff --git a/algoperf/workloads/wmt/workload.py b/algoperf/workloads/wmt/workload.py index e9a07d2b3..5a3e97ed8 100644 --- a/algoperf/workloads/wmt/workload.py +++ b/algoperf/workloads/wmt/workload.py @@ -88,7 +88,7 @@ def train_stddev(self): @property def max_allowed_runtime_sec(self) -> int: - return 48_151 # ~13.5 hours + return 43_336 # ~12.0 hours @property def eval_period_time_sec(self) -> int: From 16eb8d645b7205f9eec84a7fb8ae46a2e1613a36 Mon Sep 17 00:00:00 2001 From: Frank Schneider Date: Mon, 3 Feb 2025 12:04:13 +0100 Subject: [PATCH 089/105] Clarify in comment that its the old budgets --- scoring/compute_speedups.py | 1 + 1 file changed, 1 insertion(+) diff --git a/scoring/compute_speedups.py b/scoring/compute_speedups.py index 5fb5f259d..d0e5bf70b 100644 --- a/scoring/compute_speedups.py +++ b/scoring/compute_speedups.py @@ -25,6 +25,7 @@ 'Whether to save the results to disk.') FLAGS = flags.FLAGS +# These are the old budgets, used in the first iteration of the competition. MAX_BUDGETS = { 'criteo1tb': 7703, 'fastmri': 8859, From d3f788d9337c8d5fd0dda594abda1bcd463cf37c Mon Sep 17 00:00:00 2001 From: Frank Schneider Date: Mon, 3 Feb 2025 12:13:12 +0100 Subject: [PATCH 090/105] Adapt step hint as well --- algoperf/spec.py | 2 +- algoperf/workloads/criteo1tb/workload.py | 2 +- algoperf/workloads/fastmri/workload.py | 4 ++-- algoperf/workloads/imagenet_resnet/workload.py | 4 ++-- algoperf/workloads/imagenet_vit/workload.py | 4 ++-- algoperf/workloads/librispeech_conformer/workload.py | 4 ++-- .../librispeech_deepspeech/librispeech_jax/workload.py | 4 ++-- .../librispeech_deepspeech/librispeech_pytorch/workload.py | 4 ++-- algoperf/workloads/ogbg/workload.py | 4 ++-- algoperf/workloads/wmt/workload.py | 4 ++-- 10 files changed, 18 insertions(+), 18 deletions(-) diff --git a/algoperf/spec.py b/algoperf/spec.py index 381d52f32..cf4f1a14e 100644 --- a/algoperf/spec.py +++ b/algoperf/spec.py @@ -206,7 +206,7 @@ def eval_period_time_sec(self) -> int: @property @abc.abstractmethod def step_hint(self) -> int: - """Max num steps the baseline algo was given to reach the target.""" + """Approx. steps the baseline can do in the allowed runtime budget.""" @property def param_shapes(self): diff --git a/algoperf/workloads/criteo1tb/workload.py b/algoperf/workloads/criteo1tb/workload.py index db87771e8..617b2e987 100644 --- a/algoperf/workloads/criteo1tb/workload.py +++ b/algoperf/workloads/criteo1tb/workload.py @@ -123,7 +123,7 @@ def _build_input_queue( @property def step_hint(self) -> int: - """Max num steps the baseline algo was given to reach the target.""" + """Approx. steps the baseline can do in the allowed runtime budget.""" return 10_666 def _eval_model_on_split(self, diff --git a/algoperf/workloads/fastmri/workload.py b/algoperf/workloads/fastmri/workload.py index 5b4ce7b3e..051749cc3 100644 --- a/algoperf/workloads/fastmri/workload.py +++ b/algoperf/workloads/fastmri/workload.py @@ -103,8 +103,8 @@ def eval_period_time_sec(self) -> int: @property def step_hint(self) -> int: - """Max num steps the baseline algo was given to reach the target.""" - return 36_189 + """Approx. steps the baseline can do in the allowed runtime budget.""" + return 18_094 def _build_input_queue(self, data_rng: spec.RandomState, diff --git a/algoperf/workloads/imagenet_resnet/workload.py b/algoperf/workloads/imagenet_resnet/workload.py index 9a18f3681..83fe97108 100644 --- a/algoperf/workloads/imagenet_resnet/workload.py +++ b/algoperf/workloads/imagenet_resnet/workload.py @@ -144,5 +144,5 @@ def _build_input_queue( @property def step_hint(self) -> int: - """Max num steps the baseline algo was given to reach the target.""" - return 186_666 + """Approx. steps the baseline can do in the allowed runtime budget.""" + return 195_999 diff --git a/algoperf/workloads/imagenet_vit/workload.py b/algoperf/workloads/imagenet_vit/workload.py index 30b774eec..9c885ca7c 100644 --- a/algoperf/workloads/imagenet_vit/workload.py +++ b/algoperf/workloads/imagenet_vit/workload.py @@ -109,5 +109,5 @@ def _build_dataset( @property def step_hint(self) -> int: - """Max num steps the baseline algo was given to reach the target.""" - return 186_666 + """Approx. steps the baseline can do in the allowed runtime budget.""" + return 167_999 diff --git a/algoperf/workloads/librispeech_conformer/workload.py b/algoperf/workloads/librispeech_conformer/workload.py index 92cc7f61f..94f01dd97 100644 --- a/algoperf/workloads/librispeech_conformer/workload.py +++ b/algoperf/workloads/librispeech_conformer/workload.py @@ -87,5 +87,5 @@ def eval_period_time_sec(self) -> int: @property def step_hint(self) -> int: - """Max num steps the baseline algo was given to reach the target.""" - return 80_000 + """Approx. steps the baseline can do in the allowed runtime budget.""" + return 76_000 diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py index 43ff4f066..1cadebf45 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py @@ -99,8 +99,8 @@ def test_target_value(self) -> float: @property def step_hint(self) -> int: - """Max num steps the baseline algo was given to reach the target.""" - return 48_000 + """Approx. steps the baseline can do in the allowed runtime budget.""" + return 38_400 @property def max_allowed_runtime_sec(self) -> int: diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py index d368d0bc5..c72c1daee 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py @@ -77,8 +77,8 @@ def test_target_value(self) -> float: @property def step_hint(self) -> int: - """Max num steps the baseline algo was given to reach the target.""" - return 48_000 + """Approx. steps the baseline can do in the allowed runtime budget.""" + return 38_400 @property def max_allowed_runtime_sec(self) -> int: diff --git a/algoperf/workloads/ogbg/workload.py b/algoperf/workloads/ogbg/workload.py index 4a4e94990..ca123f885 100644 --- a/algoperf/workloads/ogbg/workload.py +++ b/algoperf/workloads/ogbg/workload.py @@ -139,8 +139,8 @@ def loss_fn( @property def step_hint(self) -> int: - """Max num steps the baseline algo was given to reach the target.""" - return 80_000 + """Approx. steps the baseline can do in the allowed runtime budget.""" + return 52_000 @abc.abstractmethod def _normalize_eval_metrics( diff --git a/algoperf/workloads/wmt/workload.py b/algoperf/workloads/wmt/workload.py index 5a3e97ed8..51b33373d 100644 --- a/algoperf/workloads/wmt/workload.py +++ b/algoperf/workloads/wmt/workload.py @@ -96,8 +96,8 @@ def eval_period_time_sec(self) -> int: @property def step_hint(self) -> int: - """Max num steps the baseline algo was given to reach the target.""" - return 133_333 + """Approx. steps the baseline can do in the allowed runtime budget.""" + return 120_000 @property def pre_ln(self) -> bool: From b3dec6742341833d9f24bb67f96cf46da658e009 Mon Sep 17 00:00:00 2001 From: Frank Schneider Date: Mon, 3 Feb 2025 12:15:03 +0100 Subject: [PATCH 091/105] Clarify `step_hint` --- DOCUMENTATION.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/DOCUMENTATION.md b/DOCUMENTATION.md index 255ccdeea..795846efd 100644 --- a/DOCUMENTATION.md +++ b/DOCUMENTATION.md @@ -91,7 +91,7 @@ With the exception of `_build_input_queue`, submitters can call any of these fun def step_hint(self): -> int ``` -- The `step_hint` function gives the number of global steps the baseline algorithm was allowed to use to reach the targets for a workload. Note that the baseline algorithms may have reached the target in fewer steps than this, but these were the max number of steps the baseline algorithms used for their learning rate schedules. Submitters can use this to help specify learning rate (or other) schedules. +- The `step_hint` function gives the number of global steps the baseline algorithm can perform with the `max_runtime` to reach the targets for a workload. The `step_hint` is therefore dependent on the `max_runtime` and the workload. Note that the baseline algorithms may have reached the target in fewer steps than this, but these were the max number of steps the baseline algorithms used for their learning rate schedules. Submitters can use this to help specify learning rate (or other) schedules. ###### Data augmentation and preprocessing From b4ed6cc11d1730b3402e83e5162302d14681a9c9 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 4 Feb 2025 01:38:52 +0000 Subject: [PATCH 092/105] set env variables for pytorch before initializing w ddp. --- submission_runner.py | 28 +++++++++++++--------------- 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/submission_runner.py b/submission_runner.py index da4e8371c..495fd2039 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -22,11 +22,6 @@ import json import os -os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True' -os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" # Disables tensorRT, cuda warnings. -# disable only for deepspeech if it works fine for other workloads. -os.environ['XLA_FLAGS'] = '--xla_gpu_enable_triton_gemm=false' - import struct import time from types import MappingProxyType @@ -56,6 +51,11 @@ from algorithmic_efficiency.pytorch_utils import sync_ddp_time from algorithmic_efficiency.workloads import workloads +# Environment variables +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" # Disables tensorRT, cuda warnings. +# disable only for deepspeech if it works fine for other workloads +os.environ['XLA_FLAGS'] = '--xla_gpu_enable_triton_gemm=false' + # TODO(znado): make a nicer registry of workloads that lookup in. BASE_WORKLOADS_DIR = workloads.BASE_WORKLOADS_DIR @@ -681,6 +681,14 @@ def main(_): else: profiler = PassThroughProfiler() + # Set PyTorch environment variables before initializing w DDP + base_workload = workloads.get_base_workload_name(FLAGS.workload) + if base_workload == 'librispeech_conformer': + os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True' + + if FLAGS.set_pytorch_max_split_size: + os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:256' + if FLAGS.framework == 'pytorch': pytorch_init(USE_PYTORCH_DDP, RANK, profiler) @@ -692,9 +700,6 @@ def main(_): workload_metadata = WORKLOADS[FLAGS.workload] - # Prevent OOM on librispeech conformer. - base_workload = workloads.get_base_workload_name(FLAGS.workload) - if base_workload in [ 'librispeech_conformer', 'librispeech_deepspeech', @@ -703,13 +708,6 @@ def main(_): ]: os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.80' - if base_workload != 'librispeech_conformer': - # Remove the environment variable (only for workloads other than librispeech conformer). - del os.environ['PYTORCH_CUDA_ALLOC_CONF'] - - if FLAGS.set_pytorch_max_split_size: - os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:256' - # Extend path according to framework. workload_metadata['workload_path'] = os.path.join( BASE_WORKLOADS_DIR, From ebf0341ab43d78d7f70be88416543a970c327efb Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Wed, 5 Feb 2025 23:08:59 +0000 Subject: [PATCH 093/105] set jax to 0.4.26 --- setup.cfg | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/setup.cfg b/setup.cfg index 8c512d32e..3c435c453 100644 --- a/setup.cfg +++ b/setup.cfg @@ -42,7 +42,8 @@ install_requires = tensorflow-datasets==4.9.7 gputil==1.4.0 psutil==6.1.0 - clu==0.0.12 + # clu==0.0.12 + clu matplotlib>=3.9.2 tabulate==0.9.0 wandb==0.18.7 @@ -107,31 +108,34 @@ wmt = # JAX Core jax_core_deps = - flax==0.10.1 - optax==0.2.4 + # flax==0.10.1 + flax + # optax==0.2.4 + optax # Fix chex (optax dependency) version. # Not fixing it can raise dependency issues with our # jax version. # Todo(kasimbeg): verify if this is necessary after we # upgrade jax. - chex==0.1.87 + # chex==0.1.87 + chex ml_dtypes==0.4.1 protobuf==4.25.5 # JAX CPU jax_cpu = - jax==0.4.36 - jaxlib==0.4.36 + jax==0.4.26 + jaxlib==0.4.26 %(jax_core_deps)s # JAX GPU # Note this installs both jax and jaxlib. jax_gpu = - jax==0.4.36 - jaxlib==0.4.36 - jax-cuda12-plugin[with_cuda]==0.4.36 - jax-cuda12-pjrt==0.4.36 + jax==0.4.26 + jaxlib==0.4.26 + jax-cuda12-plugin[with_cuda]==0.4.26 + jax-cuda12-pjrt==0.4.26 %(jax_core_deps)s # PyTorch CPU From 1ce6deac8cec6b5e06bd161b2deb5392432d43c5 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Fri, 7 Feb 2025 23:22:38 +0000 Subject: [PATCH 094/105] set jax versions --- setup.cfg | 25 ++++++++----------------- 1 file changed, 8 insertions(+), 17 deletions(-) diff --git a/setup.cfg b/setup.cfg index 3c435c453..43b31b536 100644 --- a/setup.cfg +++ b/setup.cfg @@ -42,8 +42,7 @@ install_requires = tensorflow-datasets==4.9.7 gputil==1.4.0 psutil==6.1.0 - # clu==0.0.12 - clu + clu==0.0.12 matplotlib>=3.9.2 tabulate==0.9.0 wandb==0.18.7 @@ -108,17 +107,9 @@ wmt = # JAX Core jax_core_deps = - # flax==0.10.1 - flax - # optax==0.2.4 - optax - # Fix chex (optax dependency) version. - # Not fixing it can raise dependency issues with our - # jax version. - # Todo(kasimbeg): verify if this is necessary after we - # upgrade jax. - # chex==0.1.87 - chex + flax==0.8.4 + optax==0.2.2 + chex==0.1.86 ml_dtypes==0.4.1 protobuf==4.25.5 @@ -132,10 +123,10 @@ jax_cpu = # JAX GPU # Note this installs both jax and jaxlib. jax_gpu = - jax==0.4.26 - jaxlib==0.4.26 - jax-cuda12-plugin[with_cuda]==0.4.26 - jax-cuda12-pjrt==0.4.26 + jax==0.4.28 + jaxlib==0.4.28 + jax-cuda12-plugin[with_cuda]==0.4.28 + jax-cuda12-pjrt==0.4.28 %(jax_core_deps)s # PyTorch CPU From 082be0311894082409cb5580ecb8cb2f13821b8e Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Sat, 8 Feb 2025 02:27:45 +0000 Subject: [PATCH 095/105] fix pytorch version --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 870cfe99a..6acdd3351 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -116,7 +116,7 @@ jax_gpu = [ ] pytorch_cpu = ["torch==2.5.1", "torchvision==0.20.1"] pytorch_gpu = [ - "torch==2.5.0", + "torch==2.5.1", "torchvision==0.20.1", ] # Note: omit the cuda suffix and installing from the appropriate wheel will result in using locally installed CUDA. wandb = ["wandb==0.16.5"] From 39bb87683aa068dad0040ac36c31d2d2d1769575 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Sat, 8 Feb 2025 02:28:45 +0000 Subject: [PATCH 096/105] fix jax versions --- pyproject.toml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 6acdd3351..c34ec00f3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -103,13 +103,13 @@ jax_core_deps = [ "protobuf==4.25.5", ] jax_cpu = [ - "jax==0.4.26", - "jaxlib==0.4.26", + "jax==0.4.28", + "jaxlib==0.4.28", "algorithmic_efficiency[jax_core_deps]", ] jax_gpu = [ - "jax==0.4.26", - "jaxlib==0.4.26", + "jax==0.4.28", + "jaxlib==0.4.28", "jax-cuda12-plugin[with_cuda]==0.4.28", "jax-cuda12-pjrt==0.4.28", "algorithmic_efficiency[jax_core_deps]", From b45a69b029f7d2d20e100cca5ea800843fec361a Mon Sep 17 00:00:00 2001 From: init-22 Date: Sat, 8 Feb 2025 14:13:13 +0530 Subject: [PATCH 097/105] fix: adding wandb under 'full' section --- pyproject.toml | 6 ++++-- submission_runner.py | 3 +-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index c34ec00f3..b77adaef7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,6 +47,7 @@ dependencies = [ "clu==0.0.12", "matplotlib>=3.9.2", "tabulate==0.9.0", + ] [build-system] @@ -70,7 +71,7 @@ version_file = "algorithmic_efficiency/_version.py" [project.optional-dependencies] # All workloads full = [ - "algorithmic_efficiency[criteo1tb,fastmri,ogbg,librispeech_conformer,wmt]", + "algorithmic_efficiency[criteo1tb,fastmri,ogbg,librispeech_conformer,wmt,wandb]", ] # All workloads plus development dependencies full_dev = ["algorithmic_efficiency[full,dev]"] @@ -83,6 +84,8 @@ dev = [ "pre-commit==4.0.1", ] +wandb = ["wandb==0.16.5"] + # Workloads criteo1tb = ["scikit-learn==1.5.2"] fastmri = ["h5py==3.12.0", "scikit-image==0.24.0"] @@ -119,7 +122,6 @@ pytorch_gpu = [ "torch==2.5.1", "torchvision==0.20.1", ] # Note: omit the cuda suffix and installing from the appropriate wheel will result in using locally installed CUDA. -wandb = ["wandb==0.16.5"] ############################################################################### # Linting Configurations # diff --git a/submission_runner.py b/submission_runner.py index 495fd2039..2753a604b 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -21,7 +21,6 @@ import itertools import json import os - import struct import time from types import MappingProxyType @@ -685,7 +684,7 @@ def main(_): base_workload = workloads.get_base_workload_name(FLAGS.workload) if base_workload == 'librispeech_conformer': os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True' - + if FLAGS.set_pytorch_max_split_size: os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:256' From b719a6e5a5bb05abf85da768a60ea43183468684 Mon Sep 17 00:00:00 2001 From: init-22 Date: Mon, 10 Feb 2025 21:58:51 +0530 Subject: [PATCH 098/105] fix: wandb version upgrade --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index b77adaef7..b4840b35c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -84,7 +84,7 @@ dev = [ "pre-commit==4.0.1", ] -wandb = ["wandb==0.16.5"] +wandb = ["wandb==0.19.6"] # Workloads criteo1tb = ["scikit-learn==1.5.2"] From 1ce3e624de7055af44d90b9fb0fab9a28ea268dd Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Mon, 10 Feb 2025 21:26:54 +0000 Subject: [PATCH 099/105] remove wandb from full --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index b4840b35c..9130f733f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -71,7 +71,7 @@ version_file = "algorithmic_efficiency/_version.py" [project.optional-dependencies] # All workloads full = [ - "algorithmic_efficiency[criteo1tb,fastmri,ogbg,librispeech_conformer,wmt,wandb]", + "algorithmic_efficiency[criteo1tb,fastmri,ogbg,librispeech_conformer,wmt]", ] # All workloads plus development dependencies full_dev = ["algorithmic_efficiency[full,dev]"] From 5be969bb6510ab379cd80716ab99bb06663c9020 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Mon, 10 Feb 2025 23:53:30 +0000 Subject: [PATCH 100/105] fix isort version in test --- .github/workflows/linting.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/linting.yml b/.github/workflows/linting.yml index e49686358..699289029 100644 --- a/.github/workflows/linting.yml +++ b/.github/workflows/linting.yml @@ -34,7 +34,7 @@ jobs: - name: Install isort run: | python -m pip install --upgrade pip - pip install isort + pip install isort==5.12.0 - name: Run isort run: | isort . --check --diff From a12733afad3bf0b1656bea409bf1ee658c9c88fe Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 11 Feb 2025 00:04:19 +0000 Subject: [PATCH 101/105] revert isort version --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 2cc9dfdc8..cc404f4b5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -77,7 +77,7 @@ full = [ full_dev = ["algoperf[full,dev]"] # Dependencies for developing the package dev = [ - "isort==5.13.0", + "isort==5.12.0", "pylint==2.17.4", "pytest==8.3.3", "yapf==0.32.0", From bee0e3f03a6b072519dcbdb522ec4afcca6fc3cc Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 11 Feb 2025 00:05:11 +0000 Subject: [PATCH 102/105] revert import order changes --- .../imagenet_jax/randaugment.py | 3 +-- algoperf/workloads/imagenet_vit/workload.py | 3 ++- .../librispeech_jax/workload.py | 10 ++++----- .../librispeech_pytorch/workload.py | 21 +++++++++---------- algoperf/workloads/ogbg/workload.py | 3 ++- 5 files changed, 20 insertions(+), 20 deletions(-) diff --git a/algoperf/workloads/imagenet_resnet/imagenet_jax/randaugment.py b/algoperf/workloads/imagenet_resnet/imagenet_jax/randaugment.py index 41002ff9b..98e6e0f8e 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_jax/randaugment.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_jax/randaugment.py @@ -7,14 +7,13 @@ import inspect import math -import tensorflow as tf - from algorithmic_efficiency.workloads.imagenet_resnet.imagenet_jax.custom_tf_addons import \ rotate_img from algorithmic_efficiency.workloads.imagenet_resnet.imagenet_jax.custom_tf_addons import \ transform from algorithmic_efficiency.workloads.imagenet_resnet.imagenet_jax.custom_tf_addons import \ translate +import tensorflow as tf # This signifies the max integer that the controller RNN could predict for the # augmentation scheme. diff --git a/algoperf/workloads/imagenet_vit/workload.py b/algoperf/workloads/imagenet_vit/workload.py index 9c885ca7c..f249ddee8 100644 --- a/algoperf/workloads/imagenet_vit/workload.py +++ b/algoperf/workloads/imagenet_vit/workload.py @@ -3,7 +3,8 @@ from typing import Dict, Iterator, Optional from algoperf import spec -from algoperf.workloads.imagenet_resnet.workload import BaseImagenetResNetWorkload +from algoperf.workloads.imagenet_resnet.workload import \ + BaseImagenetResNetWorkload def decode_variant(variant: str) -> Dict[str, int]: diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py index 1cadebf45..d3b616f43 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py @@ -1,15 +1,15 @@ import functools from typing import Dict, Optional, Tuple +from flax import jax_utils import jax import jax.numpy as jnp import numpy as np -from flax import jax_utils -from algoperf import param_utils, spec -from algoperf.workloads.librispeech_conformer.librispeech_jax.workload import ( - LibriSpeechConformerWorkload, -) +from algoperf import param_utils +from algoperf import spec +from algoperf.workloads.librispeech_conformer.librispeech_jax.workload import \ + LibriSpeechConformerWorkload from algoperf.workloads.librispeech_deepspeech.librispeech_jax import models diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py index c72c1daee..e5387f5cb 100644 --- a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py +++ b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py @@ -3,18 +3,17 @@ import torch from torch.nn.parallel import DistributedDataParallel as DDP -from algoperf import param_utils, spec +from algoperf import param_utils +from algoperf import spec from algoperf.pytorch_utils import pytorch_setup -from algoperf.workloads.librispeech_conformer.librispeech_pytorch.models import ( - initialize, -) -from algoperf.workloads.librispeech_conformer.librispeech_pytorch.workload import ( - LibriSpeechConformerWorkload, -) -from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.models import ( - DeepspeechConfig, - DeepspeechEncoderDecoder, -) +from algoperf.workloads.librispeech_conformer.librispeech_pytorch.models import \ + initialize +from algoperf.workloads.librispeech_conformer.librispeech_pytorch.workload import \ + LibriSpeechConformerWorkload +from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.models import \ + DeepspeechConfig +from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.models import \ + DeepspeechEncoderDecoder USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_setup() diff --git a/algoperf/workloads/ogbg/workload.py b/algoperf/workloads/ogbg/workload.py index ca123f885..971e7f0f6 100644 --- a/algoperf/workloads/ogbg/workload.py +++ b/algoperf/workloads/ogbg/workload.py @@ -9,7 +9,8 @@ from algoperf import random_utils as prng from algoperf import spec -from algoperf.workloads.ogbg import input_pipeline, metrics +from algoperf.workloads.ogbg import input_pipeline +from algoperf.workloads.ogbg import metrics class BaseOgbgWorkload(spec.Workload): From 1f72cb3f6df53d5d0330f363576832217ef0f537 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 11 Feb 2025 00:08:07 +0000 Subject: [PATCH 103/105] remove temporary testing for upgrades --- .../regression_tests_python_upgrade.yml | 183 ------------------ 1 file changed, 183 deletions(-) delete mode 100644 .github/workflows/regression_tests_python_upgrade.yml diff --git a/.github/workflows/regression_tests_python_upgrade.yml b/.github/workflows/regression_tests_python_upgrade.yml deleted file mode 100644 index 783395353..000000000 --- a/.github/workflows/regression_tests_python_upgrade.yml +++ /dev/null @@ -1,183 +0,0 @@ -name: Containerized Regression Tests Python Upgrades - -on: - pull_request: - branches: - - 'python_test_env_upgrade' - -jobs: - build_and_push_jax_docker_image: - runs-on: self-hosted - steps: - - uses: actions/checkout@v2 - - name: Build and push docker images - run: | - GIT_BRANCH=${{ github.head_ref || github.ref_name }} - FRAMEWORK=jax - IMAGE_NAME="algoperf_${FRAMEWORK}_${GIT_BRANCH}" - cd $HOME/algorithmic-efficiency/docker - docker build --no-cache -t $IMAGE_NAME . --build-arg framework=$FRAMEWORK --build-arg branch=$GIT_BRANCH - BUILD_RETURN=$? - if [[ ${BUILD_RETURN} != 0 ]]; then exit ${BUILD_RETURN}; fi - docker tag $IMAGE_NAME us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/$IMAGE_NAME - docker push us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/$IMAGE_NAME - build_and_push_pytorch_docker_image: - runs-on: self-hosted - steps: - - uses: actions/checkout@v2 - - name: Build and push docker images - run: | - GIT_BRANCH=${{ github.head_ref || github.ref_name }} - FRAMEWORK=pytorch - IMAGE_NAME="algoperf_${FRAMEWORK}_${GIT_BRANCH}" - cd $HOME/algorithmic-efficiency/docker - docker build --no-cache -t $IMAGE_NAME . --build-arg framework=$FRAMEWORK --build-arg branch=$GIT_BRANCH - BUILD_RETURN=$? - if [[ ${BUILD_RETURN} != 0 ]]; then exit ${BUILD_RETURN}; fi - docker tag $IMAGE_NAME us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/$IMAGE_NAME - docker push us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/$IMAGE_NAME - fastmri_jax: - runs-on: self-hosted - needs: build_and_push_jax_docker_image - steps: - - uses: actions/checkout@v2 - - name: Run containerized workload - run: | - docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} - docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} -d fastmri -f jax -s reference_algorithms/paper_baselines/adamw/jax/submission.py -w fastmri -t reference_algorithms/paper_baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false - imagenet_resnet_jax: - runs-on: self-hosted - needs: build_and_push_jax_docker_image - steps: - - uses: actions/checkout@v2 - - name: Run containerized workload - run: | - docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} - docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} -d imagenet -f jax -s reference_algorithms/paper_baselines/adamw/jax/submission.py -w imagenet_resnet -t reference_algorithms/paper_baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false - imagenet_vit_jax: - runs-on: self-hosted - needs: build_and_push_jax_docker_image - steps: - - uses: actions/checkout@v2 - - name: Run containerized workload - run: | - docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} - docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} -d imagenet -f jax -s reference_algorithms/paper_baselines/adamw/jax/submission.py -w imagenet_vit -t reference_algorithms/paper_baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false - ogbg_jax: - runs-on: self-hosted - needs: build_and_push_jax_docker_image - steps: - - uses: actions/checkout@v2 - - name: Run containerized workload - run: | - docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} - docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} -d ogbg -f jax -s reference_algorithms/paper_baselines/adamw/jax/submission.py -w ogbg -t reference_algorithms/paper_baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false - criteo_jax: - runs-on: self-hosted - needs: build_and_push_jax_docker_image - steps: - - uses: actions/checkout@v2 - - name: Run containerized workload - run: | - docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} - docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} -d criteo1tb -f jax -s reference_algorithms/paper_baselines/adamw/jax/submission.py -w criteo1tb -t reference_algorithms/paper_baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false - librispeech_conformer_jax: - runs-on: self-hosted - needs: build_and_push_jax_docker_image - steps: - - uses: actions/checkout@v2 - - name: Run containerized workload - run: | - docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} - docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} -d librispeech -f jax -s reference_algorithms/paper_baselines/adamw/jax/submission.py -w librispeech_conformer -t reference_algorithms/paper_baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false - librispeech_deepspeech_jax: - runs-on: self-hosted - needs: build_and_push_jax_docker_image - steps: - - uses: actions/checkout@v2 - - name: Run containerized workload - run: | - docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} - docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} -d librispeech -f jax -s reference_algorithms/paper_baselines/adamw/jax/submission.py -w librispeech_deepspeech -t reference_algorithms/paper_baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false - wmt_jax: - runs-on: self-hosted - needs: build_and_push_jax_docker_image - steps: - - uses: actions/checkout@v2 - - name: Run containerized workload - run: | - docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} - docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} -d wmt -f jax -s reference_algorithms/paper_baselines/adamw/jax/submission.py -w wmt -t reference_algorithms/paper_baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false - fastmri_pytorch: - runs-on: self-hosted - needs: build_and_push_pytorch_docker_image - steps: - - uses: actions/checkout@v2 - - name: Run containerized workload - run: | - docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} - docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} -d fastmri -f pytorch -s reference_algorithms/paper_baselines/adamw/pytorch/submission.py -w fastmri -t reference_algorithms/paper_baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false - imagenet_resnet_pytorch: - runs-on: self-hosted - needs: build_and_push_pytorch_docker_image - steps: - - uses: actions/checkout@v2 - - name: Run containerized workload - run: | - docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} - docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} -d imagenet -f pytorch -s reference_algorithms/paper_baselines/adamw/pytorch/submission.py -w imagenet_resnet -t reference_algorithms/paper_baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false - imagenet_vit_pytorch: - runs-on: self-hosted - needs: build_and_push_pytorch_docker_image - steps: - - uses: actions/checkout@v2 - - name: Run containerized workload - run: | - docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} - docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} -d imagenet -f pytorch -s reference_algorithms/paper_baselines/adamw/pytorch/submission.py -w imagenet_vit -t reference_algorithms/paper_baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false - ogbg_pytorch: - runs-on: self-hosted - needs: build_and_push_pytorch_docker_image - steps: - - uses: actions/checkout@v2 - - name: Run containerized workload - run: | - docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} - docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} -d ogbg -f pytorch -s reference_algorithms/paper_baselines/adamw/pytorch/submission.py -w ogbg -t reference_algorithms/paper_baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false - criteo_pytorch: - runs-on: self-hosted - needs: build_and_push_pytorch_docker_image - steps: - - uses: actions/checkout@v2 - - name: Run containerized workload - run: | - docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} - docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} -d criteo1tb -f pytorch -s reference_algorithms/paper_baselines/adamw/pytorch/submission.py -w criteo1tb -t reference_algorithms/paper_baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false - exit $? - librispeech_conformer_pytorch: - runs-on: self-hosted - needs: build_and_push_pytorch_docker_image - steps: - - uses: actions/checkout@v2 - - name: Run containerized workload - run: | - docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} - docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} -d librispeech -f pytorch -s reference_algorithms/paper_baselines/adamw/pytorch/submission.py -w librispeech_conformer -t reference_algorithms/paper_baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false - librispeech_deepspeech_pytorch: - runs-on: self-hosted - needs: build_and_push_pytorch_docker_image - steps: - - uses: actions/checkout@v2 - - name: Run containerized workload - run: | - docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} - docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} -d librispeech -f pytorch -s reference_algorithms/paper_baselines/adamw/pytorch/submission.py -w librispeech_deepspeech -t reference_algorithms/paper_baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false - wmt_pytorch: - runs-on: self-hosted - needs: build_and_push_pytorch_docker_image - steps: - - uses: actions/checkout@v2 - - name: Run containerized workload - run: | - docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} - docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} -d wmt -f pytorch -s reference_algorithms/paper_baselines/adamw/pytorch/submission.py -w wmt -t reference_algorithms/paper_baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false From 4dadef0b93b8b08606c51571389efbcdd2552698 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 11 Feb 2025 00:34:02 +0000 Subject: [PATCH 104/105] update import path in randaugment.py --- .../workloads/imagenet_resnet/imagenet_jax/randaugment.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/algoperf/workloads/imagenet_resnet/imagenet_jax/randaugment.py b/algoperf/workloads/imagenet_resnet/imagenet_jax/randaugment.py index 98e6e0f8e..accd9b4a9 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_jax/randaugment.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_jax/randaugment.py @@ -7,11 +7,11 @@ import inspect import math -from algorithmic_efficiency.workloads.imagenet_resnet.imagenet_jax.custom_tf_addons import \ +from algoperf.workloads.imagenet_resnet.imagenet_jax.custom_tf_addons import \ rotate_img -from algorithmic_efficiency.workloads.imagenet_resnet.imagenet_jax.custom_tf_addons import \ +from algoperf.workloads.imagenet_resnet.imagenet_jax.custom_tf_addons import \ transform -from algorithmic_efficiency.workloads.imagenet_resnet.imagenet_jax.custom_tf_addons import \ +from algoperf.workloads.imagenet_resnet.imagenet_jax.custom_tf_addons import \ translate import tensorflow as tf From f375099648d9d9904d16d2008146644a793b3bd7 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 11 Feb 2025 01:33:02 +0000 Subject: [PATCH 105/105] isort changes --- algoperf/workloads/imagenet_resnet/imagenet_jax/randaugment.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/algoperf/workloads/imagenet_resnet/imagenet_jax/randaugment.py b/algoperf/workloads/imagenet_resnet/imagenet_jax/randaugment.py index accd9b4a9..c68e2de33 100644 --- a/algoperf/workloads/imagenet_resnet/imagenet_jax/randaugment.py +++ b/algoperf/workloads/imagenet_resnet/imagenet_jax/randaugment.py @@ -7,13 +7,14 @@ import inspect import math +import tensorflow as tf + from algoperf.workloads.imagenet_resnet.imagenet_jax.custom_tf_addons import \ rotate_img from algoperf.workloads.imagenet_resnet.imagenet_jax.custom_tf_addons import \ transform from algoperf.workloads.imagenet_resnet.imagenet_jax.custom_tf_addons import \ translate -import tensorflow as tf # This signifies the max integer that the controller RNN could predict for the # augmentation scheme.