Feat: Speculatice Decoding export with quantization support#913
Feat: Speculatice Decoding export with quantization support#913
Conversation
|
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
|
Important Review skippedAuto incremental reviews are disabled on this repository. Please check the settings in the CodeRabbit UI or the You can disable this status message by setting the Use the checkbox below for a quick retry:
📝 WalkthroughWalkthroughThis pull request introduces a class-based refactoring of the speculative decoding export pipeline, replacing legacy procedural functions with EagleExporter and EagleMedusaExporter classes. A new public export_speculative_decoding API exports spec-optimized models independently, with updated key naming schemes, configuration templates, and early-exit integration points in the HF export flow. Changes
Sequence DiagramsequenceDiagram
participant Export as Export Flow
participant Check as has_spec_opt()
participant Exporter as EagleExporter/<br/>EagleMedusaExporter
participant SaveState as Save State Dict
participant SaveConfig as Save Config
Export->>Check: Check if spec-optimized
alt Spec-Optimized Model
Check-->>Export: True
Export->>Exporter: Create exporter instance
Exporter->>Exporter: extract_state_dict()
Exporter-->>SaveState: Filtered state dict
SaveState->>SaveState: model.safetensors
Exporter->>Exporter: export_config()
Exporter-->>SaveConfig: Resolved config
SaveConfig->>SaveConfig: config.json
else Standard Model
Check-->>Export: False
Export->>Export: Continue standard export
end
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes 🚥 Pre-merge checks | ✅ 3✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #913 +/- ##
=======================================
Coverage 73.10% 73.10%
=======================================
Files 205 205
Lines 22281 22281
=======================================
Hits 16288 16288
Misses 5993 5993 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
d9926e9 to
1b73de3
Compare
There was a problem hiding this comment.
Actionable comments posted: 4
🧹 Nitpick comments (1)
modelopt/torch/export/plugins/hf_spec_export.py (1)
185-214: Validation bypass is documented as temporary.The
_check_valid_sd = lambda *args, **kwargs: Noneon line 194 effectively disables state dict validation for parallel draft exports. TheNOTE: tmp:comment indicates this is intentional but temporary.Consider tracking this with a TODO or issue reference to ensure validation is properly implemented for parallel draft exports before the feature is considered stable.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/export/plugins/hf_spec_export.py` around lines 185 - 214, The code currently disables state-dict validation by setting self._check_valid_sd = lambda *args, **kwargs: None in the EagleMedusaExporter __init__, which is marked only as a temporary NOTE; replace this silent bypass with a tracked TODO and a visible reminder: restore validation by implementing proper checks for parallel_draft_step in extract_state_dict and call the original EagleExporter._check_valid_sd (or raise/log a clear warning/error) until full validation is implemented; specifically update the EagleMedusaExporter class to remove the no-op lambda, add a TODO/issue-ID comment referencing the missing validation work, and ensure any call sites (e.g., extract_state_dict) invoke the proper _check_valid_sd behavior so state-dict validation is not permanently skipped.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@examples/llm_ptq/hf_ptq.py`:
- Around line 571-574: The early return after calling has_spec_opt(full_model)
and export_speculative_decoding(full_model, export_dir=export_path) skips the
subsequent tokenizer save and timing/export message; update the
speculative-decoding branch so it either (a) calls the same tokenizer save
routine (e.g., tokenizer.save_pretrained or the existing tokenizer save logic)
and prints the export/timing confirmation before returning, or (b) moves the
return to after those steps, and if skipping is intentional add a concise
comment explaining why; reference has_spec_opt, export_speculative_decoding,
full_model and export_path so the change is applied to the correct branch.
In `@modelopt/torch/export/plugins/hf_spec_export.py`:
- Around line 180-182: Fix the typo in the docstring of export_quant_config:
change "hf_quant_coinfig.json" to "hf_quant_config.json" in the docstring for
the function export_quant_config which returns copy(self.hf_quant_config).
- Around line 144-178: In export_config, using copy(template_config) creates
only a shallow copy so nested dicts (e.g., eagle config data) are mutated on
assignment; replace the shallow copy with a deep copy (use copy.deepcopy) when
copying the selected template (referencing template_config,
llama_eagle_template_config, kimik2_eagle_template_config in the export_config
method) so modifications to nested keys do not alter the original imported
templates across multiple calls; ensure the copy module's deepcopy is
imported/used accordingly.
In `@modelopt/torch/export/unified_export_hf.py`:
- Around line 994-996: The comment above the state-dict export is incorrect:
change the misleading "Export config.json" comment that precedes the lines using
exporter.extract_state_dict(), drafter_sd, and save_file(...,
"model.safetensors") to accurately describe exporting the model state dict
(e.g., "Export model state dict to model.safetensors"), leaving the actual
config.json export block (using save_file for config.json) unchanged.
---
Nitpick comments:
In `@modelopt/torch/export/plugins/hf_spec_export.py`:
- Around line 185-214: The code currently disables state-dict validation by
setting self._check_valid_sd = lambda *args, **kwargs: None in the
EagleMedusaExporter __init__, which is marked only as a temporary NOTE; replace
this silent bypass with a tracked TODO and a visible reminder: restore
validation by implementing proper checks for parallel_draft_step in
extract_state_dict and call the original EagleExporter._check_valid_sd (or
raise/log a clear warning/error) until full validation is implemented;
specifically update the EagleMedusaExporter class to remove the no-op lambda,
add a TODO/issue-ID comment referencing the missing validation work, and ensure
any call sites (e.g., extract_state_dict) invoke the proper _check_valid_sd
behavior so state-dict validation is not permanently skipped.
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
What does this PR do?
Type of change: ?
Overview:
Main changes:
Refactored speculative decoding export logics into
class EagleExporterto improve cohesion;Separated speculative decoding export entrance with quantization export (
export_hf_checkpoint()) due to their fundamental differences:Usage
To export an regular bf16 eagle checkpoint without quantization, the commands are the same:
To run PTQ on online-trained eagle checkpoint and export it:
The above two commands will produce drafter ckpt for deployment, in the same foramt.
Testing
Tested setting:
python scripts/export_hf_checkpoint.py --model_path <x> --export_path <x>python hf_ptq.py --pyt_ckpt_path <x> --qformat fp8 --export_path <x>Before your PR is "Ready for review"
Additional Information