@@ -56,4 +56,40 @@ function Base.show(io::IO, tgcn::TGCNCell)
56
56
print (io, " TGCNCell($(tgcn. in_dims) => $(tgcn. out_dims) )" )
57
57
end
58
58
59
- TGCN (ch:: Pair{Int, Int} ; kwargs... ) = GNNLux. StatefulRecurrentCell (TGCNCell (ch; kwargs... ))
59
+ TGCN (ch:: Pair{Int, Int} ; kwargs... ) = GNNLux. StatefulRecurrentCell (TGCNCell (ch; kwargs... ))
60
+
61
+ @concrete struct A3TGCN <: GNNContainerLayer{(:tgcn, :dense1, :dense2)}
62
+ in_dims:: Int
63
+ out_dims:: Int
64
+ tgcn
65
+ dense1
66
+ dense2
67
+ end
68
+
69
+ function A3TGCN (ch:: Pair{Int, Int} ; use_bias = true , init_weight = glorot_uniform, init_state = zeros32, init_bias = zeros32, add_self_loops = false , use_edge_weight = true )
70
+ in_dims, out_dims = ch
71
+ tgcn = TGCN (ch; use_bias, init_weight, init_state, init_bias, add_self_loops, use_edge_weight)
72
+ dense1 = Dense (out_dims, out_dims)
73
+ dense2 = Dense (out_dims, out_dims)
74
+ return A3TGCN (in_dims, out_dims, tgcn, dense1, dense2)
75
+ end
76
+
77
+ function (l:: A3TGCN )(g, x, ps, st)
78
+ dense1 = StatefulLuxLayer {true} (l. dense1, ps. dense1, _getstate (st, :dense1 ))
79
+ dense2 = StatefulLuxLayer {true} (l. dense2, ps. dense2, _getstate (st, :dense2 ))
80
+ h, st = l. tgcn (g, x, ps. tgcn, st. tgcn)
81
+ x = dense1 (h)
82
+ x = dense2 (x)
83
+ a = NNlib. softmax (x, dims = 3 )
84
+ c = sum (a .* h , dims = 3 )
85
+ if length (size (c)) == 3
86
+ c = dropdims (c, dims = 3 )
87
+ end
88
+ return c, st
89
+ end
90
+
91
+ LuxCore. outputsize (l:: A3TGCN ) = (l. out_dims,)
92
+
93
+ function Base. show (io:: IO , l:: A3TGCN )
94
+ print (io, " A3TGCN($(l. in_dims) => $(l. out_dims) )" )
95
+ end
0 commit comments