From f4c647bca09ddc054308824629a8eb0c0986054b Mon Sep 17 00:00:00 2001 From: Michael Welborn Date: Tue, 4 Feb 2025 09:36:39 -0600 Subject: [PATCH 01/16] Rename `etl_output_url` to the more accurate `etl_output_uri` --- indico_toolkit/etloutput/__init__.py | 48 ++++++++++++++-------------- indico_toolkit/polling/autoreview.py | 2 +- indico_toolkit/results/document.py | 10 +++--- tests/results/test_predictionlist.py | 2 +- 4 files changed, 31 insertions(+), 31 deletions(-) diff --git a/indico_toolkit/etloutput/__init__.py b/indico_toolkit/etloutput/__init__.py index a9e43e8..64e3858 100644 --- a/indico_toolkit/etloutput/__init__.py +++ b/indico_toolkit/etloutput/__init__.py @@ -26,7 +26,7 @@ def load( - etl_output_url: str, + etl_output_uri: str, *, reader: "Callable[..., Any]", text: bool = True, @@ -34,30 +34,30 @@ def load( tables: bool = False, ) -> EtlOutput: """ - Load `etl_output_url` as an ETL Output dataclass. A `reader` function must be + Load `etl_output_uri` as an ETL Output dataclass. A `reader` function must be supplied to read JSON files from disk, storage API, or Indico client. Use `text`, `tokens`, and `tables` to specify what to load. ``` - result = results.load(submission.result_file, reader=read_url) + result = results.load(submission.result_file, reader=read_uri) etl_outputs = { - document: etloutput.load(document.etl_output_url, reader=read_url) + document: etloutput.load(document.etl_output_uri, reader=read_uri) for document in result.documents } ``` """ - etl_output = reader(etl_output_url) - tables_url = etl_output_url.replace("etl_output.json", "tables.json") + etl_output = reader(etl_output_uri) + tables_uri = etl_output_uri.replace("etl_output.json", "tables.json") if has(etl_output, str, "pages", 0, "page_info"): - return _load_v1(etl_output, tables_url, reader, text, tokens, tables) + return _load_v1(etl_output, tables_uri, reader, text, tokens, tables) else: - return _load_v3(etl_output, tables_url, reader, text, tokens, tables) + return _load_v3(etl_output, tables_uri, reader, text, tokens, tables) async def load_async( - etl_output_url: str, + etl_output_uri: str, *, reader: "Callable[..., Awaitable[Any]]", text: bool = True, @@ -65,35 +65,35 @@ async def load_async( tables: bool = False, ) -> EtlOutput: """ - Load `etl_output_url` as an ETL Output dataclass. A `reader` coroutine must be + Load `etl_output_uri` as an ETL Output dataclass. A `reader` coroutine must be supplied to read JSON files from disk, storage API, or Indico client. Use `text`, `tokens`, and `tables` to specify what to load. ``` - result = await results.load_async(submission.result_file, reader=read_url) + result = await results.load_async(submission.result_file, reader=read_uri) etl_outputs = { - document: await etloutput.load_async(document.etl_output_url, reader=read_url) + document: await etloutput.load_async(document.etl_output_uri, reader=read_uri) for document in result.documents } ``` """ - etl_output = await reader(etl_output_url) - tables_url = etl_output_url.replace("etl_output.json", "tables.json") + etl_output = await reader(etl_output_uri) + tables_uri = etl_output_uri.replace("etl_output.json", "tables.json") if has(etl_output, str, "pages", 0, "page_info"): return await _load_v1_async( - etl_output, tables_url, reader, text, tokens, tables + etl_output, tables_uri, reader, text, tokens, tables ) else: return await _load_v3_async( - etl_output, tables_url, reader, text, tokens, tables + etl_output, tables_uri, reader, text, tokens, tables ) def _load_v1( etl_output: "Any", - tables_url: str, + tables_uri: str, reader: "Callable[..., Any]", text: bool, tokens: bool, @@ -111,7 +111,7 @@ def _load_v1( tokens_by_page = () # type: ignore[assignment] if tables: - tables_by_page = reader(tables_url) + tables_by_page = reader(tables_uri) else: tables_by_page = () @@ -120,7 +120,7 @@ def _load_v1( def _load_v3( etl_output: "Any", - tables_url: str, + tables_uri: str, reader: "Callable[..., Any]", text: bool, tokens: bool, @@ -139,7 +139,7 @@ def _load_v3( tokens_by_page = () # type: ignore[assignment] if tables: - tables_by_page = reader(tables_url) + tables_by_page = reader(tables_uri) else: tables_by_page = () @@ -148,7 +148,7 @@ def _load_v3( async def _load_v1_async( etl_output: "Any", - tables_url: str, + tables_uri: str, reader: "Callable[..., Awaitable[Any]]", text: bool, tokens: bool, @@ -166,7 +166,7 @@ async def _load_v1_async( tokens_by_page = () # type: ignore[assignment] if tables: - tables_by_page = await reader(tables_url) + tables_by_page = await reader(tables_uri) else: tables_by_page = () @@ -175,7 +175,7 @@ async def _load_v1_async( async def _load_v3_async( etl_output: "Any", - tables_url: str, + tables_uri: str, reader: "Callable[..., Awaitable[Any]]", text: bool, tokens: bool, @@ -194,7 +194,7 @@ async def _load_v3_async( tokens_by_page = () # type: ignore[assignment] if tables: - tables_by_page = await reader(tables_url) + tables_by_page = await reader(tables_uri) else: tables_by_page = () diff --git a/indico_toolkit/polling/autoreview.py b/indico_toolkit/polling/autoreview.py index c93fa9a..4d305d5 100644 --- a/indico_toolkit/polling/autoreview.py +++ b/indico_toolkit/polling/autoreview.py @@ -151,7 +151,7 @@ async def _worker(self, submission_id: "SubmissionId") -> None: logger.info(f"Retrieving etl output for {submission_id=}") etl_outputs = { document: await etloutput.load_async( - document.etl_output_url, + document.etl_output_uri, reader=self._retrieve_storage_object, text=self._load_text, tokens=self._load_tokens, diff --git a/indico_toolkit/results/document.py b/indico_toolkit/results/document.py index 264070a..8522f29 100644 --- a/indico_toolkit/results/document.py +++ b/indico_toolkit/results/document.py @@ -7,7 +7,7 @@ class Document: id: int name: str - etl_output_url: str + etl_output_uri: str # Auto review changes must reproduce all model sections that were present in the # original result file. This may not be possible from the predictions alone--if a @@ -24,13 +24,13 @@ def from_v1_dict(result: object) -> "Document": """ document_results = get(result, dict, "results", "document", "results") model_names = frozenset(document_results.keys()) - etl_output_url = get(result, str, "etl_output") + etl_output_uri = get(result, str, "etl_output") return Document( # v1 result files don't include document IDs or filenames. id=None, # type: ignore[arg-type] name=None, # type: ignore[arg-type] - etl_output_url=etl_output_url, + etl_output_uri=etl_output_uri, _model_sections=model_names, ) @@ -41,11 +41,11 @@ def from_v3_dict(document: object) -> "Document": """ model_results = get(document, dict, "model_results", "ORIGINAL") model_ids = frozenset(model_results.keys()) - etl_output_url = get(document, str, "etl_output") + etl_output_uri = get(document, str, "etl_output") return Document( id=get(document, int, "submissionfile_id"), name=get(document, str, "input_filename"), - etl_output_url=etl_output_url, + etl_output_uri=etl_output_uri, _model_sections=model_ids, ) diff --git a/tests/results/test_predictionlist.py b/tests/results/test_predictionlist.py index bbf9b53..a606e0d 100644 --- a/tests/results/test_predictionlist.py +++ b/tests/results/test_predictionlist.py @@ -21,7 +21,7 @@ def document() -> Document: return Document( id=2922, name="1040_filled.tiff", - etl_output_url="indico-file:///storage/submission/2922/etl_output.json", + etl_output_uri="indico-file:///storage/submission/2922/etl_output.json", _model_sections=frozenset({"124", "123", "122", "121"}), ) From 8aa7464090a060fc601d8f65fd93ba7fb1bfa5e5 Mon Sep 17 00:00:00 2001 From: Michael Welborn Date: Tue, 4 Feb 2025 10:07:50 -0600 Subject: [PATCH 02/16] Update comments and formatting --- indico_toolkit/etloutput/cell.py | 2 +- indico_toolkit/etloutput/etloutput.py | 2 +- indico_toolkit/etloutput/table.py | 2 +- indico_toolkit/etloutput/token.py | 2 +- indico_toolkit/results/predictionlist.py | 2 ++ indico_toolkit/results/predictions/documentextraction.py | 2 +- indico_toolkit/results/review.py | 2 +- indico_toolkit/results/utilities.py | 1 - tests/results/test_predictionlist.py | 8 +++++--- 9 files changed, 13 insertions(+), 10 deletions(-) diff --git a/indico_toolkit/etloutput/cell.py b/indico_toolkit/etloutput/cell.py index 4ce97b1..b915665 100644 --- a/indico_toolkit/etloutput/cell.py +++ b/indico_toolkit/etloutput/cell.py @@ -42,7 +42,7 @@ def __lt__(self, other: "Cell") -> bool: @staticmethod def from_dict(cell: object, page: int) -> "Cell": """ - Create a `Cell` from a v1 or v3 ETL Ouput cell dictionary. + Create a `Cell` from a v1 or v3 cell dictionary. """ return Cell( type=CellType(get(cell, str, "cell_type")), diff --git a/indico_toolkit/etloutput/etloutput.py b/indico_toolkit/etloutput/etloutput.py index 33c3b46..537f273 100644 --- a/indico_toolkit/etloutput/etloutput.py +++ b/indico_toolkit/etloutput/etloutput.py @@ -33,7 +33,7 @@ def from_pages( table_dicts_by_page: "Iterable[Iterable[object]]", ) -> "EtlOutput": """ - Create an `EtlOutput` from v1 or v3 ETL Ouput pages. + Create an `EtlOutput` from v1 or v3 page lists. """ text_by_page = tuple(text_by_page) tokens_by_page = tuple( diff --git a/indico_toolkit/etloutput/table.py b/indico_toolkit/etloutput/table.py index 8def075..c8d34d2 100644 --- a/indico_toolkit/etloutput/table.py +++ b/indico_toolkit/etloutput/table.py @@ -35,7 +35,7 @@ def __lt__(self, other: "Table") -> bool: @staticmethod def from_dict(table: object) -> "Table": """ - Create a `Table` from a v1 or v3 ETL Ouput table dictionary. + Create a `Table` from a v1 or v3 table dictionary. """ page = get(table, int, "page_num") cells = tuple( diff --git a/indico_toolkit/etloutput/token.py b/indico_toolkit/etloutput/token.py index 26a1d07..ec045df 100644 --- a/indico_toolkit/etloutput/token.py +++ b/indico_toolkit/etloutput/token.py @@ -37,7 +37,7 @@ def __lt__(self, other: "Token") -> bool: @staticmethod def from_dict(token: object) -> "Token": """ - Create a `Token` from a v1 or v3 ETL Ouput token dictionary. + Create a `Token` from a v1 or v3 token dictionary. """ return Token( text=get(token, str, "text"), diff --git a/indico_toolkit/results/predictionlist.py b/indico_toolkit/results/predictionlist.py index 14882a2..8612a7e 100644 --- a/indico_toolkit/results/predictionlist.py +++ b/indico_toolkit/results/predictionlist.py @@ -57,8 +57,10 @@ def unbundlings(self) -> "PredictionList[Unbundling]": @overload def __getitem__(self, index: "SupportsIndex", /) -> PredictionType: ... + @overload def __getitem__(self, index: slice, /) -> "PredictionList[PredictionType]": ... + def __getitem__( self, index: "SupportsIndex | slice" ) -> "PredictionType | PredictionList[PredictionType]": diff --git a/indico_toolkit/results/predictions/documentextraction.py b/indico_toolkit/results/predictions/documentextraction.py index 9b2a7c2..15e34e9 100644 --- a/indico_toolkit/results/predictions/documentextraction.py +++ b/indico_toolkit/results/predictions/documentextraction.py @@ -27,7 +27,7 @@ def from_v1_dict( prediction: object, ) -> "DocumentExtraction": """ - Create n `DocumentExtraction` from a v1 prediction dictionary. + Create a `DocumentExtraction` from a v1 prediction dictionary. """ return DocumentExtraction( document=document, diff --git a/indico_toolkit/results/review.py b/indico_toolkit/results/review.py index 3c46461..a1930aa 100644 --- a/indico_toolkit/results/review.py +++ b/indico_toolkit/results/review.py @@ -21,7 +21,7 @@ class Review: @staticmethod def from_dict(review: object) -> "Review": """ - Create a `Review` from a result file review dictionary. + Create a `Review` from a review dictionary. """ return Review( id=get(review, int, "review_id"), diff --git a/indico_toolkit/results/utilities.py b/indico_toolkit/results/utilities.py index 1a204d6..3b016bb 100644 --- a/indico_toolkit/results/utilities.py +++ b/indico_toolkit/results/utilities.py @@ -66,7 +66,6 @@ def omit(dictionary: object, *keys: str) -> "dict[str, Value]": """ if not isinstance(dictionary, dict): return {} - return { key: value for key, value in dictionary.items() diff --git a/tests/results/test_predictionlist.py b/tests/results/test_predictionlist.py index a606e0d..4d062bb 100644 --- a/tests/results/test_predictionlist.py +++ b/tests/results/test_predictionlist.py @@ -209,16 +209,17 @@ def test_where_review( assert predictions.where(review=auto_review) == [first_name] assert predictions.where(review=ReviewType.MANUAL) == [last_name] + def test_where_review_in( predictions: "PredictionList[Prediction]", auto_review: Review ) -> None: classification, first_name, last_name = predictions assert predictions.where(review_in={None}) == [classification] assert predictions.where( - review_in={None, auto_review} + review_in={None, auto_review}, ) == [classification, first_name] assert predictions.where( - review_in={auto_review, ReviewType.MANUAL} + review_in={auto_review, ReviewType.MANUAL}, ) == [first_name, last_name] assert predictions.where(review_in={}) == [] @@ -231,7 +232,7 @@ def test_where_label(predictions: "PredictionList[Prediction]") -> None: def test_where_label_in(predictions: "PredictionList[Prediction]") -> None: first_name, last_name = predictions.extractions assert predictions.where( - label_in=("First Name", "Last Name") + label_in=("First Name", "Last Name"), ) == [first_name, last_name] @@ -264,6 +265,7 @@ def test_where_accepted(predictions: "PredictionList[Prediction]") -> None: assert predictions.where(accepted=False) == [] assert predictions.where(accepted=True) == [first_name, last_name] + def test_where_rejected(predictions: "PredictionList[Prediction]") -> None: first_name, last_name = predictions.extractions predictions.unreject() From 3b3ad8b670118db041c270e4b7b344a03897a15a Mon Sep 17 00:00:00 2001 From: Michael Welborn Date: Tue, 4 Feb 2025 10:10:35 -0600 Subject: [PATCH 03/16] Reuse `results.utilities` in `etloutput` --- indico_toolkit/etloutput/__init__.py | 2 +- indico_toolkit/etloutput/cell.py | 2 +- indico_toolkit/etloutput/table.py | 2 +- indico_toolkit/etloutput/token.py | 2 +- indico_toolkit/etloutput/utilities.py | 44 --------------------------- 5 files changed, 4 insertions(+), 48 deletions(-) delete mode 100644 indico_toolkit/etloutput/utilities.py diff --git a/indico_toolkit/etloutput/__init__.py b/indico_toolkit/etloutput/__init__.py index 64e3858..5e84a4c 100644 --- a/indico_toolkit/etloutput/__init__.py +++ b/indico_toolkit/etloutput/__init__.py @@ -1,11 +1,11 @@ from typing import TYPE_CHECKING +from ..results.utilities import get, has from .cell import Cell, CellType from .errors import EtlOutputError, TableCellNotFoundError, TokenNotFoundError from .etloutput import EtlOutput from .table import Table from .token import Token -from .utilities import get, has if TYPE_CHECKING: from collections.abc import Awaitable, Callable diff --git a/indico_toolkit/etloutput/cell.py b/indico_toolkit/etloutput/cell.py index b915665..826641d 100644 --- a/indico_toolkit/etloutput/cell.py +++ b/indico_toolkit/etloutput/cell.py @@ -1,7 +1,7 @@ from dataclasses import dataclass from enum import Enum -from .utilities import get, has +from ..results.utilities import get, has class CellType(Enum): diff --git a/indico_toolkit/etloutput/table.py b/indico_toolkit/etloutput/table.py index c8d34d2..d1aacbb 100644 --- a/indico_toolkit/etloutput/table.py +++ b/indico_toolkit/etloutput/table.py @@ -1,7 +1,7 @@ from dataclasses import dataclass +from ..results.utilities import get from .cell import Cell -from .utilities import get @dataclass(frozen=True) diff --git a/indico_toolkit/etloutput/token.py b/indico_toolkit/etloutput/token.py index ec045df..b057464 100644 --- a/indico_toolkit/etloutput/token.py +++ b/indico_toolkit/etloutput/token.py @@ -1,6 +1,6 @@ from dataclasses import dataclass -from .utilities import get +from ..results.utilities import get @dataclass(frozen=True) diff --git a/indico_toolkit/etloutput/utilities.py b/indico_toolkit/etloutput/utilities.py deleted file mode 100644 index 7f85117..0000000 --- a/indico_toolkit/etloutput/utilities.py +++ /dev/null @@ -1,44 +0,0 @@ -from typing import TypeVar - -from .errors import EtlOutputError - -Value = TypeVar("Value") - - -def get(nested: object, value_type: "type[Value]", *keys: "str | int") -> Value: - """ - Return the value obtained by traversing `nested` using `keys` as indices if that - value is of type `value_type`. Raise a `EtlOutputError` otherwise. - """ - for key in keys: - if isinstance(key, str) and isinstance(nested, dict) and key in nested: - nested = nested[key] - elif isinstance(key, int) and isinstance(nested, list) and key < len(nested): - nested = nested[key] - else: - raise EtlOutputError( - f"etl output `{type(nested)!r}` does not contain key `{key!r}`" - ) - - if isinstance(nested, value_type): - return nested - else: - raise EtlOutputError( - f"etl output `{type(nested)!r}` does not have a value for " - f"key `{key!r}` of type `{value_type}`" - ) - - -def has(nested: object, value_type: "type[Value]", *keys: "str | int") -> bool: - """ - Check if `nested` can be traversed using `keys` to a value of type `value_type`. - """ - for key in keys: - if isinstance(key, str) and isinstance(nested, dict) and key in nested: - nested = nested[key] - elif isinstance(key, int) and isinstance(nested, list) and key < len(nested): - nested = nested[key] - else: - return False - - return isinstance(nested, value_type) From 5383248fdd556346133edfea49b859dc09b9a87b Mon Sep 17 00:00:00 2001 From: Michael Welborn Date: Tue, 4 Feb 2025 10:25:37 -0600 Subject: [PATCH 04/16] Factor out spans and bounding boxes to make them reusable using composition --- indico_toolkit/results/__init__.py | 8 +++ indico_toolkit/results/normalization.py | 21 ++----- .../results/predictions/__init__.py | 6 ++ indico_toolkit/results/predictions/box.py | 58 +++++++++++++++++++ .../results/predictions/documentextraction.py | 46 ++++++++++----- .../results/predictions/extraction.py | 14 ++++- .../results/predictions/formextraction.py | 25 +++----- indico_toolkit/results/predictions/span.py | 40 +++++++++++++ tests/results/test_predictionlist.py | 13 ++--- tests/results/test_predictions.py | 3 +- 10 files changed, 174 insertions(+), 60 deletions(-) create mode 100644 indico_toolkit/results/predictions/box.py create mode 100644 indico_toolkit/results/predictions/span.py diff --git a/indico_toolkit/results/__init__.py b/indico_toolkit/results/__init__.py index bd5650e..5aa20de 100644 --- a/indico_toolkit/results/__init__.py +++ b/indico_toolkit/results/__init__.py @@ -6,6 +6,9 @@ from .model import ModelGroup, TaskType from .predictionlist import PredictionList from .predictions import ( + NULL_BOX, + NULL_SPAN, + Box, Classification, DocumentExtraction, Extraction, @@ -13,6 +16,7 @@ FormExtractionType, Group, Prediction, + Span, Unbundling, ) from .result import Result @@ -24,6 +28,7 @@ __all__ = ( + "Box", "Classification", "Document", "DocumentExtraction", @@ -34,12 +39,15 @@ "load", "load_async", "ModelGroup", + "NULL_BOX", + "NULL_SPAN", "Prediction", "PredictionList", "Result", "ResultError", "Review", "ReviewType", + "Span", "TaskType", "Unbundling", ) diff --git a/indico_toolkit/results/normalization.py b/indico_toolkit/results/normalization.py index 6423add..2cd8c9a 100644 --- a/indico_toolkit/results/normalization.py +++ b/indico_toolkit/results/normalization.py @@ -43,17 +43,10 @@ def normalize_v1_result(result: "Any") -> None: if "confidence" not in prediction: prediction["confidence"] = {prediction["label"]: 0} - # Document Extractions added in review may lack spans. - if ( - "text" in prediction - and "type" not in prediction - and "start" not in prediction - ): - prediction["start"] = 0 - prediction["end"] = 0 - # Form Extractions added in review may lack bounding boxes. + # Set values that will equal `NULL_BOX`. if "type" in prediction and "top" not in prediction: + prediction["page_num"] = 0 prediction["top"] = 0 prediction["left"] = 0 prediction["right"] = 0 @@ -110,16 +103,12 @@ def normalize_v3_result(result: "Any") -> None: and "type" not in prediction and "spans" not in prediction ): - prediction["spans"] = [ - { - "page_num": prediction["page_num"], - "start": 0, - "end": 0, - } - ] + prediction["spans"] = [] # Form Extractions added in review may lack bounding boxes. + # Set values that will equal `NULL_BOX`. if "type" in prediction and "top" not in prediction: + prediction["page_num"] = 0 prediction["top"] = 0 prediction["left"] = 0 prediction["right"] = 0 diff --git a/indico_toolkit/results/predictions/__init__.py b/indico_toolkit/results/predictions/__init__.py index af325d0..b023488 100644 --- a/indico_toolkit/results/predictions/__init__.py +++ b/indico_toolkit/results/predictions/__init__.py @@ -1,12 +1,14 @@ from typing import TYPE_CHECKING from ..model import TaskType +from .box import NULL_BOX, Box from .classification import Classification from .documentextraction import DocumentExtraction from .extraction import Extraction from .formextraction import FormExtraction, FormExtractionType from .group import Group from .prediction import Prediction +from .span import NULL_SPAN, Span from .unbundling import Unbundling if TYPE_CHECKING: @@ -16,13 +18,17 @@ from ..review import Review __all__ = ( + "Box", "Classification", "DocumentExtraction", "Extraction", "FormExtraction", "FormExtractionType", "Group", + "NULL_BOX", + "NULL_SPAN", "Prediction", + "Span", "Unbundling", ) diff --git a/indico_toolkit/results/predictions/box.py b/indico_toolkit/results/predictions/box.py new file mode 100644 index 0000000..1175b60 --- /dev/null +++ b/indico_toolkit/results/predictions/box.py @@ -0,0 +1,58 @@ +from dataclasses import dataclass +from typing import TYPE_CHECKING + +from ..utilities import get + +if TYPE_CHECKING: + from typing import Final + + +@dataclass(frozen=True) +class Box: + page: int + top: int + left: int + right: int + bottom: int + + def __lt__(self, other: "Box") -> bool: + """ + Bounding boxes are sorted with vertical hysteresis. Those on the same line are + sorted left-to-right, even when later tokens are higher than earlier ones, + as long as they overlap vertically. + + ┌──────────────────┐ ┌───────────────────┐ + │ 1 │ │ 2 │ + └──────────────────┘ │ │ + └───────────────────┘ + ┌────────────────┐ + ┌─────────────┐ │ 4 │ ┌─────┐ + │ 3 │ └────────────────┘ │ 5 │ + └─────────────┘ └─────┘ + """ + return ( + self.page < other.page + or (self.page == other.page and self.bottom < other.top) + or ( + self.page == other.page + and self.top < other.bottom + and self.left < other.left + ) + ) + + @staticmethod + def from_dict(box: object) -> "Box": + return Box( + page=get(box, int, "page_num"), + top=get(box, int, "top"), + left=get(box, int, "left"), + right=get(box, int, "right"), + bottom=get(box, int, "bottom"), + ) + + +# It's more ergonomic to represent the lack of a bounding box with a special null box +# object rather than using `None` or raising an error. This lets you e.g. sort by the +# `box` attribute without having to constantly check for `None`, while still allowing +# you do a "None check" with `extraction.box == NULL_BOX`. +NULL_BOX: "Final" = Box(page=0, top=0, left=0, right=0, bottom=0) diff --git a/indico_toolkit/results/predictions/documentextraction.py b/indico_toolkit/results/predictions/documentextraction.py index 15e34e9..7c87f9b 100644 --- a/indico_toolkit/results/predictions/documentextraction.py +++ b/indico_toolkit/results/predictions/documentextraction.py @@ -5,6 +5,7 @@ from ..utilities import get, has, omit from .extraction import Extraction from .group import Group +from .span import NULL_SPAN, Span if TYPE_CHECKING: from typing import Any @@ -15,9 +16,28 @@ @dataclass class DocumentExtraction(Extraction): - start: int - end: int groups: "set[Group]" + spans: "list[Span]" + + @property + def span(self) -> Span: + """ + Return the first `Span` the document extraction covers else `NULL_SPAN`. + + Post-review, document extractions have no spans. + """ + return self.spans[0] if self.spans else NULL_SPAN + + @span.setter + def span(self, span: Span) -> None: + """ + Overwrite all `spans` with the one provided. + + This is implemented under the assumption that if you're setting the single span, + you want it to be the only one. And if you're working in a context that's + multiple-span sensetive, you'll set `extraction.spans` instead. + """ + self.spans = [span] @staticmethod def from_v1_dict( @@ -35,17 +55,15 @@ def from_v1_dict( review=review, label=get(prediction, str, "label"), confidences=get(prediction, dict, "confidence"), + text=get(prediction, str, "normalized", "formatted"), accepted=( has(prediction, bool, "accepted") and get(prediction, bool, "accepted") ), rejected=( has(prediction, bool, "rejected") and get(prediction, bool, "rejected") ), - text=get(prediction, str, "normalized", "formatted"), - page=get(prediction, int, "page_num"), - start=get(prediction, int, "start"), - end=get(prediction, int, "end"), groups=set(map(Group.from_dict, get(prediction, list, "groupings"))), + spans=[Span.from_dict(prediction)] if has(prediction, int, "start") else [], extras=omit( prediction, "label", @@ -82,15 +100,14 @@ def from_v3_dict( has(prediction, bool, "rejected") and get(prediction, bool, "rejected") ), text=get(prediction, str, "normalized", "formatted"), - page=get(prediction, int, "spans", 0, "page_num"), - start=get(prediction, int, "spans", 0, "start"), - end=get(prediction, int, "spans", 0, "end"), groups=set(map(Group.from_dict, get(prediction, list, "groupings"))), + spans=sorted(map(Span.from_dict, get(prediction, list, "spans"))), extras=omit( prediction, "label", "confidence", "groupings", + "spans", "accepted", "rejected", ), @@ -104,9 +121,9 @@ def to_v1_dict(self) -> "dict[str, Any]": **self.extras, "label": self.label, "confidence": self.confidences, - "page_num": self.page, - "start": self.start, - "end": self.end, + "page_num": self.span.page, + "start": self.span.start, + "end": self.span.end, "groupings": [group.to_dict() for group in self.groups], } @@ -133,10 +150,7 @@ def to_v3_dict(self) -> "dict[str, Any]": prediction["normalized"]["formatted"] = self.text prediction["text"] = self.text # 6.10 sometimes reverts to raw text in review. - - prediction["spans"][0]["page_num"] = self.page - prediction["spans"][0]["start"] = self.start - prediction["spans"][0]["end"] = self.end + prediction["spans"] = [span.to_dict() for span in self.spans] if self.accepted: prediction["accepted"] = True diff --git a/indico_toolkit/results/predictions/extraction.py b/indico_toolkit/results/predictions/extraction.py index fb0b060..3e5c1d8 100644 --- a/indico_toolkit/results/predictions/extraction.py +++ b/indico_toolkit/results/predictions/extraction.py @@ -5,10 +5,20 @@ @dataclass class Extraction(Prediction): + text: str accepted: bool rejected: bool - text: str - page: int + + @property + def page(self) -> int: + """ + Convenience property to get an extraction's page without knowing its subclass. + Allows you to do `predictions.extractions.groupby(attrgetter("page"))` et al. + """ + if hasattr(self, "box"): + return self.box.page # type: ignore + else: + return self.span.page # type: ignore def accept(self) -> None: self.accepted = True diff --git a/indico_toolkit/results/predictions/formextraction.py b/indico_toolkit/results/predictions/formextraction.py index 1925641..8185c3f 100644 --- a/indico_toolkit/results/predictions/formextraction.py +++ b/indico_toolkit/results/predictions/formextraction.py @@ -4,6 +4,7 @@ from ..review import Review from ..utilities import get, has, omit +from .box import Box from .extraction import Extraction if TYPE_CHECKING: @@ -22,14 +23,10 @@ class FormExtractionType(Enum): @dataclass class FormExtraction(Extraction): type: FormExtractionType + box: Box checked: bool signed: bool - top: int - left: int - right: int - bottom: int - @staticmethod def _from_dict( document: "Document", @@ -46,6 +43,7 @@ def _from_dict( review=review, label=get(prediction, str, "label"), confidences=get(prediction, dict, "confidence"), + text=get(prediction, str, "normalized", "formatted"), accepted=( has(prediction, bool, "accepted") and get(prediction, bool, "accepted") ), @@ -53,6 +51,7 @@ def _from_dict( has(prediction, bool, "rejected") and get(prediction, bool, "rejected") ), type=FormExtractionType(get(prediction, str, "type")), + box=Box.from_dict(prediction), checked=( has(prediction, bool, "normalized", "structured", "checked") and get(prediction, bool, "normalized", "structured", "checked") @@ -61,12 +60,6 @@ def _from_dict( has(prediction, bool, "normalized", "structured", "signed") and get(prediction, bool, "normalized", "structured", "signed") ), - text=get(prediction, str, "normalized", "formatted"), - page=get(prediction, int, "page_num"), - top=get(prediction, int, "top"), - left=get(prediction, int, "left"), - right=get(prediction, int, "right"), - bottom=get(prediction, int, "bottom"), extras=omit( prediction, "label", @@ -96,11 +89,11 @@ def _to_dict(self) -> "dict[str, Any]": "label": self.label, "confidence": self.confidences, "type": self.type.value, - "page_num": self.page, - "top": self.top, - "left": self.left, - "right": self.right, - "bottom": self.bottom, + "page_num": self.box.page, + "top": self.box.top, + "left": self.box.left, + "right": self.box.right, + "bottom": self.box.bottom, } if self.type == FormExtractionType.CHECKBOX: diff --git a/indico_toolkit/results/predictions/span.py b/indico_toolkit/results/predictions/span.py new file mode 100644 index 0000000..7e9c0b3 --- /dev/null +++ b/indico_toolkit/results/predictions/span.py @@ -0,0 +1,40 @@ +from dataclasses import dataclass +from typing import TYPE_CHECKING + +from ..utilities import get + +if TYPE_CHECKING: + from typing import Any, Final + + +@dataclass(order=True, frozen=True) +class Span: + page: int + start: int + end: int + + @property + def slice(self) -> slice: + return slice(self.start, self.end) + + @staticmethod + def from_dict(span: object) -> "Span": + return Span( + page=get(span, int, "page_num"), + start=get(span, int, "start"), + end=get(span, int, "end"), + ) + + def to_dict(self) -> "dict[str, Any]": + return { + "page_num": self.page, + "start": self.start, + "end": self.end, + } + + +# It's more ergonomic to represent the lack of spans with a special null span object +# rather than using `None` or raising an error. This lets you e.g. sort by the `span` +# attribute without having to constantly check for `None`, while still allowing you do +# a "None check" with `extraction.span == NULL_SPAN`. +NULL_SPAN: "Final" = Span(page=0, start=0, end=0) diff --git a/tests/results/test_predictionlist.py b/tests/results/test_predictionlist.py index 4d062bb..2a47542 100644 --- a/tests/results/test_predictionlist.py +++ b/tests/results/test_predictionlist.py @@ -12,6 +12,7 @@ PredictionList, Review, ReviewType, + Span, TaskType, ) @@ -91,13 +92,11 @@ def predictions( label="First Name", confidences={"First Name": 0.8}, extras={}, + text="John", accepted=False, rejected=False, - text="John", - start=352, - end=356, - page=0, groups={group_alpha}, + spans=[Span(page=0, start=352, end=356)], ), DocumentExtraction( document=document, @@ -106,13 +105,11 @@ def predictions( label="Last Name", confidences={"Last Name": 0.9}, extras={}, + text="Doe", accepted=False, rejected=False, - text="Doe", - start=357, - end=360, - page=1, groups={group_alpha, group_bravo}, + spans=[Span(page=1, start=357, end=360)], ), ] ) diff --git a/tests/results/test_predictions.py b/tests/results/test_predictions.py index 5ec407a..114329d 100644 --- a/tests/results/test_predictions.py +++ b/tests/results/test_predictions.py @@ -23,9 +23,8 @@ def test_extractions() -> None: review=None, label="Label", confidences={"Label": 0.5}, - text="Value", - page=0, extras=None, # type: ignore[arg-type] + text="Value", accepted=False, rejected=False, ) From 979e85cfb0c78f3219c3717f714854c3ddf8159f Mon Sep 17 00:00:00 2001 From: Michael Welborn Date: Tue, 4 Feb 2025 10:26:19 -0600 Subject: [PATCH 05/16] Make the unspecified review sentinel value a final constant --- indico_toolkit/results/predictionlist.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/indico_toolkit/results/predictionlist.py b/indico_toolkit/results/predictionlist.py index 8612a7e..f48504d 100644 --- a/indico_toolkit/results/predictionlist.py +++ b/indico_toolkit/results/predictionlist.py @@ -16,7 +16,7 @@ if TYPE_CHECKING: from collections.abc import Callable, Collection, Container, Iterable - from typing import Any, SupportsIndex + from typing import Any, Final, SupportsIndex from typing_extensions import Self @@ -29,7 +29,7 @@ KeyType = TypeVar("KeyType") # Non-None sentinel value to support `PredictionList.where(review=None)`. -ReviewUnspecified = Review( +REVIEW_UNSPECIFIED: "Final" = Review( id=None, reviewer_id=None, notes=None, rejected=None, type=None # type: ignore[arg-type] ) @@ -146,8 +146,8 @@ def where( document_in: "Container[Document] | None" = None, model: "ModelGroup | TaskType | str | None" = None, model_in: "Container[ModelGroup | TaskType | str] | None" = None, - review: "Review | ReviewType | None" = ReviewUnspecified, - review_in: "Container[Review | ReviewType | None]" = {ReviewUnspecified}, + review: "Review | ReviewType | None" = REVIEW_UNSPECIFIED, + review_in: "Container[Review | ReviewType | None]" = {REVIEW_UNSPECIFIED}, label: "str | None" = None, label_in: "Container[str] | None" = None, page: "int | None" = None, @@ -210,7 +210,7 @@ def where( ) ) - if review is not ReviewUnspecified: + if review is not REVIEW_UNSPECIFIED: predicates.append( lambda prediction: ( prediction.review == review @@ -221,7 +221,7 @@ def where( ) ) - if review_in != {ReviewUnspecified}: + if review_in != {REVIEW_UNSPECIFIED}: predicates.append( lambda prediction: ( prediction.review in review_in From 459ef13950082af0993d52a9d31cbc98ed1811fa Mon Sep 17 00:00:00 2001 From: Michael Welborn Date: Tue, 4 Feb 2025 10:40:24 -0600 Subject: [PATCH 06/16] Reuse span and bounding box classes in etl output using composition --- indico_toolkit/etloutput/__init__.py | 7 +++ indico_toolkit/etloutput/cell.py | 65 +++++++++------------------ indico_toolkit/etloutput/etloutput.py | 56 ++++++++++++----------- indico_toolkit/etloutput/range.py | 30 +++++++++++++ indico_toolkit/etloutput/table.py | 42 +++++------------ indico_toolkit/etloutput/token.py | 42 ++++------------- 6 files changed, 107 insertions(+), 135 deletions(-) create mode 100644 indico_toolkit/etloutput/range.py diff --git a/indico_toolkit/etloutput/__init__.py b/indico_toolkit/etloutput/__init__.py index 5e84a4c..51e0012 100644 --- a/indico_toolkit/etloutput/__init__.py +++ b/indico_toolkit/etloutput/__init__.py @@ -1,9 +1,11 @@ from typing import TYPE_CHECKING +from ..results import NULL_BOX, NULL_SPAN, Box, Span from ..results.utilities import get, has from .cell import Cell, CellType from .errors import EtlOutputError, TableCellNotFoundError, TokenNotFoundError from .etloutput import EtlOutput +from .range import Range from .table import Table from .token import Token @@ -12,12 +14,17 @@ from typing import Any __all__ = ( + "Box", "Cell", "CellType", "EtlOutput", "EtlOutputError", "load", "load_async", + "NULL_BOX", + "NULL_SPAN", + "Range", + "Span", "Table", "TableCellNotFoundError", "Token", diff --git a/indico_toolkit/etloutput/cell.py b/indico_toolkit/etloutput/cell.py index 826641d..6624ea1 100644 --- a/indico_toolkit/etloutput/cell.py +++ b/indico_toolkit/etloutput/cell.py @@ -1,7 +1,9 @@ from dataclasses import dataclass from enum import Enum -from ..results.utilities import get, has +from ..results import NULL_SPAN, Box, Span +from ..results.utilities import get +from .range import Range class CellType(Enum): @@ -13,60 +15,33 @@ class CellType(Enum): class Cell: type: CellType text: str - # Span - start: int - end: int - # Bounding box - page: int - top: int - left: int - right: int - bottom: int - # Table coordinates - row: int - rowspan: int - rows: "tuple[int, ...]" - column: int - columnspan: int - columns: "tuple[int, ...]" + box: Box + range: Range + spans: "tuple[Span, ...]" - def __lt__(self, other: "Cell") -> bool: + @property + def span(self) -> Span: """ - By default, cells are sorted in table order (by row, then column). - Cells can also be sorted in span order: `tokens.sort(key=attrgetter("start"))`. + Return the first `Span` the cell covers or `NULL_SPAN` otherwise. + + Empty cells have no spans. """ - return self.row < other.row or ( - self.row == other.row and self.column < other.column - ) + return self.spans[0] if self.spans else NULL_SPAN @staticmethod def from_dict(cell: object, page: int) -> "Cell": """ Create a `Cell` from a v1 or v3 cell dictionary. """ + get(cell, dict, "position")["page_num"] = page + + for doc_offset in get(cell, list, "doc_offsets"): + doc_offset["page_num"] = page + return Cell( type=CellType(get(cell, str, "cell_type")), text=get(cell, str, "text"), - # Empty cells have no start and end; so use [0:0] for a valid slice. - start=( - get(cell, int, "doc_offsets", 0, "start") - if has(cell, int, "doc_offsets", 0, "start") - else 0 - ), - end=( - get(cell, int, "doc_offsets", 0, "end") - if has(cell, int, "doc_offsets", 0, "end") - else 0 - ), - page=page, - top=get(cell, int, "position", "top"), - left=get(cell, int, "position", "left"), - right=get(cell, int, "position", "right"), - bottom=get(cell, int, "position", "bottom"), - row=get(cell, int, "rows", 0), - rowspan=len(get(cell, list, "rows")), - rows=tuple(get(cell, list, "rows")), - column=get(cell, int, "columns", 0), - columnspan=len(get(cell, list, "columns")), - columns=tuple(get(cell, list, "columns")), + box=Box.from_dict(get(cell, dict, "position")), + range=Range.from_dict(cell), + spans=tuple(map(Span.from_dict, get(cell, list, "doc_offsets"))), ) diff --git a/indico_toolkit/etloutput/etloutput.py b/indico_toolkit/etloutput/etloutput.py index 537f273..2086f1a 100644 --- a/indico_toolkit/etloutput/etloutput.py +++ b/indico_toolkit/etloutput/etloutput.py @@ -4,6 +4,7 @@ from operator import attrgetter from typing import TYPE_CHECKING +from ..results import Box, Span from .errors import TableCellNotFoundError, TokenNotFoundError from .table import Table from .token import Token @@ -11,7 +12,6 @@ if TYPE_CHECKING: from collections.abc import Iterable - from ..results import DocumentExtraction from .cell import Cell @@ -37,11 +37,11 @@ def from_pages( """ text_by_page = tuple(text_by_page) tokens_by_page = tuple( - tuple(map(Token.from_dict, token_dict_page)) + tuple(sorted(map(Token.from_dict, token_dict_page), key=attrgetter("span"))) for token_dict_page in token_dicts_by_page ) tables_by_page = tuple( - tuple(map(Table.from_dict, table_dict_page)) + tuple(sorted(map(Table.from_dict, table_dict_page), key=attrgetter("box"))) for table_dict_page in table_dicts_by_page ) @@ -54,51 +54,55 @@ def from_pages( tables_on_page=tables_by_page, ) - def token_for(self, extraction: "DocumentExtraction") -> Token: + def token_for(self, span: Span) -> Token: """ - Return a `Token` that contains every character from `extraction`. + Return a `Token` that contains every character from `span`. Raise `TokenNotFoundError` if one can't be produced. """ try: - tokens = self.tokens_on_page[extraction.page] - first = bisect_right(tokens, extraction.start, key=attrgetter("end")) - last = bisect_left(tokens, extraction.end, lo=first, key=attrgetter("start")) # fmt: skip # noqa: E501 + tokens = self.tokens_on_page[span.page] + first = bisect_right(tokens, span.start, key=attrgetter("span.end")) + last = bisect_left(tokens, span.end, lo=first, key=attrgetter("span.start")) tokens = tokens[first:last] return Token( - text=self.text[extraction.start : extraction.end], - start=extraction.start, - end=extraction.end, - page=min(token.page for token in tokens), - top=min(token.top for token in tokens), - left=min(token.left for token in tokens), - right=max(token.right for token in tokens), - bottom=max(token.bottom for token in tokens), + text=self.text[span.slice], + span=span, + box=Box( + page=min(token.box.page for token in tokens), + top=min(token.box.top for token in tokens), + left=min(token.box.left for token in tokens), + right=max(token.box.right for token in tokens), + bottom=max(token.box.bottom for token in tokens), + ), ) except (IndexError, ValueError) as error: - raise TokenNotFoundError(f"no token contains {extraction!r}") from error + raise TokenNotFoundError(f"no token contains {span!r}") from error def table_cell_for(self, token: Token) -> "tuple[Table, Cell]": """ Return the `Table` and `Cell` that contain the midpoint of `token`. Raise `TableCellNotFoundError` if it's not inside a table cell. """ - token_vmid = (token.top + token.bottom) // 2 - token_hmid = (token.left + token.right) // 2 - - for table in self.tables_on_page[token.page]: - if (table.top <= token_vmid <= table.bottom) and ( - table.left <= token_hmid <= table.right - ): + token_vmid = (token.box.top + token.box.bottom) // 2 + token_hmid = (token.box.left + token.box.right) // 2 + + for table in self.tables_on_page[token.box.page]: + if ( + (table.box.top <= token_vmid <= table.box.bottom) and + (table.box.left <= token_hmid <= table.box.right) + ): # fmt: skip break else: raise TableCellNotFoundError(f"no table contains {token!r}") try: - row_index = bisect_left(table.rows, token_vmid, key=lambda row: row[0].bottom) # fmt: skip # noqa: E501 + row_index = bisect_left( + table.rows, token_vmid, key=lambda row: row[0].box.bottom + ) row = table.rows[row_index] - cell_index = bisect_left(row, token_hmid, key=attrgetter("right")) + cell_index = bisect_left(row, token_hmid, key=attrgetter("box.right")) cell = row[cell_index] except (IndexError, ValueError) as error: raise TableCellNotFoundError(f"no cell contains {token!r}") from error diff --git a/indico_toolkit/etloutput/range.py b/indico_toolkit/etloutput/range.py new file mode 100644 index 0000000..ad2443c --- /dev/null +++ b/indico_toolkit/etloutput/range.py @@ -0,0 +1,30 @@ +from dataclasses import dataclass + +from ..results.utilities import get + + +@dataclass(order=True, frozen=True) +class Range: + row: int + column: int + rowspan: int + columnspan: int + rows: "tuple[int, ...]" + columns: "tuple[int, ...]" + + @staticmethod + def from_dict(cell: object) -> "Range": + """ + Create a `Range` from a v1 or v3 cell dictionary. + """ + rows = get(cell, list, "rows") + columns = get(cell, list, "columns") + + return Range( + row=rows[0], + column=columns[0], + rowspan=len(rows), + columnspan=len(columns), + rows=tuple(rows), + columns=tuple(columns), + ) diff --git a/indico_toolkit/etloutput/table.py b/indico_toolkit/etloutput/table.py index d1aacbb..b592136 100644 --- a/indico_toolkit/etloutput/table.py +++ b/indico_toolkit/etloutput/table.py @@ -1,61 +1,43 @@ from dataclasses import dataclass +from operator import attrgetter +from ..results import Box from ..results.utilities import get from .cell import Cell @dataclass(frozen=True) class Table: - page: int - top: int - left: int - right: int - bottom: int - + box: Box cells: "tuple[Cell, ...]" rows: "tuple[tuple[Cell, ...], ...]" columns: "tuple[tuple[Cell, ...], ...]" - def __lt__(self, other: "Table") -> bool: - """ - By default, tables are sorted in bounding box order with vertical hysteresis. - Those on the same line are sorted left-to-right, even when later tables are - slightly higher than earlier ones. - """ - return ( - self.page < other.page - or (self.page == other.page and self.bottom < other.top) - or ( - self.page == other.page - and self.top < other.bottom - and self.left < other.left - ) - ) - @staticmethod def from_dict(table: object) -> "Table": """ Create a `Table` from a v1 or v3 table dictionary. """ page = get(table, int, "page_num") + get(table, dict, "position")["page_num"] = page + cells = tuple( - sorted(Cell.from_dict(cell, page) for cell in get(table, list, "cells")) + sorted( + (Cell.from_dict(cell, page) for cell in get(table, list, "cells")), + key=attrgetter("range"), + ) ) rows = tuple( - tuple(sorted(cell for cell in cells if row in cell.rows)) + tuple(cell for cell in cells if row in cell.range.rows) for row in range(get(table, int, "num_rows")) ) columns = tuple( - tuple(sorted(cell for cell in cells if column in cell.columns)) + tuple(cell for cell in cells if column in cell.range.columns) for column in range(get(table, int, "num_columns")) ) return Table( - page=page, - top=get(table, int, "position", "top"), - left=get(table, int, "position", "left"), - right=get(table, int, "position", "right"), - bottom=get(table, int, "position", "bottom"), + box=Box.from_dict(get(table, dict, "position")), cells=cells, rows=rows, columns=columns, diff --git a/indico_toolkit/etloutput/token.py b/indico_toolkit/etloutput/token.py index b057464..a1cc36f 100644 --- a/indico_toolkit/etloutput/token.py +++ b/indico_toolkit/etloutput/token.py @@ -1,51 +1,25 @@ from dataclasses import dataclass +from ..results import Box, Span from ..results.utilities import get @dataclass(frozen=True) class Token: text: str - # Span - start: int - end: int - # Bounding box - page: int - top: int - left: int - right: int - bottom: int - - def __lt__(self, other: "Token") -> bool: - """ - By default, tokens are sorted in bounding box order with vertical hysteresis. - Those on the same line are sorted left-to-right, even when later tokens are - slightly higher than earlier ones. - - Tokens can also be sorted in span order: `tokens.sort(key=attrgetter("start"))`. - """ - return ( - self.page < other.page - or (self.page == other.page and self.bottom < other.top) - or ( - self.page == other.page - and self.top < other.bottom - and self.left < other.left - ) - ) + box: Box + span: Span @staticmethod def from_dict(token: object) -> "Token": """ Create a `Token` from a v1 or v3 token dictionary. """ + get(token, dict, "position")["page_num"] = get(token, int, "page_num") + get(token, dict, "doc_offset")["page_num"] = get(token, int, "page_num") + return Token( text=get(token, str, "text"), - start=get(token, int, "doc_offset", "start"), - end=get(token, int, "doc_offset", "end"), - page=get(token, int, "page_num"), - top=get(token, int, "position", "top"), - left=get(token, int, "position", "left"), - right=get(token, int, "position", "right"), - bottom=get(token, int, "position", "bottom"), + box=Box.from_dict(get(token, dict, "position")), + span=Span.from_dict(get(token, dict, "doc_offset")), ) From 0625261a15d3b177d1b92b81bed85b039c838db4 Mon Sep 17 00:00:00 2001 From: Michael Welborn Date: Fri, 7 Feb 2025 16:21:23 -0600 Subject: [PATCH 07/16] Add GenAI Classification and Extraction Task Types --- indico_toolkit/results/model.py | 2 ++ indico_toolkit/results/predictions/__init__.py | 4 ++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/indico_toolkit/results/model.py b/indico_toolkit/results/model.py index b4ae6a0..8e88c90 100755 --- a/indico_toolkit/results/model.py +++ b/indico_toolkit/results/model.py @@ -8,6 +8,8 @@ class TaskType(Enum): CLASSIFICATION = "classification" DOCUMENT_EXTRACTION = "annotation" FORM_EXTRACTION = "form_extraction" + GENAI_CLASSIFICATION = "genai_classification" + GENAI_EXTRACTION = "genai_annotation" UNBUNDLING = "classification_unbundling" diff --git a/indico_toolkit/results/predictions/__init__.py b/indico_toolkit/results/predictions/__init__.py index b023488..6fd9c32 100644 --- a/indico_toolkit/results/predictions/__init__.py +++ b/indico_toolkit/results/predictions/__init__.py @@ -61,9 +61,9 @@ def from_v3_dict( """ Create a `Prediction` subclass from a v3 prediction dictionary. """ - if model.task_type == TaskType.CLASSIFICATION: + if model.task_type in (TaskType.CLASSIFICATION, TaskType.GENAI_CLASSIFICATION): return Classification.from_v3_dict(document, model, review, prediction) - elif model.task_type == TaskType.DOCUMENT_EXTRACTION: + elif model.task_type in (TaskType.DOCUMENT_EXTRACTION, TaskType.GENAI_EXTRACTION): return DocumentExtraction.from_v3_dict(document, model, review, prediction) elif model.task_type == TaskType.FORM_EXTRACTION: return FormExtraction.from_v3_dict(document, model, review, prediction) From 05cf2a5a18252906bd6b18e4b01334b3efac6ca2 Mon Sep 17 00:00:00 2001 From: Michael Welborn Date: Fri, 7 Feb 2025 16:48:46 -0600 Subject: [PATCH 08/16] Add `Span.__bool__()` to simplify `extraction.span == NULL_SPAN` to `bool(extraction.span)` --- indico_toolkit/results/predictions/span.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/indico_toolkit/results/predictions/span.py b/indico_toolkit/results/predictions/span.py index 7e9c0b3..4ab8cb6 100644 --- a/indico_toolkit/results/predictions/span.py +++ b/indico_toolkit/results/predictions/span.py @@ -17,6 +17,9 @@ class Span: def slice(self) -> slice: return slice(self.start, self.end) + def __bool__(self) -> bool: + return self != NULL_SPAN + @staticmethod def from_dict(span: object) -> "Span": return Span( @@ -36,5 +39,5 @@ def to_dict(self) -> "dict[str, Any]": # It's more ergonomic to represent the lack of spans with a special null span object # rather than using `None` or raising an error. This lets you e.g. sort by the `span` # attribute without having to constantly check for `None`, while still allowing you do -# a "None check" with `extraction.span == NULL_SPAN`. +# a "None check" with `extraction.span == NULL_SPAN` or `bool(extraction.span)`. NULL_SPAN: "Final" = Span(page=0, start=0, end=0) From 33e2636eb5321bd288f0c47a7874937dac471a1e Mon Sep 17 00:00:00 2001 From: Michael Welborn Date: Fri, 7 Feb 2025 17:17:05 -0600 Subject: [PATCH 09/16] Update normalization to use task type names instead of heuristics --- indico_toolkit/results/normalization.py | 32 +++++++++++++++---------- 1 file changed, 19 insertions(+), 13 deletions(-) diff --git a/indico_toolkit/results/normalization.py b/indico_toolkit/results/normalization.py index 2cd8c9a..cbd6fcf 100644 --- a/indico_toolkit/results/normalization.py +++ b/indico_toolkit/results/normalization.py @@ -3,6 +3,7 @@ from .utilities import get, has if TYPE_CHECKING: + from collections.abc import Iterator from typing import Any @@ -84,30 +85,33 @@ def normalize_v3_result(result: "Any") -> None: """ Fix inconsistencies observed in v3 result files. """ - predictions: "Any" = ( - prediction + task_type_by_model_group_id = { + model_group_id: model_group["task_type"] + for model_group_id, model_group in result["modelgroup_metadata"].items() + } + predictions_with_task_type: "Iterator[tuple[Any, str]]" = ( + (prediction, task_type_by_model_group_id[model_group_id]) for submission_result in get(result, list, "submission_results") - for model_result in get(submission_result, dict, "model_results").values() - for review_result in model_result.values() - for prediction in review_result + for review_result in get(submission_result, dict, "model_results").values() + for model_group_id, model_results in review_result.items() + for prediction in model_results ) - for prediction in predictions: - # Predictions added in review lack a `confidence` section. + for prediction, task_type in predictions_with_task_type: + # Predictions added in review may lack a `confidence` section. if "confidence" not in prediction: prediction["confidence"] = {prediction["label"]: 0} # Document Extractions added in review may lack spans. if ( - "text" in prediction - and "type" not in prediction + task_type in ("annotation", "genai_annotation") and "spans" not in prediction ): prediction["spans"] = [] # Form Extractions added in review may lack bounding boxes. # Set values that will equal `NULL_BOX`. - if "type" in prediction and "top" not in prediction: + if task_type == "form_extraction" and "top" not in prediction: prediction["page_num"] = 0 prediction["top"] = 0 prediction["left"] = 0 @@ -116,14 +120,16 @@ def normalize_v3_result(result: "Any") -> None: # Prior to 6.11, some Extractions lack a `normalized` section after # review. - if "text" in prediction and "normalized" not in prediction: + if ( + task_type in ("annotation", "form_extraction", "genai_annotation") + and "normalized" not in prediction + ): prediction["normalized"] = {"formatted": prediction["text"]} # Document Extractions that didn't go through a linked labels # transformer lack a `groupings` section. if ( - "text" in prediction - and "type" not in prediction + task_type in ("annotation", "genai_annotation") and "groupings" not in prediction ): prediction["groupings"] = [] From 0f6fb0ecab0dacc6886e065dd80f78c0936a0266 Mon Sep 17 00:00:00 2001 From: Michael Welborn Date: Fri, 7 Feb 2025 17:19:11 -0600 Subject: [PATCH 10/16] Add Summarization prediction type --- indico_toolkit/results/__init__.py | 4 + indico_toolkit/results/model.py | 1 + indico_toolkit/results/normalization.py | 4 + indico_toolkit/results/predictionlist.py | 5 + .../results/predictions/__init__.py | 7 ++ .../results/predictions/citation.py | 47 ++++++++ .../results/predictions/summarization.py | 103 ++++++++++++++++++ 7 files changed, 171 insertions(+) create mode 100644 indico_toolkit/results/predictions/citation.py create mode 100644 indico_toolkit/results/predictions/summarization.py diff --git a/indico_toolkit/results/__init__.py b/indico_toolkit/results/__init__.py index 5aa20de..520322a 100644 --- a/indico_toolkit/results/__init__.py +++ b/indico_toolkit/results/__init__.py @@ -7,6 +7,7 @@ from .predictionlist import PredictionList from .predictions import ( NULL_BOX, + NULL_CITATION, NULL_SPAN, Box, Classification, @@ -17,6 +18,7 @@ Group, Prediction, Span, + Summarization, Unbundling, ) from .result import Result @@ -40,6 +42,7 @@ "load_async", "ModelGroup", "NULL_BOX", + "NULL_CITATION", "NULL_SPAN", "Prediction", "PredictionList", @@ -48,6 +51,7 @@ "Review", "ReviewType", "Span", + "Summarization", "TaskType", "Unbundling", ) diff --git a/indico_toolkit/results/model.py b/indico_toolkit/results/model.py index 8e88c90..2df8ff8 100755 --- a/indico_toolkit/results/model.py +++ b/indico_toolkit/results/model.py @@ -10,6 +10,7 @@ class TaskType(Enum): FORM_EXTRACTION = "form_extraction" GENAI_CLASSIFICATION = "genai_classification" GENAI_EXTRACTION = "genai_annotation" + GENAI_SUMMARIZATION = "summarization" UNBUNDLING = "classification_unbundling" diff --git a/indico_toolkit/results/normalization.py b/indico_toolkit/results/normalization.py index cbd6fcf..ae73cdd 100644 --- a/indico_toolkit/results/normalization.py +++ b/indico_toolkit/results/normalization.py @@ -134,6 +134,10 @@ def normalize_v3_result(result: "Any") -> None: ): prediction["groupings"] = [] + # Summarizations may lack citations after review. + if task_type == "summarization" and "citations" not in prediction: + prediction["citations"] = [] + # Prior to 6.8, v3 result files don't include a `reviews` section. if not has(result, dict, "reviews"): result["reviews"] = {} diff --git a/indico_toolkit/results/predictionlist.py b/indico_toolkit/results/predictionlist.py index f48504d..8129063 100644 --- a/indico_toolkit/results/predictionlist.py +++ b/indico_toolkit/results/predictionlist.py @@ -9,6 +9,7 @@ Extraction, FormExtraction, Prediction, + Summarization, Unbundling, ) from .review import Review, ReviewType @@ -51,6 +52,10 @@ def extractions(self) -> "PredictionList[Extraction]": def form_extractions(self) -> "PredictionList[FormExtraction]": return self.oftype(FormExtraction) + @property + def summarizations(self) -> "PredictionList[Summarization]": + return self.oftype(Summarization) + @property def unbundlings(self) -> "PredictionList[Unbundling]": return self.oftype(Unbundling) diff --git a/indico_toolkit/results/predictions/__init__.py b/indico_toolkit/results/predictions/__init__.py index 6fd9c32..280bfb9 100644 --- a/indico_toolkit/results/predictions/__init__.py +++ b/indico_toolkit/results/predictions/__init__.py @@ -2,6 +2,7 @@ from ..model import TaskType from .box import NULL_BOX, Box +from .citation import NULL_CITATION, Citation from .classification import Classification from .documentextraction import DocumentExtraction from .extraction import Extraction @@ -9,6 +10,7 @@ from .group import Group from .prediction import Prediction from .span import NULL_SPAN, Span +from .summarization import Summarization from .unbundling import Unbundling if TYPE_CHECKING: @@ -19,6 +21,7 @@ __all__ = ( "Box", + "Citation", "Classification", "DocumentExtraction", "Extraction", @@ -26,9 +29,11 @@ "FormExtractionType", "Group", "NULL_BOX", + "NULL_CITATION", "NULL_SPAN", "Prediction", "Span", + "Summarization", "Unbundling", ) @@ -67,6 +72,8 @@ def from_v3_dict( return DocumentExtraction.from_v3_dict(document, model, review, prediction) elif model.task_type == TaskType.FORM_EXTRACTION: return FormExtraction.from_v3_dict(document, model, review, prediction) + elif model.task_type == TaskType.GENAI_SUMMARIZATION: + return Summarization.from_v3_dict(document, model, review, prediction) elif model.task_type == TaskType.UNBUNDLING: return Unbundling.from_v3_dict(document, model, review, prediction) else: diff --git a/indico_toolkit/results/predictions/citation.py b/indico_toolkit/results/predictions/citation.py new file mode 100644 index 0000000..47354cb --- /dev/null +++ b/indico_toolkit/results/predictions/citation.py @@ -0,0 +1,47 @@ +from dataclasses import dataclass +from typing import TYPE_CHECKING + +from ..utilities import get +from .span import NULL_SPAN, Span + +if TYPE_CHECKING: + from typing import Any, Final + + +@dataclass(order=True, frozen=True) +class Citation: + start: int + end: int + span: Span + + @property + def slice(self) -> slice: + return slice(self.start, self.end) + + def __bool__(self) -> bool: + return self != NULL_CITATION + + @staticmethod + def from_dict(span: object) -> "Citation": + return Citation( + start=get(span, int, "response", "start"), + end=get(span, int, "response", "end"), + span=Span.from_dict(get(span, dict, "document")), + ) + + def to_dict(self) -> "dict[str, Any]": + return { + "document": self.span.to_dict(), + "response": { + "start": self.start, + "end": self.end, + }, + } + + +# It's more ergonomic to represent the lack of citations with a special null citation +# object rather than using `None` or raising an error. This lets you e.g. sort by the +# `citation` attribute without having to constantly check for `None`, while still +# allowing you do a "None check" with `summarization.citation == NULL_CITATION` or +# `bool(summarization.citation)`. +NULL_CITATION: "Final" = Citation(span=NULL_SPAN, start=0, end=0) diff --git a/indico_toolkit/results/predictions/summarization.py b/indico_toolkit/results/predictions/summarization.py new file mode 100644 index 0000000..77176a4 --- /dev/null +++ b/indico_toolkit/results/predictions/summarization.py @@ -0,0 +1,103 @@ +from dataclasses import dataclass +from typing import TYPE_CHECKING + +from ..review import Review +from ..utilities import get, has, omit +from .citation import NULL_CITATION, Citation +from .extraction import Extraction + +if TYPE_CHECKING: + from typing import Any + + from ..document import Document + from ..model import ModelGroup + from .span import Span + + +@dataclass +class Summarization(Extraction): + citations: "list[Citation]" + + @property + def citation(self) -> Citation: + """ + Return the first `Citation` the summarization covers else `NULL_CITATION`. + + Post-review, summarizations have no citations. + """ + return self.citations[0] if self.citations else NULL_CITATION + + @citation.setter + def citation(self, citation: Citation) -> None: + """ + Overwrite all `citations` with the one provided. + + This is implemented under the assumption that if you're setting the single + citation, you want it to be the only one. And if you're working in a context + that's multiple-citation sensetive, you'll set `extraction.citations` instead. + """ + self.citations = [citation] + + @property + def span(self) -> "Span": + return self.citation.span + + @span.setter + def span(self, span: "Span") -> None: + self.citations = [Citation(self.citation.start, self.citation.end, span)] + + @staticmethod + def from_v3_dict( + document: "Document", + model: "ModelGroup", + review: "Review | None", + prediction: object, + ) -> "Summarization": + """ + Create a `Summarization` from a v3 prediction dictionary. + """ + return Summarization( + document=document, + model=model, + review=review, + label=get(prediction, str, "label"), + confidences=get(prediction, dict, "confidence"), + text=get(prediction, str, "text"), + accepted=( + has(prediction, bool, "accepted") and get(prediction, bool, "accepted") + ), + rejected=( + has(prediction, bool, "rejected") and get(prediction, bool, "rejected") + ), + citations=sorted( + map(Citation.from_dict, get(prediction, list, "citations")) + ), + extras=omit( + prediction, + "label", + "confidence", + "text", + "accepted", + "rejected", + "citations", + ), + ) + + def to_v3_dict(self) -> "dict[str, Any]": + """ + Create a prediction dictionary for v3 auto review changes. + """ + prediction = { + **self.extras, + "label": self.label, + "confidence": self.confidences, + "text": self.text, + "citations": [citation.to_dict() for citation in self.citations], + } + + if self.accepted: + prediction["accepted"] = True + elif self.rejected: + prediction["rejected"] = True + + return prediction From f26b587bff3c81d335c29e8951f1e665c7074565 Mon Sep 17 00:00:00 2001 From: Michael Welborn Date: Fri, 7 Feb 2025 17:20:14 -0600 Subject: [PATCH 11/16] Update attr order to be consistent --- .../results/predictions/documentextraction.py | 18 +++++++++--------- .../results/predictions/formextraction.py | 10 +++++----- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/indico_toolkit/results/predictions/documentextraction.py b/indico_toolkit/results/predictions/documentextraction.py index 7c87f9b..8ce4b7f 100644 --- a/indico_toolkit/results/predictions/documentextraction.py +++ b/indico_toolkit/results/predictions/documentextraction.py @@ -68,12 +68,12 @@ def from_v1_dict( prediction, "label", "confidence", - "start", - "end", - "page_num", - "groupings", "accepted", "rejected", + "groupings", + "page_num", + "start", + "end", ), ) @@ -93,23 +93,23 @@ def from_v3_dict( review=review, label=get(prediction, str, "label"), confidences=get(prediction, dict, "confidence"), + text=get(prediction, str, "normalized", "formatted"), accepted=( has(prediction, bool, "accepted") and get(prediction, bool, "accepted") ), rejected=( has(prediction, bool, "rejected") and get(prediction, bool, "rejected") ), - text=get(prediction, str, "normalized", "formatted"), groups=set(map(Group.from_dict, get(prediction, list, "groupings"))), spans=sorted(map(Span.from_dict, get(prediction, list, "spans"))), extras=omit( prediction, "label", "confidence", - "groupings", - "spans", "accepted", "rejected", + "groupings", + "spans", ), ) @@ -121,10 +121,10 @@ def to_v1_dict(self) -> "dict[str, Any]": **self.extras, "label": self.label, "confidence": self.confidences, + "groupings": [group.to_dict() for group in self.groups], "page_num": self.span.page, "start": self.span.start, "end": self.span.end, - "groupings": [group.to_dict() for group in self.groups], } prediction["normalized"]["formatted"] = self.text @@ -146,11 +146,11 @@ def to_v3_dict(self) -> "dict[str, Any]": "label": self.label, "confidence": self.confidences, "groupings": [group.to_dict() for group in self.groups], + "spans": [span.to_dict() for span in self.spans], } prediction["normalized"]["formatted"] = self.text prediction["text"] = self.text # 6.10 sometimes reverts to raw text in review. - prediction["spans"] = [span.to_dict() for span in self.spans] if self.accepted: prediction["accepted"] = True diff --git a/indico_toolkit/results/predictions/formextraction.py b/indico_toolkit/results/predictions/formextraction.py index 8185c3f..6b9ed47 100644 --- a/indico_toolkit/results/predictions/formextraction.py +++ b/indico_toolkit/results/predictions/formextraction.py @@ -64,16 +64,16 @@ def _from_dict( prediction, "label", "confidence", + "accepted", + "rejected", "type", - "checked", - "signed", "page_num", "top", - "bottom", "left", "right", - "accepted", - "rejected", + "bottom", + "checked", + "signed", ), ) From 8fbc33eb4c744a03bb74cfcc92b2a263004ed24e Mon Sep 17 00:00:00 2001 From: Michael Welborn Date: Fri, 7 Feb 2025 17:20:44 -0600 Subject: [PATCH 12/16] Add unit test result file for GenAI Classification, Extraction, and Summarization --- tests/data/results/96127_v3_genai.json | 194 +++++++++++++++++++++++++ 1 file changed, 194 insertions(+) create mode 100644 tests/data/results/96127_v3_genai.json diff --git a/tests/data/results/96127_v3_genai.json b/tests/data/results/96127_v3_genai.json new file mode 100644 index 0000000..1650388 --- /dev/null +++ b/tests/data/results/96127_v3_genai.json @@ -0,0 +1,194 @@ +{ + "file_version": 3, + "submission_id": 96127, + "modelgroup_metadata": { + "5561": { + "id": 5561, + "task_type": "genai_classification", + "name": "GenAI Classification", + "selected_model": { "id": 9045, "model_type": "genai" } + }, + "5562": { + "id": 5562, + "task_type": "genai_annotation", + "name": "GenAI Invoice Extraction", + "selected_model": { "id": 9048, "model_type": "genai" } + }, + "5563": { + "id": 5563, + "task_type": "summarization", + "name": "GenAI Purchase Order Summarization", + "selected_model": { "id": 9050, "model_type": "summarization" } + } + }, + "submission_results": [ + { + "submissionfile_id": 89243, + "etl_output": "indico-file:///storage/submission/5216/96127/89243/etl_output.json", + "input_filename": "Invoice.pdf", + "input_filepath": "indico-file:///storage/submission/5216/96127/89243.pdf", + "input_filesize": 426157, + "model_results": { + "ORIGINAL": { + "5562": [ + { + "label": "Invoice Number", + "spans": [ + { "start": 113, "end": 119, "page_num": 0 } + ], + "span_id": "89243:c:18030:idx:0", + "confidence": { + "Invoice Number": 0.23762473235648304 + }, + "field_id": 590168, + "location_type": "exact", + "text": "579266", + "normalized": { + "text": "579266", + "start": 113, + "end": 119, + "structured": null, + "formatted": "579266", + "status": "SUCCESS", + "comparison_type": "string", + "comparison_value": "579266", + "validation": [ + { + "validation_type": "TYPE_CONVERSION", + "error_message": null, + "validation_status": "SUCCESS" + } + ] + } + } + ], + "5561": [ + { + "field_id": 590167, + "confidence": { + "Invoice": 0.9999999857970121 + }, + "label": "Invoice" + } + ] + }, + "FINAL": { + "5562": [ + { + "text": "579266", + "label": "Invoice Number", + "field_id": 590168, + "page_num": 0, + "normalized": { + "end": null, + "text": "579266", + "start": null, + "status": "SUCCESS", + "formatted": "579266", + "structured": {}, + "validation": [ + { + "error_message": null, + "validation_type": "TYPE_CONVERSION", + "validation_status": "SUCCESS" + } + ], + "comparison_type": "string", + "comparison_value": "579266" + } + } + ], + "5561": [ + { + "field_id": 590167, + "confidence": { + "Invoice": 0.9999999857970121 + }, + "label": "Invoice" + } + ], + "5563": [] + } + }, + "component_results": { "ORIGINAL": {}, "FINAL": {} }, + "rejected": { + "models": { "5562": [], "5561": [] }, + "components": {} + } + }, + { + "submissionfile_id": 89244, + "etl_output": "indico-file:///storage/submission/5216/96127/89244/etl_output.json", + "input_filename": "purchase_order.pdf", + "input_filepath": "indico-file:///storage/submission/5216/96127/89244.pdf", + "input_filesize": 80950, + "model_results": { + "ORIGINAL": { + "5563": [ + { + "field_id": 590169, + "confidence": { "Purchase Order Number": 1.0 }, + "label": "Purchase Order Number", + "text": "29111525 [1]", + "citations": [ + { + "document": { + "start": 0, + "end": 1329, + "page_num": 0 + }, + "response": { "start": 9, "end": 12 } + } + ] + } + ], + "5561": [ + { + "field_id": 590167, + "confidence": { + "Purchase Order": 0.9999999846134298 + }, + "label": "Purchase Order" + } + ] + }, + "FINAL": { + "5563": [ + { + "text": "29111525", + "label": "Purchase Order Number", + "field_id": 590169, + "page_num": 0, + "confidence": { "Purchase Order Number": 1 } + } + ], + "5561": [ + { + "field_id": 590167, + "confidence": { + "Purchase Order": 0.9999999846134298 + }, + "label": "Purchase Order" + } + ], + "5562": [] + } + }, + "component_results": { "ORIGINAL": {}, "FINAL": {} }, + "rejected": { + "models": { "5563": [], "5561": [] }, + "components": {} + } + } + ], + "reviews": { + "68458": { + "review_id": 68458, + "reviewer_id": 422, + "review_notes": null, + "review_rejected": false, + "review_type": "manual" + } + }, + "errored_files": {} +} From 5fb4d3048651f72c73705e49ef69f6d269ed67f6 Mon Sep 17 00:00:00 2001 From: Michael Welborn Date: Mon, 10 Feb 2025 15:41:01 -0600 Subject: [PATCH 13/16] Update examples --- examples/results_dataclasses.py | 28 ++++++++++++++++------------ 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/examples/results_dataclasses.py b/examples/results_dataclasses.py index 0a9ccd1..0d84a11 100644 --- a/examples/results_dataclasses.py +++ b/examples/results_dataclasses.py @@ -48,11 +48,13 @@ """ -Dataclass Reference +Dataclass Reference Summary + +See class definitions for complete reference. """ # Result Dataclass -result.id # Submission ID +result.submission_id # Submission ID result.version # Result file version result.documents # List of documents in this submission result.models # List of documents in this submission @@ -81,8 +83,7 @@ document = result.documents[0] document.id document.name -document.etl_output_url -document.full_text_url +document.etl_output_uri # Prediction list Dataclass @@ -130,9 +131,9 @@ # DocumentExtraction Dataclass (Subclass of Extraction) document_extraction = predictions.document_extractions[0] document_extraction.text -document_extraction.page -document_extraction.start -document_extraction.end +document_extraction.span.page +document_extraction.span.start +document_extraction.span.end document_extraction.groups # Any linked label groups this prediction is a part of document_extraction.accepted document_extraction.rejected @@ -145,12 +146,15 @@ # FormExtraction Dataclass (Subclass of Extraction) form_extraction = predictions.form_extractions[0] +form_extraction.type form_extraction.text -form_extraction.page -form_extraction.top -form_extraction.left -form_extraction.right -form_extraction.bottom +form_extraction.checked +form_extraction.signed +form_extraction.box.page +form_extraction.box.top +form_extraction.box.left +form_extraction.box.right +form_extraction.box.bottom form_extraction.accepted form_extraction.rejected From 07996e79605313676a6c7a0afe24662894e310b4 Mon Sep 17 00:00:00 2001 From: Michael Welborn Date: Mon, 10 Feb 2025 16:21:14 -0600 Subject: [PATCH 14/16] Simplify `ModelGroup.task_type: TaskType` to `ModelGroup.type: ModelGroupType` --- indico_toolkit/results/__init__.py | 4 +-- indico_toolkit/results/model.py | 16 +++++----- indico_toolkit/results/predictionlist.py | 12 ++++---- .../results/predictions/__init__.py | 30 ++++++++++++------- tests/results/test_predictionlist.py | 14 +++++---- 5 files changed, 43 insertions(+), 33 deletions(-) diff --git a/indico_toolkit/results/__init__.py b/indico_toolkit/results/__init__.py index 520322a..51b2cc1 100644 --- a/indico_toolkit/results/__init__.py +++ b/indico_toolkit/results/__init__.py @@ -3,7 +3,7 @@ from .document import Document from .errors import ResultError -from .model import ModelGroup, TaskType +from .model import ModelGroup, ModelGroupType from .predictionlist import PredictionList from .predictions import ( NULL_BOX, @@ -41,6 +41,7 @@ "load", "load_async", "ModelGroup", + "ModelGroupType", "NULL_BOX", "NULL_CITATION", "NULL_SPAN", @@ -52,7 +53,6 @@ "ReviewType", "Span", "Summarization", - "TaskType", "Unbundling", ) diff --git a/indico_toolkit/results/model.py b/indico_toolkit/results/model.py index 2df8ff8..36d9b88 100755 --- a/indico_toolkit/results/model.py +++ b/indico_toolkit/results/model.py @@ -4,7 +4,7 @@ from .utilities import get, has -class TaskType(Enum): +class ModelGroupType(Enum): CLASSIFICATION = "classification" DOCUMENT_EXTRACTION = "annotation" FORM_EXTRACTION = "form_extraction" @@ -18,7 +18,7 @@ class TaskType(Enum): class ModelGroup: id: int name: str - task_type: TaskType + type: ModelGroupType @staticmethod def from_v1_section(section: "tuple[str, object]") -> "ModelGroup": @@ -32,20 +32,20 @@ def from_v1_section(section: "tuple[str, object]") -> "ModelGroup": prediction = get(predictions, dict, "pre_review", 0) if has(prediction, str, "type"): - task_type = TaskType.FORM_EXTRACTION + type = ModelGroupType.FORM_EXTRACTION elif has(prediction, str, "text"): - task_type = TaskType.DOCUMENT_EXTRACTION + type = ModelGroupType.DOCUMENT_EXTRACTION else: - task_type = TaskType.CLASSIFICATION + type = ModelGroupType.CLASSIFICATION else: # Likely an extraction model that produced no predictions. - task_type = TaskType.DOCUMENT_EXTRACTION + type = ModelGroupType.DOCUMENT_EXTRACTION return ModelGroup( # v1 result files don't include model IDs. id=None, # type: ignore[arg-type] name=name, - task_type=task_type, + type=type, ) @staticmethod @@ -56,5 +56,5 @@ def from_v3_dict(model_group: object) -> "ModelGroup": return ModelGroup( id=get(model_group, int, "id"), name=get(model_group, str, "name"), - task_type=TaskType(get(model_group, str, "task_type")), + type=ModelGroupType(get(model_group, str, "task_type")), ) diff --git a/indico_toolkit/results/predictionlist.py b/indico_toolkit/results/predictionlist.py index 8129063..d2a0cd7 100644 --- a/indico_toolkit/results/predictionlist.py +++ b/indico_toolkit/results/predictionlist.py @@ -2,7 +2,7 @@ from operator import attrgetter from typing import TYPE_CHECKING, List, TypeVar, overload -from .model import TaskType +from .model import ModelGroupType from .predictions import ( Classification, DocumentExtraction, @@ -149,8 +149,8 @@ def where( *, document: "Document | None" = None, document_in: "Container[Document] | None" = None, - model: "ModelGroup | TaskType | str | None" = None, - model_in: "Container[ModelGroup | TaskType | str] | None" = None, + model: "ModelGroup | ModelGroupType | str | None" = None, + model_in: "Container[ModelGroup | ModelGroupType | str] | None" = None, review: "Review | ReviewType | None" = REVIEW_UNSPECIFIED, review_in: "Container[Review | ReviewType | None]" = {REVIEW_UNSPECIFIED}, label: "str | None" = None, @@ -201,7 +201,7 @@ def where( predicates.append( lambda prediction: ( prediction.model == model - or prediction.model.task_type == model + or prediction.model.type == model or prediction.model.name == model ) ) @@ -210,7 +210,7 @@ def where( predicates.append( lambda prediction: ( prediction.model in model_in - or prediction.model.task_type in model_in + or prediction.model.type in model_in or prediction.model.name in model_in ) ) @@ -346,7 +346,7 @@ def to_v1_changes(self, document: "Document") -> "dict[str, Any]": changes: "dict[str, Any]" = {} for model, predictions in self.groupby(attrgetter("model")).items(): - if model.task_type == TaskType.CLASSIFICATION: + if model.type == ModelGroupType.CLASSIFICATION: changes[model.name] = predictions[0].to_v1_dict() else: changes[model.name] = [ diff --git a/indico_toolkit/results/predictions/__init__.py b/indico_toolkit/results/predictions/__init__.py index 280bfb9..aa65277 100644 --- a/indico_toolkit/results/predictions/__init__.py +++ b/indico_toolkit/results/predictions/__init__.py @@ -1,6 +1,6 @@ from typing import TYPE_CHECKING -from ..model import TaskType +from ..model import ModelGroupType from .box import NULL_BOX, Box from .citation import NULL_CITATION, Citation from .classification import Classification @@ -37,6 +37,14 @@ "Unbundling", ) +CLASSIFICATION = ModelGroupType.CLASSIFICATION +DOCUMENT_EXTRACTION = ModelGroupType.DOCUMENT_EXTRACTION +FORM_EXTRACTION = ModelGroupType.FORM_EXTRACTION +GENAI_CLASSIFICATION = ModelGroupType.GENAI_CLASSIFICATION +GENAI_EXTRACTION = ModelGroupType.GENAI_EXTRACTION +GENAI_SUMMARIZATION = ModelGroupType.GENAI_SUMMARIZATION +UNBUNDLING = ModelGroupType.UNBUNDLING + def from_v1_dict( document: "Document", @@ -47,14 +55,14 @@ def from_v1_dict( """ Create a `Prediction` subclass from a v1 prediction dictionary. """ - if model.task_type == TaskType.CLASSIFICATION: + if model.type == CLASSIFICATION: return Classification.from_v1_dict(document, model, review, prediction) - elif model.task_type == TaskType.DOCUMENT_EXTRACTION: + elif model.type == DOCUMENT_EXTRACTION: return DocumentExtraction.from_v1_dict(document, model, review, prediction) - elif model.task_type == TaskType.FORM_EXTRACTION: + elif model.type == FORM_EXTRACTION: return FormExtraction.from_v1_dict(document, model, review, prediction) else: - raise ResultError(f"unsupported v1 task type `{model.task_type!r}`") + raise ResultError(f"unsupported v1 model type `{model.type!r}`") def from_v3_dict( @@ -66,15 +74,15 @@ def from_v3_dict( """ Create a `Prediction` subclass from a v3 prediction dictionary. """ - if model.task_type in (TaskType.CLASSIFICATION, TaskType.GENAI_CLASSIFICATION): + if model.type in (CLASSIFICATION, GENAI_CLASSIFICATION): return Classification.from_v3_dict(document, model, review, prediction) - elif model.task_type in (TaskType.DOCUMENT_EXTRACTION, TaskType.GENAI_EXTRACTION): + elif model.type in (DOCUMENT_EXTRACTION, GENAI_EXTRACTION): return DocumentExtraction.from_v3_dict(document, model, review, prediction) - elif model.task_type == TaskType.FORM_EXTRACTION: + elif model.type == FORM_EXTRACTION: return FormExtraction.from_v3_dict(document, model, review, prediction) - elif model.task_type == TaskType.GENAI_SUMMARIZATION: + elif model.type == GENAI_SUMMARIZATION: return Summarization.from_v3_dict(document, model, review, prediction) - elif model.task_type == TaskType.UNBUNDLING: + elif model.type == UNBUNDLING: return Unbundling.from_v3_dict(document, model, review, prediction) else: - raise ResultError(f"unsupported v3 task type `{model.task_type!r}`") + raise ResultError(f"unsupported v3 model type `{model.type!r}`") diff --git a/tests/results/test_predictionlist.py b/tests/results/test_predictionlist.py index 2a47542..56c8c8f 100644 --- a/tests/results/test_predictionlist.py +++ b/tests/results/test_predictionlist.py @@ -8,12 +8,12 @@ DocumentExtraction, Group, ModelGroup, + ModelGroupType, Prediction, PredictionList, Review, ReviewType, Span, - TaskType, ) @@ -30,14 +30,14 @@ def document() -> Document: @pytest.fixture def classification_model() -> ModelGroup: return ModelGroup( - id=121, name="Tax Classification", task_type=TaskType.CLASSIFICATION + id=121, name="Tax Classification", type=ModelGroupType.CLASSIFICATION ) @pytest.fixture def extraction_model() -> ModelGroup: return ModelGroup( - id=122, name="1040 Document Extraction", task_type=TaskType.DOCUMENT_EXTRACTION + id=122, name="1040 Document Extraction", type=ModelGroupType.DOCUMENT_EXTRACTION ) @@ -178,7 +178,7 @@ def test_where_model( ) -> None: (classification,) = predictions.classifications assert predictions.where(model=classification_model) == [classification] - assert predictions.where(model=TaskType.CLASSIFICATION) == [classification] + assert predictions.where(model=ModelGroupType.CLASSIFICATION) == [classification] assert predictions.where(model="Tax Classification") == [classification] @@ -187,9 +187,11 @@ def test_where_model_in( ) -> None: classification, first_name, last_name = predictions assert predictions.where(model_in={classification_model}) == [classification] - assert predictions.where(model_in={TaskType.CLASSIFICATION}) == [classification] + assert predictions.where(model_in={ModelGroupType.CLASSIFICATION}) == [ + classification + ] assert predictions.where( - model_in={TaskType.CLASSIFICATION, TaskType.DOCUMENT_EXTRACTION} + model_in={ModelGroupType.CLASSIFICATION, ModelGroupType.DOCUMENT_EXTRACTION} ) == [classification, first_name, last_name] assert predictions.where(model_in={"Tax Classification"}) == [classification] assert predictions.where( From 1ddf8909d5d7c4a8b96cd59aaf53b572b3bf4719 Mon Sep 17 00:00:00 2001 From: Michael Welborn Date: Mon, 10 Feb 2025 17:17:54 -0600 Subject: [PATCH 15/16] Parse failed files into `Result.documents` list --- indico_toolkit/polling/autoreview.py | 1 + indico_toolkit/results/document.py | 27 ++++++++++++++++++++++++ indico_toolkit/results/normalization.py | 11 ++++++++++ indico_toolkit/results/predictionlist.py | 3 +++ indico_toolkit/results/result.py | 6 +++++- tests/data/results/96127_v3_genai.json | 8 ++++++- tests/results/test_predictionlist.py | 3 +++ 7 files changed, 57 insertions(+), 2 deletions(-) diff --git a/indico_toolkit/polling/autoreview.py b/indico_toolkit/polling/autoreview.py index 4d305d5..401a312 100644 --- a/indico_toolkit/polling/autoreview.py +++ b/indico_toolkit/polling/autoreview.py @@ -158,6 +158,7 @@ async def _worker(self, submission_id: "SubmissionId") -> None: tables=self._load_tables, ) for document in result.documents + if not document.failed } else: logger.info(f"Skipping etl output for {submission_id=}") diff --git a/indico_toolkit/results/document.py b/indico_toolkit/results/document.py index 8522f29..47c8afc 100644 --- a/indico_toolkit/results/document.py +++ b/indico_toolkit/results/document.py @@ -8,6 +8,9 @@ class Document: id: int name: str etl_output_uri: str + failed: bool + error: str + traceback: str # Auto review changes must reproduce all model sections that were present in the # original result file. This may not be possible from the predictions alone--if a @@ -31,6 +34,9 @@ def from_v1_dict(result: object) -> "Document": id=None, # type: ignore[arg-type] name=None, # type: ignore[arg-type] etl_output_uri=etl_output_uri, + failed=False, + error="", + traceback="", _model_sections=model_names, ) @@ -47,5 +53,26 @@ def from_v3_dict(document: object) -> "Document": id=get(document, int, "submissionfile_id"), name=get(document, str, "input_filename"), etl_output_uri=etl_output_uri, + failed=False, + error="", + traceback="", _model_sections=model_ids, ) + + @staticmethod + def from_v3_errored_file(errored_file: object) -> "Document": + """ + Create a `Document` from a v3 errored file dictionary. + """ + traceback = get(errored_file, str, "error") + error = traceback.split("\n")[-1].strip() + + return Document( + id=get(errored_file, int, "submissionfile_id"), + name=get(errored_file, str, "input_filename"), + etl_output_uri="", + failed=True, + error=error, + traceback=traceback, + _model_sections=frozenset(), + ) diff --git a/indico_toolkit/results/normalization.py b/indico_toolkit/results/normalization.py index ae73cdd..5d14c91 100644 --- a/indico_toolkit/results/normalization.py +++ b/indico_toolkit/results/normalization.py @@ -1,3 +1,4 @@ +import re from typing import TYPE_CHECKING from .utilities import get, has @@ -146,3 +147,13 @@ def normalize_v3_result(result: "Any") -> None: for review_dict in get(result, dict, "reviews").values(): if not has(review_dict, str, "review_notes"): review_dict["review_notes"] = "" + + # Prior to 7.0, v3 result files don't include an `errored_files` section. + if not has(result, dict, "errored_files"): + result["errored_files"] = {} + + # Prior to 7.X, errored files may lack filenames. + for file in get(result, dict, "errored_files").values(): + if not has(file, str, "input_filename") and has(file, str, "reason"): + match = re.search(r"file '([^']*)' with id", get(file, str, "reason")) + file["input_filename"] = match.group(1) if match else "" diff --git a/indico_toolkit/results/predictionlist.py b/indico_toolkit/results/predictionlist.py index d2a0cd7..140ffbd 100644 --- a/indico_toolkit/results/predictionlist.py +++ b/indico_toolkit/results/predictionlist.py @@ -366,6 +366,9 @@ def to_v3_changes(self, documents: "Iterable[Document]") -> "list[dict[str, Any] changes: "list[dict[str, Any]]" = [] for document in documents: + if document.failed: + continue + model_results: "dict[str, Any]" = {} changes.append( { diff --git a/indico_toolkit/results/result.py b/indico_toolkit/results/result.py index 0d6384b..809fc9e 100644 --- a/indico_toolkit/results/result.py +++ b/indico_toolkit/results/result.py @@ -1,5 +1,6 @@ from dataclasses import dataclass from functools import partial +from itertools import chain from typing import TYPE_CHECKING from . import predictions as prediction @@ -107,7 +108,10 @@ def from_v3_dict(result: object) -> "Result": modelgroup_metadata = get(result, dict, "modelgroup_metadata") review_metadata = get(result, dict, "reviews") - documents = sorted(map(Document.from_v3_dict, submission_results)) + processed_documents = map(Document.from_v3_dict, submission_results) + errored_files = get(result, dict, "errored_files").values() + failed_documents = map(Document.from_v3_errored_file, errored_files) + documents = sorted(chain(processed_documents, failed_documents)) models = sorted(map(ModelGroup.from_v3_dict, modelgroup_metadata.values())) predictions: "PredictionList[Prediction]" = PredictionList() reviews = sorted(map(Review.from_dict, review_metadata.values())) diff --git a/tests/data/results/96127_v3_genai.json b/tests/data/results/96127_v3_genai.json index 1650388..ab2b1b7 100644 --- a/tests/data/results/96127_v3_genai.json +++ b/tests/data/results/96127_v3_genai.json @@ -190,5 +190,11 @@ "review_type": "manual" } }, - "errored_files": {} + "errored_files": { + "89245": { + "submissionfile_id": 89245, + "error": "Traceback (most recent call last):\n\n File \"/venv/.venv/lib/python3.10/site-packages/mediocris/pdfconverter/converter.py\", line 92, in libre_to_pdf\n raise FileProcessingFailed()\n\nmediocris.pdfconverter.errors.FileProcessingFailed\n\n\nDuring handling of the above exception, another exception occurred:\n\n\nTraceback (most recent call last):\n\n File \"/venv/.venv/lib/python3.10/site-packages/mediocris/pdfconverter/converter.py\", line 92, in libre_to_pdf\n raise FileProcessingFailed()\n\nmediocris.pdfconverter.errors.FileProcessingFailed\n\n\nDuring handling of the above exception, another exception occurred:\n\n\nTraceback (most recent call last):\n\n File \"/readapi/readapi/celery_tasks/submission.py\", line 87, in _\n readapi_input: dict = await get_readapi_client().prepare_for_ocr(\n\n File \"/readapi/readapi/read/readapi.py\", line 193, in prepare_for_ocr\n pdf_path: str = await async_convert_file(\n\n File \"/readapi/readapi/read/readapi.py\", line 127, in async_convert_file\n filename: str = await asyncio.wait_for(\n\n File \"/usr/lib/python3.10/asyncio/tasks.py\", line 445, in wait_for\n return fut.result()\n\n File \"/usr/lib/python3.10/asyncio/threads.py\", line 25, in to_thread\n return await loop.run_in_executor(None, func_call)\n\n File \"/usr/lib/python3.10/concurrent/futures/thread.py\", line 58, in run\n result = self.fn(*self.args, **self.kwargs)\n\n File \"/venv/.venv/lib/python3.10/site-packages/mediocris/pdfconverter/convert.py\", line 82, in write_and_convert_file\n convert_to_pdf(\n\n File \"/venv/.venv/lib/python3.10/site-packages/mediocris/pdfconverter/convert.py\", line 278, in convert_to_pdf\n libre_to_pdf(file_path, pdf_file_path, timeout=timeout)\n\n File \"/venv/.venv/lib/python3.10/site-packages/mediocris/pdfconverter/converter.py\", line 105, in libre_to_pdf\n result = libre_to_pdf(\n\n File \"/venv/.venv/lib/python3.10/site-packages/mediocris/pdfconverter/converter.py\", line 114, in libre_to_pdf\n raise InvalidFile(\"Failed to process input file.\")\n\nmediocris.pdfconverter.errors.InvalidFile: Failed to process input file.", + "reason": "Error preparing for OCR, skipping submission file 'technical_questions.docx' with id '89245'" + } + } } diff --git a/tests/results/test_predictionlist.py b/tests/results/test_predictionlist.py index 56c8c8f..5df1229 100644 --- a/tests/results/test_predictionlist.py +++ b/tests/results/test_predictionlist.py @@ -23,6 +23,9 @@ def document() -> Document: id=2922, name="1040_filled.tiff", etl_output_uri="indico-file:///storage/submission/2922/etl_output.json", + failed=False, + error="", + traceback="", _model_sections=frozenset({"124", "123", "122", "121"}), ) From 12b1981d0454ae5d52c81326da32793125fb009d Mon Sep 17 00:00:00 2001 From: Michael Welborn Date: Tue, 11 Feb 2025 08:22:16 -0600 Subject: [PATCH 16/16] Update etloutput docstring examples with `document.failed` --- indico_toolkit/etloutput/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/indico_toolkit/etloutput/__init__.py b/indico_toolkit/etloutput/__init__.py index 51e0012..62b41de 100644 --- a/indico_toolkit/etloutput/__init__.py +++ b/indico_toolkit/etloutput/__init__.py @@ -51,6 +51,7 @@ def load( etl_outputs = { document: etloutput.load(document.etl_output_uri, reader=read_uri) for document in result.documents + if not document.failed } ``` """ @@ -82,6 +83,7 @@ async def load_async( etl_outputs = { document: await etloutput.load_async(document.etl_output_uri, reader=read_uri) for document in result.documents + if not document.failed } ``` """