Skip to content

Spatial Relationship Metrics

This module aims to implement the Spatial relationship metric described in section 3.2 of T2I-CompBench: A Comprehensive Benchmark for Open-world Compositional Text-to-image Generation.

Using an object-detection model for spatial relationship evaluation as proposed in T2I-CompBench
Weave gives us a holistic view of the evaluations to drill into individual ouputs and scores.
Example
import wandb
import weave

from hemm.eval_pipelines import BaseDiffusionModel, EvaluationPipeline
from hemm.metrics.image_quality import LPIPSMetric, PSNRMetric, SSIMMetric

# Initialize Weave and WandB
wandb.init(project="image-quality-leaderboard", job_type="evaluation")
weave.init(project_name="image-quality-leaderboard")

# Initialize the diffusion model to be evaluated as a `weave.Model` using `BaseWeaveModel`
model = BaseDiffusionModel(diffusion_model_name_or_path="CompVis/stable-diffusion-v1-4")

# Add the model to the evaluation pipeline
evaluation_pipeline = EvaluationPipeline(model=model)

# Define the judge model for 2d spatial relationship metric
judge = DETRSpatialRelationShipJudge(
    model_address=detr_model_address, revision=detr_revision
)

# Add PSNR Metric to the evaluation pipeline
metric = SpatialRelationshipMetric2D(judge=judge, name="2d_spatial_relationship_score")
evaluation_pipeline.add_metric(metric)

# Evaluate!
evaluation_pipeline(dataset="t2i_compbench_spatial_prompts:v0")

SpatialRelationshipMetric2D

Spatial relationship metric for 2D images as proposed by Section 4.2 from the paper T2I-CompBench: A Comprehensive Benchmark for Open-world Compositional Text-to-image Generation.

Parameters:

Name Type Description Default
judge Union[Model, DETRSpatialRelationShipJudge]

The judge model to predict the bounding boxes from the generated image.

required
iou_threshold Optional[float]

The IoU threshold for the spatial relationship.

0.1
distance_threshold Optional[float]

The distance threshold for the spatial relationship.

150
name Optional[str]

The name of the metric.

'spatial_relationship_score'
Source code in hemm/metrics/spatial_relationship/spatial_relationship_2d.py
class SpatialRelationshipMetric2D:
    """Spatial relationship metric for 2D images as proposed by Section 4.2 from the paper
    [T2I-CompBench: A Comprehensive Benchmark for Open-world Compositional Text-to-image Generation](https://arxiv.org/pdf/2307.06350).

    Args:
        judge (Union[weave.Model, DETRSpatialRelationShipJudge]): The judge model to predict
            the bounding boxes from the generated image.
        iou_threshold (Optional[float], optional): The IoU threshold for the spatial relationship.
        distance_threshold (Optional[float], optional): The distance threshold for the spatial relationship.
        name (Optional[str], optional): The name of the metric.
    """

    def __init__(
        self,
        judge: Union[weave.Model, DETRSpatialRelationShipJudge],
        iou_threshold: Optional[float] = 0.1,
        distance_threshold: Optional[float] = 150,
        name: Optional[str] = "spatial_relationship_score",
    ) -> None:
        self.judge = judge
        self.judge._initialize_models()
        self.iou_threshold = iou_threshold
        self.distance_threshold = distance_threshold
        self.name = name
        self.scores = []
        self.config = judge.model_dump()

    @weave.op()
    def compose_judgement(
        self,
        prompt: str,
        image: str,
        entity_1: str,
        entity_2: str,
        relationship: str,
        boxes: List[BoundingBox],
    ) -> Dict[str, Any]:
        """Compose the judgement based on the response and the predicted bounding boxes.

        Args:
            prompt (str): The prompt using which the image was generated.
            image (str): The base64 encoded image.
            entity_1 (str): First entity.
            entity_2 (str): Second entity.
            relationship (str): Relationship between the entities.
            boxes (List[BoundingBox]): The predicted bounding boxes.

        Returns:
            Dict[str, Any]: The comprehensive spatial relationship judgement.
        """
        _ = prompt

        # Determine presence of entities in the judgement
        judgement = {
            "entity_1_present": False,
            "entity_2_present": False,
        }
        entity_1_box: BoundingBox = None
        entity_2_box: BoundingBox = None
        annotated_image = image
        for box in boxes:
            if box.label == entity_1:
                judgement["entity_1_present"] = True
                entity_1_box = box
            elif box.label == entity_2:
                judgement["entity_2_present"] = True
                entity_2_box = box
            annotated_image = annotate_with_bounding_box(annotated_image, box)

        judgement["score"] = 0.0
        # assign score based on the spatial relationship inferred from the judgement
        if judgement["entity_1_present"] and judgement["entity_2_present"]:
            center_distance_x = abs(
                entity_1_box.box_coordinates_center.x
                - entity_2_box.box_coordinates_center.x
            )
            center_distance_y = abs(
                entity_1_box.box_coordinates_center.y
                - entity_2_box.box_coordinates_center.y
            )
            iou = get_iou(entity_1_box, entity_2_box)
            score = 0.0
            if relationship in ["near", "next to", "on side of", "side of"]:
                if (
                    abs(center_distance_x) < self.distance_threshold
                    or abs(center_distance_y) < self.distance_threshold
                ):
                    score = 1.0
                else:
                    score = self.distance_threshold / max(
                        abs(center_distance_x), abs(center_distance_y)
                    )
            elif relationship == "on the right of":
                if center_distance_x < 0:
                    if (
                        abs(center_distance_x) > abs(center_distance_y)
                        and iou < self.iou_threshold
                    ):
                        score = 1.0
                    elif (
                        abs(center_distance_x) > abs(center_distance_y)
                        and iou >= self.iou_threshold
                    ):
                        score = self.iou_threshold / iou
            elif relationship == "on the left of":
                if center_distance_x > 0:
                    if (
                        abs(center_distance_x) > abs(center_distance_y)
                        and iou < self.iou_threshold
                    ):
                        score = 1.0
                    elif (
                        abs(center_distance_x) > abs(center_distance_y)
                        and iou >= self.iou_threshold
                    ):
                        score = self.iou_threshold / iou
                else:
                    score = 0.0
            elif relationship == "on the bottom of":
                if center_distance_y < 0:
                    if (
                        abs(center_distance_y) > abs(center_distance_x)
                        and iou < self.iou_threshold
                    ):
                        score = 1
                    elif (
                        abs(center_distance_y) > abs(center_distance_x)
                        and iou >= self.iou_threshold
                    ):
                        score = self.iou_threshold / iou
            elif relationship == "on the top of":
                if center_distance_y > 0:
                    if (
                        abs(center_distance_y) > abs(center_distance_x)
                        and iou < self.iou_threshold
                    ):
                        score = 1
                    elif (
                        abs(center_distance_y) > abs(center_distance_x)
                        and iou >= self.iou_threshold
                    ):
                        score = self.iou_threshold / iou
            judgement["score"] = score

        self.scores.append(
            {
                **judgement,
                **{
                    "judge_annotated_image": wandb.Image(
                        base64_decode_image(annotated_image)
                        if isinstance(annotated_image, str)
                        else annotated_image
                    )
                },
            }
        )
        return {
            **judgement,
            **{
                "judge_annotated_image": (
                    base64_encode_image(annotated_image)
                    if isinstance(annotated_image, Image.Image)
                    else annotated_image
                )
            },
        }

    @weave.op()
    async def __call__(
        self,
        prompt: str,
        entity_1: str,
        entity_2: str,
        relationship: str,
        model_output: Dict[str, Any],
    ) -> Dict[str, Union[bool, float, int]]:
        """Calculate the spatial relationship score for the given prompt and model output.

        Args:
            prompt (str): The prompt for the model.
            entity_1 (str): The first entity in the spatial relationship.
            entity_2 (str): The second entity in the spatial relationship.
            relationship (str): The spatial relationship between the two entities.
            model_output (Dict[str, Any]): The output from the model.

        Returns:
            Dict[str, Union[bool, float, int]]: The comprehensive spatial relationship judgement.
        """
        _ = prompt

        image = model_output["image"]
        boxes: List[BoundingBox] = self.judge.predict(image)
        judgement = self.compose_judgement(
            prompt, image, entity_1, entity_2, relationship, boxes
        )
        return {self.name: judgement["score"]}

__call__(prompt, entity_1, entity_2, relationship, model_output) async

Calculate the spatial relationship score for the given prompt and model output.

Parameters:

Name Type Description Default
prompt str

The prompt for the model.

required
entity_1 str

The first entity in the spatial relationship.

required
entity_2 str

The second entity in the spatial relationship.

required
relationship str

The spatial relationship between the two entities.

required
model_output Dict[str, Any]

The output from the model.

required

Returns:

Type Description
Dict[str, Union[bool, float, int]]

Dict[str, Union[bool, float, int]]: The comprehensive spatial relationship judgement.

Source code in hemm/metrics/spatial_relationship/spatial_relationship_2d.py
@weave.op()
async def __call__(
    self,
    prompt: str,
    entity_1: str,
    entity_2: str,
    relationship: str,
    model_output: Dict[str, Any],
) -> Dict[str, Union[bool, float, int]]:
    """Calculate the spatial relationship score for the given prompt and model output.

    Args:
        prompt (str): The prompt for the model.
        entity_1 (str): The first entity in the spatial relationship.
        entity_2 (str): The second entity in the spatial relationship.
        relationship (str): The spatial relationship between the two entities.
        model_output (Dict[str, Any]): The output from the model.

    Returns:
        Dict[str, Union[bool, float, int]]: The comprehensive spatial relationship judgement.
    """
    _ = prompt

    image = model_output["image"]
    boxes: List[BoundingBox] = self.judge.predict(image)
    judgement = self.compose_judgement(
        prompt, image, entity_1, entity_2, relationship, boxes
    )
    return {self.name: judgement["score"]}

compose_judgement(prompt, image, entity_1, entity_2, relationship, boxes)

Compose the judgement based on the response and the predicted bounding boxes.

Parameters:

Name Type Description Default
prompt str

The prompt using which the image was generated.

required
image str

The base64 encoded image.

required
entity_1 str

First entity.

required
entity_2 str

Second entity.

required
relationship str

Relationship between the entities.

required
boxes List[BoundingBox]

The predicted bounding boxes.

required

Returns:

Type Description
Dict[str, Any]

Dict[str, Any]: The comprehensive spatial relationship judgement.

Source code in hemm/metrics/spatial_relationship/spatial_relationship_2d.py
@weave.op()
def compose_judgement(
    self,
    prompt: str,
    image: str,
    entity_1: str,
    entity_2: str,
    relationship: str,
    boxes: List[BoundingBox],
) -> Dict[str, Any]:
    """Compose the judgement based on the response and the predicted bounding boxes.

    Args:
        prompt (str): The prompt using which the image was generated.
        image (str): The base64 encoded image.
        entity_1 (str): First entity.
        entity_2 (str): Second entity.
        relationship (str): Relationship between the entities.
        boxes (List[BoundingBox]): The predicted bounding boxes.

    Returns:
        Dict[str, Any]: The comprehensive spatial relationship judgement.
    """
    _ = prompt

    # Determine presence of entities in the judgement
    judgement = {
        "entity_1_present": False,
        "entity_2_present": False,
    }
    entity_1_box: BoundingBox = None
    entity_2_box: BoundingBox = None
    annotated_image = image
    for box in boxes:
        if box.label == entity_1:
            judgement["entity_1_present"] = True
            entity_1_box = box
        elif box.label == entity_2:
            judgement["entity_2_present"] = True
            entity_2_box = box
        annotated_image = annotate_with_bounding_box(annotated_image, box)

    judgement["score"] = 0.0
    # assign score based on the spatial relationship inferred from the judgement
    if judgement["entity_1_present"] and judgement["entity_2_present"]:
        center_distance_x = abs(
            entity_1_box.box_coordinates_center.x
            - entity_2_box.box_coordinates_center.x
        )
        center_distance_y = abs(
            entity_1_box.box_coordinates_center.y
            - entity_2_box.box_coordinates_center.y
        )
        iou = get_iou(entity_1_box, entity_2_box)
        score = 0.0
        if relationship in ["near", "next to", "on side of", "side of"]:
            if (
                abs(center_distance_x) < self.distance_threshold
                or abs(center_distance_y) < self.distance_threshold
            ):
                score = 1.0
            else:
                score = self.distance_threshold / max(
                    abs(center_distance_x), abs(center_distance_y)
                )
        elif relationship == "on the right of":
            if center_distance_x < 0:
                if (
                    abs(center_distance_x) > abs(center_distance_y)
                    and iou < self.iou_threshold
                ):
                    score = 1.0
                elif (
                    abs(center_distance_x) > abs(center_distance_y)
                    and iou >= self.iou_threshold
                ):
                    score = self.iou_threshold / iou
        elif relationship == "on the left of":
            if center_distance_x > 0:
                if (
                    abs(center_distance_x) > abs(center_distance_y)
                    and iou < self.iou_threshold
                ):
                    score = 1.0
                elif (
                    abs(center_distance_x) > abs(center_distance_y)
                    and iou >= self.iou_threshold
                ):
                    score = self.iou_threshold / iou
            else:
                score = 0.0
        elif relationship == "on the bottom of":
            if center_distance_y < 0:
                if (
                    abs(center_distance_y) > abs(center_distance_x)
                    and iou < self.iou_threshold
                ):
                    score = 1
                elif (
                    abs(center_distance_y) > abs(center_distance_x)
                    and iou >= self.iou_threshold
                ):
                    score = self.iou_threshold / iou
        elif relationship == "on the top of":
            if center_distance_y > 0:
                if (
                    abs(center_distance_y) > abs(center_distance_x)
                    and iou < self.iou_threshold
                ):
                    score = 1
                elif (
                    abs(center_distance_y) > abs(center_distance_x)
                    and iou >= self.iou_threshold
                ):
                    score = self.iou_threshold / iou
        judgement["score"] = score

    self.scores.append(
        {
            **judgement,
            **{
                "judge_annotated_image": wandb.Image(
                    base64_decode_image(annotated_image)
                    if isinstance(annotated_image, str)
                    else annotated_image
                )
            },
        }
    )
    return {
        **judgement,
        **{
            "judge_annotated_image": (
                base64_encode_image(annotated_image)
                if isinstance(annotated_image, Image.Image)
                else annotated_image
            )
        },
    }

DETRSpatialRelationShipJudge

Bases: Model

DETR spatial relationship judge model for 2D images.

Parameters:

Name Type Description Default
model_address str

The address of the model to use.

required
revision str

The revision of the model to use.

required
Source code in hemm/metrics/spatial_relationship/judges/detr.py
class DETRSpatialRelationShipJudge(weave.Model):
    """DETR spatial relationship judge model for 2D images.

    Args:
        model_address (str, optional): The address of the model to use.
        revision (str, optional): The revision of the model to use.
    """

    model_address: str = "facebook/detr-resnet-50"
    revision: str = "no_timm"
    _feature_extractor: DetrImageProcessor = None
    _object_detection_model: DetrForObjectDetection = None

    def _initialize_models(self):
        self._feature_extractor = DetrImageProcessor.from_pretrained(
            self.model_address, revision=self.revision
        )
        self._object_detection_model = DetrForObjectDetection.from_pretrained(
            self.model_address, revision=self.revision
        )

    @weave.op()
    def predict(self, image: str) -> List[BoundingBox]:
        """Predict the bounding boxes from the input image.

        Args:
            image (str): The base64 encoded image.

        Returns:
            List[BoundingBox]: The predicted bounding boxes.
        """
        pil_image = base64_decode_image(image)
        encoding = self._feature_extractor(pil_image, return_tensors="pt")
        outputs = self._object_detection_model(**encoding)
        target_sizes = torch.tensor([pil_image.size[::-1]])
        results = self._feature_extractor.post_process_object_detection(
            outputs, target_sizes=target_sizes, threshold=0.9
        )[0]
        bboxes = []
        for score, label, box in zip(
            results["scores"], results["labels"], results["boxes"]
        ):
            xmin, ymin, xmax, ymax = box.tolist()
            bboxes.append(
                BoundingBox(
                    box_coordinates_min=CartesianCoordinate2D(x=xmin, y=ymin),
                    box_coordinates_max=CartesianCoordinate2D(x=xmax, y=ymax),
                    box_coordinates_center=CartesianCoordinate2D(
                        x=(xmin + xmax) / 2, y=(ymin + ymax) / 2
                    ),
                    label=self._object_detection_model.config.id2label[label.item()],
                    score=score.item(),
                )
            )
        return bboxes

predict(image)

Predict the bounding boxes from the input image.

Parameters:

Name Type Description Default
image str

The base64 encoded image.

required

Returns:

Type Description
List[BoundingBox]

List[BoundingBox]: The predicted bounding boxes.

Source code in hemm/metrics/spatial_relationship/judges/detr.py
@weave.op()
def predict(self, image: str) -> List[BoundingBox]:
    """Predict the bounding boxes from the input image.

    Args:
        image (str): The base64 encoded image.

    Returns:
        List[BoundingBox]: The predicted bounding boxes.
    """
    pil_image = base64_decode_image(image)
    encoding = self._feature_extractor(pil_image, return_tensors="pt")
    outputs = self._object_detection_model(**encoding)
    target_sizes = torch.tensor([pil_image.size[::-1]])
    results = self._feature_extractor.post_process_object_detection(
        outputs, target_sizes=target_sizes, threshold=0.9
    )[0]
    bboxes = []
    for score, label, box in zip(
        results["scores"], results["labels"], results["boxes"]
    ):
        xmin, ymin, xmax, ymax = box.tolist()
        bboxes.append(
            BoundingBox(
                box_coordinates_min=CartesianCoordinate2D(x=xmin, y=ymin),
                box_coordinates_max=CartesianCoordinate2D(x=xmax, y=ymax),
                box_coordinates_center=CartesianCoordinate2D(
                    x=(xmin + xmax) / 2, y=(ymin + ymax) / 2
                ),
                label=self._object_detection_model.config.id2label[label.item()],
                score=score.item(),
            )
        )
    return bboxes