From 06354e2ce653e8c64b7f85f4e0557e0235107d87 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=EC=9D=B4=ED=83=9C=ED=99=94?= <102410772+studyingeugene@users.noreply.github.com> Date: Wed, 22 Oct 2025 17:23:08 +0900 Subject: [PATCH 1/2] refactor: simplify forward() permutation logic for compile-friendly execution What's changed - Replace tensor-based perm construction with list-based version - Add explicit inverse permutation for correctness - Remove TorchScript-specific branches Why - Compile-friendly: torch.compile/AOTAutograd prefer static Python control flow and index lists over device tensor construction inside forward. Replacing torch.tensor([...]), torch.arange(...), and torch.cat(...) with plain Python lists reduces graph breaks and guard complexity, improving compilation stability and cache reuse. --- compressai/entropy_models/entropy_models.py | 29 ++++++--------------- 1 file changed, 8 insertions(+), 21 deletions(-) diff --git a/compressai/entropy_models/entropy_models.py b/compressai/entropy_models/entropy_models.py index c3ccb4d9..fbfb982f 100644 --- a/compressai/entropy_models/entropy_models.py +++ b/compressai/entropy_models/entropy_models.py @@ -474,28 +474,18 @@ def forward( if training is None: training = self.training - if not torch.jit.is_scripting(): - # x from B x C x ... to C x B x ... - perm = torch.cat( - ( - torch.tensor([1, 0], dtype=torch.long, device=x.device), - torch.arange(2, x.ndim, dtype=torch.long, device=x.device), - ) - ) - inv_perm = perm - else: - raise NotImplementedError() - # TorchScript in 2D for static inference - # Convert to (channels, ... , batch) format - # perm = (1, 2, 3, 0) - # inv_perm = (3, 0, 1, 2) + D = x.dim() + # B C ... -> C B ... + perm = [1, 0] + list(range(2, D)) + inv_perm = [0] * D + for i, p in enumerate(perm): + inv_perm[p] = i x = x.permute(*perm).contiguous() shape = x.size() values = x.reshape(x.size(0), 1, -1) # Add noise or quantize - outputs = self.quantize( values, "noise" if training else "dequantize", self._get_medians() ) @@ -510,11 +500,8 @@ def forward( # likelihood = torch.zeros_like(outputs) # Convert back to input tensor shape - outputs = outputs.reshape(shape) - outputs = outputs.permute(*inv_perm).contiguous() - - likelihood = likelihood.reshape(shape) - likelihood = likelihood.permute(*inv_perm).contiguous() + outputs = outputs.reshape(shape).permute(*inv_perm).contiguous() + likelihood = likelihood.reshape(shape).permute(*inv_perm).contiguous() return outputs, likelihood From 5c9529701e906ef7b47e0f117a6cd207456b62c8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=EC=9D=B4=ED=83=9C=ED=99=94?= <102410772+studyingeugene@users.noreply.github.com> Date: Wed, 22 Oct 2025 21:17:17 +0900 Subject: [PATCH 2/2] refactor: fix lint errors Fix lint errors in entropy_models.py --- compressai/entropy_models/entropy_models.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/compressai/entropy_models/entropy_models.py b/compressai/entropy_models/entropy_models.py index fbfb982f..acb98333 100644 --- a/compressai/entropy_models/entropy_models.py +++ b/compressai/entropy_models/entropy_models.py @@ -54,9 +54,7 @@ def __init__(self, method): if method not in available_entropy_coders(): methods = ", ".join(available_entropy_coders()) - raise ValueError( - f'Unknown entropy coder "{method}"' f" (available: {methods})" - ) + raise ValueError(f'Unknown entropy coder "{method}" (available: {methods})') if method == "ans": from compressai import ans @@ -474,7 +472,7 @@ def forward( if training is None: training = self.training - D = x.dim() + D = x.dim() # B C ... -> C B ... perm = [1, 0] + list(range(2, D)) inv_perm = [0] * D