From 9e10eeeaddf36a7d21d8fdb6016c51b1dc069a5e Mon Sep 17 00:00:00 2001 From: jackyjinjing Date: Mon, 2 Sep 2024 03:04:40 +0000 Subject: [PATCH 1/2] add clip example --- usage_examples/clip_example | 157 ++++++++++++++++++++++++++++++++++++ 1 file changed, 157 insertions(+) create mode 100644 usage_examples/clip_example diff --git a/usage_examples/clip_example b/usage_examples/clip_example new file mode 100644 index 00000000..7ae31515 --- /dev/null +++ b/usage_examples/clip_example @@ -0,0 +1,157 @@ +import argparse +import cv2 +import numpy as np +import torch +from torch import nn +import open_clip + +from pytorch_grad_cam import GradCAM, \ + ScoreCAM, \ + GradCAMPlusPlus, \ + AblationCAM, \ + XGradCAM, \ + EigenCAM, \ + EigenGradCAM, \ + LayerCAM, \ + FullGrad + +from pytorch_grad_cam.utils.image import show_cam_on_image, \ + preprocess_image +from pytorch_grad_cam.ablation_layer import AblationLayerVit + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--use-cuda', action='store_true', default=False, + help='Use NVIDIA GPU acceleration') + parser.add_argument( + '--image-path', + type=str, + default='./examples/both.png', + help='Input image path') + parser.add_argument('--aug_smooth', action='store_true', + help='Apply test time augmentation to smooth the CAM') + parser.add_argument( + '--eigen_smooth', + action='store_true', + help='Reduce noise by taking the first principle componenet' + 'of cam_weights*activations') + + parser.add_argument( + '--method', + type=str, + default='gradcam', + help='Can be gradcam/gradcam++/scorecam/xgradcam/ablationcam') + + args = parser.parse_args() + args.use_cuda = args.use_cuda and torch.cuda.is_available() + if args.use_cuda: + print('Using GPU for acceleration') + else: + print('Using CPU for computation') + + return args + + +def reshape_transform(tensor, height=16, width=16): + result = tensor[:, 1:, :].reshape(tensor.size(0), + height, width, tensor.size(2)) + + # Bring the channels to the first dimension, + # like in CNNs. + result = result.transpose(2, 3).transpose(1, 2) + return result + + +class ImageClassifier(nn.Module): + def __init__(self, num_classes, device="cpu"): + super(ImageClassifier, self).__init__() + + model, _, _ = open_clip.create_model_and_transforms('ViT-L-14', device=device) + + self.encoder = model.visual + + num_features = 1024 + self.encoder.proj = None + + self.classifier = nn.Sequential( + nn.Linear(num_features, 2 * num_features), + nn.ReLU(), + nn.Dropout(0.5), + nn.Linear(2 * num_features, num_classes) + ) + + def forward(self, x): + x = self.encoder(x) + x = self.classifier(x) + return x + + +if __name__ == '__main__': + """ python vit_gradcam.py --image-path + Example usage of using cam-methods on a VIT network. + + """ + + args = get_args() + methods = \ + {"gradcam": GradCAM, + "scorecam": ScoreCAM, + "gradcam++": GradCAMPlusPlus, + "ablationcam": AblationCAM, + "xgradcam": XGradCAM, + "eigencam": EigenCAM, + "eigengradcam": EigenGradCAM, + "layercam": LayerCAM, + "fullgrad": FullGrad} + + if args.method not in list(methods.keys()): + raise Exception(f"method should be one of {list(methods.keys())}") + + model = ImageClassifier(1000) + model.eval() + print(model) + + target_layers = [model.encoder.transformer.resblocks[-1].ln_1] + + if args.method not in methods: + raise Exception(f"Method {args.method} not implemented") + + if args.use_cuda: + model = model.cuda() + + if args.method == "ablationcam": + cam = methods[args.method](model=model, + target_layers=target_layers, + reshape_transform=reshape_transform, + ablation_layer=AblationLayerVit()) + else: + cam = methods[args.method](model=model, + target_layers=target_layers, + reshape_transform=reshape_transform) + + rgb_img = cv2.imread(args.image_path, 1)[:, :, ::-1] + rgb_img = cv2.resize(rgb_img, (224, 224)) + rgb_img = np.float32(rgb_img) / 255 + input_tensor = preprocess_image(rgb_img, mean=[0.5, 0.5, 0.5], + std=[0.5, 0.5, 0.5]) + + # If None, returns the map for the highest scoring category. + # Otherwise, targets the requested category. + targets = None + print(input_tensor.shape) + + # AblationCAM and ScoreCAM have batched implementations. + # You can override the internal batch size for faster computation. + cam.batch_size = 32 + + grayscale_cam = cam(input_tensor=input_tensor, + targets=targets, + eigen_smooth=args.eigen_smooth, + aug_smooth=args.aug_smooth) + + # Here grayscale_cam has only one image in the batch + grayscale_cam = grayscale_cam[0, :] + + cam_image = show_cam_on_image(rgb_img, grayscale_cam) + cv2.imwrite(f'{args.method}_cam.jpg', cam_image) From e0211ffbe16e9e4d2fe7261c932057fd795e0aee Mon Sep 17 00:00:00 2001 From: jackyjinjing Date: Fri, 13 Sep 2024 03:36:22 +0000 Subject: [PATCH 2/2] add clip example --- requirements.txt | 3 +- usage_examples/clip_example | 59 ++++++++++++++++++++++--------------- 2 files changed, 37 insertions(+), 25 deletions(-) diff --git a/requirements.txt b/requirements.txt index 1451e880..ae14dcc1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,4 +6,5 @@ ttach tqdm opencv-python matplotlib -scikit-learn \ No newline at end of file +scikit-learn +transformers \ No newline at end of file diff --git a/usage_examples/clip_example b/usage_examples/clip_example index 7ae31515..9339d171 100644 --- a/usage_examples/clip_example +++ b/usage_examples/clip_example @@ -1,9 +1,11 @@ import argparse + import cv2 import numpy as np import torch from torch import nn -import open_clip +from transformers import CLIPProcessor, CLIPModel + from pytorch_grad_cam import GradCAM, \ ScoreCAM, \ @@ -29,6 +31,14 @@ def get_args(): type=str, default='./examples/both.png', help='Input image path') + parser.add_argument( + '--labels', + type=str, + nargs='+', + default=["a cat", "a dog", "a car", "a person", "a shoe"], + help='need recognition labels' + ) + parser.add_argument('--aug_smooth', action='store_true', help='Apply test time augmentation to smooth the CAM') parser.add_argument( @@ -64,27 +74,23 @@ def reshape_transform(tensor, height=16, width=16): class ImageClassifier(nn.Module): - def __init__(self, num_classes, device="cpu"): + def __init__(self, labels): super(ImageClassifier, self).__init__() + self.clip = CLIPModel.from_pretrained("openai/clip-vit-large-patch14") + self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14") + self.labels = labels - model, _, _ = open_clip.create_model_and_transforms('ViT-L-14', device=device) - - self.encoder = model.visual + def forward(self, x): + text_inputs = self.processor(text=labels, return_tensors="pt", padding=True) - num_features = 1024 - self.encoder.proj = None + outputs = self.clip(pixel_values=x, input_ids=text_inputs['input_ids'], attention_mask=text_inputs['attention_mask']) - self.classifier = nn.Sequential( - nn.Linear(num_features, 2 * num_features), - nn.ReLU(), - nn.Dropout(0.5), - nn.Linear(2 * num_features, num_classes) - ) + logits_per_image = outputs.logits_per_image + probs = logits_per_image.softmax(dim=1) - def forward(self, x): - x = self.encoder(x) - x = self.classifier(x) - return x + for label, prob in zip(self.labels, probs[0]): + print(f"{label}: {prob:.4f}") + return probs if __name__ == '__main__': @@ -108,11 +114,14 @@ if __name__ == '__main__': if args.method not in list(methods.keys()): raise Exception(f"method should be one of {list(methods.keys())}") - model = ImageClassifier(1000) + labels = args.labels + model = ImageClassifier(labels) + if args.use_cuda: + model.cuda() model.eval() print(model) - target_layers = [model.encoder.transformer.resblocks[-1].ln_1] + target_layers = [model.clip.vision_model.encoder.layers[-1].layer_norm1] if args.method not in methods: raise Exception(f"Method {args.method} not implemented") @@ -120,6 +129,12 @@ if __name__ == '__main__': if args.use_cuda: model = model.cuda() + rgb_img = cv2.imread(args.image_path, 1)[:, :, ::-1] + rgb_img = cv2.resize(rgb_img, (224, 224)) + rgb_img = np.float32(rgb_img) / 255 + input_tensor = preprocess_image(rgb_img, mean=[0.5, 0.5, 0.5], + std=[0.5, 0.5, 0.5]) + if args.method == "ablationcam": cam = methods[args.method](model=model, target_layers=target_layers, @@ -130,11 +145,7 @@ if __name__ == '__main__': target_layers=target_layers, reshape_transform=reshape_transform) - rgb_img = cv2.imread(args.image_path, 1)[:, :, ::-1] - rgb_img = cv2.resize(rgb_img, (224, 224)) - rgb_img = np.float32(rgb_img) / 255 - input_tensor = preprocess_image(rgb_img, mean=[0.5, 0.5, 0.5], - std=[0.5, 0.5, 0.5]) + # If None, returns the map for the highest scoring category. # Otherwise, targets the requested category.