Skip to content

Conversation

@njzjz
Copy link
Member

@njzjz njzjz commented Jun 10, 2025

  1. For dpmodel, pt, and pd, pass the trainable parameter to the layer (not actually used in this PR).
  2. For JAX, support the trainable parameter in the layer.
  3. "trainable" is now serialized in dpmodel, tf, pt, pd, etc.
  4. Support trainable in TF dipole & polar fitting.

Summary by CodeRabbit

Summary by CodeRabbit

  • New Features

    • Added a trainable parameter to numerous model components, descriptors, embedding and fitting networks, and attention layers, allowing users to enable or disable parameter trainability.
    • Ensured consistent propagation of the trainable flag across submodules and layers in PyTorch, TensorFlow, and JAX backends.
    • Enhanced serialization and deserialization processes to include the trainable attribute, maintaining trainability settings during model save and load.
    • Implemented conditional parameter wrapping in the JAX backend to differentiate trainable parameters from fixed variables.
  • Tests

    • Updated test configurations to include the trainable flag, validating behavior with non-trainable components.

Copilot AI review requested due to automatic review settings June 10, 2025 11:45
1. For dpmodel, pt, and pd, pass the trainable parameter to the layer (not actually used in this PR).
2. For JAX, support the `trainable` parameter in the layer.

Signed-off-by: Jinzhe Zeng <[email protected]>
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR propagates a new trainable flag throughout various descriptor and network components in the DeepMD library (Paddle, PyTorch, and JAX backends).

  • Add trainable parameter to many layer and block constructors in pd and dpmodel modules
  • Ensure MLPLayer/NativeLayer receive and forward the trainable flag
  • Update JAX backend to wrap parameters differently based on trainable

Reviewed Changes

Copilot reviewed 48 out of 48 changed files in this pull request and generated no comments.

Show a summary per file
File Description
deepmd/pd/model/descriptor/repformers.py Added trainable to block and layer constructors
deepmd/pd/model/descriptor/repformer_layer.py Added trainable to attention and MLP layers
deepmd/pd/model/descriptor/repflows.py Added trainable to block and layer constructors
deepmd/pd/model/descriptor/repflow_layer.py Added trainable to repflow layer constructors
deepmd/pd/model/descriptor/dpa3.py Added trainable to subclass initialization
deepmd/pd/model/descriptor/dpa2.py Added trainable to subclass initialization
deepmd/pd/model/descriptor/dpa1.py Added trainable to subclass initialization
deepmd/jax/utils/network.py Support JAX trainable in __setattr__
deepmd/dpmodel/utils/type_embed.py Added trainable to embedding utility
deepmd/dpmodel/utils/network.py Added trainable to native and fitting networks
deepmd/dpmodel/fitting/general_fitting.py Propagated trainable to general fitting nets
deepmd/dpmodel/descriptor/se_t_tebd.py Added trainable to SE‐TEBD descriptor
deepmd/dpmodel/descriptor/se_t.py Added trainable to SE‐T descriptor
deepmd/dpmodel/descriptor/se_r.py Added trainable to SE‐R descriptor
deepmd/dpmodel/descriptor/se_e2_a.py Added trainable to SE‐E2A descriptor
deepmd/dpmodel/descriptor/repformers.py Added trainable to Repformers block
deepmd/dpmodel/descriptor/repflows.py Added trainable to Repflows block
deepmd/dpmodel/descriptor/dpa3.py Added trainable to DPA3 subclass init
deepmd/dpmodel/descriptor/dpa2.py Added trainable to DPA2 subclass init
deepmd/dpmodel/descriptor/dpa1.py Added trainable to DPA1 subclass init
Comments suppressed due to low confidence (2)

deepmd/pd/model/descriptor/repformers.py:90

  • The trainable parameter was added to the constructor signature but is not described in the class docstring. Please update the docstring to include trainable and its purpose.
        trainable: bool = True,

deepmd/jax/utils/network.py:48

  • The __setattr__ override references self.trainable before it may be initialized, which can lead to an AttributeError. Consider setting self.trainable in the object's __init__ before any attribute assignments occur.
                if self.trainable:

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jun 10, 2025

📝 Walkthrough
## Walkthrough

A new `trainable` boolean parameter was introduced and propagated across a wide range of descriptor, network, and fitting classes in the codebase. This parameter, defaulting to `True`, enables explicit control over whether the parameters of neural network layers and submodules are trainable. The parameter is threaded through constructors, serialization, and deserialization methods, and is now consistently handled in descriptor, embedding, attention, repformer, repflow, and fitting modules for all backends (TensorFlow, PyTorch, JAX, and custom frameworks). Corresponding test configurations were updated to set `trainable` to `False`.

## Changes

| File(s)                                                                                 | Change Summary                                                                                   |
|-----------------------------------------------------------------------------------------|--------------------------------------------------------------------------------------------------|
| deepmd/dpmodel/descriptor/dpa1.py<br>deepmd/pd/model/descriptor/dpa1.py<br>deepmd/pt/model/descriptor/dpa1.py | Added `trainable` parameter to descriptor and submodule constructors; propagated to components.  |
| deepmd/dpmodel/descriptor/dpa2.py<br>deepmd/pd/model/descriptor/dpa2.py<br>deepmd/pt/model/descriptor/dpa2.py | Added and propagated `trainable` to all submodules and internal layers in DPA2 descriptors.      |
| deepmd/dpmodel/descriptor/dpa3.py<br>deepmd/pd/model/descriptor/dpa3.py<br>deepmd/pt/model/descriptor/dpa3.py | Propagated `trainable` parameter to repflows and type embedding submodules.                      |
| deepmd/dpmodel/descriptor/repflows.py<br>deepmd/pd/model/descriptor/repflows.py<br>deepmd/pt/model/descriptor/repflows.py | Added `trainable` to repflows block and layers; propagated to all internal layers.               |
| deepmd/dpmodel/descriptor/repformers.py<br>deepmd/pd/model/descriptor/repformers.py<br>deepmd/pt/model/descriptor/repformers.py | Added `trainable` to repformer blocks, layers, and attention submodules; propagated accordingly. |
| deepmd/dpmodel/descriptor/se_e2_a.py<br>deepmd/pt/model/descriptor/se_a.py              | Propagated `trainable` to embedding network initialization in descriptor blocks.                 |
| deepmd/dpmodel/descriptor/se_r.py<br>deepmd/pt/model/descriptor/se_r.py                 | Propagated `trainable` to embedding networks for each atom type in descriptor blocks.            |
| deepmd/dpmodel/descriptor/se_t.py<br>deepmd/pt/model/descriptor/se_t.py                 | Propagated `trainable` to embedding networks in descriptor blocks for each embedding index.      |
| deepmd/dpmodel/descriptor/se_t_tebd.py<br>deepmd/pd/model/descriptor/se_t_tebd.py<br>deepmd/pt/model/descriptor/se_t_tebd.py | Added and propagated `trainable` to SeTTebd descriptor and block classes and their embeddings.   |
| deepmd/dpmodel/fitting/general_fitting.py<br>deepmd/pt/model/task/fitting.py            | Propagated `trainable` to fitting network instances in general fitting classes.                  |
| deepmd/dpmodel/utils/network.py                                                         | Added `trainable` to NativeLayer and related classes; handled in serialization/deserialization.  |
| deepmd/dpmodel/utils/type_embed.py<br>deepmd/pd/model/network/network.py<br>deepmd/pt/model/network/network.py | Propagated `trainable` to type embedding network constructors and internal networks.             |
| deepmd/jax/utils/network.py                                                             | Conditional parameter wrapping in NativeLayer based on `trainable` attribute.                    |
| deepmd/pd/model/descriptor/repflow_layer.py<br>deepmd/pt/model/descriptor/repflow_layer.py | Added `trainable` to RepFlowLayer; propagated to internal MLP and residuals.                     |
| deepmd/pd/model/descriptor/repformer_layer.py<br>deepmd/pt/model/descriptor/repformer_layer.py | Added `trainable` to attention and repformer layer classes and residuals; propagated internally. |
| deepmd/pd/model/descriptor/se_a.py                                                      | Propagated `trainable` to EmbeddingNet in filter layers.                                         |
| deepmd/pd/model/descriptor/se_atten.py<br>deepmd/pt/model/descriptor/se_atten.py        | Added `trainable` to attention block, attention layers, and embedding networks; propagated.      |
| deepmd/pd/model/descriptor/se_t.py                                                      | Propagated `trainable` to EmbeddingNet in filter layers.                                         |
| deepmd/pd/model/network/mlp.py<br>deepmd/pt/model/network/mlp.py                        | Added `trainable` to MLPLayer constructor; handled in serialization/deserialization.             |
| deepmd/tf/descriptor/se.py<br>deepmd/tf/descriptor/se_t.py                              | Added `trainable` parameter to `serialize_network` methods; propagated to EmbeddingNet.          |
| deepmd/tf/descriptor/se_atten.py                                                        | Added `trainable` to serialization of attention layers and strip networks; updated docstrings.    |
| deepmd/tf/fit/fitting.py                                                                | Added `trainable` to fitting network serialization; propagated to FittingNet.                    |
| source/tests/consistent/descriptor/test_dpa1.py                                         | Set `"trainable": False` in test configuration dictionary.                                       |
| source/tests/consistent/descriptor/test_dpa2.py<br>source/tests/consistent/descriptor/test_dpa3.py | Changed `"trainable"` in test configuration from `True` to `False`.                              |

## Sequence Diagram(s)

```mermaid
sequenceDiagram
    participant User
    participant Descriptor
    participant Submodule
    participant Layer

    User->>Descriptor: DescriptorClass(trainable=...)
    Descriptor->>Submodule: Submodule(trainable=...)
    Submodule->>Layer: Layer(trainable=...)
    Note right of Layer: Layer parameters set as trainable or not
sequenceDiagram
    participant Descriptor
    participant Serializer

    Descriptor->>Serializer: serialize(trainable=...)
    Serializer->>Descriptor: EmbeddingNet(trainable=...)
    Note right of Serializer: Serialization includes trainable flag
Loading
sequenceDiagram
    participant NativeLayer
    participant JAXWrapper

    NativeLayer->>JAXWrapper: __setattr__('w'/'b'/'idt', value)
    alt trainable == True
        JAXWrapper->>NativeLayer: wrap with ArrayAPIParam
    else trainable == False
        JAXWrapper->>NativeLayer: wrap with ArrayAPIVariable
    end
Loading

Possibly related PRs

  • pd: support dpa2 #4418: Adds the DescrptDPA2 class and related descriptor modules for Paddle backend support. Related at the domain level but does not overlap in code-level changes with this PR.

Suggested labels

OP, C++, LAMMPS, Docs

Suggested reviewers

  • wanghan-iapcm
  • iProzd

</details>

<!-- walkthrough_end -->


---

<details>
<summary>📜 Recent review details</summary>

**Configuration used: CodeRabbit UI**
**Review profile: CHILL**
**Plan: Pro**


<details>
<summary>📥 Commits</summary>

Reviewing files that changed from the base of the PR and between d5840e66e814e49ba64468568ca8696776cbddae and 94ce3467a93a136fcafecbbf0f9418b50e37dcca.

</details>

<details>
<summary>📒 Files selected for processing (1)</summary>

* `deepmd/dpmodel/utils/network.py` (10 hunks)

</details>

<details>
<summary>🚧 Files skipped from review as they are similar to previous changes (1)</summary>

* deepmd/dpmodel/utils/network.py

</details>

<details>
<summary>⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (29)</summary>

* GitHub Check: Test Python (6, 3.9)
* GitHub Check: Test Python (5, 3.12)
* GitHub Check: Test Python (6, 3.12)
* GitHub Check: Test Python (3, 3.9)
* GitHub Check: Test Python (5, 3.9)
* GitHub Check: Test Python (4, 3.12)
* GitHub Check: Test Python (3, 3.12)
* GitHub Check: Test Python (2, 3.12)
* GitHub Check: Test Python (1, 3.12)
* GitHub Check: Test Python (4, 3.9)
* GitHub Check: Test Python (2, 3.9)
* GitHub Check: Test Python (1, 3.9)
* GitHub Check: Analyze (python)
* GitHub Check: Analyze (c-cpp)
* GitHub Check: Build C++ (clang, clang)
* GitHub Check: Build wheels for cp310-manylinux_aarch64
* GitHub Check: Build C++ (cuda120, cuda)
* GitHub Check: Build C++ (cuda, cuda)
* GitHub Check: Build wheels for cp311-macosx_arm64
* GitHub Check: Build C++ (rocm, rocm)
* GitHub Check: Build wheels for cp311-macosx_x86_64
* GitHub Check: Build C++ (cpu, cpu)
* GitHub Check: Build wheels for cp311-manylinux_x86_64
* GitHub Check: Build wheels for cp311-win_amd64
* GitHub Check: Build wheels for cp311-manylinux_x86_64
* GitHub Check: Test C++ (true)
* GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
* GitHub Check: Test C++ (false)
* GitHub Check: Build C library (2.14, >=2.5.0rc0,<2.15, libdeepmd_c_cu11.tar.gz)

</details>

</details>
<!-- internal state start -->


<!-- = -->

<!-- internal state end -->
<!-- finishing_touch_checkbox_start -->

<details open="true">
<summary>✨ Finishing Touches</summary>

- [ ] <!-- {"checkboxId": "7962f53c-55bc-4827-bfbf-6a18da830691"} --> 📝 Generate Docstrings

</details>

<!-- finishing_touch_checkbox_end -->
<!-- tips_start -->

---

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

<details>
<summary>❤️ Share</summary>

- [X](https://twitter.com/intent/tweet?text=I%20just%20used%20%40coderabbitai%20for%20my%20code%20review%2C%20and%20it%27s%20fantastic%21%20It%27s%20free%20for%20OSS%20and%20offers%20a%20free%20trial%20for%20the%20proprietary%20code.%20Check%20it%20out%3A&url=https%3A//coderabbit.ai)
- [Mastodon](https://mastodon.social/share?text=I%20just%20used%20%40coderabbitai%20for%20my%20code%20review%2C%20and%20it%27s%20fantastic%21%20It%27s%20free%20for%20OSS%20and%20offers%20a%20free%20trial%20for%20the%20proprietary%20code.%20Check%20it%20out%3A%20https%3A%2F%2Fcoderabbit.ai)
- [Reddit](https://www.reddit.com/submit?title=Great%20tool%20for%20code%20review%20-%20CodeRabbit&text=I%20just%20used%20CodeRabbit%20for%20my%20code%20review%2C%20and%20it%27s%20fantastic%21%20It%27s%20free%20for%20OSS%20and%20offers%20a%20free%20trial%20for%20proprietary%20code.%20Check%20it%20out%3A%20https%3A//coderabbit.ai)
- [LinkedIn](https://www.linkedin.com/sharing/share-offsite/?url=https%3A%2F%2Fcoderabbit.ai&mini=true&title=Great%20tool%20for%20code%20review%20-%20CodeRabbit&summary=I%20just%20used%20CodeRabbit%20for%20my%20code%20review%2C%20and%20it%27s%20fantastic%21%20It%27s%20free%20for%20OSS%20and%20offers%20a%20free%20trial%20for%20proprietary%20code)

</details>

<details>
<summary>🪧 Tips</summary>

### Chat

There are 3 ways to chat with [CodeRabbit](https://coderabbit.ai?utm_source=oss&utm_medium=github&utm_campaign=deepmodeling/deepmd-kit&utm_content=4793):

- Review comments: Directly reply to a review comment made by CodeRabbit. Example:
  - `I pushed a fix in commit <commit_id>, please review it.`
  - `Explain this complex logic.`
  - `Open a follow-up GitHub issue for this discussion.`
- Files and specific lines of code (under the "Files changed" tab): Tag `@coderabbitai` in a new review comment at the desired location with your query. Examples:
  - `@coderabbitai explain this code block.`
  -	`@coderabbitai modularize this function.`
- PR comments: Tag `@coderabbitai` in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
  - `@coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.`
  - `@coderabbitai read src/utils.ts and explain its main purpose.`
  - `@coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.`
  - `@coderabbitai help me debug CodeRabbit configuration file.`

### Support

Need help? Create a ticket on our [support page](https://www.coderabbit.ai/contact-us/support) for assistance with any issues or questions.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

### CodeRabbit Commands (Invoked using PR comments)

- `@coderabbitai pause` to pause the reviews on a PR.
- `@coderabbitai resume` to resume the paused reviews.
- `@coderabbitai review` to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
- `@coderabbitai full review` to do a full review from scratch and review all the files again.
- `@coderabbitai summary` to regenerate the summary of the PR.
- `@coderabbitai generate docstrings` to [generate docstrings](https://docs.coderabbit.ai/finishing-touches/docstrings) for this PR.
- `@coderabbitai generate sequence diagram` to generate a sequence diagram of the changes in this PR.
- `@coderabbitai resolve` resolve all the CodeRabbit review comments.
- `@coderabbitai configuration` to show the current CodeRabbit configuration for the repository.
- `@coderabbitai help` to get help.

### Other keywords and placeholders

- Add `@coderabbitai ignore` anywhere in the PR description to prevent this PR from being reviewed.
- Add `@coderabbitai summary` to generate the high-level summary at a specific location in the PR description.
- Add `@coderabbitai` anywhere in the PR title to generate the title automatically.

### CodeRabbit Configuration File (`.coderabbit.yaml`)

- You can programmatically configure CodeRabbit by adding a `.coderabbit.yaml` file to the root of your repository.
- Please see the [configuration documentation](https://docs.coderabbit.ai/guides/configure-coderabbit) for more information.
- If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: `# yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json`

### Documentation and Community

- Visit our [Documentation](https://docs.coderabbit.ai) for detailed information on how to use CodeRabbit.
- Join our [Discord Community](http://discord.gg/coderabbit) to get help, request features, and share feedback.
- Follow us on [X/Twitter](https://twitter.com/coderabbitai) for updates and announcements.

</details>

<!-- tips_end -->

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🔭 Outside diff range comments (14)
deepmd/dpmodel/descriptor/repflows.py (1)

1835-1901: 🛠️ Refactor suggestion

Missing trainable parameter in serialization.

The serialize() method doesn't include the trainable parameter, which means this setting won't be preserved during model save/load operations. This is inconsistent with the pattern in dpa2.py where the trainable parameter is properly serialized.

Add the trainable parameter to both DescrptBlockRepflows and RepFlowLayer serialization:

# In DescrptBlockRepflows.serialize() method (around line 732):
return {
    "e_rcut": self.e_rcut,
    "e_rcut_smth": self.e_rcut_smth,
+   "trainable": self.trainable,
    # ... other parameters
}

# In RepFlowLayer.serialize() method:
data = {
    "@class": "RepFlowLayer",
    "@version": 2,
    "e_rcut": self.e_rcut,
+   "trainable": self.trainable,
    # ... other parameters
}

Also ensure the deserialize() methods handle the trainable parameter appropriately.

deepmd/pt/model/descriptor/se_atten.py (3)

771-795: ⚠️ Potential issue

Critical: Missing trainable parameter in serialization.

The serialize() method doesn't include the trainable parameter in the returned dictionary. This will cause the parameter to be lost during save/load cycles, potentially leading to models being loaded with incorrect trainability settings.

Apply this fix to include the trainable parameter in serialization:

         return {
             "@class": "NeighborGatedAttention",
             "@version": 1,
             "layer_num": self.layer_num,
             "nnei": self.nnei,
             "embed_dim": self.embed_dim,
             "hidden_dim": self.hidden_dim,
             "dotr": self.dotr,
             "do_mask": self.do_mask,
             "scaling_factor": self.scaling_factor,
             "normalize": self.normalize,
             "temperature": self.temperature,
             "trainable_ln": self.trainable_ln,
             "ln_eps": self.ln_eps,
             "precision": self.precision,
+            "trainable": self.trainable,
             "attention_layers": [layer.serialize() for layer in self.attention_layers],
         }

Also add self.trainable = trainable in the constructor to store the parameter.


883-922: ⚠️ Potential issue

Critical: Missing trainable parameter in serialization.

Similar to the parent class, the NeighborGatedAttentionLayer.serialize() method doesn't include the trainable parameter.

Apply this fix:

         return {
             "nnei": self.nnei,
             "embed_dim": self.embed_dim,
             "hidden_dim": self.hidden_dim,
             "dotr": self.dotr,
             "do_mask": self.do_mask,
             "scaling_factor": self.scaling_factor,
             "normalize": self.normalize,
             "temperature": self.temperature,
             "trainable_ln": self.trainable_ln,
             "ln_eps": self.ln_eps,
             "precision": self.precision,
+            "trainable": self.trainable,
             "attention_layer": self.attention_layer.serialize(),
             "attn_layer_norm": self.attn_layer_norm.serialize(),
         }

Also add self.trainable = trainable in the constructor.


1073-1113: ⚠️ Potential issue

Critical: Missing trainable parameter in serialization.

The GatedAttentionLayer.serialize() method also lacks the trainable parameter.

Apply this fix:

         return {
             "nnei": self.nnei,
             "embed_dim": self.embed_dim,
             "hidden_dim": self.hidden_dim,
             "num_heads": self.num_heads,
             "dotr": self.dotr,
             "do_mask": self.do_mask,
             "scaling_factor": self.scaling_factor,
             "normalize": self.normalize,
             "temperature": self.temperature,
             "bias": self.bias,
             "smooth": self.smooth,
             "precision": self.precision,
+            "trainable": self.trainable,
             "in_proj": self.in_proj.serialize(),
             "out_proj": self.out_proj.serialize(),
         }

Also add self.trainable = trainable in the constructor.

deepmd/pd/model/descriptor/repformer_layer.py (5)

247-283: ⚠️ Potential issue

Critical: Missing trainable parameter in serialization.

The Atten2Map.serialize() method doesn't include the trainable parameter, causing state loss during save/load cycles.

Apply this fix:

         return {
             "@class": "Atten2Map",
             "@version": 1,
             "input_dim": self.input_dim,
             "hidden_dim": self.hidden_dim,
             "head_num": self.head_num,
             "has_gate": self.has_gate,
             "smooth": self.smooth,
             "attnw_shift": self.attnw_shift,
             "precision": self.precision,
+            "trainable": self.trainable,
             "mapqk": self.mapqk.serialize(),
         }

Also add self.trainable = trainable in the constructor.


338-373: ⚠️ Potential issue

Critical: Missing trainable parameter in serialization.

The Atten2MultiHeadApply.serialize() method lacks the trainable parameter.

Apply this fix:

         return {
             "@class": "Atten2MultiHeadApply",
             "@version": 1,
             "input_dim": self.input_dim,
             "head_num": self.head_num,
             "precision": self.precision,
+            "trainable": self.trainable,
             "mapv": self.mapv.serialize(),
             "head_map": self.head_map.serialize(),
         }

Also add self.trainable = trainable in the constructor.


417-449: ⚠️ Potential issue

Critical: Missing trainable parameter in serialization.

The Atten2EquiVarApply.serialize() method also lacks the trainable parameter.

Apply this fix:

         return {
             "@class": "Atten2EquiVarApply",
             "@version": 1,
             "input_dim": self.input_dim,
             "head_num": self.head_num,
             "precision": self.precision,
+            "trainable": self.trainable,
             "head_map": self.head_map.serialize(),
         }

Also add self.trainable = trainable in the constructor.


551-592: ⚠️ Potential issue

Critical: Missing trainable parameter in serialization.

The LocalAtten.serialize() method doesn't include the trainable parameter.

Apply this fix:

         return {
             "@class": "LocalAtten",
             "@version": 1,
             "input_dim": self.input_dim,
             "hidden_dim": self.hidden_dim,
             "head_num": self.head_num,
             "smooth": self.smooth,
             "attnw_shift": self.attnw_shift,
             "precision": self.precision,
+            "trainable": self.trainable,
             "mapq": self.mapq.serialize(),
             "mapkv": self.mapkv.serialize(),
             "head_map": self.head_map.serialize(),
         }

Also add self.trainable = trainable in the constructor.


1348-1450: ⚠️ Potential issue

Critical: Missing trainable parameter in serialization.

The RepformerLayer.serialize() method doesn't include the trainable parameter, which is particularly critical given the complexity of this class.

Apply this fix by adding the trainable parameter to the data dictionary:

         data = {
             "@class": "RepformerLayer",
             "@version": 2,
             "rcut": self.rcut,
             "rcut_smth": self.rcut_smth,
             "sel": self.sel,
             "ntypes": self.ntypes,
             "g1_dim": self.g1_dim,
             "g2_dim": self.g2_dim,
             "axis_neuron": self.axis_neuron,
             "update_chnnl_2": self.update_chnnl_2,
             "update_g1_has_conv": self.update_g1_has_conv,
             "update_g1_has_drrd": self.update_g1_has_drrd,
             "update_g1_has_grrg": self.update_g1_has_grrg,
             "update_g1_has_attn": self.update_g1_has_attn,
             "update_g2_has_g1g1": self.update_g2_has_g1g1,
             "update_g2_has_attn": self.update_g2_has_attn,
             "update_h2": self.update_h2,
             "attn1_hidden": self.attn1_hidden,
             "attn1_nhead": self.attn1_nhead,
             "attn2_hidden": self.attn2_hidden,
             "attn2_nhead": self.attn2_nhead,
             "attn2_has_gate": self.attn2_has_gate,
             "activation_function": self.activation_function,
             "update_style": self.update_style,
             "smooth": self.smooth,
             "precision": self.precision,
             "trainable_ln": self.trainable_ln,
             "use_sqrt_nnei": self.use_sqrt_nnei,
             "g1_out_conv": self.g1_out_conv,
             "g1_out_mlp": self.g1_out_mlp,
             "ln_eps": self.ln_eps,
+            "trainable": self.trainable,
             "linear1": self.linear1.serialize(),
         }

Also add self.trainable = trainable in the constructor.

deepmd/pt/model/descriptor/repformer_layer.py (5)

244-263: 🛠️ Refactor suggestion

Missing trainable parameter in serialization

The serialize() method should include the trainable parameter to ensure it's preserved during model save/load operations.

        return {
            "@class": "Atten2Map",
            "@version": 1,
            "input_dim": self.input_dim,
            "hidden_dim": self.hidden_dim,
            "head_num": self.head_num,
            "has_gate": self.has_gate,
            "smooth": self.smooth,
            "attnw_shift": self.attnw_shift,
            "precision": self.precision,
+           "trainable": self.trainable,
            "mapqk": self.mapqk.serialize(),
        }

Note: You'll also need to store the trainable parameter as an instance variable in the constructor.


333-349: 🛠️ Refactor suggestion

Missing trainable parameter in serialization

The serialize() method lacks the trainable parameter, which should be included for proper model persistence.

        return {
            "@class": "Atten2MultiHeadApply",
            "@version": 1,
            "input_dim": self.input_dim,
            "head_num": self.head_num,
            "precision": self.precision,
+           "trainable": self.trainable,
            "mapv": self.mapv.serialize(),
            "head_map": self.head_map.serialize(),
        }

412-427: 🛠️ Refactor suggestion

Missing trainable parameter in serialization

The serialize() method should include the trainable parameter to maintain consistency with other serialization implementations.

        return {
            "@class": "Atten2EquiVarApply",
            "@version": 1,
            "input_dim": self.input_dim,
            "head_num": self.head_num,
            "precision": self.precision,
+           "trainable": self.trainable,
            "head_map": self.head_map.serialize(),
        }

541-561: 🛠️ Refactor suggestion

Missing trainable parameter in serialization

The serialize() method should include the trainable parameter for consistency and proper model persistence.

        return {
            "@class": "LocalAtten",
            "@version": 1,
            "input_dim": self.input_dim,
            "hidden_dim": self.hidden_dim,
            "head_num": self.head_num,
            "smooth": self.smooth,
            "attnw_shift": self.attnw_shift,
            "precision": self.precision,
+           "trainable": self.trainable,
            "mapq": self.mapq.serialize(),
            "mapkv": self.mapkv.serialize(),
            "head_map": self.head_map.serialize(),
        }

1328-1430: 🛠️ Refactor suggestion

Critical: Missing trainable parameter in serialization

The serialize() method is missing the trainable parameter, which is essential for proper model persistence. This is particularly important for the main RepformerLayer class as it coordinates all sub-components.

        data = {
            "@class": "RepformerLayer",
            "@version": 2,
            "rcut": self.rcut,
            "rcut_smth": self.rcut_smth,
            "sel": self.sel,
            "ntypes": self.ntypes,
            "g1_dim": self.g1_dim,
            "g2_dim": self.g2_dim,
            "axis_neuron": self.axis_neuron,
            "update_chnnl_2": self.update_chnnl_2,
            "update_g1_has_conv": self.update_g1_has_conv,
            "update_g1_has_drrd": self.update_g1_has_drrd,
            "update_g1_has_grrg": self.update_g1_has_grrg,
            "update_g1_has_attn": self.update_g1_has_attn,
            "update_g2_has_g1g1": self.update_g2_has_g1g1,
            "update_g2_has_attn": self.update_g2_has_attn,
            "update_h2": self.update_h2,
            "attn1_hidden": self.attn1_hidden,
            "attn1_nhead": self.attn1_nhead,
            "attn2_hidden": self.attn2_hidden,
            "attn2_nhead": self.attn2_nhead,
            "attn2_has_gate": self.attn2_has_gate,
            "activation_function": self.activation_function,
            "update_style": self.update_style,
            "smooth": self.smooth,
            "precision": self.precision,
            "trainable_ln": self.trainable_ln,
            "use_sqrt_nnei": self.use_sqrt_nnei,
            "g1_out_conv": self.g1_out_conv,
            "g1_out_mlp": self.g1_out_mlp,
            "ln_eps": self.ln_eps,
+           "trainable": self.trainable,
            "linear1": self.linear1.serialize(),
        }

Also ensure self.trainable = trainable is added to the constructor to store the parameter as an instance variable.

🧹 Nitpick comments (4)
deepmd/pt/model/descriptor/repflow_layer.py (1)

67-67: Add documentation for the trainable parameter.

The new trainable parameter should be documented in the class docstring to help users understand its purpose and usage.

deepmd/pd/model/descriptor/repflow_layer.py (1)

64-64: Add documentation for the trainable parameter.

The new trainable parameter should be documented in the class docstring to help users understand its purpose and usage, consistent with the PyTorch implementation.

deepmd/pd/model/descriptor/repflows.py (1)

170-170: Add documentation for the trainable parameter.

The trainable parameter should be documented in the comprehensive class docstring (lines 51-134) to maintain the high documentation standards of this class.

deepmd/pd/model/descriptor/se_atten.py (1)

84-84: Document the new trainable parameter in the class docstring.

The trainable parameter should be documented in the class docstring along with the other parameters for consistency and clarity.

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between c46dc7d and 7d7e043.

📒 Files selected for processing (47)
  • deepmd/dpmodel/descriptor/dpa1.py (13 hunks)
  • deepmd/dpmodel/descriptor/dpa2.py (6 hunks)
  • deepmd/dpmodel/descriptor/dpa3.py (2 hunks)
  • deepmd/dpmodel/descriptor/repflows.py (18 hunks)
  • deepmd/dpmodel/descriptor/repformers.py (28 hunks)
  • deepmd/dpmodel/descriptor/se_e2_a.py (1 hunks)
  • deepmd/dpmodel/descriptor/se_r.py (1 hunks)
  • deepmd/dpmodel/descriptor/se_t.py (1 hunks)
  • deepmd/dpmodel/descriptor/se_t_tebd.py (5 hunks)
  • deepmd/dpmodel/fitting/general_fitting.py (1 hunks)
  • deepmd/dpmodel/utils/network.py (11 hunks)
  • deepmd/dpmodel/utils/type_embed.py (1 hunks)
  • deepmd/jax/utils/network.py (2 hunks)
  • deepmd/pd/model/descriptor/dpa1.py (2 hunks)
  • deepmd/pd/model/descriptor/dpa2.py (6 hunks)
  • deepmd/pd/model/descriptor/dpa3.py (2 hunks)
  • deepmd/pd/model/descriptor/repflow_layer.py (14 hunks)
  • deepmd/pd/model/descriptor/repflows.py (3 hunks)
  • deepmd/pd/model/descriptor/repformer_layer.py (24 hunks)
  • deepmd/pd/model/descriptor/repformers.py (3 hunks)
  • deepmd/pd/model/descriptor/se_a.py (1 hunks)
  • deepmd/pd/model/descriptor/se_atten.py (11 hunks)
  • deepmd/pd/model/descriptor/se_t_tebd.py (5 hunks)
  • deepmd/pd/model/network/mlp.py (2 hunks)
  • deepmd/pd/model/network/network.py (2 hunks)
  • deepmd/pt/model/descriptor/dpa1.py (2 hunks)
  • deepmd/pt/model/descriptor/dpa2.py (6 hunks)
  • deepmd/pt/model/descriptor/dpa3.py (2 hunks)
  • deepmd/pt/model/descriptor/repflow_layer.py (14 hunks)
  • deepmd/pt/model/descriptor/repflows.py (3 hunks)
  • deepmd/pt/model/descriptor/repformer_layer.py (24 hunks)
  • deepmd/pt/model/descriptor/repformers.py (4 hunks)
  • deepmd/pt/model/descriptor/se_a.py (1 hunks)
  • deepmd/pt/model/descriptor/se_atten.py (11 hunks)
  • deepmd/pt/model/descriptor/se_r.py (1 hunks)
  • deepmd/pt/model/descriptor/se_t.py (1 hunks)
  • deepmd/pt/model/descriptor/se_t_tebd.py (5 hunks)
  • deepmd/pt/model/network/mlp.py (3 hunks)
  • deepmd/pt/model/network/network.py (2 hunks)
  • deepmd/pt/model/task/fitting.py (1 hunks)
  • deepmd/tf/descriptor/se.py (4 hunks)
  • deepmd/tf/descriptor/se_atten.py (5 hunks)
  • deepmd/tf/descriptor/se_t.py (4 hunks)
  • deepmd/tf/fit/fitting.py (2 hunks)
  • source/tests/consistent/descriptor/test_dpa1.py (1 hunks)
  • source/tests/consistent/descriptor/test_dpa2.py (1 hunks)
  • source/tests/consistent/descriptor/test_dpa3.py (1 hunks)
🧰 Additional context used
🧬 Code Graph Analysis (7)
deepmd/pd/model/descriptor/repflow_layer.py (1)
deepmd/pd/model/network/mlp.py (1)
  • MLPLayer (74-298)
deepmd/jax/utils/network.py (1)
deepmd/jax/common.py (1)
  • ArrayAPIVariable (86-97)
deepmd/pt/model/descriptor/repflow_layer.py (1)
deepmd/pd/model/network/mlp.py (1)
  • MLPLayer (74-298)
deepmd/pt/model/descriptor/repformers.py (1)
deepmd/dpmodel/utils/seed.py (3)
  • child_seed (10-10)
  • child_seed (14-14)
  • child_seed (17-40)
deepmd/dpmodel/utils/network.py (4)
deepmd/utils/version.py (1)
  • check_version_compatibility (2-27)
source/tests/consistent/descriptor/test_dpa2.py (1)
  • data (92-189)
source/tests/consistent/descriptor/test_dpa1.py (1)
  • data (79-131)
source/tests/consistent/descriptor/test_dpa3.py (1)
  • data (78-134)
deepmd/pt/model/descriptor/repflows.py (1)
deepmd/pt/model/network/mlp.py (1)
  • MLPLayer (72-279)
deepmd/pt/model/descriptor/repformer_layer.py (3)
deepmd/pd/model/network/mlp.py (1)
  • MLPLayer (74-298)
deepmd/dpmodel/utils/seed.py (3)
  • child_seed (10-10)
  • child_seed (14-14)
  • child_seed (17-40)
deepmd/pd/model/descriptor/repformer_layer.py (2)
  • Atten2MultiHeadApply (286-373)
  • Atten2EquiVarApply (376-449)
⏰ Context from checks skipped due to timeout of 90000ms (28)
  • GitHub Check: Build C++ (cuda120, cuda)
  • GitHub Check: Build C++ (cpu, cpu)
  • GitHub Check: Build C++ (clang, clang)
  • GitHub Check: Build C++ (rocm, rocm)
  • GitHub Check: Build C++ (cuda, cuda)
  • GitHub Check: Test Python (3, 3.12)
  • GitHub Check: Test Python (6, 3.12)
  • GitHub Check: Test Python (3, 3.9)
  • GitHub Check: Test Python (5, 3.12)
  • GitHub Check: Test Python (5, 3.9)
  • GitHub Check: Test Python (6, 3.9)
  • GitHub Check: Test Python (4, 3.12)
  • GitHub Check: Test Python (1, 3.9)
  • GitHub Check: Test Python (4, 3.9)
  • GitHub Check: Test Python (1, 3.12)
  • GitHub Check: Test Python (2, 3.9)
  • GitHub Check: Test Python (2, 3.12)
  • GitHub Check: Build C library (2.14, >=2.5.0rc0,<2.15, libdeepmd_c_cu11.tar.gz)
  • GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
  • GitHub Check: Build wheels for cp310-manylinux_aarch64
  • GitHub Check: Build wheels for cp311-win_amd64
  • GitHub Check: Build wheels for cp311-manylinux_x86_64
  • GitHub Check: Build wheels for cp311-manylinux_x86_64
  • GitHub Check: Build wheels for cp311-macosx_x86_64
  • GitHub Check: Build wheels for cp311-macosx_arm64
  • GitHub Check: Analyze (c-cpp)
  • GitHub Check: Test C++ (false)
  • GitHub Check: Test C++ (true)
🔇 Additional comments (109)
source/tests/consistent/descriptor/test_dpa1.py (1)

130-130: LGTM! Test configuration updated appropriately.

Adding "trainable": False to the test configuration provides valuable test coverage for the non-trainable parameter setting, ensuring the DPA1 descriptor works correctly in both trainable and non-trainable modes.

source/tests/consistent/descriptor/test_dpa2.py (1)

184-184: LGTM! Consistent test configuration update.

The addition of "trainable": False maintains consistency with other descriptor tests and ensures proper validation of the non-trainable configuration for DPA2 descriptors.

source/tests/consistent/descriptor/test_dpa3.py (1)

133-133: LGTM! Completes consistent test pattern.

Adding "trainable": False to the DPA3 test configuration completes the consistent pattern across all DPA descriptor tests, ensuring comprehensive validation of the trainable parameter feature.

deepmd/pd/model/network/network.py (1)

48-48: LGTM! Well-implemented trainable parameter addition.

The trainable parameter is properly added with appropriate default value (True) for backward compatibility and correctly propagated to the internal TypeEmbedNetConsistent instance, enabling trainability control at the embedding network level.

Also applies to: 69-69

deepmd/pt/model/descriptor/se_a.py (1)

528-528: LGTM! Correct trainable parameter propagation.

The trainable parameter is properly passed to the EmbeddingNet constructor, ensuring that the trainability setting propagates correctly from the descriptor block to its internal embedding networks.

deepmd/pt/model/descriptor/dpa3.py (1)

172-172: LGTM: Proper trainable parameter propagation.

The trainable parameter is correctly propagated to both DescrptBlockRepflows and TypeEmbedNet submodules, enabling consistent trainability control throughout the descriptor hierarchy.

Also applies to: 188-188

deepmd/dpmodel/descriptor/se_e2_a.py (1)

210-210: LGTM: Consistent trainable parameter propagation to embedding networks.

The trainable parameter is properly passed to each EmbeddingNet instance within the embeddings collection, ensuring uniform trainability control across all embedding networks.

deepmd/dpmodel/utils/type_embed.py (1)

96-96: LGTM: Trainable parameter correctly forwarded to embedding network.

The trainable parameter is appropriately passed to the internal EmbeddingNet, ensuring the type embedding network's trainability is controlled consistently.

deepmd/pt/model/task/fitting.py (1)

323-323: LGTM: Trainable parameter properly propagated to fitting networks.

The trainable parameter is correctly passed to each FittingNet instance, enabling consistent trainability control across all fitting networks in the collection.

deepmd/pd/model/descriptor/dpa3.py (1)

170-170: LGTM: Consistent trainable parameter propagation across Paddle backend.

The trainable parameter is correctly propagated to both DescrptBlockRepflows and TypeEmbedNet submodules in the Paddle implementation, maintaining consistency with the PyTorch backend.

Also applies to: 186-186

deepmd/pt/model/descriptor/se_r.py (1)

145-145: LGTM! Correct implementation of trainable parameter propagation.

The change properly forwards the trainable parameter to the EmbeddingNet constructor, enabling fine-grained control over parameter trainability. The implementation is consistent with the existing PyTorch parameter handling logic (lines 149-152) that sets requires_grad appropriately.

deepmd/pd/model/descriptor/se_a.py (1)

484-484: LGTM! Proper PaddlePaddle implementation of trainable parameter.

The change correctly propagates the trainable parameter to the EmbeddingNet constructor. The implementation properly uses PaddlePaddle's stop_gradient mechanism (lines 488-491) to control parameter trainability, which is the correct equivalent to PyTorch's requires_grad.

deepmd/dpmodel/descriptor/se_t.py (1)

150-150: LGTM! Correct framework-agnostic implementation.

The change properly forwards the trainable parameter to the EmbeddingNet constructor. This framework-agnostic implementation correctly omits backend-specific parameter handling while maintaining the trainable parameter for serialization and configuration purposes (visible in the serialize method at line 431).

deepmd/dpmodel/fitting/general_fitting.py (1)

201-201: LGTM! Well-designed trainable parameter handling for fitting networks.

The change correctly propagates the trainable parameter to the FittingNet constructor. The implementation elegantly handles multiple trainable parameter formats (lines 136-139): None (defaulting to all trainable), single boolean (applied to all layers), or list of booleans (per-layer control). This provides flexible control over fitting network trainability.

deepmd/pt/model/descriptor/se_t.py (1)

578-578: LGTM: Correct propagation of trainable parameter.

The trainable parameter is properly passed to the EmbeddingNet initialization, enabling control over parameter trainability in filter layers.

deepmd/pd/model/descriptor/repformers.py (3)

90-90: LGTM: Trainable parameter added to constructor.

The trainable parameter is correctly added with a default value of True, maintaining backward compatibility.


227-232: LGTM: Trainable parameter propagated to MLPLayer.

The trainable parameter is properly passed to the MLPLayer initialization, enabling control over the g2 embedding layer's trainability.


269-269: LGTM: Trainable parameter propagated to RepformerLayer.

The trainable parameter is correctly passed to each RepformerLayer instance, ensuring consistent trainability control across all layers.

deepmd/pt/model/network/network.py (2)

256-256: LGTM: Trainable parameter added to TypeEmbedNet constructor.

The trainable parameter is correctly added with appropriate default value, maintaining backward compatibility.


277-277: LGTM: Trainable parameter propagated to TypeEmbedNetConsistent.

The trainable parameter is properly passed to the internal TypeEmbedNetConsistent instance, enabling trainability control for type embedding networks.

deepmd/pt/model/descriptor/repformers.py (4)

114-114: LGTM: Trainable parameter added to constructor.

The trainable parameter is correctly added with a default value of True, maintaining backward compatibility.


201-202: LGTM: Well-documented trainable parameter.

Good documentation that clearly explains the purpose of the trainable parameter.


252-258: LGTM: Trainable parameter propagated to MLPLayer.

The trainable parameter is properly passed to the MLPLayer initialization along with other required parameters including the child seed.


295-295: LGTM: Trainable parameter propagated to RepformerLayer.

The trainable parameter is correctly passed to each RepformerLayer instance in the loop, ensuring consistent trainability control.

deepmd/dpmodel/descriptor/se_r.py (1)

169-169: LGTM: Trainable parameter propagated to EmbeddingNet.

The trainable parameter is correctly passed to each type-specific EmbeddingNet instance, enabling control over embedding network trainability for all atom types.

deepmd/pd/model/network/mlp.py (2)

88-88: LGTM: Trainable parameter added with appropriate default.

The addition of the trainable parameter with a default value of True maintains backward compatibility while enabling explicit control over parameter trainability.


281-281: LGTM: Proper handling of trainable parameter in deserialization.

The trainable parameter is correctly extracted from the NativeLayer during deserialization and passed to the constructor.

deepmd/pt/model/descriptor/dpa1.py (2)

301-301: LGTM: Trainable parameter properly propagated to attention block.

The trainable parameter is correctly passed to the DescrptBlockSeAtten component, ensuring consistent trainability control throughout the descriptor hierarchy.


315-315: LGTM: Trainable parameter properly propagated to type embedding.

The trainable parameter is correctly passed to the TypeEmbedNet component, maintaining consistent trainability control across all subcomponents.

deepmd/pt/model/descriptor/repflows.py (4)

222-222: LGTM: Trainable parameter added with appropriate default.

The trainable parameter is properly added to the constructor with a default value of True, maintaining backward compatibility while enabling trainability control.


286-292: LGTM: Trainable parameter propagated to edge embedding layer.

The trainable parameter is correctly passed to the edge embedding MLPLayer, ensuring consistent trainability control for edge representations.


293-300: LGTM: Trainable parameter propagated to angle embedding layer.

The trainable parameter is correctly passed to the angle embedding MLPLayer, ensuring consistent trainability control for angle representations.


331-331: LGTM: Trainable parameter propagated to RepFlow layers.

The trainable parameter is correctly passed to each RepFlowLayer instance in the layer stack, ensuring consistent trainability control throughout the entire descriptor block.

deepmd/pt/model/descriptor/repflow_layer.py (1)

125-131: LGTM: Consistent trainable parameter propagation.

The trainable parameter is correctly propagated to all MLPLayer instances and get_residual calls throughout the constructor. This ensures consistent control over parameter trainability across all subcomponents.

Also applies to: 143-152, 163-172, 184-193, 222-232, 233-240, 243-252, 250-259, 270-279

deepmd/pd/model/descriptor/repflow_layer.py (1)

122-128: LGTM: Consistent backend implementation.

The PaddlePaddle implementation correctly mirrors the PyTorch version, with the trainable parameter properly propagated to all MLPLayer instances and get_residual calls. This ensures consistent behavior across different backends.

Also applies to: 143-149, 163-169, 184-190, 222-229, 230-237, 243-249, 250-256, 270-276

deepmd/pd/model/descriptor/repflows.py (1)

226-232: LGTM: Complete trainable parameter integration.

The trainable parameter is correctly propagated to both the MLPLayer instances (edge_embd and angle_embd) and all RepFlowLayer instances in the layer list. This completes the integration chain from the descriptor block level down to individual layer components.

Also applies to: 233-240, 243-272

deepmd/dpmodel/descriptor/dpa3.py (1)

360-360: LGTM! Proper trainable parameter propagation.

The trainable parameter is correctly passed to both DescrptBlockRepflows and TypeEmbedNet subcomponents, enabling consistent trainability control throughout the DPA3 descriptor.

Also applies to: 378-378

deepmd/dpmodel/descriptor/se_t_tebd.py (2)

160-160: LGTM! Consistent trainable parameter propagation in DescrptSeTTebd.

The trainable parameter is properly passed to both DescrptBlockSeTTebd and TypeEmbedNet instances, maintaining consistency with the broader codebase pattern.

Also applies to: 175-175


502-502: LGTM! Proper trainable parameter integration in DescrptBlockSeTTebd.

The trainable parameter is correctly added to the constructor with a sensible default value (True) and properly propagated to all EmbeddingNet instances, enabling fine-grained control over parameter trainability.

Also applies to: 548-548, 564-564

deepmd/jax/utils/network.py (2)

19-19: LGTM! Required import for non-trainable parameter support.

The ArrayAPIVariable import is necessary for wrapping non-trainable parameters in the JAX backend.


48-51: LGTM! Correct JAX parameter wrapping logic.

The conditional wrapping logic properly distinguishes between trainable and non-trainable parameters by using ArrayAPIParam for trainable parameters (which support gradients) and ArrayAPIVariable for non-trainable parameters (which do not).

deepmd/pt/model/network/mlp.py (2)

86-86: LGTM! Proper trainable parameter integration.

The trainable parameter is correctly added to the constructor with a sensible default value (True) and stored as an instance attribute.

Also applies to: 89-89


238-238: LGTM! Complete serialization/deserialization support.

The trainable attribute is properly included in both serialization and deserialization methods, ensuring the trainability state is preserved across model save/load cycles.

Also applies to: 265-265

deepmd/pd/model/descriptor/dpa1.py (1)

295-295: LGTM! Consistent trainable parameter propagation in DPA1.

The trainable parameter is properly passed to both DescrptBlockSeAtten and TypeEmbedNet instances, maintaining consistency with other descriptor implementations and enabling uniform trainability control across the DPA1 architecture.

Also applies to: 309-309

deepmd/pd/model/descriptor/dpa2.py (4)

92-92: LGTM: Well-designed parameter addition with appropriate default.

The trainable parameter addition with a default value of True maintains backward compatibility while enabling explicit control over parameter trainability.


187-187: LGTM: Comprehensive parameter propagation to all sub-components.

The trainable parameter is correctly passed to all relevant sub-components including descriptor blocks, type embedding networks, and MLP layers, ensuring consistent trainability control throughout the model hierarchy.

Also applies to: 207-207, 248-248, 276-276, 302-302, 312-312


323-324: LGTM: Correct gradient control implementation for PaddlePaddle.

The gradient control logic correctly uses PaddlePaddle's stop_gradient attribute with inverse logic (not trainable) to control parameter trainability.


558-558: LGTM: Proper serialization inclusion.

The trainable parameter is correctly included in the serialization dictionary to maintain state consistency during model save/load operations.

deepmd/pt/model/descriptor/se_t_tebd.py (4)

138-138: LGTM: Consistent parameter addition across descriptor classes.

The trainable parameter is appropriately added to both the main descriptor class and the descriptor block class with proper default values.

Also applies to: 530-530


163-163: LGTM: Complete parameter propagation to embedding networks.

The trainable parameter is correctly passed to all embedding network instances, including both main and strip embedding networks when applicable.

Also applies to: 174-174, 583-583, 597-597


184-185: LGTM: Correct gradient control implementation for PyTorch.

The gradient control logic correctly uses PyTorch's requires_grad attribute to control parameter trainability. This is the appropriate approach for the PyTorch framework.


370-370: LGTM: Proper serialization inclusion.

The trainable parameter is correctly included in the serialization dictionary for state persistence.

deepmd/pd/model/descriptor/se_t_tebd.py (4)

138-138: LGTM: Consistent parameter addition matching PyTorch implementation.

The trainable parameter is appropriately added to both descriptor classes with the same signature as the PyTorch version, ensuring cross-framework consistency.

Also applies to: 534-534


163-163: LGTM: Complete parameter propagation to all embedding components.

The trainable parameter is correctly passed to all embedding network instances, maintaining consistency with the PyTorch implementation while using PaddlePaddle-specific components.

Also applies to: 177-177, 591-591, 605-605


184-185: LGTM: Correct gradient control implementation for PaddlePaddle.

The gradient control logic correctly uses PaddlePaddle's stop_gradient attribute with inverse logic (not trainable), which is the appropriate approach for this framework.


374-374: LGTM: Proper serialization inclusion maintaining consistency.

The trainable parameter is correctly included in the serialization dictionary, ensuring state consistency across framework implementations.

deepmd/dpmodel/descriptor/dpa2.py (1)

384-384: Excellent implementation of trainable parameter propagation.

The trainable parameter is correctly added to the constructor and systematically propagated to all key subcomponents (repinit blocks, repformers, type embedding, and transform layers), ensuring consistent control over parameter trainability throughout the DPA2 descriptor hierarchy.

Also applies to: 477-477, 497-497, 538-538, 568-568, 592-592, 602-602

deepmd/tf/fit/fitting.py (1)

138-138: LGTM: Proper trainable parameter integration in fitting serialization.

The trainable parameter is correctly added to the serialize_network method and properly propagated to the FittingNet constructor, enabling trainability control during network serialization.

Also applies to: 203-203

deepmd/tf/descriptor/se.py (1)

195-195: Well-implemented trainable parameter for SE descriptor serialization.

The trainable parameter is properly added with documentation and correctly propagated to all EmbeddingNet instances, including both excluded type networks and those initialized from variables, ensuring comprehensive trainability control.

Also applies to: 218-219, 242-242, 250-250, 286-286

deepmd/tf/descriptor/se_t.py (1)

729-729: Consistent trainable parameter implementation for SE-T descriptor.

The trainable parameter follows the same well-established pattern as other descriptor modules, with proper documentation and correct propagation to all EmbeddingNet constructors in both the helper function and main initialization logic.

Also applies to: 752-753, 777-777, 812-812

deepmd/pt/model/descriptor/dpa2.py (6)

96-96: LGTM! Trainable parameter properly added.

The trainable parameter is correctly added with proper typing and a backward-compatible default value.


191-191: LGTM! Consistent parameter propagation to repinit modules.

The trainable parameter is correctly propagated to both DescrptBlockSeAtten and DescrptBlockSeTTebd instances.

Also applies to: 211-211


252-252: LGTM! Parameter propagation to core modules.

The trainable parameter is properly passed to DescrptBlockRepformers and TypeEmbedNet instances, maintaining consistency.

Also applies to: 280-280


306-306: LGTM! MLPLayer instances receive trainable parameter.

The trainable parameter is correctly propagated to both the g1_shape_tranform and conditional tebd_transform MLPLayer instances.

Also applies to: 316-316


288-288: LGTM! Proper parameter storage and gradient control.

Storing the trainable parameter and explicitly setting requires_grad on all parameters ensures consistent behavior across the descriptor.

Also applies to: 327-328


556-556: LGTM! Trainable parameter properly serialized.

Including the trainable parameter in serialization ensures the setting is preserved during model save/load operations.

deepmd/dpmodel/descriptor/repflows.py (3)

170-171: LGTM! Trainable parameter properly added to DescrptBlockRepflows.

The parameter is correctly added with proper documentation, typing, and backward-compatible default value.

Also applies to: 210-210


274-288: LGTM! Consistent parameter propagation in DescrptBlockRepflows.

The trainable parameter is properly propagated to all NativeLayer instances and RepFlowLayer instances in the constructor.

Also applies to: 319-319


876-876: LGTM! Comprehensive trainable parameter support in RepFlowLayer.

The RepFlowLayer class properly receives the trainable parameter and systematically propagates it to all internal NativeLayer instances and residual components.

Also applies to: 934-1097

deepmd/pd/model/descriptor/se_atten.py (6)

209-209: LGTM!

The trainable parameter is correctly passed to the NeighborGatedAttention constructor.


234-234: LGTM!

The trainable parameter is correctly passed to both EmbeddingNet constructors for main and strip mode filter layers.

Also applies to: 248-248


698-698: LGTM!

The trainable parameter is correctly passed to each NeighborGatedAttentionLayer in the attention layers loop.


834-834: LGTM!

The trainable parameter is correctly passed to the GatedAttentionLayer constructor.


948-948: LGTM!

The trainable parameter is correctly passed to both in_proj and out_proj MLPLayer constructors.

Also applies to: 959-959


84-84: Overall implementation is consistent and well-structured.

The trainable parameter has been correctly added to all relevant classes and properly propagated through the initialization chain. The implementation aligns well with the PR objectives. However, please address the serialization and documentation issues identified in the previous comments to ensure complete functionality.

Also applies to: 209-209, 234-234, 248-248, 662-662, 698-698, 806-806, 834-834, 915-915, 948-948, 959-959

deepmd/tf/descriptor/se_atten.py (5)

1596-1596: LGTM! Correctly propagates trainable flag to attention layer.

The addition of trainable=self.trainable ensures that the in_proj NativeLayer inherits the trainable property from the parent descriptor, maintaining consistency in the attention mechanism serialization.


1615-1615: LGTM! Maintains consistency in trainable flag propagation.

The out_proj NativeLayer now correctly inherits the trainable property, ensuring both in_proj and out_proj layers have consistent trainable settings.


1659-1659: LGTM! Well-designed parameter addition.

The new trainable parameter has proper type annotation, sensible default value, and follows consistent naming conventions. This enables control over trainability in network serialization.


1685-1686: LGTM! Clear and properly formatted parameter documentation.

The documentation for the trainable parameter follows the existing style and clearly describes its purpose, maintaining consistency with the rest of the method's docstring.


1727-1727: LGTM! Completes trainable flag propagation chain.

The trainable parameter is correctly passed from the method parameter to the EmbeddingNet constructor, ensuring that the network respects the specified trainable setting during serialization.

deepmd/pt/model/descriptor/se_atten.py (2)

103-103: Parameter addition looks good.

The trainable parameter is correctly added with appropriate default value for backward compatibility.


228-228: Consistent parameter propagation to child components.

The trainable parameter is correctly passed to all child components (NeighborGatedAttention, EmbeddingNet instances).

Also applies to: 253-253, 267-267

deepmd/pd/model/descriptor/repformer_layer.py (3)

45-45: Correct implementation for residual tensor trainability.

The logic residual.stop_gradient = not trainable correctly controls whether the residual tensor participates in gradient computation.

Also applies to: 75-75


166-166: Parameter addition and propagation look good.

The trainable parameter is correctly added and passed to the MLPLayer.

Also applies to: 179-179


630-630: Extensive and consistent parameter propagation.

The trainable parameter is correctly propagated to all child components and residual tensors throughout the RepformerLayer class. The implementation properly handles all the various layer types and update styles.

Also applies to: 691-691, 701-701, 718-718, 728-728, 737-737, 747-747, 760-760, 769-769, 779-779, 789-789, 799-799, 811-811, 819-820, 836-836, 846-847, 856-856, 867-867, 877-877

deepmd/dpmodel/descriptor/dpa1.py (4)

322-323: LGTM: Proper trainable parameter propagation

The trainable parameter is correctly passed to both DescrptBlockSeAtten and TypeEmbedNet components, and properly stored as an instance variable for later use in serialization.

Also applies to: 337-338, 341-341


696-696: LGTM: Consistent trainable parameter implementation

The trainable parameter is properly added to the DescrptBlockSeAtten constructor and correctly propagated to all embedding networks and attention components.

Also applies to: 747-747, 763-763, 782-782


1195-1195: LGTM: Attention layer trainable parameter support

The trainable parameter is correctly implemented in both NeighborGatedAttention and NeighborGatedAttentionLayer classes, maintaining consistency with the overall design pattern.

Also applies to: 1229-1229, 1325-1325, 1352-1352


1433-1433: LGTM: Complete trainable parameter chain

The GatedAttentionLayer properly implements trainable parameter support and correctly passes it to the underlying NativeLayer components, completing the parameter propagation chain.

Also applies to: 1463-1463, 1472-1472

deepmd/pt/model/descriptor/repformer_layer.py (6)

44-44: LGTM: Correct trainable parameter implementation

The get_residual function properly uses the trainable parameter to control the requires_grad attribute of the residual tensor, which is the correct approach for PyTorch parameter trainability.

Also applies to: 71-71


163-163: Good: Trainable parameter properly added to constructor

The trainable parameter is correctly added to the Atten2Map constructor and properly passed to the MLPLayer.

Also applies to: 176-176


290-290: LGTM: Proper trainable parameter propagation

The trainable parameter is correctly added to the constructor and properly passed to both mapv and head_map MLPLayer instances.

Also applies to: 301-301, 308-308


378-378: LGTM: Trainable parameter correctly implemented

The trainable parameter is properly added to the constructor and correctly passed to the head_map MLPLayer.

Also applies to: 384-390


457-457: LGTM: Comprehensive trainable parameter implementation

The trainable parameter is correctly added to the LocalAtten constructor and properly propagated to all three MLPLayer instances (mapq, mapkv, and head_map).

Also applies to: 469-469, 477-477, 484-484


620-620: Excellent: Comprehensive trainable parameter implementation

The RepformerLayer class demonstrates excellent implementation of the trainable parameter, correctly adding it to the constructor and systematically propagating it to all sub-components including MLPLayers, residual parameters, and attention modules. This ensures complete control over parameter trainability throughout the layer hierarchy.

Also applies to: 681-681, 691-691, 708-708, 718-718, 727-727, 737-737, 750-750, 759-759, 769-769, 779-779, 789-789, 801-801, 809-809, 826-826, 836-836, 846-846, 857-857, 867-867

deepmd/dpmodel/utils/network.py (6)

90-91: LGTM: Clear documentation for the new trainable parameter.

The documentation properly describes the new parameter with appropriate formatting and default value indication.


104-107: Good practice: Setting trainable before array attributes.

The comment and early assignment of self.trainable is well-placed, ensuring the trainable state is established before any array operations that might depend on it.


168-168: Proper version compatibility handling.

The version compatibility check correctly allows both version 1 and 2, ensuring backward compatibility.


249-250: Complete the getter interface.

The __getitem__ method now properly supports the trainable key, maintaining consistency with other layer attributes.


440-441: Consistent trainable parameter propagation.

The trainable parameter is properly threaded through the class hierarchy from LayerNorm to EmbeddingNet to FittingNet, ensuring all layers in the network respect the trainability setting.

Also applies to: 693-694, 709-710, 800-801, 809-810, 823-824


147-154:

✅ Verification successful

Verify version compatibility across the codebase.

The serialization version has been incremented from 1 to 2 to accommodate the new trainable field. This is a breaking change that requires careful coordination.


🏁 Script executed:

#!/bin/bash
# Description: Verify that all deserialization code can handle the version bump and new trainable field
# Expected: All deserialize methods should handle both version 1 and 2, with proper defaults for trainable

# Check for other classes that might serialize/deserialize NativeLayer
rg -A 10 -B 5 "deserialize.*NativeLayer|NativeLayer.*deserialize"

# Check for hardcoded version references that might need updating
rg -A 3 -B 3 "version.*[\"']1[\"']|[\"']1[\"'].*version" --type py

Length of output: 47637


NativeLayer Deserialization Handles Version 1 Default Trainable

The version bump to 2 is safe: NativeLayer.deserialize uses data.pop("@version", 1) and check_version_compatibility, so when deserializing a version-1 payload it defaults to version 1. Since trainable only exists in version 2, any code that does

trainable = data.pop("trainable", True)

(or equivalent default logic) will automatically behave as before for version 1. I’ve confirmed:

@version is popped with default 1 in deepmd/dpmodel/utils/network.py
check_version_compatibility allows versions 1–2
• No hard-coded “1” version checks elsewhere in the repo

No follow-up changes are needed.

deepmd/dpmodel/descriptor/repformers.py (8)

167-169: Comprehensive trainable parameter integration.

The DescrptBlockRepformers class properly adds the trainable parameter with appropriate documentation and default value.

Also applies to: 209-210


257-263: Consistent parameter passing to embedded layers.

The g2_embd layer properly receives the trainable parameter, maintaining consistency with the overall design.


300-301: Proper propagation to RepformerLayer instances.

The trainable parameter is correctly passed to each RepformerLayer instance in the loop, ensuring all layers in the stack respect the trainability setting.


858-859: Attention mechanism trainability support.

The Atten2Map class properly integrates the trainable parameter and passes it to its internal mapqk layer.

Also applies to: 871-872


983-984: Multi-head attention trainability.

The Atten2MultiHeadApply class correctly propagates the trainable parameter to both its mapv and head_map layers.

Also applies to: 994-995, 1001-1002


1074-1086: Clean multi-line parameter formatting.

The Atten2EquiVarApply class uses clean multi-line formatting for the NativeLayer constructor call while properly passing the trainable parameter.


1155-1156: Local attention trainability.

The LocalAtten class systematically passes the trainable parameter to all its internal layers (mapq, mapkv, head_map), ensuring complete coverage.

Also applies to: 1167-1168, 1175-1176, 1182-1183


1321-1322: Comprehensive RepformerLayer trainability integration.

The RepformerLayer class extensively integrates the trainable parameter:

  • Passes it to all NativeLayer instances
  • Includes it in get_residual calls for residual connections
  • Propagates it to attention mechanism components (Atten2Map, Atten2MultiHeadApply, Atten2EquiVarApply, LocalAtten)
  • Maintains consistency across all conditional layer instantiations

This ensures that the entire repformer architecture respects the trainability setting at every level.

Also applies to: 1381-1382, 1391-1392, 1408-1409, 1418-1419, 1427-1428, 1437-1438, 1450-1451, 1459-1460, 1469-1470, 1479-1480, 1489-1490, 1501-1502, 1509-1510, 1526-1527, 1536-1537, 1546-1547, 1557-1558, 1567-1568

njzjz added 6 commits June 10, 2025 20:49
Signed-off-by: Jinzhe Zeng <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
@codecov
Copy link

codecov bot commented Jun 13, 2025

Codecov Report

❌ Patch coverage is 87.17949% with 5 lines in your changes missing coverage. Please review.
✅ Project coverage is 84.57%. Comparing base (ab6e300) to head (94ce346).
⚠️ Report is 105 commits behind head on devel.

Files with missing lines Patch % Lines
deepmd/pd/model/descriptor/repflows.py 77.77% 2 Missing ⚠️
deepmd/tf/fit/dipole.py 83.33% 1 Missing ⚠️
deepmd/tf/fit/fitting.py 50.00% 1 Missing ⚠️
deepmd/tf/fit/polar.py 83.33% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##            devel    #4793      +/-   ##
==========================================
- Coverage   84.80%   84.57%   -0.24%     
==========================================
  Files         698      699       +1     
  Lines       67798    68070     +272     
  Branches     3542     3541       -1     
==========================================
+ Hits        57494    57567      +73     
- Misses       9171     9369     +198     
- Partials     1133     1134       +1     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
  • 📦 JS Bundle Analysis: Save yourself from yourself by tracking and limiting bundle sizes in JS merges.

@njzjz njzjz changed the title fix(dpmodel/pt/pd/jax): pass trainable to layer & support JAX trainable fix(dpmodel/pt/pd/jax): pass trainable to layer & support JAX trainable & support TF tensor fitting trainable Jun 13, 2025
@njzjz njzjz requested review from iProzd and wanghan-iapcm June 13, 2025 06:36
Co-authored-by: Duo <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
@njzjz njzjz requested a review from iProzd June 27, 2025 12:42
@njzjz njzjz enabled auto-merge July 8, 2025 08:58
@njzjz njzjz added this pull request to the merge queue Jul 8, 2025
Merged via the queue into deepmodeling:devel with commit c151e04 Jul 8, 2025
60 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants