Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion end_to_end/tpu/deepseek/v2-16b/test_deepseek.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# Example Usage: export HF_TOKEN=<huggingface_access_token>; export BASE_OUTPUT_PATH=<GCS_bucket_path>; bash test_deepseek.sh

# The golden logit can be generated by:
# python3 -m MaxText.scratch_code.generate_hf_golden_logits --model-id=deepseek-ai/DeepSeek-V2-Lite --output-path=golden_data_deepseek2-16b.jsonl --prompts='I love to;Today is a;What is the' --hf-model-path=$local_bf16_path --not-trust-remote-code
# python3 -m MaxText.scratch_code.generate_hf_golden_logits --model-id=deepseek-ai/DeepSeek-V2-Lite --output-path=golden_data_deepseek2-16b.jsonl --prompts='I love to;Today is a;What is the' --hf-model-path=$local_bf16_path --trust-remote-code=False

set -ex

Expand Down
2 changes: 1 addition & 1 deletion end_to_end/tpu/deepseek/v3-671b/2_test_deepseek.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
# 2. Run logit check, pre-training, fine-tuning, and decoding.

# The golden logit can be generated by:
# python3 -m MaxText.scratch_code.generate_hf_golden_logits --model-id=deepseek-ai/DeepSeek-V3 --output-path=golden_data_deepseek3-671b.jsonl --prompts='I love to' --hf-model-path=$local_bf16_path --not-trust-remote-code --hf-load-dtype=bfloat16
# python3 -m MaxText.scratch_code.generate_hf_golden_logits --model-id=deepseek-ai/DeepSeek-V3 --output-path=golden_data_deepseek3-671b.jsonl --prompts='I love to' --hf-model-path=$local_bf16_path --trust-remote-code=False --hf-load-dtype=bfloat16

set -ex

Expand Down
9 changes: 5 additions & 4 deletions src/MaxText/scratch_code/generate_hf_golden_logits.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
import numpy as np
from google.cloud import storage
from PIL import Image
from MaxText.inference_utils import str2bool

# Load the tokenizer and model from Hugging Face

Expand Down Expand Up @@ -184,11 +185,11 @@ def main(raw_args=None) -> None:
default="float32",
help="model_class.from_pretrained: dtype",
)
# variable `args.trust_remote_code` is True by default, False only if with flag `--not-trust-remote-code`
parser.add_argument(
"--not-trust-remote-code",
dest="trust_remote_code",
action="store_false",
"--trust-remote-code",
type=str2bool,
required=False,
default=True,
help="model_class.from_pretrained: trust_remote_code",
)
parser.add_argument(
Expand Down
24 changes: 17 additions & 7 deletions src/MaxText/utils/ckpt_conversion/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,17 @@ This guide provides instructions for using the scripts that convert model checkp

The following models are supported:

- Gemma2 (2B, 9B, 27B).
- Gemma3 multimodal (4B, 12B, 27B).
- Qwen3 (0.6B, 4B, 8B, 14B, 32B).
- Mixtral (8x7B, 8x22B).
| Model Family | Sizes | HF $\to$ Orbax (scan) | HF $\to$ Orbax (unscan) | Orbax (scan) $\to$ HF | Orbax (unscan) $\to$ HF |
| :--- | :--- | :---: | :---: | :---: | :---: |
| **Gemma2** | 2B, 9B, 27B | √ | √ | √ | √ |
| **Gemma3** (Multimodal) | 4B, 12B, 27B | - | √ | - | √ |
| **Llama3.1** | 8B, 70B, 450B | √ | √ | √ | √ |
| **Qwen3** | 0.6B, 4B, 8B, 14B, 32B | √ | √ | √ | √ |
| **Qwen3 MoE** | 30B, 235B, 480B | √ | √ | √ | √ |
| **Mixtral** | 8x7B, 8x22B | √ | √ | √ | √ |
| **GPT-OSS** | 20B, 120B | √ | √ | √ | √ |
| **DeepSeek3** | 671B | - | - | √ | - |


## Prerequisites
- Hugging Face requires Pytorch.
Expand Down Expand Up @@ -42,8 +49,9 @@ python3 -m MaxText.utils.ckpt_conversion.to_maxtext src/MaxText/configs/base.yml
* `use_multimodal`: Indicates if multimodality is used, important for Gemma3.
* `hf_access_token`: Your Hugging Face token.
* `base_output_directory`: The path where the converted Orbax checkpoint will be stored; it can be Googld Cloud Storage (GCS) or local. If not set, the default output directory is `Maxtext/tmp`.
* `--lazy_load_tensors` (optional): If `true`, loads Hugging Face weights on-demand to minimize RAM usage.
* `--hf_model_path` (optional): Specifies a local directory containing the model weights. If unspecified, we use the [default Hugging Face repository ID](https://github.com/AI-Hypercomputer/maxtext/blob/2f77e7b5fcc4b580bc2d109525c362f3d9056ec9/src/MaxText/utils/ckpt_conversion/utils/utils.py#L54-L82) (e.g., openai/gpt-oss-20b). This is necessary for locally dequantized models like GPT-OSS or DeepSeek.

\*\**It only converts the official version of Hugging Face model. You can refer the supported official version in HF_IDS in `src/MaxText/utils/ckpt_conversion/utils/utils.py`*

## MaxText to Hugging Face

Expand All @@ -62,6 +70,7 @@ python3 -m MaxText.utils.ckpt_conversion.to_huggingface src/MaxText/configs/base
scan_layers=false \
use_multimodal=false \
hf_access_token=<your-hf-token> \
weight_dtype=bfloat16
```

**Key arguments:**
Expand All @@ -72,6 +81,7 @@ python3 -m MaxText.utils.ckpt_conversion.to_huggingface src/MaxText/configs/base
* `hf_access_token`: Your Hugging Face token.
* `use_multimodal`: Indicates if multimodality is used, important for Gemma3.
* `base_output_directory`: The path where the converted Orbax checkpoint will be stored; it can be Googld Cloud Storage (GCS), Hugging Face Hub or local. If not set, the default output directory is `Maxtext/tmp`.
* `weight_dtype`: dtype for MaxText weights. It affects the resulting HF weight dtype. Default value is `float32`. We recommend using `bfloat16` to save memory and speed up conversion.


## Verifying conversion correctness
Expand All @@ -87,11 +97,11 @@ python3 -m tests.forward_pass_logit_checker src/MaxText/configs/base.yml \
model_name=<MODEL_NAME> \
scan_layers=false \
max_prefill_predict_length=4 \
max_target_length=8 \
max_target_length=8 \
use_multimodal=false \
--run_hf_model=True \
--hf_model_path=<path-to-HF-checkpoint> \
--max_kl_div=0.015 \
--max_kl_div=0.015
```

**Key arguments:**
Expand Down
71 changes: 12 additions & 59 deletions src/MaxText/utils/ckpt_conversion/to_huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,12 @@
)
from MaxText.utils.ckpt_conversion.utils.hf_shape import HF_SHAPE
from MaxText.utils.ckpt_conversion.utils.hf_model_configs import HF_MODEL_CONFIGS
from MaxText.utils.ckpt_conversion.utils.utils import process_maxtext_param, save_model_files, HF_IDS

from MaxText.utils.ckpt_conversion.utils.utils import (
validate_and_filter_param_map_keys,
process_maxtext_param,
save_model_files,
HF_IDS,
)

os.environ["JAX_PLATFORMS"] = "cpu"
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=16"
Expand Down Expand Up @@ -107,59 +111,6 @@ def _get_model_mappings(
}


def _check_param_map_keys(param_map_keys, maxtext_state_keys):
"""Validates map coverage, handles N-to-1 mappings, and filters unused keys.

Ensures every MaxText checkpoint key (`maxtext_state_keys`) is covered by
the flattened parameter map. Keys in the map that are not present in the
checkpoint (common for multi-variant maps like gemma3, qwen3, deepseek) are skipped.

Tuple keys represent N-to-1 mappings (multiple MaxText keys combining into one
target key) and are only returned if all constituent keys exist in the checkpoint.

Args:
param_map_keys: Keys from the parameter mapping (strings or N-to-1 tuples).
maxtext_state_keys: Set of parameter keys loaded from the MaxText checkpoint.

Returns:
A list of 'filtered' mapping keys (strings or tuples) that are fully present
and valid based on `maxtext_state_keys`.

Raises:
ValueError: If `maxtext_state_keys` is NOT a subset of the flattened
`param_map_keys`.
"""
flattened_map_keys = set()
for key in param_map_keys:
if isinstance(key, tuple):
flattened_map_keys.update(key)
else:
flattened_map_keys.add(key)

# every maxtext state key must be covered by param map
missing_keys = maxtext_state_keys - flattened_map_keys
if missing_keys:
raise ValueError(
"maxtext_state_dict must be a subset of flattened param_map"
+ f"\nparam map\n{param_map_keys}"
+ f"\nmaxtext:\n{maxtext_state_keys}"
)

# param map may have extra keys
extra_keys = flattened_map_keys - maxtext_state_keys
if extra_keys:
max_logging.log(f"Warning: extra keys in param_map are skipped: {extra_keys}")

# skip extra keys in param map
filtered_map_keys = []
for key in param_map_keys:
if (isinstance(key, str) and key in maxtext_state_keys) or (
isinstance(key, tuple) and all(k in maxtext_state_keys for k in key)
):
filtered_map_keys.append(key)
return filtered_map_keys


def main(argv: Sequence[str]) -> None:
"""Main function to convert a MaxText checkpoint to HuggingFace format.

Expand All @@ -180,6 +131,7 @@ def main(argv: Sequence[str]) -> None:
config.load_full_state_path == ""
), "This script expects parameters, not a full state. Use generate_param_only_checkpoint first if needed."
max_utils.print_system_information()
overall_start = time.time()

# Load Maxtext checkpoint
max_logging.log("\nLoading Orbax checkpoint...")
Expand All @@ -189,7 +141,7 @@ def main(argv: Sequence[str]) -> None:
rng, rng_load_params = jax.random.split(rng)
# load params from maxengine
loaded_params_from_engine = engine.load_params(rng_load_params)
max_logging.log(f"Elapse: {(time.time() - start) / 60:.2f} min")
max_logging.log(f"Elapse for checkpoint load: {(time.time() - start) / 60:.2f} min")

if not config.base_output_directory:
output_directory = f"tmp/{config.run_name}"
Expand Down Expand Up @@ -239,7 +191,7 @@ def main(argv: Sequence[str]) -> None:
# The param_map may contain tuples as keys, which represent N-to-1 mappings from maxtext to huggingface
# Check maxtext_state_dict is a subset of flattened param_map
# Skip extra keys from param_map
filtered_map_keys = _check_param_map_keys(param_map.keys(), maxtext_state_dict.keys())
filtered_map_keys = validate_and_filter_param_map_keys(param_map.keys(), maxtext_state_dict.keys())

# Iterate through the parameter map to transform and collect weights.
# This loop handles both simple 1-to-1 mappings and complex N-to-1 mappings
Expand All @@ -260,7 +212,7 @@ def main(argv: Sequence[str]) -> None:
processed_params_list.extend(processed_params)

transformed_hf_weights = dict(processed_params_list)
max_logging.log(f"Elapse: {(time.time() - start) / 60:.2f} min")
max_logging.log(f"Elapse for transform: {(time.time() - start) / 60:.2f} min")

# 5. Save in HuggingFace Format
if not transformed_hf_weights:
Expand All @@ -277,7 +229,8 @@ def main(argv: Sequence[str]) -> None:
output_dir=output_directory,
)
max_logging.log(f"✅ MaxText model successfully saved in HuggingFace format at {output_directory}")
max_logging.log(f"Elapse: {(time.time() - start) / 60:.2f} min")
max_logging.log(f"Elapse for save: {(time.time() - start) / 60:.2f} min")
max_logging.log(f"Overall Elapse: {(time.time() - overall_start) / 60:.2f} min")


if __name__ == "__main__":
Expand Down
Loading
Loading