Skip to content

ValueError: pmap was requested to map its argument along axis 0, which implies that its rank should be at least 1, but is only 0 (its shape is ()) #813

@init-22

Description

@init-22

System Info:

Ubuntu 20.04,
Python 3.11,
Nvidia3080ti

Jax Versions:
jax==0.4.35
jax-cuda12-pjrt==0.4.35
jax-cuda12-plugin==0.4.35
jaxlib==0.4.35

Getting the pmap error while running the following command:

python3 submission_runner.py \
    --framework=pytorch \
    --workload=mnist \
    --experiment_dir=$HOME/experiments \
    --experiment_name=my_first_experiment \
    --submission_path=reference_algorithms/paper_baselines/adamw/jax/submission.py \
    --tuning_search_space=reference_algorithms/paper_baselines/adamw/tuning_search_space.json

Here is the full traceback:

Traceback (most recent call last):
  File "/home/isaac/miniconda3/envs/main/lib/python3.11/site-packages/jax/_src/api.py", line 1020, in _get_axis_size
    return shape[axis]
           ~~~~~^^^^^^
IndexError: tuple index out of range

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/isaac/Desktop/opensource/isaac_algorithmic_efficiency/algorithmic-efficiency/submission_runner.py", line 714, in <module>
    app.run(main)
  File "/home/isaac/miniconda3/envs/main/lib/python3.11/site-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/home/isaac/miniconda3/envs/main/lib/python3.11/site-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
             ^^^^^^^^^^
  File "/home/isaac/Desktop/opensource/isaac_algorithmic_efficiency/algorithmic-efficiency/submission_runner.py", line 682, in main
    score = score_submission_on_workload(
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/isaac/Desktop/opensource/isaac_algorithmic_efficiency/algorithmic-efficiency/submission_runner.py", line 587, in score_submission_on_workload
    timing, metrics = train_once(workload, workload_name,
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/isaac/Desktop/opensource/isaac_algorithmic_efficiency/algorithmic-efficiency/submission_runner.py", line 351, in train_once
    optimizer_state, model_params, model_state = update_params(
                                                 ^^^^^^^^^^^^^^
  File "/home/isaac/Desktop/opensource/isaac_algorithmic_efficiency/algorithmic-efficiency/reference_algorithms/paper_baselines/adamw/jax/submission.py", line 139, in update_params
    outputs = pmapped_train_step(workload,
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ValueError: pmap was requested to map its argument along axis 0, which implies that its rank should be at least 1, but is only 0 (its shape is ())
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions