Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 65 additions & 11 deletions deepmd/dpmodel/loss/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,16 @@
)


def custom_huber_loss(predictions, targets, delta=1.0):
xp = array_api_compat.array_namespace(predictions, targets)
error = targets - predictions
abs_error = xp.abs(error)
quadratic_loss = 0.5 * error**2
linear_loss = delta * (abs_error - 0.5 * delta)
loss = xp.where(abs_error <= delta, quadratic_loss, linear_loss)
return xp.mean(loss)


class EnergyLoss(Loss):
def __init__(
self,
Expand All @@ -36,6 +46,8 @@ def __init__(
start_pref_gf: float = 0.0,
limit_pref_gf: float = 0.0,
numb_generalized_coord: int = 0,
use_huber=False,
huber_delta=0.01,
**kwargs,
) -> None:
self.starter_learning_rate = starter_learning_rate
Expand Down Expand Up @@ -64,6 +76,14 @@ def __init__(
raise RuntimeError(
"When generalized force loss is used, the dimension of generalized coordinates should be larger than 0"
)
self.use_huber = use_huber
self.huber_delta = huber_delta
if self.use_huber and (
self.has_pf or self.has_gf or self.relative_f is not None
):
raise RuntimeError(
"Huber loss is not implemented for force with atom_pref, generalized force and relative force. "
)

def call(
self,
Expand Down Expand Up @@ -144,15 +164,31 @@ def call(
self.limit_pref_pf + (self.start_pref_pf - self.limit_pref_pf) * lr_ratio
)

l2_loss = 0
loss = 0
more_loss = {}
if self.has_e:
l2_ener_loss = xp.mean(xp.square(energy - energy_hat))
l2_loss += atom_norm_ener * (pref_e * l2_ener_loss)
if not self.use_huber:
loss += atom_norm_ener * (pref_e * l2_ener_loss)
else:
l_huber_loss = custom_huber_loss(
atom_norm_ener * energy,
atom_norm_ener * energy_hat,
delta=self.huber_delta,
)
loss += pref_e * l_huber_loss
more_loss["l2_ener_loss"] = self.display_if_exist(l2_ener_loss, find_energy)
if self.has_f:
l2_force_loss = xp.mean(xp.square(diff_f))
l2_loss += pref_f * l2_force_loss
if not self.use_huber:
loss += pref_f * l2_force_loss
else:
l_huber_loss = custom_huber_loss(
xp.reshape(force, [-1]),
xp.reshape(force_hat, [-1]),
delta=self.huber_delta,
)
loss += pref_f * l_huber_loss
more_loss["l2_force_loss"] = self.display_if_exist(
l2_force_loss, find_force
)
Expand All @@ -162,7 +198,15 @@ def call(
l2_virial_loss = xp.mean(
xp.square(virial_hat_reshape - virial_reshape),
)
l2_loss += atom_norm * (pref_v * l2_virial_loss)
if not self.use_huber:
loss += atom_norm * (pref_v * l2_virial_loss)
else:
l_huber_loss = custom_huber_loss(
atom_norm * virial_reshape,
atom_norm * virial_hat_reshape,
delta=self.huber_delta,
)
loss += pref_v * l_huber_loss
more_loss["l2_virial_loss"] = self.display_if_exist(
l2_virial_loss, find_virial
)
Expand All @@ -172,7 +216,15 @@ def call(
l2_atom_ener_loss = xp.mean(
xp.square(atom_ener_hat_reshape - atom_ener_reshape),
)
l2_loss += pref_ae * l2_atom_ener_loss
if not self.use_huber:
loss += pref_ae * l2_atom_ener_loss
else:
l_huber_loss = custom_huber_loss(
atom_ener_reshape,
atom_ener_hat_reshape,
delta=self.huber_delta,
)
loss += pref_ae * l_huber_loss
more_loss["l2_atom_ener_loss"] = self.display_if_exist(
l2_atom_ener_loss, find_atom_ener
)
Expand All @@ -181,7 +233,7 @@ def call(
l2_pref_force_loss = xp.mean(
xp.multiply(xp.square(diff_f), atom_pref_reshape),
)
l2_loss += pref_pf * l2_pref_force_loss
loss += pref_pf * l2_pref_force_loss
more_loss["l2_pref_force_loss"] = self.display_if_exist(
l2_pref_force_loss, find_atom_pref
)
Expand All @@ -203,14 +255,14 @@ def call(
self.limit_pref_gf
+ (self.start_pref_gf - self.limit_pref_gf) * lr_ratio
)
l2_loss += pref_gf * l2_gen_force_loss
loss += pref_gf * l2_gen_force_loss
more_loss["l2_gen_force_loss"] = self.display_if_exist(
l2_gen_force_loss, find_drdq
)

self.l2_l = l2_loss
self.l2_l = loss
self.l2_more = more_loss
return l2_loss, more_loss
return loss, more_loss

@property
def label_requirement(self) -> list[DataRequirementItem]:
Expand Down Expand Up @@ -300,7 +352,7 @@ def serialize(self) -> dict:
"""
return {
"@class": "EnergyLoss",
"@version": 1,
"@version": 2,
"starter_learning_rate": self.starter_learning_rate,
"start_pref_e": self.start_pref_e,
"limit_pref_e": self.limit_pref_e,
Expand All @@ -317,6 +369,8 @@ def serialize(self) -> dict:
"start_pref_gf": self.start_pref_gf,
"limit_pref_gf": self.limit_pref_gf,
"numb_generalized_coord": self.numb_generalized_coord,
"use_huber": self.use_huber,
"huber_delta": self.huber_delta,
}

@classmethod
Expand All @@ -334,6 +388,6 @@ def deserialize(cls, data: dict) -> "Loss":
The deserialized loss module
"""
data = data.copy()
check_version_compatibility(data.pop("@version"), 1, 1)
check_version_compatibility(data.pop("@version"), 2, 1)
data.pop("@class")
return cls(**data)
73 changes: 67 additions & 6 deletions deepmd/pd/loss/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,15 @@
)


def custom_huber_loss(predictions, targets, delta=1.0):
error = targets - predictions
abs_error = paddle.abs(error)
quadratic_loss = 0.5 * paddle.pow(error, 2)
linear_loss = delta * (abs_error - 0.5 * delta)
loss = paddle.where(abs_error <= delta, quadratic_loss, linear_loss)
return paddle.mean(loss)


class EnergyStdLoss(TaskLoss):
def __init__(
self,
Expand All @@ -44,6 +53,8 @@ def __init__(
numb_generalized_coord: int = 0,
use_l1_all: bool = False,
inference=False,
use_huber=False,
huber_delta=0.01,
**kwargs,
):
r"""Construct a layer to compute loss on energy, force and virial.
Expand Down Expand Up @@ -88,6 +99,14 @@ def __init__(
Whether to use L1 loss, if False (default), it will use L2 loss.
inference : bool
If true, it will output all losses found in output, ignoring the pre-factors.
use_huber : bool
Enables Huber loss calculation for energy/force/virial terms with user-defined threshold delta (D).
The loss function smoothly transitions between L2 and L1 loss:
- For absolute prediction errors within D: quadratic loss (0.5 * (error**2))
- For absolute errors exceeding D: linear loss (D * |error| - 0.5 * D)
Formula: loss = 0.5 * (error**2) if |error| <= D else D * (|error| - 0.5 * D).
huber_delta : float
The threshold value used for Huber loss, controlling transition between L2 and L1 loss.
**kwargs
Other keyword arguments.
"""
Expand Down Expand Up @@ -121,6 +140,14 @@ def __init__(
)
self.use_l1_all = use_l1_all
self.inference = inference
self.use_huber = use_huber
self.huber_delta = huber_delta
if self.use_huber and (
self.has_pf or self.has_gf or self.relative_f is not None
):
raise RuntimeError(
"Huber loss is not implemented for force with atom_pref, generalized force and relative force. "
)

def forward(self, input_dict, model, label, natoms, learning_rate, mae=False):
"""Return loss on energy and force.
Expand Down Expand Up @@ -183,7 +210,15 @@ def forward(self, input_dict, model, label, natoms, learning_rate, mae=False):
more_loss["l2_ener_loss"] = self.display_if_exist(
l2_ener_loss.detach(), find_energy
)
loss += atom_norm * (pref_e * l2_ener_loss)
if not self.use_huber:
loss += atom_norm * (pref_e * l2_ener_loss)
else:
l_huber_loss = custom_huber_loss(
atom_norm * model_pred["energy"],
atom_norm * label["energy"],
delta=self.huber_delta,
)
loss += pref_e * l_huber_loss
rmse_e = l2_ener_loss.sqrt() * atom_norm
more_loss["rmse_e"] = self.display_if_exist(
rmse_e.detach(), find_energy
Expand Down Expand Up @@ -238,7 +273,15 @@ def forward(self, input_dict, model, label, natoms, learning_rate, mae=False):
more_loss["l2_force_loss"] = self.display_if_exist(
l2_force_loss.detach(), find_force
)
loss += (pref_f * l2_force_loss).to(GLOBAL_PD_FLOAT_PRECISION)
if not self.use_huber:
loss += (pref_f * l2_force_loss).to(GLOBAL_PD_FLOAT_PRECISION)
else:
l_huber_loss = custom_huber_loss(
force_pred.reshape(-1),
force_label.reshape(-1),
delta=self.huber_delta,
)
loss += pref_f * l_huber_loss
rmse_f = l2_force_loss.sqrt()
more_loss["rmse_f"] = self.display_if_exist(
rmse_f.detach(), find_force
Expand Down Expand Up @@ -317,7 +360,15 @@ def forward(self, input_dict, model, label, natoms, learning_rate, mae=False):
more_loss["l2_virial_loss"] = self.display_if_exist(
l2_virial_loss.detach(), find_virial
)
loss += atom_norm * (pref_v * l2_virial_loss)
if not self.use_huber:
loss += atom_norm * (pref_v * l2_virial_loss)
else:
l_huber_loss = custom_huber_loss(
atom_norm * model_pred["virial"].reshape(-1),
atom_norm * label["virial"].reshape(-1),
delta=self.huber_delta,
)
loss += pref_v * l_huber_loss
rmse_v = l2_virial_loss.sqrt() * atom_norm
more_loss["rmse_v"] = self.display_if_exist(rmse_v.detach(), find_virial)
if mae:
Expand All @@ -338,7 +389,15 @@ def forward(self, input_dict, model, label, natoms, learning_rate, mae=False):
more_loss["l2_atom_ener_loss"] = self.display_if_exist(
l2_atom_ener_loss.detach(), find_atom_ener
)
loss += (pref_ae * l2_atom_ener_loss).to(GLOBAL_PD_FLOAT_PRECISION)
if not self.use_huber:
loss += (pref_ae * l2_atom_ener_loss).to(GLOBAL_PD_FLOAT_PRECISION)
else:
l_huber_loss = custom_huber_loss(
atom_ener_reshape,
atom_ener_label_reshape,
delta=self.huber_delta,
)
loss += pref_ae * l_huber_loss
rmse_ae = l2_atom_ener_loss.sqrt()
more_loss["rmse_ae"] = self.display_if_exist(
rmse_ae.detach(), find_atom_ener
Expand Down Expand Up @@ -436,7 +495,7 @@ def serialize(self) -> dict:
"""
return {
"@class": "EnergyLoss",
"@version": 1,
"@version": 2,
"starter_learning_rate": self.starter_learning_rate,
"start_pref_e": self.start_pref_e,
"limit_pref_e": self.limit_pref_e,
Expand All @@ -453,6 +512,8 @@ def serialize(self) -> dict:
"start_pref_gf": self.start_pref_gf,
"limit_pref_gf": self.limit_pref_gf,
"numb_generalized_coord": self.numb_generalized_coord,
"use_huber": self.use_huber,
"huber_delta": self.huber_delta,
}

@classmethod
Expand All @@ -470,6 +531,6 @@ def deserialize(cls, data: dict) -> "TaskLoss":
The deserialized loss module
"""
data = data.copy()
check_version_compatibility(data.pop("@version"), 1, 1)
check_version_compatibility(data.pop("@version"), 2, 1)
data.pop("@class")
return cls(**data)
Loading
Loading