|
| 1 | +# SPDX-License-Identifier: LGPL-3.0-or-later |
| 2 | +from pathlib import ( |
| 3 | + Path, |
| 4 | +) |
| 5 | + |
| 6 | +import orbax.checkpoint as ocp |
| 7 | + |
| 8 | +from deepmd.jax.env import ( |
| 9 | + jax, |
| 10 | + nnx, |
| 11 | +) |
| 12 | +from deepmd.jax.model.model import ( |
| 13 | + BaseModel, |
| 14 | + get_model, |
| 15 | +) |
| 16 | + |
| 17 | + |
| 18 | +def deserialize_to_file(model_file: str, data: dict) -> None: |
| 19 | + """Deserialize the dictionary to a model file. |
| 20 | +
|
| 21 | + Parameters |
| 22 | + ---------- |
| 23 | + model_file : str |
| 24 | + The model file to be saved. |
| 25 | + data : dict |
| 26 | + The dictionary to be deserialized. |
| 27 | + """ |
| 28 | + if model_file.endswith(".jax"): |
| 29 | + model = BaseModel.deserialize(data["model"]) |
| 30 | + model_def_script = data["model_def_script"] |
| 31 | + _, state = nnx.split(model) |
| 32 | + with ocp.Checkpointer( |
| 33 | + ocp.CompositeCheckpointHandler("state", "model_def_script") |
| 34 | + ) as checkpointer: |
| 35 | + checkpointer.save( |
| 36 | + Path(model_file).absolute(), |
| 37 | + ocp.args.Composite( |
| 38 | + state=ocp.args.StandardSave(state.to_pure_dict()), |
| 39 | + model_def_script=ocp.args.JsonSave(model_def_script), |
| 40 | + ), |
| 41 | + ) |
| 42 | + else: |
| 43 | + raise ValueError("JAX backend only supports converting .jax directory") |
| 44 | + |
| 45 | + |
| 46 | +def serialize_from_file(model_file: str) -> dict: |
| 47 | + """Serialize the model file to a dictionary. |
| 48 | +
|
| 49 | + Parameters |
| 50 | + ---------- |
| 51 | + model_file : str |
| 52 | + The model file to be serialized. |
| 53 | +
|
| 54 | + Returns |
| 55 | + ------- |
| 56 | + dict |
| 57 | + The serialized model data. |
| 58 | + """ |
| 59 | + if model_file.endswith(".jax"): |
| 60 | + with ocp.Checkpointer( |
| 61 | + ocp.CompositeCheckpointHandler("state", "model_def_script") |
| 62 | + ) as checkpointer: |
| 63 | + data = checkpointer.restore( |
| 64 | + Path(model_file).absolute(), |
| 65 | + ocp.args.Composite( |
| 66 | + state=ocp.args.StandardRestore(), |
| 67 | + model_def_script=ocp.args.JsonRestore(), |
| 68 | + ), |
| 69 | + ) |
| 70 | + state = data.state |
| 71 | + |
| 72 | + # convert str "1" to int 1 key |
| 73 | + def convert_str_to_int_key(item: dict): |
| 74 | + for key, value in item.copy().items(): |
| 75 | + if isinstance(value, dict): |
| 76 | + convert_str_to_int_key(value) |
| 77 | + if key.isdigit(): |
| 78 | + item[int(key)] = item.pop(key) |
| 79 | + |
| 80 | + convert_str_to_int_key(state) |
| 81 | + |
| 82 | + model_def_script = data.model_def_script |
| 83 | + abstract_model = get_model(model_def_script) |
| 84 | + graphdef, abstract_state = nnx.split(abstract_model) |
| 85 | + abstract_state.replace_by_pure_dict(state) |
| 86 | + model = nnx.merge(graphdef, abstract_state) |
| 87 | + model_dict = model.serialize() |
| 88 | + data = { |
| 89 | + "backend": "JAX", |
| 90 | + "jax_version": jax.__version__, |
| 91 | + "model": model_dict, |
| 92 | + "model_def_script": model_def_script, |
| 93 | + "@variables": {}, |
| 94 | + } |
| 95 | + return data |
| 96 | + else: |
| 97 | + raise ValueError("JAX backend only supports converting .jax directory") |
0 commit comments