-
-
Notifications
You must be signed in to change notification settings - Fork 1.8k
Open
Description
In the segformer paper, the diagram looks like this

But in this repo, the code is written as below. How come it has encoder name attribute, there's no CNN feature extraction separately in the original design plan?
@supports_config_loading
def __init__(
self,
encoder_name: str = "resnet34",
encoder_depth: int = 5,
encoder_weights: Optional[str] = "imagenet",
decoder_segmentation_channels: int = 256,
in_channels: int = 3,
classes: int = 1,
activation: Optional[Union[str, Callable]] = None,
upsampling: int = 4,
aux_params: Optional[dict] = None,
**kwargs: dict[str, Any],
):
super().__init__()
self.encoder = get_encoder(
encoder_name,
in_channels=in_channels,
depth=encoder_depth,
weights=encoder_weights,
**kwargs,
)
self.decoder = SegformerDecoder(
encoder_channels=self.encoder.out_channels,
encoder_depth=encoder_depth,
segmentation_channels=decoder_segmentation_channels,
)
self.segmentation_head = SegmentationHead(
in_channels=decoder_segmentation_channels,
out_channels=classes,
activation=activation,
kernel_size=1,
upsampling=upsampling,
)
if aux_params is not None:
self.classification_head = ClassificationHead(
in_channels=self.encoder.out_channels[-1], **aux_params
)
else:
self.classification_head = None
self.name = "segformer-{}".format(encoder_name)
self.initialize()
Metadata
Metadata
Assignees
Labels
No labels