2323)
2424
2525
26+ def custom_huber_loss (predictions , targets , delta = 1.0 ):
27+ error = targets - predictions
28+ abs_error = paddle .abs (error )
29+ quadratic_loss = 0.5 * paddle .pow (error , 2 )
30+ linear_loss = delta * (abs_error - 0.5 * delta )
31+ loss = paddle .where (abs_error <= delta , quadratic_loss , linear_loss )
32+ return paddle .mean (loss )
33+
34+
2635class EnergyStdLoss (TaskLoss ):
2736 def __init__ (
2837 self ,
@@ -44,6 +53,8 @@ def __init__(
4453 numb_generalized_coord : int = 0 ,
4554 use_l1_all : bool = False ,
4655 inference = False ,
56+ use_huber = False ,
57+ huber_delta = 0.01 ,
4758 ** kwargs ,
4859 ):
4960 r"""Construct a layer to compute loss on energy, force and virial.
@@ -88,6 +99,14 @@ def __init__(
8899 Whether to use L1 loss, if False (default), it will use L2 loss.
89100 inference : bool
90101 If true, it will output all losses found in output, ignoring the pre-factors.
102+ use_huber : bool
103+ Enables Huber loss calculation for energy/force/virial terms with user-defined threshold delta (D).
104+ The loss function smoothly transitions between L2 and L1 loss:
105+ - For absolute prediction errors within D: quadratic loss (0.5 * (error**2))
106+ - For absolute errors exceeding D: linear loss (D * |error| - 0.5 * D)
107+ Formula: loss = 0.5 * (error**2) if |error| <= D else D * (|error| - 0.5 * D).
108+ huber_delta : float
109+ The threshold delta (D) used for Huber loss, controlling transition between L2 and L1 loss.
91110 **kwargs
92111 Other keyword arguments.
93112 """
@@ -121,6 +140,14 @@ def __init__(
121140 )
122141 self .use_l1_all = use_l1_all
123142 self .inference = inference
143+ self .use_huber = use_huber
144+ self .huber_delta = huber_delta
145+ if self .use_huber and (
146+ self .has_pf or self .has_gf or self .relative_f is not None
147+ ):
148+ raise RuntimeError (
149+ "Huber loss is not implemented for force with atom_pref, generalized force and relative force. "
150+ )
124151
125152 def forward (self , input_dict , model , label , natoms , learning_rate , mae = False ):
126153 """Return loss on energy and force.
@@ -183,7 +210,15 @@ def forward(self, input_dict, model, label, natoms, learning_rate, mae=False):
183210 more_loss ["l2_ener_loss" ] = self .display_if_exist (
184211 l2_ener_loss .detach (), find_energy
185212 )
186- loss += atom_norm * (pref_e * l2_ener_loss )
213+ if not self .use_huber :
214+ loss += atom_norm * (pref_e * l2_ener_loss )
215+ else :
216+ l_huber_loss = custom_huber_loss (
217+ atom_norm * model_pred ["energy" ],
218+ atom_norm * label ["energy" ],
219+ delta = self .huber_delta ,
220+ )
221+ loss += pref_e * l_huber_loss
187222 rmse_e = l2_ener_loss .sqrt () * atom_norm
188223 more_loss ["rmse_e" ] = self .display_if_exist (
189224 rmse_e .detach (), find_energy
@@ -238,7 +273,15 @@ def forward(self, input_dict, model, label, natoms, learning_rate, mae=False):
238273 more_loss ["l2_force_loss" ] = self .display_if_exist (
239274 l2_force_loss .detach (), find_force
240275 )
241- loss += (pref_f * l2_force_loss ).to (GLOBAL_PD_FLOAT_PRECISION )
276+ if not self .use_huber :
277+ loss += (pref_f * l2_force_loss ).to (GLOBAL_PD_FLOAT_PRECISION )
278+ else :
279+ l_huber_loss = custom_huber_loss (
280+ force_pred .reshape ([- 1 ]),
281+ force_label .reshape ([- 1 ]),
282+ delta = self .huber_delta ,
283+ )
284+ loss += pref_f * l_huber_loss
242285 rmse_f = l2_force_loss .sqrt ()
243286 more_loss ["rmse_f" ] = self .display_if_exist (
244287 rmse_f .detach (), find_force
@@ -317,7 +360,15 @@ def forward(self, input_dict, model, label, natoms, learning_rate, mae=False):
317360 more_loss ["l2_virial_loss" ] = self .display_if_exist (
318361 l2_virial_loss .detach (), find_virial
319362 )
320- loss += atom_norm * (pref_v * l2_virial_loss )
363+ if not self .use_huber :
364+ loss += atom_norm * (pref_v * l2_virial_loss )
365+ else :
366+ l_huber_loss = custom_huber_loss (
367+ atom_norm * model_pred ["virial" ].reshape ([- 1 ]),
368+ atom_norm * label ["virial" ].reshape ([- 1 ]),
369+ delta = self .huber_delta ,
370+ )
371+ loss += pref_v * l_huber_loss
321372 rmse_v = l2_virial_loss .sqrt () * atom_norm
322373 more_loss ["rmse_v" ] = self .display_if_exist (rmse_v .detach (), find_virial )
323374 if mae :
@@ -338,7 +389,15 @@ def forward(self, input_dict, model, label, natoms, learning_rate, mae=False):
338389 more_loss ["l2_atom_ener_loss" ] = self .display_if_exist (
339390 l2_atom_ener_loss .detach (), find_atom_ener
340391 )
341- loss += (pref_ae * l2_atom_ener_loss ).to (GLOBAL_PD_FLOAT_PRECISION )
392+ if not self .use_huber :
393+ loss += (pref_ae * l2_atom_ener_loss ).to (GLOBAL_PD_FLOAT_PRECISION )
394+ else :
395+ l_huber_loss = custom_huber_loss (
396+ atom_ener_reshape ,
397+ atom_ener_label_reshape ,
398+ delta = self .huber_delta ,
399+ )
400+ loss += pref_ae * l_huber_loss
342401 rmse_ae = l2_atom_ener_loss .sqrt ()
343402 more_loss ["rmse_ae" ] = self .display_if_exist (
344403 rmse_ae .detach (), find_atom_ener
@@ -436,7 +495,7 @@ def serialize(self) -> dict:
436495 """
437496 return {
438497 "@class" : "EnergyLoss" ,
439- "@version" : 1 ,
498+ "@version" : 2 ,
440499 "starter_learning_rate" : self .starter_learning_rate ,
441500 "start_pref_e" : self .start_pref_e ,
442501 "limit_pref_e" : self .limit_pref_e ,
@@ -453,6 +512,8 @@ def serialize(self) -> dict:
453512 "start_pref_gf" : self .start_pref_gf ,
454513 "limit_pref_gf" : self .limit_pref_gf ,
455514 "numb_generalized_coord" : self .numb_generalized_coord ,
515+ "use_huber" : self .use_huber ,
516+ "huber_delta" : self .huber_delta ,
456517 }
457518
458519 @classmethod
@@ -470,6 +531,6 @@ def deserialize(cls, data: dict) -> "TaskLoss":
470531 The deserialized loss module
471532 """
472533 data = data .copy ()
473- check_version_compatibility (data .pop ("@version" ), 1 , 1 )
534+ check_version_compatibility (data .pop ("@version" ), 2 , 1 )
474535 data .pop ("@class" )
475536 return cls (** data )
0 commit comments