Skip to content

Commit 492a6b7

Browse files
puririshi98pre-commit-ci[bot]akihironitta
authored
Disallow sgformer until disjoint sampling is supported in cugraph examples (#10231)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Akihiro Nitta <[email protected]>
1 parent 2cc28ed commit 492a6b7

File tree

2 files changed

+18
-5
lines changed

2 files changed

+18
-5
lines changed

examples/multi_gpu/ogbn_train_cugraph.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,13 @@ def arg_parse():
8282
"--model",
8383
type=str,
8484
default='GCN',
85-
choices=['SAGE', 'GAT', 'GCN', 'SGFormer'],
85+
choices=[
86+
'SAGE',
87+
'GAT',
88+
'GCN',
89+
# TODO: Uncomment when we add support for disjoint sampling
90+
# 'SGFormer',
91+
],
8692
help="Model used for training, default GCN",
8793
)
8894
parser.add_argument(
@@ -347,6 +353,7 @@ def run_train(rank, args, data, world_size, cugraph_id, model, split_idx,
347353
dataset.num_classes,
348354
)
349355
elif args.model == 'SGFormer':
356+
# TODO add support for this with disjoint sampling
350357
model = torch_geometric.nn.models.SGFormer(
351358
in_channels=dataset.num_features,
352359
hidden_channels=args.hidden_channels,

examples/ogbn_train_cugraph.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,15 @@ def arg_parse():
7676
parser.add_argument(
7777
"--model",
7878
type=str,
79-
default='SGFormer',
80-
choices=['SAGE', 'GAT', 'GCN', 'SGFormer'],
81-
help="Model used for training, default SGFormer",
79+
default='SAGE',
80+
choices=[
81+
'SAGE',
82+
'GAT',
83+
'GCN',
84+
# TODO: Uncomment when we add support for disjoint sampling
85+
# 'SGFormer',
86+
],
87+
help="Model used for training, default SAGE",
8288
)
8389
parser.add_argument(
8490
"--num_heads",
@@ -88,7 +94,6 @@ def arg_parse():
8894
)
8995
parser.add_argument('--tempdir_root', type=str, default=None)
9096
args = parser.parse_args()
91-
9297
return args
9398

9499

@@ -211,6 +216,7 @@ def test(model, loader):
211216
dataset.num_classes,
212217
).cuda()
213218
elif args.model == 'SGFormer':
219+
# TODO add support for this with disjoint sampling
214220
model = torch_geometric.nn.models.SGFormer(
215221
in_channels=dataset.num_features,
216222
hidden_channels=args.hidden_channels,

0 commit comments

Comments
 (0)