← ~/projects/
Attention-ResUNet CT Segmentation
PyTorch
Medical Imaging
CT Segmentation
Attention Gates
Residual Learning
Deep learning model for liver and hepatic vessel segmentation from CT images.
Combines UNet's encoder-decoder architecture with residual connections and spatial attention mechanisms.
Trained on the Medical Segmentation Decathlon dataset.
Architecture Overview
Three core components: UNet structure for feature extraction/reconstruction, residual connections
for stable gradient flow in deeper networks, and attention gates for focusing on relevant anatomical regions.
INPUT (1×H×W) ──────────────────────────────────────────────────────────────────┐
│ │
▼ │
┌─────────────────────── ENCODER ───────────────────────┐ │
│ ResBlock ResBlock ResBlock ResBlock Bridge │ │
│ (8 feat) (16 feat) (32 feat) (64 feat) (128) │ │
│ ↓pool ↓pool ↓pool ↓pool │ │
└──────┬──────────┬──────────┬──────────┬───────────────┘ │
│ │ │ │ │ │
│ Skip Connections (with Attention Gates) │
│ │ │ │ │ │
▼ ▼ ▼ ▼ ▼ │
┌─────────────────────── DECODER ───────────────────────┐ │
│ AttnGate AttnGate AttnGate AttnGate │ │
│ +UpConv +UpConv +UpConv +UpConv │ │
│ (64 feat) (32 feat) (16 feat) (8 feat) │ │
└───────────────────────────────────────────────────────┘ │
│ │
▼ │
OUTPUT (3×H×W) ←─────────────────────────────┘
[background, liver/vessel, tumor]
Attention Gate Mechanism
Attention gates filter skip connections, learning to suppress irrelevant regions (background)
while highlighting salient features (liver boundaries, vessel structures).
Attention(Q, K, V) = softmax(QK^T / √d_k) · V
Skip (encoder) Gating (decoder)
│ │
▼ ▼
┌───────┐ ┌───────┐
│ W_x │ 1×1 conv │ W_g │ 1×1 conv
└───┬───┘ └───┬───┘
│ │
└─────────┬─────────────┘
▼
ADD → ReLU → ψ (1×1) → Sigmoid
│
▼
α (attention coefficients)
│
▼
output = α ⊙ skip_features
Residual Block
Input ─────────────────────────┐
│ │
▼ │ (1×1 conv if channels differ)
Conv 3×3 → BatchNorm → ReLU │
│ │
▼ │
Conv 3×3 → BatchNorm │
│ │
▼ │
ADD ←─────────────────────────┘
│
▼
ReLU
│
Output
Residual formulation: y = x + F(x, {W_i})
Dataset
| Task | Source | Images | Train/Val Split |
| Liver + Tumor | Medical Decathlon | 54 (filtered <300 slices) | 43 / 11 |
| Hepatic Vessels | Medical Decathlon | 216 (filtered <80 slices) | 172 / 44 |
Training Configuration
| Parameter | Liver Model | Vessel Model |
| Base Features | 8 | 16 |
| Batch Size | 32 |
| Optimizer | Adam |
| Initial LR | 1 × 10⁻⁴ |
| LR Scheduler | ReduceLROnPlateau (factor=0.5, patience=3) |
| Early Stopping | Patience = 5 epochs |
| Epochs (auto-stopped) | 25 | 18 |
| Loss Function | Dice-Cross Entropy |
Results
| Task | Structure | Dice Score |
| Liver Segmentation | Liver | 0.88 |
| Liver Segmentation | Tumor | 0.73 |
| Vessel Segmentation | Hepatic Vessels | 0.83 |
| Vessel Segmentation | Tumor | 0.66 |
Potential improvements: Class imbalance remains a challenge in medical segmentation—
background pixels dominate 3D CT volumes. Weighted Dice loss or focal loss could further improve
tumor segmentation by penalizing errors on minority classes more heavily.
Dice Coefficient
2 × |P ∩ T|
Dice = ─────────────
|P| + |T|
P = predicted mask
T = ground truth mask
Range: [0, 1] — higher is better
Usage
# Load and run inference
from modules.model_loader import load_models
from modules.func import normalize_image
liver_model, vessel_model, device = load_models()
# Process NIfTI file
import nibabel as nib
img = nib.load("scan.nii.gz")
data = img.get_fdata()
# Segment slice by slice
for z in range(data.shape[2]):
slice_tensor = torch.tensor(normalize_image(data[:,:,z])).unsqueeze(0).unsqueeze(0)
prediction = torch.argmax(torch.softmax(liver_model(slice_tensor), dim=1), dim=1)