Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 30 additions & 6 deletions docs/tutorials/posttraining/knowledge_distillation.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,13 @@

# Knowledge distillation

> [!WARNING]
> **⚠️ Deprecated -- Update Coming Soon**
>
> This documentation is deprecated and current distillation recipes are no longer being actively fixed.
>
> We are prioritizing a migration to **vLLM** and **Tunix**. Please wait for the updated documentation regarding these new workflows.

## Overview
Knowledge Distillation is a compression technique that transfers knowledge from a larger (teacher) model to a smaller (student) model. This allows the smaller model to achieve performance levels closer to the larger one, but with significantly fewer parameters and computational resources.

Expand Down Expand Up @@ -49,13 +56,26 @@ export RUN_NAME = <unique name for the run>

```sh
git clone https://github.com/AI-Hypercomputer/maxtext.git
python3 -m venv ~/venv-maxtext
source ~/venv-maxtext/bin/activate
python3 -m venv ~/maxtext_venv
source ~/maxtext_venv/bin/activate
python3 -m pip install uv
cd maxtext
uv pip install -r dependencies/requirements/requirements.txt
```

---

**⚠️ Warning: PyTorch Installation Required for Checkpoint Scripts**

The checkpoint conversion scripts located under `MaxText/utils/ckpt_scripts/` (e.g., `llama_or_mistral_ckpt.py`, `convert_deepseek_family_unscanned_ckpt.py`) have a dependency on PyTorch (`torch`). These scripts are used for converting model checkpoints from other formats to be compatible with MaxText.

If you intend to use these conversion scripts, you must install PyTorch. We recommend using `uv` for a fast and efficient installation:

```bash
# Example command to install torch
uv pip install torch
```

### 1. Obtain and prepare the teacher model

#### a. Download model from Hugging Face
Expand Down Expand Up @@ -94,7 +114,8 @@ JAX_PLATFORMS=cpu \
python3 -m MaxText.utils.ckpt_scripts.llama_or_mistral_ckpt \
--base-model-path ~/llama2-7b-chat \
--maxtext-model-path ${BASE_DIRECTORY}/llama2-7b-chat/scanned \
--model-size llama2-7b
--model-size llama2-7b \
--huggingface-checkpoint true
```

### 3. Generate dataset using the teacher model
Expand All @@ -109,7 +130,7 @@ python3 -m MaxText.maxengine_server src/MaxText/configs/base.yml \
tokenizer_path=deepseek-ai/DeepSeek-V2-Lite-chat tokenizer_type=huggingface \
load_parameters_path=${BASE_DIRECTORY}/deepseek2-16-chat/unscanned/0/items \
model_name=deepseek2-16b \
per_device_batch_size=10 ici_tensor_parallelism=4 \
per_device_batch_size=10 ici_tensor_parallelism=8 \
max_target_length=2048 max_prefill_predict_length=64 \
hf_access_token=$HF_TOKEN \
scan_layers=False \
Expand All @@ -118,12 +139,15 @@ python3 -m MaxText.maxengine_server src/MaxText/configs/base.yml \

Set `multi_sampling` to `True` to generate multiple independent completions per prompt.

> **Note on `ici_tensor_parallelism`**
For Inference/JetStream, always set `ici_tensor_parallelism` to the total number of chips in your machine (e.g., `8` for a `v4-8`). Using fewer chips for tensor parallelism forces the system to use Data Parallelism (FSDP) for the remaining chips. DeepSeek MoE models cannot handle Data Parallelism when the Batch Size is 1 (a single request), causing a crash.


### 3.b. Generate dataset using JetStream server
In a new tab in your terminal, run the following command to generate dataset from teacher model. Note that this is an example command to run on `v4-8`:

```bash
python3 -m MaxText.generate_distillation_data \
python3 -m tools.data_generation.generate_distillation_data \
--tokenizer-path deepseek-ai/DeepSeek-V2-Lite-chat \
--dataset-path HuggingFaceH4/ultrachat_200k --data-split train_sft \
--data-columns messages \
Expand Down Expand Up @@ -155,7 +179,7 @@ python3 -m MaxText.sft_trainer src/MaxText/configs/sft.yml \
base_output_directory=${BASE_DIRECTORY}/distillation/deepseek2-16b-distill-llama2-7b \
tokenizer_path=meta-llama/Llama-2-7b-chat-hf tokenizer_type=huggingface \
hf_path=${USERNAME_OR_ORG}/${HF_REPO_NAME} \
train_split='train' train_data_columns=['prompt','completion'] \
train_split='train_sft' train_data_columns=['prompt','completion'] \
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we changing train_split?

load_parameters_path=${BASE_DIRECTORY}/llama2-7b-chat/scanned/0/items \
model_name=llama2-7b \
per_device_batch_size=2 ici_expert_parallelism=-1 ici_fsdp_parallelism=4 \
Expand Down
29 changes: 29 additions & 0 deletions src/MaxText/layers/attention_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,35 @@ def _init_projections(self, inputs_q_shape: Tuple, inputs_kv_shape: Tuple) -> No
rngs=self.rngs,
)

def init_kv_caches(self, inputs_kv_shape: Tuple):
"""Initializes KVCache for MLA when using naive caching strategy.

Overrides Attention.init_kv_caches to account for the specific
key dimensions of DeepSeek MLA (Content + RoPE).
"""
batch_size, _, _ = inputs_kv_shape
# Placeholder sequence length for initialization
placeholder_seq_len = 1

return kvcache.KVCache(
max_prefill_length=self.max_prefill_predict_length,
max_target_length=self.max_target_length,
batch=batch_size,
key_seq_len=placeholder_seq_len,
value_seq_len=placeholder_seq_len,
key_heads=self.num_kv_heads,
value_heads=self.num_kv_heads,
key_head_size=self.qk_head_dim,
value_head_size=self.v_head_dim,
dtype=self.dtype,
kv_quant=self.kv_quant,
prefill_cache_axis_order=self.prefill_cache_axis_order,
ar_cache_axis_order=self.ar_cache_axis_order,
use_chunked_prefill=self.config.use_chunked_prefill,
model_mode=self.model_mode,
rngs=self.rngs,
)

def mla_query_projection(self, inputs_q: Array, inputs_positions: Array, model_mode) -> Array:
"""Query projection for MLA, e.g. includes LoRA if q_lora_rank > 0."""
# specify query logical name
Expand Down
5 changes: 5 additions & 0 deletions src/MaxText/maxengine.py
Original file line number Diff line number Diff line change
Expand Up @@ -752,8 +752,12 @@ def _prefill_multisampling_jit(
# pytype: disable=attribute-error
token_logps.append(inference_utils.log_prob_of_chosen_token(selected_logits, first_generated_token))
first_generated_tokens = jnp.concatenate(first_generated_tokens, axis=0)

if self.config.return_log_prob:
token_logps = jnp.concatenate(token_logps, axis=0)
else:
# If logprobs are disabled, create Zeros matching the token shape.
token_logps = jnp.zeros(first_generated_tokens.shape, dtype=jnp.float32)

all_valid = jnp.ones((num_samples, 1), dtype=jnp.int8)
generated_tokens = jnp.zeros((num_samples, 1), dtype=jnp.int32)
Expand All @@ -780,6 +784,7 @@ def _prefill_multisampling_jit(
"next_pos": next_pos,
"generated_tokens": generated_tokens,
"tokens": first_generated_tokens,
"token_logp": token_logps,
}, result

@functools.partial(
Expand Down
Empty file.
4 changes: 2 additions & 2 deletions tools/data_generation/generate_distillation_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
This generated dataset can be used to fine-tune a student model.

Example command:
python3 -m MaxText.generate_distillation_data \
python3 -m tools.data_generation.generate_distillation_data \
--dataset-path HuggingFaceH4/ultrachat_200k --data-split train_sft --data-columns messages \
--tokenizer-path deepseek-ai/DeepSeek-V2-Lite-chat \
--hf-access-token <access token> \
Expand All @@ -37,7 +37,7 @@
`max-target-length` is the max length of prompt tokens and expected completion tokens.
Set `--remove-local-dataset-files` to remove dataset files created locally after uploading to Hugging Face or GCS.
`upload-to-hf` will upload the dataset to Hugging Face and `upload-to-gcs` will upload the dataset to GCS.
For more information, check out `python3 -m MaxText.generate_distillation_data --help`.
For more information, check out `python3 -m tools.data_generation.generate_distillation_data --help`.
Note:
Make sure to run maxengine server in a new terminal before executing this command. Example command to run maxengine server:
python3 -m MaxText.maxengine_server src/MaxText/configs/base.yml \
Expand Down
Loading