1+ 
2+ using  StableRNGs
3+ using  LuxTestUtils:  test_gradients, AutoReverseDiff, AutoTracker, AutoForwardDiff, AutoEnzyme
4+ 
5+ using  Lux:  Lux, Chain, Dense, GRUCell,
6+            glorot_uniform, zeros32  , 
7+            StatefulLuxLayer
8+ 
9+ import  Reexport:  @reexport 
10+ 
11+ @reexport  using  Test
12+ @reexport  using  GNNLux
13+ @reexport  using  Lux
14+ @reexport  using  StableRNGs
15+ @reexport  using  Random, Statistics
16+ 
17+ using  LuxTestUtils:  test_gradients, AutoReverseDiff, AutoTracker, AutoForwardDiff, AutoEnzyme
18+ 
19+ 
20+ rng =  StableRNG (1234 )
21+     edim =  10 
22+     g =  rand_graph (10 , 40 )
23+     in_dims =  3 
24+     out_dims =  5 
25+     x =  randn (Float32, in_dims, 10 )
26+ 
27+     g2 =  GNNGraph (g, edata =  rand (Float32, edim, g. num_edges)) 
28+ 
29+ @testset  " GINConv" begin 
30+     nn =  Chain (Dense (in_dims =>  out_dims, relu), Dense (out_dims =>  out_dims))
31+     l =  GINConv (nn, 0.5 )
32+     test_lux_layer (rng, l, g, x, sizey= (out_dims,g. num_nodes), container= true )
33+ end 
34+ 
35+ 
36+ 
37+ function  test_lux_layer (rng:: AbstractRNG , l, g:: GNNGraph , x; 
38+             outputsize= nothing , sizey= nothing , container= false ,
39+             atol= 1.0f-2 , rtol= 1.0f-2 , edge_weight= nothing ) 
40+ 
41+     if  container
42+         @test  l isa  GNNContainerLayer
43+     else 
44+         @test  l isa  GNNLayer
45+     end 
46+ 
47+     ps =  LuxCore. initialparameters (rng, l)
48+     st =  LuxCore. initialstates (rng, l)
49+     @test  LuxCore. parameterlength (l) ==  LuxCore. parameterlength (ps)
50+     @test  LuxCore. statelength (l) ==  LuxCore. statelength (st)
51+     
52+     if  edge_weight != =  nothing 
53+         y, st′ =  l (g, x, edge_weight, ps, st)  
54+     else 
55+         y, st′ =  l (g, x, ps, st)  
56+     end 
57+     @test  eltype (y) ==  eltype (x)
58+     if  outputsize != =  nothing 
59+         @test  LuxCore. outputsize (l) ==  outputsize
60+     end 
61+     if  sizey != =  nothing 
62+         @test  size (y) ==  sizey
63+     elseif  outputsize != =  nothing 
64+         @test  size (y) ==  (outputsize... , g. num_nodes)
65+     end 
66+     
67+     loss =  (x, ps) ->  sum (first (l (g, x, ps, st)))
68+     test_gradients (loss, x, ps; atol, rtol, skip_backends= [AutoReverseDiff (), AutoTracker (), AutoForwardDiff (), AutoEnzyme ()])
69+ end 
70+ 
71+ 
72+ """ 
73+ MEGNetConv{Flux.Chain{Tuple{Flux.Dense{typeof(relu), Matrix{Float32}, Vector{Float32}}, Flux.Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, Flux.Chain{Tuple{Flux.Dense{typeof(relu), Matrix{Float32}, Vector{Float32}}, Flux.Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, typeof(mean)}(Chain(Dense(9 => 5, relu), Dense(5 => 5)), Chain(Dense(8 => 5, relu), Dense(5 => 5)), Statistics.mean) 
74+ """ 
75+ 
76+ g =  rand_graph (10 , 40 , seed= 1234 )
77+     in_dims =  3 
78+     out_dims =  5 
79+     x =  randn (Float32, in_dims, 10 )
80+     rng =  StableRNG (1234 )
81+     l =  MEGNetConv (in_dims =>  out_dims)
82+     l
83+     l isa  GNNContainerLayer
84+     test_lux_layer (rng, l, g, x, sizey= ((out_dims, g. num_nodes), (out_dims, g. num_edges)), container= true )
85+ 
86+ 
87+         ps =  LuxCore. initialparameters (rng, l)
88+         st =  LuxCore. initialstates (rng, l)
89+         edata =  rand (T, in_channel, g. num_edges)
90+ 
91+         (x_new, e_new), st_new =  l (g, x, ps, st)
92+ 
93+         @test  size (x_new) ==  (out_dims, g. num_nodes)
94+         @test  size (e_new) ==  (out_dims, g. num_edges)
95+ 
96+ 
97+ 
98+ 
99+            edim =  10 
100+            in_dims =  3  #  Example
101+            out_dims =  5  #  Example
102+ using  Flux           
103+     g2 =  GNNGraph (g, edata =  rand (Float32, edim, g. num_edges))
104+            nn =  Dense (edim, out_dims *  in_dims) 
105+            l =  NNConv (in_dims =>  out_dims, nn, tanh) 
106+            test_lux_layer (rng, l, g2, x, sizey= (out_dims, g2. num_nodes), container= true , edge_weight= g2. edata. e) 
107+ 
108+ 
109+ 
110+     hin =  6 
111+     hout =  7 
112+     hidden =  8 
113+     l =  EGNNConv (hin =>  hout, hidden)
114+     ps =  LuxCore. initialparameters (rng, l)
115+     st =  LuxCore. initialstates (rng, l)
116+     h =  randn (rng, Float32, hin, g. num_nodes)
117+     (hnew, xnew), stnew =  l (g, h, x, ps, st)
118+     @test  size (hnew) ==  (hout, g. num_nodes)
119+     @test  size (xnew) ==  (in_dims, g. num_nodes)
120+ 
121+ 
122+     l =  MEGNetConv (in_dims =>  out_dims)
123+     l
124+     l isa  GNNContainerLayer
125+     test_lux_layer (rng, l, g, x, sizey= ((out_dims, g. num_nodes), (out_dims, g. num_edges)), container= true )
126+ 
127+ 
128+         ps =  LuxCore. initialparameters (rng, l)
129+         st =  LuxCore. initialstates (rng, l)
0 commit comments