Fine-tune a TorchVision Model with Keras and WandB

Using Keras to fine-tune a pre-trained model from Torchvision.

April 9, 2023

This notebook demonstrates

Installing and Importing the Dependencies

  • We install the main branch of KerasCore, this lets us use the latest feature merged in KerasCore.
  • We also install wandb-addons, a library that hosts the backend-agnostic callbacks compatible with KerasCore
# install the `main` branch of KerasCore
!pip install -qq namex
!apt install python3.10-venv
!git clone && cd keras-core && python --install

# install wandb-addons
!pip install -qq git+

We specify the Keras backend to be using torch by explicitly specifying the environment variable KERAS_BACKEND.

import os
os.environ["KERAS_BACKEND"] = "torch"

import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F

import torchvision
from torchvision import datasets, models, transforms

import wandb
from wandb_addons.keras import WandbMetricsLogger, WandbModelCheckpoint

We initialize a wandb run and set the configs for the experiment.


config = wandb.config
config.batch_size = 4
config.num_epochs = 25

A PyTorch-based Input Pipeline

We will be using the ImageNette dataset for this experiment. Imagenette is a subset of 10 easily classified classes from Imagenet (tench, English springer, cassette player, chain saw, church, French horn, garbage truck, gas pump, golf ball, parachute).

First, let’s download this dataset.

!wget -P imagenette
!tar zxf imagenette/imagenette2-320.tgz -C imagenette
!gzip -d imagenette/imagenette2-320.tgz

Now, we create our standard torch-based data loading pipeline.

# Define pre-processing and augmentation transforms for the train and validation sets
data_transforms = {
    'train': transforms.Compose([
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    'val': transforms.Compose([
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

# Define the train and validation datasets
data_dir = 'imagenette/imagenette2-320'
image_datasets = {
    x: datasets.ImageFolder(
        os.path.join(data_dir, x), data_transforms[x]
    for x in ['train', 'val']

# Define the torch dataloaders corresponding to the train and validation dataset
dataloaders = {
    for x in ['train', 'val']
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes

# Specify the global device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Let’s take a look at a few of the samples.

def imshow(inp, title=None):
    """Display image for Tensor."""
    inp = inp.numpy().transpose((1, 2, 0))
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    inp = std * inp + mean
    inp = np.clip(inp, 0, 1)
    if title is not None:
    plt.pause(0.001)  # pause a bit so that plots are updated

# Get a batch of training data
inputs, classes = next(iter(dataloaders['train']))
print(inputs.shape, classes.shape)

# Make a grid from batch
out = torchvision.utils.make_grid(inputs)

imshow(out, title=[class_names[x] for x in classes])

Creating and Training our Classifier

We typically define a model in PyTorch using torch.nn.Modules which act as the building blocks of stateful computation. Even though Keras supports PyTorch as a backend, it does not mean that we can nest torch modules inside a keras_core.Model, because trainable variables inside a Keras Model is tracked exclusively via Keras Layers.

KerasCore provides us with a feature called TorchModuleWrapper which enables us to do exactly this. The TorchModuleWrapper is a Keras Layer that accepts a torch module and tracks its trainable variables, essentially converting the torch module into a Keras Layer. This enables us to put any torch modules inside a Keras Model and train them with a single!

The idea of the TorchModuleWrapper was proposed by Keras’ creator François Chollet on this issue thread.

import keras_core as keras
from keras_core.utils import TorchModuleWrapper

class ResNet18Classifier(keras.Model):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        # Define the pre-trained ResNet18 module from torchvision
        resnet_18 = models.resnet18(weights='IMAGENET1K_V1')
        num_ftrs = resnet_18.fc.in_features

        # Set the classification head of the pre-trained ResNet18
        # module to an identity module
        resnet_18.fc = nn.Identity()
        # Set the trainable ResNet18 backbone to be a Keras Layer
        # using `TorchModuleWrapper`
        self.backbone = TorchModuleWrapper(resnet_18)

        # Set this to `False` if you want to freeze the backbone
        self.backbone.trainable = True

        # Note that we don't convert nn.Dropout to a Keras Layer
        # because it doesn't consist of trainable variables.
        self.dropout = nn.Dropout(p=0.5)

        # We define the classification head as a Keras Layer
        self.fc = keras.layers.Dense(10)

    def call(self, inputs):
        x = self.backbone(inputs)
        x = self.dropout(x)
        x = self.fc(x)
        return keras.activations.softmax(x, axis=1)

Note: It is actually possible to use torch modules inside a Keras Model without having to explicitly have them wrapped with the TorchModuleWrapper as evident by this tweet from François Chollet. However, this doesn’t seem to work at the point of time this example was created, as reported in this issue.

# Now, we define the model and pass a random tensor to check the output shape
model = ResNet18Classifier()
model(torch.ones(1, 3, 224, 224).to("cuda")).shape

Now, in standard Keras fashion, all we need to do is compile the model and call!

# Create exponential decay learning rate scheduler
decay_steps = config.num_epochs * len(dataloaders["train"]) // config.batch_size
lr_scheduler = keras.optimizers.schedules.ExponentialDecay(
    initial_learning_rate=1e-3, decay_steps=decay_steps, decay_rate=0.1,

# Compile the model

# Define the backend-agnostic WandB callbacks for KerasCore
callbacks = [
    # Track experiment metrics
    # Track and version model checkpoints

# Train the model by calling

In order to know more about the backend-agnostic Keras callbacks for Weights & Biases, check out the docs for wandb-addons.