Skip to content

Commit c4ebbbf

Browse files
committed
global attention layer
1 parent 9050ae3 commit c4ebbbf

File tree

6 files changed

+88
-4
lines changed

6 files changed

+88
-4
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ In detail, the following methods are currently implemented:
132132
* **[XConv](https://rusty1s.github.io/pytorch_geometric/build/html/modules/nn.html#torch_geometric.nn.conv.XConv)** from Li *et al.*: [PointCNN: Convolution On X-Transformed Points](https://arxiv.org/abs/1801.07791) (NeurIPS 2018)
133133
* **[GMMConv](https://rusty1s.github.io/pytorch_geometric/build/html/modules/nn.html#torch_geometric.nn.conv.GMMConv)** from Monti *et al.*: [Geometric Deep Learning on Graphs and Manifolds using Mixture Model CNNs](https://arxiv.org/abs/1612.00593) (CVPR 2017)
134134
* A **[MetaLayer](https://rusty1s.github.io/pytorch_geometric/build/html/modules/nn.html#torch_geometric.nn.meta.MetaLayer)** for building any kind of graph network similar to the [TensorFlow Graph Nets library](https://github.com/deepmind/graph_nets) from Battaglia *et al.*: [Relational Inductive Biases, Deep Learning, and Graph Networks](https://arxiv.org/abs/1806.01261) (CoRR 2018)
135+
* **[GlobalAttention](https://rusty1s.github.io/pytorch_geometric/build/html/modules/nn.html#torch_geometric.nn.glob.GlobalAttention)** from Li *et al.*: [Gated Graph Sequence Neural Networks](https://arxiv.org/abs/1511.05493) (ICLR 2016)
135136
* **[Set2Set](https://rusty1s.github.io/pytorch_geometric/build/html/modules/nn.html#torch_geometric.nn.glob.Set2Set)** from Vinyals *et al.*: [Order Matters: Sequence to Sequence for Sets](https://arxiv.org/abs/1511.06391) (ICLR 2016)
136137
* **[Sort Pool](https://rusty1s.github.io/pytorch_geometric/build/html/modules/nn.html#torch_geometric.nn.glob.global_sort_pool)** from Zhang *et al.*: [An End-to-End Deep Learning Architecture for Graph Classification](https://www.cse.wustl.edu/~muhan/papers/AAAI_2018_DGCNN.pdf) (AAAI 2018)
137138
* **[Dense Differentiable Pooling](https://rusty1s.github.io/pytorch_geometric/build/html/modules/nn.html#torch_geometric.nn.dense.diff_pool.dense_diff_pool)** from Ying *et al.*: [Hierarchical Graph Representation Learning with Differentiable Pooling](https://arxiv.org/abs/1806.08804) (NeurIPS 2018)

test/nn/glob/test_attention.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import torch
2+
from torch.nn import Sequential as Seq, Linear as Lin, ReLU
3+
from torch_geometric.nn import GlobalAttention
4+
5+
6+
def test_global_attention():
7+
channels, batch_size = (32, 10)
8+
gate_nn = Seq(Lin(channels, channels), ReLU(), Lin(channels, 1))
9+
nn = Seq(Lin(channels, channels), ReLU(), Lin(channels, channels))
10+
11+
glob = GlobalAttention(gate_nn, nn)
12+
assert glob.__repr__() == (
13+
'GlobalAttention(gate_nn=Sequential(\n'
14+
' (0): Linear(in_features=32, out_features=32, bias=True)\n'
15+
' (1): ReLU()\n'
16+
' (2): Linear(in_features=32, out_features=1, bias=True)\n'
17+
'), nn=Sequential(\n'
18+
' (0): Linear(in_features=32, out_features=32, bias=True)\n'
19+
' (1): ReLU()\n'
20+
' (2): Linear(in_features=32, out_features=32, bias=True)\n'
21+
'))')
22+
23+
x = torch.randn((batch_size**2, channels))
24+
batch = torch.arange(batch_size, dtype=torch.long)
25+
batch = batch.view(-1, 1).repeat(1, batch_size).view(-1)
26+
27+
assert glob(x, batch).size() == (batch_size, channels)
28+
assert glob(x, batch, batch_size + 1).size() == (batch_size + 1, channels)

torch_geometric/nn/conv/point_conv.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ class PointConv(torch.nn.Module):
1414
.. math::
1515
\mathbf{x}^{\prime}_i = \gamma_{\mathbf{\Theta}} \left( \max_{j \in
1616
\mathcal{N}(i) \cup \{ i \}} h_{\mathbf{\Theta}} ( \mathbf{x}_j \,
17-
\Vert \, \mathbf{p}_j - \mathbf{p}_i) \right)
17+
\Vert \, \mathbf{p}_j - \mathbf{p}_i) \right),
1818
1919
where :math:`\gamma_{\mathbf{\Theta}}` and
2020
:math:`h_{\mathbf{\Theta}}` denote neural
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
from .glob import global_add_pool, global_mean_pool, global_max_pool
22
from .sort import global_sort_pool
3+
from .attention import GlobalAttention
34
from .set2set import Set2Set
45

56
__all__ = [
67
'global_add_pool',
78
'global_mean_pool',
89
'global_max_pool',
910
'global_sort_pool',
11+
'GlobalAttention',
1012
'Set2Set',
1113
]
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import torch
2+
from torch_geometric.utils import softmax, scatter_
3+
4+
from ..inits import reset
5+
6+
7+
class GlobalAttention(torch.nn.Module):
8+
r"""Global soft attention layer from the `"Gated Graph Sequence Neural
9+
Networks" <https://arxiv.org/abs/1511.05493>`_ paper
10+
11+
.. math::
12+
\mathbf{r}_i = \sum_{n=1}^{N_i} \mathrm{softmax} \left(
13+
h_{\mathrm{gate}} ( \mathbf{x}_n ) \right) \odot
14+
h_{\mathbf{\Theta}} ( \mathbf{x}_n ),
15+
16+
where :math:`h_{\mathrm{gate}} \colon \mathbb{R}^F \to
17+
\mathbb{R}` and :math:`h_{\mathbf{\Theta}}` denote neural networks, *i.e.*
18+
MLPS.
19+
20+
Args:
21+
gate_nn (nn.Sequential): Neural network
22+
:math:`h_{\mathrm{gate}} \colon \mathbb{R}^F \to \mathbb{R}`.
23+
nn (nn.Sequential, optional): Neural network
24+
:math:`h_{\mathbf{\Theta}}`. (default: :obj:`None`)
25+
"""
26+
27+
def __init__(self, gate_nn, nn=None):
28+
super(GlobalAttention, self).__init__()
29+
self.gate_nn = gate_nn
30+
self.nn = nn
31+
32+
self.reset_parameters()
33+
34+
def reset_parameters(self):
35+
reset(self.gate_nn)
36+
reset(self.nn)
37+
38+
def forward(self, x, batch, size=None):
39+
""""""
40+
x = x.unsqueeze(-1) if x.dim() == 1 else x
41+
size = batch[-1].item() + 1 if size is None else size
42+
43+
gate = self.gate_nn(x).view(-1, 1)
44+
x = self.nn(x) if self.nn is not None else x
45+
assert gate.dim() == x.dim() and gate.size(0) == x.size(0)
46+
47+
gate = softmax(gate, batch, size)
48+
out = scatter_('add', gate * x, batch, size)
49+
50+
return out
51+
52+
def __repr__(self):
53+
return '{}(gate_nn={}, nn={})'.format(self.__class__.__name__,
54+
self.gate_nn, self.nn)

torch_geometric/utils/softmax.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,8 @@ def softmax(src, index, num_nodes=None):
1010
Args:
1111
src (Tensor): The source tensor.
1212
index (LongTensor): The indices of elements for applying the softmax.
13-
num_nodes (int, optional): Automatically create output tensor with size
14-
:attr:`num_nodes` in the first dimension. If set to :attr:`None`, a
15-
minimal sized output tensor is returned. (default: :obj:`None`)
13+
num_nodes (int, optional): The number of nodes, *i.e.*
14+
:obj:`max_val + 1` of :attr:`index`. (default: :obj:`None`)
1615
1716
:rtype: :class:`Tensor`
1817
"""

0 commit comments

Comments
 (0)