Unet 3d segmentation
Setup environment¶
In [ ]:
Copied!
!python -c "import monai" || pip install -q "monai-weekly[gdown, nibabel, tqdm, ignite]"
!python -c "import wandb" || pip install -q wandb
!pip install -q --upgrade git+https://github.com/soumik12345/wandb-addons@
!python -c "import monai" || pip install -q "monai-weekly[gdown, nibabel, tqdm, ignite]"
!python -c "import wandb" || pip install -q wandb
!pip install -q --upgrade git+https://github.com/soumik12345/wandb-addons@
In [ ]:
Copied!
import glob
import logging
import os
from pathlib import Path
import shutil
import sys
import tempfile
import nibabel as nib
import numpy as np
from monai.config import print_config
from monai.data import (
ArrayDataset,
create_test_image_3d,
decollate_batch,
DataLoader
)
from monai.handlers import (
MeanDice,
StatsHandler,
TensorBoardImageHandler,
TensorBoardStatsHandler,
)
from monai.losses import DiceLoss
from monai.networks.nets import UNet
from monai.transforms import (
Activations,
EnsureChannelFirst,
AsDiscrete,
Compose,
LoadImage,
RandSpatialCrop,
Resize,
ScaleIntensity,
)
from monai.utils import first
import wandb
from wandb_addons.monai import WandbStatsHandler, WandbModelCheckpointHandler
import ignite
import torch
from tqdm.auto import tqdm
print_config()
import glob
import logging
import os
from pathlib import Path
import shutil
import sys
import tempfile
import nibabel as nib
import numpy as np
from monai.config import print_config
from monai.data import (
ArrayDataset,
create_test_image_3d,
decollate_batch,
DataLoader
)
from monai.handlers import (
MeanDice,
StatsHandler,
TensorBoardImageHandler,
TensorBoardStatsHandler,
)
from monai.losses import DiceLoss
from monai.networks.nets import UNet
from monai.transforms import (
Activations,
EnsureChannelFirst,
AsDiscrete,
Compose,
LoadImage,
RandSpatialCrop,
Resize,
ScaleIntensity,
)
from monai.utils import first
import wandb
from wandb_addons.monai import WandbStatsHandler, WandbModelCheckpointHandler
import ignite
import torch
from tqdm.auto import tqdm
print_config()
In [ ]:
Copied!
directory = os.environ.get("MONAI_DATA_DIRECTORY")
root_dir = tempfile.mkdtemp() if directory is None else directory
print(root_dir)
directory = os.environ.get("MONAI_DATA_DIRECTORY")
root_dir = tempfile.mkdtemp() if directory is None else directory
print(root_dir)
Setup Weights & Biases run¶
In [ ]:
Copied!
log_dir = os.path.join(root_dir, "logs")
wandb.tensorboard.patch(log_dir)
wandb.init(project="monai-integration", save_code=True, sync_tensorboard=True)
log_dir = os.path.join(root_dir, "logs")
wandb.tensorboard.patch(log_dir)
wandb.init(project="monai-integration", save_code=True, sync_tensorboard=True)
Setup logging¶
In [ ]:
Copied!
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
Setup demo data¶
In [ ]:
Copied!
for i in tqdm(range(40)):
im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1)
n = nib.Nifti1Image(im, np.eye(4))
nib.save(n, os.path.join(root_dir, f"im{i}.nii.gz"))
n = nib.Nifti1Image(seg, np.eye(4))
nib.save(n, os.path.join(root_dir, f"seg{i}.nii.gz"))
images = sorted(glob.glob(os.path.join(root_dir, "im*.nii.gz")))
segs = sorted(glob.glob(os.path.join(root_dir, "seg*.nii.gz")))
for i in tqdm(range(40)):
im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1)
n = nib.Nifti1Image(im, np.eye(4))
nib.save(n, os.path.join(root_dir, f"im{i}.nii.gz"))
n = nib.Nifti1Image(seg, np.eye(4))
nib.save(n, os.path.join(root_dir, f"seg{i}.nii.gz"))
images = sorted(glob.glob(os.path.join(root_dir, "im*.nii.gz")))
segs = sorted(glob.glob(os.path.join(root_dir, "seg*.nii.gz")))
Setup transforms, dataset¶
In [ ]:
Copied!
# Define transforms for image and segmentation
imtrans = Compose(
[
LoadImage(image_only=True),
ScaleIntensity(),
EnsureChannelFirst(),
RandSpatialCrop((96, 96, 96), random_size=False),
]
)
segtrans = Compose(
[
LoadImage(image_only=True),
EnsureChannelFirst(),
RandSpatialCrop((96, 96, 96), random_size=False),
]
)
# Define nifti dataset, dataloader
ds = ArrayDataset(images, imtrans, segs, segtrans)
loader = DataLoader(ds, batch_size=10, num_workers=2, pin_memory=torch.cuda.is_available())
im, seg = first(loader)
print(im.shape, seg.shape)
# Define transforms for image and segmentation
imtrans = Compose(
[
LoadImage(image_only=True),
ScaleIntensity(),
EnsureChannelFirst(),
RandSpatialCrop((96, 96, 96), random_size=False),
]
)
segtrans = Compose(
[
LoadImage(image_only=True),
EnsureChannelFirst(),
RandSpatialCrop((96, 96, 96), random_size=False),
]
)
# Define nifti dataset, dataloader
ds = ArrayDataset(images, imtrans, segs, segtrans)
loader = DataLoader(ds, batch_size=10, num_workers=2, pin_memory=torch.cuda.is_available())
im, seg = first(loader)
print(im.shape, seg.shape)
Create Model, Loss, Optimizer¶
In [ ]:
Copied!
# Create UNet, DiceLoss and Adam optimizer
device = torch.device("cuda:0")
net = UNet(
spatial_dims=3,
in_channels=1,
out_channels=1,
channels=(16, 32, 64, 128, 256),
strides=(2, 2, 2, 2),
num_res_units=2,
).to(device)
loss = DiceLoss(sigmoid=True)
lr = 1e-3
opt = torch.optim.Adam(net.parameters(), lr)
# Create UNet, DiceLoss and Adam optimizer
device = torch.device("cuda:0")
net = UNet(
spatial_dims=3,
in_channels=1,
out_channels=1,
channels=(16, 32, 64, 128, 256),
strides=(2, 2, 2, 2),
num_res_units=2,
).to(device)
loss = DiceLoss(sigmoid=True)
lr = 1e-3
opt = torch.optim.Adam(net.parameters(), lr)
Create supervised_trainer using ignite¶
In [ ]:
Copied!
# Create trainer
trainer = ignite.engine.create_supervised_trainer(net, opt, loss, device, False)
# Create trainer
trainer = ignite.engine.create_supervised_trainer(net, opt, loss, device, False)
Setup event handlers for checkpointing and logging¶
In [ ]:
Copied!
# optional section for checkpoint and tensorboard logging
# adding checkpoint handler to save models (network
# params and optimizer stats) during training
log_dir = os.path.join(root_dir, "logs")
checkpoint_handler = ignite.handlers.ModelCheckpoint(log_dir, "net", n_saved=10, require_empty=False)
trainer.add_event_handler(
event_name=ignite.engine.Events.EPOCH_COMPLETED,
handler=checkpoint_handler,
to_save={"net": net, "opt": opt},
)
# StatsHandler prints loss at every iteration
# user can also customize print functions and can use output_transform to convert
# engine.state.output if it's not a loss value
train_stats_handler = StatsHandler(name="trainer", output_transform=lambda x: x)
train_stats_handler.attach(trainer)
# TensorBoardStatsHandler plots loss at every iteration
train_tensorboard_stats_handler = TensorBoardStatsHandler(log_dir=log_dir, output_transform=lambda x: x)
train_tensorboard_stats_handler.attach(trainer)
# WandbStatsHandler logs loss at every iteration to Weights & Biases
train_wandb_stats_handler = WandbStatsHandler(output_transform=lambda x: x)
train_wandb_stats_handler.attach(trainer)
# CheckpointHandler with `WandbModelCheckpointSaver` logs model
# checkpoints at every iteration
checkpoint_handler = Checkpoint(
{"model": net, "optimizer": opt},
WandbModelCheckpointSaver(),
n_saved=1,
filename_prefix="best_checkpoint",
score_name=metric_name,
global_step_transform=global_step_from_engine(trainer)
)
evaluator.add_event_handler(Events.COMPLETED, checkpoint_handler)
# optional section for checkpoint and tensorboard logging
# adding checkpoint handler to save models (network
# params and optimizer stats) during training
log_dir = os.path.join(root_dir, "logs")
checkpoint_handler = ignite.handlers.ModelCheckpoint(log_dir, "net", n_saved=10, require_empty=False)
trainer.add_event_handler(
event_name=ignite.engine.Events.EPOCH_COMPLETED,
handler=checkpoint_handler,
to_save={"net": net, "opt": opt},
)
# StatsHandler prints loss at every iteration
# user can also customize print functions and can use output_transform to convert
# engine.state.output if it's not a loss value
train_stats_handler = StatsHandler(name="trainer", output_transform=lambda x: x)
train_stats_handler.attach(trainer)
# TensorBoardStatsHandler plots loss at every iteration
train_tensorboard_stats_handler = TensorBoardStatsHandler(log_dir=log_dir, output_transform=lambda x: x)
train_tensorboard_stats_handler.attach(trainer)
# WandbStatsHandler logs loss at every iteration to Weights & Biases
train_wandb_stats_handler = WandbStatsHandler(output_transform=lambda x: x)
train_wandb_stats_handler.attach(trainer)
# CheckpointHandler with `WandbModelCheckpointSaver` logs model
# checkpoints at every iteration
checkpoint_handler = Checkpoint(
{"model": net, "optimizer": opt},
WandbModelCheckpointSaver(),
n_saved=1,
filename_prefix="best_checkpoint",
score_name=metric_name,
global_step_transform=global_step_from_engine(trainer)
)
evaluator.add_event_handler(Events.COMPLETED, checkpoint_handler)
Add Validation every N epochs¶
In [ ]:
Copied!
# optional section for model validation during training
validation_every_n_epochs = 1
# Set parameters for validation
metric_name = "Mean_Dice"
# add evaluation metric to the evaluator engine
val_metrics = {metric_name: MeanDice()}
post_pred = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])
post_label = Compose([AsDiscrete(threshold=0.5)])
# Ignite evaluator expects batch=(img, seg) and
# returns output=(y_pred, y) at every iteration,
# user can add output_transform to return other values
evaluator = ignite.engine.create_supervised_evaluator(
net,
val_metrics,
device,
True,
output_transform=lambda x, y, y_pred: (
[post_pred(i) for i in decollate_batch(y_pred)],
[post_label(i) for i in decollate_batch(y)],
),
)
# create a validation data loader
val_imtrans = Compose(
[
LoadImage(image_only=True),
ScaleIntensity(),
EnsureChannelFirst(),
Resize((96, 96, 96)),
]
)
val_segtrans = Compose(
[
LoadImage(image_only=True),
EnsureChannelFirst(),
Resize((96, 96, 96)),
]
)
val_ds = ArrayDataset(images[21:], val_imtrans, segs[21:], val_segtrans)
val_loader = DataLoader(val_ds, batch_size=5, num_workers=8, pin_memory=torch.cuda.is_available())
@trainer.on(ignite.engine.Events.EPOCH_COMPLETED(every=validation_every_n_epochs))
def run_validation(engine):
evaluator.run(val_loader)
# Add stats event handler to print validation stats via evaluator
val_stats_handler = StatsHandler(
name="evaluator",
# no need to print loss value, so disable per iteration output
output_transform=lambda x: None,
# fetch global epoch number from trainer
global_epoch_transform=lambda x: trainer.state.epoch,
)
val_stats_handler.attach(evaluator)
# add handler to record metrics to TensorBoard at every validation epoch
val_tensorboard_stats_handler = TensorBoardStatsHandler(
log_dir=log_dir,
# no need to plot loss value, so disable per iteration output
output_transform=lambda x: None,
# fetch global epoch number from trainer
global_epoch_transform=lambda x: trainer.state.epoch,
)
val_tensorboard_stats_handler.attach(evaluator)
val_wandb_stats_handler = WandbStatsHandler(
output_transform=lambda x: None,
global_epoch_transform=lambda x: trainer.state.epoch,
)
val_wandb_stats_handler.attach(evaluator)
# add handler to draw the first image and the corresponding
# label and model output in the last batch
# here we draw the 3D output as GIF format along Depth
# axis, at every validation epoch
val_tensorboard_image_handler = TensorBoardImageHandler(
log_dir=log_dir,
batch_transform=lambda batch: (batch[0], batch[1]),
output_transform=lambda output: output[0],
global_iter_transform=lambda x: trainer.state.epoch,
)
evaluator.add_event_handler(
event_name=ignite.engine.Events.EPOCH_COMPLETED,
handler=val_tensorboard_image_handler,
)
# The `Checkpoint` handler for PyTorch Ignite along with `WandbModelCheckpointSaver()`
# logs model checkpoints as WandB Artifacts.
checkpoint_handler = Checkpoint(
{"model": model, "optimizer": optimizer},
WandbModelCheckpointSaver(),
n_saved=1,
filename_prefix="best_checkpoint",
score_name=metric_name,
global_step_transform=global_step_from_engine(trainer)
)
evaluator.add_event_handler(Events.COMPLETED, checkpoint_handler)
# optional section for model validation during training
validation_every_n_epochs = 1
# Set parameters for validation
metric_name = "Mean_Dice"
# add evaluation metric to the evaluator engine
val_metrics = {metric_name: MeanDice()}
post_pred = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])
post_label = Compose([AsDiscrete(threshold=0.5)])
# Ignite evaluator expects batch=(img, seg) and
# returns output=(y_pred, y) at every iteration,
# user can add output_transform to return other values
evaluator = ignite.engine.create_supervised_evaluator(
net,
val_metrics,
device,
True,
output_transform=lambda x, y, y_pred: (
[post_pred(i) for i in decollate_batch(y_pred)],
[post_label(i) for i in decollate_batch(y)],
),
)
# create a validation data loader
val_imtrans = Compose(
[
LoadImage(image_only=True),
ScaleIntensity(),
EnsureChannelFirst(),
Resize((96, 96, 96)),
]
)
val_segtrans = Compose(
[
LoadImage(image_only=True),
EnsureChannelFirst(),
Resize((96, 96, 96)),
]
)
val_ds = ArrayDataset(images[21:], val_imtrans, segs[21:], val_segtrans)
val_loader = DataLoader(val_ds, batch_size=5, num_workers=8, pin_memory=torch.cuda.is_available())
@trainer.on(ignite.engine.Events.EPOCH_COMPLETED(every=validation_every_n_epochs))
def run_validation(engine):
evaluator.run(val_loader)
# Add stats event handler to print validation stats via evaluator
val_stats_handler = StatsHandler(
name="evaluator",
# no need to print loss value, so disable per iteration output
output_transform=lambda x: None,
# fetch global epoch number from trainer
global_epoch_transform=lambda x: trainer.state.epoch,
)
val_stats_handler.attach(evaluator)
# add handler to record metrics to TensorBoard at every validation epoch
val_tensorboard_stats_handler = TensorBoardStatsHandler(
log_dir=log_dir,
# no need to plot loss value, so disable per iteration output
output_transform=lambda x: None,
# fetch global epoch number from trainer
global_epoch_transform=lambda x: trainer.state.epoch,
)
val_tensorboard_stats_handler.attach(evaluator)
val_wandb_stats_handler = WandbStatsHandler(
output_transform=lambda x: None,
global_epoch_transform=lambda x: trainer.state.epoch,
)
val_wandb_stats_handler.attach(evaluator)
# add handler to draw the first image and the corresponding
# label and model output in the last batch
# here we draw the 3D output as GIF format along Depth
# axis, at every validation epoch
val_tensorboard_image_handler = TensorBoardImageHandler(
log_dir=log_dir,
batch_transform=lambda batch: (batch[0], batch[1]),
output_transform=lambda output: output[0],
global_iter_transform=lambda x: trainer.state.epoch,
)
evaluator.add_event_handler(
event_name=ignite.engine.Events.EPOCH_COMPLETED,
handler=val_tensorboard_image_handler,
)
# The `Checkpoint` handler for PyTorch Ignite along with `WandbModelCheckpointSaver()`
# logs model checkpoints as WandB Artifacts.
checkpoint_handler = Checkpoint(
{"model": model, "optimizer": optimizer},
WandbModelCheckpointSaver(),
n_saved=1,
filename_prefix="best_checkpoint",
score_name=metric_name,
global_step_transform=global_step_from_engine(trainer)
)
evaluator.add_event_handler(Events.COMPLETED, checkpoint_handler)
Run training loop¶
In [ ]:
Copied!
# create a training data loader
train_ds = ArrayDataset(images[:20], imtrans, segs[:20], segtrans)
train_loader = DataLoader(
train_ds,
batch_size=5,
shuffle=True,
num_workers=8,
pin_memory=torch.cuda.is_available(),
)
max_epochs = 10
state = trainer.run(train_loader, max_epochs)
# create a training data loader
train_ds = ArrayDataset(images[:20], imtrans, segs[:20], segtrans)
train_loader = DataLoader(
train_ds,
batch_size=5,
shuffle=True,
num_workers=8,
pin_memory=torch.cuda.is_available(),
)
max_epochs = 10
state = trainer.run(train_loader, max_epochs)
In [ ]:
Copied!
wandb.finish()
if directory is None:
shutil.rmtree(root_dir)
wandb.finish()
if directory is None:
shutil.rmtree(root_dir)