@@ -844,6 +844,7 @@ function Base.show(io::IO, l::ResGatedGraphConv)
844844 l. use_bias || print (io, " , use_bias=false" )
845845 print (io, " )" )
846846end
847+
847848@concrete struct TransformerConv <: GNNContainerLayer{(:W1, :W2, :W3, :W4, :W5, :W6, :FF, :BN1, :BN2)}
848849 in_dims:: NTuple{2, Int}
849850 out_dims:: Int
864865end
865866
866867function TransformerConv (ch:: Pair{Int, Int} , args... ; kws... )
867- TransformerConv ((ch[1 ], 0 ) => ch[2 ], args... ; kws... )
868+ return TransformerConv ((ch[1 ], 0 ) => ch[2 ], args... ; kws... )
868869end
869870
870871function TransformerConv (ch:: Pair{NTuple{2, Int}, Int} ;
@@ -880,21 +881,19 @@ function TransformerConv(ch::Pair{NTuple{2, Int}, Int};
880881 skip_connection:: Bool = false ,
881882 batch_norm:: Bool = false ,
882883 ff_channels:: Int = 0 )
883-
884884 (in, ein), out = ch
885885
886886 if add_self_loops
887887 @assert iszero (ein) " Using edge features and setting add_self_loops=true at the same time is not yet supported."
888888 end
889889
890- W1 = root_weight ?
891- Dense (in => out * (concat ? heads : 1 ); use_bias = bias_root, init_weight, init_bias) : nothing
892- W2 = Dense (in => out * heads; use_bias = bias_qkv, init_weight, init_bias)
893- W3 = Dense (in => out * heads; use_bias = bias_qkv, init_weight, init_bias)
894- W4 = Dense (in => out * heads; use_bias = bias_qkv, init_weight, init_bias)
890+ W1 = root_weight ? Dense (in => out * (concat ? heads : 1 ); use_bias= bias_root, init_weight, init_bias) : nothing
891+ W2 = Dense (in => out * heads; use_bias= bias_qkv, init_weight, init_bias)
892+ W3 = Dense (in => out * heads; use_bias= bias_qkv, init_weight, init_bias)
893+ W4 = Dense (in => out * heads; use_bias= bias_qkv, init_weight, init_bias)
895894 out_mha = out * (concat ? heads : 1 )
896- W5 = gating ? Dense (3 * out_mha => 1 , sigmoid; use_bias = false , init_weight, init_bias) : nothing
897- W6 = ein > 0 ? Dense (ein => out * heads; use_bias = bias_qkv, init_weight, init_bias) : nothing
895+ W5 = gating ? Dense (3 * out_mha => 1 , sigmoid; use_bias= false , init_weight, init_bias) : nothing
896+ W6 = ein > 0 ? Dense (ein => out * heads; use_bias= bias_qkv, init_weight, init_bias) : nothing
898897 FF = ff_channels > 0 ?
899898 Chain (Dense (out_mha => ff_channels, relu; init_weight, init_bias),
900899 Dense (ff_channels => out_mha; init_weight, init_bias)) : nothing
@@ -905,10 +904,10 @@ function TransformerConv(ch::Pair{NTuple{2, Int}, Int};
905904 skip_connection, Float32 (√ out), W1, W2, W3, W4, W5, W6, FF, BN1, BN2)
906905end
907906
908- LuxCore. outputsize (l:: TransformerConv ) = (l. out_dims,)
907+ LuxCore. outputsize (l:: TransformerConv ) = (l. concat ? l . out_dims * l . heads : l . out_dims,)
909908
910909function (l:: TransformerConv )(g, x, ps, st)
911- l (g, x, nothing , ps, st)
910+ return l (g, x, nothing , ps, st)
912911end
913912
914913function (l:: TransformerConv )(g, x, e, ps, st)
@@ -933,21 +932,21 @@ function (l::TransformerConv)(g, x, e, ps, st)
933932end
934933
935934function LuxCore. parameterlength (l:: TransformerConv )
936- n = parameterlength (l. W1 ) + parameterlength (l. W2 ) +
937- parameterlength (l . W3) + parameterlength (l. W4) +
938- parameterlength ( l. W5) + parameterlength (l. W6)
939-
935+ n = parameterlength (l. W2 ) + parameterlength (l. W3 ) + parameterlength (l . W4)
936+ n += l . W1 === nothing ? 0 : parameterlength (l. W1)
937+ n += l. W5 === nothing ? 0 : parameterlength (l. W5)
938+ n += l . W6 === nothing ? 0 : parameterlength (l . W6)
940939 n += l. FF === nothing ? 0 : parameterlength (l. FF)
941940 n += l. BN1 === nothing ? 0 : parameterlength (l. BN1)
942941 n += l. BN2 === nothing ? 0 : parameterlength (l. BN2)
943942 return n
944943end
945944
946945function LuxCore. statelength (l:: TransformerConv )
947- n = statelength (l. W1 ) + statelength (l. W2 ) +
948- statelength (l . W3) + statelength (l. W4) +
949- statelength ( l. W5) + statelength (l. W6)
950-
946+ n = statelength (l. W2 ) + statelength (l. W3 ) + statelength (l . W4)
947+ n += l . W1 === nothing ? 0 : statelength (l. W1)
948+ n += l. W5 === nothing ? 0 : statelength (l. W5)
949+ n += l . W6 === nothing ? 0 : statelength (l . W6)
951950 n += l. FF === nothing ? 0 : statelength (l. FF)
952951 n += l. BN1 === nothing ? 0 : statelength (l. BN1)
953952 n += l. BN2 === nothing ? 0 : statelength (l. BN2)
0 commit comments