diff --git a/src/maxdiffusion/configs/base_wan_14b.yml b/src/maxdiffusion/configs/base_wan_14b.yml index 8cd7e70f..a11ce9b5 100644 --- a/src/maxdiffusion/configs/base_wan_14b.yml +++ b/src/maxdiffusion/configs/base_wan_14b.yml @@ -319,8 +319,10 @@ quantization: '' quantization_local_shard_count: -1 compile_topology_num_slices: -1 # Number of target slices, set to a positive integer. use_qwix_quantization: False # Whether to use qwix for quantization. If set to True, the transformer of WAN will be quantized using qwix. -# Quantization calibration method used for weights and activations. Supported methods can be found in https://github.com/google/qwix/blob/dc2a0770351c740e5ab3cce7c0efe9f7beacce9e/qwix/qconfig.py#L70-L80 -quantization_calibration_method: "absmax" +# Quantization calibration method used for weights, activations and bwd. Supported methods can be found in https://github.com/google/qwix/blob/dc2a0770351c740e5ab3cce7c0efe9f7beacce9e/qwix/qconfig.py#L70-L80 +weight_quantization_calibration_method: "absmax" +act_quantization_calibration_method: "absmax" +bwd_quantization_calibration_method: "absmax" qwix_module_path: ".*" # Eval model on per eval_every steps. -1 means don't eval. diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index 153c225d..d00c0116 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -302,9 +302,9 @@ def get_fp8_config(cls, config: HyperParameters): act_qtype=jnp.float8_e4m3fn, bwd_qtype=jnp.float8_e5m2, disable_channelwise_axes=True, # per_tensor calibration - weight_calibration_method=config.quantization_calibration_method, - act_calibration_method=config.quantization_calibration_method, - bwd_calibration_method=config.quantization_calibration_method, + weight_calibration_method=config.weight_quantization_calibration_method, + act_calibration_method=config.act_quantization_calibration_method, + bwd_calibration_method=config.bwd_quantization_calibration_method, op_names=("dot_general", "einsum"), ), qwix.QtRule( @@ -313,9 +313,9 @@ def get_fp8_config(cls, config: HyperParameters): act_qtype=jnp.float8_e4m3fn, bwd_qtype=jnp.float8_e4m3fn, disable_channelwise_axes=True, # per_tensor calibration - weight_calibration_method=config.quantization_calibration_method, - act_calibration_method=config.quantization_calibration_method, - bwd_calibration_method=config.quantization_calibration_method, + weight_calibration_method=config.weight_quantization_calibration_method, + act_calibration_method=config.act_quantization_calibration_method, + bwd_calibration_method=config.bwd_quantization_calibration_method, op_names=("conv_general_dilated"), ), ] diff --git a/src/maxdiffusion/tests/wan_transformer_test.py b/src/maxdiffusion/tests/wan_transformer_test.py index 34f0ef64..71b7ce6e 100644 --- a/src/maxdiffusion/tests/wan_transformer_test.py +++ b/src/maxdiffusion/tests/wan_transformer_test.py @@ -332,7 +332,9 @@ def create_real_rule_instance(*args, **kwargs): config_fp8_full = Mock(spec=HyperParameters) config_fp8_full.use_qwix_quantization = True config_fp8_full.quantization = "fp8_full" - config_fp8_full.quantization_calibration_method = "absmax" + config_fp8_full.weight_quantization_calibration_method = "fixed,-224,224" + config_fp8_full.act_quantization_calibration_method = "fixed,-224,224" + config_fp8_full.bwd_quantization_calibration_method = "absmax" config_fp8_full.qwix_module_path = ".*" provider_fp8_full = WanPipeline.get_qt_provider(config_fp8_full) self.assertIsNotNone(provider_fp8_full) @@ -343,9 +345,9 @@ def create_real_rule_instance(*args, **kwargs): act_qtype=jnp.float8_e4m3fn, bwd_qtype=jnp.float8_e5m2, disable_channelwise_axes=True, # per_tensor calibration - weight_calibration_method=config_fp8_full.quantization_calibration_method, - act_calibration_method=config_fp8_full.quantization_calibration_method, - bwd_calibration_method=config_fp8_full.quantization_calibration_method, + weight_calibration_method=config_fp8_full.weight_quantization_calibration_method, + act_calibration_method=config_fp8_full.act_quantization_calibration_method, + bwd_calibration_method=config_fp8_full.bwd_quantization_calibration_method, op_names=("dot_general", "einsum"), ), call( @@ -354,9 +356,9 @@ def create_real_rule_instance(*args, **kwargs): act_qtype=jnp.float8_e4m3fn, bwd_qtype=jnp.float8_e4m3fn, disable_channelwise_axes=True, # per_tensor calibration - weight_calibration_method=config_fp8_full.quantization_calibration_method, - act_calibration_method=config_fp8_full.quantization_calibration_method, - bwd_calibration_method=config_fp8_full.quantization_calibration_method, + weight_calibration_method=config_fp8_full.weight_quantization_calibration_method, + act_calibration_method=config_fp8_full.act_quantization_calibration_method, + bwd_calibration_method=config_fp8_full.bwd_quantization_calibration_method, op_names=("conv_general_dilated"), ), ] @@ -381,7 +383,9 @@ def test_quantize_transformer_enabled(self, mock_get_dummy_inputs, mock_quantize mock_config.quantization = "fp8_full" mock_config.qwix_module_path = ".*" mock_config.per_device_batch_size = 1 - mock_config.quantization_calibration_method = "absmax" + mock_config.weight_quantization_calibration_method = "fixed,-224,224" + mock_config.act_quantization_calibration_method = "fixed,-224,224" + mock_config.bwd_quantization_calibration_method = "absmax" mock_model = Mock(spec=WanModel) mock_pipeline = Mock()