← ~/projects/

Attention-ResUNet CT Segmentation

PyTorch Medical Imaging CT Segmentation Attention Gates Residual Learning
Deep learning model for liver and hepatic vessel segmentation from CT images. Combines UNet's encoder-decoder architecture with residual connections and spatial attention mechanisms. Trained on the Medical Segmentation Decathlon dataset.
Architecture Overview

Three core components: UNet structure for feature extraction/reconstruction, residual connections for stable gradient flow in deeper networks, and attention gates for focusing on relevant anatomical regions.

INPUT (1×H×W) ──────────────────────────────────────────────────────────────────┐
      │                                                                          │
      ▼                                                                          │
┌─────────────────────── ENCODER ───────────────────────┐                       │
│  ResBlock    ResBlock    ResBlock    ResBlock   Bridge │                       │
│  (8 feat)   (16 feat)   (32 feat)   (64 feat)  (128)  │                       │
│     ↓pool      ↓pool      ↓pool      ↓pool            │                       │
└──────┬──────────┬──────────┬──────────┬───────────────┘                       │
       │          │          │          │          │                             │
       │    Skip Connections (with Attention Gates)                              │
       │          │          │          │          │                             │
       ▼          ▼          ▼          ▼          ▼                             │
┌─────────────────────── DECODER ───────────────────────┐                       │
│  AttnGate   AttnGate   AttnGate   AttnGate            │                       │
│  +UpConv    +UpConv    +UpConv    +UpConv             │                       │
│  (64 feat)  (32 feat)  (16 feat)  (8 feat)            │                       │
└───────────────────────────────────────────────────────┘                       │
                                           │                                     │
                                           ▼                                     │
                                   OUTPUT (3×H×W)  ←─────────────────────────────┘
                                   [background, liver/vessel, tumor]
Attention Gate Mechanism

Attention gates filter skip connections, learning to suppress irrelevant regions (background) while highlighting salient features (liver boundaries, vessel structures).

Attention(Q, K, V) = softmax(QK^T / √d_k) · V

     Skip (encoder)          Gating (decoder)
           │                       │
           ▼                       ▼
       ┌───────┐               ┌───────┐
       │ W_x   │ 1×1 conv      │ W_g   │ 1×1 conv
       └───┬───┘               └───┬───┘
           │                       │
           └─────────┬─────────────┘
                     ▼
                   ADD → ReLU → ψ (1×1) → Sigmoid
                     │
                     ▼
                α (attention coefficients)
                     │
                     ▼
              output = α ⊙ skip_features
Residual Block
     Input ─────────────────────────┐
       │                            │
       ▼                            │ (1×1 conv if channels differ)
  Conv 3×3 → BatchNorm → ReLU       │
       │                            │
       ▼                            │
  Conv 3×3 → BatchNorm              │
       │                            │
       ▼                            │
      ADD ←─────────────────────────┘
       │
       ▼
     ReLU
       │
     Output

Residual formulation: y = x + F(x, {W_i})
Dataset
TaskSourceImagesTrain/Val Split
Liver + TumorMedical Decathlon54 (filtered <300 slices)43 / 11
Hepatic VesselsMedical Decathlon216 (filtered <80 slices)172 / 44
Training Configuration
ParameterLiver ModelVessel Model
Base Features816
Batch Size32
OptimizerAdam
Initial LR1 × 10⁻⁴
LR SchedulerReduceLROnPlateau (factor=0.5, patience=3)
Early StoppingPatience = 5 epochs
Epochs (auto-stopped)2518
Loss FunctionDice-Cross Entropy
Results
TaskStructureDice Score
Liver SegmentationLiver0.88
Liver SegmentationTumor0.73
Vessel SegmentationHepatic Vessels0.83
Vessel SegmentationTumor0.66
Potential improvements: Class imbalance remains a challenge in medical segmentation— background pixels dominate 3D CT volumes. Weighted Dice loss or focal loss could further improve tumor segmentation by penalizing errors on minority classes more heavily.
Dice Coefficient
            2 × |P ∩ T|
    Dice = ─────────────
            |P| + |T|

    P = predicted mask
    T = ground truth mask
    Range: [0, 1] — higher is better
Usage
# Load and run inference
from modules.model_loader import load_models
from modules.func import normalize_image

liver_model, vessel_model, device = load_models()

# Process NIfTI file
import nibabel as nib
img = nib.load("scan.nii.gz")
data = img.get_fdata()

# Segment slice by slice
for z in range(data.shape[2]):
    slice_tensor = torch.tensor(normalize_image(data[:,:,z])).unsqueeze(0).unsqueeze(0)
    prediction = torch.argmax(torch.softmax(liver_model(slice_tensor), dim=1), dim=1)