- 
                Notifications
    You must be signed in to change notification settings 
- Fork 75
Closed
Description
System Info:
Ubuntu 20.04,
Python 3.11,
Nvidia3080ti
Jax Versions:
jax==0.4.35
jax-cuda12-pjrt==0.4.35
jax-cuda12-plugin==0.4.35
jaxlib==0.4.35
Getting the pmap error while running the following command:
python3 submission_runner.py \
    --framework=pytorch \
    --workload=mnist \
    --experiment_dir=$HOME/experiments \
    --experiment_name=my_first_experiment \
    --submission_path=reference_algorithms/paper_baselines/adamw/jax/submission.py \
    --tuning_search_space=reference_algorithms/paper_baselines/adamw/tuning_search_space.json
Here is the full traceback:
Traceback (most recent call last):
  File "/home/isaac/miniconda3/envs/main/lib/python3.11/site-packages/jax/_src/api.py", line 1020, in _get_axis_size
    return shape[axis]
           ~~~~~^^^^^^
IndexError: tuple index out of range
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
  File "/home/isaac/Desktop/opensource/isaac_algorithmic_efficiency/algorithmic-efficiency/submission_runner.py", line 714, in <module>
    app.run(main)
  File "/home/isaac/miniconda3/envs/main/lib/python3.11/site-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/home/isaac/miniconda3/envs/main/lib/python3.11/site-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
             ^^^^^^^^^^
  File "/home/isaac/Desktop/opensource/isaac_algorithmic_efficiency/algorithmic-efficiency/submission_runner.py", line 682, in main
    score = score_submission_on_workload(
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/isaac/Desktop/opensource/isaac_algorithmic_efficiency/algorithmic-efficiency/submission_runner.py", line 587, in score_submission_on_workload
    timing, metrics = train_once(workload, workload_name,
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/isaac/Desktop/opensource/isaac_algorithmic_efficiency/algorithmic-efficiency/submission_runner.py", line 351, in train_once
    optimizer_state, model_params, model_state = update_params(
                                                 ^^^^^^^^^^^^^^
  File "/home/isaac/Desktop/opensource/isaac_algorithmic_efficiency/algorithmic-efficiency/reference_algorithms/paper_baselines/adamw/jax/submission.py", line 139, in update_params
    outputs = pmapped_train_step(workload,
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ValueError: pmap was requested to map its argument along axis 0, which implies that its rank should be at least 1, but is only 0 (its shape is ())
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
Metadata
Metadata
Assignees
Labels
No labels