Skip to content

Commit 7bc5c42

Browse files
authored
Update conv.jl
1 parent 1bc673a commit 7bc5c42

File tree

1 file changed

+18
-19
lines changed

1 file changed

+18
-19
lines changed

GNNLux/src/layers/conv.jl

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -844,6 +844,7 @@ function Base.show(io::IO, l::ResGatedGraphConv)
844844
l.use_bias || print(io, ", use_bias=false")
845845
print(io, ")")
846846
end
847+
847848
@concrete struct TransformerConv <: GNNContainerLayer{(:W1, :W2, :W3, :W4, :W5, :W6, :FF, :BN1, :BN2)}
848849
in_dims::NTuple{2, Int}
849850
out_dims::Int
@@ -864,7 +865,7 @@ end
864865
end
865866

866867
function TransformerConv(ch::Pair{Int, Int}, args...; kws...)
867-
TransformerConv((ch[1], 0) => ch[2], args...; kws...)
868+
return TransformerConv((ch[1], 0) => ch[2], args...; kws...)
868869
end
869870

870871
function TransformerConv(ch::Pair{NTuple{2, Int}, Int};
@@ -880,21 +881,19 @@ function TransformerConv(ch::Pair{NTuple{2, Int}, Int};
880881
skip_connection::Bool = false,
881882
batch_norm::Bool = false,
882883
ff_channels::Int = 0)
883-
884884
(in, ein), out = ch
885885

886886
if add_self_loops
887887
@assert iszero(ein) "Using edge features and setting add_self_loops=true at the same time is not yet supported."
888888
end
889889

890-
W1 = root_weight ?
891-
Dense(in => out * (concat ? heads : 1); use_bias = bias_root, init_weight, init_bias) : nothing
892-
W2 = Dense(in => out * heads; use_bias = bias_qkv, init_weight, init_bias)
893-
W3 = Dense(in => out * heads; use_bias = bias_qkv, init_weight, init_bias)
894-
W4 = Dense(in => out * heads; use_bias = bias_qkv, init_weight, init_bias)
890+
W1 = root_weight ? Dense(in => out * (concat ? heads : 1); use_bias=bias_root, init_weight, init_bias) : nothing
891+
W2 = Dense(in => out * heads; use_bias=bias_qkv, init_weight, init_bias)
892+
W3 = Dense(in => out * heads; use_bias=bias_qkv, init_weight, init_bias)
893+
W4 = Dense(in => out * heads; use_bias=bias_qkv, init_weight, init_bias)
895894
out_mha = out * (concat ? heads : 1)
896-
W5 = gating ? Dense(3 * out_mha => 1, sigmoid; use_bias = false, init_weight, init_bias) : nothing
897-
W6 = ein > 0 ? Dense(ein => out * heads; use_bias = bias_qkv, init_weight, init_bias) : nothing
895+
W5 = gating ? Dense(3 * out_mha => 1, sigmoid; use_bias=false, init_weight, init_bias) : nothing
896+
W6 = ein > 0 ? Dense(ein => out * heads; use_bias=bias_qkv, init_weight, init_bias) : nothing
898897
FF = ff_channels > 0 ?
899898
Chain(Dense(out_mha => ff_channels, relu; init_weight, init_bias),
900899
Dense(ff_channels => out_mha; init_weight, init_bias)) : nothing
@@ -905,10 +904,10 @@ function TransformerConv(ch::Pair{NTuple{2, Int}, Int};
905904
skip_connection, Float32(out), W1, W2, W3, W4, W5, W6, FF, BN1, BN2)
906905
end
907906

908-
LuxCore.outputsize(l::TransformerConv) = (l.out_dims,)
907+
LuxCore.outputsize(l::TransformerConv) = (l.concat ? l.out_dims * l.heads : l.out_dims,)
909908

910909
function (l::TransformerConv)(g, x, ps, st)
911-
l(g, x, nothing, ps, st)
910+
return l(g, x, nothing, ps, st)
912911
end
913912

914913
function (l::TransformerConv)(g, x, e, ps, st)
@@ -933,21 +932,21 @@ function (l::TransformerConv)(g, x, e, ps, st)
933932
end
934933

935934
function LuxCore.parameterlength(l::TransformerConv)
936-
n = parameterlength(l.W1) + parameterlength(l.W2) +
937-
parameterlength(l.W3) + parameterlength(l.W4) +
938-
parameterlength(l.W5) + parameterlength(l.W6)
939-
935+
n = parameterlength(l.W2) + parameterlength(l.W3) + parameterlength(l.W4)
936+
n += l.W1 === nothing ? 0 : parameterlength(l.W1)
937+
n += l.W5 === nothing ? 0 : parameterlength(l.W5)
938+
n += l.W6 === nothing ? 0 : parameterlength(l.W6)
940939
n += l.FF === nothing ? 0 : parameterlength(l.FF)
941940
n += l.BN1 === nothing ? 0 : parameterlength(l.BN1)
942941
n += l.BN2 === nothing ? 0 : parameterlength(l.BN2)
943942
return n
944943
end
945944

946945
function LuxCore.statelength(l::TransformerConv)
947-
n = statelength(l.W1) + statelength(l.W2) +
948-
statelength(l.W3) + statelength(l.W4) +
949-
statelength(l.W5) + statelength(l.W6)
950-
946+
n = statelength(l.W2) + statelength(l.W3) + statelength(l.W4)
947+
n += l.W1 === nothing ? 0 : statelength(l.W1)
948+
n += l.W5 === nothing ? 0 : statelength(l.W5)
949+
n += l.W6 === nothing ? 0 : statelength(l.W6)
951950
n += l.FF === nothing ? 0 : statelength(l.FF)
952951
n += l.BN1 === nothing ? 0 : statelength(l.BN1)
953952
n += l.BN2 === nothing ? 0 : statelength(l.BN2)

0 commit comments

Comments
 (0)