|
| 1 | +import torch |
| 2 | +from torch_geometric.utils import softmax, scatter_ |
| 3 | + |
| 4 | +from ..inits import reset |
| 5 | + |
| 6 | + |
| 7 | +class GlobalAttention(torch.nn.Module): |
| 8 | + r"""Global soft attention layer from the `"Gated Graph Sequence Neural |
| 9 | + Networks" <https://arxiv.org/abs/1511.05493>`_ paper |
| 10 | +
|
| 11 | + .. math:: |
| 12 | + \mathbf{r}_i = \sum_{n=1}^{N_i} \mathrm{softmax} \left( |
| 13 | + h_{\mathrm{gate}} ( \mathbf{x}_n ) \right) \odot |
| 14 | + h_{\mathbf{\Theta}} ( \mathbf{x}_n ), |
| 15 | +
|
| 16 | + where :math:`h_{\mathrm{gate}} \colon \mathbb{R}^F \to |
| 17 | + \mathbb{R}` and :math:`h_{\mathbf{\Theta}}` denote neural networks, *i.e.* |
| 18 | + MLPS. |
| 19 | +
|
| 20 | + Args: |
| 21 | + gate_nn (nn.Sequential): Neural network |
| 22 | + :math:`h_{\mathrm{gate}} \colon \mathbb{R}^F \to \mathbb{R}`. |
| 23 | + nn (nn.Sequential, optional): Neural network |
| 24 | + :math:`h_{\mathbf{\Theta}}`. (default: :obj:`None`) |
| 25 | + """ |
| 26 | + |
| 27 | + def __init__(self, gate_nn, nn=None): |
| 28 | + super(GlobalAttention, self).__init__() |
| 29 | + self.gate_nn = gate_nn |
| 30 | + self.nn = nn |
| 31 | + |
| 32 | + self.reset_parameters() |
| 33 | + |
| 34 | + def reset_parameters(self): |
| 35 | + reset(self.gate_nn) |
| 36 | + reset(self.nn) |
| 37 | + |
| 38 | + def forward(self, x, batch, size=None): |
| 39 | + """""" |
| 40 | + x = x.unsqueeze(-1) if x.dim() == 1 else x |
| 41 | + size = batch[-1].item() + 1 if size is None else size |
| 42 | + |
| 43 | + gate = self.gate_nn(x).view(-1, 1) |
| 44 | + x = self.nn(x) if self.nn is not None else x |
| 45 | + assert gate.dim() == x.dim() and gate.size(0) == x.size(0) |
| 46 | + |
| 47 | + gate = softmax(gate, batch, size) |
| 48 | + out = scatter_('add', gate * x, batch, size) |
| 49 | + |
| 50 | + return out |
| 51 | + |
| 52 | + def __repr__(self): |
| 53 | + return '{}(gate_nn={}, nn={})'.format(self.__class__.__name__, |
| 54 | + self.gate_nn, self.nn) |
0 commit comments