File size: 6,791 Bytes
da716ed |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
import numpy as np
import torch
import tqdm
from typing import Callable, List
from pytorch_grad_cam.base_cam import BaseCAM
from pytorch_grad_cam.utils.find_layers import replace_layer_recursive
from pytorch_grad_cam.ablation_layer import AblationLayer
""" Implementation of AblationCAM
https://openaccess.thecvf.com/content_WACV_2020/papers/Desai_Ablation-CAM_Visual_Explanations_for_Deep_Convolutional_Network_via_Gradient-free_Localization_WACV_2020_paper.pdf
Ablate individual activations, and then measure the drop in the target score.
In the current implementation, the target layer activations is cached, so it won't be re-computed.
However layers before it, if any, will not be cached.
This means that if the target layer is a large block, for example model.featuers (in vgg), there will
be a large save in run time.
Since we have to go over many channels and ablate them, and every channel ablation requires a forward pass,
it would be nice if we could avoid doing that for channels that won't contribute anwyay, making it much faster.
The parameter ratio_channels_to_ablate controls how many channels should be ablated, using an experimental method
(to be improved). The default 1.0 value means that all channels will be ablated.
"""
class AblationCAM(BaseCAM):
def __init__(self,
model: torch.nn.Module,
target_layers: List[torch.nn.Module],
use_cuda: bool = False,
reshape_transform: Callable = None,
ablation_layer: torch.nn.Module = AblationLayer(),
batch_size: int = 32,
ratio_channels_to_ablate: float = 1.0) -> None:
super(AblationCAM, self).__init__(model,
target_layers,
use_cuda,
reshape_transform,
uses_gradients=False)
self.batch_size = batch_size
self.ablation_layer = ablation_layer
self.ratio_channels_to_ablate = ratio_channels_to_ablate
def save_activation(self, module, input, output) -> None:
""" Helper function to save the raw activations from the target layer """
self.activations = output
def assemble_ablation_scores(self,
new_scores: list,
original_score: float,
ablated_channels: np.ndarray,
number_of_channels: int) -> np.ndarray:
""" Take the value from the channels that were ablated,
and just set the original score for the channels that were skipped """
index = 0
result = []
sorted_indices = np.argsort(ablated_channels)
ablated_channels = ablated_channels[sorted_indices]
new_scores = np.float32(new_scores)[sorted_indices]
for i in range(number_of_channels):
if index < len(ablated_channels) and ablated_channels[index] == i:
weight = new_scores[index]
index = index + 1
else:
weight = original_score
result.append(weight)
return result
def get_cam_weights(self,
input_tensor: torch.Tensor,
target_layer: torch.nn.Module,
targets: List[Callable],
activations: torch.Tensor,
grads: torch.Tensor) -> np.ndarray:
# Do a forward pass, compute the target scores, and cache the
# activations
handle = target_layer.register_forward_hook(self.save_activation)
with torch.no_grad():
outputs = self.model(input_tensor)
handle.remove()
original_scores = np.float32(
[target(output).cpu().item() for target, output in zip(targets, outputs)])
# Replace the layer with the ablation layer.
# When we finish, we will replace it back, so the original model is
# unchanged.
ablation_layer = self.ablation_layer
replace_layer_recursive(self.model, target_layer, ablation_layer)
number_of_channels = activations.shape[1]
weights = []
# This is a "gradient free" method, so we don't need gradients here.
with torch.no_grad():
# Loop over each of the batch images and ablate activations for it.
for batch_index, (target, tensor) in enumerate(
zip(targets, input_tensor)):
new_scores = []
batch_tensor = tensor.repeat(self.batch_size, 1, 1, 1)
# Check which channels should be ablated. Normally this will be all channels,
# But we can also try to speed this up by using a low
# ratio_channels_to_ablate.
channels_to_ablate = ablation_layer.activations_to_be_ablated(
activations[batch_index, :], self.ratio_channels_to_ablate)
number_channels_to_ablate = len(channels_to_ablate)
for i in tqdm.tqdm(
range(
0,
number_channels_to_ablate,
self.batch_size)):
if i + self.batch_size > number_channels_to_ablate:
batch_tensor = batch_tensor[:(
number_channels_to_ablate - i)]
# Change the state of the ablation layer so it ablates the next channels.
# TBD: Move this into the ablation layer forward pass.
ablation_layer.set_next_batch(
input_batch_index=batch_index,
activations=self.activations,
num_channels_to_ablate=batch_tensor.size(0))
score = [target(o).cpu().item()
for o in self.model(batch_tensor)]
new_scores.extend(score)
ablation_layer.indices = ablation_layer.indices[batch_tensor.size(
0):]
new_scores = self.assemble_ablation_scores(
new_scores,
original_scores[batch_index],
channels_to_ablate,
number_of_channels)
weights.extend(new_scores)
weights = np.float32(weights)
weights = weights.reshape(activations.shape[:2])
original_scores = original_scores[:, None]
weights = (original_scores - weights) / original_scores
# Replace the model back to the original state
replace_layer_recursive(self.model, ablation_layer, target_layer)
return weights
|