55from fastNLP .core .metrics import SeqLabelEvaluator
66from fastNLP .core .field import TextField , LabelField
77from fastNLP .core .instance import Instance
8-
98from fastNLP .core .optimizer import Optimizer
109from fastNLP .core .trainer import SeqLabelTrainer
1110from fastNLP .models .sequence_modeling import SeqLabeling
@@ -51,6 +50,8 @@ def test_case_1(self):
5150 print ("loss = %f" % (los ))
5251 print ("r = %f" % (r ))
5352
53+ self .assertEqual (int (los * 1000 ), int (r * 1000 ))
54+
5455 def test_case_2 (self ):
5556 #验证squash()的正确性
5657 print ("----------------------------------" )
@@ -82,12 +83,14 @@ def test_case_2(self):
8283
8384 y = tc .log (y )
8485 los = loss_func (y , gy )
86+ print ("loss = %f" % (los ))
8587
8688 r = - log (.3 ) - log (.3 ) - log (.1 ) - log (.3 ) - log (.7 ) - log (.1 )
8789 r /= 6
88- print ("loss = %f" % (los ))
8990 print ("r = %f" % (r ))
9091
92+ self .assertEqual (int (los * 1000 ), int (r * 1000 ))
93+
9194 def test_case_3 (self ):
9295 #验证pack_padded_sequence()的正确性
9396 print ("----------------------------------" )
@@ -130,6 +133,8 @@ def test_case_3(self):
130133 r /= 6
131134 print ("r = %f" % (r ))
132135
136+ self .assertEqual (int (los * 1000 ), int (r * 1000 ))
137+
133138 def test_case_4 (self ):
134139 #验证unpad()的正确性
135140 print ("----------------------------------" )
@@ -169,6 +174,9 @@ def test_case_4(self):
169174 r /= 7
170175 print ("r = %f" % (r ))
171176
177+
178+ self .assertEqual (int (los * 1000 ), int (r * 1000 ))
179+
172180 def test_case_5 (self ):
173181 #验证mask()和make_mask()的正确性
174182 print ("----------------------------------" )
@@ -217,6 +225,10 @@ def test_case_5(self):
217225 r /= 8
218226 print ("r = %f" % (r ))
219227
228+
229+ self .assertEqual (int (los * 1000 ), int (r * 1000 ))
230+ self .assertEqual (int (los2 * 1000 ), int (r * 1000 ))
231+
220232 def test_case_6 (self ):
221233 #验证unpad_mask()的正确性
222234 print ("----------------------------------" )
@@ -256,6 +268,8 @@ def test_case_6(self):
256268 r /= 7
257269 print ("r = %f" % (r ))
258270
271+ self .assertEqual (int (los * 1000 ), int (r * 1000 ))
272+
259273 def test_case_7 (self ):
260274 #验证一些其他东西
261275 print ("----------------------------------" )
@@ -295,6 +309,7 @@ def test_case_7(self):
295309 r = - log (.3 ) - log (.5 ) - log (.3 )
296310 r /= 3
297311 print ("r = %f" % (r ))
312+ self .assertEqual (int (los * 1000 ), int (r * 1000 ))
298313
299314if __name__ == "__main__" :
300- unittest .main ()
315+ unittest .main ()
0 commit comments