# install the `main` branch of KerasCore
!pip install -qq namex
!apt install python3.10-venv
!git clone https://github.com/soumik12345/keras-core.git && cd keras-core && python pip_build.py --install
# install wandb-addons
!pip install -qq git+https://github.com/soumik12345/wandb-addons
Fine-tune a TorchVision Model with Keras and WandB
- how we can fine-tune a pre-trained model from torchvision using KerasCore.
- how we can use the backend-agnostic Keras callbacks for Weights & Biases to manage and track our experiment.
Installing and Importing the Dependencies
- We install the
main
branch of KerasCore, this lets us use the latest feature merged in KerasCore. - We also install wandb-addons, a library that hosts the backend-agnostic callbacks compatible with KerasCore
We specify the Keras backend to be using torch
by explicitly specifying the environment variable KERAS_BACKEND
.
import os
"KERAS_BACKEND"] = "torch"
os.environ[
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import datasets, models, transforms
import wandb
from wandb_addons.keras import WandbMetricsLogger, WandbModelCheckpoint
We initialize a wandb run and set the configs for the experiment.
="keras-torch")
wandb.init(project
= wandb.config
config = 4
config.batch_size = 25 config.num_epochs
A PyTorch-based Input Pipeline
We will be using the ImageNette dataset for this experiment. Imagenette is a subset of 10 easily classified classes from Imagenet (tench, English springer, cassette player, chain saw, church, French horn, garbage truck, gas pump, golf ball, parachute).
First, let’s download this dataset.
!wget https://s3.amazonaws.com/fast-ai-imageclas/imagenette2-320.tgz -P imagenette
!tar zxf imagenette/imagenette2-320.tgz -C imagenette
!gzip -d imagenette/imagenette2-320.tgz
Now, we create our standard torch-based data loading pipeline.
# Define pre-processing and augmentation transforms for the train and validation sets
= {
data_transforms 'train': transforms.Compose([
224),
transforms.RandomResizedCrop(
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
transforms.Normalize([
]),'val': transforms.Compose([
256),
transforms.Resize(224),
transforms.CenterCrop(
transforms.ToTensor(),0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
transforms.Normalize([
]),
}
# Define the train and validation datasets
= 'imagenette/imagenette2-320'
data_dir = {
image_datasets
x: datasets.ImageFolder(
os.path.join(data_dir, x), data_transforms[x]
)for x in ['train', 'val']
}
# Define the torch dataloaders corresponding to the train and validation dataset
= {
dataloaders
x: torch.utils.data.DataLoader(
image_datasets[x],=config.batch_size,
batch_size=True,
shuffle=4
num_workers
)for x in ['train', 'val']
}= {x: len(image_datasets[x]) for x in ['train', 'val']}
dataset_sizes = image_datasets['train'].classes
class_names
# Specify the global device
= torch.device("cuda" if torch.cuda.is_available() else "cpu") device
Let’s take a look at a few of the samples.
def imshow(inp, title=None):
"""Display image for Tensor."""
= inp.numpy().transpose((1, 2, 0))
inp = np.array([0.485, 0.456, 0.406])
mean = np.array([0.229, 0.224, 0.225])
std = std * inp + mean
inp = np.clip(inp, 0, 1)
inp
plt.imshow(inp)if title is not None:
plt.title(title)0.001) # pause a bit so that plots are updated
plt.pause(
# Get a batch of training data
= next(iter(dataloaders['train']))
inputs, classes print(inputs.shape, classes.shape)
# Make a grid from batch
= torchvision.utils.make_grid(inputs)
out
=[class_names[x] for x in classes]) imshow(out, title
Creating and Training our Classifier
We typically define a model in PyTorch using torch.nn.Module
s which act as the building blocks of stateful computation. Even though Keras supports PyTorch as a backend, it does not mean that we can nest torch modules inside a keras_core.Model
, because trainable variables inside a Keras Model is tracked exclusively via Keras Layers.
KerasCore provides us with a feature called TorchModuleWrapper
which enables us to do exactly this. The TorchModuleWrapper
is a Keras Layer that accepts a torch module and tracks its trainable variables, essentially converting the torch module into a Keras Layer. This enables us to put any torch modules inside a Keras Model and train them with a single model.fit()
!
The idea of the TorchModuleWrapper
was proposed by Keras’ creator François Chollet on this issue thread.
import keras_core as keras
from keras_core.utils import TorchModuleWrapper
class ResNet18Classifier(keras.Model):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Define the pre-trained ResNet18 module from torchvision
= models.resnet18(weights='IMAGENET1K_V1')
resnet_18 = resnet_18.fc.in_features
num_ftrs
# Set the classification head of the pre-trained ResNet18
# module to an identity module
= nn.Identity()
resnet_18.fc
# Set the trainable ResNet18 backbone to be a Keras Layer
# using `TorchModuleWrapper`
self.backbone = TorchModuleWrapper(resnet_18)
# Set this to `False` if you want to freeze the backbone
self.backbone.trainable = True
# Note that we don't convert nn.Dropout to a Keras Layer
# because it doesn't consist of trainable variables.
self.dropout = nn.Dropout(p=0.5)
# We define the classification head as a Keras Layer
self.fc = keras.layers.Dense(10)
def call(self, inputs):
= self.backbone(inputs)
x = self.dropout(x)
x = self.fc(x)
x return keras.activations.softmax(x, axis=1)
Note: It is actually possible to use torch modules inside a Keras Model without having to explicitly have them wrapped with the TorchModuleWrapper
as evident by this tweet from François Chollet. However, this doesn’t seem to work at the point of time this example was created, as reported in this issue.
# Now, we define the model and pass a random tensor to check the output shape
= ResNet18Classifier()
model 1, 3, 224, 224).to("cuda")).shape model(torch.ones(
Now, in standard Keras fashion, all we need to do is compile the model and call model.fit()
!
# Create exponential decay learning rate scheduler
= config.num_epochs * len(dataloaders["train"]) // config.batch_size
decay_steps = keras.optimizers.schedules.ExponentialDecay(
lr_scheduler =1e-3, decay_steps=decay_steps, decay_rate=0.1,
initial_learning_rate
)
# Compile the model
compile(
model.="sparse_categorical_crossentropy",
loss=keras.optimizers.Adam(lr_scheduler),
optimizer=["accuracy"],
metrics
)
# Define the backend-agnostic WandB callbacks for KerasCore
= [
callbacks # Track experiment metrics
="batch"),
WandbMetricsLogger(log_freq# Track and version model checkpoints
"model.keras")
WandbModelCheckpoint(
]
# Train the model by calling model.fit
model.fit("train"],
dataloaders[=dataloaders["val"],
validation_data=config.num_epochs,
epochs=callbacks,
callbacks )
In order to know more about the backend-agnostic Keras callbacks for Weights & Biases, check out the docs for wandb-addons.