@@ -5,7 +5,7 @@ def squash(predict , truth , **kwargs):
55
66 :param predict : Tensor, model output
77 :param truth : Tensor, truth from dataset
8- :param **kwargs : extract arguments
8+ :param **kwargs : extra arguments
99
1010 :return predict , truth: predict & truth after processing
1111 '''
@@ -18,8 +18,8 @@ def unpad(predict , truth , **kwargs):
1818
1919 :param predict : Tensor, [batch_size , max_len , tag_size]
2020 :param truth : Tensor, [batch_size , max_len]
21- :param **kwargs : extract arguments, kwargs["lens"] is expected to be exsist
22- arg ["lens"] : list or LongTensor, [batch_size]
21+ :param **kwargs : extra arguments, kwargs["lens"] is expected to be exsist
22+ kwargs ["lens"] : list or LongTensor, [batch_size]
2323 the i-th element is true lengths of i-th sequence
2424
2525 :return predict , truth: predict & truth after processing
@@ -39,8 +39,8 @@ def unpad_mask(predict , truth , **kwargs):
3939
4040 :param predict : Tensor, [batch_size , max_len , tag_size]
4141 :param truth : Tensor, [batch_size , max_len]
42- :param **kwargs : extract arguments, kwargs["lens"] is expected to be exsist
43- arg ["lens"] : list or LongTensor, [batch_size]
42+ :param **kwargs : extra arguments, kwargs["lens"] is expected to be exsist
43+ kwargs ["lens"] : list or LongTensor, [batch_size]
4444 the i-th element is true lengths of i-th sequence
4545
4646 :return predict , truth: predict & truth after processing
@@ -56,8 +56,8 @@ def mask(predict , truth , **kwargs):
5656
5757 :param predict : Tensor, [batch_size , max_len , tag_size]
5858 :param truth : Tensor, [batch_size , max_len]
59- :param **kwargs : extract arguments, kwargs["mask"] is expected to be exsist
60- arg ["mask"] : ByteTensor, [batch_size , max_len]
59+ :param **kwargs : extra arguments, kwargs["mask"] is expected to be exsist
60+ kwargs ["mask"] : ByteTensor, [batch_size , max_len]
6161 the mask Tensor , the position that is 1 will be selected
6262
6363 :return predict , truth: predict & truth after processing
@@ -112,7 +112,6 @@ def make_mask(lens , tar_len):
112112 "MarginRankingLoss" .lower () : torch .nn .MarginRankingLoss ,
113113 "TripletMarginLoss" .lower () : torch .nn .TripletMarginLoss ,
114114 "HingeEmbeddingLoss" .lower () : torch .nn .HingeEmbeddingLoss ,
115- "HingeEmbeddingLoss" .lower () : torch .nn .HingeEmbeddingLoss ,
116115 "CosineEmbeddingLoss" .lower () : torch .nn .CosineEmbeddingLoss ,
117116 "MultiLabelMarginLoss" .lower () : torch .nn .MultiLabelMarginLoss ,
118117 "MultiLabelSoftMarginLoss" .lower () : torch .nn .MultiLabelSoftMarginLoss ,
@@ -132,7 +131,7 @@ def __init__(self , loss_name , pre_pro = [squash], **kwargs):
132131
133132 pre_pro funcsions should have three arguments: predict, truth, **arg
134133 predict and truth is the necessary parameters in loss function
135- arg is the extra parameters passed-in when calling loss function
134+ kwargs is the extra parameters passed-in when calling loss function
136135 pre_pro functions should return two objects, respectively predict and truth that after processed
137136
138137 '''
0 commit comments