Densenet training dict
In [ ]:
Copied!
!mkdir dataset
%cd dataset
!wget http://biomedic.doc.ic.ac.uk/brain-development/downloads/IXI/IXI-T1.tar
!wget http://biomedic.doc.ic.ac.uk/brain-development/downloads/IXI/IXI-T2.tar
!tar -xf IXI-T1.tar && tar -xf IXI-T2.tar && rm -rf IXI-T1.tar && rm -rf IXI-T2.tar
%cd ..
!git clone https://github.com/soumik12345/wandb-addons
!pip install -q --upgrade pip setuptools
!pip install -q -e wandb-addons[monai]
!mkdir dataset
%cd dataset
!wget http://biomedic.doc.ic.ac.uk/brain-development/downloads/IXI/IXI-T1.tar
!wget http://biomedic.doc.ic.ac.uk/brain-development/downloads/IXI/IXI-T2.tar
!tar -xf IXI-T1.tar && tar -xf IXI-T2.tar && rm -rf IXI-T1.tar && rm -rf IXI-T2.tar
%cd ..
!git clone https://github.com/soumik12345/wandb-addons
!pip install -q --upgrade pip setuptools
!pip install -q -e wandb-addons[monai]
In [ ]:
Copied!
import os
import sys
from glob import glob
import numpy as np
import wandb
import torch
from ignite.engine import Events, _prepare_batch, create_supervised_evaluator, create_supervised_trainer
from ignite.handlers import EarlyStopping, ModelCheckpoint
import monai
from monai.data import decollate_batch, DataLoader
from monai.handlers import ROCAUC, StatsHandler, TensorBoardStatsHandler, stopping_fn_from_metric
from monai.transforms import Activations, AsDiscrete, Compose, LoadImaged, RandRotate90d, Resized, ScaleIntensityd
from wandb_addons.monai import WandbStatsHandler, WandbModelCheckpointHandler
monai.config.print_config()
import os
import sys
from glob import glob
import numpy as np
import wandb
import torch
from ignite.engine import Events, _prepare_batch, create_supervised_evaluator, create_supervised_trainer
from ignite.handlers import EarlyStopping, ModelCheckpoint
import monai
from monai.data import decollate_batch, DataLoader
from monai.handlers import ROCAUC, StatsHandler, TensorBoardStatsHandler, stopping_fn_from_metric
from monai.transforms import Activations, AsDiscrete, Compose, LoadImaged, RandRotate90d, Resized, ScaleIntensityd
from wandb_addons.monai import WandbStatsHandler, WandbModelCheckpointHandler
monai.config.print_config()
In [ ]:
Copied!
wandb.tensorboard.patch(root_logdir="./runs")
wandb.init(project="monai-integration", sync_tensorboard=True, save_code=True)
wandb.tensorboard.patch(root_logdir="./runs")
wandb.init(project="monai-integration", sync_tensorboard=True, save_code=True)
In [ ]:
Copied!
images = glob("./dataset/*")[:20]
labels = np.array([0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0], dtype=np.int64)
train_files = [{"img": img, "label": label} for img, label in zip(images[:10], labels[:10])]
val_files = [{"img": img, "label": label} for img, label in zip(images[-10:], labels[-10:])]
images = glob("./dataset/*")[:20]
labels = np.array([0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0], dtype=np.int64)
train_files = [{"img": img, "label": label} for img, label in zip(images[:10], labels[:10])]
val_files = [{"img": img, "label": label} for img, label in zip(images[-10:], labels[-10:])]
In [ ]:
Copied!
train_transforms = Compose(
[
LoadImaged(keys=["img"], ensure_channel_first=True),
ScaleIntensityd(keys=["img"]),
Resized(keys=["img"], spatial_size=(96, 96, 96)),
RandRotate90d(keys=["img"], prob=0.8, spatial_axes=[0, 2]),
]
)
val_transforms = Compose(
[
LoadImaged(keys=["img"], ensure_channel_first=True),
ScaleIntensityd(keys=["img"]),
Resized(keys=["img"], spatial_size=(96, 96, 96)),
]
)
train_transforms = Compose(
[
LoadImaged(keys=["img"], ensure_channel_first=True),
ScaleIntensityd(keys=["img"]),
Resized(keys=["img"], spatial_size=(96, 96, 96)),
RandRotate90d(keys=["img"], prob=0.8, spatial_axes=[0, 2]),
]
)
val_transforms = Compose(
[
LoadImaged(keys=["img"], ensure_channel_first=True),
ScaleIntensityd(keys=["img"]),
Resized(keys=["img"], spatial_size=(96, 96, 96)),
]
)
In [ ]:
Copied!
check_ds = monai.data.Dataset(data=train_files, transform=train_transforms)
check_loader = DataLoader(check_ds, batch_size=2, num_workers=4, pin_memory=torch.cuda.is_available())
check_data = monai.utils.misc.first(check_loader)
print(check_data["img"].shape, check_data["label"])
check_ds = monai.data.Dataset(data=train_files, transform=train_transforms)
check_loader = DataLoader(check_ds, batch_size=2, num_workers=4, pin_memory=torch.cuda.is_available())
check_data = monai.utils.misc.first(check_loader)
print(check_data["img"].shape, check_data["label"])
In [ ]:
Copied!
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net = monai.networks.nets.DenseNet121(spatial_dims=3, in_channels=1, out_channels=2).to(device)
loss = torch.nn.CrossEntropyLoss()
lr = 1e-5
opt = torch.optim.Adam(net.parameters(), lr)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net = monai.networks.nets.DenseNet121(spatial_dims=3, in_channels=1, out_channels=2).to(device)
loss = torch.nn.CrossEntropyLoss()
lr = 1e-5
opt = torch.optim.Adam(net.parameters(), lr)
In [ ]:
Copied!
def prepare_batch(batch, device=None, non_blocking=False):
return _prepare_batch((batch["img"], batch["label"]), device, non_blocking)
trainer = create_supervised_trainer(net, opt, loss, device, False, prepare_batch=prepare_batch)
def prepare_batch(batch, device=None, non_blocking=False):
return _prepare_batch((batch["img"], batch["label"]), device, non_blocking)
trainer = create_supervised_trainer(net, opt, loss, device, False, prepare_batch=prepare_batch)
In [ ]:
Copied!
checkpoint_handler = WandbModelCheckpointHandler("./runs_dict/", "net", n_saved=10, require_empty=False)
trainer.add_event_handler(
event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler, to_save={"net": net, "opt": opt}
)
train_stats_handler = StatsHandler(name="trainer", output_transform=lambda x: x)
train_stats_handler.attach(trainer)
train_tensorboard_stats_handler = TensorBoardStatsHandler(output_transform=lambda x: x)
train_tensorboard_stats_handler.attach(trainer)
# WandbStatsHandler logs loss at every iteration
train_wandb_stats_handler = WandbStatsHandler(output_transform=lambda x: x)
train_wandb_stats_handler.attach(trainer)
checkpoint_handler = WandbModelCheckpointHandler("./runs_dict/", "net", n_saved=10, require_empty=False)
trainer.add_event_handler(
event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler, to_save={"net": net, "opt": opt}
)
train_stats_handler = StatsHandler(name="trainer", output_transform=lambda x: x)
train_stats_handler.attach(trainer)
train_tensorboard_stats_handler = TensorBoardStatsHandler(output_transform=lambda x: x)
train_tensorboard_stats_handler.attach(trainer)
# WandbStatsHandler logs loss at every iteration
train_wandb_stats_handler = WandbStatsHandler(output_transform=lambda x: x)
train_wandb_stats_handler.attach(trainer)
In [ ]:
Copied!
# set parameters for validation
validation_every_n_epochs = 1
metric_name = "AUC"
# add evaluation metric to the evaluator engine
val_metrics = {metric_name: ROCAUC()}
post_label = Compose([AsDiscrete(to_onehot=2)])
post_pred = Compose([Activations(softmax=True)])
# Ignite evaluator expects batch=(img, label) and returns output=(y_pred, y) at every iteration,
# user can add output_transform to return other values
evaluator = create_supervised_evaluator(
net,
val_metrics,
device,
True,
prepare_batch=prepare_batch,
output_transform=lambda x, y, y_pred: (
[post_pred(i) for i in decollate_batch(y_pred)],
[post_label(i) for i in decollate_batch(y, detach=False)],
),
)
# set parameters for validation
validation_every_n_epochs = 1
metric_name = "AUC"
# add evaluation metric to the evaluator engine
val_metrics = {metric_name: ROCAUC()}
post_label = Compose([AsDiscrete(to_onehot=2)])
post_pred = Compose([Activations(softmax=True)])
# Ignite evaluator expects batch=(img, label) and returns output=(y_pred, y) at every iteration,
# user can add output_transform to return other values
evaluator = create_supervised_evaluator(
net,
val_metrics,
device,
True,
prepare_batch=prepare_batch,
output_transform=lambda x, y, y_pred: (
[post_pred(i) for i in decollate_batch(y_pred)],
[post_label(i) for i in decollate_batch(y, detach=False)],
),
)
In [ ]:
Copied!
# add stats event handler to print validation stats via evaluator
val_stats_handler = StatsHandler(
name="evaluator",
output_transform=lambda x: None, # no need to print loss value, so disable per iteration output
global_epoch_transform=lambda x: trainer.state.epoch,
) # fetch global epoch number from trainer
val_stats_handler.attach(evaluator)
# add handler to record metrics to TensorBoard at every epoch
val_tensorboard_stats_handler = TensorBoardStatsHandler(
output_transform=lambda x: None, # no need to plot loss value, so disable per iteration output
global_epoch_transform=lambda x: trainer.state.epoch,
) # fetch global epoch number from trainer
val_tensorboard_stats_handler.attach(evaluator)
# add handler to record metrics to Weights & Biases at every epoch
val_wandb_stats_handler = WandbStatsHandler(
output_transform=lambda x: None,
global_epoch_transform=lambda x: trainer.state.epoch,
)
val_wandb_stats_handler.attach(trainer)
# add stats event handler to print validation stats via evaluator
val_stats_handler = StatsHandler(
name="evaluator",
output_transform=lambda x: None, # no need to print loss value, so disable per iteration output
global_epoch_transform=lambda x: trainer.state.epoch,
) # fetch global epoch number from trainer
val_stats_handler.attach(evaluator)
# add handler to record metrics to TensorBoard at every epoch
val_tensorboard_stats_handler = TensorBoardStatsHandler(
output_transform=lambda x: None, # no need to plot loss value, so disable per iteration output
global_epoch_transform=lambda x: trainer.state.epoch,
) # fetch global epoch number from trainer
val_tensorboard_stats_handler.attach(evaluator)
# add handler to record metrics to Weights & Biases at every epoch
val_wandb_stats_handler = WandbStatsHandler(
output_transform=lambda x: None,
global_epoch_transform=lambda x: trainer.state.epoch,
)
val_wandb_stats_handler.attach(trainer)
In [ ]:
Copied!
# add early stopping handler to evaluator
early_stopper = EarlyStopping(patience=4, score_function=stopping_fn_from_metric(metric_name), trainer=trainer)
evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=early_stopper)
# add early stopping handler to evaluator
early_stopper = EarlyStopping(patience=4, score_function=stopping_fn_from_metric(metric_name), trainer=trainer)
evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=early_stopper)
In [ ]:
Copied!
# create a validation data loader
val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
val_loader = DataLoader(val_ds, batch_size=2, num_workers=4, pin_memory=torch.cuda.is_available())
# create a validation data loader
val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
val_loader = DataLoader(val_ds, batch_size=2, num_workers=4, pin_memory=torch.cuda.is_available())
In [ ]:
Copied!
@trainer.on(Events.EPOCH_COMPLETED(every=validation_every_n_epochs))
def run_validation(engine):
evaluator.run(val_loader)
@trainer.on(Events.EPOCH_COMPLETED(every=validation_every_n_epochs))
def run_validation(engine):
evaluator.run(val_loader)
In [ ]:
Copied!
# create a training data loader
train_ds = monai.data.Dataset(data=train_files, transform=train_transforms)
train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=4, pin_memory=torch.cuda.is_available())
# create a training data loader
train_ds = monai.data.Dataset(data=train_files, transform=train_transforms)
train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=4, pin_memory=torch.cuda.is_available())
In [ ]:
Copied!
train_epochs = 30
state = trainer.run(train_loader, train_epochs)
print(state)
wandb.finish()
train_epochs = 30
state = trainer.run(train_loader, train_epochs)
print(state)
wandb.finish()
In [ ]:
Copied!