|
23 | 23 | ) |
24 | 24 |
|
25 | 25 |
|
26 | | -@flax_module |
27 | | -class DPAtomicModel(DPAtomicModelDP): |
28 | | - base_descriptor_cls = BaseDescriptor |
29 | | - """The base descriptor class.""" |
30 | | - base_fitting_cls = BaseFitting |
31 | | - """The base fitting class.""" |
32 | | - |
33 | | - def __setattr__(self, name: str, value: Any) -> None: |
34 | | - value = base_atomic_model_set_attr(name, value) |
35 | | - return super().__setattr__(name, value) |
36 | | - |
37 | | - def forward_common_atomic( |
38 | | - self, |
39 | | - extended_coord: jnp.ndarray, |
40 | | - extended_atype: jnp.ndarray, |
41 | | - nlist: jnp.ndarray, |
42 | | - mapping: Optional[jnp.ndarray] = None, |
43 | | - fparam: Optional[jnp.ndarray] = None, |
44 | | - aparam: Optional[jnp.ndarray] = None, |
45 | | - ) -> dict[str, jnp.ndarray]: |
46 | | - return super().forward_common_atomic( |
47 | | - extended_coord, |
48 | | - extended_atype, |
49 | | - jax.lax.stop_gradient(nlist), |
50 | | - mapping=mapping, |
51 | | - fparam=fparam, |
52 | | - aparam=aparam, |
53 | | - ) |
| 26 | +def make_jax_dp_atomic_model_from_dpmodel( |
| 27 | + dpmodel_atomic_model: type[DPAtomicModelDP], |
| 28 | +) -> type[DPAtomicModelDP]: |
| 29 | + """Make a JAX backend DP atomic model from a DPModel backend DP atomic model. |
| 30 | +
|
| 31 | + Parameters |
| 32 | + ---------- |
| 33 | + dpmodel_atomic_model : type[DPAtomicModelDP] |
| 34 | + The DPModel backend DP atomic model. |
| 35 | +
|
| 36 | + Returns |
| 37 | + ------- |
| 38 | + type[DPAtomicModel] |
| 39 | + The JAX backend DP atomic model. |
| 40 | + """ |
| 41 | + |
| 42 | + @flax_module |
| 43 | + class jax_atomic_model(dpmodel_atomic_model): |
| 44 | + base_descriptor_cls = BaseDescriptor |
| 45 | + """The base descriptor class.""" |
| 46 | + base_fitting_cls = BaseFitting |
| 47 | + """The base fitting class.""" |
| 48 | + |
| 49 | + def __setattr__(self, name: str, value: Any) -> None: |
| 50 | + value = base_atomic_model_set_attr(name, value) |
| 51 | + return super().__setattr__(name, value) |
| 52 | + |
| 53 | + def forward_common_atomic( |
| 54 | + self, |
| 55 | + extended_coord: jnp.ndarray, |
| 56 | + extended_atype: jnp.ndarray, |
| 57 | + nlist: jnp.ndarray, |
| 58 | + mapping: Optional[jnp.ndarray] = None, |
| 59 | + fparam: Optional[jnp.ndarray] = None, |
| 60 | + aparam: Optional[jnp.ndarray] = None, |
| 61 | + ) -> dict[str, jnp.ndarray]: |
| 62 | + return super().forward_common_atomic( |
| 63 | + extended_coord, |
| 64 | + extended_atype, |
| 65 | + jax.lax.stop_gradient(nlist), |
| 66 | + mapping=mapping, |
| 67 | + fparam=fparam, |
| 68 | + aparam=aparam, |
| 69 | + ) |
| 70 | + |
| 71 | + return jax_atomic_model |
| 72 | + |
| 73 | + |
| 74 | +class DPAtomicModel(make_jax_dp_atomic_model_from_dpmodel(DPAtomicModelDP)): |
| 75 | + pass |
0 commit comments