Skip to content

Commit 694a33e

Browse files
authored
SpMM message passing CUDA support for coalesced COO graphs (#617)
* Enhance CUDA support by updating adjacency_matrix and propagate functions for COO graphs * Swap edge encoding order in coalesce to fix CUDA.jl issue * Update comments to clarify coalesce behavior * Add custom _adjacency_matrix for propagate CUDA COO graphs - Leave public adjacency_matrix interface uniform, always returning a sparse adjacency_matrix - Implement custom _adjacency_matrix for propagate copy_xj for CUDA COO graphs, converting to dense when more efficient * Fix imports * Update GPU compatibility checks for COO CUDA * Add @non_differentiable annotation to _adjacency_matrix function * Add tests for coalesced COO graphs * Remove debug statements
1 parent a0a7fad commit 694a33e

File tree

10 files changed

+80
-20
lines changed

10 files changed

+80
-20
lines changed

GNNGraphs/ext/GNNGraphsCUDAExt.jl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,10 @@ using Random, Statistics, LinearAlgebra
55
using GNNGraphs
66
using GNNGraphs: COO_T, ADJMAT_T, SPARSE_T
77
using SparseArrays
8+
using Graphs
89

910
const CUMAT_T = Union{CUDA.AnyCuMatrix, CUDA.CUSPARSE.CuSparseMatrix}
11+
const CUDA_COO_T = Tuple{T, T, V} where {T <: AnyCuArray{<:Integer}, V <: Union{Nothing, AnyCuArray}}
1012

1113
# Query
1214

@@ -35,5 +37,31 @@ function sort_edge_index(u::AnyCuArray, v::AnyCuArray)
3537
sort_edge_index(u, v) |> dev
3638
end
3739

40+
# Convert
41+
42+
function GNNGraphs.to_sparse(coo::CUDA_COO_T, T = nothing; dir = :out, num_nodes = nothing,
43+
weighted = true, is_coalesced = false)
44+
s, t, eweight = coo
45+
T = T === nothing ? (eweight === nothing ? eltype(s) : eltype(eweight)) : T
46+
47+
if eweight === nothing || !weighted
48+
eweight = fill!(similar(s, T), 1)
49+
end
50+
51+
num_nodes::Int = isnothing(num_nodes) ? max(maximum(s), maximum(t)) : num_nodes
52+
53+
# if coalesced build directly sparse coo matrix
54+
if is_coalesced
55+
A = CUDA.CUSPARSE.CuSparseMatrixCOO{T,eltype(s)}(s, t, eweight, (num_nodes, num_nodes))
56+
else
57+
A = sparse(s, t, eweight, num_nodes, num_nodes)
58+
end
59+
60+
num_edges::Int = nnz(A)
61+
if eltype(A) != T
62+
A = T.(A)
63+
end
64+
return A, num_nodes, num_edges
65+
end
3866

3967
end #module

GNNGraphs/src/gnngraph.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ struct GNNGraph{T <: Union{COO_T, ADJMAT_T}} <: AbstractGNNGraph{T}
113113
ndata::DataStore
114114
edata::DataStore
115115
gdata::DataStore
116-
is_coalesced::Bool # only for :coo, true if the graph is coalesced, i.e., indices ordered by row and no multi edges
116+
is_coalesced::Bool # only for :coo, true if the graph is coalesced, i.e., no multi edges and indices ordered by target, then source
117117
end
118118

119119
# GNNGraph constructor setting the is_coalesced field to false

GNNGraphs/src/query.jl

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -231,13 +231,7 @@ If `weighted=true`, the `A` will contain the edge weights if any, otherwise the
231231
"""
232232
function Graphs.adjacency_matrix(g::GNNGraph{<:COO_T}, T::DataType = eltype(g); dir = :out,
233233
weighted = true)
234-
if iscuarray(g.graph[1])
235-
# Revisit after
236-
# https://github.com/JuliaGPU/CUDA.jl/issues/1113
237-
A, n, m = to_dense(g.graph, T; num_nodes = g.num_nodes, weighted)
238-
else
239-
A, n, m = to_sparse(g.graph, T; num_nodes = g.num_nodes, weighted)
240-
end
234+
A, n, m = to_sparse(g.graph, T; num_nodes = g.num_nodes, weighted)
241235
@assert size(A) == (n, n)
242236
return dir == :out ? A : A'
243237
end

GNNGraphs/src/transform.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ end
148148
"""
149149
coalesce(g::GNNGraph; aggr=+)
150150
151-
Return a new GNNGraph where all multiple edges between the same pair of nodes are merged (using aggr for edge weights and features), and the edge indices are sorted lexicographically (by source, then target).
151+
Return a new GNNGraph where all multiple edges between the same pair of nodes are merged (using aggr for edge weights and features), and the edge indices are sorted lexicographically (by target, then by source).
152152
This method is only applicable to graphs of type `:coo`.
153153
154154
`aggr` can take value `+`,`min`, `max` or `mean`.
@@ -158,7 +158,8 @@ function Base.coalesce(g::GNNGraph{<:COO_T}; aggr = +)
158158
w = get_edge_weight(g)
159159
edata = g.edata
160160
num_edges = g.num_edges
161-
idxs, idxmax = edge_encoding(s, t, g.num_nodes)
161+
# order by target first and then source as a workaround of CUDA.jl issue: https://github.com/JuliaGPU/CUDA.jl/issues/2820
162+
idxs, idxmax = edge_encoding(t, s, g.num_nodes)
162163

163164
perm = sortperm(idxs)
164165
idxs = idxs[perm]

GNNGraphs/test/gnngraph.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -99,13 +99,14 @@ end
9999
mat_gpu = adjacency_matrix(g_gpu)
100100
@test mat_gpu isa AbstractMatrix{Int}
101101
@test get_device(mat_gpu) isa AbstractGPUDevice
102-
@test Array(mat_gpu) == adj_mat
102+
# Convert to float first because poor Int support in CUSPARSE, throws an error
103+
@test Array(Float32.(mat_gpu)) == Float32.(adj_mat)
103104
end
104105
end
105106

106107
@testset "normalized_laplacian" begin
107108
mat = normalized_laplacian(g)
108-
if TEST_GPU && !(dev isa MetalDevice) && GRAPH_T != :sparse
109+
if TEST_GPU && !(dev isa MetalDevice) && GRAPH_T != :sparse && GRAPH_T != :coo
109110
mat_gpu = normalized_laplacian(g_gpu)
110111
@test mat_gpu isa AbstractMatrix{Float32}
111112
@test get_device(mat_gpu)isa AbstractGPUDevice
@@ -114,7 +115,7 @@ end
114115
end
115116

116117
@testset "scaled_laplacian" begin
117-
if TEST_GPU && !(dev isa MetalDevice) && GRAPH_T != :sparse
118+
if TEST_GPU && !(dev isa MetalDevice) && GRAPH_T != :sparse && GRAPH_T != :coo
118119
mat = scaled_laplacian(g)
119120
mat_gpu = scaled_laplacian(g_gpu)
120121
@test mat_gpu isa AbstractMatrix{Float32}

GNNGraphs/test/transform.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -456,8 +456,10 @@ end
456456

457457
s2, t2 = edge_index(g2)
458458
w2 = get_edge_weight(g2)
459-
@test s2 == [1, 2, 2, 3, 3, 4, 4]
460-
@test t2 == [2, 1, 3, 2, 4, 3, 4]
459+
# @test s2 == [1, 2, 2, 3, 3, 4, 4]
460+
# @test t2 == [2, 1, 3, 2, 4, 3, 4]
461+
@test s2 == [2, 1, 3, 2, 4, 3, 4]
462+
@test t2 == [1, 2, 2, 3, 3, 4, 4]
461463
@test w2 == [1, 1, 2, 2, 3.5, 3.5, 5]
462464
@test g2.edata.e == [10.0, 10.0, 20.0, 20.0, 35.0, 35.0, 50.0]
463465
end

GNNlib/ext/GNNlibCUDAExt.jl

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,10 @@ module GNNlibCUDAExt
33
using CUDA
44
using Random, Statistics, LinearAlgebra
55
using GNNlib: GNNlib, propagate, copy_xj, e_mul_xj, w_mul_xj
6-
using GNNGraphs: GNNGraph, COO_T, SPARSE_T
6+
using GNNGraphs: GNNGraph, COO_T, SPARSE_T, to_dense, to_sparse
7+
using ChainRulesCore: @non_differentiable
8+
9+
const CUDA_COO_T = Tuple{T, T, V} where {T <: AnyCuArray{<:Integer}, V <: Union{Nothing, AnyCuArray}}
710

811
###### PROPAGATE SPECIALIZATIONS ####################
912

@@ -12,7 +15,9 @@ using GNNGraphs: GNNGraph, COO_T, SPARSE_T
1215
## avoid the fast path on gpu until we have better cuda support
1316
function GNNlib.propagate(::typeof(copy_xj), g::GNNGraph{<:COO_T}, ::typeof(+),
1417
xi, xj::AnyCuMatrix, e)
15-
propagate((xi, xj, e) -> copy_xj(xi, xj, e), g, +, xi, xj, e)
18+
A = _adjacency_matrix(g, eltype(xj); weighted = false)
19+
20+
return xj * A
1621
end
1722

1823
## E_MUL_XJ
@@ -42,4 +47,21 @@ end
4247

4348
# Flux.Zygote.@nograd compute_degree
4449

50+
## CUSTOM ADJACENCY_MATRIX IMPLEMENTATION FOR CUDA COO GRAPHS, returning dense matrix when not coalesced, more efficient
51+
52+
function _adjacency_matrix(g::GNNGraph{<:CUDA_COO_T}, T::DataType = eltype(g); dir = :out,
53+
weighted = true)
54+
if !g.is_coalesced
55+
# Revisit after
56+
# https://github.com/JuliaGPU/CUDA.jl/issues/1113
57+
A, n, m = to_dense(g.graph, T; num_nodes = g.num_nodes, weighted) # if not coalesced, construction of sparse matrix is slow
58+
else
59+
A, n, m = to_sparse(g.graph, T; num_nodes = g.num_nodes, weighted, is_coalesced = true)
60+
end
61+
@assert size(A) == (n, n)
62+
return dir == :out ? A : A'
63+
end
64+
65+
@non_differentiable _adjacency_matrix(x...)
66+
4567
end #module

GNNlib/test/test_module.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ function test_gradients(
150150
return true
151151
end
152152

153-
function generate_test_graphs(graph_type)
153+
function generate_test_graphs(graph_type; do_coalesce=false)
154154
adj1 = [0 1 0 1
155155
1 0 1 0
156156
0 1 0 1
@@ -168,12 +168,18 @@ function generate_test_graphs(graph_type)
168168
g_single_vertex = GNNGraph(adj_single_vertex,
169169
ndata = rand(Float32, D_IN, 4);
170170
graph_type)
171+
172+
if graph_type == :coo && do_coalesce
173+
g1 = coalesce(g1)
174+
g_single_vertex = coalesce(g_single_vertex)
175+
end
171176

172177
return (g1, g_single_vertex)
173178
end
174179

175180
GRAPH_TYPES = [:coo, :dense, :sparse]
176181
TEST_GRAPHS = [generate_test_graphs(:coo)...,
182+
generate_test_graphs(:coo, do_coalesce=true)...,
177183
generate_test_graphs(:dense)...,
178184
generate_test_graphs(:sparse)...]
179185

GraphNeuralNetworks/test/layers/conv.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ end
108108

109109
if gpu_backend() == "AMDGPU"
110110
broken = true
111-
elseif gpu_backend() == "CUDA" && get_graph_type(g) == :sparse
111+
elseif gpu_backend() == "CUDA" && get_graph_type(g) in [:coo, :sparse]
112112
broken = true
113113
else
114114
broken = false

GraphNeuralNetworks/test/test_module.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ function test_gradients(
157157
end
158158

159159

160-
function generate_test_graphs(graph_type)
160+
function generate_test_graphs(graph_type; do_coalesce=false)
161161
adj1 = [0 1 0 1
162162
1 0 1 0
163163
0 1 0 1
@@ -175,12 +175,18 @@ function generate_test_graphs(graph_type)
175175
g_single_vertex = GNNGraph(adj_single_vertex,
176176
ndata = rand(Float32, D_IN, 4);
177177
graph_type)
178+
179+
if graph_type == :coo && do_coalesce
180+
g1 = coalesce(g1)
181+
g_single_vertex = coalesce(g_single_vertex)
182+
end
178183

179184
return (g1, g_single_vertex)
180185
end
181186

182187
GRAPH_TYPES = [:coo, :dense, :sparse]
183188
TEST_GRAPHS = [generate_test_graphs(:coo)...,
189+
generate_test_graphs(:coo, do_coalesce=true)...,
184190
generate_test_graphs(:dense)...,
185191
generate_test_graphs(:sparse)...]
186192

0 commit comments

Comments
 (0)