The GradCAM module provides visual shortcut detection for image models.
::: shortcut_detect.gradcam.GradCAMHeatmapGenerator options: show_root_heading: true show_source: true
GradCAMHeatmapGenerator(
model: torch.nn.Module,
target_layer: torch.nn.Module,
device: str = 'cuda',
use_guided: bool = False
)| Parameter | Type | Default | Description |
|---|---|---|---|
model |
nn.Module | required | PyTorch model |
target_layer |
nn.Module | required | Layer for GradCAM |
device |
str | 'cuda' | Computation device |
use_guided |
bool | False | Use Guided GradCAM |
def generate(
input_tensor: torch.Tensor,
target_class: int = None
) -> np.ndarrayGenerate GradCAM heatmap for a single input.
Parameters:
| Parameter | Type | Description |
|---|---|---|
input_tensor |
Tensor | Shape (C, H, W) or (1, C, H, W) |
target_class |
int | Class to explain (None = predicted) |
Returns: Heatmap array (H, W)
def generate_batch(
inputs: torch.Tensor,
target_classes: list[int] = None
) -> list[np.ndarray]Generate heatmaps for a batch of inputs.
def visualize(
input_tensor: torch.Tensor,
heatmap: np.ndarray,
alpha: float = 0.4,
colormap: str = 'jet',
save_path: str = None
) -> np.ndarrayOverlay heatmap on input image.
def compare_groups(
heatmaps: np.ndarray,
group_labels: np.ndarray
) -> AttentionOverlapResultCompare attention patterns between groups.
Returns: AttentionOverlapResult dataclass
@dataclass
class AttentionOverlapResult:
overlap_score: float # Attention overlap (0-1)
group_heatmaps: dict # Average heatmap per group
divergence_regions: ndarray # Regions with different attention
summary: str # Human-readable summaryfrom shortcut_detect import GradCAMHeatmapGenerator
import torch
model = torch.load("model.pth")
target_layer = model.layer4[-1]
gradcam = GradCAMHeatmapGenerator(model, target_layer)
heatmap = gradcam.generate(image_tensor)
gradcam.visualize(image_tensor, heatmap, save_path="attention.png")# Generate heatmaps for all images
heatmaps = []
for img in images:
heatmaps.append(gradcam.generate(img))
heatmaps = np.stack(heatmaps)
# Compare groups
result = gradcam.compare_groups(heatmaps, group_labels)
print(f"Overlap: {result.overlap_score:.2f}")
print(result.summary)from torch.utils.data import DataLoader
all_heatmaps = []
for batch in DataLoader(dataset, batch_size=32):
images, labels = batch
heatmaps = gradcam.generate_batch(images.cuda(), labels)
all_heatmaps.extend(heatmaps)# ResNet
target_layer = model.layer4[-1]
# VGG
target_layer = model.features[-1]
# DenseNet
target_layer = model.features.denseblock4
# EfficientNet
target_layer = model.features[-1]