@@ -89,6 +89,30 @@ def forward(self, input):
8989 output = self .linear (recurrent ) # batch_size x T x output_size
9090 return output
9191
92+ class VGG_FeatureExtractor (nn .Module ):
93+
94+ def __init__ (self , input_channel , output_channel = 256 ):
95+ super (VGG_FeatureExtractor , self ).__init__ ()
96+ self .output_channel = [int (output_channel / 8 ), int (output_channel / 4 ),
97+ int (output_channel / 2 ), output_channel ]
98+ self .ConvNet = nn .Sequential (
99+ nn .Conv2d (input_channel , self .output_channel [0 ], 3 , 1 , 1 ), nn .ReLU (True ),
100+ nn .MaxPool2d (2 , 2 ),
101+ nn .Conv2d (self .output_channel [0 ], self .output_channel [1 ], 3 , 1 , 1 ), nn .ReLU (True ),
102+ nn .MaxPool2d (2 , 2 ),
103+ nn .Conv2d (self .output_channel [1 ], self .output_channel [2 ], 3 , 1 , 1 ), nn .ReLU (True ),
104+ nn .Conv2d (self .output_channel [2 ], self .output_channel [2 ], 3 , 1 , 1 ), nn .ReLU (True ),
105+ nn .MaxPool2d ((2 , 1 ), (2 , 1 )),
106+ nn .Conv2d (self .output_channel [2 ], self .output_channel [3 ], 3 , 1 , 1 , bias = False ),
107+ nn .BatchNorm2d (self .output_channel [3 ]), nn .ReLU (True ),
108+ nn .Conv2d (self .output_channel [3 ], self .output_channel [3 ], 3 , 1 , 1 , bias = False ),
109+ nn .BatchNorm2d (self .output_channel [3 ]), nn .ReLU (True ),
110+ nn .MaxPool2d ((2 , 1 ), (2 , 1 )),
111+ nn .Conv2d (self .output_channel [3 ], self .output_channel [3 ], 2 , 1 , 0 ), nn .ReLU (True ))
112+
113+ def forward (self , input ):
114+ return self .ConvNet (input )
115+
92116class ResNet_FeatureExtractor (nn .Module ):
93117 """ FeatureExtractor of FAN (http://openaccess.thecvf.com/content_ICCV_2017/papers/Cheng_Focusing_Attention_Towards_ICCV_2017_paper.pdf) """
94118
0 commit comments