Skip to content

Commit 3fe2c76

Browse files
authored
added sageconv lux (#500)
1 parent a034753 commit 3fe2c76

File tree

3 files changed

+54
-1
lines changed

3 files changed

+54
-1
lines changed

GNNLux/src/GNNLux.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ export AGNNConv,
3535
MEGNetConv,
3636
NNConv,
3737
ResGatedGraphConv,
38-
# SAGEConv,
38+
SAGEConv,
3939
SGConv
4040
# TAGConv,
4141
# TransformerConv

GNNLux/src/layers/conv.jl

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -844,3 +844,51 @@ function Base.show(io::IO, l::ResGatedGraphConv)
844844
l.use_bias || print(io, ", use_bias=false")
845845
print(io, ")")
846846
end
847+
848+
@concrete struct SAGEConv <: GNNLayer
849+
in_dims::Int
850+
out_dims::Int
851+
use_bias::Bool
852+
init_weight
853+
init_bias
854+
σ
855+
aggr
856+
end
857+
858+
function SAGEConv(ch::Pair{Int, Int}, σ = identity;
859+
aggr = mean,
860+
init_weight = glorot_uniform,
861+
init_bias = zeros32,
862+
use_bias::Bool = true)
863+
in_dims, out_dims = ch
864+
σ = NNlib.fast_act(σ)
865+
return SAGEConv(in_dims, out_dims, use_bias, init_weight, init_bias, σ, aggr)
866+
end
867+
868+
function LuxCore.initialparameters(rng::AbstractRNG, l::SAGEConv)
869+
weight = l.init_weight(rng, l.out_dims, 2 * l.in_dims)
870+
if l.use_bias
871+
bias = l.init_bias(rng, l.out_dims)
872+
return (; weight, bias)
873+
else
874+
return (; weight)
875+
end
876+
end
877+
878+
LuxCore.parameterlength(l::SAGEConv) = l.use_bias ? l.out_dims * 2 * l.in_dims + l.out_dims :
879+
l.out_dims * 2 * l.in_dims
880+
LuxCore.outputsize(d::SAGEConv) = (d.out_dims,)
881+
882+
function Base.show(io::IO, l::SAGEConv)
883+
print(io, "SAGEConv(", l.in_dims, " => ", l.out_dims)
884+
(l.σ == identity) || print(io, ", ", l.σ)
885+
(l.aggr == mean) || print(io, ", aggr=", l.aggr)
886+
l.use_bias || print(io, ", use_bias=false")
887+
print(io, ")")
888+
end
889+
890+
function (l::SAGEConv)(g, x, ps, st)
891+
m = (; ps.weight, bias = _getbias(ps),
892+
l.σ, l.aggr)
893+
return GNNlib.sage_conv(m, g, x), st
894+
end

GNNLux/test/layers/conv_tests.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,4 +134,9 @@
134134
l = ResGatedGraphConv(in_dims => out_dims, tanh)
135135
test_lux_layer(rng, l, g, x, outputsize=(out_dims,))
136136
end
137+
138+
@testset "SAGEConv" begin
139+
l = SAGEConv(in_dims => out_dims, tanh)
140+
test_lux_layer(rng, l, g, x, outputsize=(out_dims,))
141+
end
137142
end

0 commit comments

Comments
 (0)