|
| 1 | +# # Graph Classification with Graph Neural Networks |
| 2 | + |
| 3 | +# *This tutorial is a julia adaptation of the Pytorch Geometric tutorials that can be found [here](https://pytorch-geometric.readthedocs.io/en/latest/notes/colabs.html).* |
| 4 | + |
| 5 | +# In this tutorial session we will have a closer look at how to apply **Graph Neural Networks (GNNs) to the task of graph classification**. |
| 6 | +# Graph classification refers to the problem of classifying entire graphs (in contrast to nodes), given a **dataset of graphs**, based on some structural graph properties and possibly on some input node features. |
| 7 | +# Here, we want to embed entire graphs, and we want to embed those graphs in such a way so that they are linearly separable given a task at hand. |
| 8 | +# We will use a graph convolutional network to create a vector embedding of the input graph, and the apply a simple linear classification head to perform the final classification. |
| 9 | + |
| 10 | +# A common graph classification task is **molecular property prediction**, in which molecules are represented as graphs, and the task may be to infer whether a molecule inhibits HIV virus replication or not. |
| 11 | + |
| 12 | +# The TU Dortmund University has collected a wide range of different graph classification datasets, known as the [**TUDatasets**](https://chrsmrrs.github.io/datasets/), which are also accessible via MLDatasets.jl. |
| 13 | +# Let's import the necessary packages. Then we'll load and inspect one of the smaller ones, the **MUTAG dataset**: |
| 14 | + |
| 15 | + |
| 16 | +using Flux, GraphNeuralNetworks |
| 17 | +using Flux: onecold, onehotbatch, logitcrossentropy, DataLoader |
| 18 | +using MLDatasets, MLUtils |
| 19 | +using LinearAlgebra, Random, Statistics |
| 20 | + |
| 21 | +ENV["DATADEPS_ALWAYS_ACCEPT"] = "true" # don't ask for dataset download confirmation |
| 22 | +Random.seed!(42); # for reproducibility |
| 23 | +# |
| 24 | + |
| 25 | +dataset = TUDataset("MUTAG") |
| 26 | + |
| 27 | +# |
| 28 | +dataset.graph_data.targets |> union |
| 29 | + |
| 30 | +# |
| 31 | +g1, y1 = dataset[1] # get the first graph and target |
| 32 | + |
| 33 | +# |
| 34 | +reduce(vcat, g.node_data.targets for (g, _) in dataset) |> union |
| 35 | + |
| 36 | +# |
| 37 | +reduce(vcat, g.edge_data.targets for (g, _) in dataset) |> union |
| 38 | + |
| 39 | +# This dataset provides **188 different graphs**, and the task is to classify each graph into **one out of two classes**. |
| 40 | + |
| 41 | +# By inspecting the first graph object of the dataset, we can see that it comes with **17 nodes** and **38 edges**. |
| 42 | +# It also comes with exactly **one graph label**, and provides additional node labels (7 classes) and edge labels (4 classes). |
| 43 | +# However, for the sake of simplicity, we will not make use of edge labels. |
| 44 | + |
| 45 | +# We now convert the `MLDatasets.jl` graph types to our `GNNGraph`s and we also onehot encode both the node labels (which will be used as input features) and the graph labels (what we want to predict): |
| 46 | + |
| 47 | +graphs = mldataset2gnngraph(dataset) |
| 48 | +graphs = [GNNGraph(g, |
| 49 | + ndata = Float32.(onehotbatch(g.ndata.targets, 0:6)), |
| 50 | + edata = nothing) |
| 51 | + for g in graphs] |
| 52 | +y = onehotbatch(dataset.graph_data.targets, [-1, 1]) |
| 53 | + |
| 54 | + |
| 55 | +# We have some useful utilities for working with graph datasets, *e.g.*, we can shuffle the dataset and use the first 150 graphs as training graphs, while using the remaining ones for testing: |
| 56 | + |
| 57 | +train_data, test_data = splitobs((graphs, y), at = 150, shuffle = true) |> getobs |
| 58 | + |
| 59 | + |
| 60 | +train_loader = DataLoader(train_data, batchsize = 32, shuffle = true) |
| 61 | +test_loader = DataLoader(test_data, batchsize = 32, shuffle = false) |
| 62 | + |
| 63 | +# Here, we opt for a `batch_size` of 32, leading to 5 (randomly shuffled) mini-batches, containing all $4 \cdot 32+22 = 150$ graphs. |
| 64 | + |
| 65 | + |
| 66 | +# ## Mini-batching of graphs |
| 67 | + |
| 68 | +# Since graphs in graph classification datasets are usually small, a good idea is to **batch the graphs** before inputting them into a Graph Neural Network to guarantee full GPU utilization. |
| 69 | +# In the image or language domain, this procedure is typically achieved by **rescaling** or **padding** each example into a set of equally-sized shapes, and examples are then grouped in an additional dimension. |
| 70 | +# The length of this dimension is then equal to the number of examples grouped in a mini-batch and is typically referred to as the `batchsize`. |
| 71 | + |
| 72 | + |
| 73 | +# However, for GNNs the two approaches described above are either not feasible or may result in a lot of unnecessary memory consumption. |
| 74 | +# Therefore, GraphNeuralNetworks.jl opts for another approach to achieve parallelization across a number of examples. Here, adjacency matrices are stacked in a diagonal fashion (creating a giant graph that holds multiple isolated subgraphs), and node and target features are simply concatenated in the node dimension (the last dimension). |
| 75 | + |
| 76 | +# This procedure has some crucial advantages over other batching procedures: |
| 77 | + |
| 78 | +# 1. GNN operators that rely on a message passing scheme do not need to be modified since messages are not exchanged between two nodes that belong to different graphs. |
| 79 | + |
| 80 | +# 2. There is no computational or memory overhead since adjacency matrices are saved in a sparse fashion holding only non-zero entries, *i.e.*, the edges. |
| 81 | + |
| 82 | +# GraphNeuralNetworks.jl can **batch multiple graphs into a single giant graph**: |
| 83 | + |
| 84 | + |
| 85 | +vec_gs, _ = first(train_loader) |
| 86 | + |
| 87 | +# |
| 88 | +MLUtils.batch(vec_gs) |
| 89 | + |
| 90 | + |
| 91 | +# Each batched graph object is equipped with a **`graph_indicator` vector**, which maps each node to its respective graph in the batch: |
| 92 | + |
| 93 | +# ```math |
| 94 | +# \textrm{graph\_indicator} = [1, \ldots, 1, 2, \ldots, 2, 3, \ldots ] |
| 95 | +# ``` |
| 96 | + |
| 97 | + |
| 98 | +# ## Training a Graph Neural Network (GNN) |
| 99 | + |
| 100 | +# Training a GNN for graph classification usually follows a simple recipe: |
| 101 | + |
| 102 | +# 1. Embed each node by performing multiple rounds of message passing |
| 103 | +# 2. Aggregate node embeddings into a unified graph embedding (**readout layer**) |
| 104 | +# 3. Train a final classifier on the graph embedding |
| 105 | + |
| 106 | +# There exists multiple **readout layers** in literature, but the most common one is to simply take the average of node embeddings: |
| 107 | + |
| 108 | +# ```math |
| 109 | +# \mathbf{x}_{\mathcal{G}} = \frac{1}{|\mathcal{V}|} \sum_{v \in \mathcal{V}} \mathcal{x}^{(L)}_v |
| 110 | +# ``` |
| 111 | + |
| 112 | +# GraphNeuralNetworks.jl provides this functionality via `GlobalPool(mean)`, which takes in the node embeddings of all nodes in the mini-batch and the assignment vector `graph_indicator` to compute a graph embedding of size `[hidden_channels, batchsize]`. |
| 113 | + |
| 114 | +# The final architecture for applying GNNs to the task of graph classification then looks as follows and allows for complete end-to-end training: |
| 115 | + |
| 116 | +function create_model(nin, nh, nout) |
| 117 | + GNNChain(GCNConv(nin => nh, relu), |
| 118 | + GCNConv(nh => nh, relu), |
| 119 | + GCNConv(nh => nh), |
| 120 | + GlobalPool(mean), |
| 121 | + Dropout(0.5), |
| 122 | + Dense(nh, nout)) |
| 123 | +end; |
| 124 | + |
| 125 | + |
| 126 | +# Here, we again make use of the `GCNConv` with $\mathrm{ReLU}(x) = \max(x, 0)$ activation for obtaining localized node embeddings, before we apply our final classifier on top of a graph readout layer. |
| 127 | + |
| 128 | +# Let's train our network for a few epochs to see how well it performs on the training as well as test set: |
| 129 | + |
| 130 | + |
| 131 | + |
| 132 | +function eval_loss_accuracy(model, data_loader, device) |
| 133 | + loss = 0.0 |
| 134 | + acc = 0.0 |
| 135 | + ntot = 0 |
| 136 | + for (g, y) in data_loader |
| 137 | + g, y = MLUtils.batch(g) |> device, y |> device |
| 138 | + n = length(y) |
| 139 | + ŷ = model(g, g.ndata.x) |
| 140 | + loss += logitcrossentropy(ŷ, y) * n |
| 141 | + acc += mean((ŷ .> 0) .== y) * n |
| 142 | + ntot += n |
| 143 | + end |
| 144 | + return (loss = round(loss / ntot, digits = 4), |
| 145 | + acc = round(acc * 100 / ntot, digits = 2)) |
| 146 | +end |
| 147 | + |
| 148 | + |
| 149 | +function train!(model; epochs = 200, η = 1e-3, infotime = 10) |
| 150 | + ## device = Flux.gpu # uncomment this for GPU training |
| 151 | + device = Flux.cpu |
| 152 | + model = model |> device |
| 153 | + opt = Flux.setup(Adam(η), model) |
| 154 | + |
| 155 | + function report(epoch) |
| 156 | + train = eval_loss_accuracy(model, train_loader, device) |
| 157 | + test = eval_loss_accuracy(model, test_loader, device) |
| 158 | + @info (; epoch, train, test) |
| 159 | + end |
| 160 | + |
| 161 | + report(0) |
| 162 | + for epoch in 1:epochs |
| 163 | + for (g, y) in train_loader |
| 164 | + g, y = MLUtils.batch(g) |> device, y |> device |
| 165 | + grad = Flux.gradient(model) do model |
| 166 | + ŷ = model(g, g.ndata.x) |
| 167 | + logitcrossentropy(ŷ, y) |
| 168 | + end |
| 169 | + Flux.update!(opt, model, grad[1]) |
| 170 | + end |
| 171 | + epoch % infotime == 0 && report(epoch) |
| 172 | + end |
| 173 | +end |
| 174 | + |
| 175 | + |
| 176 | +nin = 7 |
| 177 | +nh = 64 |
| 178 | +nout = 2 |
| 179 | +model = create_model(nin, nh, nout) |
| 180 | +train!(model) |
| 181 | + |
| 182 | + |
| 183 | + |
| 184 | +# As one can see, our model reaches around **75% test accuracy**. |
| 185 | +# Reasons for the fluctuations in accuracy can be explained by the rather small dataset (only 38 test graphs), and usually disappear once one applies GNNs to larger datasets. |
| 186 | + |
| 187 | +# ## (Optional) Exercise |
| 188 | + |
| 189 | +# Can we do better than this? |
| 190 | +# As multiple papers pointed out ([Xu et al. (2018)](https://arxiv.org/abs/1810.00826), [Morris et al. (2018)](https://arxiv.org/abs/1810.02244)), applying **neighborhood normalization decreases the expressivity of GNNs in distinguishing certain graph structures**. |
| 191 | +# An alternative formulation ([Morris et al. (2018)](https://arxiv.org/abs/1810.02244)) omits neighborhood normalization completely and adds a simple skip-connection to the GNN layer in order to preserve central node information: |
| 192 | + |
| 193 | +# ```math |
| 194 | +# \mathbf{x}_i^{(\ell+1)} = \mathbf{W}^{(\ell + 1)}_1 \mathbf{x}_i^{(\ell)} + \mathbf{W}^{(\ell + 1)}_2 \sum_{j \in \mathcal{N}(i)} \mathbf{x}_j^{(\ell)} |
| 195 | +# ``` |
| 196 | + |
| 197 | +# This layer is implemented under the name `GraphConv` in GraphNeuralNetworks.jl. |
| 198 | + |
| 199 | +# As an exercise, you are invited to complete the following code to the extent that it makes use of `GraphConv` rather than `GCNConv`. |
| 200 | +# This should bring you close to **82% test accuracy**. |
| 201 | + |
| 202 | +# ## Conclusion |
| 203 | + |
| 204 | +# In this chapter, you have learned how to apply GNNs to the task of graph classification. |
| 205 | +# You have learned how graphs can be batched together for better GPU utilization, and how to apply readout layers for obtaining graph embeddings rather than node embeddings. |
0 commit comments