Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
5db61eb
📝 Make changes to 05-patch-prediction notebook
gozdeg Dec 15, 2025
e7273b9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 15, 2025
ede9dfe
🐛 Bug fix for model/module assumption - implement helper logic to acc…
gozdeg Dec 16, 2025
407b6f3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 16, 2025
6da5741
📝 Update notebook-5 & fix engine_abc _get_model_attr() type checking
gozdeg Dec 17, 2025
7d47b2b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 17, 2025
cffae15
:bug: Fix pre-commit errors
shaneahmed Dec 17, 2025
a8d28c1
:white_check_mark: Fix rendering issues
shaneahmed Dec 17, 2025
709813b
:rewind: Revert changes to multi_task_segmentor.py and nucleus_instan…
shaneahmed Dec 17, 2025
cfe0d5e
:art: Improve Jupyter Notebook output
shaneahmed Dec 17, 2025
e1afce2
📝 Update-notebook-5
gozdeg Dec 18, 2025
267133e
🐛 Small fix
gozdeg Dec 18, 2025
c960baf
Merge branch 'dev-define-engines-abc' into dev-update-example-notebooks
shaneahmed Dec 19, 2025
4a7a35f
Merge branch 'dev-define-engines-abc' into dev-update-example-notebooks
shaneahmed Jan 8, 2026
372ee3e
[skip ci] :doc: Update documentation and fix links.
shaneahmed Jan 9, 2026
f50ee08
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 9, 2026
8bb866a
:white_check_mark: Add test for coverage
shaneahmed Jan 9, 2026
2bfcf92
:memo: Update 05-patch-prediction.ipynb
shaneahmed Jan 9, 2026
5f0bab0
[skip ci] :memo: Update 05-patch-prediction.ipynb
shaneahmed Jan 9, 2026
e600960
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 9, 2026
8182442
[skip ci] :doc: Fix typo and update metada.
shaneahmed Jan 9, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
761 changes: 578 additions & 183 deletions examples/05-patch-prediction.ipynb

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion pre-commit/notebook_markdown_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ def main(files: list[Path]) -> None:

"""
for path in files:
notebook = json.loads(path.read_text())
with Path.open(path, encoding="utf-8", errors="ignore") as f:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
with Path.open(path, encoding="utf-8", errors="ignore") as f:
notebook = json.loads(path.read_text())

notebook = json.load(f)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
notebook = json.load(f)

formatted_notebook = format_notebook(copy.deepcopy(notebook))
changed = any(
cell != formatted_cell
Expand Down
1 change: 1 addition & 0 deletions requirements/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ flask-cors>=4.0.0
glymur>=0.12.7
huggingface_hub>=0.33.3
imagecodecs>=2022.9.26
ipywidgets>=8.1.7
joblib>=1.1.1
jupyterlab>=3.5.2
matplotlib>=3.6.2
Expand Down
7 changes: 7 additions & 0 deletions tests/engines/test_engine_abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,13 @@ def test_engine_initalization() -> NoReturn:
eng = TestEngineABC(model=model, weights=weights_path)
assert isinstance(eng, EngineABC)

with pytest.raises(AttributeError):
_ = eng._get_model_attr("test_attr")

model.test_attr = True
eng = TestEngineABC(model=model, weights=weights_path)
assert eng._get_model_attr("test_attr") is True


def test_engine_run() -> NoReturn:
"""Test engine run."""
Expand Down
3 changes: 2 additions & 1 deletion tiatoolbox/models/engine/deep_feature_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,8 +302,9 @@ def infer_wsi(
probabilities_zarr, coordinates_zarr = None, None

probabilities_used_percent = 0
infer_batch = self._get_model_attr("infer_batch")
for batch_data in tqdm_loop:
batch_output = self.model.infer_batch(
batch_output = infer_batch(
self.model,
batch_data["image"],
device=self.device,
Expand Down
16 changes: 13 additions & 3 deletions tiatoolbox/models/engine/engine_abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@

if TYPE_CHECKING: # pragma: no cover
import os
from collections.abc import Callable

from torch.utils.data import DataLoader

Expand Down Expand Up @@ -375,6 +376,14 @@ def _initialize_model_ioconfig(

return model, None

def _get_model_attr(self: EngineABC, attr_name: str) -> Callable:
"""Return a model attribute, unwrapping DataParallel if required."""
try:
return getattr(self.model, attr_name)
except AttributeError:
module = getattr(self.model, "module", None)
return getattr(module, attr_name)

def get_dataloader(
self: EngineABC,
images: str | Path | list[str | Path] | np.ndarray,
Expand Down Expand Up @@ -428,7 +437,7 @@ def get_dataloader(
auto_get_mask=auto_get_mask,
)

dataset.preproc_func = self.model.preproc_func
dataset.preproc_func = self._get_model_attr("preproc_func")

# preprocessing must be defined with the dataset
return torch.utils.data.DataLoader(
Expand All @@ -444,7 +453,7 @@ def get_dataloader(
inputs=images, labels=labels, patch_input_shape=ioconfig.patch_input_shape
)

dataset.preproc_func = self.model.preproc_func
dataset.preproc_func = self._get_model_attr("preproc_func")

# preprocessing must be defined with the dataset
return torch.utils.data.DataLoader(
Expand Down Expand Up @@ -529,8 +538,9 @@ def infer_patches(
else self.dataloader
)

infer_batch = self._get_model_attr("infer_batch")
for batch_data in tqdm_loop:
batch_output = self.model.infer_batch(
batch_output = infer_batch(
self.model,
batch_data["image"],
device=self.device,
Expand Down
3 changes: 2 additions & 1 deletion tiatoolbox/models/engine/patch_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,8 @@ def post_process_patches(
dict[str, da.Array]: Post-processed predictions as a Dask array.

"""
predictions = self.model.postproc_func(raw_predictions["probabilities"])
postproc_func = self._get_model_attr("postproc_func")
predictions = postproc_func(raw_predictions["probabilities"])
raw_predictions["predictions"] = cast_to_min_dtype(predictions)
return raw_predictions

Expand Down
5 changes: 3 additions & 2 deletions tiatoolbox/models/engine/semantic_segmentor.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ def get_dataloader(
auto_get_mask=auto_get_mask,
)

dataset.preproc_func = self.model.preproc_func
dataset.preproc_func = self._get_model_attr("preproc_func")
self.output_locations = dataset.outputs

# preprocessing must be defined with the dataset
Expand Down Expand Up @@ -477,8 +477,9 @@ def infer_wsi(
else dataloader.dataset.outputs
)

infer_batch = self._get_model_attr("infer_batch")
for batch_idx, batch_data in enumerate(tqdm_loop):
batch_output = self.model.infer_batch(
batch_output = infer_batch(
self.model,
batch_data["image"],
device=self.device,
Expand Down