Skip to content

Commit a1b5089

Browse files
authored
feat: add huber loss (#4684)
<!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Introduced an optional robust loss calculation mode that leverages Huber loss. Users can now choose between the traditional loss and a smoother Huber-based approach, offering improved error handling for key training metrics. - **Documentation** - Updated user-facing guides to explain the new robust loss option and its configurable threshold. - **Tests** - Expanded test coverage to validate both conventional and Huber loss scenarios, ensuring consistent and reliable performance. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent 5f740f9 commit a1b5089

File tree

7 files changed

+323
-37
lines changed

7 files changed

+323
-37
lines changed

deepmd/dpmodel/loss/ener.py

Lines changed: 65 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,16 @@
1717
)
1818

1919

20+
def custom_huber_loss(predictions, targets, delta=1.0):
21+
xp = array_api_compat.array_namespace(predictions, targets)
22+
error = targets - predictions
23+
abs_error = xp.abs(error)
24+
quadratic_loss = 0.5 * error**2
25+
linear_loss = delta * (abs_error - 0.5 * delta)
26+
loss = xp.where(abs_error <= delta, quadratic_loss, linear_loss)
27+
return xp.mean(loss)
28+
29+
2030
class EnergyLoss(Loss):
2131
def __init__(
2232
self,
@@ -36,6 +46,8 @@ def __init__(
3646
start_pref_gf: float = 0.0,
3747
limit_pref_gf: float = 0.0,
3848
numb_generalized_coord: int = 0,
49+
use_huber=False,
50+
huber_delta=0.01,
3951
**kwargs,
4052
) -> None:
4153
self.starter_learning_rate = starter_learning_rate
@@ -64,6 +76,14 @@ def __init__(
6476
raise RuntimeError(
6577
"When generalized force loss is used, the dimension of generalized coordinates should be larger than 0"
6678
)
79+
self.use_huber = use_huber
80+
self.huber_delta = huber_delta
81+
if self.use_huber and (
82+
self.has_pf or self.has_gf or self.relative_f is not None
83+
):
84+
raise RuntimeError(
85+
"Huber loss is not implemented for force with atom_pref, generalized force and relative force. "
86+
)
6787

6888
def call(
6989
self,
@@ -144,15 +164,31 @@ def call(
144164
self.limit_pref_pf + (self.start_pref_pf - self.limit_pref_pf) * lr_ratio
145165
)
146166

147-
l2_loss = 0
167+
loss = 0
148168
more_loss = {}
149169
if self.has_e:
150170
l2_ener_loss = xp.mean(xp.square(energy - energy_hat))
151-
l2_loss += atom_norm_ener * (pref_e * l2_ener_loss)
171+
if not self.use_huber:
172+
loss += atom_norm_ener * (pref_e * l2_ener_loss)
173+
else:
174+
l_huber_loss = custom_huber_loss(
175+
atom_norm_ener * energy,
176+
atom_norm_ener * energy_hat,
177+
delta=self.huber_delta,
178+
)
179+
loss += pref_e * l_huber_loss
152180
more_loss["l2_ener_loss"] = self.display_if_exist(l2_ener_loss, find_energy)
153181
if self.has_f:
154182
l2_force_loss = xp.mean(xp.square(diff_f))
155-
l2_loss += pref_f * l2_force_loss
183+
if not self.use_huber:
184+
loss += pref_f * l2_force_loss
185+
else:
186+
l_huber_loss = custom_huber_loss(
187+
xp.reshape(force, [-1]),
188+
xp.reshape(force_hat, [-1]),
189+
delta=self.huber_delta,
190+
)
191+
loss += pref_f * l_huber_loss
156192
more_loss["l2_force_loss"] = self.display_if_exist(
157193
l2_force_loss, find_force
158194
)
@@ -162,7 +198,15 @@ def call(
162198
l2_virial_loss = xp.mean(
163199
xp.square(virial_hat_reshape - virial_reshape),
164200
)
165-
l2_loss += atom_norm * (pref_v * l2_virial_loss)
201+
if not self.use_huber:
202+
loss += atom_norm * (pref_v * l2_virial_loss)
203+
else:
204+
l_huber_loss = custom_huber_loss(
205+
atom_norm * virial_reshape,
206+
atom_norm * virial_hat_reshape,
207+
delta=self.huber_delta,
208+
)
209+
loss += pref_v * l_huber_loss
166210
more_loss["l2_virial_loss"] = self.display_if_exist(
167211
l2_virial_loss, find_virial
168212
)
@@ -172,7 +216,15 @@ def call(
172216
l2_atom_ener_loss = xp.mean(
173217
xp.square(atom_ener_hat_reshape - atom_ener_reshape),
174218
)
175-
l2_loss += pref_ae * l2_atom_ener_loss
219+
if not self.use_huber:
220+
loss += pref_ae * l2_atom_ener_loss
221+
else:
222+
l_huber_loss = custom_huber_loss(
223+
atom_ener_reshape,
224+
atom_ener_hat_reshape,
225+
delta=self.huber_delta,
226+
)
227+
loss += pref_ae * l_huber_loss
176228
more_loss["l2_atom_ener_loss"] = self.display_if_exist(
177229
l2_atom_ener_loss, find_atom_ener
178230
)
@@ -181,7 +233,7 @@ def call(
181233
l2_pref_force_loss = xp.mean(
182234
xp.multiply(xp.square(diff_f), atom_pref_reshape),
183235
)
184-
l2_loss += pref_pf * l2_pref_force_loss
236+
loss += pref_pf * l2_pref_force_loss
185237
more_loss["l2_pref_force_loss"] = self.display_if_exist(
186238
l2_pref_force_loss, find_atom_pref
187239
)
@@ -203,14 +255,14 @@ def call(
203255
self.limit_pref_gf
204256
+ (self.start_pref_gf - self.limit_pref_gf) * lr_ratio
205257
)
206-
l2_loss += pref_gf * l2_gen_force_loss
258+
loss += pref_gf * l2_gen_force_loss
207259
more_loss["l2_gen_force_loss"] = self.display_if_exist(
208260
l2_gen_force_loss, find_drdq
209261
)
210262

211-
self.l2_l = l2_loss
263+
self.l2_l = loss
212264
self.l2_more = more_loss
213-
return l2_loss, more_loss
265+
return loss, more_loss
214266

215267
@property
216268
def label_requirement(self) -> list[DataRequirementItem]:
@@ -300,7 +352,7 @@ def serialize(self) -> dict:
300352
"""
301353
return {
302354
"@class": "EnergyLoss",
303-
"@version": 1,
355+
"@version": 2,
304356
"starter_learning_rate": self.starter_learning_rate,
305357
"start_pref_e": self.start_pref_e,
306358
"limit_pref_e": self.limit_pref_e,
@@ -317,6 +369,8 @@ def serialize(self) -> dict:
317369
"start_pref_gf": self.start_pref_gf,
318370
"limit_pref_gf": self.limit_pref_gf,
319371
"numb_generalized_coord": self.numb_generalized_coord,
372+
"use_huber": self.use_huber,
373+
"huber_delta": self.huber_delta,
320374
}
321375

322376
@classmethod
@@ -334,6 +388,6 @@ def deserialize(cls, data: dict) -> "Loss":
334388
The deserialized loss module
335389
"""
336390
data = data.copy()
337-
check_version_compatibility(data.pop("@version"), 1, 1)
391+
check_version_compatibility(data.pop("@version"), 2, 1)
338392
data.pop("@class")
339393
return cls(**data)

deepmd/pd/loss/ener.py

Lines changed: 67 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,15 @@
2323
)
2424

2525

26+
def custom_huber_loss(predictions, targets, delta=1.0):
27+
error = targets - predictions
28+
abs_error = paddle.abs(error)
29+
quadratic_loss = 0.5 * paddle.pow(error, 2)
30+
linear_loss = delta * (abs_error - 0.5 * delta)
31+
loss = paddle.where(abs_error <= delta, quadratic_loss, linear_loss)
32+
return paddle.mean(loss)
33+
34+
2635
class EnergyStdLoss(TaskLoss):
2736
def __init__(
2837
self,
@@ -44,6 +53,8 @@ def __init__(
4453
numb_generalized_coord: int = 0,
4554
use_l1_all: bool = False,
4655
inference=False,
56+
use_huber=False,
57+
huber_delta=0.01,
4758
**kwargs,
4859
):
4960
r"""Construct a layer to compute loss on energy, force and virial.
@@ -88,6 +99,14 @@ def __init__(
8899
Whether to use L1 loss, if False (default), it will use L2 loss.
89100
inference : bool
90101
If true, it will output all losses found in output, ignoring the pre-factors.
102+
use_huber : bool
103+
Enables Huber loss calculation for energy/force/virial terms with user-defined threshold delta (D).
104+
The loss function smoothly transitions between L2 and L1 loss:
105+
- For absolute prediction errors within D: quadratic loss (0.5 * (error**2))
106+
- For absolute errors exceeding D: linear loss (D * |error| - 0.5 * D)
107+
Formula: loss = 0.5 * (error**2) if |error| <= D else D * (|error| - 0.5 * D).
108+
huber_delta : float
109+
The threshold delta (D) used for Huber loss, controlling transition between L2 and L1 loss.
91110
**kwargs
92111
Other keyword arguments.
93112
"""
@@ -121,6 +140,14 @@ def __init__(
121140
)
122141
self.use_l1_all = use_l1_all
123142
self.inference = inference
143+
self.use_huber = use_huber
144+
self.huber_delta = huber_delta
145+
if self.use_huber and (
146+
self.has_pf or self.has_gf or self.relative_f is not None
147+
):
148+
raise RuntimeError(
149+
"Huber loss is not implemented for force with atom_pref, generalized force and relative force. "
150+
)
124151

125152
def forward(self, input_dict, model, label, natoms, learning_rate, mae=False):
126153
"""Return loss on energy and force.
@@ -183,7 +210,15 @@ def forward(self, input_dict, model, label, natoms, learning_rate, mae=False):
183210
more_loss["l2_ener_loss"] = self.display_if_exist(
184211
l2_ener_loss.detach(), find_energy
185212
)
186-
loss += atom_norm * (pref_e * l2_ener_loss)
213+
if not self.use_huber:
214+
loss += atom_norm * (pref_e * l2_ener_loss)
215+
else:
216+
l_huber_loss = custom_huber_loss(
217+
atom_norm * model_pred["energy"],
218+
atom_norm * label["energy"],
219+
delta=self.huber_delta,
220+
)
221+
loss += pref_e * l_huber_loss
187222
rmse_e = l2_ener_loss.sqrt() * atom_norm
188223
more_loss["rmse_e"] = self.display_if_exist(
189224
rmse_e.detach(), find_energy
@@ -238,7 +273,15 @@ def forward(self, input_dict, model, label, natoms, learning_rate, mae=False):
238273
more_loss["l2_force_loss"] = self.display_if_exist(
239274
l2_force_loss.detach(), find_force
240275
)
241-
loss += (pref_f * l2_force_loss).to(GLOBAL_PD_FLOAT_PRECISION)
276+
if not self.use_huber:
277+
loss += (pref_f * l2_force_loss).to(GLOBAL_PD_FLOAT_PRECISION)
278+
else:
279+
l_huber_loss = custom_huber_loss(
280+
force_pred.reshape([-1]),
281+
force_label.reshape([-1]),
282+
delta=self.huber_delta,
283+
)
284+
loss += pref_f * l_huber_loss
242285
rmse_f = l2_force_loss.sqrt()
243286
more_loss["rmse_f"] = self.display_if_exist(
244287
rmse_f.detach(), find_force
@@ -317,7 +360,15 @@ def forward(self, input_dict, model, label, natoms, learning_rate, mae=False):
317360
more_loss["l2_virial_loss"] = self.display_if_exist(
318361
l2_virial_loss.detach(), find_virial
319362
)
320-
loss += atom_norm * (pref_v * l2_virial_loss)
363+
if not self.use_huber:
364+
loss += atom_norm * (pref_v * l2_virial_loss)
365+
else:
366+
l_huber_loss = custom_huber_loss(
367+
atom_norm * model_pred["virial"].reshape([-1]),
368+
atom_norm * label["virial"].reshape([-1]),
369+
delta=self.huber_delta,
370+
)
371+
loss += pref_v * l_huber_loss
321372
rmse_v = l2_virial_loss.sqrt() * atom_norm
322373
more_loss["rmse_v"] = self.display_if_exist(rmse_v.detach(), find_virial)
323374
if mae:
@@ -338,7 +389,15 @@ def forward(self, input_dict, model, label, natoms, learning_rate, mae=False):
338389
more_loss["l2_atom_ener_loss"] = self.display_if_exist(
339390
l2_atom_ener_loss.detach(), find_atom_ener
340391
)
341-
loss += (pref_ae * l2_atom_ener_loss).to(GLOBAL_PD_FLOAT_PRECISION)
392+
if not self.use_huber:
393+
loss += (pref_ae * l2_atom_ener_loss).to(GLOBAL_PD_FLOAT_PRECISION)
394+
else:
395+
l_huber_loss = custom_huber_loss(
396+
atom_ener_reshape,
397+
atom_ener_label_reshape,
398+
delta=self.huber_delta,
399+
)
400+
loss += pref_ae * l_huber_loss
342401
rmse_ae = l2_atom_ener_loss.sqrt()
343402
more_loss["rmse_ae"] = self.display_if_exist(
344403
rmse_ae.detach(), find_atom_ener
@@ -436,7 +495,7 @@ def serialize(self) -> dict:
436495
"""
437496
return {
438497
"@class": "EnergyLoss",
439-
"@version": 1,
498+
"@version": 2,
440499
"starter_learning_rate": self.starter_learning_rate,
441500
"start_pref_e": self.start_pref_e,
442501
"limit_pref_e": self.limit_pref_e,
@@ -453,6 +512,8 @@ def serialize(self) -> dict:
453512
"start_pref_gf": self.start_pref_gf,
454513
"limit_pref_gf": self.limit_pref_gf,
455514
"numb_generalized_coord": self.numb_generalized_coord,
515+
"use_huber": self.use_huber,
516+
"huber_delta": self.huber_delta,
456517
}
457518

458519
@classmethod
@@ -470,6 +531,6 @@ def deserialize(cls, data: dict) -> "TaskLoss":
470531
The deserialized loss module
471532
"""
472533
data = data.copy()
473-
check_version_compatibility(data.pop("@version"), 1, 1)
534+
check_version_compatibility(data.pop("@version"), 2, 1)
474535
data.pop("@class")
475536
return cls(**data)

0 commit comments

Comments
 (0)