On this tutorial, we construct an end-to-end 3D medical picture segmentation pipeline utilizing MONAI to phase the spleen on the Medical Segmentation Decathlon Task09 dataset. We work with volumetric CT scans, apply medical imaging transformations similar to orientation alignment, voxel-spacing normalization, depth windowing, foreground cropping, and patch-based sampling, after which practice a 3D UNet mannequin for binary organ segmentation. We additionally use blended precision coaching, DiceCE loss, sliding-window inference, Cube-based validation, and qualitative visualization to grasp how the mannequin learns and the way its predictions evaluate with the ground-truth masks. Additionally, we transfer from uncooked medical volumes to an entire practice–validate–visualize segmentation system.
!pip set up -q "monai[nibabel,tqdm,matplotlib]==1.5.2" 2>/dev/null
import os, time, glob, tempfile, warnings
import numpy as np
import torch
import matplotlib.pyplot as plt
from torch.amp import autocast, GradScaler
from monai.apps import DecathlonDataset
from monai.knowledge import DataLoader, decollate_batch
from monai.networks.nets import UNet
from monai.networks.layers import Norm
from monai.losses import DiceCELoss
from monai.metrics import DiceMetric
from monai.inferers import sliding_window_inference
from monai.utils import set_determinism
from monai.transforms import (
Compose, LoadImaged, EnsureChannelFirstd, EnsureTyped, Orientationd,
Spacingd, ScaleIntensityRanged, CropForegroundd, RandCropByPosNegLabeld,
RandFlipd, RandRotate90d, RandShiftIntensityd, AsDiscrete,
)
warnings.filterwarnings("ignore")
We begin by putting in MONAI with the required medical imaging and visualization dependencies. We then import PyTorch, NumPy, Matplotlib, and the primary MONAI modules wanted for datasets, transforms, mannequin coaching, metrics, and inference. We additionally suppress warnings to maintain the pocket book output clear whereas we concentrate on the segmentation workflow.
QUICK_RUN = True
machine = torch.machine("cuda" if torch.cuda.is_available() else "cpu")
root_dir = tempfile.mkdtemp()
roi_size = (96, 96, 96)
num_samples = 4
batch_size = 2
max_epochs = 15 if QUICK_RUN else 200
val_every = 3
train_cache = 8 if QUICK_RUN else 24
val_cache = 2 if QUICK_RUN else 6
set_determinism(seed=0)
print(f"Machine: {machine} | epochs: {max_epochs} | knowledge dir: {root_dir}")
train_transforms = Compose(frequent + [
image_key="image", image_threshold=0),
RandFlipd(keys=["image", "label"], prob=0.2, spatial_axis=0),
RandFlipd(keys=["image", "label"], prob=0.2, spatial_axis=1),
RandFlipd(keys=["image", "label"], prob=0.2, spatial_axis=2),
RandRotate90d(keys=["image", "label"], prob=0.2, max_k=3),
RandShiftIntensityd(keys=["image"], offsets=0.10, prob=0.5),
EnsureTyped(keys=["image", "label"]),
])
val_transforms = Compose(frequent + [EnsureTyped(keys=["image", "label"])])
We outline the primary configuration for the tutorial, together with the machine, dataset listing, patch dimension, batch dimension, variety of epochs, and cache settings. We then create the preprocessing pipeline for CT volumes by loading photographs, aligning orientation, resampling voxel spacing, scaling intensities, and cropping the foreground. We additionally outline the coaching and validation transforms, with the coaching pipeline together with random crops, flips, rotations, and depth shifts.
train_ds = DecathlonDataset(
root_dir=root_dir, activity="Task09_Spleen", part="coaching",
remodel=train_transforms, obtain=True, val_frac=0.2,
cache_num=train_cache, num_workers=2, seed=0)
val_ds = DecathlonDataset(
root_dir=root_dir, activity="Task09_Spleen", part="validation",
remodel=val_transforms, obtain=False, val_frac=0.2,
cache_num=val_cache, num_workers=2, seed=0)
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,
num_workers=2, pin_memory=torch.cuda.is_available())
val_loader = DataLoader(val_ds, batch_size=1, shuffle=False,
num_workers=1, pin_memory=torch.cuda.is_available())
print(f"Practice volumes: {len(train_ds)} | Val volumes: {len(val_ds)}")
loss_fn = DiceCELoss(to_onehot_y=True, softmax=True)
optimizer = torch.optim.AdamW(mannequin.parameters(), lr=1e-4, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_epochs)
scaler = GradScaler("cuda", enabled=torch.cuda.is_available())
dice_metric = DiceMetric(include_background=False, discount="imply")
post_pred = Compose([AsDiscrete(argmax=True, to_onehot=2)])
post_label = Compose([AsDiscrete(to_onehot=2)])
We load the Medical Segmentation Decathlon Task09 Spleen dataset utilizing MONAI’s DecathlonDataset. We break up the info into coaching and validation sections, apply the suitable transforms, and wrap each datasets with PyTorch-style knowledge loaders. We then create a 3D UNet mannequin, outline the DiceCE loss, arrange the AdamW optimizer, learning-rate scheduler, mixed-precision scaler, Cube metric, and post-processing steps.
best_dice, best_epoch = -1.0, -1
loss_hist, dice_hist, dice_epochs = [], [], []
best_path = os.path.be a part of(root_dir, "best_spleen_unet.pth")
for epoch in vary(1, max_epochs + 1):
mannequin.practice(); epoch_loss, t0 = 0.0, time.time()
for batch in train_loader:
x, y = batch["image"].to(machine), batch["label"].to(machine)
optimizer.zero_grad(set_to_none=True)
with autocast("cuda", enabled=torch.cuda.is_available()):
logits = mannequin(x)
loss = loss_fn(logits, y)
scaler.scale(loss).backward()
scaler.step(optimizer); scaler.replace()
epoch_loss += loss.merchandise()
scheduler.step()
epoch_loss /= len(train_loader); loss_hist.append(epoch_loss)
print(f"[{epoch:3d}/{max_epochs}] loss={epoch_loss:.4f} "
f"lr={scheduler.get_last_lr()[0]:.2e} ({time.time()-t0:.0f}s)")
if epoch % val_every == 0 or epoch == max_epochs:
mannequin.eval(); dice_metric.reset()
with torch.no_grad():
for vb in val_loader:
vx, vy = vb["image"].to(machine), vb["label"].to(machine)
with autocast("cuda", enabled=torch.cuda.is_available()):
vout = sliding_window_inference(vx, roi_size, 4, mannequin,
overlap=0.5)
vout = [post_pred(o) for o in decollate_batch(vout)]
vlab = [post_label(o) for o in decollate_batch(vy)]
dice_metric(y_pred=vout, y=vlab)
d = dice_metric.mixture().merchandise()
dice_hist.append(d); dice_epochs.append(epoch)
if d > best_dice:
best_dice, best_epoch = d, epoch
torch.save(mannequin.state_dict(), best_path)
print(f" >> val Cube={d:.4f} (finest={best_dice:.4f} @ {best_epoch})")
print(f"nDone. Finest imply Cube {best_dice:.4f} at epoch {best_epoch}.")
We run the total coaching loop, the place every epoch trains the 3D UNet on cropped volumetric patches from the spleen dataset. We use computerized blended precision to scale back reminiscence utilization and pace up coaching when a GPU is on the market. We additionally validate the mannequin at common intervals utilizing sliding-window inference, monitor the Cube rating, and save the best-performing checkpoint.
fig, ax = plt.subplots(1, 2, figsize=(12, 4))
ax[0].plot(vary(1, len(loss_hist)+1), loss_hist, "-o", ms=3)
ax[0].set(title="Coaching loss", xlabel="epoch", ylabel="DiceCE loss")
ax[1].plot(dice_epochs, dice_hist, "-o", colour="seagreen", ms=4)
ax[1].set(title="Validation imply Cube", xlabel="epoch", ylabel="Cube"); ax[1].set_ylim(0, 1)
plt.tight_layout(); plt.present()
mannequin.load_state_dict(torch.load(best_path, map_location=machine)); mannequin.eval()
with torch.no_grad():
pattern = subsequent(iter(val_loader))
img = pattern["image"].to(machine)
with autocast("cuda", enabled=torch.cuda.is_available()):
pred = sliding_window_inference(img, roi_size, 4, mannequin, overlap=0.5)
pred = torch.argmax(pred, dim=1).cpu().numpy()[0]
img_np, lab_np = img.cpu().numpy()[0, 0], pattern["label"].numpy()[0, 0]
z = int(np.argmax(lab_np.sum(axis=(0, 1))))
fig, ax = plt.subplots(1, 3, figsize=(13, 5))
ax[0].imshow(img_np[:, :, z], cmap="grey"); ax[0].set_title("CT slice")
ax[1].imshow(lab_np[:, :, z], cmap="viridis"); ax[1].set_title("Floor reality")
ax[2].imshow(pred[:, :, z], cmap="viridis"); ax[2].set_title("Prediction")
for a in ax: a.axis("off")
plt.tight_layout(); plt.present()
We first plot the coaching loss and validation Cube rating to see how the mannequin improves over time. We then reload the best-saved mannequin checkpoint and run inference on a single validation quantity utilizing sliding-window prediction. We visualize the CT slice, ground-truth masks, and predicted segmentation facet by facet to examine the mannequin’s qualitative efficiency.
In conclusion, we completed a sensible MONAI-based workflow for 3D spleen segmentation utilizing a 3D UNet mannequin. We ready the Medical Segmentation Decathlon dataset, reworked and augmented the CT volumes, skilled the mannequin with DiceCE loss, validated it utilizing sliding-window inference, and tracked each loss and Cube rating over time. We additionally inspected the ultimate prediction visually by evaluating the CT slice, ground-truth label, and mannequin output facet by facet. Now, we’ve a transparent understanding of how MONAI helps medical segmentation duties from knowledge loading and preprocessing to mannequin coaching, analysis, checkpointing, and qualitative evaluation.
Try the Full Codes with Pocket book. Additionally, be happy to observe us on Twitter and don’t neglect to hitch our 150k+ ML SubReddit and Subscribe to our Publication. Wait! are you on telegram? now you’ll be able to be a part of us on telegram as effectively.
Must accomplice with us for selling your GitHub Repo OR Hugging Face Web page OR Product Launch OR Webinar and so forth.? Join with us
The publish A Coding Implementation on MONAI for Finish-to-Finish 3D Spleen Segmentation Utilizing UNet on Medical CT Volumes appeared first on MarkTechPost.
