diff --git a/compressai/entropy_models/entropy_models.py b/compressai/entropy_models/entropy_models.py index c3ccb4d9..acb98333 100644 --- a/compressai/entropy_models/entropy_models.py +++ b/compressai/entropy_models/entropy_models.py @@ -54,9 +54,7 @@ def __init__(self, method): if method not in available_entropy_coders(): methods = ", ".join(available_entropy_coders()) - raise ValueError( - f'Unknown entropy coder "{method}"' f" (available: {methods})" - ) + raise ValueError(f'Unknown entropy coder "{method}" (available: {methods})') if method == "ans": from compressai import ans @@ -474,28 +472,18 @@ def forward( if training is None: training = self.training - if not torch.jit.is_scripting(): - # x from B x C x ... to C x B x ... - perm = torch.cat( - ( - torch.tensor([1, 0], dtype=torch.long, device=x.device), - torch.arange(2, x.ndim, dtype=torch.long, device=x.device), - ) - ) - inv_perm = perm - else: - raise NotImplementedError() - # TorchScript in 2D for static inference - # Convert to (channels, ... , batch) format - # perm = (1, 2, 3, 0) - # inv_perm = (3, 0, 1, 2) + D = x.dim() + # B C ... -> C B ... + perm = [1, 0] + list(range(2, D)) + inv_perm = [0] * D + for i, p in enumerate(perm): + inv_perm[p] = i x = x.permute(*perm).contiguous() shape = x.size() values = x.reshape(x.size(0), 1, -1) # Add noise or quantize - outputs = self.quantize( values, "noise" if training else "dequantize", self._get_medians() ) @@ -510,11 +498,8 @@ def forward( # likelihood = torch.zeros_like(outputs) # Convert back to input tensor shape - outputs = outputs.reshape(shape) - outputs = outputs.permute(*inv_perm).contiguous() - - likelihood = likelihood.reshape(shape) - likelihood = likelihood.permute(*inv_perm).contiguous() + outputs = outputs.reshape(shape).permute(*inv_perm).contiguous() + likelihood = likelihood.reshape(shape).permute(*inv_perm).contiguous() return outputs, likelihood