-
Notifications
You must be signed in to change notification settings - Fork 75
Closed
Labels
Description
I was trying to run the submission_runner.py file inside the docker and got a TypeError,
Use these commands to reproduce the error:
sudo docker run -it -v <PATH>/algorithmic-efficiency:/algorithmic-efficiency --runtime=nvidia algoperf_pytorch /bin/bash
cd algorithmic-efficiency
python3 submission_runner.py \
--framework=pytorch \
--workload=mnist \
--experiment_dir=$HOME/experiments \
--experiment_name=my_first_experiment \
--submission_path=reference_algorithms/paper_baselines/adamw/jax/submission.py \
--tuning_search_space=reference_algorithms/paper_baselines/adamw/tuning_search_space.json
Here is the traceback:
Traceback (most recent call last):
File "submission_runner.py", line 714, in <module>
app.run(main)
File "/usr/local/lib/python3.8/dist-packages/absl/app.py", line 308, in run
_run_main(main, args)
File "/usr/local/lib/python3.8/dist-packages/absl/app.py", line 254, in _run_main
sys.exit(main(argv))
File "submission_runner.py", line 682, in main
score = score_submission_on_workload(
File "submission_runner.py", line 587, in score_submission_on_workload
timing, metrics = train_once(workload, workload_name,
File "submission_runner.py", line 351, in train_once
optimizer_state, model_params, model_state = update_params(
File "/algorithmic-efficiency/reference_algorithms/paper_baselines/adamw/jax/submission.py", line 130, in update_params
per_device_rngs = jax.random.split(rng, jax.local_device_count())
File "/usr/local/lib/python3.8/dist-packages/jax/_src/random.py", line 217, in split
key, wrapped = _check_prng_key(key)
File "/usr/local/lib/python3.8/dist-packages/jax/_src/random.py", line 79, in _check_prng_key
return prng.random_wrap(key, impl=default_prng_impl()), True
File "/usr/local/lib/python3.8/dist-packages/jax/_src/prng.py", line 907, in random_wrap
_check_prng_key_data(impl, base_arr)
File "/usr/local/lib/python3.8/dist-packages/jax/_src/prng.py", line 119, in _check_prng_key_data
raise TypeError("JAX encountered invalid PRNG key data: expected key_data.dtype = uint32; "
TypeError: JAX encountered invalid PRNG key data: expected key_data.dtype = uint32; got dtype=int32
am I missing something?