diff --git a/docs/tutorials/posttraining/knowledge_distillation.md b/docs/tutorials/posttraining/knowledge_distillation.md index 7723e568be..e040da26f6 100644 --- a/docs/tutorials/posttraining/knowledge_distillation.md +++ b/docs/tutorials/posttraining/knowledge_distillation.md @@ -16,6 +16,13 @@ # Knowledge distillation +> [!WARNING] +> **⚠️ Deprecated -- Update Coming Soon** +> +> This documentation is deprecated and current distillation recipes are no longer being actively fixed. +> +> We are prioritizing a migration to **vLLM** and **Tunix**. Please wait for the updated documentation regarding these new workflows. + ## Overview Knowledge Distillation is a compression technique that transfers knowledge from a larger (teacher) model to a smaller (student) model. This allows the smaller model to achieve performance levels closer to the larger one, but with significantly fewer parameters and computational resources. @@ -47,13 +54,19 @@ export RUN_NAME = #### b. Install dependencies -```sh -git clone https://github.com/AI-Hypercomputer/maxtext.git -python3 -m venv ~/venv-maxtext -source ~/venv-maxtext/bin/activate -python3 -m pip install uv -cd maxtext -uv pip install -r dependencies/requirements/requirements.txt + To install MaxText and its dependencies, follow the instructions in the [Install Maxtext](https://maxtext.readthedocs.io/en/latest/install_maxtext.html#install-maxtext) documentation. + +--- + +**⚠️ Warning: PyTorch Installation Required for Checkpoint Scripts** + +The checkpoint conversion scripts located under `MaxText/utils/ckpt_scripts/` (e.g., `llama_or_mistral_ckpt.py`, `convert_deepseek_family_unscanned_ckpt.py`) have a dependency on PyTorch (`torch`). These scripts are used for converting model checkpoints from other formats to be compatible with MaxText. + +If you intend to use these conversion scripts, you must install PyTorch. We recommend using `uv` for a fast and efficient installation: + +```bash +# Example command to install torch +uv pip install torch ``` ### 1. Obtain and prepare the teacher model @@ -94,7 +107,8 @@ JAX_PLATFORMS=cpu \ python3 -m MaxText.utils.ckpt_scripts.llama_or_mistral_ckpt \ --base-model-path ~/llama2-7b-chat \ --maxtext-model-path ${BASE_DIRECTORY}/llama2-7b-chat/scanned \ - --model-size llama2-7b + --model-size llama2-7b \ + --huggingface-checkpoint true ``` ### 3. Generate dataset using the teacher model @@ -109,7 +123,7 @@ python3 -m MaxText.maxengine_server src/MaxText/configs/base.yml \ tokenizer_path=deepseek-ai/DeepSeek-V2-Lite-chat tokenizer_type=huggingface \ load_parameters_path=${BASE_DIRECTORY}/deepseek2-16-chat/unscanned/0/items \ model_name=deepseek2-16b \ - per_device_batch_size=10 ici_tensor_parallelism=4 \ + per_device_batch_size=10 ici_tensor_parallelism=8 \ max_target_length=2048 max_prefill_predict_length=64 \ hf_access_token=$HF_TOKEN \ scan_layers=False \ @@ -118,12 +132,15 @@ python3 -m MaxText.maxengine_server src/MaxText/configs/base.yml \ Set `multi_sampling` to `True` to generate multiple independent completions per prompt. +> **Note on `ici_tensor_parallelism`** + For Inference/JetStream, always set `ici_tensor_parallelism` to the total number of chips in your machine (e.g., `8` for a `v4-8`). Using fewer chips for tensor parallelism forces the system to use Data Parallelism (FSDP) for the remaining chips. DeepSeek MoE models cannot handle Data Parallelism when the Batch Size is 1 (a single request), causing a crash. + ### 3.b. Generate dataset using JetStream server In a new tab in your terminal, run the following command to generate dataset from teacher model. Note that this is an example command to run on `v4-8`: ```bash -python3 -m MaxText.generate_distillation_data \ +python3 -m tools.data_generation.generate_distillation_data \ --tokenizer-path deepseek-ai/DeepSeek-V2-Lite-chat \ --dataset-path HuggingFaceH4/ultrachat_200k --data-split train_sft \ --data-columns messages \ @@ -155,7 +172,7 @@ python3 -m MaxText.sft_trainer src/MaxText/configs/sft.yml \ base_output_directory=${BASE_DIRECTORY}/distillation/deepseek2-16b-distill-llama2-7b \ tokenizer_path=meta-llama/Llama-2-7b-chat-hf tokenizer_type=huggingface \ hf_path=${USERNAME_OR_ORG}/${HF_REPO_NAME} \ - train_split='train' train_data_columns=['prompt','completion'] \ + train_split='train_sft' train_data_columns=['prompt','completion'] \ load_parameters_path=${BASE_DIRECTORY}/llama2-7b-chat/scanned/0/items \ model_name=llama2-7b \ per_device_batch_size=2 ici_expert_parallelism=-1 ici_fsdp_parallelism=4 \ diff --git a/src/MaxText/layers/attention_mla.py b/src/MaxText/layers/attention_mla.py index 051396ffd4..b2ad406f27 100644 --- a/src/MaxText/layers/attention_mla.py +++ b/src/MaxText/layers/attention_mla.py @@ -503,6 +503,35 @@ def _init_projections(self, inputs_q_shape: Tuple, inputs_kv_shape: Tuple) -> No rngs=self.rngs, ) + def init_kv_caches(self, inputs_kv_shape: Tuple): + """Initializes KVCache for MLA when using naive caching strategy. + + Overrides Attention.init_kv_caches to account for the specific + key dimensions of DeepSeek MLA (Content + RoPE). + """ + batch_size, _, _ = inputs_kv_shape + # Placeholder sequence length for initialization + placeholder_seq_len = 1 + + return kvcache.KVCache( + max_prefill_length=self.max_prefill_predict_length, + max_target_length=self.max_target_length, + batch=batch_size, + key_seq_len=placeholder_seq_len, + value_seq_len=placeholder_seq_len, + key_heads=self.num_kv_heads, + value_heads=self.num_kv_heads, + key_head_size=self.qk_head_dim, + value_head_size=self.v_head_dim, + dtype=self.dtype, + kv_quant=self.kv_quant, + prefill_cache_axis_order=self.prefill_cache_axis_order, + ar_cache_axis_order=self.ar_cache_axis_order, + use_chunked_prefill=self.config.use_chunked_prefill, + model_mode=self.model_mode, + rngs=self.rngs, + ) + def mla_query_projection(self, inputs_q: Array, inputs_positions: Array, model_mode) -> Array: """Query projection for MLA, e.g. includes LoRA if q_lora_rank > 0.""" # specify query logical name diff --git a/src/MaxText/maxengine.py b/src/MaxText/maxengine.py index dff1382d08..2914ff3796 100644 --- a/src/MaxText/maxengine.py +++ b/src/MaxText/maxengine.py @@ -752,8 +752,12 @@ def _prefill_multisampling_jit( # pytype: disable=attribute-error token_logps.append(inference_utils.log_prob_of_chosen_token(selected_logits, first_generated_token)) first_generated_tokens = jnp.concatenate(first_generated_tokens, axis=0) + if self.config.return_log_prob: token_logps = jnp.concatenate(token_logps, axis=0) + else: + # If logprobs are disabled, create Zeros matching the token shape. + token_logps = jnp.zeros(first_generated_tokens.shape, dtype=jnp.float32) all_valid = jnp.ones((num_samples, 1), dtype=jnp.int8) generated_tokens = jnp.zeros((num_samples, 1), dtype=jnp.int32) @@ -780,6 +784,7 @@ def _prefill_multisampling_jit( "next_pos": next_pos, "generated_tokens": generated_tokens, "tokens": first_generated_tokens, + "token_logp": token_logps, }, result @functools.partial( diff --git a/src/MaxText/utils/ckpt_scripts/__init__.py b/src/MaxText/utils/ckpt_scripts/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tools/data_generation/generate_distillation_data.py b/tools/data_generation/generate_distillation_data.py index 8636efa782..e0a2a3363b 100644 --- a/tools/data_generation/generate_distillation_data.py +++ b/tools/data_generation/generate_distillation_data.py @@ -20,7 +20,7 @@ This generated dataset can be used to fine-tune a student model. Example command: - python3 -m MaxText.generate_distillation_data \ + python3 -m tools.data_generation.generate_distillation_data \ --dataset-path HuggingFaceH4/ultrachat_200k --data-split train_sft --data-columns messages \ --tokenizer-path deepseek-ai/DeepSeek-V2-Lite-chat \ --hf-access-token \ @@ -37,7 +37,7 @@ `max-target-length` is the max length of prompt tokens and expected completion tokens. Set `--remove-local-dataset-files` to remove dataset files created locally after uploading to Hugging Face or GCS. `upload-to-hf` will upload the dataset to Hugging Face and `upload-to-gcs` will upload the dataset to GCS. -For more information, check out `python3 -m MaxText.generate_distillation_data --help`. +For more information, check out `python3 -m tools.data_generation.generate_distillation_data --help`. Note: Make sure to run maxengine server in a new terminal before executing this command. Example command to run maxengine server: python3 -m MaxText.maxengine_server src/MaxText/configs/base.yml \