Skip to content

Conversation

@ChiahsinChu
Copy link
Contributor

@ChiahsinChu ChiahsinChu commented Mar 3, 2025

Currently, the pytorch dipole model returns dipole (atomic quantities), global dipole (sum of atomic quantities) and force/virial operation on global dipoles. I am interested in not only the atomic dipole but also the derivative of total dipole w.r.t. coord, which is relevant to the dielectric response of the systems and has been used to train field-response MLPs (https://www.nature.com/articles/s41467-024-52491-3). However, the "global dipole" is not always the real total dipole of systems due to the missing information on charges. If the charges associated with the winner centroids are not identical, there is no way for users to get the derivative of (real) total dipole w.r.t. coord from DW models. (I tried to calculate the gradient of output dipole w.r.t. the input coord, but the gradient chain seems to be cut in the model.) Therefore, I add the atomic_weight parameter to the model.forward, which allows a user-defined atomic weight.

There are some potential advantages of this implementation:

  • allow users to calculate the derivative of (real) total dipole w.r.t. coord
  • allow further implementation of DPLR method

This implementation has been only activated for dipole models yet, but making it work for other models should be straightforword.

Summary by CodeRabbit

  • New Features

    • Enabled support for an optional atomic weight input across prediction workflows, allowing users to supply custom atomic weight data that scales computed outputs for enhanced result precision.
  • Tests

    • Increased test coverage to validate the proper integration and consistent handling of atomic weight data in model predictions.

@github-actions github-actions bot added the Python label Mar 3, 2025
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Mar 3, 2025

Caution

Review failed

The pull request is closed.

📝 Walkthrough

Walkthrough

This pull request introduces an optional atomic_weight parameter across multiple model functions in both PyTorch and JAX implementations. The parameter is added to various forward methods in model classes and utility functions, allowing the atomic weight data to be passed through the prediction workflow. In one case, the previous logic modifying output values based on atomic weights was removed to simplify control flow. Additionally, testing has been updated to validate the correct propagation and application of atomic weights.

Changes

File(s) Change Summary
deepmd/pt/model/.../dipole_model.py Added optional atomic_weight to forward and forward_lower method signatures; parameter is forwarded to common computation methods.
deepmd/pt/model/.../make_model.py Updated forward_common_lower and forward_common signatures to include atomic_weight; removed logic that modified atomic_ret based on the parameter.
deepmd/pt/model/atomic_model/base_atomic_model.py
deepmd/dpmodel/atomic_model/base_atomic_model.py
Added atomic_weight parameter to forward_common_atomic and forward methods and updated documentation to reflect its role in scaling outputs.
deepmd/jax/model/.../base_model.py
deepmd/jax/atomic_model/.../dp_atomic_model.py
deepmd/jax/model/dp_model.py
Integrated atomic_weight parameter into forward, evaluation, and atomic model functions, ensuring it passes through to underlying atomic model calls.
deepmd/dpmodel/model/make_model.py Modified several function signatures (model_call_from_call_lower, call, call_lower, input_type_cast, etc.) to consistently include the new atomic_weight parameter.
deepmd/jax/jax2tf/... (make_model.py, serialization.py, tfmodel.py) Updated function signatures to include atomic_weight; adjusted internal logic for passing and initializing atomic weight data in the JAX-to-TF conversion workflow.
deepmd/jax/utils/serialization.py Added atomic_weight to the signature of call_lower_with_fixed_do_atomic_virial, ensuring it is forwarded in subsequent calls.
source/tests/pt/model/test_dp_atomic_model.py Introduced tests that generate a random atomic weight tensor and verify the consistency of energy calculations when the parameter is applied.

Sequence Diagram(s)

sequenceDiagram
    participant Caller
    participant ModelLayer
    participant AtomicModel

    Caller->>ModelLayer: forward(..., atomic_weight)
    ModelLayer->>ModelLayer: Process inputs and pass atomic_weight to common routine
    ModelLayer->>AtomicModel: forward_common_atomic(..., atomic_weight)
    AtomicModel-->>ModelLayer: Return atomic predictions
    ModelLayer-->>Caller: Return final results (atomic weight applied)
Loading

Possibly related PRs

  • feat(jax): zbl #4301: Introduces the atomic_weight parameter in the forward_common_atomic method of the BaseAtomicModel class, aligning with the current changes.
  • feat(jax): force & virial #4251: Involves adding the atomic_weight parameter in forward routines related to atomic predictions, which is conceptually similar to this PR.
  • feat(jax): energy model (no grad support) #4226: Also introduces an atomic_weight parameter for the forward_common_atomic method in atomic model classes, showing a shared focus on enabling atomic weight handling.

Suggested reviewers

  • njzjz
  • wanghan-iapcm

Tip

⚡🧪 Multi-step agentic review comment chat (experimental)
  • We're introducing multi-step agentic chat in review comments. This experimental feature enhances review discussions with the CodeRabbit agentic chat by enabling advanced interactions, including the ability to create pull requests directly from comments.
    - To enable this feature, set early_access to true under in the settings.

📜 Recent review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 088b252 and 66d3e0f.

📒 Files selected for processing (1)
  • deepmd/jax/model/base_model.py (8 hunks)

🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Generate unit testing code for this file.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai generate unit testing code for this file.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and generate unit testing code.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai generate docstrings to generate docstrings for this PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🔭 Outside diff range comments (1)
deepmd/pt/model/model/make_model.py (1)

616-634: ⚠️ Potential issue

Update the forward method to include atomic_weight parameter.

The forward method at the bottom of the class needs to be updated to include the new atomic_weight parameter to ensure consistency with the forward_common method signature.

Apply this change to update the method signature and parameter passing:

def forward(
    self,
    coord,
    atype,
    box: Optional[torch.Tensor] = None,
    fparam: Optional[torch.Tensor] = None,
    aparam: Optional[torch.Tensor] = None,
    do_atomic_virial: bool = False,
+   atomic_weight: Optional[torch.Tensor] = None,
) -> dict[str, torch.Tensor]:
    # directly call the forward_common method when no specific transform rule
    return self.forward_common(
        coord,
        atype,
        box,
        fparam=fparam,
        aparam=aparam,
        do_atomic_virial=do_atomic_virial,
+       atomic_weight=atomic_weight,
    )
🧹 Nitpick comments (1)
deepmd/pt/model/model/make_model.py (1)

299-301: Clarify the atomic weight initialization and usage.

The code uses next(iter(self.atomic_output_def().var_defs.keys())) to get the first key from the atomic output definitions. Consider adding a comment explaining why this key is chosen and ensuring it's the correct key for applying the weight.

Add a clarifying comment:

# add weight to atomic_output
kw = next(iter(self.atomic_output_def().var_defs.keys()))
+# Create default weight tensor using the shape of the first output variable
atomic_weight = torch.ones_like(atomic_ret[kw])
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 80d445b and 275fb52.

📒 Files selected for processing (2)
  • deepmd/pt/model/model/dipole_model.py (4 hunks)
  • deepmd/pt/model/model/make_model.py (4 hunks)
⏰ Context from checks skipped due to timeout of 90000ms (20)
  • GitHub Check: Test Python (4, 3.12)
  • GitHub Check: Test Python (4, 3.9)
  • GitHub Check: Test Python (3, 3.12)
  • GitHub Check: Build C++ (clang, clang)
  • GitHub Check: Test Python (3, 3.9)
  • GitHub Check: Build C++ (rocm, rocm)
  • GitHub Check: Test Python (2, 3.12)
  • GitHub Check: Build C++ (cuda120, cuda)
  • GitHub Check: Test Python (2, 3.9)
  • GitHub Check: Analyze (python)
  • GitHub Check: Build C++ (cuda, cuda)
  • GitHub Check: Build C library (2.14, >=2.5.0rc0,<2.15, libdeepmd_c_cu11.tar.gz)
  • GitHub Check: Test Python (1, 3.12)
  • GitHub Check: Analyze (javascript-typescript)
  • GitHub Check: Test C++ (false)
  • GitHub Check: Build C++ (cpu, cpu)
  • GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
  • GitHub Check: Test Python (1, 3.9)
  • GitHub Check: Analyze (c-cpp)
  • GitHub Check: Test C++ (true)
🔇 Additional comments (7)
deepmd/pt/model/model/dipole_model.py (4)

63-63: Function signature enhancement: New atomic weight parameter added.

This change appropriately adds the optional atomic_weight parameter to the forward method, which allows users to define weights for atomic dipoles.


72-72: Correctly passing atomic_weight to the common implementation.

The parameter is properly forwarded to the underlying implementation.


103-103: Function signature enhancement: New atomic weight parameter added to lower interface.

The atomic_weight parameter is appropriately added to the forward_lower method, maintaining consistency with the higher-level interface.


114-114: Correctly passing atomic_weight to the lower-level implementation.

The parameter is properly forwarded to the underlying implementation.

deepmd/pt/model/model/make_model.py (3)

138-138: Function signature enhancement: New atomic weight parameter added.

This change appropriately adds the optional atomic_weight parameter to the forward_common method, allowing the dipole model to incorporate atomic weights in its predictions.


192-192: Correctly passing atomic_weight to the lower-level implementation.

The parameter is properly forwarded to the forward_common_lower method.


247-247: Function signature enhancement: New atomic weight parameter added to lower interface.

The atomic_weight parameter is appropriately added to the forward_common_lower method.

Comment on lines 299 to 306
# add weight to atomic_output
kw = next(iter(self.atomic_output_def().var_defs.keys()))
atomic_weight = torch.ones_like(atomic_ret[kw])
if atomic_weight is not None:
# atomic_weight: nf x nloc x dim
atomic_ret[kw] = atomic_ret[kw] * atomic_weight.reshape(
*atomic_ret[kw].shape[:-1], -1
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix parameter shadowing issue in the atomic weight implementation.

There's a critical bug in the implementation. The function parameter atomic_weight is being shadowed by a local variable with the same name on line 301. This will prevent the externally provided weights from being used.

Apply this fix to correct the parameter shadowing:

# add weight to atomic_output
kw = next(iter(self.atomic_output_def().var_defs.keys()))
-atomic_weight = torch.ones_like(atomic_ret[kw])
+weight_tensor = torch.ones_like(atomic_ret[kw])
if atomic_weight is not None:
    # atomic_weight: nf x nloc x dim
-   atomic_ret[kw] = atomic_ret[kw] * atomic_weight.reshape(
+   atomic_ret[kw] = atomic_ret[kw] * atomic_weight.reshape(
        *atomic_ret[kw].shape[:-1], -1
    )
+else:
+    atomic_ret[kw] = atomic_ret[kw] * weight_tensor

Or, for a cleaner implementation:

# add weight to atomic_output
kw = next(iter(self.atomic_output_def().var_defs.keys()))
-atomic_weight = torch.ones_like(atomic_ret[kw])
if atomic_weight is not None:
    # atomic_weight: nf x nloc x dim
    atomic_ret[kw] = atomic_ret[kw] * atomic_weight.reshape(
        *atomic_ret[kw].shape[:-1], -1
    )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# add weight to atomic_output
kw = next(iter(self.atomic_output_def().var_defs.keys()))
atomic_weight = torch.ones_like(atomic_ret[kw])
if atomic_weight is not None:
# atomic_weight: nf x nloc x dim
atomic_ret[kw] = atomic_ret[kw] * atomic_weight.reshape(
*atomic_ret[kw].shape[:-1], -1
)
# add weight to atomic_output
kw = next(iter(self.atomic_output_def().var_defs.keys()))
weight_tensor = torch.ones_like(atomic_ret[kw])
if atomic_weight is not None:
# atomic_weight: nf x nloc x dim
atomic_ret[kw] = atomic_ret[kw] * atomic_weight.reshape(
*atomic_ret[kw].shape[:-1], -1
)
else:
atomic_ret[kw] = atomic_ret[kw] * weight_tensor

@ChiahsinChu ChiahsinChu force-pushed the devel-dipole_with_atomic_weight branch from 64fbde6 to 50a8a21 Compare March 3, 2025 11:58
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🔭 Outside diff range comments (1)
deepmd/pt/model/model/make_model.py (1)

617-634: ⚠️ Potential issue

Update forward method to pass the atomic_weight parameter.

The forward method needs to be updated to accept and pass the atomic_weight parameter to forward_common, otherwise this functionality won't be available when using the base forward method.

def forward(
    self,
    coord,
    atype,
    box: Optional[torch.Tensor] = None,
    fparam: Optional[torch.Tensor] = None,
    aparam: Optional[torch.Tensor] = None,
    do_atomic_virial: bool = False,
+   atomic_weight: Optional[torch.Tensor] = None,
) -> dict[str, torch.Tensor]:
    # directly call the forward_common method when no specific transform rule
    return self.forward_common(
        coord,
        atype,
        box,
        fparam=fparam,
        aparam=aparam,
        do_atomic_virial=do_atomic_virial,
+       atomic_weight=atomic_weight,
    )
♻️ Duplicate comments (1)
deepmd/pt/model/model/make_model.py (1)

300-307: ⚠️ Potential issue

Fix parameter shadowing issue as identified in previous review.

The current implementation introduces a potential parameter shadowing issue if the atomic model fitting net's var_name is the same as "atomic_weight". This could lead to unexpected behavior, similar to the issue identified in a previous review comment.

It's best to use a different variable name for clarity:

if hasattr(self.atomic_model, "fitting_net"):
    if hasattr(self.atomic_model.fitting_net, "var_name"):
        kw = self.atomic_model.fitting_net.var_name
        if atomic_weight is not None:
            # atomic_weight: nf x nloc x dim
-           atomic_ret[kw] = atomic_ret[kw] * atomic_weight.view(
+           atomic_ret[kw] = atomic_ret[kw] * atomic_weight.view(
                [atomic_ret[kw].shape[0], atomic_ret[kw].shape[1], -1]
            )

The variable name doesn't need to be changed in this case as it doesn't actually shadow the parameter (unlike the previous issue), but the concern about having the same name is valid.

🧰 Tools
🪛 Ruff (0.8.2)

300-301: Use a single if statement instead of nested if statements

(SIM102)

🧹 Nitpick comments (2)
deepmd/pt/model/model/make_model.py (2)

138-138: Add the new parameter to docstring.

The parameter atomic_weight is added to the method signature but not documented in the method docstring. For better code maintenance and usability, update the docstring to include a description of this parameter.

@@ -152,6 +152,8 @@
                frame parameter. nf x ndf
            aparam
                atomic parameter. nf x nloc x nda
+           atomic_weight
+               Optional weight for atomic output. nf x nloc x dim
            do_atomic_virial
                If calculate the atomic virial.

247-247: Add the new parameter to docstring.

The parameter atomic_weight is added to the method signature but not documented in the method docstring. Update the docstring to include a description of this parameter.

@@ -271,6 +271,8 @@
                The data needed for communication for parallel inference.
            extra_nlist_sort
                whether to forcibly sort the nlist.
+           atomic_weight
+               Optional weight for atomic output. nf x nloc x dim
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 64fbde6 and 50a8a21.

📒 Files selected for processing (1)
  • deepmd/pt/model/model/make_model.py (4 hunks)
🧰 Additional context used
🪛 Ruff (0.8.2)
deepmd/pt/model/model/make_model.py

300-301: Use a single if statement instead of nested if statements

(SIM102)

⏰ Context from checks skipped due to timeout of 90000ms (14)
  • GitHub Check: Build C++ (clang, clang)
  • GitHub Check: Build C++ (rocm, rocm)
  • GitHub Check: Analyze (python)
  • GitHub Check: Build C++ (cuda120, cuda)
  • GitHub Check: Test Python (1, 3.12)
  • GitHub Check: Analyze (javascript-typescript)
  • GitHub Check: Build C++ (cuda, cuda)
  • GitHub Check: Build C library (2.14, >=2.5.0rc0,<2.15, libdeepmd_c_cu11.tar.gz)
  • GitHub Check: Test C++ (false)
  • GitHub Check: Analyze (c-cpp)
  • GitHub Check: Test Python (1, 3.9)
  • GitHub Check: Build C++ (cpu, cpu)
  • GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
  • GitHub Check: Test C++ (true)
🔇 Additional comments (1)
deepmd/pt/model/model/make_model.py (1)

192-192: Appropriate parameter passing.

The atomic_weight parameter is correctly passed down to the lower-level implementation, ensuring consistent behavior across the API layers.

@codecov
Copy link

codecov bot commented Mar 3, 2025

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 84.78%. Comparing base (5aa7a8f) to head (3cbaa61).

Additional details and impacted files
@@           Coverage Diff           @@
##            devel    #4628   +/-   ##
=======================================
  Coverage   84.77%   84.78%           
=======================================
  Files         688      688           
  Lines       66097    66102    +5     
  Branches     3539     3539           
=======================================
+ Hits        56036    56042    +6     
  Misses       8919     8919           
+ Partials     1142     1141    -1     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
  • 📦 JS Bundle Analysis: Save yourself from yourself by tracking and limiting bundle sizes in JS merges.

Copy link
Collaborator

@wanghan-iapcm wanghan-iapcm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not get what do you mean by " If the charges associated with the winner centroids are not identical, there is no way for users to get the derivative of (real) total dipole w.r.t. coord from DW models. (I tried to calculate the gradient of output dipole w.r.t. the input coord, but the gradient chain seems to be cut in the model.)"

Does the type exclusion mechanism solves your issue? you may check

atom_exclude_types: list[int] = [],

@wanghan-iapcm wanghan-iapcm requested a review from iProzd March 3, 2025 14:41
@ChiahsinChu
Copy link
Contributor Author

ChiahsinChu commented Mar 3, 2025

Does the type exclusion mechanism solves your issue?

No. It is not what I mean. The atom_exclude_types , as far as I know, is the equivalent of sel_type in the tf model.

An example for my case:
I want to train a DW model for LiOH aqueous solution, in which the charges associated with the WCs of O and Li should be -8 and -2, respectively. (No WC for H.) In this case, the model_pred["force"], which is the negative derivative of the sum of atomic dipoles w.r.t. atomic coordinates, does not correspond to the derivative of total dipole moment w.r.t. coord.

@ChiahsinChu ChiahsinChu force-pushed the devel-dipole_with_atomic_weight branch from d48e45f to a47b1e0 Compare March 4, 2025 01:23
Comment on lines 300 to 302
if hasattr(self.atomic_model, "fitting_net"):
if hasattr(self.atomic_model.fitting_net, "var_name"):
kw = self.atomic_model.fitting_net.var_name
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel that self.atomic_model.fitting_output_def could be used here

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a bit confusing since it breaks the consistency of APIs between different models and backends.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be better to add this keyword for all models and backends?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure what is the best way

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would suggest incorporating the atomic weight into the base atomic model, following a similar approach to how atom masks are managed. This would ensure that atomic weight is consistently supported across all atomic modes.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caution

Inline review comments failed to post. This is likely due to GitHub's limits when posting large numbers of comments. If you are seeing this consistently it is likely a permissions issue. Please check "Moderation" -> "Code review limits" under your organization settings.

Actionable comments posted: 5

🔭 Outside diff range comments (1)
deepmd/pd/model/descriptor/se_a.py (1)

745-761: ⚠️ Potential issue

Inconsistent code structure detected

There appears to be redundant or conflicting code blocks in this section. The same conditional logic appears twice with different indentation (lines 748-755 and 754-761), which could lead to unpredictable behavior.

This appears to be a merge conflict or code duplication that wasn't properly resolved. Please consolidate these blocks into a single coherent code path:

-                    "Compressed environment is not implemented yet."
-                )
-            else:
-                # NOTE: control flow with double backward is not supported well yet by paddle.jit
-                if not paddle.framework.in_dynamic_mode() or decomp.numel(rr) > 0:
-                    rr = rr * mm.unsqueeze(2).astype(rr.dtype)
-                    ss = rr[:, :, :1]
-                    if self.compress:
-            else:
-                # NOTE: control flow with double backward is not supported well yet by paddle.jit
-                if not paddle.in_dynamic_mode() or decomp.numel(rr) > 0:
-                    rr = rr * mm.unsqueeze(2).astype(rr.dtype)
-                    ss = rr[:, :, :1]
-                    if self.compress:
-                        raise NotImplementedError(
-                            "Compressed environment is not implemented yet."
-                        )
-                    else:
+                    "Compressed environment is not implemented yet."
+                )
+            else:
+                # NOTE: control flow with double backward is not supported well yet by paddle.jit
+                if not paddle.in_dynamic_mode() or decomp.numel(rr) > 0:
+                    rr = rr * mm.unsqueeze(2).astype(rr.dtype)
+                    ss = rr[:, :, :1]
+                    if self.compress:
+                        raise NotImplementedError(
+                            "Compressed environment is not implemented yet."
+                        )
+                    else:
🧹 Nitpick comments (22)
deepmd/utils/data.py (1)

92-96: Added validation to ensure comprehensive type map coverage

This addition improves input validation by checking if all elements in self.type_map are present in the provided type_map. It raises a descriptive error that helps users identify which elements are missing from their type map, preventing silent errors or unexpected behavior later in the execution pipeline.

Consider adding a unit test to verify this validation works as expected with various input combinations.

source/tests/array_api_strict/descriptor/repflows.py (2)

22-38: Consider handling None values before serialization/deserialization.

In this logic, some attributes (e.g., layers, edge_embd, angle_embd) are immediately serialized/deserialized. If value is unexpectedly None, it could lead to runtime errors. A small type check or safety guard can improve robustness.

 def __setattr__(self, name: str, value: Any) -> None:
     if name in {"mean", "stddev"}:
         value = to_array_api_strict_array(value)
+    elif name in {"layers"} and value is not None:
+        value = [RepFlowLayer.deserialize(layer.serialize()) for layer in value]
+    elif name in {"edge_embd", "angle_embd"} and value is not None:
+        value = NativeLayer.deserialize(value.serialize())
     elif name in {"env_mat_edge", "env_mat_angle"}:
         pass
     elif name == "emask":
         value = PairExcludeMask(value.ntypes, value.exclude_types)
     return super().__setattr__(name, value)

41-61: Check for None before calling value.serialize().

In the current implementation, if any of the named attributes are set to None, calling value.serialize() will fail. Consider adding a guard to avoid potential errors.

 def __setattr__(self, name: str, value: Any) -> None:
     if name in {
         "node_self_mlp",
         "node_sym_linear",
         "node_edge_linear",
         "edge_self_linear",
         "a_compress_n_linear",
         "a_compress_e_linear",
         "edge_angle_linear1",
         "edge_angle_linear2",
         "angle_self_linear",
     }:
-        if value is not None:
+        if value is not None and hasattr(value, "serialize"):
             value = NativeLayer.deserialize(value.serialize())
     elif name in {"n_residual", "e_residual", "a_residual"}:
         value = [to_array_api_strict_array(vv) for vv in value]
     return super().__setattr__(name, value)
deepmd/jax/descriptor/repflows.py (2)

25-42: Add a guard for potential None values.

Similar to the array_api_strict variant, deserializing attributes without checking for None can cause errors. Adding a simple check can increase code resilience.

 def __setattr__(self, name: str, value: Any) -> None:
     if name in {"mean", "stddev"}:
         value = to_jax_array(value)
         if value is not None:
             value = ArrayAPIVariable(value)
+    elif name in {"layers"} and value is not None:
+        value = [RepFlowLayer.deserialize(layer.serialize()) for layer in value]
+    elif name in {"edge_embd", "angle_embd"} and value is not None:
+        value = NativeLayer.deserialize(value.serialize())
     elif name in {"env_mat_edge", "env_mat_angle"}:
         pass
     elif name == "emask":
         value = PairExcludeMask(value.ntypes, value.exclude_types)
     return super().__setattr__(name, value)

46-66: Check for None or incorrect types before value.serialize().

When deserializing network layers, confirm that they are valid, non-None objects with a serialize() method. This helps avoid unhandled runtime exceptions.

 if name in {
     "node_self_mlp",
     "node_sym_linear",
     "node_edge_linear",
     "edge_self_linear",
     "a_compress_n_linear",
     "a_compress_e_linear",
     "edge_angle_linear1",
     "edge_angle_linear2",
     "angle_self_linear",
 }:
-    if value is not None:
+    if value is not None and hasattr(value, "serialize"):
         value = NativeLayer.deserialize(value.serialize())
 elif name in {"n_residual", "e_residual", "a_residual"}:
     value = [ArrayAPIVariable(to_jax_array(vv)) for vv in value]
 return super().__setattr__(name, value)
source/tests/pt/model/test_dpa3.py (1)

209-209: Remove or utilize the unused variable model.

Local variable model is assigned but never used. This can be safely removed unless intended for later use.

- model = torch.jit.script(dd0)
🧰 Tools
🪛 Ruff (0.8.2)

209-209: Local variable model is assigned to but never used

Remove assignment to unused variable model

(F841)

deepmd/dpmodel/descriptor/dpa3.py (3)

212-253: Constructor largely follows good practices.

The overall constructor is consistent, properly leveraging helper functions for initialization. However, please review the default argument for exclude_types.

🧰 Tools
🪛 Ruff (0.8.2)

253-253: Do not use mutable data structures for argument defaults

Replace with None; initialize within function

(B006)


253-253: Replace mutable default argument.

Defaulting exclude_types to a mutable list can introduce unexpected behaviors if the list is modified in-place. Consider using None instead:

-    exclude_types: list[tuple[int, int]] = [],
+    exclude_types: Optional[list[tuple[int, int]]] = None,
     ...
+    if exclude_types is None:
+        exclude_types = []
🧰 Tools
🪛 Ruff (0.8.2)

253-253: Do not use mutable data structures for argument defaults

Replace with None; initialize within function

(B006)


574-574: Unused local variable.

The variable env_mat is assigned but never used. Removing it clarifies the code and avoids confusion:

-        env_mat = repflow_variable.pop("env_mat")
🧰 Tools
🪛 Ruff (0.8.2)

574-574: Local variable env_mat is assigned to but never used

Remove assignment to unused variable env_mat

(F841)

source/tests/universal/dpmodel/descriptor/test_descriptor.py (1)

467-527: New DescriptorParamDPA3 function appears correct.

However, please address the mutable default argument for exclude_types at line 474 to avoid potential unintended side effects.

-    exclude_types=[],
+    exclude_types: Optional[list] = None,
     ...
+    if exclude_types is None:
+        exclude_types = []
🧰 Tools
🪛 Ruff (0.8.2)

474-474: Do not use mutable data structures for argument defaults

Replace with None; initialize within function

(B006)

deepmd/dpmodel/descriptor/repflows.py (3)

440-440: Rename unused loop variable idx to _idx for clarity

The variable idx isn’t referenced in the loop body. Renaming it to _idx clarifies that it’s intentionally unused.

-for idx, ll in enumerate(self.layers):
+for _idx, ll in enumerate(self.layers):
    # node_ebd: nb x nloc x n_dim
    # node_ebd_ext: nb x nall x n_dim
🧰 Tools
🪛 Ruff (0.8.2)

440-440: Loop control variable idx not used within loop body

Rename unused idx to _idx

(B007)


937-937: Remove unused variable nall

The variable nall is assigned but never used, creating dead code. Removing it tidies up the logic.

-        nall = node_ebd_ext.shape[1]
         node_ebd = node_ebd_ext[:, :nloc, :]
         assert (nb, nloc) == node_ebd.shape[:2]
🧰 Tools
🪛 Ruff (0.8.2)

937-937: Local variable nall is assigned to but never used

Remove assignment to unused variable nall

(F841)


1199-1199: Remove unused variable nitem

The nitem variable is never used. Eliminating it prevents confusion and aligns with clean coding practices.

-        nitem = len(update_list)
         uu = update_list[0]
         if update_name == "node":
             for ii, vv in enumerate(self.n_residual):
🧰 Tools
🪛 Ruff (0.8.2)

1199-1199: Local variable nitem is assigned to but never used

Remove assignment to unused variable nitem

(F841)

deepmd/pt/model/descriptor/repflows.py (2)

461-461: Rename unused variable to underscore.

The loop index idx is never referenced. Use _ to indicate it’s unused:

-        for idx, ll in enumerate(self.layers):
+        for _, ll in enumerate(self.layers):
🧰 Tools
🪛 Ruff (0.8.2)

461-461: Loop control variable idx not used within loop body

Rename unused idx to _idx

(B007)


571-575: Use a ternary operator for concise code.

You can simplify:

-            if callable(merged):
-                sampled = merged()
-            else:
-                sampled = merged
+            sampled = merged() if callable(merged) else merged
🧰 Tools
🪛 Ruff (0.8.2)

571-575: Use ternary operator sampled = merged() if callable(merged) else merged instead of if-else-block

(SIM108)

deepmd/pt/model/descriptor/repflow_layer.py (3)

306-307: Remove unused local variable
Line 307 declares e_dim but it is never referenced afterward. Consider removing it to streamline the code and avoid potential confusion.

     nb, nloc, nnei, _ = edge_ebd.shape
-    e_dim = edge_ebd.shape[-1]
🧰 Tools
🪛 Ruff (0.8.2)

307-307: Local variable e_dim is assigned to but never used

Remove assignment to unused variable e_dim

(F841)


534-534: Remove unused local variable
Line 534 introduces nall but it is not used anywhere in the code. Cleaning it up will improve clarity.

-    nall = node_ebd_ext.shape[1]
🧰 Tools
🪛 Ruff (0.8.2)

534-534: Local variable nall is assigned to but never used

Remove assignment to unused variable nall

(F841)


788-788: Remove unused local variable
In list_update_res_residual, the variable nitem is assigned but never utilized. Consider deleting it.

     def list_update_res_residual(
         self, update_list: list[torch.Tensor], update_name: str = "node"
     ) -> torch.Tensor:
-        nitem = len(update_list)
         uu = update_list[0]
🧰 Tools
🪛 Ruff (0.8.2)

788-788: Local variable nitem is assigned to but never used

Remove assignment to unused variable nitem

(F841)

source/tests/consistent/descriptor/test_dpa3.py (1)

39-44: Simplify the conditional logic
The current if/else statements always assign DescrptDPA3PD = None. Using a ternary expression or removing the condition entirely can make the code simpler.

-if INSTALLED_PD:
-    # not supported yet
-    DescrptDPA3PD = None
-else:
-    DescrptDPA3PD = None
+DescrptDPA3PD = None
🧰 Tools
🪛 Ruff (0.8.2)

39-43: Use ternary operator DescrptDPA3PD = None if INSTALLED_PD else None instead of if-else-block

(SIM108)

deepmd/pt/model/descriptor/dpa3.py (3)

339-339: Rename or remove unused loop variable
The loop variable ii is not utilized within the loop body. This can be misleading. Rename it to _ to clarify that it is unused, or remove the enumeration if not needed.

-        for ii, descrpt in enumerate(descrpt_list):
+        for _, descrpt in enumerate(descrpt_list):
             descrpt.compute_input_stats(merged, path)
🧰 Tools
🪛 Ruff (0.8.2)

339-339: Loop control variable ii not used within loop body

Rename unused ii to _ii

(B007)


414-415: Remove assignment to unused variable
env_mat is assigned here but never referenced. Remove it to reduce unnecessary code.

 env_mat = repflow_variable.pop("env_mat")
-env_mat
🧰 Tools
🪛 Ruff (0.8.2)

415-415: Local variable env_mat is assigned to but never used

Remove assignment to unused variable env_mat

(F841)


471-471: Remove unused variable
nall is assigned on line 471 without subsequent usage. Deleting it will keep the code concise.

-        nall = extended_coord.view(nframes, -1).shape[1] // 3
🧰 Tools
🪛 Ruff (0.8.2)

471-471: Local variable nall is assigned to but never used

Remove assignment to unused variable nall

(F841)

🛑 Comments failed to post (5)
deepmd/dpmodel/descriptor/repflows.py (2)

156-156: 🛠️ Refactor suggestion

Avoid using a mutable default value for exclude_types

Using a list as a default argument may lead to unexpected side effects because it is shared across all calls. Consider using None and initializing inside the constructor.

-        exclude_types: list[tuple[int, int]] = [],
+        exclude_types: Optional[list[tuple[int, int]]] = None,

Then, after the constructor starts:

+        if exclude_types is None:
+            exclude_types = []
         self.e_rcut = float(e_rcut)
         self.e_rcut_smth = float(e_rcut_smth)
         ...
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

        exclude_types: Optional[list[tuple[int, int]]] = None,
        other_param: int = 0,  # other parameters...

        # Inside the constructor after the parameters are received:
        if exclude_types is None:
            exclude_types = []
        self.e_rcut = float(e_rcut)
        self.e_rcut_smth = float(e_rcut_smth)
        # ... rest of the constructor code
🧰 Tools
🪛 Ruff (0.8.2)

156-156: Do not use mutable data structures for argument defaults

Replace with None; initialize within function

(B006)


353-353: 🛠️ Refactor suggestion

Avoid using a mutable default value for exclude_types in reinit_exclude

Same reasoning applies: a default list can cause unexpected behavior. Use None and initialize inside the method.

-def reinit_exclude(
-    self,
-    exclude_types: list[tuple[int, int]] = [],
-) -> None:
+def reinit_exclude(
+    self,
+    exclude_types: Optional[list[tuple[int, int]]] = None,
+) -> None:
+    if exclude_types is None:
+        exclude_types = []

     self.exclude_types = exclude_types
     self.emask = PairExcludeMask(self.ntypes, exclude_types=exclude_types)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

def reinit_exclude(
    self,
    exclude_types: Optional[list[tuple[int, int]]] = None,
) -> None:
    if exclude_types is None:
        exclude_types = []
    self.exclude_types = exclude_types
    self.emask = PairExcludeMask(self.ntypes, exclude_types=exclude_types)
🧰 Tools
🪛 Ruff (0.8.2)

353-353: Do not use mutable data structures for argument defaults

Replace with None; initialize within function

(B006)

deepmd/pt/model/descriptor/repflows.py (2)

360-360: 🛠️ Refactor suggestion

Avoid mutable default argument in reinit_exclude.

Applying the same reasoning here:

-    def reinit_exclude(
-        self,
-        exclude_types: list[tuple[int, int]] = [],
-    ) -> None:
+    def reinit_exclude(
+        self,
+        exclude_types: Optional[list[tuple[int, int]]] = None,
+    ) -> None:
+        if exclude_types is None:
+            exclude_types = []
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

    def reinit_exclude(
        self,
        exclude_types: Optional[list[tuple[int, int]]] = None,
    ) -> None:
        if exclude_types is None:
            exclude_types = []
🧰 Tools
🪛 Ruff (0.8.2)

360-360: Do not use mutable data structures for argument defaults

Replace with None; initialize within function

(B006)


178-178: 🛠️ Refactor suggestion

Use None instead of a mutable default argument.

Using a mutable list ([]) as a default argument for exclude_types can lead to unexpected shared state across instances. Consider this refactor:

-        exclude_types: list[tuple[int, int]] = [],
+        exclude_types: Optional[list[tuple[int, int]]] = None,
         ...
     ):
+        if exclude_types is None:
+            exclude_types = []
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

        exclude_types: Optional[list[tuple[int, int]]] = None,
         ...
     ):
        if exclude_types is None:
            exclude_types = []
🧰 Tools
🪛 Ruff (0.8.2)

178-178: Do not use mutable data structures for argument defaults

Replace with None; initialize within function

(B006)

deepmd/pt/model/descriptor/dpa3.py (1)

105-105: 🛠️ Refactor suggestion

Avoid mutable default argument
Defining exclude_types: list[tuple[int, int]] = [] can lead to unintended shared state across calls. Consider using None as a default and initializing inside the constructor.

-    exclude_types: list[tuple[int, int]] = [],
+    exclude_types: Optional[list[tuple[int, int]]] = None,
...
def __init__(
    self,
    ...
    exclude_types: Optional[list[tuple[int, int]]] = None,
    ...
):
    if exclude_types is None:
        exclude_types = []
    self.exclude_types = exclude_types
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

from typing import Optional  # Add this if not already imported

class YourClassName:
    def __init__(
        self,
        ...,
-       exclude_types: list[tuple[int, int]] = [],
+       exclude_types: Optional[list[tuple[int, int]]] = None,
        ...
    ):
-       self.exclude_types = exclude_types
+       if exclude_types is None:
+           exclude_types = []
+       self.exclude_types = exclude_types
🧰 Tools
🪛 Ruff (0.8.2)

105-105: Do not use mutable data structures for argument defaults

Replace with None; initialize within function

(B006)

@ChiahsinChu ChiahsinChu force-pushed the devel-dipole_with_atomic_weight branch from ab0f1f5 to d14fa86 Compare March 11, 2025 15:22
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (1)
deepmd/dpmodel/atomic_model/base_atomic_model.py (1)

152-224: Missing documentation for the atomic_weight parameter.

The docstring for the forward_common_atomic method hasn't been updated to include information about the new atomic_weight parameter and its usage.

Add documentation for the atomic_weight parameter in the method's docstring:

    def forward_common_atomic(
        self,
        extended_coord: np.ndarray,
        extended_atype: np.ndarray,
        nlist: np.ndarray,
        mapping: Optional[np.ndarray] = None,
        fparam: Optional[np.ndarray] = None,
        aparam: Optional[np.ndarray] = None,
        atomic_weight: Optional[np.ndarray] = None,
    ) -> dict[str, np.ndarray]:
        """Common interface for atomic inference.

        This method accept extended coordinates, extended atom typs, neighbor list,
        and predict the atomic contribution of the fit property.

        Parameters
        ----------
        extended_coord
            extended coordinates, shape: nf x (nall x 3)
        extended_atype
            extended atom typs, shape: nf x nall
            for a type < 0 indicating the atomic is virtual.
        nlist
            neighbor list, shape: nf x nloc x nsel
        mapping
            extended to local index mapping, shape: nf x nall
        fparam
            frame parameters, shape: nf x dim_fparam
        aparam
            atomic parameter, shape: nf x nloc x dim_aparam
+       atomic_weight
+           atomic weights for scaling outputs, shape: nf x nloc
+           if provided, all output values will be multiplied by this weight

        Returns
        -------
        ret_dict
            dict of output atomic properties.
            should implement the definition of `fitting_output_def`.
            ret_dict["mask"] of shape nf x nloc will be provided.
            ret_dict["mask"][ff,ii] == 1 indicating the ii-th atom of the ff-th frame is real.
            ret_dict["mask"][ff,ii] == 0 indicating the ii-th atom of the ff-th frame is virtual.

        """
🧰 Tools
🪛 Ruff (0.8.2)

212-212: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between ab0f1f5 and d14fa86.

📒 Files selected for processing (4)
  • deepmd/jax/model/base_model.py (1 hunks)
  • deepmd/dpmodel/atomic_model/base_atomic_model.py (5 hunks)
  • deepmd/dpmodel/model/make_model.py (12 hunks)
  • deepmd/jax/model/base_model.py (8 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • deepmd/jax/model/base_model.py
⏰ Context from checks skipped due to timeout of 90000ms (5)
  • GitHub Check: Build C++ (clang, clang)
  • GitHub Check: Analyze (python)
  • GitHub Check: Build C++ (cuda, cuda)
  • GitHub Check: Build C++ (cpu, cpu)
  • GitHub Check: Analyze (c-cpp)
🔇 Additional comments (20)
deepmd/jax/model/base_model.py (5)

30-30: Added atomic_weight parameter for enhanced dipole model.

The addition of the optional atomic_weight parameter aligns with the PR's goal to improve dipole model calculations by allowing atomic weight specification.


39-39: Properly passed atomic_weight to the atomic model.

The parameter is correctly forwarded to the atomic model's forward_common_atomic method.


61-62: Updated eval_output signature and handling of atomic_weight.

The function now properly accepts and processes the atomic_weight parameter, with appropriate None-checking when passing it to the atomic model.

Also applies to: 73-75


111-112: Updated eval_ce signature and handling of atomic_weight.

Similar to eval_output, this function has been properly updated to handle the atomic_weight parameter with appropriate None-checking.

Also applies to: 124-126


88-89: Correctly propagated atomic_weight to derivative calculations.

The parameter is correctly passed to the JAX autodiff functions, ensuring that derivatives properly account for atomic weights.

Also applies to: 144-145

deepmd/dpmodel/model/make_model.py (10)

66-66: Added atomic_weight parameter to model_call_from_call_lower function.

The optional atomic_weight parameter has been added to support weighted dipole calculations.


125-126: Properly propagated atomic_weight to lower-level functions.

The parameter is correctly passed to call_lower, ensuring proper propagation through the call chain.


229-229: Added atomic_weight parameter to the call method.

This addition maintains consistency with the other interface methods and supports the enhanced dipole model functionality.


256-262: Updated input_type_cast to handle atomic_weight parameter.

The function now properly unpacks and includes atomic_weight in its parameter list and return values.


290-290: Added atomic_weight parameter to call_lower method.

This change ensures that the atomic_weight parameter is supported throughout the model's interface.


328-333: Updated input_type_cast handling in call_lower method.

The parameter is now correctly unpacked from the input_type_cast return value.


357-357: Added atomic_weight parameter to forward_common_atomic method.

This change provides consistent support for atomic_weight throughout the model interface.


383-389: Updated input_type_cast function signature and return type.

The function's signature and return type have been properly updated to include the atomic_weight parameter.


399-403: Added atomic_weight handling in input_type_cast conversion logic.

The parameter is correctly included in the type conversion list and unpacked appropriately.


411-411: Added type casting for atomic_weight in precision conversion.

The atomic_weight parameter is properly type-cast when converting precision, maintaining consistency with other parameters.

deepmd/dpmodel/atomic_model/base_atomic_model.py (5)

201-202: Properly passed atomic_weight to forward_atomic method.

The parameter is correctly forwarded to the forward_atomic method.


218-223: Implemented atomic_weight scaling for output calculation.

This implementation correctly scales the atomic output by the provided weights when atomic_weight is not None. The reshaping ensures proper broadcasting regardless of output shape.


235-235: Added atomic_weight parameter to call method.

The parameter has been added to maintain consistency across the API.


244-245: Propagated atomic_weight to forward_common_atomic in call method.

The parameter is correctly passed through the call method to forward_common_atomic.


152-152:

❓ Verification inconclusive

Added atomic_weight parameter to forward_common_atomic method.

The optional atomic_weight parameter has been added to support weighted dipole calculations.


🌐 Web query:

What are atomic weights in molecular dipole modeling?

💡 Result:

Atomic weights, which represent the average mass of an element's isotopes, are not directly used in molecular dipole modeling. Instead, dipole modeling focuses on atomic charges, polarizabilities, and molecular geometry to predict how electron density shifts within a molecule. Here’s how atomic-level properties contribute to dipole moment calculations:

Key Factors in Dipole Modeling:

  1. Atomic Partial Charges:
    Dipole moments arise from uneven charge distribution. Models like MuML (Machine learning for μ) decompose molecular dipoles into atomic contributions, combining local polarization (atomic dipoles) and long-range charge transfer (partial charges)[1][2][15]. These charges are derived from quantum mechanical calculations or machine learning predictions, not atomic weights[10][12].

  2. Atomic Polarizabilities:
    Polarizability—how easily an atom’s electron cloud distorts under an electric field—determines induced dipoles. Thole-style models assign polarizabilities to atoms based on their chemical environment, enabling predictions of molecular polarization[7][16]. For example, oxygen’s lone pairs significantly influence polarizability in molecules like water[7][13].

  3. Molecular Geometry:
    The spatial arrangement of atoms dictates whether bond dipoles cancel or reinforce each other. Linear molecules (e.g., CO₂) may have zero net dipole, while bent shapes (e.g., H₂O) result in measurable dipoles[6][9][13].

Role of Machine Learning:

  • Models like MuML use Gaussian process regression and symmetry-adapted kernels to predict atomic dipoles and charges, achieving accuracy comparable to quantum chemistry methods at lower computational cost[1][2][12].
  • Recent approaches train neural networks on Wannier function centers (representing electron distributions) to predict dipole moments in liquids like methanol, capturing polarization effects from intermolecular interactions[12].

Clarifying Terminology:

  • Atomic WeightAtomic Charge: Atomic weight (mass) is unrelated to charge distribution. Confusion may arise from terms like "weighting" in model parameters, which refers to the influence of atomic contributions, not isotopic mass[3][10].
  • Atomic Polarizability vs. Mass: Polarizability depends on electron cloud flexibility, not nuclear mass. For example, larger atoms like bromine exhibit higher polarizability due to diffuse electron shells[7][16].

In summary, molecular dipole models prioritize electronic properties (charges, polarizabilities) and geometry, while atomic weights play no direct role. Machine learning bridges quantum accuracy and scalability, enabling predictions for complex systems like biomolecules[1][12].

Citations:


Clarification on the atomic_weight Parameter Usage

While the new atomic_weight parameter has been added to support weighted dipole calculations in forward_common_atomic, note that in typical molecular dipole modeling atomic weights (which represent isotopic masses) are not directly used. Dipole calculations usually focus on atomic partial charges, polarizabilities, and molecular geometry. Please confirm that this parameter is intended as a custom weighting factor and, if so, consider updating its naming or documentation to clearly distinguish its purpose from the conventional concept of atomic weight.

  • File: deepmd/dpmodel/atomic_model/base_atomic_model.py (line 152)
  • Action: Verify the intended role of atomic_weight and adjust its name/documentation as needed to avoid confusion.

@ChiahsinChu ChiahsinChu force-pushed the devel-dipole_with_atomic_weight branch from d14fa86 to ad063bd Compare March 11, 2025 15:53
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (2)
deepmd/pt/model/atomic_model/base_atomic_model.py (2)

275-278: Consider adding shape validation for atomic_weight

The current implementation assumes that atomic_weight can be reshaped to [out_shape[0], out_shape[1], -1]. If the provided atomic_weight has an incompatible shape, this could lead to runtime errors.

Consider adding validation to check that the atomic_weight tensor has a compatible shape before reshaping, or at least document the expected shape more explicitly in the docstring.

            if atomic_weight is not None:
+               expected_batch_size = out_shape[0]
+               expected_atoms = out_shape[1]
+               if atomic_weight.shape[0] != expected_batch_size or atomic_weight.shape[1] != expected_atoms:
+                   raise ValueError(f"atomic_weight shape {atomic_weight.shape} incompatible with output shape {out_shape}")
                ret_dict[kk] = ret_dict[kk] * atomic_weight.view(
                    [out_shape[0], out_shape[1], -1]
                )

279-284: Remove duplicated line

The line ret_dict["mask"] = atom_mask appears twice (lines 279 and 282). This is redundant and one of them should be removed.

        ret_dict["mask"] = atom_mask

        return ret_dict
-       ret_dict["mask"] = atom_mask
-
-       return ret_dict
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between d14fa86 and 0e3647f.

📒 Files selected for processing (9)
  • deepmd/jax/model/base_model.py (1 hunks)
  • deepmd/pt/model/atomic_model/base_atomic_model.py (4 hunks)
  • deepmd/dpmodel/atomic_model/base_atomic_model.py (4 hunks)
  • deepmd/dpmodel/model/make_model.py (12 hunks)
  • source/tests/pt/model/test_dp_atomic_model.py (2 hunks)
  • deepmd/dpmodel/atomic_model/base_atomic_model.py (1 hunks)
  • deepmd/pt/model/atomic_model/base_atomic_model.py (1 hunks)
  • deepmd/jax/atomic_model/dp_atomic_model.py (2 hunks)
  • deepmd/jax/model/base_model.py (8 hunks)
🚧 Files skipped from review as they are similar to previous changes (5)
  • deepmd/pt/model/atomic_model/base_atomic_model.py
  • deepmd/jax/model/base_model.py
  • deepmd/dpmodel/atomic_model/base_atomic_model.py
  • deepmd/dpmodel/atomic_model/base_atomic_model.py
  • deepmd/dpmodel/model/make_model.py
⏰ Context from checks skipped due to timeout of 90000ms (17)
  • GitHub Check: Test Python (2, 3.12)
  • GitHub Check: Test Python (1, 3.12)
  • GitHub Check: Build wheels for cp310-manylinux_aarch64
  • GitHub Check: Build wheels for cp311-win_amd64
  • GitHub Check: Build wheels for cp311-macosx_arm64
  • GitHub Check: Build wheels for cp311-macosx_x86_64
  • GitHub Check: Build wheels for cp311-manylinux_x86_64
  • GitHub Check: Build wheels for cp311-manylinux_x86_64
  • GitHub Check: Build C library (2.14, >=2.5.0rc0,<2.15, libdeepmd_c_cu11.tar.gz)
  • GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
  • GitHub Check: Build C++ (clang, clang)
  • GitHub Check: Build C++ (rocm, rocm)
  • GitHub Check: Build C++ (cuda120, cuda)
  • GitHub Check: Build C++ (cuda, cuda)
  • GitHub Check: Test C++ (false)
  • GitHub Check: Build C++ (cpu, cpu)
  • GitHub Check: Test C++ (true)
🔇 Additional comments (6)
deepmd/jax/model/base_model.py (2)

50-50: Logic change from c_differentiable to r_differentiable

I noticed the condition changed from checking vdef.c_differentiable to vdef.r_differentiable. This seems to be an intentional logic change since there's an assertion at line 98 that assumes vdef.r_differentiable is true when vdef.c_differentiable is true.

Could you confirm this is the intended behavior and not a bug? This change affects when the derivative calculations are performed.


30-30: Good implementation of atomic_weight parameter

The atomic_weight parameter has been properly added to the function signature and correctly propagated to all relevant function calls within the implementation. This ensures consistent handling of atomic weights throughout the computation path.

Also applies to: 39-39, 61-62, 73-75, 88-88, 111-112, 124-126, 144-144

source/tests/pt/model/test_dp_atomic_model.py (2)

77-82: Well-designed test case for atomic_weight in self_consistency

The test implementation correctly verifies that applying atomic weights scales the energy output as expected. Multiplying the original energy by the reshaped atomic weights should match the energy when the weights are passed to the forward function.


112-118: Good cross-model consistency test for atomic_weight

This test appropriately verifies that the atomic_weight parameter works consistently between different model implementations (DPModel and PyTorch). This is important to ensure consistent behavior across the different backends.

deepmd/pt/model/atomic_model/base_atomic_model.py (1)

206-206: Correct implementation of atomic_weight parameter

The atomic_weight parameter has been properly added to the function signatures with appropriate documentation. The implementation correctly applies the weights to scale the output values.

Also applies to: 230-232, 275-278, 292-292, 302-302

deepmd/jax/atomic_model/dp_atomic_model.py (1)

61-61: Correctly propagated atomic_weight parameter

The atomic_weight parameter has been properly added to the function signature and correctly passed to the superclass method.

Also applies to: 70-70

@ChiahsinChu ChiahsinChu force-pushed the devel-dipole_with_atomic_weight branch from 0e3647f to 088b252 Compare March 11, 2025 16:30
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🔭 Outside diff range comments (1)
deepmd/jax/utils/serialization.py (1)

78-91: ⚠️ Potential issue

Update exported signature to include atomic_weight

The exported function uses jax.ShapeDtypeStruct for various parameters including aparam, but there is no corresponding jax.ShapeDtypeStruct for the newly added atomic_weight parameter. This could lead to serialization issues.

Add a shape-dtype struct for atomic_weight similar to the one for aparam:

                jax.ShapeDtypeStruct((nf, nloc, model.get_dim_aparam()), jnp.float64)
                if model.get_dim_aparam()
                else None,  # aparam
+               jax.ShapeDtypeStruct((nf, nloc), jnp.float64)
+               if atomic_weight is not None
+               else None,  # atomic_weight
            )
🧹 Nitpick comments (4)
deepmd/jax/jax2tf/make_model.py (1)

50-72: Update documentation for atomic_weight parameter

The function docstring should be updated to include information about the new atomic_weight parameter.

    aparam
        atomic parameter. nf x nloc x nda
+   atomic_weight
+       atomic weights for scaling. nf x nloc
    do_atomic_virial
        If calculate the atomic virial.
deepmd/jax/jax2tf/tfmodel.py (1)

82-82: Update docstring to include the new parameter.

The atomic_weight parameter has been added to the method signature, but it's missing from the docstring documentation.

deepmd/jax/jax2tf/serialization.py (2)

85-86: Address the hardcoded dimension in a future update

The comment indicates that the "1" should be replaced with the fitting output dimension. Consider creating a ticket to track this requirement for a future update.


120-121: Ensure consistent handling of the dimension issue

The same comment about replacing "1" with fitting output dimension appears multiple times. Ensure this is consistently addressed across all occurrences when implementing the fix.

Also applies to: 203-204, 229-230

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 0e3647f and 088b252.

📒 Files selected for processing (8)
  • deepmd/jax/model/base_model.py (1 hunks)
  • deepmd/jax/atomic_model/dp_atomic_model.py (2 hunks)
  • deepmd/jax/jax2tf/make_model.py (4 hunks)
  • deepmd/jax/jax2tf/serialization.py (11 hunks)
  • deepmd/jax/jax2tf/tfmodel.py (6 hunks)
  • deepmd/jax/model/base_model.py (8 hunks)
  • deepmd/jax/model/dp_model.py (2 hunks)
  • deepmd/jax/utils/serialization.py (2 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
  • deepmd/jax/model/base_model.py
  • deepmd/jax/atomic_model/dp_atomic_model.py
⏰ Context from checks skipped due to timeout of 90000ms (18)
  • GitHub Check: Analyze (python)
  • GitHub Check: Analyze (javascript-typescript)
  • GitHub Check: Analyze (c-cpp)
  • GitHub Check: Build wheels for cp310-manylinux_aarch64
  • GitHub Check: Build wheels for cp311-win_amd64
  • GitHub Check: Build wheels for cp311-macosx_arm64
  • GitHub Check: Build C++ (clang, clang)
  • GitHub Check: Build wheels for cp311-macosx_x86_64
  • GitHub Check: Build C++ (rocm, rocm)
  • GitHub Check: Build wheels for cp311-manylinux_x86_64
  • GitHub Check: Build C++ (cuda120, cuda)
  • GitHub Check: Build wheels for cp311-manylinux_x86_64
  • GitHub Check: Test C++ (false)
  • GitHub Check: Build C++ (cuda, cuda)
  • GitHub Check: Build C library (2.14, >=2.5.0rc0,<2.15, libdeepmd_c_cu11.tar.gz)
  • GitHub Check: Test C++ (true)
  • GitHub Check: Build C++ (cpu, cpu)
  • GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
🔇 Additional comments (16)
deepmd/jax/model/base_model.py (4)

29-40: LGTM: atomic_weight parameter appropriately added

The atomic_weight parameter has been correctly added to the function signature and passed to the atomic model's forward method.


73-89: Update atomic_weight in eval_output function

The atomic_weight parameter has been correctly added to the function signature and passed to the atomic model's forward method.


124-144: LGTM: atomic_weight parameter correctly handled in atomic virial calculation

The atomic_weight parameter is properly passed to the forward_common_atomic method in this code section.

🧰 Tools
🪛 Ruff (0.8.2)

130-130: Function definition does not bind loop variable def_ndim

(B023)


50-50:

✅ Verification successful

Verify intent of conditional change from c_differentiable to r_differentiable

The condition was changed from checking vdef.c_differentiable to vdef.r_differentiable. Line 97-98 confirms that r_differentiable is a prerequisite for c_differentiable (assert vdef.r_differentiable). This change potentially broadens which properties will have their derivatives calculated.


🏁 Script executed:

#!/bin/bash
# Check where these differentiability flags are defined and what they represent
rg -A 3 "r_differentiable|c_differentiable" --glob "*.py" | grep -v test

Length of output: 29705


Clarify and Confirm Differentiability Flag Usage in Base Model

The change from checking vdef.c_differentiable to checking vdef.r_differentiable is consistent with the overall design in the codebase. In many modules (e.g., in jax2tf/transform_output.py and tfmodel.py), r_differentiable is used as a prerequisite for c_differentiable—as seen by the assert vdef.r_differentiable accompanying the latter. This indicates that the intent is for properties flagged as r_differentiable to always have their derivatives calculated, while the additional c_differentiable flag further controls specific derivative computations (with the built-in safety check ensuring consistency).

Action Items:

  • Documentation/Comment Update: Consider adding a clarifying comment in deepmd/jax/model/base_model.py (and possibly related modules) to explicitly state that checking r_differentiable is intentional and that it broadens the derivative computation scope, while c_differentiable remains a subset requirement.
  • Review Consistency: Verify that downstream processing and derivative computations align with this broader condition as part of the overall differentiation strategy.
deepmd/jax/jax2tf/make_model.py (3)

47-47: LGTM: atomic_weight parameter added to function signature

The atomic_weight parameter has been properly added to the function signature.


77-78: LGTM: atomic_weight handled in parameter unpacking

The atomic_weight parameter is correctly included in variable unpacking and removed from the del statement.


107-107: LGTM: atomic_weight correctly passed to call_lower

The atomic_weight parameter is properly passed to the call_lower function.

deepmd/jax/utils/serialization.py (2)

60-61: LGTM: atomic_weight parameter added to function signature

The atomic_weight parameter has been correctly added to the call_lower_with_fixed_do_atomic_virial function signature.


69-70: LGTM: atomic_weight correctly passed to call_lower

The atomic_weight parameter is properly passed to the call_lower function.

deepmd/jax/model/dp_model.py (2)

59-60: LGTM: atomic_weight parameter added to function signature

The atomic_weight parameter has been correctly added to the forward_common_atomic method signature with an appropriate type annotation and default value.


70-71: LGTM: atomic_weight correctly passed to forward_common_atomic

The atomic_weight parameter is properly passed to the forward_common_atomic function.

deepmd/jax/jax2tf/tfmodel.py (1)

164-167: LGTM! Empty initialization for atomic_weight

The pattern of initializing atomic_weight to an empty tensor with proper shape when it's None is consistent with how other optional parameters are handled.

deepmd/jax/jax2tf/serialization.py (4)

43-49: LGTM! Parameter added correctly

The atomic_weight parameter has been correctly added to the function signature.


78-79: LGTM! Consistent shape specification

The shape specification "(nf, nloc, 1)" for atomic_weight is appropriate and consistent with the parameter's expected structure.


215-225: LGTM! Function call updated correctly

The function signature and call to make_call_whether_do_atomic_virial have been properly updated to include the atomic_weight parameter.


241-251: LGTM! Function call updated correctly

The function signature and call to make_call_whether_do_atomic_virial have been properly updated to include the atomic_weight parameter.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants