We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 8680766 commit e4589baCopy full SHA for e4589ba
examples/pointnet++.py
@@ -37,13 +37,17 @@ def forward(self, data):
37
pos, batch = data.pos, data.batch
38
39
idx = fps(pos, batch, ratio=0.5) # 512 points
40
- edge_index = radius(pos[idx], pos, 0.1, batch[idx], batch, 64)
41
- x = F.relu(self.local_sa1(None, pos, edge_index))
+ edge_index = radius(
+ pos[idx], pos, 0.1, batch[idx], batch, max_num_neighbors=64)
42
+ N, M = pos.size(0), idx.size(0)
43
+ x = F.relu(self.local_sa1(None, pos, edge_index, size=(N, M)))
44
pos, batch = pos[idx], batch[idx]
45
46
idx = fps(pos, batch, ratio=0.25) # 128 points
- edge_index = radius(pos[idx], pos, 0.2, batch[idx], batch, 64)
- x = F.relu(self.local_sa2(x, pos, edge_index))
47
48
+ pos[idx], pos, 0.2, batch[idx], batch, max_num_neighbors=64)
49
50
+ x = F.relu(self.local_sa2(x, pos, edge_index, size=(N, M)))
51
52
53
x = self.global_sa(torch.cat([x, pos], dim=1))
torch_geometric/nn/conv/point_conv.py
@@ -40,10 +40,13 @@ def reset_parameters(self):
reset(self.local_nn)
reset(self.global_nn)
- def forward(self, x, pos, edge_index):
+ def forward(self, x, pos, edge_index, size=None):
""""""
- N, M = edge_index[0].max().item() + 1, edge_index[1].max().item() + 1
- return self.propagate(edge_index, size=(N, M), x=x, pos=pos)
+ if size is None:
+ N = edge_index[0].max().item() + 1
+ M = edge_index[1].max().item() + 1
+ size = (N, M)
+ return self.propagate(edge_index, size=size, x=x, pos=pos)
def message(self, x_j, pos_j, pos_i):
msg = pos_j - pos_i
0 commit comments