-
Hi, Can I ask how to enable float64 globally? I tried to use class NN(nn.Module):
dtype = jax.numpy.float64
@nn.compact
def __call__(self, x):
x = nn.Dense(features=10, dtype=jax.numpy.float64)(x)
x = nn.gelu(x)
x = nn.Dense(features=5, dtype=jax.numpy.float64)(x)
x = nn.gelu(x)
x = nn.Dense(features=1, dtype=jax.numpy.float64)(x)
return x
model = NN()
params = model.init(key, dataset.xs)
print(params) |
Beta Was this translation helpful? Give feedback.
Answered by
soraros
Apr 19, 2023
Replies: 1 comment 1 reply
-
Flax defaults parameter dtype to |
Beta Was this translation helpful? Give feedback.
1 reply
Answer selected by
zgbkdlm
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Flax defaults parameter dtype to
flat32
which can be changed in your case by settingparam_dtype
tonp.float64
.