-
Notifications
You must be signed in to change notification settings - Fork 6
Description
dear authors,
when using revlib in torch amp, it reports error as follow:
Traceback (most recent call last):
File "/home/fqm/.conda/envs/torch/lib/python3.9/site-packages/revlib/core.py", line 130, in backward
mod_out = take_0th_tensor(new_mod.wrapped_module(y0, *ctx.args, **ctx.kwargs))
File "/home/fqm/.conda/envs/torch/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/fqm/.conda/envs/torch/lib/python3.9/site-packages/torch/nn/modules/container.py", line 217, in forward
input = module(input)
File "/home/fqm/.conda/envs/torch/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/fqm/.conda/envs/torch/lib/python3.9/site-packages/torch/nn/modules/conv.py", line 613, in forward
return self._conv_forward(input, self.weight, self.bias)
File "/home/fqm/.conda/envs/torch/lib/python3.9/site-packages/torch/nn/modules/conv.py", line 597, in _conv_forward
return F.conv3d(
RuntimeError: Input type (torch.cuda.HalfTensor) and weight type (torch.cuda.FloatTensor) should be the same