Skip to content

Conversation

@shuningjin
Copy link
Collaborator

@shuningjin shuningjin commented Dec 11, 2025

Description

Previously we have gpt-oss, orbax(scan) -> hf: #2647

Fix: b/459541579

  • gpt-oss, hf -> orbax(scan)
  • gpt-oss, hf -> orbax(unscan)
  • gpt-oss, orbax(unscan) -> hf

Fix: b/452392132

  • implement weight splitting for hf->orbax (i.e., hf to many maxtext key)

Fix: b/452391921

  • verify interleaved scan pattern for hf->orbax

What this does

to_maxtext.py

  • allow hf to many mt
    • assume mt keys have same shape, hook function return a tensor stacked in last dim
    • accomodate lazy tensor: unoptimized, hf is repeated loaded for each mt
  • allow loading local hf checkpoint
    • the remote hf checkpoint is quantized, but we are using local de-quantized copy (usually bf16), this also applies to gpt-oss (mxfp4) and deepseek (fp8)
    • accomodate lazy tensor
  • improve heuristic for single axis stack
  • (other: factor out get_maxtext_dict, add time)

param_mapping.py, gpt-oss

  • implement interleave function for hf to many mt
  • add unscan version of mapping and hook

to_huggingface.py

  • move _check_param_map_keys to utils.py, so it can be reused by to_maxtext.py

Tests

1 HF -> orbax (gpt-oss-20b)

since we made non-trivial changes to lazy tensor implementation, also test lazy mode

HF -> orbax (scan), cpu

BASE_OUTPUT_PATH=gs://runner-maxtext-logs/$(date +%Y-%m-%d-%H-%M); \
echo $BASE_OUTPUT_PATH/0/items; \
python -m MaxText.utils.ckpt_conversion.to_maxtext MaxText/configs/base.yml model_name=gpt-oss-20b scan_layers=true \
base_output_directory=$BASE_OUTPUT_PATH hf_access_token=$HF_TOKEN \
hardware=cpu skip_jax_distributed_system=True \
attention=dot_product \
--hf_model_path=/home/shuningjin/gpt-oss-20b/gpt-oss-20b-bf16-v2

https://paste.googleplex.com/4888544332087296

CKPT=gs://runner-maxtext-logs/2025-12-26-21-51
python3 -m tests.forward_pass_logit_checker src/MaxText/configs/base.yml \
model_name=gpt-oss-20b \
load_parameters_path=$CKPT/0/items \
scan_layers=true \
per_device_batch_size=1 max_prefill_predict_length=4 max_target_length=4 \
attention=dot_product sparse_matmul=false \
--max_kl_div=0.015 --atol=0.5 --rtol=0.5 \
--run_hf_model=True \
--hf_model_path=/home/shuningjin/gpt-oss-20b/gpt-oss-20b-bf16-v2 \
tokenizer_path=openai/gpt-oss-20b tokenizer_type=huggingface \
skip_jax_distributed_system=True

https://paste.googleplex.com/5272274628378624

HF -> orbax (scan), cpu, lazy load

BASE_OUTPUT_PATH=gs://runner-maxtext-logs/$(date +%Y-%m-%d-%H-%M); \
echo $BASE_OUTPUT_PATH/0/items; \
python -m MaxText.utils.ckpt_conversion.to_maxtext MaxText/configs/base.yml model_name=gpt-oss-20b scan_layers=true \
base_output_directory=$BASE_OUTPUT_PATH hf_access_token=$HF_TOKEN \
hardware=cpu skip_jax_distributed_system=True \
attention=dot_product \
--hf_model_path=/home/shuningjin/gpt-oss-20b/gpt-oss-20b-bf16-v2 \
--lazy_load_tensors=true

https://paste.googleplex.com/6192468888518656

CKPT=gs://runner-maxtext-logs/2025-12-26-21-58
python3 -m tests.forward_pass_logit_checker src/MaxText/configs/base.yml \
model_name=gpt-oss-20b \
load_parameters_path=$CKPT/0/items \
scan_layers=true \
per_device_batch_size=1 max_prefill_predict_length=4 max_target_length=4 \
attention=dot_product sparse_matmul=false \
--max_kl_div=0.015 --atol=0.5 --rtol=0.5 \
--run_hf_model=True \
--hf_model_path=/home/shuningjin/gpt-oss-20b/gpt-oss-20b-bf16-v2 \
tokenizer_path=openai/gpt-oss-20b tokenizer_type=huggingface \
skip_jax_distributed_system=True

https://paste.googleplex.com/5022180226236416

HF -> orbax (unscan), cpu

BASE_OUTPUT_PATH=gs://runner-maxtext-logs/$(date +%Y-%m-%d-%H-%M); \
echo $BASE_OUTPUT_PATH/0/items; \
python -m MaxText.utils.ckpt_conversion.to_maxtext MaxText/configs/base.yml model_name=gpt-oss-20b scan_layers=false \
base_output_directory=$BASE_OUTPUT_PATH hf_access_token=$HF_TOKEN \
hardware=cpu skip_jax_distributed_system=True \
attention=dot_product \
--hf_model_path=/home/shuningjin/gpt-oss-20b/gpt-oss-20b-bf16-v2

https://paste.googleplex.com/6000687559344128

CKPT=gs://runner-maxtext-logs/2025-12-26-22-20
python3 -m tests.forward_pass_logit_checker src/MaxText/configs/base.yml \
model_name=gpt-oss-20b \
load_parameters_path=$CKPT/0/items \
scan_layers=false \
per_device_batch_size=1 max_prefill_predict_length=4 max_target_length=4 \
attention=dot_product sparse_matmul=false \
--max_kl_div=0.015 --atol=0.5 --rtol=0.5 \
--run_hf_model=True \
--hf_model_path=/home/shuningjin/gpt-oss-20b/gpt-oss-20b-bf16-v2 \
tokenizer_path=openai/gpt-oss-20b tokenizer_type=huggingface \
skip_jax_distributed_system=True

https://paste.googleplex.com/6128993298415616

2 orbax -> HF (gpt-oss-20b)

orbax -> HF (unscan), cpu

ID=$(date +%Y-%m-%d-%H-%M-%S); \
python3 -m MaxText.utils.ckpt_conversion.to_huggingface src/MaxText/configs/base.yml \
model_name=gpt-oss-20b \
load_parameters_path=gs://shuningjin-multipod-dev/gpt-oss-20b/unscan-bf16-v2-2025-09-02-01-16-00/0/items \
base_output_directory=/home/shuningjin/gpt-oss-20b/gpt-oss-20b-hf-$ID \
scan_layers=false \
attention=dot_product skip_jax_distributed_system=True \
weight_dtype=bfloat16 checkpoint_storage_concurrent_gb=1024

https://paste.googleplex.com/5483624130543616

HF_PATH=/home/shuningjin/gpt-oss-20b/gpt-oss-20b-hf-2025-12-26-22-56-40
python3 -m tests.forward_pass_logit_checker src/MaxText/configs/base.yml \
model_name=gpt-oss-20b \
load_parameters_path=gs://shuningjin-multipod-dev/gpt-oss-20b/unscan-bf16-v2-2025-09-02-01-16-00/0/items \
scan_layers=false \
per_device_batch_size=1 max_prefill_predict_length=4 max_target_length=4 \
attention=dot_product sparse_matmul=false \
--max_kl_div=0.015 --atol=0.5 --rtol=0.5 \
--run_hf_model=True \
--hf_model_path=$HF_PATH \
tokenizer_path=openai/gpt-oss-20b tokenizer_type=huggingface \
skip_jax_distributed_system=True

https://paste.googleplex.com/4854884102963200

3 HF -> orbax (check other models just in case)

qwen3-4b

BASE_OUTPUT_PATH=gs://runner-maxtext-logs/$(date +%Y-%m-%d-%H-%M); \
echo $BASE_OUTPUT_PATH/0/items; \
python -m MaxText.utils.ckpt_conversion.to_maxtext MaxText/configs/base.yml model_name=qwen3-4b scan_layers=false \
base_output_directory=$BASE_OUTPUT_PATH hf_access_token=$HF_TOKEN \
hardware=cpu skip_jax_distributed_system=True

https://paste.googleplex.com/5401590255190016

CKPT=gs://runner-maxtext-logs/2025-12-26-23-25
python3 -m tests.forward_pass_logit_checker src/MaxText/configs/base.yml \
model_name=qwen3-4b attention=dot_product \
override_model_config=true enable_dropout=false tokenizer_type=huggingface \
load_parameters_path=$CKPT/0/items scan_layers=false \
per_device_batch_size=1 max_prefill_predict_length=4 max_target_length=8 \
tokenizer_path=Qwen/Qwen3-4B --run_hf_model=True --hf_model_path=Qwen/Qwen3-4B \
--max_kl_div=0.015 --atol=0.5 --rtol=0.5 \
skip_jax_distributed_system=True

https://paste.googleplex.com/6538164284030976

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@shuningjin shuningjin changed the title [WIP] checkpoint util: gpt-oss, orbax -> hf [WIP] checkpoint util: gpt-oss, hf -> orbax Dec 11, 2025
@codecov
Copy link

codecov bot commented Dec 23, 2025

@shuningjin shuningjin changed the title [WIP] checkpoint util: gpt-oss, hf -> orbax Checkpoint utility: gpt-oss, hf to orbax Dec 26, 2025
@github-actions
Copy link

🤖 Hi @shuningjin, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

📋 Review Summary

This pull request introduces a significant and valuable set of enhancements to the checkpoint conversion utilities. The refactoring improves code modularity by moving shared functions to a central utils.py, and the new functionality, such as support for local Hugging Face models and unscanned GPT-OSS models, greatly increases the flexibility of these tools.

🔍 General Feedback

  • Positive: The addition of more granular timing logs is a great improvement for performance analysis and debugging. The updated documentation in the README provides a much clearer overview of supported models and conversion paths.
  • Good Refactoring: Moving check_param_map_keys to utils.py and introducing get_maxtext_model_info in to_maxtext.py are excellent changes that improve code organization and reusability.

I have left a couple of minor comments, one regarding a bug in the timing calculation and another for a small docstring clarification. Overall, this is a solid contribution.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants