# install the `main` branch of KerasCore
!pip install -qq namex
!apt install python3.10-venv
!git clone https://github.com/soumik12345/keras-core.git && cd keras-core && python pip_build.py --install
# install wandb-addons
!pip install -qq git+https://github.com/soumik12345/wandb-addonsFine-tune a TorchVision Model with Keras and WandB
- how we can fine-tune a pre-trained model from torchvision using KerasCore.
- how we can use the backend-agnostic Keras callbacks for Weights & Biases to manage and track our experiment.
Installing and Importing the Dependencies
- We install the
mainbranch of KerasCore, this lets us use the latest feature merged in KerasCore. - We also install wandb-addons, a library that hosts the backend-agnostic callbacks compatible with KerasCore
We specify the Keras backend to be using torch by explicitly specifying the environment variable KERAS_BACKEND.
import os
os.environ["KERAS_BACKEND"] = "torch"
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import datasets, models, transforms
import wandb
from wandb_addons.keras import WandbMetricsLogger, WandbModelCheckpointWe initialize a wandb run and set the configs for the experiment.
wandb.init(project="keras-torch")
config = wandb.config
config.batch_size = 4
config.num_epochs = 25A PyTorch-based Input Pipeline
We will be using the ImageNette dataset for this experiment. Imagenette is a subset of 10 easily classified classes from Imagenet (tench, English springer, cassette player, chain saw, church, French horn, garbage truck, gas pump, golf ball, parachute).
First, let’s download this dataset.
!wget https://s3.amazonaws.com/fast-ai-imageclas/imagenette2-320.tgz -P imagenette
!tar zxf imagenette/imagenette2-320.tgz -C imagenette
!gzip -d imagenette/imagenette2-320.tgzNow, we create our standard torch-based data loading pipeline.
# Define pre-processing and augmentation transforms for the train and validation sets
data_transforms = {
'train': transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'val': transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
}
# Define the train and validation datasets
data_dir = 'imagenette/imagenette2-320'
image_datasets = {
x: datasets.ImageFolder(
os.path.join(data_dir, x), data_transforms[x]
)
for x in ['train', 'val']
}
# Define the torch dataloaders corresponding to the train and validation dataset
dataloaders = {
x: torch.utils.data.DataLoader(
image_datasets[x],
batch_size=config.batch_size,
shuffle=True,
num_workers=4
)
for x in ['train', 'val']
}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes
# Specify the global device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")Let’s take a look at a few of the samples.
def imshow(inp, title=None):
"""Display image for Tensor."""
inp = inp.numpy().transpose((1, 2, 0))
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
inp = std * inp + mean
inp = np.clip(inp, 0, 1)
plt.imshow(inp)
if title is not None:
plt.title(title)
plt.pause(0.001) # pause a bit so that plots are updated
# Get a batch of training data
inputs, classes = next(iter(dataloaders['train']))
print(inputs.shape, classes.shape)
# Make a grid from batch
out = torchvision.utils.make_grid(inputs)
imshow(out, title=[class_names[x] for x in classes])Creating and Training our Classifier
We typically define a model in PyTorch using torch.nn.Modules which act as the building blocks of stateful computation. Even though Keras supports PyTorch as a backend, it does not mean that we can nest torch modules inside a keras_core.Model, because trainable variables inside a Keras Model is tracked exclusively via Keras Layers.
KerasCore provides us with a feature called TorchModuleWrapper which enables us to do exactly this. The TorchModuleWrapper is a Keras Layer that accepts a torch module and tracks its trainable variables, essentially converting the torch module into a Keras Layer. This enables us to put any torch modules inside a Keras Model and train them with a single model.fit()!
The idea of the TorchModuleWrapper was proposed by Keras’ creator François Chollet on this issue thread.
import keras_core as keras
from keras_core.utils import TorchModuleWrapper
class ResNet18Classifier(keras.Model):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Define the pre-trained ResNet18 module from torchvision
resnet_18 = models.resnet18(weights='IMAGENET1K_V1')
num_ftrs = resnet_18.fc.in_features
# Set the classification head of the pre-trained ResNet18
# module to an identity module
resnet_18.fc = nn.Identity()
# Set the trainable ResNet18 backbone to be a Keras Layer
# using `TorchModuleWrapper`
self.backbone = TorchModuleWrapper(resnet_18)
# Set this to `False` if you want to freeze the backbone
self.backbone.trainable = True
# Note that we don't convert nn.Dropout to a Keras Layer
# because it doesn't consist of trainable variables.
self.dropout = nn.Dropout(p=0.5)
# We define the classification head as a Keras Layer
self.fc = keras.layers.Dense(10)
def call(self, inputs):
x = self.backbone(inputs)
x = self.dropout(x)
x = self.fc(x)
return keras.activations.softmax(x, axis=1)Note: It is actually possible to use torch modules inside a Keras Model without having to explicitly have them wrapped with the TorchModuleWrapper as evident by this tweet from François Chollet. However, this doesn’t seem to work at the point of time this example was created, as reported in this issue.
# Now, we define the model and pass a random tensor to check the output shape
model = ResNet18Classifier()
model(torch.ones(1, 3, 224, 224).to("cuda")).shapeNow, in standard Keras fashion, all we need to do is compile the model and call model.fit()!
# Create exponential decay learning rate scheduler
decay_steps = config.num_epochs * len(dataloaders["train"]) // config.batch_size
lr_scheduler = keras.optimizers.schedules.ExponentialDecay(
initial_learning_rate=1e-3, decay_steps=decay_steps, decay_rate=0.1,
)
# Compile the model
model.compile(
loss="sparse_categorical_crossentropy",
optimizer=keras.optimizers.Adam(lr_scheduler),
metrics=["accuracy"],
)
# Define the backend-agnostic WandB callbacks for KerasCore
callbacks = [
# Track experiment metrics
WandbMetricsLogger(log_freq="batch"),
# Track and version model checkpoints
WandbModelCheckpoint("model.keras")
]
# Train the model by calling model.fit
model.fit(
dataloaders["train"],
validation_data=dataloaders["val"],
epochs=config.num_epochs,
callbacks=callbacks,
)In order to know more about the backend-agnostic Keras callbacks for Weights & Biases, check out the docs for wandb-addons.