Skip to content

Commit dd0ba85

Browse files
authored
Add bias term in TransformerConv (#10177)
close #10163 and #10130
1 parent 470c988 commit dd0ba85

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

torch_geometric/nn/conv/transformer_conv.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -126,9 +126,11 @@ def __init__(
126126
if isinstance(in_channels, int):
127127
in_channels = (in_channels, in_channels)
128128

129-
self.lin_key = Linear(in_channels[0], heads * out_channels)
130-
self.lin_query = Linear(in_channels[1], heads * out_channels)
131-
self.lin_value = Linear(in_channels[0], heads * out_channels)
129+
self.lin_key = Linear(in_channels[0], heads * out_channels, bias=bias)
130+
self.lin_query = Linear(in_channels[1], heads * out_channels,
131+
bias=bias)
132+
self.lin_value = Linear(in_channels[0], heads * out_channels,
133+
bias=bias)
132134
if edge_dim is not None:
133135
self.lin_edge = Linear(edge_dim, heads * out_channels, bias=False)
134136
else:

0 commit comments

Comments
 (0)