Fix: dtype-safe weight initialization for quantized models (skip non-floating tensors) #42857
+81
−5
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Summary
Fixes #39366. Prevents crashes when loading quantized (e.g., W8A8) models by making weight initialization dtype-safe: only initialize floating-point or complex tensors, and skip non-floating dtypes (like int8).
Motivation
Quantized weights (e.g., int8) can appear in various model flows (compression, W8A8 loading, conversions). The generic initialization routines call operations like nn.init.normal_ which require floating/complex dtypes. This raises runtime errors (e.g., "expected a floating-point or complex dtype, but got dtype=torch.int8" or "normal_kernel_cpu not implemented for 'Char'"), blocking model load.
Reproduction
Solution
Additional Changes
Impact
Tests & Verification
Checklist