Skip to content

Commit 1636365

Browse files
fix path (#73)
1 parent 2dd88e5 commit 1636365

File tree

4 files changed

+17
-13
lines changed

4 files changed

+17
-13
lines changed

encoding/models/base.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
class BaseNet(nn.Module):
2626
def __init__(self, nclass, backbone, aux, se_loss, dilated=True, norm_layer=None,
27-
mean=[.485, .456, .406], std=[.229, .224, .225]):
27+
mean=[.485, .456, .406], std=[.229, .224, .225], root='~/.encoding/models'):
2828
super(BaseNet, self).__init__()
2929
self.nclass = nclass
3030
self.aux = aux
@@ -33,11 +33,14 @@ def __init__(self, nclass, backbone, aux, se_loss, dilated=True, norm_layer=None
3333
self.std = std
3434
# copying modules from pretrained models
3535
if backbone == 'resnet50':
36-
self.pretrained = resnet.resnet50(pretrained=True, dilated=dilated, norm_layer=norm_layer)
36+
self.pretrained = resnet.resnet50(pretrained=True, dilated=dilated,
37+
norm_layer=norm_layer, root=root)
3738
elif backbone == 'resnet101':
38-
self.pretrained = resnet.resnet101(pretrained=True, dilated=dilated, norm_layer=norm_layer)
39+
self.pretrained = resnet.resnet101(pretrained=True, dilated=dilated,
40+
norm_layer=norm_layer, root=root)
3941
elif backbone == 'resnet152':
40-
self.pretrained = resnet.resnet152(pretrained=True, dilated=dilated, norm_layer=norm_layer)
42+
self.pretrained = resnet.resnet152(pretrained=True, dilated=dilated,
43+
norm_layer=norm_layer, root=root)
4144
else:
4245
raise RuntimeError('unknown backbone: {}'.format(backbone))
4346
# bilinear upsample options

encoding/models/encnet.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@
1919
class EncNet(BaseNet):
2020
def __init__(self, nclass, backbone, aux=True, se_loss=True, lateral=False,
2121
norm_layer=nn.BatchNorm2d, **kwargs):
22-
super(EncNet, self).__init__(nclass, backbone, aux, se_loss, norm_layer=norm_layer)
22+
super(EncNet, self).__init__(nclass, backbone, aux, se_loss,
23+
norm_layer=norm_layer, **kwargs)
2324
self.head = EncHead(self.nclass, in_channels=2048, se_loss=se_loss,
2425
lateral=lateral, norm_layer=norm_layer,
2526
up_kwargs=self._up_kwargs)
@@ -142,7 +143,7 @@ def get_encnet(dataset='pascal_voc', backbone='resnet50', pretrained=False,
142143
kwargs['lateral'] = True if dataset.lower() == 'pcontext' else False
143144
# infer number of classes
144145
from ..datasets import datasets, VOCSegmentation, VOCAugSegmentation, ADE20KSegmentation
145-
model = EncNet(datasets[dataset.lower()].NUM_CLASS, backbone=backbone, **kwargs)
146+
model = EncNet(datasets[dataset.lower()].NUM_CLASS, backbone=backbone, root=root, **kwargs)
146147
if pretrained:
147148
from .model_store import get_model_file
148149
model.load_state_dict(torch.load(

encoding/models/fcn.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ class FCN(BaseNet):
3939
>>> print(model)
4040
"""
4141
def __init__(self, nclass, backbone, aux=True, se_loss=False, norm_layer=nn.BatchNorm2d, **kwargs):
42-
super(FCN, self).__init__(nclass, backbone, aux, se_loss, norm_layer=norm_layer)
42+
super(FCN, self).__init__(nclass, backbone, aux, se_loss, norm_layer=norm_layer, **kwargs)
4343
self.head = FCNHead(2048, nclass, norm_layer)
4444
if aux:
4545
self.auxlayer = FCNHead(1024, nclass, norm_layer)
@@ -97,7 +97,7 @@ def get_fcn(dataset='pascal_voc', backbone='resnet50', pretrained=False,
9797
}
9898
# infer number of classes
9999
from ..datasets import datasets, VOCSegmentation, VOCAugSegmentation, ADE20KSegmentation
100-
model = FCN(datasets[dataset.lower()].NUM_CLASS, backbone=backbone, **kwargs)
100+
model = FCN(datasets[dataset.lower()].NUM_CLASS, backbone=backbone, root=root, **kwargs)
101101
if pretrained:
102102
from .model_store import get_model_file
103103
model.load_state_dict(torch.load(
@@ -122,7 +122,7 @@ def get_fcn_resnet50_pcontext(pretrained=False, root='~/.encoding/models', **kwa
122122
>>> model = get_fcn_resnet50_pcontext(pretrained=True)
123123
>>> print(model)
124124
"""
125-
return get_fcn('pcontext', 'resnet50', pretrained, aux=False, **kwargs)
125+
return get_fcn('pcontext', 'resnet50', pretrained, root=root, aux=False, **kwargs)
126126

127127
def get_fcn_resnet50_ade(pretrained=False, root='~/.encoding/models', **kwargs):
128128
r"""EncNet-PSP model from the paper `"Context Encoding for Semantic Segmentation"
@@ -141,4 +141,4 @@ def get_fcn_resnet50_ade(pretrained=False, root='~/.encoding/models', **kwargs):
141141
>>> model = get_fcn_resnet50_ade(pretrained=True)
142142
>>> print(model)
143143
"""
144-
return get_fcn('ade20k', 'resnet50', pretrained, **kwargs)
144+
return get_fcn('ade20k', 'resnet50', pretrained, root=root, **kwargs)

encoding/models/psp.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
class PSP(BaseNet):
1818
def __init__(self, nclass, backbone, aux=True, se_loss=False, norm_layer=nn.BatchNorm2d, **kwargs):
19-
super(PSP, self).__init__(nclass, backbone, aux, se_loss, norm_layer=norm_layer)
19+
super(PSP, self).__init__(nclass, backbone, aux, se_loss, norm_layer=norm_layer, **kwargs)
2020
self.head = PSPHead(2048, nclass, norm_layer, self._up_kwargs)
2121
if aux:
2222
self.auxlayer = FCNHead(1024, nclass, norm_layer)
@@ -59,7 +59,7 @@ def get_psp(dataset='pascal_voc', backbone='resnet50', pretrained=False,
5959
}
6060
# infer number of classes
6161
from ..datasets import datasets, VOCSegmentation, VOCAugSegmentation, ADE20KSegmentation
62-
model = PSP(datasets[dataset.lower()].NUM_CLASS, backbone=backbone, **kwargs)
62+
model = PSP(datasets[dataset.lower()].NUM_CLASS, backbone=backbone, root=root, **kwargs)
6363
if pretrained:
6464
from .model_store import get_model_file
6565
model.load_state_dict(torch.load(
@@ -83,4 +83,4 @@ def get_psp_resnet50_ade(pretrained=False, root='~/.encoding/models', **kwargs):
8383
>>> model = get_psp_resnet50_ade(pretrained=True)
8484
>>> print(model)
8585
"""
86-
return get_psp('ade20k', 'resnet50', pretrained)
86+
return get_psp('ade20k', 'resnet50', pretrained, root=root, **kwargs)

0 commit comments

Comments
 (0)