@@ -40,23 +40,39 @@ def __init__(
4040 aa_layer = None ,
4141 drop_block = None ,
4242 drop_path = None ,
43+ device = None ,
44+ dtype = None ,
4345 ):
46+ dd = {'device' : device , 'dtype' : dtype }
4447 super (SelectiveKernelBasic , self ).__init__ ()
4548
4649 sk_kwargs = sk_kwargs or {}
47- conv_kwargs = dict (act_layer = act_layer , norm_layer = norm_layer )
50+ conv_kwargs = dict (act_layer = act_layer , norm_layer = norm_layer , ** dd )
4851 assert cardinality == 1 , 'BasicBlock only supports cardinality of 1'
4952 assert base_width == 64 , 'BasicBlock doest not support changing base width'
5053 first_planes = planes // reduce_first
5154 outplanes = planes * self .expansion
5255 first_dilation = first_dilation or dilation
5356
5457 self .conv1 = SelectiveKernel (
55- inplanes , first_planes , stride = stride , dilation = first_dilation ,
56- aa_layer = aa_layer , drop_layer = drop_block , ** conv_kwargs , ** sk_kwargs )
58+ inplanes ,
59+ first_planes ,
60+ stride = stride ,
61+ dilation = first_dilation ,
62+ aa_layer = aa_layer ,
63+ drop_layer = drop_block ,
64+ ** conv_kwargs ,
65+ ** sk_kwargs ,
66+ )
5767 self .conv2 = ConvNormAct (
58- first_planes , outplanes , kernel_size = 3 , dilation = dilation , apply_act = False , ** conv_kwargs )
59- self .se = create_attn (attn_layer , outplanes )
68+ first_planes ,
69+ outplanes ,
70+ kernel_size = 3 ,
71+ dilation = dilation ,
72+ apply_act = False ,
73+ ** conv_kwargs ,
74+ )
75+ self .se = create_attn (attn_layer , outplanes , ** dd )
6076 self .act = act_layer (inplace = True )
6177 self .downsample = downsample
6278 self .drop_path = drop_path
@@ -101,22 +117,33 @@ def __init__(
101117 aa_layer = None ,
102118 drop_block = None ,
103119 drop_path = None ,
120+ device = None ,
121+ dtype = None ,
104122 ):
123+ dd = {'device' : device , 'dtype' : dtype }
105124 super (SelectiveKernelBottleneck , self ).__init__ ()
106125
107126 sk_kwargs = sk_kwargs or {}
108- conv_kwargs = dict (act_layer = act_layer , norm_layer = norm_layer )
127+ conv_kwargs = dict (act_layer = act_layer , norm_layer = norm_layer , ** dd )
109128 width = int (math .floor (planes * (base_width / 64 )) * cardinality )
110129 first_planes = width // reduce_first
111130 outplanes = planes * self .expansion
112131 first_dilation = first_dilation or dilation
113132
114133 self .conv1 = ConvNormAct (inplanes , first_planes , kernel_size = 1 , ** conv_kwargs )
115134 self .conv2 = SelectiveKernel (
116- first_planes , width , stride = stride , dilation = first_dilation , groups = cardinality ,
117- aa_layer = aa_layer , drop_layer = drop_block , ** conv_kwargs , ** sk_kwargs )
135+ first_planes ,
136+ width ,
137+ stride = stride ,
138+ dilation = first_dilation ,
139+ groups = cardinality ,
140+ aa_layer = aa_layer ,
141+ drop_layer = drop_block ,
142+ ** conv_kwargs ,
143+ ** sk_kwargs ,
144+ )
118145 self .conv3 = ConvNormAct (width , outplanes , kernel_size = 1 , apply_act = False , ** conv_kwargs )
119- self .se = create_attn (attn_layer , outplanes )
146+ self .se = create_attn (attn_layer , outplanes , ** dd )
120147 self .act = act_layer (inplace = True )
121148 self .downsample = downsample
122149 self .drop_path = drop_path
0 commit comments