Skip to content

Conversation

@yuanyao-nv
Copy link
Contributor

Previously, the bias shape of ConvTranpose was wrong since, unlike Conv, it should using the 1st dimension of weight shape and not the 0th. See described in #1299

In other words,

    if bias is None:
        weight_dim_0 = op.Shape(weight, start=0, end=1)
        bias_shape = op.Expand(weight_dim_0, op.Constant(value_ints=[1]))
        zero = op.CastLike(0.0, input)
        bias = op.Expand(zero, bias_shape)

should be changed to something like:

weight_dim_0 = op.Shape(weight, start=1, end=2) if transposed else op.Shape(weight, start=0, end=1)

However, I think it's more efficient to just eliminate bias altogether if it's not provided instead of filling it with zeros, since the ONNX spec allows bias to be absent.

@justinchuby
Copy link
Collaborator

Amazing, thanks!

@justinchuby justinchuby added the module: torchlib Related to the torch/aten function lib in development label Oct 9, 2024
@justinchuby justinchuby changed the title Fix wrong bias shape of ConvTranspose [torchlib] Fix wrong bias shape of ConvTranspose Oct 9, 2024
@codecov
Copy link

codecov bot commented Oct 9, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 75.09%. Comparing base (a7c797d) to head (afc0b5f).
Report is 1 commits behind head on main.

Additional details and impacted files
@@           Coverage Diff           @@
##             main    #1901   +/-   ##
=======================================
  Coverage   75.08%   75.09%           
=======================================
  Files         252      252           
  Lines       27417    27412    -5     
  Branches     3190     3189    -1     
=======================================
- Hits        20587    20584    -3     
+ Misses       5880     5878    -2     
  Partials      950      950           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@justinchuby justinchuby merged commit 12f9209 into microsoft:main Oct 10, 2024
27 of 41 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

module: torchlib Related to the torch/aten function lib in development

Projects

Development

Successfully merging this pull request may close these issues.

2 participants