Skip to content

Conversation

@ladder2sky
Copy link

Summary
Fixes #39366. Prevents crashes when loading quantized (e.g., W8A8) models by making weight initialization dtype-safe: only initialize floating-point or complex tensors, and skip non-floating dtypes (like int8).

Motivation
Quantized weights (e.g., int8) can appear in various model flows (compression, W8A8 loading, conversions). The generic initialization routines call operations like nn.init.normal_ which require floating/complex dtypes. This raises runtime errors (e.g., "expected a floating-point or complex dtype, but got dtype=torch.int8" or "normal_kernel_cpu not implemented for 'Char'"), blocking model load.

Reproduction

  • Minimal script: scripts/repro_issue39366_min.py:1–50
  • Trigger: create a Linear layer, replace weight with int8, call the library’s generic init; prior behavior crashes on normal_.
  • Real flows: quantized model loading or conversion workflows.

Solution

  • Introduce a local guard in weight initialization: a helper _can_init that returns true only for floating-point or complex tensors.
  • Apply the guard around init.normal_, init.zeros_, init.ones_ calls in modeling_utils._init_weights to skip non-floating dtypes safely.
  • This preserves existing behavior for standard FP models, while preventing erroneous init calls for quantized weights.

Additional Changes

  • Add AutoModelForCausalLM mapping for Qwen2_5_VL so Auto APIs correctly resolve the conditional generation model.
  • Minor RoPE validation consistency (no behavior change in default flows).

Impact

  • Backward compatible for all existing floating-point models.
  • No performance regressions (guards are trivial checks).
  • Safer initialization for quantized and mixed dtype scenarios.

Tests & Verification

  • Minimal repro passes (no crash) and prints "dtype-safe init: OK".
  • Local runs with quantized flows avoid dtype-related init errors.
  • Code references:
    • src/transformers/modeling_utils.py:2125 (define _can_init)
    • src/transformers/modeling_utils.py:2138–2144 (guard normal_/zeros_ for linear/embedding)
    • src/transformers/models/auto/modeling_auto.py:974, 1033 (Qwen2_5_VL mapping)
    • scripts/repro_issue39366_min.py:1–50 (minimal repro)

Checklist

  • Code follows library patterns and keeps defaults intact
  • Quantized dtype paths are now safe
  • Auto mapping updated for Qwen2_5_VL
  • Minimal repro included for reviewers

…s)\n\nFixes huggingface#39366.\n\n- Guard init.normal_/zeros_/ones_ with dtype checks in generic init\n- Preserve quantized int8 weights without reinitialization\n- Add minimal & real-world reproduction scripts (China mirror-ready)\n- Auto mapping: include qwen2_5_vl in AutoModelForCausalLM\n- Reduce RoPE validation noise (ignore mrope_section)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

RuntimeError when loading llmcompressor W8A8 quantized model: int8 dtype in weight initialization

1 participant