Image Segmentation: UNet & DeepLab
“Classify every single pixel — semantic masks, instance boundaries, and panoptic understanding”
Pixel-level classification — semantic vs instance vs panoptic segmentation, skip connections in UNet, dilated convolutions in DeepLab, Dice loss for class imbalance, and applications in medical imaging and autonomous driving.
Prerequisites
Concepts Covered
∑Key Formulas
Dice Loss
Measures overlap between predicted and ground-truth masks — robust to class imbalance
mIoU
Mean Intersection over Union per class — standard segmentation metric
Skip Connection
UNet residual connection merges encoder features with decoder features — recovers spatial details lost in downsampling
▶Interactive Simulation
Segmentation: Pixel-Level Understanding
Object detection gives bounding boxes — rough approximations. Segmentation gives pixel-perfect masks. This matters for: medical imaging (delineate a tumour boundary precisely for radiotherapy planning), autonomous driving (distinguish drivable surface from sidewalk at every pixel), satellite imagery (calculate crop area to 1m² precision), portrait mode (separate person from background for bokeh effect). Three levels: **Semantic segmentation** — label each pixel with a class (no instance distinction). **Instance segmentation** — detect and mask individual object instances (Mask R-CNN). **Panoptic segmentation** — combined: things (instances) + stuff (amorphous regions like sky/road).
A radiologist delineating a tumour by hand takes 30–60 minutes per scan. An AI segmentation model does it in < 1 second — the bottleneck shifts to verifying AI output, not producing it.
Encoder–Decoder Architecture (UNet)
UNet is the canonical segmentation architecture. The encoder (contracting path) is a CNN that progressively downsamples the feature map — learning 'what' is in the image, losing 'where'. The decoder (expanding path) progressively upsamples back to the original resolution using transposed convolutions or bilinear upsampling — recovering 'where'. The key innovation: skip connections that directly concatenate encoder feature maps at each resolution to their decoder counterpart. These skip connections let the model combine high-level semantic features (from deep encoder layers) with low-level spatial detail (from shallow encoder layers) — essential for sharp, accurate boundaries.
UNet was designed for biomedical segmentation in 2015 with very few training images (~30). The architecture's data efficiency comes from skip connections and heavy data augmentation.
Segmentation with torchvision and Albumentations
import torch import torch.nn as nn import torchvision.models as models from torchvision.models.segmentation import DeepLabV3_ResNet101_Weights class="tok-comment"># ── class="tok-num">1. Pretrained semantic segmentation (DeepLabV3) ─────────────────────────── model = models.segmentation.deeplabv3_resnet101( weights=DeepLabV3_ResNet101_Weights.DEFAULT ) model.eval() class="tok-comment"># Inference from PIL import Image import torchvision.transforms.functional as F import numpy as np img = Image.open(class="tok-str">"street.jpg").convert(class="tok-str">"RGB") img_t = F.to_tensor(img).unsqueeze(class="tok-num">0) class="tok-comment"># (class="tok-num">1, class="tok-num">3, H, W) img_t = F.normalize(img_t, mean=[class="tok-num">0.485,class="tok-num">0.456,class="tok-num">0.406], std=[class="tok-num">0.229,class="tok-num">0.224,class="tok-num">0.225]) with torch.no_grad(): output = model(img_t)[class="tok-str">"out"] class="tok-comment"># (class="tok-num">1, class="tok-num">21, H, W) — class="tok-num">21 PASCAL VOC classes pred_mask = output.argmax(dim=class="tok-num">1)[class="tok-num">0] class="tok-comment"># (H, W) class labels class="tok-comment"># ── class="tok-num">2. Minimal UNet ──────────────────────────────────────────────────────────── class DoubleConv(nn.Module): def __init__(self, in_ch, out_ch): super().__init__() self.conv = nn.Sequential( nn.Conv2d(in_ch, out_ch, class="tok-num">3, padding=class="tok-num">1, bias=False), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True), nn.Conv2d(out_ch, out_ch, class="tok-num">3, padding=class="tok-num">1, bias=False), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True), ) def forward(self, x): return self.conv(x) class UNet(nn.Module): def __init__(self, in_channels=class="tok-num">3, n_classes=class="tok-num">2, base_channels=class="tok-num">64): super().__init__() bc = base_channels class="tok-comment"># Encoder self.enc1 = DoubleConv(in_channels, bc) self.enc2 = DoubleConv(bc, bc*class="tok-num">2) self.enc3 = DoubleConv(bc*class="tok-num">2, bc*class="tok-num">4) self.pool = nn.MaxPool2d(class="tok-num">2) class="tok-comment"># Bottleneck self.bottleneck = DoubleConv(bc*class="tok-num">4, bc*class="tok-num">8) class="tok-comment"># Decoder self.up3 = nn.ConvTranspose2d(bc*class="tok-num">8, bc*class="tok-num">4, class="tok-num">2, stride=class="tok-num">2) self.dec3 = DoubleConv(bc*class="tok-num">8, bc*class="tok-num">4) class="tok-comment"># bc*class="tok-num">8 because of skip connection self.up2 = nn.ConvTranspose2d(bc*class="tok-num">4, bc*class="tok-num">2, class="tok-num">2, stride=class="tok-num">2) self.dec2 = DoubleConv(bc*class="tok-num">4, bc*class="tok-num">2) self.up1 = nn.ConvTranspose2d(bc*class="tok-num">2, bc, class="tok-num">2, stride=class="tok-num">2) self.dec1 = DoubleConv(bc*class="tok-num">2, bc) class="tok-comment"># Output self.out = nn.Conv2d(bc, n_classes, class="tok-num">1) def forward(self, x): e1 = self.enc1(x) e2 = self.enc2(self.pool(e1)) e3 = self.enc3(self.pool(e2)) b = self.bottleneck(self.pool(e3)) d3 = self.dec3(torch.cat([self.up3(b), e3], dim=class="tok-num">1)) d2 = self.dec2(torch.cat([self.up2(d3), e2], dim=class="tok-num">1)) d1 = self.dec1(torch.cat([self.up1(d2), e1], dim=class="tok-num">1)) return self.out(d1) unet = UNet(n_classes=class="tok-num">2) x = torch.randn(class="tok-num">2, class="tok-num">3, class="tok-num">256, class="tok-num">256) out = unet(x) print(fclass="tok-str">"UNet output: {out.shape}") class="tok-comment"># (class="tok-num">2, class="tok-num">2, class="tok-num">256, class="tok-num">256) class="tok-comment"># ── class="tok-num">3. Dice loss ────────────────────────────────────────────────────────────── def dice_loss(pred, target, eps=class="tok-num">1e-6): class="tok-str">"""pred: (B, C, H, W) softmax, target: (B, H, W) long""" pred_soft = torch.softmax(pred, dim=class="tok-num">1) target_oh = torch.zeros_like(pred_soft) target_oh.scatter_(class="tok-num">1, target.unsqueeze(class="tok-num">1), class="tok-num">1) inter = (pred_soft * target_oh).sum(dim=(class="tok-num">2,class="tok-num">3)) union = pred_soft.sum(dim=(class="tok-num">2,class="tok-num">3)) + target_oh.sum(dim=(class="tok-num">2,class="tok-num">3)) return class="tok-num">1 - (class="tok-num">2*inter + eps) / (union + eps) pred = torch.randn(class="tok-num">2, class="tok-num">2, class="tok-num">64, class="tok-num">64) target = torch.randint(class="tok-num">0, class="tok-num">2, (class="tok-num">2, class="tok-num">64, class="tok-num">64)) loss = dice_loss(pred, target).mean() print(fclass="tok-str">"Dice loss: {loss.item():.4f}")
Class Imbalance Kills Segmentation Models
In most segmentation tasks the background class dominates — a self-driving scene might be 95% sky+road and 5% pedestrians. Cross-entropy treats all pixels equally, so the model learns to predict 'background' everywhere and gets 95% accuracy while missing every pedestrian. Solutions: (1) Weighted cross-entropy — weight the loss inversely by class frequency. (2) Dice loss — naturally insensitive to imbalance because it measures overlap ratio, not pixel count. (3) Focal loss (from RetinaNet) — downweights well-classified pixels so training focuses on hard examples. In practice, combine: total_loss = cross_entropy + dice_loss works best for medical imaging.
Always check the IoU per class in your validation metrics — global accuracy hides bad performance on small/rare classes that are often the ones that matter most.
?Knowledge Check
Progress is saved in your browser — no account needed.
Need an AI engineer or data scientist?
I build custom ML models, AI agents, computer vision, and automation — from idea to production.