Skip to content

KerasCV Object Detection Integrations

Utilities and callbacks integrating Weights & Biases with the object detection systems present in KerasCV.

Examples Link WandB Run
Easy and Simple Object-detection using KerasCV and Weights & Biases Open In Colab
Visualizing Object-detection datasets using Weights & Biases Open In Colab
Training an object detetction model using KerasCV and Weights & Biases Open In Colab

WandBDetectionVisualizationCallback

Bases: Callback

Callback for visualizing ground-truth and predicted bounding boxes in an epoch-wise manner for an object-detection task for KerasCV. The callback logs a wandb.Table with columns for the epoch, the images overlayed with an interactive bounding box overlay corresponding to the ground-truth and predicted boudning boxes, the number of ground-truth bounding boxes and the predicted mean-confidence for each class.

Parameters:

Name Type Description Default
dataset Dataset

A batched dataset consisting of Ragged Tensors. This can be obtained by applying ragged_batch() on a tf.data.Dataset.

required
class_mapping Dict[int, str]

A dictionary that maps the index of the classes to the corresponding class names.

required
max_batches_to_visualize Optional[int]

Maximum number of batches from the dataset to be visualized.

1
iou_threshold float

IoU threshold for non-max suppression during prediction.

0.01
confidence_threshold float

Confidence threshold for non-max suppression during prediction.

0.01
source_bounding_box_format str

Format of the source bounding box, one of "xyxy" or "xywh".

'xywh'
title str

Title under which the table will be logged to the Weights & Biases workspace.

'Evaluation-Table'
Source code in wandb_addons/keras/detection/callback.py
class WandBDetectionVisualizationCallback(keras.callbacks.Callback):
    """Callback for visualizing ground-truth and predicted bounding boxes in an
    epoch-wise manner for an object-detection task for
    [KerasCV](https://github.com/keras-team/keras-cv). The callback logs a
    [`wandb.Table`](https://docs.wandb.ai/guides/tables) with columns for the epoch,
    the images overlayed with an interactive bounding box overlay corresponding to the
    ground-truth and predicted boudning boxes, the number of ground-truth bounding
    boxes and the predicted mean-confidence for each class.

    !!! example "Examples:"
        - [Fine-tuning an Object Detection Model using KerasCV](../examples/train_retinanet).
        - [Sample Results for Fine-tuning an Object Detection Model using KerasCV](https://wandb.ai/geekyrakshit/keras-cv-callbacks/reports/Keras-CV-Integration--Vmlldzo1MTU4Nzk3)

    Arguments:
        dataset (tf.data.Dataset): A batched dataset consisting of Ragged Tensors.
            This can be obtained by applying `ragged_batch()` on a `tf.data.Dataset`.
        class_mapping (Dict[int, str]): A dictionary that maps the index of the classes
            to the corresponding class names.
        max_batches_to_visualize (Optional[int]): Maximum number of batches from the
            dataset to be visualized.
        iou_threshold (float): IoU threshold for non-max suppression during prediction.
        confidence_threshold (float): Confidence threshold for non-max suppression
            during prediction.
        source_bounding_box_format (str): Format of the source bounding box, one of
            `"xyxy"` or `"xywh"`.
        title (str): Title under which the table will be logged to the Weights & Biases
            workspace.
    """

    def __init__(
        self,
        dataset: tf_data.Dataset,
        class_mapping: dict,
        max_batches_to_visualize: Optional[Union[int, None]] = 1,
        iou_threshold: float = 0.01,
        confidence_threshold: float = 0.01,
        source_bounding_box_format: str = "xywh",
        title: str = "Evaluation-Table",
        *args,
        **kwargs
    ):
        super().__init__(*args, **kwargs)
        self.dataset = dataset.take(max_batches_to_visualize)
        self.class_mapping = class_mapping
        self.max_batches_to_visualize = max_batches_to_visualize
        self.iou_threshold = iou_threshold
        self.confidence_threshold = confidence_threshold
        self.source_bounding_box_format = source_bounding_box_format
        self.title = title
        self.prediction_decoder = keras_cv.layers.MultiClassNonMaxSuppression(
            bounding_box_format=self.source_bounding_box_format,
            from_logits=True,
            iou_threshold=self.iou_threshold,
            confidence_threshold=self.confidence_threshold,
        )
        self.table = wandb.Table(
            columns=[
                "Epoch",
                "Image",
                "Number-of-Ground-Truth-Boxes",
                "Mean-Confidence",
            ]
        )

    def plot_prediction(self, epoch, image_batch, y_true_batch):
        y_pred_batch = self.model.predict(image_batch, verbose=0)
        y_pred = keras_cv.bounding_box.to_ragged(y_pred_batch)
        image_batch = keras_cv.utils.to_numpy(image_batch).astype(np.uint8)
        ground_truth_bounding_boxes = keras_cv.utils.to_numpy(
            keras_cv.bounding_box.convert_format(
                y_true_batch["boxes"],
                source=self.source_bounding_box_format,
                target="xyxy",
                images=image_batch,
            )
        )
        ground_truth_classes = keras_cv.utils.to_numpy(y_true_batch["classes"])
        predicted_bounding_boxes = keras_cv.utils.to_numpy(
            keras_cv.bounding_box.convert_format(
                y_pred["boxes"],
                source=self.source_bounding_box_format,
                target="xyxy",
                images=image_batch,
            )
        )
        for idx in range(image_batch.shape[0]):
            num_detections = y_pred["num_detections"][idx].item()
            predicted_boxes = predicted_bounding_boxes[idx][:num_detections]
            confidences = keras_cv.utils.to_numpy(
                y_pred["confidence"][idx][:num_detections]
            )
            predicted_classes = keras_cv.utils.to_numpy(
                y_pred["classes"][idx][:num_detections]
            )

            gt_classes = [
                int(class_idx) for class_idx in ground_truth_classes[idx].tolist()
            ]
            gt_boxes = ground_truth_bounding_boxes[idx]
            if -1 in gt_classes:
                gt_classes = gt_classes[: gt_classes.index(-1)]

            wandb_prediction_boxes = []
            for box_idx in range(num_detections):
                wandb_prediction_boxes.append(
                    {
                        "position": {
                            "minX": predicted_boxes[box_idx][0]
                            / image_batch[idx].shape[0],
                            "minY": predicted_boxes[box_idx][1]
                            / image_batch[idx].shape[1],
                            "maxX": predicted_boxes[box_idx][2]
                            / image_batch[idx].shape[0],
                            "maxY": predicted_boxes[box_idx][3]
                            / image_batch[idx].shape[1],
                        },
                        "class_id": int(predicted_classes[box_idx]),
                        "box_caption": self.class_mapping[
                            int(predicted_classes[box_idx])
                        ],
                        "scores": {"confidence": float(confidences[box_idx])},
                    }
                )

            wandb_ground_truth_boxes = []
            for box_idx in range(len(gt_classes)):
                wandb_ground_truth_boxes.append(
                    {
                        "position": {
                            "minX": int(gt_boxes[box_idx][0]),
                            "minY": int(gt_boxes[box_idx][1]),
                            "maxX": int(gt_boxes[box_idx][2]),
                            "maxY": int(gt_boxes[box_idx][3]),
                        },
                        "class_id": gt_classes[box_idx],
                        "box_caption": self.class_mapping[int(gt_classes[box_idx])],
                        "domain": "pixel",
                    }
                )
            wandb_image = wandb.Image(
                image_batch[idx],
                boxes={
                    "ground-truth": {
                        "box_data": wandb_ground_truth_boxes,
                        "class_labels": self.class_mapping,
                    },
                    "predictions": {
                        "box_data": wandb_prediction_boxes,
                        "class_labels": self.class_mapping,
                    },
                },
            )
            mean_confidence_dict = get_mean_confidence_per_class(
                confidences, predicted_classes, self.class_mapping
            )
            self.table.add_data(
                epoch, wandb_image, len(gt_classes), mean_confidence_dict
            )

    def on_epoch_end(self, epoch, logs):
        original_prediction_decoder = self.model._prediction_decoder
        self.model.prediction_decoder = self.prediction_decoder
        for _ in tqdm(range(self.max_batches_to_visualize)):
            image_batch, y_true_batch = next(iter(self.dataset))
            self.plot_prediction(epoch, image_batch, y_true_batch)
        self.model.prediction_decoder = original_prediction_decoder

    def on_train_end(self, logs):
        wandb.log({self.title: self.table})

log_predictions_to_wandb(image_batch, prediction_batch, class_mapping, source_bbox_format='xywh')

Function to log inference results to a wandb.Table with images overlayed with an interactive bounding box overlay corresponding to the predicted boxes.

Example notebooks:

Parameters:

Name Type Description Default
image_batch Union[KerasTensor, array]

The batch of resized and batched images that is also passed to the model.

required
prediction_batch Union[KerasTensor, array]

The prediction batch that is the output of the detection model.

required
class_mapping Dict[int, str]

A dictionary that maps the index of the classes to the corresponding class names.

required
source_bbox_format str

Format of the source bounding box, one of "xyxy" or "xywh".

'xywh'
Source code in wandb_addons/keras/detection/inference.py
def log_predictions_to_wandb(
    image_batch: np.array,
    prediction_batch: np.array,
    class_mapping: Dict[int, str],
    source_bbox_format: str = "xywh",
):
    """Function to log inference results to a
    [wandb.Table](https://docs.wandb.ai/guides/data-vis) with images overlayed with an
    interactive bounding box overlay corresponding to the predicted boxes.

    !!! example "Example notebooks:"
        - [Object Detection using KerasCV](../examples/object_detection_inference).

    Arguments:
        image_batch (Union[backend.KerasTensor, np.array]): The batch of resized and
            batched images that is also passed to the model.
        prediction_batch (Union[backend.KerasTensor, np.array]): The prediction batch
            that is the output of the detection model.
        class_mapping (Dict[int, str]): A dictionary that maps the index of the classes
            to the corresponding class names.
        source_bbox_format (str): Format of the source bounding box, one of `"xyxy"`
            or `"xywh"`.
    """
    batch_size = prediction_batch["boxes"].shape[0]
    image_batch = keras_cv.utils.to_numpy(image_batch).astype(np.uint8)
    bounding_boxes = keras_cv.utils.to_numpy(
        keras_cv.bounding_box.convert_format(
            prediction_batch["boxes"],
            source=source_bbox_format,
            target="xyxy",
            images=image_batch,
        )
    )
    table = wandb.Table(columns=["Predictions", "Mean-Confidence"])
    for idx in tqdm(range(batch_size)):
        num_detections = prediction_batch["num_detections"][idx].item()
        predicted_boxes = bounding_boxes[idx][:num_detections]
        confidences = prediction_batch["confidence"][idx][:num_detections]
        classes = prediction_batch["classes"][idx][:num_detections]
        wandb_boxes = []
        for box_idx in range(num_detections):
            wandb_boxes.append(
                {
                    "position": {
                        "minX": predicted_boxes[box_idx][0] / image_batch[idx].shape[0],
                        "minY": predicted_boxes[box_idx][1] / image_batch[idx].shape[1],
                        "maxX": predicted_boxes[box_idx][2] / image_batch[idx].shape[0],
                        "maxY": predicted_boxes[box_idx][3] / image_batch[idx].shape[1],
                    },
                    "class_id": int(classes[box_idx]),
                    "box_caption": class_mapping[int(classes[box_idx])],
                    "scores": {"confidence": float(confidences[box_idx])},
                }
            )
        wandb_image = wandb.Image(
            image_batch[idx],
            boxes={
                "predictions": {
                    "box_data": wandb_boxes,
                    "class_labels": class_mapping,
                },
            },
        )
        mean_confidence_dict = get_mean_confidence_per_class(
            confidences, classes, class_mapping
        )
        table.add_data(wandb_image, mean_confidence_dict)
    wandb.log({"Prediction-Table": table})

visualize_dataset(dataset, class_mapping, title, max_batches_to_visualize=1, source_bbox_format='xywh')

Function to visualize a dataset using a wandb.Table with 2 columns, one with the images overlayed with an interactive bounding box overlay corresponding to the predicted boxes and another showing the number of bounding boxes corresponding to that image.

Example notebooks:

Parameters:

Name Type Description Default
dataset Dataset

A batched dataset consisting of Ragged Tensors. This can be obtained by applying ragged_batch() on a tf.data.Dataset.

required
class_mapping Dict[int, str]

A dictionary that maps the index of the classes to the corresponding class names.

required
title str

Title under which the table will be logged to the Weights & Biases workspace.

required
max_batches_to_visualize Optional[int]

Maximum number of batches from the dataset to be visualized.

1
source_bbox_format str

Format of the source bounding box, one of "xyxy" or "xywh".

'xywh'
Source code in wandb_addons/keras/detection/dataset.py
def visualize_dataset(
    dataset: tf_data.Dataset,
    class_mapping: Dict[int, str],
    title: str,
    max_batches_to_visualize: Optional[int] = 1,
    source_bbox_format: str = "xywh",
):
    """Function to visualize a dataset using a
    [wandb.Table](https://docs.wandb.ai/guides/data-vis) with 2 columns, one with the
    images overlayed with an interactive bounding box overlay corresponding to the
    predicted boxes and another showing the number of bounding boxes corresponding to
    that image.

    !!! example "Example notebooks:"
        - [Object Detection using KerasCV](../examples/visualize_dataset).

    Arguments:
        dataset (tf.data.Dataset): A batched dataset consisting of Ragged Tensors.
            This can be obtained by applying `ragged_batch()` on a `tf.data.Dataset`.
        class_mapping (Dict[int, str]): A dictionary that maps the index of the classes
            to the corresponding class names.
        title (str): Title under which the table will be logged to the Weights & Biases
            workspace.
        max_batches_to_visualize (Optional[int]): Maximum number of batches from the
            dataset to be visualized.
        source_bbox_format (str): Format of the source bounding box, one of `"xyxy"`
            or `"xywh"`.
    """
    table = wandb.Table(columns=["Images", "Number-of-Objects"])
    if max_batches_to_visualize is not None:
        dataset = iter(dataset.take(max_batches_to_visualize))
    else:
        dataset = iter(dataset)
        max_batches_to_visualize = tf_data.experimental.cardinality(dataset).numpy()

    for _ in tqdm(range(max_batches_to_visualize)):
        sample = next(dataset)
        images, bounding_boxes = sample["images"], sample["bounding_boxes"]
        images = keras_cv.utils.to_numpy(images)
        images = keras_cv.utils.transform_value_range(
            images, original_range=(0, 255), target_range=(0, 255)
        )
        for key, val in bounding_boxes.items():
            bounding_boxes[key] = keras_cv.utils.to_numpy(val)
        bounding_boxes["boxes"] = keras_cv.bounding_box.convert_format(
            bounding_boxes["boxes"],
            source=source_bbox_format,
            target="xyxy",
            images=images,
        )
        bounding_boxes["boxes"] = keras_cv.utils.to_numpy(bounding_boxes["boxes"])
        for idx in range(images.shape[0]):
            classes = [
                int(class_idx) for class_idx in bounding_boxes["classes"][idx].tolist()
            ]
            bboxes = bounding_boxes["boxes"][idx]
            if -1 in classes:
                classes = classes[: classes.index(-1)]
            wandb_boxes = []
            for object_idx in range(len(classes)):
                wandb_boxes.append(
                    {
                        "position": {
                            "minX": int(bboxes[object_idx][0]),
                            "minY": int(bboxes[object_idx][1]),
                            "maxX": int(bboxes[object_idx][2]),
                            "maxY": int(bboxes[object_idx][3]),
                        },
                        "class_id": classes[object_idx],
                        "box_caption": class_mapping[int(classes[object_idx])],
                        "domain": "pixel",
                    }
                )
            wandb_image = wandb.Image(
                images[idx],
                boxes={
                    "gorund-truth": {
                        "box_data": wandb_boxes,
                        "class_labels": class_mapping,
                    },
                },
            )
            table.add_data(wandb_image, len(classes))

    wandb.log({title: table})