Skip to content

Commit 07fb61e

Browse files
committed
Update test_loss
1 parent 1f15b52 commit 07fb61e

File tree

1 file changed

+18
-3
lines changed

1 file changed

+18
-3
lines changed

test/core/test_loss.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from fastNLP.core.metrics import SeqLabelEvaluator
66
from fastNLP.core.field import TextField, LabelField
77
from fastNLP.core.instance import Instance
8-
98
from fastNLP.core.optimizer import Optimizer
109
from fastNLP.core.trainer import SeqLabelTrainer
1110
from fastNLP.models.sequence_modeling import SeqLabeling
@@ -51,6 +50,8 @@ def test_case_1(self):
5150
print ("loss = %f" % (los))
5251
print ("r = %f" % (r))
5352

53+
self.assertEqual(int(los * 1000), int(r * 1000))
54+
5455
def test_case_2(self):
5556
#验证squash()的正确性
5657
print ("----------------------------------")
@@ -82,12 +83,14 @@ def test_case_2(self):
8283

8384
y = tc.log(y)
8485
los = loss_func(y , gy)
86+
print ("loss = %f" % (los))
8587

8688
r = -log(.3) - log(.3) - log(.1) - log(.3) - log(.7) - log(.1)
8789
r /= 6
88-
print ("loss = %f" % (los))
8990
print ("r = %f" % (r))
9091

92+
self.assertEqual(int(los * 1000), int(r * 1000))
93+
9194
def test_case_3(self):
9295
#验证pack_padded_sequence()的正确性
9396
print ("----------------------------------")
@@ -130,6 +133,8 @@ def test_case_3(self):
130133
r /= 6
131134
print ("r = %f" % (r))
132135

136+
self.assertEqual(int(los * 1000), int(r * 1000))
137+
133138
def test_case_4(self):
134139
#验证unpad()的正确性
135140
print ("----------------------------------")
@@ -169,6 +174,9 @@ def test_case_4(self):
169174
r /= 7
170175
print ("r = %f" % (r))
171176

177+
178+
self.assertEqual(int(los * 1000), int(r * 1000))
179+
172180
def test_case_5(self):
173181
#验证mask()和make_mask()的正确性
174182
print ("----------------------------------")
@@ -217,6 +225,10 @@ def test_case_5(self):
217225
r /= 8
218226
print ("r = %f" % (r))
219227

228+
229+
self.assertEqual(int(los * 1000), int(r * 1000))
230+
self.assertEqual(int(los2 * 1000), int(r * 1000))
231+
220232
def test_case_6(self):
221233
#验证unpad_mask()的正确性
222234
print ("----------------------------------")
@@ -256,6 +268,8 @@ def test_case_6(self):
256268
r /= 7
257269
print ("r = %f" % (r))
258270

271+
self.assertEqual(int(los * 1000), int(r * 1000))
272+
259273
def test_case_7(self):
260274
#验证一些其他东西
261275
print ("----------------------------------")
@@ -295,6 +309,7 @@ def test_case_7(self):
295309
r = - log(.3) - log(.5) - log(.3)
296310
r /= 3
297311
print ("r = %f" % (r))
312+
self.assertEqual(int(los * 1000), int(r * 1000))
298313

299314
if __name__ == "__main__":
300-
unittest.main()
315+
unittest.main()

0 commit comments

Comments
 (0)