From 2605047eaf956bff6a3a4f9d409c7031172bcd3f Mon Sep 17 00:00:00 2001 From: davidberenstein1957 Date: Tue, 28 Apr 2026 14:59:31 +0200 Subject: [PATCH 1/2] feat(text-metrics): split oneig_alignment into dedicated branch Adds oneig_alignment metric implementation, its focused tests, and benchmark subset wiring while keeping reasoning and text-rendering metrics for later stacked PRs. Made-with: Cursor --- src/pruna/evaluation/benchmarks.py | 31 ++- .../metrics/metric_oneig_alignment.py | 234 ++++++++++++++++++ tests/evaluation/test_text_metrics.py | 136 ++++++++++ 3 files changed, 394 insertions(+), 7 deletions(-) create mode 100644 src/pruna/evaluation/metrics/metric_oneig_alignment.py create mode 100644 tests/evaluation/test_text_metrics.py diff --git a/src/pruna/evaluation/benchmarks.py b/src/pruna/evaluation/benchmarks.py index de005c9b..e3f58164 100644 --- a/src/pruna/evaluation/benchmarks.py +++ b/src/pruna/evaluation/benchmarks.py @@ -272,13 +272,30 @@ def list(cls, task_type: str | None = None) -> list[str]: reference="https://arxiv.org/abs/2504.17761", ), Benchmark( - name="OneIG", - description=( - "Omni-dimensional benchmark for text-to-image evaluation. Six dataset categories " - "(Anime_Stylization, General_Object, Knowledge_Reasoning, Multilingualism, Portrait, " - "Text_Rendering) plus fine-grained style classes. Includes alignment questions." - ), - metrics=[], # Paper uses dimension-specific metrics; not in Pruna + name="OneIG Anime Stylization", + description="OneIG subset: anime and stylized imagery.", + metrics=["oneig_alignment"], + task_type="text_to_image", + reference="https://arxiv.org/abs/2506.07977", + ), + Benchmark( + name="OneIG General Object", + description="OneIG subset: everyday objects and scenes.", + metrics=["oneig_alignment"], + task_type="text_to_image", + reference="https://arxiv.org/abs/2506.07977", + ), + Benchmark( + name="OneIG Multilingualism", + description="OneIG subset: multilingual prompts (incl. Chinese splits).", + metrics=["oneig_alignment"], + task_type="text_to_image", + reference="https://arxiv.org/abs/2506.07977", + ), + Benchmark( + name="OneIG Portrait", + description="OneIG subset: people and portraits.", + metrics=["oneig_alignment"], task_type="text_to_image", reference="https://arxiv.org/abs/2506.07977", ), diff --git a/src/pruna/evaluation/metrics/metric_oneig_alignment.py b/src/pruna/evaluation/metrics/metric_oneig_alignment.py new file mode 100644 index 00000000..ad443283 --- /dev/null +++ b/src/pruna/evaluation/metrics/metric_oneig_alignment.py @@ -0,0 +1,234 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""OneIG alignment scoring with dependency masking (parent ``No`` gates children).""" + +from __future__ import annotations + +from typing import Any, Mapping + +import torch + +from pruna.evaluation.metrics.metric_qa_accuracy import QAAccuracyMetric +from pruna.evaluation.metrics.registry import MetricRegistry +from pruna.evaluation.metrics.utils import metric_data_processor +from pruna.evaluation.metrics.vlm_utils import _process_images + + +def _int_dict_keys(mapping: Mapping[Any, Any]) -> dict[int, Any]: + return {int(k): v for k, v in mapping.items()} + + +def _normalize_dependencies(deps: Any) -> dict[int, list[int]]: + if not isinstance(deps, Mapping): + return {} + out: dict[int, list[int]] = {} + for k, v in deps.items(): + key = int(k) + if isinstance(v, list): + out[key] = [int(p) for p in v] + else: + out[key] = [] + return out + + +def _active_oneig_question_ids(qmap: dict[int, Any]) -> list[int]: + """Question ids with real prompt text (excludes HF ``datasets`` padding and empty slots).""" + active: list[int] = [] + for qi in sorted(qmap): + text = qmap[qi] + if text is None: + continue + s = str(text).strip() + if not s or s == "None": + continue + active.append(qi) + return active + + +def apply_oneig_dependency_mask( + raw_scores: Mapping[int, float], + dependencies: Mapping[int, list[int]], +) -> dict[int, float]: + """ + Apply OneIG ``filter_score`` logic per dependency graph (single grid cell). + + Parents with semantic answer ``No`` (score ``0``) force dependent question + scores to ``0``. Parent id ``0`` is ignored, matching the reference script. + + Parameters + ---------- + raw_scores : Mapping[int, float] + Map question id → VLM score in ``{0, 1}`` (or float) before masking. + dependencies : Mapping[int, list[int]] + Map child question id → list of parent question ids (use ``[0]`` for roots). + + Returns + ------- + dict[int, float] + Copy of scores with dependent questions zeroed when any non-zero parent + scored ``0``. + """ + filtered = {int(k): float(v) for k, v in raw_scores.items()} + deps = _normalize_dependencies(dependencies) + raw = dict(filtered) + for child_id, parent_ids in deps.items(): + if child_id not in filtered: + continue + any_parent_no = False + for parent_id in parent_ids: + if parent_id == 0: + continue + if parent_id not in raw: + continue + if raw[parent_id] == 0.0: + any_parent_no = True + break + if any_parent_no: + filtered[child_id] = 0.0 + return filtered + + +def aggregate_oneig_alignment_per_cell(filtered_scores: Mapping[int, float], question_ids: list[int]) -> float: + """ + Mean filtered score over all questions in the prompt (one grid cell). + + Parameters + ---------- + filtered_scores : Mapping[int, float] + Post-mask scores for each question id. + question_ids : list[int] + Ordered ids (typically sorted ascending) defining the denominator. + + Returns + ------- + float + Average score in ``[0, 1]`` if inputs are binary; ``0.0`` if ``question_ids`` is empty. + """ + if not question_ids: + return 0.0 + s = sum(float(filtered_scores[qid]) for qid in question_ids) + return s / float(len(question_ids)) + + +@MetricRegistry.register("oneig_alignment") +class OneIGAlignmentMetric(QAAccuracyMetric): + """ + OneIG alignment with dependency-aware aggregation. + + Reuses :class:`QAAccuracyMetric` VLM Yes/No scoring but aggregates like + ``OneIG-Benchmark`` ``alignment_score.py`` for a **single** grid cell (no + ``split_mxn_grid``): question ids are sorted numerically, raw scores are + masked when any non-root parent is ``No``, then the mean over all questions + is stored per image. Entries with null or blank question text (HF ``datasets`` + schema padding) are omitted from scoring. + + Numerical parity with upstream also depends on the VLM (e.g. ``openai/gpt-4o`` via + litellm vs reference Qwen2.5-VL). + + Parameters + ---------- + *args : Any + Additional positional arguments for :class:`QAAccuracyMetric`. + vlm : BaseVLM | None, optional + Custom VLM instance. If provided, ``vlm_type`` and ``model_name`` are ignored. + vlm_type : {"litellm", "transformers"}, optional + VLM backend. Default is ``"litellm"``. + model_name : str | None, optional + Litellm model id or HuggingFace checkpoint id. **Required** when ``vlm`` is not + provided (e.g. ``openai/gpt-4o``). + vlm_kwargs : dict, optional + Forwarded by ``get_vlm`` to ``LitellmVLM`` or ``TransformersVLM``. For local models, + set ``model_load_kwargs`` for ``from_pretrained``; for litellm, pass extra API options. + structured_output : bool, optional + Use structured generation (litellm pydantic; transformers outlines when applicable). + Default is True. + device : str | torch.device | None, optional + Device for transformers VLM. + api_key : str | None, optional + API key for litellm. + call_type : str, optional + Call type for the metric. + **kwargs : Any + Additional keyword arguments for :class:`QAAccuracyMetric`. + + Examples + -------- + Same ``hosted`` / ``local`` pattern as ``QAAccuracyMetric`` and + :func:`~pruna.evaluation.metrics.vlm_base.get_vlm`: + + .. code-block:: python + + import torch + + from pruna.evaluation.metrics import OneIGAlignmentMetric + + hosted = OneIGAlignmentMetric(vlm_type="litellm", model_name="openai/gpt-4o") + local = OneIGAlignmentMetric( + vlm_type="transformers", + model_name="HuggingFaceTB/SmolVLM-256M-Instruct", + device="cpu", + vlm_kwargs={"model_load_kwargs": {"torch_dtype": torch.float32}}, + ) + """ + + metric_name: str = "oneig_alignment" + metric_units: str = "alignment" + + def update(self, x: list[Any] | torch.Tensor, gt: torch.Tensor, outputs: torch.Tensor) -> None: + """ + Score each question with the VLM, apply dependency masking, append per-cell mean. + + Parameters + ---------- + x : list[Any] | torch.Tensor + Unused batch metadata (kept for metric interface). + gt : torch.Tensor + Ground-truth slot holding per-sample aux dicts with ``questions`` and + optionally ``dependencies``. + outputs : torch.Tensor + Model outputs (images) evaluated against the questions. + """ + inputs = metric_data_processor(x, gt, outputs, self.call_type) + images = _process_images(inputs[0]) + aux_list = inputs[1] if len(inputs) > 1 else [] + if isinstance(aux_list, torch.Tensor): + aux_list = aux_list.tolist() + for i, image in enumerate(images): + aux = aux_list[i] if i < len(aux_list) else {} + if not isinstance(aux, dict): + raise ValueError( + "oneig_alignment requires aux[{}] to be a dict with 'questions'. Got: {!r}.".format(i, type(aux)) + ) + qs = aux.get("questions") + if not isinstance(qs, dict) or not qs: + raise ValueError( + f"oneig_alignment requires 'questions' as a non-empty dict on aux. Got keys: {list(aux.keys())}." + ) + qmap = _int_dict_keys(qs) + qids = _active_oneig_question_ids(qmap) + if not qids: + self.scores.append(0.0) + continue + question_texts = [str(qmap[qi]) for qi in qids] + deps = _normalize_dependencies(aux.get("dependencies", {})) + raw_scores_list = self.vlm.score( + [image] * len(question_texts), + question_texts, + ["Yes"] * len(question_texts), + response_format=self.response_format, + ) + raw_map = {qid: float(raw_scores_list[j]) for j, qid in enumerate(qids)} + filtered = apply_oneig_dependency_mask(raw_map, deps) + self.scores.append(aggregate_oneig_alignment_per_cell(filtered, qids)) diff --git a/tests/evaluation/test_text_metrics.py b/tests/evaluation/test_text_metrics.py new file mode 100644 index 00000000..12705e91 --- /dev/null +++ b/tests/evaluation/test_text_metrics.py @@ -0,0 +1,136 @@ +"""Tests for OneIG alignment masking and wiring.""" + +from unittest.mock import MagicMock + +import pytest +import torch + +from pruna.data.datasets.prompt import _to_oneig_record +from pruna.evaluation.metrics.metric_oneig_alignment import ( + OneIGAlignmentMetric, + _active_oneig_question_ids, + aggregate_oneig_alignment_per_cell, + apply_oneig_dependency_mask, +) +from pruna.evaluation.metrics.vlm_base import BaseVLM + + +def test_apply_oneig_dependency_mask_parent_no_zeros_child() -> None: + """Parent ``No`` forces dependent question score to zero.""" + raw = {1: 0.0, 2: 1.0} + deps = {1: [0], 2: [1]} + out = apply_oneig_dependency_mask(raw, deps) + assert out[1] == 0.0 + assert out[2] == 0.0 + assert aggregate_oneig_alignment_per_cell(out, [1, 2]) == 0.0 + + +def test_apply_oneig_dependency_mask_parent_yes_keeps_child() -> None: + """All ``Yes`` yields nonzero child and mean 1.0 over two questions.""" + raw = {1: 1.0, 2: 1.0} + deps = {1: [0], 2: [1]} + out = apply_oneig_dependency_mask(raw, deps) + assert out == {1: 1.0, 2: 1.0} + assert aggregate_oneig_alignment_per_cell(out, [1, 2]) == 1.0 + + +def test_apply_oneig_dependency_mask_uses_raw_parent_not_filtered_for_chain() -> None: + r"""Grandchild may stay 1 when parent's **raw** VLM score is Yes even if parent was masked to 0.""" + raw = {1: 0.0, 2: 1.0, 3: 1.0} + deps = {1: [0], 2: [1], 3: [2]} + out = apply_oneig_dependency_mask(raw, deps) + assert out[1] == 0.0 + assert out[2] == 0.0 + assert out[3] == 1.0 + + +def test_apply_oneig_dependency_mask_grandchild_chain() -> None: + """3-level chain: grandparent No masks parent; grandchild uses raw parent (stays 1.0).""" + raw_scores = {1: 0.0, 2: 1.0, 3: 1.0} + dependencies = {2: [1], 3: [2]} + filtered = apply_oneig_dependency_mask(raw_scores, dependencies) + assert filtered[2] == 0.0 + assert filtered[3] == 1.0 + assert filtered[1] == 0.0 + + +def test_active_oneig_question_ids_skips_padding() -> None: + """Padded ``None`` and blank slots are excluded; numeric order preserved.""" + qmap = {1: "a", 21: None, 3: " ", 2: "b"} + assert _active_oneig_question_ids(qmap) == [1, 2] + + +def test_active_oneig_question_ids_skips_literal_none_string() -> None: + r"""The literal ``\"None\"`` string is treated as a missing label (legacy / bad rows).""" + assert _active_oneig_question_ids({1: "None", 2: "ok"}) == [2] + + +@pytest.mark.cpu +def test_oneig_alignment_metric_respects_question_id_order() -> None: + """Questions are scored in numeric id order; masking uses aligned raw scores.""" + mock_vlm = MagicMock(spec=BaseVLM) + mock_vlm.score.return_value = [0.0, 1.0] + + metric = OneIGAlignmentMetric(vlm=mock_vlm, vlm_type="litellm", device="cpu") + images = torch.rand(1, 3, 64, 64) + aux = { + "questions": {"2": "second", "1": "first"}, + "dependencies": {"1": [0], "2": [1]}, + } + metric.update(["p"], [aux], images) + result = metric.compute() + assert result.name == "oneig_alignment" + assert result.higher_is_better is True + assert result.metric_units == "alignment" + assert result.result == 0.0 + call = mock_vlm.score.call_args + assert call[0][1] == ["first", "second"] + + +@pytest.mark.cpu +def test_oneig_alignment_skips_none_question_texts() -> None: + """HF ``datasets`` schema padding (``None`` question text) is not sent to the VLM.""" + mock_vlm = MagicMock(spec=BaseVLM) + mock_vlm.score.return_value = [1.0] + + metric = OneIGAlignmentMetric(vlm=mock_vlm, vlm_type="litellm", device="cpu") + images = torch.rand(1, 3, 64, 64) + aux = { + "questions": {"1": "first", "21": None}, + "dependencies": {"1": [0], "21": [0]}, + } + metric.update(["p"], [aux], images) + result = metric.compute() + assert result.name == "oneig_alignment" + assert result.result == 1.0 + mock_vlm.score.assert_called_once() + assert mock_vlm.score.call_args[0][1] == ["first"] + + +@pytest.mark.cpu +def test_oneig_alignment_all_padding_questions_yields_zero_without_vlm() -> None: + """When every slot is padding, score is 0.0 and the VLM is not called.""" + mock_vlm = MagicMock(spec=BaseVLM) + metric = OneIGAlignmentMetric(vlm=mock_vlm, vlm_type="litellm", device="cpu") + aux = {"questions": {"1": None, "2": None}, "dependencies": {}} + metric.update(["p"], [aux], torch.rand(1, 3, 64, 64)) + assert metric.compute().result == 0.0 + mock_vlm.score.assert_not_called() + + +def test_to_oneig_record_strips_null_questions_and_dependencies() -> None: + """Null-valued Q_D entries are filtered out at record construction time.""" + row = {"category": "Anime_Stylization", "id": "001", "class": "None", "prompt_en": "a cat"} + questions_by_key = { + "anime_001": { + "questions": {"1": "Is there a cat?", "21": None}, + "dependencies": {"1": [0], "21": None}, + } + } + record = _to_oneig_record(row, questions_by_key, {}, {}) + assert "21" not in record["questions"] + assert "21" not in record["dependencies"] + assert record["questions"] == {"1": "Is there a cat?"} + assert record["dependencies"] == {"1": [0]} + + From c916183123ed062fd2fb0b1af729b36e563db061 Mon Sep 17 00:00:00 2001 From: David Berenstein Date: Tue, 2 Jun 2026 19:27:08 +0200 Subject: [PATCH 2/2] feat(metrics): paper-faithful OneIG alignment scoring Split 2x2 grids, one VLM question per call, Qwen2.5-VL default, and strict list[dict] aux validation. Co-authored-by: Cursor --- src/pruna/evaluation/metrics/__init__.py | 2 + .../metrics/metric_oneig_alignment.py | 171 +++++++++++------- tests/evaluation/test_text_metrics.py | 13 +- 3 files changed, 120 insertions(+), 66 deletions(-) diff --git a/src/pruna/evaluation/metrics/__init__.py b/src/pruna/evaluation/metrics/__init__.py index 389b9533..49cfe904 100644 --- a/src/pruna/evaluation/metrics/__init__.py +++ b/src/pruna/evaluation/metrics/__init__.py @@ -23,6 +23,7 @@ from pruna.evaluation.metrics.metric_memory import DiskMemoryMetric, InferenceMemoryMetric, TrainingMemoryMetric from pruna.evaluation.metrics.metric_model_architecture import TotalMACsMetric, TotalParamsMetric from pruna.evaluation.metrics.metric_pairwise_clip import PairwiseClipScore +from pruna.evaluation.metrics.metric_oneig_alignment import OneIGAlignmentMetric from pruna.evaluation.metrics.metric_qa_accuracy import QAAccuracyMetric from pruna.evaluation.metrics.metric_rapiddata import RapidataMetric as RapidataMetric from pruna.evaluation.metrics.metric_sharpness import SharpnessMetric @@ -54,6 +55,7 @@ "SharpnessMetric", "AestheticLAION", "LMEvalMetric", + "OneIGAlignmentMetric", "QAAccuracyMetric", "RapidataMetric", "BaseVLM", diff --git a/src/pruna/evaluation/metrics/metric_oneig_alignment.py b/src/pruna/evaluation/metrics/metric_oneig_alignment.py index ad443283..0f372f4f 100644 --- a/src/pruna/evaluation/metrics/metric_oneig_alignment.py +++ b/src/pruna/evaluation/metrics/metric_oneig_alignment.py @@ -16,14 +16,17 @@ from __future__ import annotations -from typing import Any, Mapping +from typing import Any, Literal, Mapping import torch +from PIL import Image from pruna.evaluation.metrics.metric_qa_accuracy import QAAccuracyMetric from pruna.evaluation.metrics.registry import MetricRegistry from pruna.evaluation.metrics.utils import metric_data_processor -from pruna.evaluation.metrics.vlm_utils import _process_images +from pruna.evaluation.metrics.vlm_utils import _process_images, split_mxn_grid + +_DEFAULT_ONEIG_ALIGNMENT_VLM = "Qwen/Qwen2.5-VL-7B-Instruct" def _int_dict_keys(mapping: Mapping[Any, Any]) -> dict[int, Any]: @@ -122,73 +125,146 @@ def aggregate_oneig_alignment_per_cell(filtered_scores: Mapping[int, float], que return s / float(len(question_ids)) +def _aux_list_from_gt(aux_slot: Any, batch_size: int) -> list[dict[str, Any]]: + if isinstance(aux_slot, torch.Tensor): + raise ValueError( + "oneig_alignment expects gt as list[dict] with 'questions' and optional 'dependencies'. " + f"Got tensor with shape {tuple(aux_slot.shape)}." + ) + if not isinstance(aux_slot, (list, tuple)): + return [{} for _ in range(batch_size)] + out: list[dict[str, Any]] = [] + for i in range(batch_size): + row = aux_slot[i] if i < len(aux_slot) else {} + if not isinstance(row, dict): + raise ValueError(f"oneig_alignment requires aux[{i}] to be a dict. Got: {type(row)!r}.") + out.append(row) + return out + + @MetricRegistry.register("oneig_alignment") class OneIGAlignmentMetric(QAAccuracyMetric): """ OneIG alignment with dependency-aware aggregation. - Reuses :class:`QAAccuracyMetric` VLM Yes/No scoring but aggregates like - ``OneIG-Benchmark`` ``alignment_score.py`` for a **single** grid cell (no - ``split_mxn_grid``): question ids are sorted numerically, raw scores are - masked when any non-root parent is ``No``, then the mean over all questions - is stored per image. Entries with null or blank question text (HF ``datasets`` - schema padding) are omitted from scoring. + Matches ``OneIG-Benchmark`` ``alignment_score.py``: split an ``m x n`` output grid + (default ``2 x 2``), score **one question per VLM call** across all cells, apply + dependency masking per cell, then average cell scores. - Numerical parity with upstream also depends on the VLM (e.g. ``openai/gpt-4o`` via - litellm vs reference Qwen2.5-VL). + Scoring semantics + ----------------- + OneIG Q_D probes are phrased so **Yes = aligned**. Each call requests + :meth:`~pruna.evaluation.metrics.vlm_base.BaseVLM.score` with expected answer + ``"Yes"`` (probability of Yes). Low scores act as semantic **No** for dependency + masking. Parameters ---------- - *args : Any - Additional positional arguments for :class:`QAAccuracyMetric`. + grid_size : tuple[int, int], optional + ``(columns, rows)`` for :func:`~pruna.evaluation.metrics.vlm_utils.split_mxn_grid`. + Default ``(2, 2)`` per OneIG. Use ``(1, 1)`` to score the full image without splitting. vlm : BaseVLM | None, optional Custom VLM instance. If provided, ``vlm_type`` and ``model_name`` are ignored. vlm_type : {"litellm", "transformers"}, optional - VLM backend. Default is ``"litellm"``. + VLM backend. Default is ``"transformers"`` (paper-faithful Qwen2.5-VL). model_name : str | None, optional - Litellm model id or HuggingFace checkpoint id. **Required** when ``vlm`` is not - provided (e.g. ``openai/gpt-4o``). + HuggingFace or litellm model id. Default ``Qwen/Qwen2.5-VL-7B-Instruct``. vlm_kwargs : dict, optional - Forwarded by ``get_vlm`` to ``LitellmVLM`` or ``TransformersVLM``. For local models, - set ``model_load_kwargs`` for ``from_pretrained``; for litellm, pass extra API options. + Forwarded by ``get_vlm``. structured_output : bool, optional - Use structured generation (litellm pydantic; transformers outlines when applicable). - Default is True. + Use structured generation when applicable. device : str | torch.device | None, optional Device for transformers VLM. api_key : str | None, optional API key for litellm. call_type : str, optional Call type for the metric. + aggregation : str, optional + Unused; kept for registry compatibility with :class:`QAAccuracyMetric`. **kwargs : Any Additional keyword arguments for :class:`QAAccuracyMetric`. Examples -------- - Same ``hosted`` / ``local`` pattern as ``QAAccuracyMetric`` and - :func:`~pruna.evaluation.metrics.vlm_base.get_vlm`: - .. code-block:: python - import torch - from pruna.evaluation.metrics import OneIGAlignmentMetric - hosted = OneIGAlignmentMetric(vlm_type="litellm", model_name="openai/gpt-4o") - local = OneIGAlignmentMetric( - vlm_type="transformers", - model_name="HuggingFaceTB/SmolVLM-256M-Instruct", - device="cpu", - vlm_kwargs={"model_load_kwargs": {"torch_dtype": torch.float32}}, - ) + paper = OneIGAlignmentMetric(device="cuda") + api = OneIGAlignmentMetric(vlm_type="litellm", model_name="openai/gpt-4o") """ metric_name: str = "oneig_alignment" metric_units: str = "alignment" + def __init__( + self, + *args: Any, + grid_size: tuple[int, int] = (2, 2), + vlm: Any | None = None, + vlm_type: Literal["litellm", "transformers"] = "transformers", + model_name: str | None = _DEFAULT_ONEIG_ALIGNMENT_VLM, + vlm_kwargs: dict | None = None, + structured_output: bool = True, + device: str | torch.device | None = None, + api_key: str | None = None, + call_type: str | None = None, + **kwargs: Any, + ) -> None: + super().__init__( + *args, + vlm=vlm, + vlm_type=vlm_type, + model_name=model_name, + vlm_kwargs=vlm_kwargs, + structured_output=structured_output, + device=device, + api_key=api_key, + call_type=call_type if call_type is not None else "y_gt", + **kwargs, + ) + self.grid_size = (int(grid_size[0]), int(grid_size[1])) + + def _score_sample(self, image: Any, aux: dict[str, Any]) -> float: + if not isinstance(image, Image.Image): + if isinstance(image, torch.Tensor): + from pruna.evaluation.metrics.vlm_utils import _tensor_to_pil + + image = _tensor_to_pil(image) + else: + image = Image.fromarray(image).convert("RGB") + cells = split_mxn_grid(image, self.grid_size) + qs = aux.get("questions") + if not isinstance(qs, dict) or not qs: + raise ValueError( + f"oneig_alignment requires 'questions' as a non-empty dict on aux. Got keys: {list(aux.keys())}." + ) + qmap = _int_dict_keys(qs) + qids = _active_oneig_question_ids(qmap) + if not qids: + return 0.0 + deps = _normalize_dependencies(aux.get("dependencies", {})) + per_question_cell_scores: dict[int, list[float]] = {} + n_cells = len(cells) + for qid in qids: + qtext = str(qmap[qid]) + raw_scores_list = self.vlm.score( + cells, + [qtext] * n_cells, + ["Yes"] * n_cells, + response_format=self.response_format, + ) + per_question_cell_scores[qid] = [float(s) for s in raw_scores_list] + cell_means: list[float] = [] + for cell_i in range(n_cells): + raw_map = {qid: per_question_cell_scores[qid][cell_i] for qid in qids} + filtered = apply_oneig_dependency_mask(raw_map, deps) + cell_means.append(aggregate_oneig_alignment_per_cell(filtered, qids)) + return float(sum(cell_means) / len(cell_means)) + def update(self, x: list[Any] | torch.Tensor, gt: torch.Tensor, outputs: torch.Tensor) -> None: """ - Score each question with the VLM, apply dependency masking, append per-cell mean. + Score each prompt image with OneIG alignment (grid split + per-question VLM calls). Parameters ---------- @@ -202,33 +278,6 @@ def update(self, x: list[Any] | torch.Tensor, gt: torch.Tensor, outputs: torch.T """ inputs = metric_data_processor(x, gt, outputs, self.call_type) images = _process_images(inputs[0]) - aux_list = inputs[1] if len(inputs) > 1 else [] - if isinstance(aux_list, torch.Tensor): - aux_list = aux_list.tolist() + aux_list = _aux_list_from_gt(inputs[1] if len(inputs) > 1 else [], len(images)) for i, image in enumerate(images): - aux = aux_list[i] if i < len(aux_list) else {} - if not isinstance(aux, dict): - raise ValueError( - "oneig_alignment requires aux[{}] to be a dict with 'questions'. Got: {!r}.".format(i, type(aux)) - ) - qs = aux.get("questions") - if not isinstance(qs, dict) or not qs: - raise ValueError( - f"oneig_alignment requires 'questions' as a non-empty dict on aux. Got keys: {list(aux.keys())}." - ) - qmap = _int_dict_keys(qs) - qids = _active_oneig_question_ids(qmap) - if not qids: - self.scores.append(0.0) - continue - question_texts = [str(qmap[qi]) for qi in qids] - deps = _normalize_dependencies(aux.get("dependencies", {})) - raw_scores_list = self.vlm.score( - [image] * len(question_texts), - question_texts, - ["Yes"] * len(question_texts), - response_format=self.response_format, - ) - raw_map = {qid: float(raw_scores_list[j]) for j, qid in enumerate(qids)} - filtered = apply_oneig_dependency_mask(raw_map, deps) - self.scores.append(aggregate_oneig_alignment_per_cell(filtered, qids)) + self.scores.append(self._score_sample(image, aux_list[i])) diff --git a/tests/evaluation/test_text_metrics.py b/tests/evaluation/test_text_metrics.py index 12705e91..a5931bae 100644 --- a/tests/evaluation/test_text_metrics.py +++ b/tests/evaluation/test_text_metrics.py @@ -71,7 +71,8 @@ def test_oneig_alignment_metric_respects_question_id_order() -> None: mock_vlm = MagicMock(spec=BaseVLM) mock_vlm.score.return_value = [0.0, 1.0] - metric = OneIGAlignmentMetric(vlm=mock_vlm, vlm_type="litellm", device="cpu") + metric = OneIGAlignmentMetric(vlm=mock_vlm, vlm_type="litellm", device="cpu", grid_size=(1, 1)) + mock_vlm.score.side_effect = [[0.0], [1.0]] images = torch.rand(1, 3, 64, 64) aux = { "questions": {"2": "second", "1": "first"}, @@ -83,8 +84,9 @@ def test_oneig_alignment_metric_respects_question_id_order() -> None: assert result.higher_is_better is True assert result.metric_units == "alignment" assert result.result == 0.0 - call = mock_vlm.score.call_args - assert call[0][1] == ["first", "second"] + assert mock_vlm.score.call_count == 2 + assert mock_vlm.score.call_args_list[0][0][1] == ["first"] + assert mock_vlm.score.call_args_list[1][0][1] == ["second"] @pytest.mark.cpu @@ -93,7 +95,8 @@ def test_oneig_alignment_skips_none_question_texts() -> None: mock_vlm = MagicMock(spec=BaseVLM) mock_vlm.score.return_value = [1.0] - metric = OneIGAlignmentMetric(vlm=mock_vlm, vlm_type="litellm", device="cpu") + metric = OneIGAlignmentMetric(vlm=mock_vlm, vlm_type="litellm", device="cpu", grid_size=(1, 1)) + mock_vlm.score.return_value = [1.0] images = torch.rand(1, 3, 64, 64) aux = { "questions": {"1": "first", "21": None}, @@ -111,7 +114,7 @@ def test_oneig_alignment_skips_none_question_texts() -> None: def test_oneig_alignment_all_padding_questions_yields_zero_without_vlm() -> None: """When every slot is padding, score is 0.0 and the VLM is not called.""" mock_vlm = MagicMock(spec=BaseVLM) - metric = OneIGAlignmentMetric(vlm=mock_vlm, vlm_type="litellm", device="cpu") + metric = OneIGAlignmentMetric(vlm=mock_vlm, vlm_type="litellm", device="cpu", grid_size=(1, 1)) aux = {"questions": {"1": None, "2": None}, "dependencies": {}} metric.update(["p"], [aux], torch.rand(1, 3, 64, 64)) assert metric.compute().result == 0.0