-
Notifications
You must be signed in to change notification settings - Fork 755
Closed
Description
When using Flax’s Conv and ConvTranspose layers in pair, the ConvTranspose does not seem to correctly restore the original input shape, even when the parameters are set in a way that should theoretically allow this. This behavior differs from PyTorch, where ConvXd and ConvTransposeXd used together reliably restore the input shape.
ConvTranspose function appears to produce incorrect output shapes, sometimes resulting in dimensions collapsing to zero. This behavior is not just a mismatch with PyTorch, but makes the function effectively unusable in certain cases.
Reproduction Example
from jax import random
from flax import nnx
import torch
from torch import nn
key = random.PRNGKey(42)
batch_size = 4
in_channels = 128
out_channels = 32
i = 4
k = 3
s = 1
p = 0
# ============= Flax ===========================
x = random.uniform(key, shape=(batch_size, i, i, in_channels))
conv = nnx.Conv(in_features=in_channels,
out_features=out_channels,
kernel_size=(k, k),
strides=(s, s),
padding=p,
rngs=nnx.Rngs(0))
y = conv(x)
print(y.shape) # (4, 2, 2, 32)
assert y.shape[2] == 2
tconv = nnx.ConvTranspose(in_features=out_channels,
out_features=in_channels,
kernel_size=(k, k),
strides=(s, s),
padding=p,
rngs=nnx.Rngs(0))
z = tconv(y)
print(z.shape) # (4, 0, 0, 128)
if z.shape[2] != i:
print(f"Flax transConv failed to restore original input shape.")
# ============= PyTorch ========================
x = torch.rand(batch_size, in_channels, i, i)
conv = nn.Conv2d(in_channels=in_channels,
out_channels=out_channels,
kernel_size=k,
stride=s,
padding=p)
y = conv(x)
print(y.shape) # torch.Size([4, 32, 2, 2])
assert y.shape == (batch_size, out_channels, 2, 2)
kp = k
sp = s
pp = k - 1
ip = 2
op = ip + (k-1)
tconv = nn.ConvTranspose2d(in_channels=out_channels,
out_channels=in_channels,
kernel_size=k,
stride=s,
padding=p)
z = tconv(y)
print(z.shape) # torch.Size([4, 128, 4, 4])
assert z.shape[2] == i
Metadata
Metadata
Assignees
Labels
No labels