ML Learning Hub
Visionadvanced

Image Segmentation: UNet & DeepLab

Classify every single pixel — semantic masks, instance boundaries, and panoptic understanding

Pixel-level classification — semantic vs instance vs panoptic segmentation, skip connections in UNet, dilated convolutions in DeepLab, Dice loss for class imbalance, and applications in medical imaging and autonomous driving.

40 min
9 diagrams
7 Concepts Covered

Prerequisites

Object Detection

Concepts Covered

Semantic SegmentationInstance SegmentationUNetSkip ConnectionsDice LossmIoUDeepLab

Key Formulas

Dice Loss

Measures overlap between predicted and ground-truth masks — robust to class imbalance

mIoU

Mean Intersection over Union per class — standard segmentation metric

Skip Connection

UNet residual connection merges encoder features with decoder features — recovers spatial details lost in downsampling

Interactive Simulation

Loading visualization…
🎯

Segmentation: Pixel-Level Understanding

motivation

Object detection gives bounding boxes — rough approximations. Segmentation gives pixel-perfect masks. This matters for: medical imaging (delineate a tumour boundary precisely for radiotherapy planning), autonomous driving (distinguish drivable surface from sidewalk at every pixel), satellite imagery (calculate crop area to 1m² precision), portrait mode (separate person from background for bokeh effect). Three levels: **Semantic segmentation** — label each pixel with a class (no instance distinction). **Instance segmentation** — detect and mask individual object instances (Mask R-CNN). **Panoptic segmentation** — combined: things (instances) + stuff (amorphous regions like sky/road).

A radiologist delineating a tumour by hand takes 30–60 minutes per scan. An AI segmentation model does it in < 1 second — the bottleneck shifts to verifying AI output, not producing it.

💡

Encoder–Decoder Architecture (UNet)

intuition

UNet is the canonical segmentation architecture. The encoder (contracting path) is a CNN that progressively downsamples the feature map — learning 'what' is in the image, losing 'where'. The decoder (expanding path) progressively upsamples back to the original resolution using transposed convolutions or bilinear upsampling — recovering 'where'. The key innovation: skip connections that directly concatenate encoder feature maps at each resolution to their decoder counterpart. These skip connections let the model combine high-level semantic features (from deep encoder layers) with low-level spatial detail (from shallow encoder layers) — essential for sharp, accurate boundaries.

UNet was designed for biomedical segmentation in 2015 with very few training images (~30). The architecture's data efficiency comes from skip connections and heavy data augmentation.

</>

Segmentation with torchvision and Albumentations

code
python86 lines
import torch
import torch.nn as nn
import torchvision.models as models
from torchvision.models.segmentation import DeepLabV3_ResNet101_Weights

class="tok-comment"># ── class="tok-num">1. Pretrained semantic segmentation (DeepLabV3) ───────────────────────────
model = models.segmentation.deeplabv3_resnet101(
    weights=DeepLabV3_ResNet101_Weights.DEFAULT
)
model.eval()

class="tok-comment"># Inference
from PIL import Image
import torchvision.transforms.functional as F
import numpy as np

img = Image.open(class="tok-str">"street.jpg").convert(class="tok-str">"RGB")
img_t = F.to_tensor(img).unsqueeze(class="tok-num">0)   class="tok-comment"># (class="tok-num">1, class="tok-num">3, H, W)
img_t = F.normalize(img_t, mean=[class="tok-num">0.485,class="tok-num">0.456,class="tok-num">0.406], std=[class="tok-num">0.229,class="tok-num">0.224,class="tok-num">0.225])

with torch.no_grad():
    output = model(img_t)[class="tok-str">"out"]         class="tok-comment"># (class="tok-num">1, class="tok-num">21, H, W) — class="tok-num">21 PASCAL VOC classes
pred_mask = output.argmax(dim=class="tok-num">1)[class="tok-num">0]     class="tok-comment"># (H, W) class labels

class="tok-comment"># ── class="tok-num">2. Minimal UNet ────────────────────────────────────────────────────────────
class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, class="tok-num">3, padding=class="tok-num">1, bias=False),
            nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, class="tok-num">3, padding=class="tok-num">1, bias=False),
            nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True),
        )
    def forward(self, x): return self.conv(x)

class UNet(nn.Module):
    def __init__(self, in_channels=class="tok-num">3, n_classes=class="tok-num">2, base_channels=class="tok-num">64):
        super().__init__()
        bc = base_channels
        class="tok-comment"># Encoder
        self.enc1 = DoubleConv(in_channels, bc)
        self.enc2 = DoubleConv(bc, bc*class="tok-num">2)
        self.enc3 = DoubleConv(bc*class="tok-num">2, bc*class="tok-num">4)
        self.pool = nn.MaxPool2d(class="tok-num">2)
        class="tok-comment"># Bottleneck
        self.bottleneck = DoubleConv(bc*class="tok-num">4, bc*class="tok-num">8)
        class="tok-comment"># Decoder
        self.up3  = nn.ConvTranspose2d(bc*class="tok-num">8, bc*class="tok-num">4, class="tok-num">2, stride=class="tok-num">2)
        self.dec3 = DoubleConv(bc*class="tok-num">8, bc*class="tok-num">4)   class="tok-comment"># bc*class="tok-num">8 because of skip connection
        self.up2  = nn.ConvTranspose2d(bc*class="tok-num">4, bc*class="tok-num">2, class="tok-num">2, stride=class="tok-num">2)
        self.dec2 = DoubleConv(bc*class="tok-num">4, bc*class="tok-num">2)
        self.up1  = nn.ConvTranspose2d(bc*class="tok-num">2, bc, class="tok-num">2, stride=class="tok-num">2)
        self.dec1 = DoubleConv(bc*class="tok-num">2, bc)
        class="tok-comment"># Output
        self.out  = nn.Conv2d(bc, n_classes, class="tok-num">1)

    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool(e1))
        e3 = self.enc3(self.pool(e2))
        b  = self.bottleneck(self.pool(e3))
        d3 = self.dec3(torch.cat([self.up3(b),  e3], dim=class="tok-num">1))
        d2 = self.dec2(torch.cat([self.up2(d3), e2], dim=class="tok-num">1))
        d1 = self.dec1(torch.cat([self.up1(d2), e1], dim=class="tok-num">1))
        return self.out(d1)

unet = UNet(n_classes=class="tok-num">2)
x = torch.randn(class="tok-num">2, class="tok-num">3, class="tok-num">256, class="tok-num">256)
out = unet(x)
print(fclass="tok-str">"UNet output: {out.shape}")   class="tok-comment"># (class="tok-num">2, class="tok-num">2, class="tok-num">256, class="tok-num">256)

class="tok-comment"># ── class="tok-num">3. Dice loss ──────────────────────────────────────────────────────────────
def dice_loss(pred, target, eps=class="tok-num">1e-6):
    class="tok-str">"""pred: (B, C, H, W) softmax, target: (B, H, W) long"""
    pred_soft = torch.softmax(pred, dim=class="tok-num">1)
    target_oh = torch.zeros_like(pred_soft)
    target_oh.scatter_(class="tok-num">1, target.unsqueeze(class="tok-num">1), class="tok-num">1)
    inter = (pred_soft * target_oh).sum(dim=(class="tok-num">2,class="tok-num">3))
    union = pred_soft.sum(dim=(class="tok-num">2,class="tok-num">3)) + target_oh.sum(dim=(class="tok-num">2,class="tok-num">3))
    return class="tok-num">1 - (class="tok-num">2*inter + eps) / (union + eps)

pred   = torch.randn(class="tok-num">2, class="tok-num">2, class="tok-num">64, class="tok-num">64)
target = torch.randint(class="tok-num">0, class="tok-num">2, (class="tok-num">2, class="tok-num">64, class="tok-num">64))
loss   = dice_loss(pred, target).mean()
print(fclass="tok-str">"Dice loss: {loss.item():.4f}")
⚠️

Class Imbalance Kills Segmentation Models

pitfall

In most segmentation tasks the background class dominates — a self-driving scene might be 95% sky+road and 5% pedestrians. Cross-entropy treats all pixels equally, so the model learns to predict 'background' everywhere and gets 95% accuracy while missing every pedestrian. Solutions: (1) Weighted cross-entropy — weight the loss inversely by class frequency. (2) Dice loss — naturally insensitive to imbalance because it measures overlap ratio, not pixel count. (3) Focal loss (from RetinaNet) — downweights well-classified pixels so training focuses on hard examples. In practice, combine: total_loss = cross_entropy + dice_loss works best for medical imaging.

Always check the IoU per class in your validation metrics — global accuracy hides bad performance on small/rare classes that are often the ones that matter most.

?Knowledge Check

Progress is saved in your browser — no account needed.

Need an AI engineer or data scientist?

I build custom ML models, AI agents, computer vision, and automation — from idea to production.