Skip to content

Commit 4b32e2f

Browse files
committed
fixing
1 parent d136de0 commit 4b32e2f

File tree

2 files changed

+129
-3
lines changed

2 files changed

+129
-3
lines changed

GNNLux/src/layers/conv.jl

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -666,11 +666,8 @@ function (l::NNConv)(g, x, edge_weight, ps, st)
666666
return y, stnew
667667
end
668668

669-
LuxCore.outputsize(d::NNConv) = (d.out_dims,)
670-
671669
function Base.show(io::IO, l::NNConv)
672670
print(io, "NNConv($(l.nn)")
673-
print(io, ", $(l.ϵ)")
674671
l.σ == identity || print(io, ", ", l.σ)
675672
l.use_bias || print(io, ", use_bias=false")
676673
l.add_self_loops || print(io, ", add_self_loops=false")

temp.jl

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
2+
using StableRNGs
3+
using LuxTestUtils: test_gradients, AutoReverseDiff, AutoTracker, AutoForwardDiff, AutoEnzyme
4+
5+
using Lux: Lux, Chain, Dense, GRUCell,
6+
glorot_uniform, zeros32 ,
7+
StatefulLuxLayer
8+
9+
import Reexport: @reexport
10+
11+
@reexport using Test
12+
@reexport using GNNLux
13+
@reexport using Lux
14+
@reexport using StableRNGs
15+
@reexport using Random, Statistics
16+
17+
using LuxTestUtils: test_gradients, AutoReverseDiff, AutoTracker, AutoForwardDiff, AutoEnzyme
18+
19+
20+
rng = StableRNG(1234)
21+
edim = 10
22+
g = rand_graph(10, 40)
23+
in_dims = 3
24+
out_dims = 5
25+
x = randn(Float32, in_dims, 10)
26+
27+
g2 = GNNGraph(g, edata = rand(Float32, edim, g.num_edges))
28+
29+
@testset "GINConv" begin
30+
nn = Chain(Dense(in_dims => out_dims, relu), Dense(out_dims => out_dims))
31+
l = GINConv(nn, 0.5)
32+
test_lux_layer(rng, l, g, x, sizey=(out_dims,g.num_nodes), container=true)
33+
end
34+
35+
36+
37+
function test_lux_layer(rng::AbstractRNG, l, g::GNNGraph, x;
38+
outputsize=nothing, sizey=nothing, container=false,
39+
atol=1.0f-2, rtol=1.0f-2, edge_weight=nothing)
40+
41+
if container
42+
@test l isa GNNContainerLayer
43+
else
44+
@test l isa GNNLayer
45+
end
46+
47+
ps = LuxCore.initialparameters(rng, l)
48+
st = LuxCore.initialstates(rng, l)
49+
@test LuxCore.parameterlength(l) == LuxCore.parameterlength(ps)
50+
@test LuxCore.statelength(l) == LuxCore.statelength(st)
51+
52+
if edge_weight !== nothing
53+
y, st′ = l(g, x, edge_weight, ps, st)
54+
else
55+
y, st′ = l(g, x, ps, st)
56+
end
57+
@test eltype(y) == eltype(x)
58+
if outputsize !== nothing
59+
@test LuxCore.outputsize(l) == outputsize
60+
end
61+
if sizey !== nothing
62+
@test size(y) == sizey
63+
elseif outputsize !== nothing
64+
@test size(y) == (outputsize..., g.num_nodes)
65+
end
66+
67+
loss = (x, ps) -> sum(first(l(g, x, ps, st)))
68+
test_gradients(loss, x, ps; atol, rtol, skip_backends=[AutoReverseDiff(), AutoTracker(), AutoForwardDiff(), AutoEnzyme()])
69+
end
70+
71+
72+
"""
73+
MEGNetConv{Flux.Chain{Tuple{Flux.Dense{typeof(relu), Matrix{Float32}, Vector{Float32}}, Flux.Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, Flux.Chain{Tuple{Flux.Dense{typeof(relu), Matrix{Float32}, Vector{Float32}}, Flux.Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, typeof(mean)}(Chain(Dense(9 => 5, relu), Dense(5 => 5)), Chain(Dense(8 => 5, relu), Dense(5 => 5)), Statistics.mean)
74+
"""
75+
76+
g = rand_graph(10, 40, seed=1234)
77+
in_dims = 3
78+
out_dims = 5
79+
x = randn(Float32, in_dims, 10)
80+
rng = StableRNG(1234)
81+
l = MEGNetConv(in_dims => out_dims)
82+
l
83+
l isa GNNContainerLayer
84+
test_lux_layer(rng, l, g, x, sizey=((out_dims, g.num_nodes), (out_dims, g.num_edges)), container=true)
85+
86+
87+
ps = LuxCore.initialparameters(rng, l)
88+
st = LuxCore.initialstates(rng, l)
89+
edata = rand(T, in_channel, g.num_edges)
90+
91+
(x_new, e_new), st_new = l(g, x, ps, st)
92+
93+
@test size(x_new) == (out_dims, g.num_nodes)
94+
@test size(e_new) == (out_dims, g.num_edges)
95+
96+
97+
98+
99+
edim = 10
100+
in_dims = 3 # Example
101+
out_dims = 5 # Example
102+
using Flux
103+
g2 = GNNGraph(g, edata = rand(Float32, edim, g.num_edges))
104+
nn = Dense(edim, out_dims * in_dims)
105+
l = NNConv(in_dims => out_dims, nn, tanh)
106+
test_lux_layer(rng, l, g2, x, sizey=(out_dims, g2.num_nodes), container=true, edge_weight=g2.edata.e)
107+
108+
109+
110+
hin = 6
111+
hout = 7
112+
hidden = 8
113+
l = EGNNConv(hin => hout, hidden)
114+
ps = LuxCore.initialparameters(rng, l)
115+
st = LuxCore.initialstates(rng, l)
116+
h = randn(rng, Float32, hin, g.num_nodes)
117+
(hnew, xnew), stnew = l(g, h, x, ps, st)
118+
@test size(hnew) == (hout, g.num_nodes)
119+
@test size(xnew) == (in_dims, g.num_nodes)
120+
121+
122+
l = MEGNetConv(in_dims => out_dims)
123+
l
124+
l isa GNNContainerLayer
125+
test_lux_layer(rng, l, g, x, sizey=((out_dims, g.num_nodes), (out_dims, g.num_edges)), container=true)
126+
127+
128+
ps = LuxCore.initialparameters(rng, l)
129+
st = LuxCore.initialstates(rng, l)

0 commit comments

Comments
 (0)