We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
TransformerConv
1 parent 470c988 commit dd0ba85Copy full SHA for dd0ba85
torch_geometric/nn/conv/transformer_conv.py
@@ -126,9 +126,11 @@ def __init__(
126
if isinstance(in_channels, int):
127
in_channels = (in_channels, in_channels)
128
129
- self.lin_key = Linear(in_channels[0], heads * out_channels)
130
- self.lin_query = Linear(in_channels[1], heads * out_channels)
131
- self.lin_value = Linear(in_channels[0], heads * out_channels)
+ self.lin_key = Linear(in_channels[0], heads * out_channels, bias=bias)
+ self.lin_query = Linear(in_channels[1], heads * out_channels,
+ bias=bias)
132
+ self.lin_value = Linear(in_channels[0], heads * out_channels,
133
134
if edge_dim is not None:
135
self.lin_edge = Linear(edge_dim, heads * out_channels, bias=False)
136
else:
0 commit comments