From d93e0eada74affd5282d15c34a00e283d456a39c Mon Sep 17 00:00:00 2001 From: susanbao Date: Wed, 17 Dec 2025 22:14:49 +0000 Subject: [PATCH 1/4] change quantization calibration method --- src/maxdiffusion/configs/base_wan_14b.yml | 6 ++++-- src/maxdiffusion/pipelines/wan/wan_pipeline.py | 12 ++++++------ 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/src/maxdiffusion/configs/base_wan_14b.yml b/src/maxdiffusion/configs/base_wan_14b.yml index 8cd7e70f..a11ce9b5 100644 --- a/src/maxdiffusion/configs/base_wan_14b.yml +++ b/src/maxdiffusion/configs/base_wan_14b.yml @@ -319,8 +319,10 @@ quantization: '' quantization_local_shard_count: -1 compile_topology_num_slices: -1 # Number of target slices, set to a positive integer. use_qwix_quantization: False # Whether to use qwix for quantization. If set to True, the transformer of WAN will be quantized using qwix. -# Quantization calibration method used for weights and activations. Supported methods can be found in https://github.com/google/qwix/blob/dc2a0770351c740e5ab3cce7c0efe9f7beacce9e/qwix/qconfig.py#L70-L80 -quantization_calibration_method: "absmax" +# Quantization calibration method used for weights, activations and bwd. Supported methods can be found in https://github.com/google/qwix/blob/dc2a0770351c740e5ab3cce7c0efe9f7beacce9e/qwix/qconfig.py#L70-L80 +weight_quantization_calibration_method: "absmax" +act_quantization_calibration_method: "absmax" +bwd_quantization_calibration_method: "absmax" qwix_module_path: ".*" # Eval model on per eval_every steps. -1 means don't eval. diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index 153c225d..d00c0116 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -302,9 +302,9 @@ def get_fp8_config(cls, config: HyperParameters): act_qtype=jnp.float8_e4m3fn, bwd_qtype=jnp.float8_e5m2, disable_channelwise_axes=True, # per_tensor calibration - weight_calibration_method=config.quantization_calibration_method, - act_calibration_method=config.quantization_calibration_method, - bwd_calibration_method=config.quantization_calibration_method, + weight_calibration_method=config.weight_quantization_calibration_method, + act_calibration_method=config.act_quantization_calibration_method, + bwd_calibration_method=config.bwd_quantization_calibration_method, op_names=("dot_general", "einsum"), ), qwix.QtRule( @@ -313,9 +313,9 @@ def get_fp8_config(cls, config: HyperParameters): act_qtype=jnp.float8_e4m3fn, bwd_qtype=jnp.float8_e4m3fn, disable_channelwise_axes=True, # per_tensor calibration - weight_calibration_method=config.quantization_calibration_method, - act_calibration_method=config.quantization_calibration_method, - bwd_calibration_method=config.quantization_calibration_method, + weight_calibration_method=config.weight_quantization_calibration_method, + act_calibration_method=config.act_quantization_calibration_method, + bwd_calibration_method=config.bwd_quantization_calibration_method, op_names=("conv_general_dilated"), ), ] From 2d4d61189737ccdadd1572f09862d307bd56a32a Mon Sep 17 00:00:00 2001 From: susanbao Date: Wed, 17 Dec 2025 22:34:23 +0000 Subject: [PATCH 2/4] fixe unit test --- src/maxdiffusion/tests/wan_transformer_test.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/src/maxdiffusion/tests/wan_transformer_test.py b/src/maxdiffusion/tests/wan_transformer_test.py index 34f0ef64..ef530a39 100644 --- a/src/maxdiffusion/tests/wan_transformer_test.py +++ b/src/maxdiffusion/tests/wan_transformer_test.py @@ -332,7 +332,9 @@ def create_real_rule_instance(*args, **kwargs): config_fp8_full = Mock(spec=HyperParameters) config_fp8_full.use_qwix_quantization = True config_fp8_full.quantization = "fp8_full" - config_fp8_full.quantization_calibration_method = "absmax" + config_fp8_full.weight_quantization_calibration_method = "absmax" + config_fp8_full.act_quantization_calibration_method = "absmax" + config_fp8_full.bwd_quantization_calibration_method = "absmax" config_fp8_full.qwix_module_path = ".*" provider_fp8_full = WanPipeline.get_qt_provider(config_fp8_full) self.assertIsNotNone(provider_fp8_full) @@ -343,9 +345,9 @@ def create_real_rule_instance(*args, **kwargs): act_qtype=jnp.float8_e4m3fn, bwd_qtype=jnp.float8_e5m2, disable_channelwise_axes=True, # per_tensor calibration - weight_calibration_method=config_fp8_full.quantization_calibration_method, - act_calibration_method=config_fp8_full.quantization_calibration_method, - bwd_calibration_method=config_fp8_full.quantization_calibration_method, + weight_calibration_method=config_fp8_full.weight_quantization_calibration_method, + act_calibration_method=config_fp8_full.act_quantization_calibration_method, + bwd_calibration_method=config_fp8_full.bwd_quantization_calibration_method, op_names=("dot_general", "einsum"), ), call( @@ -354,9 +356,9 @@ def create_real_rule_instance(*args, **kwargs): act_qtype=jnp.float8_e4m3fn, bwd_qtype=jnp.float8_e4m3fn, disable_channelwise_axes=True, # per_tensor calibration - weight_calibration_method=config_fp8_full.quantization_calibration_method, - act_calibration_method=config_fp8_full.quantization_calibration_method, - bwd_calibration_method=config_fp8_full.quantization_calibration_method, + weight_calibration_method=config_fp8_full.weight_quantization_calibration_method, + act_calibration_method=config_fp8_full.act_quantization_calibration_method, + bwd_calibration_method=config_fp8_full.bwd_quantization_calibration_method, op_names=("conv_general_dilated"), ), ] From 15b61a876d22a0561fc80ec41fa6ea2ea598bd0b Mon Sep 17 00:00:00 2001 From: susanbao Date: Wed, 17 Dec 2025 22:35:06 +0000 Subject: [PATCH 3/4] change set --- src/maxdiffusion/tests/wan_transformer_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/maxdiffusion/tests/wan_transformer_test.py b/src/maxdiffusion/tests/wan_transformer_test.py index ef530a39..a5ca3156 100644 --- a/src/maxdiffusion/tests/wan_transformer_test.py +++ b/src/maxdiffusion/tests/wan_transformer_test.py @@ -332,8 +332,8 @@ def create_real_rule_instance(*args, **kwargs): config_fp8_full = Mock(spec=HyperParameters) config_fp8_full.use_qwix_quantization = True config_fp8_full.quantization = "fp8_full" - config_fp8_full.weight_quantization_calibration_method = "absmax" - config_fp8_full.act_quantization_calibration_method = "absmax" + config_fp8_full.weight_quantization_calibration_method = "fixed,-224,224" + config_fp8_full.act_quantization_calibration_method = "fixed,-224,224" config_fp8_full.bwd_quantization_calibration_method = "absmax" config_fp8_full.qwix_module_path = ".*" provider_fp8_full = WanPipeline.get_qt_provider(config_fp8_full) From 3668dded83c3da51c9d325a73ba26452fd31b0e8 Mon Sep 17 00:00:00 2001 From: susanbao Date: Wed, 17 Dec 2025 23:07:31 +0000 Subject: [PATCH 4/4] fix unit test --- src/maxdiffusion/tests/wan_transformer_test.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/maxdiffusion/tests/wan_transformer_test.py b/src/maxdiffusion/tests/wan_transformer_test.py index a5ca3156..71b7ce6e 100644 --- a/src/maxdiffusion/tests/wan_transformer_test.py +++ b/src/maxdiffusion/tests/wan_transformer_test.py @@ -383,7 +383,9 @@ def test_quantize_transformer_enabled(self, mock_get_dummy_inputs, mock_quantize mock_config.quantization = "fp8_full" mock_config.qwix_module_path = ".*" mock_config.per_device_batch_size = 1 - mock_config.quantization_calibration_method = "absmax" + mock_config.weight_quantization_calibration_method = "fixed,-224,224" + mock_config.act_quantization_calibration_method = "fixed,-224,224" + mock_config.bwd_quantization_calibration_method = "absmax" mock_model = Mock(spec=WanModel) mock_pipeline = Mock()