Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ Authors@R: c(
role = c("ctb"),
email = "[email protected]"
),
person(given = "ANAMASGARD",
role = c("ctb")
),
person(family = "RStudio", role = c("cph"))
)
Description: Provides access to datasets, models and preprocessing
Expand Down Expand Up @@ -85,6 +88,7 @@ Collate:
'imagenet.R'
'models-alexnet.R'
'models-convnext.R'
'models-convnext_detection.R'
'models-deeplabv3.R'
'models-efficientnet.R'
'models-efficientnetv2.R'
Expand Down
3 changes: 3 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -107,12 +107,15 @@ export(mnist_dataset)
export(model_alexnet)
export(model_convnext_base_1k)
export(model_convnext_base_22k)
export(model_convnext_base_detection)
export(model_convnext_large_1k)
export(model_convnext_large_22k)
export(model_convnext_small_22k)
export(model_convnext_small_22k1k)
export(model_convnext_small_detection)
export(model_convnext_tiny_1k)
export(model_convnext_tiny_22k)
export(model_convnext_tiny_detection)
export(model_deeplabv3_resnet101)
export(model_deeplabv3_resnet50)
export(model_efficientnet_b0)
Expand Down
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# torchvision (development version)

## New models

* Added `model_convnext_detection()` for object detection (#262, @ANAMASGARD).

## Bug fixes and improvements

* Remove `.getbatch` method from MNIST as it is providing inconsistent tensor dimensions with `.getitem` due
Expand Down
266 changes: 266 additions & 0 deletions R/models-convnext_detection.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,266 @@
#' ConvNeXt Detection Models (Faster R-CNN style)
#'
#' @description
#' Object detection models that use a ConvNeXt backbone with a Feature
#' Pyramid Network (FPN) and the same detection head as the Faster R-CNN
#' models implemented in `model_fasterrcnn_*`.
#'
#' These helpers mirror the architecture used in
#' `model_fasterrcnn_resnet50_fpn()`, but swap the ResNet backbone for
#' ConvNeXt variants.
#'
#' @section Available Models:
#' \itemize{
#' \item `model_convnext_tiny_detection()`
#' \item `model_convnext_small_detection()`
#' \item `model_convnext_base_detection()`
#' }
#'
#' @inheritParams model_fasterrcnn_resnet50_fpn
#' @param pretrained_backbone Logical, if `TRUE` the ConvNeXt backbone
#' weights are loaded from ImageNet pretraining.
#'
#' @note Currently, detection head weights are randomly initialized, so predicted
#' bounding-boxes are random. For meaningful results, you need to train the model
#' detection head on your data.
#'
#' @examples
#' \dontrun{
#' library(magrittr)
#' norm_mean <- c(0.485, 0.456, 0.406) # ImageNet normalization constants
#' norm_std <- c(0.229, 0.224, 0.225)
#'
#' # Use a publicly available image
#' wmc <- "https://upload.wikimedia.org/wikipedia/commons/thumb/"
#' url <- "e/ea/Morsan_Normande_vache.jpg/120px-Morsan_Normande_vache.jpg"
#' img <- base_loader(paste0(wmc, url))
#'
#' input <- img %>%
#' transform_to_tensor() %>%
#' transform_resize(c(520, 520)) %>%
#' transform_normalize(norm_mean, norm_std)
#' batch <- input$unsqueeze(1) # Add batch dimension (1, 3, H, W)
#'
#' # ConvNeXt Tiny detection
#' model <- model_convnext_tiny_detection(pretrained_backbone = TRUE)
#' model$eval()
#' pred <- model(batch)$detections
#' num_boxes <- as.integer(pred$boxes$size()[1])
#' topk <- pred$scores$topk(k = 5)[[2]]
#' boxes <- pred$boxes[topk, ]
#' labels <- as.character(as.integer(pred$labels[topk]))
#'
#' # `draw_bounding_box()` may fail if bbox values are not consistent.
#' if (num_boxes > 0) {
#' boxed <- draw_bounding_boxes(input, boxes, labels = labels)
#' tensor_image_browse(boxed)
#' }
#' }
#'
#' @family object_detection_model
#' @name model_convnext_detection
NULL


convnext_fpn_backbone_tiny <- function(pretrained_backbone = FALSE, ...) {
convnext <- model_convnext_tiny_1k(pretrained = pretrained_backbone, ...)

convnext_body <- torch::nn_module(
initialize = function() {
self$model <- convnext
},
forward = function(x) {
c2 <- x %>%
self$model$downsample_layers[[1]]() %>%
self$model$stages[[1]]()

c3 <- c2 %>%
self$model$downsample_layers[[2]]() %>%
self$model$stages[[2]]()

c4 <- c3 %>%
self$model$downsample_layers[[3]]() %>%
self$model$stages[[3]]()

c5 <- c4 %>%
self$model$downsample_layers[[4]]() %>%
self$model$stages[[4]]()

list(c2, c3, c4, c5)
}
)

backbone_module <- torch::nn_module(
initialize = function() {
self$body <- convnext_body()
self$fpn <- fpn_module(
in_channels = c(96, 192, 384, 768),
out_channels = 256
)()
},
forward = function(x) {
c2_to_c5 <- self$body(x)
self$fpn(c2_to_c5)
}
)

backbone <- backbone_module()
backbone$out_channels <- 256
backbone
}


convnext_fpn_backbone_small <- function(pretrained_backbone = FALSE, ...) {
convnext <- model_convnext_small_22k(pretrained = pretrained_backbone, ...)

convnext_body <- torch::nn_module(
initialize = function() {
self$model <- convnext
},
forward = function(x) {
c2 <- x %>%
self$model$downsample_layers[[1]]() %>%
self$model$stages[[1]]()

c3 <- c2 %>%
self$model$downsample_layers[[2]]() %>%
self$model$stages[[2]]()

c4 <- c3 %>%
self$model$downsample_layers[[3]]() %>%
self$model$stages[[3]]()

c5 <- c4 %>%
self$model$downsample_layers[[4]]() %>%
self$model$stages[[4]]()

list(c2, c3, c4, c5)
}
)

backbone_module <- torch::nn_module(
initialize = function() {
self$body <- convnext_body()
self$fpn <- fpn_module(
in_channels = c(96, 192, 384, 768),
out_channels = 256
)()
},
forward = function(x) {
c2_to_c5 <- self$body(x)
self$fpn(c2_to_c5)
}
)

backbone <- backbone_module()
backbone$out_channels <- 256
backbone
}


convnext_fpn_backbone_base <- function(pretrained_backbone = FALSE, ...) {
convnext <- model_convnext_base_1k(pretrained = pretrained_backbone, ...)

convnext_body <- torch::nn_module(
initialize = function() {
self$model <- convnext
},
forward = function(x) {
c2 <- x %>%
self$model$downsample_layers[[1]]() %>%
self$model$stages[[1]]()

c3 <- c2 %>%
self$model$downsample_layers[[2]]() %>%
self$model$stages[[2]]()

c4 <- c3 %>%
self$model$downsample_layers[[3]]() %>%
self$model$stages[[3]]()

c5 <- c4 %>%
self$model$downsample_layers[[4]]() %>%
self$model$stages[[4]]()

list(c2, c3, c4, c5)
}
)

backbone_module <- torch::nn_module(
initialize = function() {
self$body <- convnext_body()
self$fpn <- fpn_module(
in_channels = c(128, 256, 512, 1024),
out_channels = 256
)()
},
forward = function(x) {
c2_to_c5 <- self$body(x)
self$fpn(c2_to_c5)
}
)

backbone <- backbone_module()
backbone$out_channels <- 256
backbone
}


validate_convnext_num_classes <- function(num_classes) {
if (num_classes <= 0) {
cli_abort("{.var num_classes} must be positive")
}
}


#' @describeIn model_convnext_detection ConvNeXt Tiny with FPN detection head
#' @export
model_convnext_tiny_detection <- function(num_classes = 91,
pretrained_backbone = FALSE,
...) {
validate_convnext_num_classes(num_classes)

backbone <- convnext_fpn_backbone_tiny(
pretrained_backbone = pretrained_backbone,
...
)

model <- fasterrcnn_model(backbone, num_classes = num_classes)()
model
}


#' @describeIn model_convnext_detection ConvNeXt Small with FPN detection head
#' @export
model_convnext_small_detection <- function(num_classes = 91,
pretrained_backbone = FALSE,
...) {
validate_convnext_num_classes(num_classes)

backbone <- convnext_fpn_backbone_small(
pretrained_backbone = pretrained_backbone,
...
)

model <- fasterrcnn_model(backbone, num_classes = num_classes)()
model
}


#' @describeIn model_convnext_detection ConvNeXt Base with FPN detection head
#' @export
model_convnext_base_detection <- function(num_classes = 91,
pretrained_backbone = FALSE,
...) {
validate_convnext_num_classes(num_classes)

backbone <- convnext_fpn_backbone_base(
pretrained_backbone = pretrained_backbone,
...
)

model <- fasterrcnn_model(backbone, num_classes = num_classes)()
model
}


2 changes: 1 addition & 1 deletion R/models-faster_rcnn.R
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,7 @@ fasterrcnn_model <- function(backbone, num_classes) {
)
)
}
) # <- Removed the () here
)
}


Expand Down
2 changes: 1 addition & 1 deletion R/vision_utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ draw_bounding_boxes.torch_tensor <- function(x,
type_error("`x` should be of dtype `torch_uint8` or `torch_float`")
}
if ((boxes[, 1] >= boxes[, 3])$any() %>% as.logical() || (boxes[, 2] >= boxes[, 4])$any() %>% as.logical()) {
value_error("Boxes need to be in c(xmin, ymin, xmax, ymax) format. Use torchvision$ops$box_convert to convert them")
value_error("Boxes need to be in c(xmin, ymin, xmax, ymax) format. Use `box_convert()` to convert them")
}
num_boxes <- boxes$shape[1]
if (num_boxes == 0) {
Expand Down
Loading