Skip to content

Commit 3cadd5a

Browse files
committed
fix a iterant lossfuntion , and some error in comments
1 parent 07fb61e commit 3cadd5a

File tree

1 file changed

+8
-9
lines changed

1 file changed

+8
-9
lines changed

fastNLP/core/loss.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ def squash(predict , truth , **kwargs):
55
66
:param predict : Tensor, model output
77
:param truth : Tensor, truth from dataset
8-
:param **kwargs : extract arguments
8+
:param **kwargs : extra arguments
99
1010
:return predict , truth: predict & truth after processing
1111
'''
@@ -18,8 +18,8 @@ def unpad(predict , truth , **kwargs):
1818
1919
:param predict : Tensor, [batch_size , max_len , tag_size]
2020
:param truth : Tensor, [batch_size , max_len]
21-
:param **kwargs : extract arguments, kwargs["lens"] is expected to be exsist
22-
arg["lens"] : list or LongTensor, [batch_size]
21+
:param **kwargs : extra arguments, kwargs["lens"] is expected to be exsist
22+
kwargs["lens"] : list or LongTensor, [batch_size]
2323
the i-th element is true lengths of i-th sequence
2424
2525
:return predict , truth: predict & truth after processing
@@ -39,8 +39,8 @@ def unpad_mask(predict , truth , **kwargs):
3939
4040
:param predict : Tensor, [batch_size , max_len , tag_size]
4141
:param truth : Tensor, [batch_size , max_len]
42-
:param **kwargs : extract arguments, kwargs["lens"] is expected to be exsist
43-
arg["lens"] : list or LongTensor, [batch_size]
42+
:param **kwargs : extra arguments, kwargs["lens"] is expected to be exsist
43+
kwargs["lens"] : list or LongTensor, [batch_size]
4444
the i-th element is true lengths of i-th sequence
4545
4646
:return predict , truth: predict & truth after processing
@@ -56,8 +56,8 @@ def mask(predict , truth , **kwargs):
5656
5757
:param predict : Tensor, [batch_size , max_len , tag_size]
5858
:param truth : Tensor, [batch_size , max_len]
59-
:param **kwargs : extract arguments, kwargs["mask"] is expected to be exsist
60-
arg["mask"] : ByteTensor, [batch_size , max_len]
59+
:param **kwargs : extra arguments, kwargs["mask"] is expected to be exsist
60+
kwargs["mask"] : ByteTensor, [batch_size , max_len]
6161
the mask Tensor , the position that is 1 will be selected
6262
6363
:return predict , truth: predict & truth after processing
@@ -112,7 +112,6 @@ def make_mask(lens , tar_len):
112112
"MarginRankingLoss".lower() : torch.nn.MarginRankingLoss,
113113
"TripletMarginLoss".lower() : torch.nn.TripletMarginLoss,
114114
"HingeEmbeddingLoss".lower() : torch.nn.HingeEmbeddingLoss,
115-
"HingeEmbeddingLoss".lower() : torch.nn.HingeEmbeddingLoss,
116115
"CosineEmbeddingLoss".lower() : torch.nn.CosineEmbeddingLoss,
117116
"MultiLabelMarginLoss".lower() : torch.nn.MultiLabelMarginLoss,
118117
"MultiLabelSoftMarginLoss".lower() : torch.nn.MultiLabelSoftMarginLoss,
@@ -132,7 +131,7 @@ def __init__(self , loss_name , pre_pro = [squash], **kwargs):
132131
133132
pre_pro funcsions should have three arguments: predict, truth, **arg
134133
predict and truth is the necessary parameters in loss function
135-
arg is the extra parameters passed-in when calling loss function
134+
kwargs is the extra parameters passed-in when calling loss function
136135
pre_pro functions should return two objects, respectively predict and truth that after processed
137136
138137
'''

0 commit comments

Comments
 (0)