Skip to content

Commit 2cd2dae

Browse files
committed
update loss
1 parent 5ec58e3 commit 2cd2dae

File tree

3 files changed

+494
-55
lines changed

3 files changed

+494
-55
lines changed

fastNLP/core/loss.py

Lines changed: 193 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,58 +1,197 @@
11
import torch
22

3+
def squash(predict , truth , **kwargs):
4+
'''To reshape tensors in order to fit Loss functions in pytorch
5+
6+
:param predict : Tensor, model output
7+
:param truth : Tensor, truth from dataset
8+
:param **kwargs : extract arguments
9+
10+
:return predict , truth: predict & truth after processing
11+
'''
12+
return predict.view(-1 , predict.size()[-1]) , truth.view(-1,)
13+
14+
def unpad(predict , truth , **kwargs):
15+
'''To process padded sequence output to get true loss
16+
Using pack_padded_sequence() method
17+
This method contains squash()
18+
19+
:param predict : Tensor, [batch_size , max_len , tag_size]
20+
:param truth : Tensor, [batch_size , max_len]
21+
:param **kwargs : extract arguments, kwargs["lens"] is expected to be exsist
22+
arg["lens"] : list or LongTensor, [batch_size]
23+
the i-th element is true lengths of i-th sequence
24+
25+
:return predict , truth: predict & truth after processing
26+
'''
27+
if kwargs.get("lens") is None:
28+
return predict , truth
29+
lens = torch.LongTensor(kwargs["lens"])
30+
lens , idx = torch.sort(lens , descending = True)
31+
predict = torch.nn.utils.rnn.pack_padded_sequence(predict[idx] , lens , batch_first = True).data
32+
truth = torch.nn.utils.rnn.pack_padded_sequence(truth[idx] , lens , batch_first = True).data
33+
return predict , truth
34+
35+
def unpad_mask(predict , truth , **kwargs):
36+
'''To process padded sequence output to get true loss
37+
Using mask() method
38+
This method contains squash()
39+
40+
:param predict : Tensor, [batch_size , max_len , tag_size]
41+
:param truth : Tensor, [batch_size , max_len]
42+
:param **kwargs : extract arguments, kwargs["lens"] is expected to be exsist
43+
arg["lens"] : list or LongTensor, [batch_size]
44+
the i-th element is true lengths of i-th sequence
45+
46+
:return predict , truth: predict & truth after processing
47+
'''
48+
if kwargs.get("lens") is None:
49+
return predict , truth
50+
mas = make_mask(kwargs["lens"] , truth.size()[1])
51+
return mask(predict , truth , mask = mas)
52+
53+
def mask(predict , truth , **kwargs):
54+
'''To select specific elements from Tensor
55+
This method contains squash()
56+
57+
:param predict : Tensor, [batch_size , max_len , tag_size]
58+
:param truth : Tensor, [batch_size , max_len]
59+
:param **kwargs : extract arguments, kwargs["mask"] is expected to be exsist
60+
arg["mask"] : ByteTensor, [batch_size , max_len]
61+
the mask Tensor , the position that is 1 will be selected
62+
63+
:return predict , truth: predict & truth after processing
64+
'''
65+
if kwargs.get("mask") is None:
66+
return predict , truth
67+
mask = kwargs["mask"]
68+
69+
predict , truth = squash(predict , truth)
70+
mask = mask.view(-1,)
71+
72+
predict = torch.masked_select(predict.permute(1,0) , mask).view(predict.size()[-1] , -1).permute(1,0)
73+
truth = torch.masked_select(truth , mask)
74+
75+
return predict , truth
76+
77+
def make_mask(lens , tar_len):
78+
'''to generate a mask that select [:lens[i]] for i-th element
79+
embezzle from fastNLP.models.sequence_modeling.seq_mask
80+
81+
:param lens : list or LongTensor, [batch_size]
82+
:param tar_len : int
83+
84+
:return mask : ByteTensor
85+
'''
86+
lens = torch.LongTensor(lens)
87+
mask = [torch.ge(lens, i + 1) for i in range(tar_len)]
88+
mask = torch.stack(mask, 1)
89+
return mask
90+
91+
#map string to function. Just for more elegant using
92+
method_dict = {
93+
"squash" : squash,
94+
"unpad" : unpad,
95+
"unpad_mask" : unpad_mask,
96+
"mask" : mask,
97+
}
98+
99+
loss_function_name = {
100+
"L1Loss".lower() : torch.nn.L1Loss,
101+
"BCELoss".lower() : torch.nn.BCELoss,
102+
"MSELoss".lower() : torch.nn.MSELoss,
103+
"NLLLoss".lower() : torch.nn.NLLLoss,
104+
"KLDivLoss".lower() : torch.nn.KLDivLoss,
105+
"NLLLoss2dLoss".lower() : torch.nn.NLLLoss2d, #every name should end with "loss"
106+
"SmoothL1Loss".lower() : torch.nn.SmoothL1Loss,
107+
"SoftMarginLoss".lower() : torch.nn.SoftMarginLoss,
108+
"PoissonNLLLoss".lower() : torch.nn.PoissonNLLLoss,
109+
"MultiMarginLoss".lower() : torch.nn.MultiMarginLoss,
110+
"CrossEntropyLoss".lower() : torch.nn.CrossEntropyLoss,
111+
"BCEWithLogitsLoss".lower() : torch.nn.BCEWithLogitsLoss,
112+
"MarginRankingLoss".lower() : torch.nn.MarginRankingLoss,
113+
"TripletMarginLoss".lower() : torch.nn.TripletMarginLoss,
114+
"HingeEmbeddingLoss".lower() : torch.nn.HingeEmbeddingLoss,
115+
"HingeEmbeddingLoss".lower() : torch.nn.HingeEmbeddingLoss,
116+
"CosineEmbeddingLoss".lower() : torch.nn.CosineEmbeddingLoss,
117+
"MultiLabelMarginLoss".lower() : torch.nn.MultiLabelMarginLoss,
118+
"MultiLabelSoftMarginLoss".lower() : torch.nn.MultiLabelSoftMarginLoss,
119+
}
3120

4121
class Loss(object):
5-
"""Loss function of the algorithm,
6-
either the wrapper of a loss function from framework, or a user-defined loss (need pytorch auto_grad support)
7-
8-
"""
9-
10-
def __init__(self, args):
11-
"""
12-
13-
:param args: None or str, the name of a loss function.
14-
15-
"""
16-
if args is None:
17-
# this is useful when Trainer.__init__ performs type check
18-
self._loss = None
19-
elif isinstance(args, str):
20-
self._loss = self._borrow_from_pytorch(args)
21-
else:
22-
raise NotImplementedError
23-
24-
def get(self):
25-
"""
26-
27-
:return self._loss: the loss function
28-
"""
29-
return self._loss
30-
31-
@staticmethod
32-
def _borrow_from_pytorch(loss_name):
33-
"""Given a name of a loss function, return it from PyTorch.
34-
35-
:param loss_name: str, the name of a loss function
36-
37-
- cross_entropy: combines log softmax and nll loss in a single function.
38-
- nll: negative log likelihood
39-
40-
:return loss: a PyTorch loss
41-
"""
42-
43-
class InnerCrossEntropy:
44-
"""A simple wrapper to guarantee input shapes."""
45-
46-
def __init__(self):
47-
self.f = torch.nn.CrossEntropyLoss()
48-
49-
def __call__(self, predict, truth):
50-
truth = truth.view(-1, )
51-
return self.f(predict, truth)
52-
53-
if loss_name == "cross_entropy":
54-
return InnerCrossEntropy()
55-
elif loss_name == 'nll':
56-
return torch.nn.NLLLoss()
57-
else:
58-
raise NotImplementedError
122+
'''a Loss object is a callable object represents loss functions
123+
'''
124+
125+
def __init__(self , loss_name , pre_pro = [squash], **kwargs):
126+
'''
127+
128+
:param loss_name: str or None , the name of loss function
129+
:param pre_pro : list of function or str, methods to reform parameters before calculating loss
130+
the strings will be auto translated to pre-defined functions
131+
:param **kwargs: kwargs for torch loss function
132+
133+
pre_pro funcsions should have three arguments: predict, truth, **arg
134+
predict and truth is the necessary parameters in loss function
135+
arg is the extra parameters passed-in when calling loss function
136+
pre_pro functions should return two objects, respectively predict and truth that after processed
137+
138+
'''
139+
140+
if loss_name is None:
141+
# this is useful when Trainer.__init__ performs type check
142+
self._loss = None
143+
else:
144+
if not isinstance(loss_name, str):
145+
raise NotImplementedError
146+
else:
147+
self._loss = self._get_loss(loss_name , **kwargs)
148+
149+
self.pre_pro = [f if callable(f) else method_dict.get(f) for f in pre_pro]
150+
151+
def add_pre_pro(self , func):
152+
'''add a pre_pro function
153+
154+
:param func: a function or str, methods to reform parameters before calculating loss
155+
the strings will be auto translated to pre-defined functions
156+
'''
157+
if not callable(func):
158+
func = method_dict.get(func)
159+
if func is None:
160+
return
161+
self.pre_pro.append(func)
162+
163+
@staticmethod
164+
def _get_loss(loss_name , **kwargs):
165+
'''Get loss function from torch
166+
167+
:param loss_name: str, the name of loss function
168+
:param **kwargs: kwargs for torch loss function
169+
:return: A callable loss function object
170+
'''
171+
loss_name = loss_name.strip().lower()
172+
loss_name = "".join(loss_name.split("_"))
173+
174+
if len(loss_name) < 4 or loss_name[-4 : ] != "loss":
175+
loss_name += "loss"
176+
return loss_function_name[loss_name](**kwargs)
177+
178+
def get(self):
179+
'''This method exists just for make some existing codes run error-freely
180+
'''
181+
return self
182+
183+
def __call__(self , predict , truth , **kwargs):
184+
'''call a loss function
185+
predict and truth will be processed by pre_pro methods in order of addition
186+
187+
:param predict : Tensor, model output
188+
:param truth : Tensor, truth from dataset
189+
:param **kwargs : extra arguments, pass to pre_pro functions
190+
for example, if used unpad_mask() in pre_pro, there should be a kwarg named lens
191+
'''
192+
for f in self.pre_pro:
193+
if f is None:
194+
continue
195+
predict , truth = f(predict , truth , **kwargs)
196+
197+
return self._loss(predict , truth)

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
numpy>=1.14.2
2-
torch==0.4.0
2+
torch>=0.4.0
33
torchvision>=0.1.8
44
tensorboardX

0 commit comments

Comments
 (0)