Skip to content

Commit 02a3048

Browse files
njzjzYour Name
andauthored
fix(tf): fix argcheck when compressing a model converted from other backends (#4331)
When the model is converted from other backends, the input script only contains the `model` section. This PR sets the default for any necessary argument. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Enhanced the data structure for model compression by adding default keys for training steps and learning rate. - **Bug Fixes** - Improved error handling with more informative runtime exceptions for missing training scripts. <!-- end of auto-generated comment: release notes by coderabbit.ai --> Co-authored-by: Your Name <[email protected]>
1 parent dcbf607 commit 02a3048

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

deepmd/tf/entrypoints/compress.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,8 @@ def compress(
147147
10 * step,
148148
int(frequency),
149149
]
150+
jdata.setdefault("training", {"numb_steps": 0})
151+
jdata.setdefault("learning_rate", {})
150152
jdata["training"]["save_ckpt"] = os.path.join("model-compression", "model.ckpt")
151153
jdata = update_deepmd_input(jdata)
152154
jdata = normalize(jdata)

0 commit comments

Comments
 (0)