Skip to content

Commit deaf4f6

Browse files
puririshi98pre-commit-ci[bot]akihironitta
authored
Clean Up CuGraph-ops (#10383)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Akihiro Nitta <[email protected]>
1 parent 449e104 commit deaf4f6

File tree

6 files changed

+43
-11
lines changed

6 files changed

+43
-11
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/).
9090

9191
### Changed
9292

93+
- Added `edge_attr` in `CuGraphGATConv` ([#10383](https://github.com/pyg-team/pytorch_geometric/pull/10383))
9394
- Adapt `dgcnn_classification` example to work with `ModelNet` and `MedShapeNet` Datasets ([#9823](https://github.com/pyg-team/pytorch_geometric/pull/9823))
9495
- Chained exceptions explicitly instead of implicitly ([#10242](https://github.com/pyg-team/pytorch_geometric/pull/10242))
9596
- Updated cuGraph examples to use buffered sampling which keeps data in memory and is significantly faster than the deprecated buffered sampling ([#10079](https://github.com/pyg-team/pytorch_geometric/pull/10079))

docs/source/install/installation.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,13 @@ If :conda:`null` :obj:`conda` does not pick up the correct CUDA version of :pyg:
187187
188188
conda install pyg=*=*cu* -c pyg
189189
190+
Enabling Accelerated cuGraph GNNs
191+
---------------------------------
192+
193+
Currently, NVIDIA recommends `NVIDIA PyG Container <https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pyg/tags>_` to use cuGraph integration in PyG.
194+
This functionality is planned to be enabled through cuDNN which is part of PyTorch builds. We still recommend using the NVIDIA PyG Container regardless to have the fastest and most stable build of the NVIDIA CUDA stack combined with PyTorch and PyG.
195+
196+
190197
Frequently Asked Questions
191198
--------------------------
192199

test/nn/conv/cugraph/test_cugraph_gat_conv.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,11 @@
1111
@pytest.mark.parametrize('bias', [True, False])
1212
@pytest.mark.parametrize('bipartite', [True, False])
1313
@pytest.mark.parametrize('concat', [True, False])
14+
@pytest.mark.parametrize('edge_attr', [True, False])
1415
@pytest.mark.parametrize('heads', [1, 2, 3])
1516
@pytest.mark.parametrize('max_num_neighbors', [8, None])
16-
def test_gat_conv_equality(bias, bipartite, concat, heads, max_num_neighbors):
17+
def test_gat_conv_equality(bias, bipartite, concat, edge_attr, heads,
18+
max_num_neighbors):
1719
in_channels, out_channels = 5, 2
1820
kwargs = dict(bias=bias, concat=concat)
1921

@@ -32,17 +34,27 @@ def test_gat_conv_equality(bias, bipartite, concat, heads, max_num_neighbors):
3234
conv2.lin.weight.data[:, :] = conv1.lin.weight.data
3335
conv2.att.data[:heads * out_channels] = conv1.att_src.data.flatten()
3436
conv2.att.data[heads * out_channels:] = conv1.att_dst.data.flatten()
37+
if edge_attr and not bipartite:
38+
e_attrs = torch.randn(size=(edge_index.size(1), 10))
39+
out1 = conv1(x, edge_index, edge_attr=e_attrs)
3540

36-
if bipartite:
37-
out1 = conv1((x, x[:size[1]]), edge_index)
41+
out2 = conv2(
42+
x,
43+
EdgeIndex(edge_index, sparse_size=size),
44+
max_num_neighbors=max_num_neighbors,
45+
edge_attr=e_attrs,
46+
)
3847
else:
39-
out1 = conv1(x, edge_index)
48+
if bipartite:
49+
out1 = conv1((x, x[:size[1]]), edge_index)
50+
else:
51+
out1 = conv1(x, edge_index)
4052

41-
out2 = conv2(
42-
x,
43-
EdgeIndex(edge_index, sparse_size=size),
44-
max_num_neighbors=max_num_neighbors,
45-
)
53+
out2 = conv2(
54+
x,
55+
EdgeIndex(edge_index, sparse_size=size),
56+
max_num_neighbors=max_num_neighbors,
57+
)
4658
assert torch.allclose(out1, out2, atol=1e-3)
4759

4860
grad_output = torch.rand_like(out1)

torch_geometric/nn/conv/cugraph/gat_conv.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@ class CuGraphGATConv(CuGraphModule): # pragma: no cover
2626
:class:`~torch_geometric.nn.conv.GATConv` based on the :obj:`cugraph-ops`
2727
package that fuses message passing computation for accelerated execution
2828
and lower memory footprint.
29+
The current method to enable :obj:`cugraph-ops`
30+
is to use `The NVIDIA PyG Container
31+
<https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pyg>`_.
2932
"""
3033
def __init__(
3134
self,
@@ -67,6 +70,7 @@ def forward(
6770
self,
6871
x: Tensor,
6972
edge_index: EdgeIndex,
73+
edge_attr: Tensor,
7074
max_num_neighbors: Optional[int] = None,
7175
) -> Tensor:
7276
graph = self.get_cugraph(edge_index, max_num_neighbors)
@@ -75,10 +79,12 @@ def forward(
7579

7680
if LEGACY_MODE:
7781
out = GATConvAgg(x, self.att, graph, self.heads, 'LeakyReLU',
78-
self.negative_slope, False, self.concat)
82+
self.negative_slope, False, self.concat,
83+
edge_feat=edge_attr)
7984
else:
8085
out = GATConvAgg(x, self.att, graph, self.heads, 'LeakyReLU',
81-
self.negative_slope, self.concat)
86+
self.negative_slope, self.concat,
87+
edge_feat=edge_attr)
8288

8389
if self.bias is not None:
8490
out = out + self.bias

torch_geometric/nn/conv/cugraph/rgcn_conv.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@ class CuGraphRGCNConv(CuGraphModule): # pragma: no cover
2929
:class:`~torch_geometric.nn.conv.RGCNConv` based on the :obj:`cugraph-ops`
3030
package that fuses message passing computation for accelerated execution
3131
and lower memory footprint.
32+
The current method to enable :obj:`cugraph-ops`
33+
is to use `The NVIDIA PyG Container
34+
<https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pyg>`_.
3235
"""
3336
def __init__(self, in_channels: int, out_channels: int, num_relations: int,
3437
num_bases: Optional[int] = None, aggr: str = 'mean',

torch_geometric/nn/conv/cugraph/sage_conv.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@ class CuGraphSAGEConv(CuGraphModule): # pragma: no cover
2727
:class:`~torch_geometric.nn.conv.SAGEConv` based on the :obj:`cugraph-ops`
2828
package that fuses message passing computation for accelerated execution
2929
and lower memory footprint.
30+
The current method to enable :obj:`cugraph-ops`
31+
is to use `The NVIDIA PyG Container
32+
<https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pyg>`_.
3033
"""
3134
def __init__(
3235
self,

0 commit comments

Comments
 (0)