-
Notifications
You must be signed in to change notification settings - Fork 575
fix(dpmodel/pt/pd/jax): pass trainable to layer & support JAX trainable & support TF tensor fitting trainable #4793
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
1. For dpmodel, pt, and pd, pass the trainable parameter to the layer (not actually used in this PR). 2. For JAX, support the `trainable` parameter in the layer. Signed-off-by: Jinzhe Zeng <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR propagates a new trainable flag throughout various descriptor and network components in the DeepMD library (Paddle, PyTorch, and JAX backends).
- Add
trainableparameter to many layer and block constructors in pd and dpmodel modules - Ensure
MLPLayer/NativeLayerreceive and forward thetrainableflag - Update JAX backend to wrap parameters differently based on
trainable
Reviewed Changes
Copilot reviewed 48 out of 48 changed files in this pull request and generated no comments.
Show a summary per file
| File | Description |
|---|---|
| deepmd/pd/model/descriptor/repformers.py | Added trainable to block and layer constructors |
| deepmd/pd/model/descriptor/repformer_layer.py | Added trainable to attention and MLP layers |
| deepmd/pd/model/descriptor/repflows.py | Added trainable to block and layer constructors |
| deepmd/pd/model/descriptor/repflow_layer.py | Added trainable to repflow layer constructors |
| deepmd/pd/model/descriptor/dpa3.py | Added trainable to subclass initialization |
| deepmd/pd/model/descriptor/dpa2.py | Added trainable to subclass initialization |
| deepmd/pd/model/descriptor/dpa1.py | Added trainable to subclass initialization |
| deepmd/jax/utils/network.py | Support JAX trainable in __setattr__ |
| deepmd/dpmodel/utils/type_embed.py | Added trainable to embedding utility |
| deepmd/dpmodel/utils/network.py | Added trainable to native and fitting networks |
| deepmd/dpmodel/fitting/general_fitting.py | Propagated trainable to general fitting nets |
| deepmd/dpmodel/descriptor/se_t_tebd.py | Added trainable to SE‐TEBD descriptor |
| deepmd/dpmodel/descriptor/se_t.py | Added trainable to SE‐T descriptor |
| deepmd/dpmodel/descriptor/se_r.py | Added trainable to SE‐R descriptor |
| deepmd/dpmodel/descriptor/se_e2_a.py | Added trainable to SE‐E2A descriptor |
| deepmd/dpmodel/descriptor/repformers.py | Added trainable to Repformers block |
| deepmd/dpmodel/descriptor/repflows.py | Added trainable to Repflows block |
| deepmd/dpmodel/descriptor/dpa3.py | Added trainable to DPA3 subclass init |
| deepmd/dpmodel/descriptor/dpa2.py | Added trainable to DPA2 subclass init |
| deepmd/dpmodel/descriptor/dpa1.py | Added trainable to DPA1 subclass init |
Comments suppressed due to low confidence (2)
deepmd/pd/model/descriptor/repformers.py:90
- The
trainableparameter was added to the constructor signature but is not described in the class docstring. Please update the docstring to includetrainableand its purpose.
trainable: bool = True,
deepmd/jax/utils/network.py:48
- The
__setattr__override referencesself.trainablebefore it may be initialized, which can lead to an AttributeError. Consider settingself.trainablein the object's__init__before any attribute assignments occur.
if self.trainable:
Signed-off-by: Jinzhe Zeng <[email protected]>
📝 Walkthrough## Walkthrough
A new `trainable` boolean parameter was introduced and propagated across a wide range of descriptor, network, and fitting classes in the codebase. This parameter, defaulting to `True`, enables explicit control over whether the parameters of neural network layers and submodules are trainable. The parameter is threaded through constructors, serialization, and deserialization methods, and is now consistently handled in descriptor, embedding, attention, repformer, repflow, and fitting modules for all backends (TensorFlow, PyTorch, JAX, and custom frameworks). Corresponding test configurations were updated to set `trainable` to `False`.
## Changes
| File(s) | Change Summary |
|-----------------------------------------------------------------------------------------|--------------------------------------------------------------------------------------------------|
| deepmd/dpmodel/descriptor/dpa1.py<br>deepmd/pd/model/descriptor/dpa1.py<br>deepmd/pt/model/descriptor/dpa1.py | Added `trainable` parameter to descriptor and submodule constructors; propagated to components. |
| deepmd/dpmodel/descriptor/dpa2.py<br>deepmd/pd/model/descriptor/dpa2.py<br>deepmd/pt/model/descriptor/dpa2.py | Added and propagated `trainable` to all submodules and internal layers in DPA2 descriptors. |
| deepmd/dpmodel/descriptor/dpa3.py<br>deepmd/pd/model/descriptor/dpa3.py<br>deepmd/pt/model/descriptor/dpa3.py | Propagated `trainable` parameter to repflows and type embedding submodules. |
| deepmd/dpmodel/descriptor/repflows.py<br>deepmd/pd/model/descriptor/repflows.py<br>deepmd/pt/model/descriptor/repflows.py | Added `trainable` to repflows block and layers; propagated to all internal layers. |
| deepmd/dpmodel/descriptor/repformers.py<br>deepmd/pd/model/descriptor/repformers.py<br>deepmd/pt/model/descriptor/repformers.py | Added `trainable` to repformer blocks, layers, and attention submodules; propagated accordingly. |
| deepmd/dpmodel/descriptor/se_e2_a.py<br>deepmd/pt/model/descriptor/se_a.py | Propagated `trainable` to embedding network initialization in descriptor blocks. |
| deepmd/dpmodel/descriptor/se_r.py<br>deepmd/pt/model/descriptor/se_r.py | Propagated `trainable` to embedding networks for each atom type in descriptor blocks. |
| deepmd/dpmodel/descriptor/se_t.py<br>deepmd/pt/model/descriptor/se_t.py | Propagated `trainable` to embedding networks in descriptor blocks for each embedding index. |
| deepmd/dpmodel/descriptor/se_t_tebd.py<br>deepmd/pd/model/descriptor/se_t_tebd.py<br>deepmd/pt/model/descriptor/se_t_tebd.py | Added and propagated `trainable` to SeTTebd descriptor and block classes and their embeddings. |
| deepmd/dpmodel/fitting/general_fitting.py<br>deepmd/pt/model/task/fitting.py | Propagated `trainable` to fitting network instances in general fitting classes. |
| deepmd/dpmodel/utils/network.py | Added `trainable` to NativeLayer and related classes; handled in serialization/deserialization. |
| deepmd/dpmodel/utils/type_embed.py<br>deepmd/pd/model/network/network.py<br>deepmd/pt/model/network/network.py | Propagated `trainable` to type embedding network constructors and internal networks. |
| deepmd/jax/utils/network.py | Conditional parameter wrapping in NativeLayer based on `trainable` attribute. |
| deepmd/pd/model/descriptor/repflow_layer.py<br>deepmd/pt/model/descriptor/repflow_layer.py | Added `trainable` to RepFlowLayer; propagated to internal MLP and residuals. |
| deepmd/pd/model/descriptor/repformer_layer.py<br>deepmd/pt/model/descriptor/repformer_layer.py | Added `trainable` to attention and repformer layer classes and residuals; propagated internally. |
| deepmd/pd/model/descriptor/se_a.py | Propagated `trainable` to EmbeddingNet in filter layers. |
| deepmd/pd/model/descriptor/se_atten.py<br>deepmd/pt/model/descriptor/se_atten.py | Added `trainable` to attention block, attention layers, and embedding networks; propagated. |
| deepmd/pd/model/descriptor/se_t.py | Propagated `trainable` to EmbeddingNet in filter layers. |
| deepmd/pd/model/network/mlp.py<br>deepmd/pt/model/network/mlp.py | Added `trainable` to MLPLayer constructor; handled in serialization/deserialization. |
| deepmd/tf/descriptor/se.py<br>deepmd/tf/descriptor/se_t.py | Added `trainable` parameter to `serialize_network` methods; propagated to EmbeddingNet. |
| deepmd/tf/descriptor/se_atten.py | Added `trainable` to serialization of attention layers and strip networks; updated docstrings. |
| deepmd/tf/fit/fitting.py | Added `trainable` to fitting network serialization; propagated to FittingNet. |
| source/tests/consistent/descriptor/test_dpa1.py | Set `"trainable": False` in test configuration dictionary. |
| source/tests/consistent/descriptor/test_dpa2.py<br>source/tests/consistent/descriptor/test_dpa3.py | Changed `"trainable"` in test configuration from `True` to `False`. |
## Sequence Diagram(s)
```mermaid
sequenceDiagram
participant User
participant Descriptor
participant Submodule
participant Layer
User->>Descriptor: DescriptorClass(trainable=...)
Descriptor->>Submodule: Submodule(trainable=...)
Submodule->>Layer: Layer(trainable=...)
Note right of Layer: Layer parameters set as trainable or notsequenceDiagram
participant Descriptor
participant Serializer
Descriptor->>Serializer: serialize(trainable=...)
Serializer->>Descriptor: EmbeddingNet(trainable=...)
Note right of Serializer: Serialization includes trainable flag
sequenceDiagram
participant NativeLayer
participant JAXWrapper
NativeLayer->>JAXWrapper: __setattr__('w'/'b'/'idt', value)
alt trainable == True
JAXWrapper->>NativeLayer: wrap with ArrayAPIParam
else trainable == False
JAXWrapper->>NativeLayer: wrap with ArrayAPIVariable
end
Possibly related PRs
Suggested labels
Suggested reviewers
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 4
🔭 Outside diff range comments (14)
deepmd/dpmodel/descriptor/repflows.py (1)
1835-1901: 🛠️ Refactor suggestionMissing trainable parameter in serialization.
The
serialize()method doesn't include thetrainableparameter, which means this setting won't be preserved during model save/load operations. This is inconsistent with the pattern indpa2.pywhere the trainable parameter is properly serialized.Add the trainable parameter to both
DescrptBlockRepflowsandRepFlowLayerserialization:# In DescrptBlockRepflows.serialize() method (around line 732): return { "e_rcut": self.e_rcut, "e_rcut_smth": self.e_rcut_smth, + "trainable": self.trainable, # ... other parameters } # In RepFlowLayer.serialize() method: data = { "@class": "RepFlowLayer", "@version": 2, "e_rcut": self.e_rcut, + "trainable": self.trainable, # ... other parameters }Also ensure the
deserialize()methods handle the trainable parameter appropriately.deepmd/pt/model/descriptor/se_atten.py (3)
771-795:⚠️ Potential issueCritical: Missing trainable parameter in serialization.
The
serialize()method doesn't include thetrainableparameter in the returned dictionary. This will cause the parameter to be lost during save/load cycles, potentially leading to models being loaded with incorrect trainability settings.Apply this fix to include the trainable parameter in serialization:
return { "@class": "NeighborGatedAttention", "@version": 1, "layer_num": self.layer_num, "nnei": self.nnei, "embed_dim": self.embed_dim, "hidden_dim": self.hidden_dim, "dotr": self.dotr, "do_mask": self.do_mask, "scaling_factor": self.scaling_factor, "normalize": self.normalize, "temperature": self.temperature, "trainable_ln": self.trainable_ln, "ln_eps": self.ln_eps, "precision": self.precision, + "trainable": self.trainable, "attention_layers": [layer.serialize() for layer in self.attention_layers], }Also add
self.trainable = trainablein the constructor to store the parameter.
883-922:⚠️ Potential issueCritical: Missing trainable parameter in serialization.
Similar to the parent class, the
NeighborGatedAttentionLayer.serialize()method doesn't include thetrainableparameter.Apply this fix:
return { "nnei": self.nnei, "embed_dim": self.embed_dim, "hidden_dim": self.hidden_dim, "dotr": self.dotr, "do_mask": self.do_mask, "scaling_factor": self.scaling_factor, "normalize": self.normalize, "temperature": self.temperature, "trainable_ln": self.trainable_ln, "ln_eps": self.ln_eps, "precision": self.precision, + "trainable": self.trainable, "attention_layer": self.attention_layer.serialize(), "attn_layer_norm": self.attn_layer_norm.serialize(), }Also add
self.trainable = trainablein the constructor.
1073-1113:⚠️ Potential issueCritical: Missing trainable parameter in serialization.
The
GatedAttentionLayer.serialize()method also lacks thetrainableparameter.Apply this fix:
return { "nnei": self.nnei, "embed_dim": self.embed_dim, "hidden_dim": self.hidden_dim, "num_heads": self.num_heads, "dotr": self.dotr, "do_mask": self.do_mask, "scaling_factor": self.scaling_factor, "normalize": self.normalize, "temperature": self.temperature, "bias": self.bias, "smooth": self.smooth, "precision": self.precision, + "trainable": self.trainable, "in_proj": self.in_proj.serialize(), "out_proj": self.out_proj.serialize(), }Also add
self.trainable = trainablein the constructor.deepmd/pd/model/descriptor/repformer_layer.py (5)
247-283:⚠️ Potential issueCritical: Missing trainable parameter in serialization.
The
Atten2Map.serialize()method doesn't include thetrainableparameter, causing state loss during save/load cycles.Apply this fix:
return { "@class": "Atten2Map", "@version": 1, "input_dim": self.input_dim, "hidden_dim": self.hidden_dim, "head_num": self.head_num, "has_gate": self.has_gate, "smooth": self.smooth, "attnw_shift": self.attnw_shift, "precision": self.precision, + "trainable": self.trainable, "mapqk": self.mapqk.serialize(), }Also add
self.trainable = trainablein the constructor.
338-373:⚠️ Potential issueCritical: Missing trainable parameter in serialization.
The
Atten2MultiHeadApply.serialize()method lacks thetrainableparameter.Apply this fix:
return { "@class": "Atten2MultiHeadApply", "@version": 1, "input_dim": self.input_dim, "head_num": self.head_num, "precision": self.precision, + "trainable": self.trainable, "mapv": self.mapv.serialize(), "head_map": self.head_map.serialize(), }Also add
self.trainable = trainablein the constructor.
417-449:⚠️ Potential issueCritical: Missing trainable parameter in serialization.
The
Atten2EquiVarApply.serialize()method also lacks thetrainableparameter.Apply this fix:
return { "@class": "Atten2EquiVarApply", "@version": 1, "input_dim": self.input_dim, "head_num": self.head_num, "precision": self.precision, + "trainable": self.trainable, "head_map": self.head_map.serialize(), }Also add
self.trainable = trainablein the constructor.
551-592:⚠️ Potential issueCritical: Missing trainable parameter in serialization.
The
LocalAtten.serialize()method doesn't include thetrainableparameter.Apply this fix:
return { "@class": "LocalAtten", "@version": 1, "input_dim": self.input_dim, "hidden_dim": self.hidden_dim, "head_num": self.head_num, "smooth": self.smooth, "attnw_shift": self.attnw_shift, "precision": self.precision, + "trainable": self.trainable, "mapq": self.mapq.serialize(), "mapkv": self.mapkv.serialize(), "head_map": self.head_map.serialize(), }Also add
self.trainable = trainablein the constructor.
1348-1450:⚠️ Potential issueCritical: Missing trainable parameter in serialization.
The
RepformerLayer.serialize()method doesn't include thetrainableparameter, which is particularly critical given the complexity of this class.Apply this fix by adding the trainable parameter to the data dictionary:
data = { "@class": "RepformerLayer", "@version": 2, "rcut": self.rcut, "rcut_smth": self.rcut_smth, "sel": self.sel, "ntypes": self.ntypes, "g1_dim": self.g1_dim, "g2_dim": self.g2_dim, "axis_neuron": self.axis_neuron, "update_chnnl_2": self.update_chnnl_2, "update_g1_has_conv": self.update_g1_has_conv, "update_g1_has_drrd": self.update_g1_has_drrd, "update_g1_has_grrg": self.update_g1_has_grrg, "update_g1_has_attn": self.update_g1_has_attn, "update_g2_has_g1g1": self.update_g2_has_g1g1, "update_g2_has_attn": self.update_g2_has_attn, "update_h2": self.update_h2, "attn1_hidden": self.attn1_hidden, "attn1_nhead": self.attn1_nhead, "attn2_hidden": self.attn2_hidden, "attn2_nhead": self.attn2_nhead, "attn2_has_gate": self.attn2_has_gate, "activation_function": self.activation_function, "update_style": self.update_style, "smooth": self.smooth, "precision": self.precision, "trainable_ln": self.trainable_ln, "use_sqrt_nnei": self.use_sqrt_nnei, "g1_out_conv": self.g1_out_conv, "g1_out_mlp": self.g1_out_mlp, "ln_eps": self.ln_eps, + "trainable": self.trainable, "linear1": self.linear1.serialize(), }Also add
self.trainable = trainablein the constructor.deepmd/pt/model/descriptor/repformer_layer.py (5)
244-263: 🛠️ Refactor suggestionMissing trainable parameter in serialization
The
serialize()method should include the trainable parameter to ensure it's preserved during model save/load operations.return { "@class": "Atten2Map", "@version": 1, "input_dim": self.input_dim, "hidden_dim": self.hidden_dim, "head_num": self.head_num, "has_gate": self.has_gate, "smooth": self.smooth, "attnw_shift": self.attnw_shift, "precision": self.precision, + "trainable": self.trainable, "mapqk": self.mapqk.serialize(), }Note: You'll also need to store the trainable parameter as an instance variable in the constructor.
333-349: 🛠️ Refactor suggestionMissing trainable parameter in serialization
The
serialize()method lacks the trainable parameter, which should be included for proper model persistence.return { "@class": "Atten2MultiHeadApply", "@version": 1, "input_dim": self.input_dim, "head_num": self.head_num, "precision": self.precision, + "trainable": self.trainable, "mapv": self.mapv.serialize(), "head_map": self.head_map.serialize(), }
412-427: 🛠️ Refactor suggestionMissing trainable parameter in serialization
The
serialize()method should include the trainable parameter to maintain consistency with other serialization implementations.return { "@class": "Atten2EquiVarApply", "@version": 1, "input_dim": self.input_dim, "head_num": self.head_num, "precision": self.precision, + "trainable": self.trainable, "head_map": self.head_map.serialize(), }
541-561: 🛠️ Refactor suggestionMissing trainable parameter in serialization
The
serialize()method should include the trainable parameter for consistency and proper model persistence.return { "@class": "LocalAtten", "@version": 1, "input_dim": self.input_dim, "hidden_dim": self.hidden_dim, "head_num": self.head_num, "smooth": self.smooth, "attnw_shift": self.attnw_shift, "precision": self.precision, + "trainable": self.trainable, "mapq": self.mapq.serialize(), "mapkv": self.mapkv.serialize(), "head_map": self.head_map.serialize(), }
1328-1430: 🛠️ Refactor suggestionCritical: Missing trainable parameter in serialization
The
serialize()method is missing the trainable parameter, which is essential for proper model persistence. This is particularly important for the mainRepformerLayerclass as it coordinates all sub-components.data = { "@class": "RepformerLayer", "@version": 2, "rcut": self.rcut, "rcut_smth": self.rcut_smth, "sel": self.sel, "ntypes": self.ntypes, "g1_dim": self.g1_dim, "g2_dim": self.g2_dim, "axis_neuron": self.axis_neuron, "update_chnnl_2": self.update_chnnl_2, "update_g1_has_conv": self.update_g1_has_conv, "update_g1_has_drrd": self.update_g1_has_drrd, "update_g1_has_grrg": self.update_g1_has_grrg, "update_g1_has_attn": self.update_g1_has_attn, "update_g2_has_g1g1": self.update_g2_has_g1g1, "update_g2_has_attn": self.update_g2_has_attn, "update_h2": self.update_h2, "attn1_hidden": self.attn1_hidden, "attn1_nhead": self.attn1_nhead, "attn2_hidden": self.attn2_hidden, "attn2_nhead": self.attn2_nhead, "attn2_has_gate": self.attn2_has_gate, "activation_function": self.activation_function, "update_style": self.update_style, "smooth": self.smooth, "precision": self.precision, "trainable_ln": self.trainable_ln, "use_sqrt_nnei": self.use_sqrt_nnei, "g1_out_conv": self.g1_out_conv, "g1_out_mlp": self.g1_out_mlp, "ln_eps": self.ln_eps, + "trainable": self.trainable, "linear1": self.linear1.serialize(), }Also ensure
self.trainable = trainableis added to the constructor to store the parameter as an instance variable.
🧹 Nitpick comments (4)
deepmd/pt/model/descriptor/repflow_layer.py (1)
67-67: Add documentation for the trainable parameter.The new
trainableparameter should be documented in the class docstring to help users understand its purpose and usage.deepmd/pd/model/descriptor/repflow_layer.py (1)
64-64: Add documentation for the trainable parameter.The new
trainableparameter should be documented in the class docstring to help users understand its purpose and usage, consistent with the PyTorch implementation.deepmd/pd/model/descriptor/repflows.py (1)
170-170: Add documentation for the trainable parameter.The
trainableparameter should be documented in the comprehensive class docstring (lines 51-134) to maintain the high documentation standards of this class.deepmd/pd/model/descriptor/se_atten.py (1)
84-84: Document the newtrainableparameter in the class docstring.The
trainableparameter should be documented in the class docstring along with the other parameters for consistency and clarity.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (47)
deepmd/dpmodel/descriptor/dpa1.py(13 hunks)deepmd/dpmodel/descriptor/dpa2.py(6 hunks)deepmd/dpmodel/descriptor/dpa3.py(2 hunks)deepmd/dpmodel/descriptor/repflows.py(18 hunks)deepmd/dpmodel/descriptor/repformers.py(28 hunks)deepmd/dpmodel/descriptor/se_e2_a.py(1 hunks)deepmd/dpmodel/descriptor/se_r.py(1 hunks)deepmd/dpmodel/descriptor/se_t.py(1 hunks)deepmd/dpmodel/descriptor/se_t_tebd.py(5 hunks)deepmd/dpmodel/fitting/general_fitting.py(1 hunks)deepmd/dpmodel/utils/network.py(11 hunks)deepmd/dpmodel/utils/type_embed.py(1 hunks)deepmd/jax/utils/network.py(2 hunks)deepmd/pd/model/descriptor/dpa1.py(2 hunks)deepmd/pd/model/descriptor/dpa2.py(6 hunks)deepmd/pd/model/descriptor/dpa3.py(2 hunks)deepmd/pd/model/descriptor/repflow_layer.py(14 hunks)deepmd/pd/model/descriptor/repflows.py(3 hunks)deepmd/pd/model/descriptor/repformer_layer.py(24 hunks)deepmd/pd/model/descriptor/repformers.py(3 hunks)deepmd/pd/model/descriptor/se_a.py(1 hunks)deepmd/pd/model/descriptor/se_atten.py(11 hunks)deepmd/pd/model/descriptor/se_t_tebd.py(5 hunks)deepmd/pd/model/network/mlp.py(2 hunks)deepmd/pd/model/network/network.py(2 hunks)deepmd/pt/model/descriptor/dpa1.py(2 hunks)deepmd/pt/model/descriptor/dpa2.py(6 hunks)deepmd/pt/model/descriptor/dpa3.py(2 hunks)deepmd/pt/model/descriptor/repflow_layer.py(14 hunks)deepmd/pt/model/descriptor/repflows.py(3 hunks)deepmd/pt/model/descriptor/repformer_layer.py(24 hunks)deepmd/pt/model/descriptor/repformers.py(4 hunks)deepmd/pt/model/descriptor/se_a.py(1 hunks)deepmd/pt/model/descriptor/se_atten.py(11 hunks)deepmd/pt/model/descriptor/se_r.py(1 hunks)deepmd/pt/model/descriptor/se_t.py(1 hunks)deepmd/pt/model/descriptor/se_t_tebd.py(5 hunks)deepmd/pt/model/network/mlp.py(3 hunks)deepmd/pt/model/network/network.py(2 hunks)deepmd/pt/model/task/fitting.py(1 hunks)deepmd/tf/descriptor/se.py(4 hunks)deepmd/tf/descriptor/se_atten.py(5 hunks)deepmd/tf/descriptor/se_t.py(4 hunks)deepmd/tf/fit/fitting.py(2 hunks)source/tests/consistent/descriptor/test_dpa1.py(1 hunks)source/tests/consistent/descriptor/test_dpa2.py(1 hunks)source/tests/consistent/descriptor/test_dpa3.py(1 hunks)
🧰 Additional context used
🧬 Code Graph Analysis (7)
deepmd/pd/model/descriptor/repflow_layer.py (1)
deepmd/pd/model/network/mlp.py (1)
MLPLayer(74-298)
deepmd/jax/utils/network.py (1)
deepmd/jax/common.py (1)
ArrayAPIVariable(86-97)
deepmd/pt/model/descriptor/repflow_layer.py (1)
deepmd/pd/model/network/mlp.py (1)
MLPLayer(74-298)
deepmd/pt/model/descriptor/repformers.py (1)
deepmd/dpmodel/utils/seed.py (3)
child_seed(10-10)child_seed(14-14)child_seed(17-40)
deepmd/dpmodel/utils/network.py (4)
deepmd/utils/version.py (1)
check_version_compatibility(2-27)source/tests/consistent/descriptor/test_dpa2.py (1)
data(92-189)source/tests/consistent/descriptor/test_dpa1.py (1)
data(79-131)source/tests/consistent/descriptor/test_dpa3.py (1)
data(78-134)
deepmd/pt/model/descriptor/repflows.py (1)
deepmd/pt/model/network/mlp.py (1)
MLPLayer(72-279)
deepmd/pt/model/descriptor/repformer_layer.py (3)
deepmd/pd/model/network/mlp.py (1)
MLPLayer(74-298)deepmd/dpmodel/utils/seed.py (3)
child_seed(10-10)child_seed(14-14)child_seed(17-40)deepmd/pd/model/descriptor/repformer_layer.py (2)
Atten2MultiHeadApply(286-373)Atten2EquiVarApply(376-449)
⏰ Context from checks skipped due to timeout of 90000ms (28)
- GitHub Check: Build C++ (cuda120, cuda)
- GitHub Check: Build C++ (cpu, cpu)
- GitHub Check: Build C++ (clang, clang)
- GitHub Check: Build C++ (rocm, rocm)
- GitHub Check: Build C++ (cuda, cuda)
- GitHub Check: Test Python (3, 3.12)
- GitHub Check: Test Python (6, 3.12)
- GitHub Check: Test Python (3, 3.9)
- GitHub Check: Test Python (5, 3.12)
- GitHub Check: Test Python (5, 3.9)
- GitHub Check: Test Python (6, 3.9)
- GitHub Check: Test Python (4, 3.12)
- GitHub Check: Test Python (1, 3.9)
- GitHub Check: Test Python (4, 3.9)
- GitHub Check: Test Python (1, 3.12)
- GitHub Check: Test Python (2, 3.9)
- GitHub Check: Test Python (2, 3.12)
- GitHub Check: Build C library (2.14, >=2.5.0rc0,<2.15, libdeepmd_c_cu11.tar.gz)
- GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
- GitHub Check: Build wheels for cp310-manylinux_aarch64
- GitHub Check: Build wheels for cp311-win_amd64
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Build wheels for cp311-macosx_x86_64
- GitHub Check: Build wheels for cp311-macosx_arm64
- GitHub Check: Analyze (c-cpp)
- GitHub Check: Test C++ (false)
- GitHub Check: Test C++ (true)
🔇 Additional comments (109)
source/tests/consistent/descriptor/test_dpa1.py (1)
130-130: LGTM! Test configuration updated appropriately.Adding
"trainable": Falseto the test configuration provides valuable test coverage for the non-trainable parameter setting, ensuring the DPA1 descriptor works correctly in both trainable and non-trainable modes.source/tests/consistent/descriptor/test_dpa2.py (1)
184-184: LGTM! Consistent test configuration update.The addition of
"trainable": Falsemaintains consistency with other descriptor tests and ensures proper validation of the non-trainable configuration for DPA2 descriptors.source/tests/consistent/descriptor/test_dpa3.py (1)
133-133: LGTM! Completes consistent test pattern.Adding
"trainable": Falseto the DPA3 test configuration completes the consistent pattern across all DPA descriptor tests, ensuring comprehensive validation of the trainable parameter feature.deepmd/pd/model/network/network.py (1)
48-48: LGTM! Well-implemented trainable parameter addition.The
trainableparameter is properly added with appropriate default value (True) for backward compatibility and correctly propagated to the internalTypeEmbedNetConsistentinstance, enabling trainability control at the embedding network level.Also applies to: 69-69
deepmd/pt/model/descriptor/se_a.py (1)
528-528: LGTM! Correct trainable parameter propagation.The
trainableparameter is properly passed to theEmbeddingNetconstructor, ensuring that the trainability setting propagates correctly from the descriptor block to its internal embedding networks.deepmd/pt/model/descriptor/dpa3.py (1)
172-172: LGTM: Proper trainable parameter propagation.The trainable parameter is correctly propagated to both
DescrptBlockRepflowsandTypeEmbedNetsubmodules, enabling consistent trainability control throughout the descriptor hierarchy.Also applies to: 188-188
deepmd/dpmodel/descriptor/se_e2_a.py (1)
210-210: LGTM: Consistent trainable parameter propagation to embedding networks.The trainable parameter is properly passed to each
EmbeddingNetinstance within the embeddings collection, ensuring uniform trainability control across all embedding networks.deepmd/dpmodel/utils/type_embed.py (1)
96-96: LGTM: Trainable parameter correctly forwarded to embedding network.The trainable parameter is appropriately passed to the internal
EmbeddingNet, ensuring the type embedding network's trainability is controlled consistently.deepmd/pt/model/task/fitting.py (1)
323-323: LGTM: Trainable parameter properly propagated to fitting networks.The trainable parameter is correctly passed to each
FittingNetinstance, enabling consistent trainability control across all fitting networks in the collection.deepmd/pd/model/descriptor/dpa3.py (1)
170-170: LGTM: Consistent trainable parameter propagation across Paddle backend.The trainable parameter is correctly propagated to both
DescrptBlockRepflowsandTypeEmbedNetsubmodules in the Paddle implementation, maintaining consistency with the PyTorch backend.Also applies to: 186-186
deepmd/pt/model/descriptor/se_r.py (1)
145-145: LGTM! Correct implementation of trainable parameter propagation.The change properly forwards the
trainableparameter to theEmbeddingNetconstructor, enabling fine-grained control over parameter trainability. The implementation is consistent with the existing PyTorch parameter handling logic (lines 149-152) that setsrequires_gradappropriately.deepmd/pd/model/descriptor/se_a.py (1)
484-484: LGTM! Proper PaddlePaddle implementation of trainable parameter.The change correctly propagates the
trainableparameter to theEmbeddingNetconstructor. The implementation properly uses PaddlePaddle'sstop_gradientmechanism (lines 488-491) to control parameter trainability, which is the correct equivalent to PyTorch'srequires_grad.deepmd/dpmodel/descriptor/se_t.py (1)
150-150: LGTM! Correct framework-agnostic implementation.The change properly forwards the
trainableparameter to theEmbeddingNetconstructor. This framework-agnostic implementation correctly omits backend-specific parameter handling while maintaining the trainable parameter for serialization and configuration purposes (visible in the serialize method at line 431).deepmd/dpmodel/fitting/general_fitting.py (1)
201-201: LGTM! Well-designed trainable parameter handling for fitting networks.The change correctly propagates the
trainableparameter to theFittingNetconstructor. The implementation elegantly handles multiple trainable parameter formats (lines 136-139):None(defaulting to all trainable), single boolean (applied to all layers), or list of booleans (per-layer control). This provides flexible control over fitting network trainability.deepmd/pt/model/descriptor/se_t.py (1)
578-578: LGTM: Correct propagation of trainable parameter.The
trainableparameter is properly passed to theEmbeddingNetinitialization, enabling control over parameter trainability in filter layers.deepmd/pd/model/descriptor/repformers.py (3)
90-90: LGTM: Trainable parameter added to constructor.The
trainableparameter is correctly added with a default value ofTrue, maintaining backward compatibility.
227-232: LGTM: Trainable parameter propagated to MLPLayer.The
trainableparameter is properly passed to theMLPLayerinitialization, enabling control over the g2 embedding layer's trainability.
269-269: LGTM: Trainable parameter propagated to RepformerLayer.The
trainableparameter is correctly passed to eachRepformerLayerinstance, ensuring consistent trainability control across all layers.deepmd/pt/model/network/network.py (2)
256-256: LGTM: Trainable parameter added to TypeEmbedNet constructor.The
trainableparameter is correctly added with appropriate default value, maintaining backward compatibility.
277-277: LGTM: Trainable parameter propagated to TypeEmbedNetConsistent.The
trainableparameter is properly passed to the internalTypeEmbedNetConsistentinstance, enabling trainability control for type embedding networks.deepmd/pt/model/descriptor/repformers.py (4)
114-114: LGTM: Trainable parameter added to constructor.The
trainableparameter is correctly added with a default value ofTrue, maintaining backward compatibility.
201-202: LGTM: Well-documented trainable parameter.Good documentation that clearly explains the purpose of the
trainableparameter.
252-258: LGTM: Trainable parameter propagated to MLPLayer.The
trainableparameter is properly passed to theMLPLayerinitialization along with other required parameters including the child seed.
295-295: LGTM: Trainable parameter propagated to RepformerLayer.The
trainableparameter is correctly passed to eachRepformerLayerinstance in the loop, ensuring consistent trainability control.deepmd/dpmodel/descriptor/se_r.py (1)
169-169: LGTM: Trainable parameter propagated to EmbeddingNet.The
trainableparameter is correctly passed to each type-specificEmbeddingNetinstance, enabling control over embedding network trainability for all atom types.deepmd/pd/model/network/mlp.py (2)
88-88: LGTM: Trainable parameter added with appropriate default.The addition of the
trainableparameter with a default value ofTruemaintains backward compatibility while enabling explicit control over parameter trainability.
281-281: LGTM: Proper handling of trainable parameter in deserialization.The trainable parameter is correctly extracted from the NativeLayer during deserialization and passed to the constructor.
deepmd/pt/model/descriptor/dpa1.py (2)
301-301: LGTM: Trainable parameter properly propagated to attention block.The
trainableparameter is correctly passed to theDescrptBlockSeAttencomponent, ensuring consistent trainability control throughout the descriptor hierarchy.
315-315: LGTM: Trainable parameter properly propagated to type embedding.The
trainableparameter is correctly passed to theTypeEmbedNetcomponent, maintaining consistent trainability control across all subcomponents.deepmd/pt/model/descriptor/repflows.py (4)
222-222: LGTM: Trainable parameter added with appropriate default.The
trainableparameter is properly added to the constructor with a default value ofTrue, maintaining backward compatibility while enabling trainability control.
286-292: LGTM: Trainable parameter propagated to edge embedding layer.The
trainableparameter is correctly passed to the edge embeddingMLPLayer, ensuring consistent trainability control for edge representations.
293-300: LGTM: Trainable parameter propagated to angle embedding layer.The
trainableparameter is correctly passed to the angle embeddingMLPLayer, ensuring consistent trainability control for angle representations.
331-331: LGTM: Trainable parameter propagated to RepFlow layers.The
trainableparameter is correctly passed to eachRepFlowLayerinstance in the layer stack, ensuring consistent trainability control throughout the entire descriptor block.deepmd/pt/model/descriptor/repflow_layer.py (1)
125-131: LGTM: Consistent trainable parameter propagation.The
trainableparameter is correctly propagated to allMLPLayerinstances andget_residualcalls throughout the constructor. This ensures consistent control over parameter trainability across all subcomponents.Also applies to: 143-152, 163-172, 184-193, 222-232, 233-240, 243-252, 250-259, 270-279
deepmd/pd/model/descriptor/repflow_layer.py (1)
122-128: LGTM: Consistent backend implementation.The PaddlePaddle implementation correctly mirrors the PyTorch version, with the
trainableparameter properly propagated to allMLPLayerinstances andget_residualcalls. This ensures consistent behavior across different backends.Also applies to: 143-149, 163-169, 184-190, 222-229, 230-237, 243-249, 250-256, 270-276
deepmd/pd/model/descriptor/repflows.py (1)
226-232: LGTM: Complete trainable parameter integration.The
trainableparameter is correctly propagated to both theMLPLayerinstances (edge_embdandangle_embd) and allRepFlowLayerinstances in the layer list. This completes the integration chain from the descriptor block level down to individual layer components.Also applies to: 233-240, 243-272
deepmd/dpmodel/descriptor/dpa3.py (1)
360-360: LGTM! Proper trainable parameter propagation.The
trainableparameter is correctly passed to bothDescrptBlockRepflowsandTypeEmbedNetsubcomponents, enabling consistent trainability control throughout the DPA3 descriptor.Also applies to: 378-378
deepmd/dpmodel/descriptor/se_t_tebd.py (2)
160-160: LGTM! Consistent trainable parameter propagation in DescrptSeTTebd.The
trainableparameter is properly passed to bothDescrptBlockSeTTebdandTypeEmbedNetinstances, maintaining consistency with the broader codebase pattern.Also applies to: 175-175
502-502: LGTM! Proper trainable parameter integration in DescrptBlockSeTTebd.The
trainableparameter is correctly added to the constructor with a sensible default value (True) and properly propagated to allEmbeddingNetinstances, enabling fine-grained control over parameter trainability.Also applies to: 548-548, 564-564
deepmd/jax/utils/network.py (2)
19-19: LGTM! Required import for non-trainable parameter support.The
ArrayAPIVariableimport is necessary for wrapping non-trainable parameters in the JAX backend.
48-51: LGTM! Correct JAX parameter wrapping logic.The conditional wrapping logic properly distinguishes between trainable and non-trainable parameters by using
ArrayAPIParamfor trainable parameters (which support gradients) andArrayAPIVariablefor non-trainable parameters (which do not).deepmd/pt/model/network/mlp.py (2)
86-86: LGTM! Proper trainable parameter integration.The
trainableparameter is correctly added to the constructor with a sensible default value (True) and stored as an instance attribute.Also applies to: 89-89
238-238: LGTM! Complete serialization/deserialization support.The
trainableattribute is properly included in both serialization and deserialization methods, ensuring the trainability state is preserved across model save/load cycles.Also applies to: 265-265
deepmd/pd/model/descriptor/dpa1.py (1)
295-295: LGTM! Consistent trainable parameter propagation in DPA1.The
trainableparameter is properly passed to bothDescrptBlockSeAttenandTypeEmbedNetinstances, maintaining consistency with other descriptor implementations and enabling uniform trainability control across the DPA1 architecture.Also applies to: 309-309
deepmd/pd/model/descriptor/dpa2.py (4)
92-92: LGTM: Well-designed parameter addition with appropriate default.The
trainableparameter addition with a default value ofTruemaintains backward compatibility while enabling explicit control over parameter trainability.
187-187: LGTM: Comprehensive parameter propagation to all sub-components.The
trainableparameter is correctly passed to all relevant sub-components including descriptor blocks, type embedding networks, and MLP layers, ensuring consistent trainability control throughout the model hierarchy.Also applies to: 207-207, 248-248, 276-276, 302-302, 312-312
323-324: LGTM: Correct gradient control implementation for PaddlePaddle.The gradient control logic correctly uses PaddlePaddle's
stop_gradientattribute with inverse logic (not trainable) to control parameter trainability.
558-558: LGTM: Proper serialization inclusion.The
trainableparameter is correctly included in the serialization dictionary to maintain state consistency during model save/load operations.deepmd/pt/model/descriptor/se_t_tebd.py (4)
138-138: LGTM: Consistent parameter addition across descriptor classes.The
trainableparameter is appropriately added to both the main descriptor class and the descriptor block class with proper default values.Also applies to: 530-530
163-163: LGTM: Complete parameter propagation to embedding networks.The
trainableparameter is correctly passed to all embedding network instances, including both main and strip embedding networks when applicable.Also applies to: 174-174, 583-583, 597-597
184-185: LGTM: Correct gradient control implementation for PyTorch.The gradient control logic correctly uses PyTorch's
requires_gradattribute to control parameter trainability. This is the appropriate approach for the PyTorch framework.
370-370: LGTM: Proper serialization inclusion.The
trainableparameter is correctly included in the serialization dictionary for state persistence.deepmd/pd/model/descriptor/se_t_tebd.py (4)
138-138: LGTM: Consistent parameter addition matching PyTorch implementation.The
trainableparameter is appropriately added to both descriptor classes with the same signature as the PyTorch version, ensuring cross-framework consistency.Also applies to: 534-534
163-163: LGTM: Complete parameter propagation to all embedding components.The
trainableparameter is correctly passed to all embedding network instances, maintaining consistency with the PyTorch implementation while using PaddlePaddle-specific components.Also applies to: 177-177, 591-591, 605-605
184-185: LGTM: Correct gradient control implementation for PaddlePaddle.The gradient control logic correctly uses PaddlePaddle's
stop_gradientattribute with inverse logic (not trainable), which is the appropriate approach for this framework.
374-374: LGTM: Proper serialization inclusion maintaining consistency.The
trainableparameter is correctly included in the serialization dictionary, ensuring state consistency across framework implementations.deepmd/dpmodel/descriptor/dpa2.py (1)
384-384: Excellent implementation of trainable parameter propagation.The
trainableparameter is correctly added to the constructor and systematically propagated to all key subcomponents (repinit blocks, repformers, type embedding, and transform layers), ensuring consistent control over parameter trainability throughout the DPA2 descriptor hierarchy.Also applies to: 477-477, 497-497, 538-538, 568-568, 592-592, 602-602
deepmd/tf/fit/fitting.py (1)
138-138: LGTM: Proper trainable parameter integration in fitting serialization.The
trainableparameter is correctly added to theserialize_networkmethod and properly propagated to theFittingNetconstructor, enabling trainability control during network serialization.Also applies to: 203-203
deepmd/tf/descriptor/se.py (1)
195-195: Well-implemented trainable parameter for SE descriptor serialization.The
trainableparameter is properly added with documentation and correctly propagated to allEmbeddingNetinstances, including both excluded type networks and those initialized from variables, ensuring comprehensive trainability control.Also applies to: 218-219, 242-242, 250-250, 286-286
deepmd/tf/descriptor/se_t.py (1)
729-729: Consistent trainable parameter implementation for SE-T descriptor.The
trainableparameter follows the same well-established pattern as other descriptor modules, with proper documentation and correct propagation to allEmbeddingNetconstructors in both the helper function and main initialization logic.Also applies to: 752-753, 777-777, 812-812
deepmd/pt/model/descriptor/dpa2.py (6)
96-96: LGTM! Trainable parameter properly added.The
trainableparameter is correctly added with proper typing and a backward-compatible default value.
191-191: LGTM! Consistent parameter propagation to repinit modules.The
trainableparameter is correctly propagated to bothDescrptBlockSeAttenandDescrptBlockSeTTebdinstances.Also applies to: 211-211
252-252: LGTM! Parameter propagation to core modules.The
trainableparameter is properly passed toDescrptBlockRepformersandTypeEmbedNetinstances, maintaining consistency.Also applies to: 280-280
306-306: LGTM! MLPLayer instances receive trainable parameter.The
trainableparameter is correctly propagated to both theg1_shape_tranformand conditionaltebd_transformMLPLayer instances.Also applies to: 316-316
288-288: LGTM! Proper parameter storage and gradient control.Storing the trainable parameter and explicitly setting
requires_gradon all parameters ensures consistent behavior across the descriptor.Also applies to: 327-328
556-556: LGTM! Trainable parameter properly serialized.Including the
trainableparameter in serialization ensures the setting is preserved during model save/load operations.deepmd/dpmodel/descriptor/repflows.py (3)
170-171: LGTM! Trainable parameter properly added to DescrptBlockRepflows.The parameter is correctly added with proper documentation, typing, and backward-compatible default value.
Also applies to: 210-210
274-288: LGTM! Consistent parameter propagation in DescrptBlockRepflows.The
trainableparameter is properly propagated to allNativeLayerinstances andRepFlowLayerinstances in the constructor.Also applies to: 319-319
876-876: LGTM! Comprehensive trainable parameter support in RepFlowLayer.The
RepFlowLayerclass properly receives the trainable parameter and systematically propagates it to all internalNativeLayerinstances and residual components.Also applies to: 934-1097
deepmd/pd/model/descriptor/se_atten.py (6)
209-209: LGTM!The
trainableparameter is correctly passed to theNeighborGatedAttentionconstructor.
234-234: LGTM!The
trainableparameter is correctly passed to bothEmbeddingNetconstructors for main and strip mode filter layers.Also applies to: 248-248
698-698: LGTM!The
trainableparameter is correctly passed to eachNeighborGatedAttentionLayerin the attention layers loop.
834-834: LGTM!The
trainableparameter is correctly passed to theGatedAttentionLayerconstructor.
948-948: LGTM!The
trainableparameter is correctly passed to bothin_projandout_projMLPLayerconstructors.Also applies to: 959-959
84-84: Overall implementation is consistent and well-structured.The
trainableparameter has been correctly added to all relevant classes and properly propagated through the initialization chain. The implementation aligns well with the PR objectives. However, please address the serialization and documentation issues identified in the previous comments to ensure complete functionality.Also applies to: 209-209, 234-234, 248-248, 662-662, 698-698, 806-806, 834-834, 915-915, 948-948, 959-959
deepmd/tf/descriptor/se_atten.py (5)
1596-1596: LGTM! Correctly propagates trainable flag to attention layer.The addition of
trainable=self.trainableensures that the in_proj NativeLayer inherits the trainable property from the parent descriptor, maintaining consistency in the attention mechanism serialization.
1615-1615: LGTM! Maintains consistency in trainable flag propagation.The out_proj NativeLayer now correctly inherits the trainable property, ensuring both in_proj and out_proj layers have consistent trainable settings.
1659-1659: LGTM! Well-designed parameter addition.The new
trainableparameter has proper type annotation, sensible default value, and follows consistent naming conventions. This enables control over trainability in network serialization.
1685-1686: LGTM! Clear and properly formatted parameter documentation.The documentation for the
trainableparameter follows the existing style and clearly describes its purpose, maintaining consistency with the rest of the method's docstring.
1727-1727: LGTM! Completes trainable flag propagation chain.The
trainableparameter is correctly passed from the method parameter to the EmbeddingNet constructor, ensuring that the network respects the specified trainable setting during serialization.deepmd/pt/model/descriptor/se_atten.py (2)
103-103: Parameter addition looks good.The
trainableparameter is correctly added with appropriate default value for backward compatibility.
228-228: Consistent parameter propagation to child components.The
trainableparameter is correctly passed to all child components (NeighborGatedAttention,EmbeddingNetinstances).Also applies to: 253-253, 267-267
deepmd/pd/model/descriptor/repformer_layer.py (3)
45-45: Correct implementation for residual tensor trainability.The logic
residual.stop_gradient = not trainablecorrectly controls whether the residual tensor participates in gradient computation.Also applies to: 75-75
166-166: Parameter addition and propagation look good.The
trainableparameter is correctly added and passed to theMLPLayer.Also applies to: 179-179
630-630: Extensive and consistent parameter propagation.The
trainableparameter is correctly propagated to all child components and residual tensors throughout theRepformerLayerclass. The implementation properly handles all the various layer types and update styles.Also applies to: 691-691, 701-701, 718-718, 728-728, 737-737, 747-747, 760-760, 769-769, 779-779, 789-789, 799-799, 811-811, 819-820, 836-836, 846-847, 856-856, 867-867, 877-877
deepmd/dpmodel/descriptor/dpa1.py (4)
322-323: LGTM: Proper trainable parameter propagationThe trainable parameter is correctly passed to both
DescrptBlockSeAttenandTypeEmbedNetcomponents, and properly stored as an instance variable for later use in serialization.Also applies to: 337-338, 341-341
696-696: LGTM: Consistent trainable parameter implementationThe trainable parameter is properly added to the
DescrptBlockSeAttenconstructor and correctly propagated to all embedding networks and attention components.Also applies to: 747-747, 763-763, 782-782
1195-1195: LGTM: Attention layer trainable parameter supportThe trainable parameter is correctly implemented in both
NeighborGatedAttentionandNeighborGatedAttentionLayerclasses, maintaining consistency with the overall design pattern.Also applies to: 1229-1229, 1325-1325, 1352-1352
1433-1433: LGTM: Complete trainable parameter chainThe
GatedAttentionLayerproperly implements trainable parameter support and correctly passes it to the underlyingNativeLayercomponents, completing the parameter propagation chain.Also applies to: 1463-1463, 1472-1472
deepmd/pt/model/descriptor/repformer_layer.py (6)
44-44: LGTM: Correct trainable parameter implementationThe
get_residualfunction properly uses the trainable parameter to control therequires_gradattribute of the residual tensor, which is the correct approach for PyTorch parameter trainability.Also applies to: 71-71
163-163: Good: Trainable parameter properly added to constructorThe trainable parameter is correctly added to the
Atten2Mapconstructor and properly passed to theMLPLayer.Also applies to: 176-176
290-290: LGTM: Proper trainable parameter propagationThe trainable parameter is correctly added to the constructor and properly passed to both
mapvandhead_mapMLPLayer instances.Also applies to: 301-301, 308-308
378-378: LGTM: Trainable parameter correctly implementedThe trainable parameter is properly added to the constructor and correctly passed to the
head_mapMLPLayer.Also applies to: 384-390
457-457: LGTM: Comprehensive trainable parameter implementationThe trainable parameter is correctly added to the
LocalAttenconstructor and properly propagated to all three MLPLayer instances (mapq,mapkv, andhead_map).Also applies to: 469-469, 477-477, 484-484
620-620: Excellent: Comprehensive trainable parameter implementationThe
RepformerLayerclass demonstrates excellent implementation of the trainable parameter, correctly adding it to the constructor and systematically propagating it to all sub-components including MLPLayers, residual parameters, and attention modules. This ensures complete control over parameter trainability throughout the layer hierarchy.Also applies to: 681-681, 691-691, 708-708, 718-718, 727-727, 737-737, 750-750, 759-759, 769-769, 779-779, 789-789, 801-801, 809-809, 826-826, 836-836, 846-846, 857-857, 867-867
deepmd/dpmodel/utils/network.py (6)
90-91: LGTM: Clear documentation for the new trainable parameter.The documentation properly describes the new parameter with appropriate formatting and default value indication.
104-107: Good practice: Setting trainable before array attributes.The comment and early assignment of
self.trainableis well-placed, ensuring the trainable state is established before any array operations that might depend on it.
168-168: Proper version compatibility handling.The version compatibility check correctly allows both version 1 and 2, ensuring backward compatibility.
249-250: Complete the getter interface.The
__getitem__method now properly supports thetrainablekey, maintaining consistency with other layer attributes.
440-441: Consistent trainable parameter propagation.The
trainableparameter is properly threaded through the class hierarchy fromLayerNormtoEmbeddingNettoFittingNet, ensuring all layers in the network respect the trainability setting.Also applies to: 693-694, 709-710, 800-801, 809-810, 823-824
147-154:✅ Verification successful
Verify version compatibility across the codebase.
The serialization version has been incremented from 1 to 2 to accommodate the new
trainablefield. This is a breaking change that requires careful coordination.
🏁 Script executed:
#!/bin/bash # Description: Verify that all deserialization code can handle the version bump and new trainable field # Expected: All deserialize methods should handle both version 1 and 2, with proper defaults for trainable # Check for other classes that might serialize/deserialize NativeLayer rg -A 10 -B 5 "deserialize.*NativeLayer|NativeLayer.*deserialize" # Check for hardcoded version references that might need updating rg -A 3 -B 3 "version.*[\"']1[\"']|[\"']1[\"'].*version" --type pyLength of output: 47637
NativeLayer Deserialization Handles Version 1 Default Trainable
The version bump to 2 is safe:
NativeLayer.deserializeusesdata.pop("@version", 1)andcheck_version_compatibility, so when deserializing a version-1 payload it defaults to version 1. Sincetrainableonly exists in version 2, any code that doestrainable = data.pop("trainable", True)(or equivalent default logic) will automatically behave as before for version 1. I’ve confirmed:
•
@versionis popped with default 1 indeepmd/dpmodel/utils/network.py
•check_version_compatibilityallows versions 1–2
• No hard-coded “1” version checks elsewhere in the repoNo follow-up changes are needed.
deepmd/dpmodel/descriptor/repformers.py (8)
167-169: Comprehensive trainable parameter integration.The
DescrptBlockRepformersclass properly adds thetrainableparameter with appropriate documentation and default value.Also applies to: 209-210
257-263: Consistent parameter passing to embedded layers.The
g2_embdlayer properly receives thetrainableparameter, maintaining consistency with the overall design.
300-301: Proper propagation to RepformerLayer instances.The
trainableparameter is correctly passed to eachRepformerLayerinstance in the loop, ensuring all layers in the stack respect the trainability setting.
858-859: Attention mechanism trainability support.The
Atten2Mapclass properly integrates thetrainableparameter and passes it to its internalmapqklayer.Also applies to: 871-872
983-984: Multi-head attention trainability.The
Atten2MultiHeadApplyclass correctly propagates thetrainableparameter to both itsmapvandhead_maplayers.Also applies to: 994-995, 1001-1002
1074-1086: Clean multi-line parameter formatting.The
Atten2EquiVarApplyclass uses clean multi-line formatting for theNativeLayerconstructor call while properly passing thetrainableparameter.
1155-1156: Local attention trainability.The
LocalAttenclass systematically passes thetrainableparameter to all its internal layers (mapq,mapkv,head_map), ensuring complete coverage.Also applies to: 1167-1168, 1175-1176, 1182-1183
1321-1322: Comprehensive RepformerLayer trainability integration.The
RepformerLayerclass extensively integrates thetrainableparameter:
- Passes it to all
NativeLayerinstances- Includes it in
get_residualcalls for residual connections- Propagates it to attention mechanism components (
Atten2Map,Atten2MultiHeadApply,Atten2EquiVarApply,LocalAtten)- Maintains consistency across all conditional layer instantiations
This ensures that the entire repformer architecture respects the trainability setting at every level.
Also applies to: 1381-1382, 1391-1392, 1408-1409, 1418-1419, 1427-1428, 1437-1438, 1450-1451, 1459-1460, 1469-1470, 1479-1480, 1489-1490, 1501-1502, 1509-1510, 1526-1527, 1536-1537, 1546-1547, 1557-1558, 1567-1568
Signed-off-by: Jinzhe Zeng <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
…bool or list Signed-off-by: Jinzhe Zeng <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## devel #4793 +/- ##
==========================================
- Coverage 84.80% 84.57% -0.24%
==========================================
Files 698 699 +1
Lines 67798 68070 +272
Branches 3542 3541 -1
==========================================
+ Hits 57494 57567 +73
- Misses 9171 9369 +198
- Partials 1133 1134 +1 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Co-authored-by: Duo <[email protected]> Signed-off-by: Jinzhe Zeng <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
trainableparameter in the layer.Summary by CodeRabbit
Summary by CodeRabbit
New Features
trainableparameter to numerous model components, descriptors, embedding and fitting networks, and attention layers, allowing users to enable or disable parameter trainability.trainableflag across submodules and layers in PyTorch, TensorFlow, and JAX backends.trainableattribute, maintaining trainability settings during model save and load.Tests
trainableflag, validating behavior with non-trainable components.