Skip to content

TypeError: JAX encountered invalid PRNG key data: expected key_data.dtype = uint32; got dtype=int32 #801

@init-22

Description

@init-22

I was trying to run the submission_runner.py file inside the docker and got a TypeError,
Use these commands to reproduce the error:

sudo docker run -it   -v <PATH>/algorithmic-efficiency:/algorithmic-efficiency --runtime=nvidia algoperf_pytorch /bin/bash

cd algorithmic-efficiency

python3 submission_runner.py \
    --framework=pytorch \
    --workload=mnist \
    --experiment_dir=$HOME/experiments \
    --experiment_name=my_first_experiment \
    --submission_path=reference_algorithms/paper_baselines/adamw/jax/submission.py \
    --tuning_search_space=reference_algorithms/paper_baselines/adamw/tuning_search_space.json

Here is the traceback:

Traceback (most recent call last):
  File "submission_runner.py", line 714, in <module>
    app.run(main)
  File "/usr/local/lib/python3.8/dist-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/usr/local/lib/python3.8/dist-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
  File "submission_runner.py", line 682, in main
    score = score_submission_on_workload(
  File "submission_runner.py", line 587, in score_submission_on_workload
    timing, metrics = train_once(workload, workload_name,
  File "submission_runner.py", line 351, in train_once
    optimizer_state, model_params, model_state = update_params(
  File "/algorithmic-efficiency/reference_algorithms/paper_baselines/adamw/jax/submission.py", line 130, in update_params
    per_device_rngs = jax.random.split(rng, jax.local_device_count())
  File "/usr/local/lib/python3.8/dist-packages/jax/_src/random.py", line 217, in split
    key, wrapped = _check_prng_key(key)
  File "/usr/local/lib/python3.8/dist-packages/jax/_src/random.py", line 79, in _check_prng_key
    return prng.random_wrap(key, impl=default_prng_impl()), True
  File "/usr/local/lib/python3.8/dist-packages/jax/_src/prng.py", line 907, in random_wrap
    _check_prng_key_data(impl, base_arr)
  File "/usr/local/lib/python3.8/dist-packages/jax/_src/prng.py", line 119, in _check_prng_key_data
    raise TypeError("JAX encountered invalid PRNG key data: expected key_data.dtype = uint32; "
TypeError: JAX encountered invalid PRNG key data: expected key_data.dtype = uint32; got dtype=int32

am I missing something?

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions