-
Notifications
You must be signed in to change notification settings - Fork 2.8k
Description
According to the paper, a block of the processor should perform the following updates:
$e'{ij} \gets f(e{ij}, v_i, v_j)$
$v'i \gets v_i + f(v_i, \sum_j e'{ij}) $
where ( f ) is an MLP with a residual connection (and I have omitted the distinction between world and mesh edges for simplicity).
This means that the model should update as follows:
$e'{ij} \gets e{ij} + \text{MLP}(e_{ij}, v_i, v_j)$
$v'i \gets v_i + \text{MLP}(v_i, \sum_j e'{ij})$
However, the current implementation instead does:
$e^{\text{update}}{ij} = \text{MLP}(e{ij}, v_i, v_j)$
$v'i \gets v_i + \text{MLP}(v_i, \sum_j e^{\text{update}}{ij})$
$e'{ij} \gets e{ij} + e^{\text{update}}_{ij}$
That is, the node processor MLP is fed the difference between the updated and old edge representation rather than the updated edge representation itself, which is what is described in the paper.
Proposed Fix
The issue can be resolved by modifying the _build
method in core_model.py
as follows:
def _build(self, graph):
"""Applies GraphNetBlock and returns updated MultiGraph per MeshGraphNets paper."""
# Apply edge functions with immediate residual connections
new_edge_sets = []
for edge_set in graph.edge_sets:
# Compute edge update
edge_update = self._update_edge_features(graph.node_features, edge_set)
# Apply residual connection immediately
updated_features = edge_set.features + edge_update
new_edge_sets.append(edge_set._replace(features=updated_features))
# Apply node function with residual connection
node_update = self._update_node_features(graph.node_features, new_edge_sets)
# Apply residual connection to node features
new_node_features = graph.node_features + node_update
return MultiGraph(new_node_features, new_edge_sets)