Skip to content

Commit abf840c

Browse files
authored
Merge pull request #106 from FFTYYY/master
update loss & a small change in requirements
2 parents 5ec58e3 + 3cadd5a commit abf840c

File tree

4 files changed

+509
-56
lines changed

4 files changed

+509
-56
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ fastNLP is a modular Natural Language Processing system based on PyTorch, for fa
1414
## Requirements
1515

1616
- numpy>=1.14.2
17-
- torch==0.4.0
17+
- torch>=0.4.0
1818
- torchvision>=0.1.8
1919
- tensorboardX
2020

fastNLP/core/loss.py

Lines changed: 192 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,58 +1,196 @@
11
import torch
22

3+
def squash(predict , truth , **kwargs):
4+
'''To reshape tensors in order to fit Loss functions in pytorch
5+
6+
:param predict : Tensor, model output
7+
:param truth : Tensor, truth from dataset
8+
:param **kwargs : extra arguments
9+
10+
:return predict , truth: predict & truth after processing
11+
'''
12+
return predict.view(-1 , predict.size()[-1]) , truth.view(-1,)
13+
14+
def unpad(predict , truth , **kwargs):
15+
'''To process padded sequence output to get true loss
16+
Using pack_padded_sequence() method
17+
This method contains squash()
18+
19+
:param predict : Tensor, [batch_size , max_len , tag_size]
20+
:param truth : Tensor, [batch_size , max_len]
21+
:param **kwargs : extra arguments, kwargs["lens"] is expected to be exsist
22+
kwargs["lens"] : list or LongTensor, [batch_size]
23+
the i-th element is true lengths of i-th sequence
24+
25+
:return predict , truth: predict & truth after processing
26+
'''
27+
if kwargs.get("lens") is None:
28+
return predict , truth
29+
lens = torch.LongTensor(kwargs["lens"])
30+
lens , idx = torch.sort(lens , descending = True)
31+
predict = torch.nn.utils.rnn.pack_padded_sequence(predict[idx] , lens , batch_first = True).data
32+
truth = torch.nn.utils.rnn.pack_padded_sequence(truth[idx] , lens , batch_first = True).data
33+
return predict , truth
34+
35+
def unpad_mask(predict , truth , **kwargs):
36+
'''To process padded sequence output to get true loss
37+
Using mask() method
38+
This method contains squash()
39+
40+
:param predict : Tensor, [batch_size , max_len , tag_size]
41+
:param truth : Tensor, [batch_size , max_len]
42+
:param **kwargs : extra arguments, kwargs["lens"] is expected to be exsist
43+
kwargs["lens"] : list or LongTensor, [batch_size]
44+
the i-th element is true lengths of i-th sequence
45+
46+
:return predict , truth: predict & truth after processing
47+
'''
48+
if kwargs.get("lens") is None:
49+
return predict , truth
50+
mas = make_mask(kwargs["lens"] , truth.size()[1])
51+
return mask(predict , truth , mask = mas)
52+
53+
def mask(predict , truth , **kwargs):
54+
'''To select specific elements from Tensor
55+
This method contains squash()
56+
57+
:param predict : Tensor, [batch_size , max_len , tag_size]
58+
:param truth : Tensor, [batch_size , max_len]
59+
:param **kwargs : extra arguments, kwargs["mask"] is expected to be exsist
60+
kwargs["mask"] : ByteTensor, [batch_size , max_len]
61+
the mask Tensor , the position that is 1 will be selected
62+
63+
:return predict , truth: predict & truth after processing
64+
'''
65+
if kwargs.get("mask") is None:
66+
return predict , truth
67+
mask = kwargs["mask"]
68+
69+
predict , truth = squash(predict , truth)
70+
mask = mask.view(-1,)
71+
72+
predict = torch.masked_select(predict.permute(1,0) , mask).view(predict.size()[-1] , -1).permute(1,0)
73+
truth = torch.masked_select(truth , mask)
74+
75+
return predict , truth
76+
77+
def make_mask(lens , tar_len):
78+
'''to generate a mask that select [:lens[i]] for i-th element
79+
embezzle from fastNLP.models.sequence_modeling.seq_mask
80+
81+
:param lens : list or LongTensor, [batch_size]
82+
:param tar_len : int
83+
84+
:return mask : ByteTensor
85+
'''
86+
lens = torch.LongTensor(lens)
87+
mask = [torch.ge(lens, i + 1) for i in range(tar_len)]
88+
mask = torch.stack(mask, 1)
89+
return mask
90+
91+
#map string to function. Just for more elegant using
92+
method_dict = {
93+
"squash" : squash,
94+
"unpad" : unpad,
95+
"unpad_mask" : unpad_mask,
96+
"mask" : mask,
97+
}
98+
99+
loss_function_name = {
100+
"L1Loss".lower() : torch.nn.L1Loss,
101+
"BCELoss".lower() : torch.nn.BCELoss,
102+
"MSELoss".lower() : torch.nn.MSELoss,
103+
"NLLLoss".lower() : torch.nn.NLLLoss,
104+
"KLDivLoss".lower() : torch.nn.KLDivLoss,
105+
"NLLLoss2dLoss".lower() : torch.nn.NLLLoss2d, #every name should end with "loss"
106+
"SmoothL1Loss".lower() : torch.nn.SmoothL1Loss,
107+
"SoftMarginLoss".lower() : torch.nn.SoftMarginLoss,
108+
"PoissonNLLLoss".lower() : torch.nn.PoissonNLLLoss,
109+
"MultiMarginLoss".lower() : torch.nn.MultiMarginLoss,
110+
"CrossEntropyLoss".lower() : torch.nn.CrossEntropyLoss,
111+
"BCEWithLogitsLoss".lower() : torch.nn.BCEWithLogitsLoss,
112+
"MarginRankingLoss".lower() : torch.nn.MarginRankingLoss,
113+
"TripletMarginLoss".lower() : torch.nn.TripletMarginLoss,
114+
"HingeEmbeddingLoss".lower() : torch.nn.HingeEmbeddingLoss,
115+
"CosineEmbeddingLoss".lower() : torch.nn.CosineEmbeddingLoss,
116+
"MultiLabelMarginLoss".lower() : torch.nn.MultiLabelMarginLoss,
117+
"MultiLabelSoftMarginLoss".lower() : torch.nn.MultiLabelSoftMarginLoss,
118+
}
3119

4120
class Loss(object):
5-
"""Loss function of the algorithm,
6-
either the wrapper of a loss function from framework, or a user-defined loss (need pytorch auto_grad support)
7-
8-
"""
9-
10-
def __init__(self, args):
11-
"""
12-
13-
:param args: None or str, the name of a loss function.
14-
15-
"""
16-
if args is None:
17-
# this is useful when Trainer.__init__ performs type check
18-
self._loss = None
19-
elif isinstance(args, str):
20-
self._loss = self._borrow_from_pytorch(args)
21-
else:
22-
raise NotImplementedError
23-
24-
def get(self):
25-
"""
26-
27-
:return self._loss: the loss function
28-
"""
29-
return self._loss
30-
31-
@staticmethod
32-
def _borrow_from_pytorch(loss_name):
33-
"""Given a name of a loss function, return it from PyTorch.
34-
35-
:param loss_name: str, the name of a loss function
36-
37-
- cross_entropy: combines log softmax and nll loss in a single function.
38-
- nll: negative log likelihood
39-
40-
:return loss: a PyTorch loss
41-
"""
42-
43-
class InnerCrossEntropy:
44-
"""A simple wrapper to guarantee input shapes."""
45-
46-
def __init__(self):
47-
self.f = torch.nn.CrossEntropyLoss()
48-
49-
def __call__(self, predict, truth):
50-
truth = truth.view(-1, )
51-
return self.f(predict, truth)
52-
53-
if loss_name == "cross_entropy":
54-
return InnerCrossEntropy()
55-
elif loss_name == 'nll':
56-
return torch.nn.NLLLoss()
57-
else:
58-
raise NotImplementedError
121+
'''a Loss object is a callable object represents loss functions
122+
'''
123+
124+
def __init__(self , loss_name , pre_pro = [squash], **kwargs):
125+
'''
126+
127+
:param loss_name: str or None , the name of loss function
128+
:param pre_pro : list of function or str, methods to reform parameters before calculating loss
129+
the strings will be auto translated to pre-defined functions
130+
:param **kwargs: kwargs for torch loss function
131+
132+
pre_pro funcsions should have three arguments: predict, truth, **arg
133+
predict and truth is the necessary parameters in loss function
134+
kwargs is the extra parameters passed-in when calling loss function
135+
pre_pro functions should return two objects, respectively predict and truth that after processed
136+
137+
'''
138+
139+
if loss_name is None:
140+
# this is useful when Trainer.__init__ performs type check
141+
self._loss = None
142+
else:
143+
if not isinstance(loss_name, str):
144+
raise NotImplementedError
145+
else:
146+
self._loss = self._get_loss(loss_name , **kwargs)
147+
148+
self.pre_pro = [f if callable(f) else method_dict.get(f) for f in pre_pro]
149+
150+
def add_pre_pro(self , func):
151+
'''add a pre_pro function
152+
153+
:param func: a function or str, methods to reform parameters before calculating loss
154+
the strings will be auto translated to pre-defined functions
155+
'''
156+
if not callable(func):
157+
func = method_dict.get(func)
158+
if func is None:
159+
return
160+
self.pre_pro.append(func)
161+
162+
@staticmethod
163+
def _get_loss(loss_name , **kwargs):
164+
'''Get loss function from torch
165+
166+
:param loss_name: str, the name of loss function
167+
:param **kwargs: kwargs for torch loss function
168+
:return: A callable loss function object
169+
'''
170+
loss_name = loss_name.strip().lower()
171+
loss_name = "".join(loss_name.split("_"))
172+
173+
if len(loss_name) < 4 or loss_name[-4 : ] != "loss":
174+
loss_name += "loss"
175+
return loss_function_name[loss_name](**kwargs)
176+
177+
def get(self):
178+
'''This method exists just for make some existing codes run error-freely
179+
'''
180+
return self
181+
182+
def __call__(self , predict , truth , **kwargs):
183+
'''call a loss function
184+
predict and truth will be processed by pre_pro methods in order of addition
185+
186+
:param predict : Tensor, model output
187+
:param truth : Tensor, truth from dataset
188+
:param **kwargs : extra arguments, pass to pre_pro functions
189+
for example, if used unpad_mask() in pre_pro, there should be a kwarg named lens
190+
'''
191+
for f in self.pre_pro:
192+
if f is None:
193+
continue
194+
predict , truth = f(predict , truth , **kwargs)
195+
196+
return self._loss(predict , truth)

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
numpy>=1.14.2
2-
torch==0.4.0
2+
torch>=0.4.0
33
torchvision>=0.1.8
44
tensorboardX

0 commit comments

Comments
 (0)