From 11916bb974f5bf9444ed01ebbb95bdd825c6beb6 Mon Sep 17 00:00:00 2001 From: Surbhi Jain Date: Fri, 12 Dec 2025 01:20:45 +0000 Subject: [PATCH] Refactor dockerfiles and deprecate stable_stack and jax_ai_image mode --- .github/workflows/UploadDockerImages.yml | 4 +- .../workflows/build_and_push_docker_image.yml | 2 +- PREFLIGHT.md | 2 +- RESTRUCTURE.md | 7 +- .../dockerfiles/jetstream_pathways.Dockerfile | 54 ----- .../dockerfiles/maxengine_server.Dockerfile | 48 ---- .../maxtext_dependencies.Dockerfile | 55 ----- .../maxtext_jax_ai_image.Dockerfile | 78 ------ .../maxtext_libtpu_path.Dockerfile | 13 - ...text_post_training_dependencies.Dockerfile | 4 +- ...ost_training_local_dependencies.Dockerfile | 5 +- ...le => maxtext_tpu_dependencies.Dockerfile} | 16 +- .../requirements_with_jax_ai_image.txt | 29 --- .../scripts/docker_build_dependency_image.sh | 226 +++++++----------- docs/run_maxtext/run_maxtext_via_pathways.md | 2 +- docs/run_maxtext/run_maxtext_via_xpk.md | 8 +- .../posttraining/rl_on_multi_host.md | 6 +- .../posttraining/sft_on_multi_host.md | 2 +- preflight.sh | 2 +- tools/setup/setup.sh | 14 +- 20 files changed, 128 insertions(+), 449 deletions(-) delete mode 100644 dependencies/dockerfiles/jetstream_pathways.Dockerfile delete mode 100644 dependencies/dockerfiles/maxengine_server.Dockerfile delete mode 100644 dependencies/dockerfiles/maxtext_dependencies.Dockerfile delete mode 100644 dependencies/dockerfiles/maxtext_jax_ai_image.Dockerfile delete mode 100644 dependencies/dockerfiles/maxtext_libtpu_path.Dockerfile rename dependencies/dockerfiles/{maxtext_db_dependencies.Dockerfile => maxtext_tpu_dependencies.Dockerfile} (84%) delete mode 100644 dependencies/requirements/requirements_with_jax_ai_image.txt diff --git a/.github/workflows/UploadDockerImages.yml b/.github/workflows/UploadDockerImages.yml index 9433efdf3e..931ac40009 100644 --- a/.github/workflows/UploadDockerImages.yml +++ b/.github/workflows/UploadDockerImages.yml @@ -65,11 +65,11 @@ jobs: - device: tpu build_mode: stable image_name: maxtext_jax_stable - dockerfile: ./dependencies/dockerfiles/maxtext_dependencies.Dockerfile + dockerfile: ./dependencies/dockerfiles/maxtext_tpu_dependencies.Dockerfile - device: tpu build_mode: nightly image_name: maxtext_jax_nightly - dockerfile: ./dependencies/dockerfiles/maxtext_dependencies.Dockerfile + dockerfile: ./dependencies/dockerfiles/maxtext_tpu_dependencies.Dockerfile uses: ./.github/workflows/build_and_push_docker_image.yml with: image_name: ${{ matrix.image_name }} diff --git a/.github/workflows/build_and_push_docker_image.yml b/.github/workflows/build_and_push_docker_image.yml index 6b852c38a1..0168dcceaf 100644 --- a/.github/workflows/build_and_push_docker_image.yml +++ b/.github/workflows/build_and_push_docker_image.yml @@ -113,7 +113,7 @@ jobs: DEVICE=${{ inputs.device }} MODE=${{ inputs.build_mode }} JAX_VERSION=NONE - LIBTPU_GCS_PATH=NONE + LIBTPU_VERSION=NONE BASEIMAGE=gcr.io/tpu-prod-env-multipod/maxtext_jax_stable:${{ inputs.image_date }} - name: Add tags to Docker image diff --git a/PREFLIGHT.md b/PREFLIGHT.md index 003bdf79f9..b589bc1f4f 100644 --- a/PREFLIGHT.md +++ b/PREFLIGHT.md @@ -26,7 +26,7 @@ bash preflight.sh PLATFORM=GCE && numactl --membind 0 --cpunodebind=0 python3 -m ``` For GKE, -`numactl` should be built into your docker image from [maxtext_dependencies.Dockerfile](https://github.com/google/maxtext/blob/main/dependencies/dockerfiles/maxtext_dependencies.Dockerfile), so you can use it directly if you built the maxtext docker image. Here is an example +`numactl` should be built into your docker image from [maxtext_tpu_dependencies.Dockerfile](https://github.com/google/maxtext/blob/main/dependencies/dockerfiles/maxtext_tpu_dependencies.Dockerfile), so you can use it directly if you built the maxtext docker image. Here is an example ``` bash preflight.sh PLATFORM=GKE && numactl --membind 0 --cpunodebind=0 python3 -m MaxText.train src/MaxText/configs/base.yml run_name=$YOUR_JOB_NAME diff --git a/RESTRUCTURE.md b/RESTRUCTURE.md index d0b1cae45f..5adc169904 100644 --- a/RESTRUCTURE.md +++ b/RESTRUCTURE.md @@ -27,14 +27,9 @@ comments, or questions by creating a new ├── README.md ├── dependencies/ │ ├── dockerfiles/ -│ │ ├── jetstream_pathways.Dockerfile -│ │ ├── maxengine_server.Dockerfile │ │ ├── maxtext_custom_wheels.Dockerfile -│ │ ├── maxtext_db_dependencies.Dockerfile -│ │ ├── maxtext_dependencies.Dockerfile +│ │ ├── maxtext_tpu_dependencies.Dockerfile │ │ ├── maxtext_gpu_dependencies.Dockerfile -│ │ ├── maxtext_jax_ai_image.Dockerfile -│ │ ├── maxtext_libtpu_path.Dockerfile │ │ └── maxtext_runner.Dockerfile │ ├── requirements/ │ │ └── requirements.txt diff --git a/dependencies/dockerfiles/jetstream_pathways.Dockerfile b/dependencies/dockerfiles/jetstream_pathways.Dockerfile deleted file mode 100644 index 669961b8af..0000000000 --- a/dependencies/dockerfiles/jetstream_pathways.Dockerfile +++ /dev/null @@ -1,54 +0,0 @@ -# Copyright 2023–2025 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Ubuntu:24.04 -# Use Ubuntu 24.04 from Docker Hub. -# https://hub.docker.com/_/ubuntu/tags\?page\=1\&name\=24.04 -FROM ubuntu:24.04 - -ENV DEBIAN_FRONTEND=noninteractive - -RUN apt -y update && apt install -y --no-install-recommends apt-transport-https ca-certificates gnupg git python3.12 python3-pip curl nano vim - -RUN update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.12 1 -RUN echo "deb [signed-by=/usr/share/keyrings/cloud.google.gpg] https://packages.cloud.google.com/apt cloud-sdk main" | tee -a /etc/apt/sources.list.d/google-cloud-sdk.list && curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | gpg --dearmor -o /usr/share/keyrings/cloud.google.gpg && apt-get update -y && apt-get install google-cloud-sdk -y - -RUN python3 -m pip install --upgrade pip - -ENV JAX_PLATFORMS=proxy -ENV JAX_BACKEND_TARGET=grpc://localhost:38681 -ENV XCLOUD_ENVIRONMENT=GCP - -ENV MAXTEXT_VERSION=main -ENV JETSTREAM_VERSION=main - -RUN git clone https://github.com/AI-Hypercomputer/JetStream.git && \ -git clone https://github.com/AI-Hypercomputer/maxtext.git - -RUN cd maxtext/ && \ -git checkout ${MAXTEXT_VERSION} && \ -bash ./tools/setup/setup.sh - -RUN cd /JetStream && \ -git checkout ${JETSTREAM_VERSION} && \ -python3 -m pip install -e . - -RUN python3 -m pip install setuptools fastapi uvicorn - -RUN apt -y update && apt-get -y install python3-dev && apt-get -y install build-essential - -COPY jetstream_pathways_entrypoint.sh /usr/bin/ -RUN chmod +x /usr/bin/jetstream_pathways_entrypoint.sh - -ENTRYPOINT ["jetstream_pathways_entrypoint.sh"] diff --git a/dependencies/dockerfiles/maxengine_server.Dockerfile b/dependencies/dockerfiles/maxengine_server.Dockerfile deleted file mode 100644 index c0fe5a3820..0000000000 --- a/dependencies/dockerfiles/maxengine_server.Dockerfile +++ /dev/null @@ -1,48 +0,0 @@ -# Copyright 2023–2025 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Ubuntu:24.04 -# Use Ubuntu 24.04 from Docker Hub. -# https://hub.docker.com/_/ubuntu/tags?page=1&name=24.04 -FROM ubuntu:24.04 - -ENV DEBIAN_FRONTEND=noninteractive -ENV MAXTEXT_VERSION=main -ENV JETSTREAM_VERSION=main - -RUN apt -y update && apt install -y --no-install-recommends \ - ca-certificates \ - git \ - python3.12 \ - python3-pip - -RUN update-alternatives --install \ - /usr/bin/python3 python3 /usr/bin/python3.12 1 - -RUN git clone https://github.com/AI-Hypercomputer/maxtext.git && \ -git clone https://github.com/AI-Hypercomputer/JetStream.git - -RUN cd maxtext/ && \ -git checkout ${MAXTEXT_VERSION} && \ -bash ./tools/setup/setup.sh - -RUN cd /JetStream && \ -git checkout ${JETSTREAM_VERSION} && \ -python3 -m pip install -e . - -COPY maxengine_server_entrypoint.sh /usr/bin/ - -RUN chmod +x /usr/bin/maxengine_server_entrypoint.sh - -ENTRYPOINT ["/usr/bin/maxengine_server_entrypoint.sh"] diff --git a/dependencies/dockerfiles/maxtext_dependencies.Dockerfile b/dependencies/dockerfiles/maxtext_dependencies.Dockerfile deleted file mode 100644 index aa133f80f0..0000000000 --- a/dependencies/dockerfiles/maxtext_dependencies.Dockerfile +++ /dev/null @@ -1,55 +0,0 @@ -# syntax=docker/dockerfile:experimental -# Use Python 3.12 as the base image -FROM python:3.12-slim-bullseye - -# Install system dependencies -RUN apt-get update && apt-get install -y curl gnupg - -# Add the Google Cloud SDK package repository -RUN echo "deb [signed-by=/usr/share/keyrings/cloud.google.gpg] https://packages.cloud.google.com/apt cloud-sdk main" | tee -a /etc/apt/sources.list.d/google-cloud-sdk.list -RUN curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | apt-key --keyring /usr/share/keyrings/cloud.google.gpg add - - -# Install the Google Cloud SDK -RUN apt-get update && apt-get install -y google-cloud-sdk - -# Set the default Python version to 3.12 -RUN update-alternatives --install /usr/bin/python3 python3 /usr/local/bin/python3.12 1 - -# Set environment variables for Google Cloud SDK and Python 3.12 -ENV PATH="/usr/local/google-cloud-sdk/bin:/usr/local/bin/python3.12:${PATH}" - -# Set environment variables via build arguments -ARG MODE -ENV ENV_MODE=$MODE - -ARG JAX_VERSION -ENV ENV_JAX_VERSION=$JAX_VERSION - -ARG LIBTPU_GCS_PATH -ENV ENV_LIBTPU_GCS_PATH=$LIBTPU_GCS_PATH - -ARG DEVICE -ENV ENV_DEVICE=$DEVICE - -ENV MAXTEXT_ASSETS_ROOT=/deps/src/MaxText/assets -ENV MAXTEXT_TEST_ASSETS_ROOT=/deps/src/MaxText/test_assets -ENV MAXTEXT_PKG_DIR=/deps/src/MaxText -ENV MAXTEXT_REPO_ROOT=/deps - -# Set the working directory in the container -WORKDIR /deps - -# Copy setup files and dependency files separately for better caching -COPY tools/setup tools/setup/ -COPY dependencies/requirements/ dependencies/requirements/ -COPY src/install_maxtext_extra_deps/extra_deps_from_github.txt src/install_maxtext_extra_deps/ - -# Install dependencies - these steps are cached unless the copied files change -RUN echo "Running command: bash setup.sh MODE=$ENV_MODE JAX_VERSION=$ENV_JAX_VERSION LIBTPU_GCS_PATH=${ENV_LIBTPU_GCS_PATH} DEVICE=${ENV_DEVICE}" -RUN --mount=type=cache,target=/root/.cache/pip bash /deps/tools/setup/setup.sh MODE=${ENV_MODE} JAX_VERSION=${ENV_JAX_VERSION} LIBTPU_GCS_PATH=${ENV_LIBTPU_GCS_PATH} DEVICE=${ENV_DEVICE} - -# Now copy the remaining code (source files that may change frequently) -COPY . . - -# Install (editable) MaxText -RUN test -f '/tmp/venv_created' && "$(tail -n1 /tmp/venv_created)"/bin/activate ; pip install --no-dependencies -e . diff --git a/dependencies/dockerfiles/maxtext_jax_ai_image.Dockerfile b/dependencies/dockerfiles/maxtext_jax_ai_image.Dockerfile deleted file mode 100644 index 1807592684..0000000000 --- a/dependencies/dockerfiles/maxtext_jax_ai_image.Dockerfile +++ /dev/null @@ -1,78 +0,0 @@ -ARG JAX_AI_IMAGE_BASEIMAGE - -# JAX AI Base Image -FROM $JAX_AI_IMAGE_BASEIMAGE -ARG JAX_AI_IMAGE_BASEIMAGE - -ARG COMMIT_HASH -ENV COMMIT_HASH=$COMMIT_HASH - -ENV MAXTEXT_ASSETS_ROOT=/deps/src/MaxText/assets -ENV MAXTEXT_TEST_ASSETS_ROOT=/deps/src/MaxText/test_assets -ENV MAXTEXT_PKG_DIR=/deps/src/MaxText -ENV MAXTEXT_REPO_ROOT=/deps - -# Set the working directory in the container -WORKDIR /deps - -# Copy setup files and dependency files separately for better caching -COPY tools/setup tools/setup/ -COPY dependencies/requirements/ dependencies/requirements/ -COPY src/install_maxtext_extra_deps/extra_deps_from_github.txt src/install_maxtext_extra_deps/ - -# For JAX AI tpu training images 0.4.37 AND 0.4.35 -# Orbax checkpoint installs the latest version of JAX, -# but the libtpu version in the base image is older. -# This version mismatch can cause compatibility issues -# and break MaxText. -# Upgrade libtpu version if using either of the old stable images - -ARG DEVICE -ENV DEVICE=$DEVICE - -RUN if [ "$DEVICE" = "tpu" ] && ([ "$JAX_AI_IMAGE_BASEIMAGE" = "us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:jax0.4.37-rev1" ] || [ "$JAX_AI_IMAGE_BASEIMAGE" = "us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:jax0.4.35-rev1" ]); then \ - python3 -m pip install --no-cache-dir --upgrade jax[tpu]; fi - -# Install Maxtext requirements with Jax AI Image -RUN apt-get update && apt-get install --yes && apt-get install --yes dnsutils -# TODO(bvandermoon, parambole): Remove this when it's added to JAX AI Image -RUN pip install google-cloud-monitoring - -# Install requirements file that was generated with pipreqs for JSS 0.6.1 using: -# pipreqs --savepath requirements_with_jax_stable_stack_0_6_1_pipreqs.txt -# Otherwise use general requirements_with_jax_ai_image.txt -RUN if [ "$DEVICE" = "tpu" ] && [ "$JAX_STABLE_STACK_BASEIMAGE" = "us-docker.pkg.dev/cloud-tpu-images/jax-ai-image/tpu:jax0.6.1-rev1" ]; then \ - python3 -m pip install -r /deps/dependencies/requirements/requirements_with_jax_stable_stack_0_6_1_pipreqs.txt; \ - else \ - python3 -m pip install -r /deps/dependencies/requirements/requirements_with_jax_ai_image.txt; \ - fi - -# Install google-tunix for TPU devices, skip for GPU -RUN if [ "$DEVICE" = "tpu" ]; then \ - python3 -m pip install 'google-tunix>=0.1.2'; \ - fi - -# Temporarily downgrade to JAX=0.7.2 for GPU images -RUN if [ "$DEVICE" = "gpu" ]; then \ - python3 -m pip install -U "jax[cuda12]==0.8.1"; \ - python3 -m pip install -U "transformer-engine-cu12" "transformer-engine-jax" "transformer-engine"; \ - fi - -# Now copy the remaining code (source files that may change frequently) -COPY . . - -RUN ls . - -ARG TEST_TYPE -# Copy over test assets if building image for end-to-end tests or unit tests -RUN if [ "$TEST_TYPE" = "xlml" ] || [ "$TEST_TYPE" = "unit_test" ]; then \ - if ! gcloud storage cp -r gs://maxtext-test-assets/* "${MAXTEXT_TEST_ASSETS_ROOT}"; then \ - echo "WARNING: Failed to download test assets from GCS. These files are only used for end-to-end tests; you may not have access to the bucket."; \ - fi; \ - fi - -# Run the script available in JAX AI base image to generate the manifest file -RUN bash /jax-ai-image/generate_manifest.sh PREFIX=maxtext COMMIT_HASH=$COMMIT_HASH - -# Install (editable) MaxText -RUN test -f '/tmp/venv_created' && "$(tail -n1 /tmp/venv_created)"/bin/activate ; pip install --no-dependencies -e . diff --git a/dependencies/dockerfiles/maxtext_libtpu_path.Dockerfile b/dependencies/dockerfiles/maxtext_libtpu_path.Dockerfile deleted file mode 100644 index 5c0866b53c..0000000000 --- a/dependencies/dockerfiles/maxtext_libtpu_path.Dockerfile +++ /dev/null @@ -1,13 +0,0 @@ -ARG BASEIMAGE=maxtext_base_image -FROM $BASEIMAGE - -ENV MAXTEXT_ASSETS_ROOT=/deps/src/MaxText/assets -ENV MAXTEXT_TEST_ASSETS_ROOT=/deps/src/MaxText/test_assets -ENV MAXTEXT_PKG_DIR=/deps/src/MaxText -ENV MAXTEXT_REPO_ROOT=/deps - -#FROM maxtext_base_image -# Set the TPU_LIBRARY_PATH -ENV TPU_LIBRARY_PATH='/root/custom_libtpu/libtpu.so' - -WORKDIR /deps diff --git a/dependencies/dockerfiles/maxtext_post_training_dependencies.Dockerfile b/dependencies/dockerfiles/maxtext_post_training_dependencies.Dockerfile index 31781bc757..ff3b3510b0 100644 --- a/dependencies/dockerfiles/maxtext_post_training_dependencies.Dockerfile +++ b/dependencies/dockerfiles/maxtext_post_training_dependencies.Dockerfile @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -ARG BASEIMAGE +ARG BASEIMAGE=maxtext_base_image FROM ${BASEIMAGE} -ARG MODE +ARG MODE ENV MODE=$MODE RUN echo "Installing Post-Training dependencies (vLLM, tpu-inference, tunix) with MODE=${MODE}" diff --git a/dependencies/dockerfiles/maxtext_post_training_local_dependencies.Dockerfile b/dependencies/dockerfiles/maxtext_post_training_local_dependencies.Dockerfile index c3905b7b96..11bfd0b6f8 100644 --- a/dependencies/dockerfiles/maxtext_post_training_local_dependencies.Dockerfile +++ b/dependencies/dockerfiles/maxtext_post_training_local_dependencies.Dockerfile @@ -12,8 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -ARG BASEIMAGE +ARG BASEIMAGE=maxtext_base_image FROM ${BASEIMAGE} + ARG MODE ENV MODE=$MODE @@ -31,11 +32,9 @@ COPY tunix /tunix RUN pip uninstall -y google-tunix RUN pip install -e /tunix --no-cache-dir - COPY vllm /vllm RUN VLLM_TARGET_DEVICE="tpu" pip install -e /vllm --no-cache-dir - COPY tpu-inference /tpu-inference RUN pip install -e /tpu-inference --no-cache-dir diff --git a/dependencies/dockerfiles/maxtext_db_dependencies.Dockerfile b/dependencies/dockerfiles/maxtext_tpu_dependencies.Dockerfile similarity index 84% rename from dependencies/dockerfiles/maxtext_db_dependencies.Dockerfile rename to dependencies/dockerfiles/maxtext_tpu_dependencies.Dockerfile index 1f7a72aa6f..1d6db625b9 100644 --- a/dependencies/dockerfiles/maxtext_db_dependencies.Dockerfile +++ b/dependencies/dockerfiles/maxtext_tpu_dependencies.Dockerfile @@ -1,6 +1,7 @@ # syntax=docker/dockerfile:experimental -# Copy benchmark-db -FROM gcr.io/tpu-prod-env-one-vm/benchmark-db:2025-02-14 + +ARG BASEIMAGE=python:3.12-slim-bullseye +FROM $BASEIMAGE # Install system dependencies RUN apt-get update && apt-get install -y curl gnupg @@ -25,8 +26,8 @@ ENV ENV_MODE=$MODE ARG JAX_VERSION ENV ENV_JAX_VERSION=$JAX_VERSION -ARG LIBTPU_GCS_PATH -ENV ENV_LIBTPU_GCS_PATH=$LIBTPU_GCS_PATH +ARG LIBTPU_VERSION +ENV ENV_LIBTPU_VERSION=$LIBTPU_VERSION ARG DEVICE ENV ENV_DEVICE=$DEVICE @@ -44,9 +45,12 @@ COPY tools/setup tools/setup/ COPY dependencies/requirements/ dependencies/requirements/ COPY src/install_maxtext_extra_deps/extra_deps_from_github.txt src/install_maxtext_extra_deps/ +# Copy the custom libtpu.so file if it exists inside maxtext repository +COPY libtpu.so* /root/custom_libtpu/ + # Install dependencies - these steps are cached unless the copied files change -RUN echo "Running command: bash setup.sh MODE=$ENV_MODE JAX_VERSION=$ENV_JAX_VERSION LIBTPU_GCS_PATH=${ENV_LIBTPU_GCS_PATH} DEVICE=${ENV_DEVICE}" -RUN --mount=type=cache,target=/root/.cache/pip bash /deps/tools/setup/setup.sh MODE=${ENV_MODE} JAX_VERSION=${ENV_JAX_VERSION} LIBTPU_GCS_PATH=${ENV_LIBTPU_GCS_PATH} DEVICE=${ENV_DEVICE} +RUN echo "Running command: bash setup.sh MODE=$ENV_MODE JAX_VERSION=$ENV_JAX_VERSION LIBTPU_VERSION=$ENV_LIBTPU_VERSION DEVICE=${ENV_DEVICE}" +RUN --mount=type=cache,target=/root/.cache/pip bash /deps/tools/setup/setup.sh MODE=${ENV_MODE} JAX_VERSION=${ENV_JAX_VERSION} LIBTPU_VERSION=${ENV_LIBTPU_VERSION} DEVICE=${ENV_DEVICE} # Now copy the remaining code (source files that may change frequently) COPY . . diff --git a/dependencies/requirements/requirements_with_jax_ai_image.txt b/dependencies/requirements/requirements_with_jax_ai_image.txt deleted file mode 100644 index 08cf4ef273..0000000000 --- a/dependencies/requirements/requirements_with_jax_ai_image.txt +++ /dev/null @@ -1,29 +0,0 @@ -# Requirements for Building the MaxText Docker Image -# These requirements are additional to the dependencies present in the JAX AI base image. -datasets @ https://github.com/huggingface/datasets/archive/6790e138c00b87a1ddc72184f89e7814cf784360.zip -flax>=0.11.0 -google-api-python-client -google-cloud-mldiagnostics -google-jetstream @ https://github.com/AI-Hypercomputer/JetStream/archive/29329e8e73820993f77cfc8efe34eb2a73f5de98.zip -grain[parquet]>=0.2.15 -jaxtyping -jsonlines -mlperf-logging @ https://github.com/mlcommons/logging/archive/38ab22670527888c8eb7825a4ece176fcc36a95d.zip -omegaconf -orbax-checkpoint>=0.11.22 -pathwaysutils>=0.1.1 -pillow>=11.1.0 -pre-commit -protobuf>=5.29.5 -pydantic -pyink -pylint -pytest -pytype -qwix -sentencepiece>=0.2.0 -tensorflow-datasets -tensorflow-text>=2.17.0 -tiktoken -tokamax>=0.0.4 -transformers diff --git a/dependencies/scripts/docker_build_dependency_image.sh b/dependencies/scripts/docker_build_dependency_image.sh index e26dec9196..0bd3cbcde9 100644 --- a/dependencies/scripts/docker_build_dependency_image.sh +++ b/dependencies/scripts/docker_build_dependency_image.sh @@ -23,33 +23,34 @@ # ================================== # Build docker image with stable dependencies -## bash dependencies/scripts/docker_build_dependency_image.sh MODE=stable -## bash dependencies/scripts/docker_build_dependency_image.sh DEVICE={{gpu|tpu}} MODE=stable_stack BASEIMAGE={{JAX_STABLE_STACK BASEIMAGE FROM ARTIFACT REGISTRY}} +## bash dependencies/scripts/docker_build_dependency_image.sh DEVICE={{gpu|tpu}} MODE=stable # Build docker image with nightly dependencies -## bash dependencies/scripts/docker_build_dependency_image.sh MODE=nightly +## bash dependencies/scripts/docker_build_dependency_image.sh DEVICE={{gpu|tpu}} MODE=nightly # Build docker image with stable dependencies and, a pinned JAX_VERSION for TPUs ## bash dependencies/scripts/docker_build_dependency_image.sh MODE=stable JAX_VERSION=0.4.13 -# Build docker image with stable dependencies and, a pinned JAX_VERSION for GPUs +# Build docker image with a pinned JAX_VERSION and, a pinned LIBTPU_VERSION for TPUs +## bash dependencies/scripts/docker_build_dependency_image.sh MODE={{stable|nightly}} JAX_VERSION=0.8.1 LIBTPU_VERSION=0.0.31.dev20251119+nightly + +# Build docker image with a custom libtpu.so for TPUs +# Note: libtpu.so file must be present in the root directory of the MaxText repository +## bash dependencies/scripts/docker_build_dependency_image.sh MODE={{stable|nightly}} + +# Build docker image with nightly dependencies and, a pinned JAX_VERSION for GPUs # Available versions listed at https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-nightly-artifacts-registry/simple/jax ## bash dependencies/scripts/docker_build_dependency_image.sh DEVICE=gpu MODE=nightly JAX_VERSION=0.4.36.dev20241109 -# MODE=custom_wheels builds the nightly environment, then reinstalls any -# additional wheels present in the maxtext directory. -# Use this mode to install custom dependencies, such as custom JAX or JAXlib builds. -## bash dependencies/scripts/docker_build_dependency_image.sh MODE=custom_wheels - # ================================== # POST-TRAINING BUILD EXAMPLES # ================================== # Build docker image with stable pre-training dependencies and stable post-training dependencies -## bash dependencies/scripts/docker_build_dependency_image.sh MODE=post-training +## bash dependencies/scripts/docker_build_dependency_image.sh WORKFLOW=post-training # Build docker image with stable pre-training dependencies and post-training dependencies from GitHub head -## bash dependencies/scripts/docker_build_dependency_image.sh MODE=post-training POST_TRAINING_SOURCE=local +## bash dependencies/scripts/docker_build_dependency_image.sh WORKFLOW=post-training POST_TRAINING_SOURCE=local if [ "${BASH_SOURCE-}" ]; then this_file="${BASH_SOURCE[0]}" @@ -75,13 +76,11 @@ if ! docker info > /dev/null 2>&1; then exit 1 fi -export LOCAL_IMAGE_NAME=maxtext_base_image -echo "Building to $LOCAL_IMAGE_NAME" - # Use Docker BuildKit so we can cache pip packages. export DOCKER_BUILDKIT=1 -echo "Starting to build your docker image. This will take a few minutes but the image can be reused as you iterate." +export LOCAL_IMAGE_NAME=maxtext_base_image +echo "Building docker image: $LOCAL_IMAGE_NAME. This will take a few minutes but the image can be reused as you iterate." # Set environment variables for ARGUMENT in "$@"; do @@ -91,149 +90,100 @@ for ARGUMENT in "$@"; do echo "$KEY=$VALUE" done - +# Set default values if not provided if [[ -z ${JAX_VERSION+x} ]] ; then export JAX_VERSION=NONE - echo "Default JAX_VERSION=${JAX_VERSION}" fi - if [[ -z ${MODE} ]]; then export MODE=stable - echo "Default MODE=${MODE}" - export CUSTOM_JAX=0 - export INSTALL_POST_TRAINING=0 -elif [[ ${MODE} == "custom_wheels" ]] ; then +# TODO: Remove 'custom_wheels' mode support when tpu-recipes migration is complete. +elif [[ ${MODE} == "custom_wheels" ]]; then + export WORKFLOW=custom-wheels export MODE=nightly - export CUSTOM_JAX=1 - export INSTALL_POST_TRAINING=0 -elif [[ ${MODE} == "post-training" || ${MODE} == "post-training-experimental" ]] ; then - export INSTALL_POST_TRAINING=1 - export CUSTOM_JAX=0 -else - export CUSTOM_JAX=0 - export INSTALL_POST_TRAINING=0 fi - if [[ -z ${DEVICE} ]]; then export DEVICE=tpu - echo "Default DEVICE=${DEVICE}" fi -# New flag for post-training source -if [[ -z ${POST_TRAINING_SOURCE} ]]; then - export POST_TRAINING_SOURCE=remote # Default to the original Dockerfile - echo "Default POST_TRAINING_SOURCE=${POST_TRAINING_SOURCE}" -fi +# Create docker build arguments array +docker_build_args=( + "DEVICE=${DEVICE}" + "MODE=${MODE}" + "JAX_VERSION=${JAX_VERSION}" +) + +run_docker_build() { + local dockerfile_path="$1" + shift 1 # Move past the first argument, the rest are build-args + docker build --network host $(printf -- '--build-arg %q ' "$@") -f "$dockerfile_path" -t "$LOCAL_IMAGE_NAME" . +} -# Function to build with MODE=jax_ai_image -build_ai_image() { - if [[ -z ${BASEIMAGE+x} ]]; then - echo "Error: BASEIMAGE is unset, please set it!" - exit 1 - fi - COMMIT_HASH=$(git rev-parse --short HEAD) - echo "Building JAX AI MaxText Imageat commit hash ${COMMIT_HASH}..." - - docker build \ - --build-arg JAX_AI_IMAGE_BASEIMAGE=${BASEIMAGE} \ - --build-arg COMMIT_HASH=${COMMIT_HASH} \ - --build-arg DEVICE="$DEVICE" \ - --network=host \ - -t ${LOCAL_IMAGE_NAME} \ - -f "$MAXTEXT_REPO_ROOT"'/dependencies/dockerfiles/maxtext_jax_ai_image.Dockerfile' \ - . +# Function to build post-training image +build_post_training_image() { + DOCKERFILE_NAME="" + if [[ ${POST_TRAINING_SOURCE} == "local" ]] ; then + # To install vllm, tunix, tpu-inference from a local path, we copy it into the build context, excluding __pycache__. + # This assumes vllm, tunix, tpu-inference is a sibling directory to the current one (maxtext). + rsync -a --exclude='__pycache__' ../tpu-inference . + rsync -a --exclude='__pycache__' ../vllm . + rsync -a --exclude='__pycache__' ../tunix . + + # The cleanup is set to run even if the build fails to remove the copied directory. + trap "rm -rf ./tpu-inference ./vllm ./tunix" EXIT INT TERM + + DOCKERFILE_NAME='maxtext_post_training_local_dependencies.Dockerfile' + echo "Building local post-training dependencies: $DOCKERFILE_NAME" + else + DOCKERFILE_NAME='maxtext_post_training_dependencies.Dockerfile' + echo "Building remote post-training dependencies: $DOCKERFILE_NAME" + fi + run_docker_build "$MAXTEXT_REPO_ROOT/dependencies/dockerfiles/$DOCKERFILE_NAME" \ + "MODE=${WORKFLOW}" "BASEIMAGE=${LOCAL_IMAGE_NAME}" +} + +# Function to build image for GPUs +build_gpu_image() { + if [[ ${MODE} == "pinned" ]]; then + local base_image=ghcr.io/nvidia/jax:base-2024-12-04 + docker_build_args+=("BASEIMAGE=${base_image}") + fi + + echo "Building docker image with arguments: ${docker_build_args[*]}" + run_docker_build "$MAXTEXT_REPO_ROOT/dependencies/dockerfiles/maxtext_gpu_dependencies.Dockerfile" "${docker_build_args[@]}" } -if [[ -z ${LIBTPU_GCS_PATH+x} ]] ; then - export LIBTPU_GCS_PATH=NONE - echo "Default LIBTPU_GCS_PATH=${LIBTPU_GCS_PATH}" - if [[ ${DEVICE} == "gpu" ]]; then - if [[ ${MODE} == "stable_stack" || ${MODE} == "jax_ai_image" ]]; then - build_ai_image - else - if [[ ${MODE} == "pinned" ]]; then - export BASEIMAGE=ghcr.io/nvidia/jax:base-2024-12-04 - else - export BASEIMAGE=ghcr.io/nvidia/jax:base - fi - docker build --network host --build-arg MODE=${MODE} --build-arg JAX_VERSION=$JAX_VERSION \ - --build-arg DEVICE=$DEVICE --build-arg BASEIMAGE=$BASEIMAGE \ - -f "$MAXTEXT_REPO_ROOT"'/dependencies/dockerfiles/maxtext_gpu_dependencies.Dockerfile' \ - -t ${LOCAL_IMAGE_NAME} . - fi +# Function to build image for TPUs +build_tpu_image() { + if [[ -n "$LIBTPU_VERSION" ]]; then + docker_build_args+=("LIBTPU_VERSION=${LIBTPU_VERSION}") else - if [[ ${MODE} == "stable_stack" || ${MODE} == "jax_ai_image" ]]; then - build_ai_image - elif [[ ${MANTARAY} == "true" ]]; then - echo "Building with benchmark-db" - docker build --network host --build-arg MODE=${MODE} --build-arg JAX_VERSION=$JAX_VERSION \ - --build-arg LIBTPU_GCS_PATH=$LIBTPU_GCS_PATH --build-arg DEVICE=$DEVICE \ - -f "$MAXTEXT_REPO_ROOT"'/dependencies/dockerfiles/maxtext_db_dependencies.Dockerfile' \ - -t ${LOCAL_IMAGE_NAME} . - elif [[ ${INSTALL_POST_TRAINING} -eq 1 && ${DEVICE} == "tpu" ]]; then - echo "Installing MaxText stable mode dependencies for post-training" - docker build --network host --build-arg MODE=stable --build-arg JAX_VERSION=$JAX_VERSION \ - --build-arg LIBTPU_GCS_PATH=$LIBTPU_GCS_PATH --build-arg DEVICE=$DEVICE \ - -f "$MAXTEXT_REPO_ROOT"'/dependencies/dockerfiles/maxtext_dependencies.Dockerfile' \ - -t ${LOCAL_IMAGE_NAME} . - else - docker build --network host --build-arg MODE=${MODE} --build-arg JAX_VERSION=$JAX_VERSION \ - --build-arg LIBTPU_GCS_PATH=$LIBTPU_GCS_PATH --build-arg DEVICE=$DEVICE \ - -f "$MAXTEXT_REPO_ROOT"'/dependencies/dockerfiles/maxtext_dependencies.Dockerfile' \ - -t ${LOCAL_IMAGE_NAME} . - fi + docker_build_args+=("LIBTPU_VERSION=NONE") fi -else - docker build --network host --build-arg MODE=${MODE} --build-arg JAX_VERSION=$JAX_VERSION \ - --build-arg LIBTPU_GCS_PATH=$LIBTPU_GCS_PATH \ - -f "$MAXTEXT_REPO_ROOT"'/dependencies/dockerfiles/maxtext_dependencies.Dockerfile' \ - -t ${LOCAL_IMAGE_NAME} . - docker build --network host --build-arg CUSTOM_LIBTPU=true \ - -f "$MAXTEXT_REPO_ROOT"'/dependencies/dockerfiles/maxtext_libtpu_path.Dockerfile' \ - -t ${LOCAL_IMAGE_NAME} . -fi -if [[ ${INSTALL_POST_TRAINING} -eq 1 ]] ; then - if [[ ${DEVICE} != "tpu" ]] ; then - echo "Error: MODE=post-training is only supported for DEVICE=tpu" - exit 1 + if [[ ${MANTARAY} == "true" ]]; then + local base_image=gcr.io/tpu-prod-env-one-vm/benchmark-db:2025-02-14 + docker_build_args+=("BASEIMAGE=${base_image}") fi - DOCKERFILE_NAME="" - if [[ ${POST_TRAINING_SOURCE} == "local" ]] ; then - - # To install tpu-inference from a local path, we copy it into the build context, excluding __pycache__. - # This assumes vllm, tunix, tpu-inference is a sibling directory to the current one (maxtext). - rsync -a --exclude='__pycache__' ../tpu-inference . - # To install vllm from a local path, we copy it into the build context, excluding __pycache__. - # This assumes vllm is a sibling directory to the current one (maxtext). - rsync -a --exclude='__pycache__' ../vllm . - - rsync -a --exclude='__pycache__' ../tunix . - - # The cleanup is set to run even if the build fails to remove the copied directory. - trap "rm -rf ./tpu-inference ./vllm ./tunix" EXIT INT TERM - - DOCKERFILE_NAME='maxtext_post_training_local_dependencies.Dockerfile' - echo "Using local post-training dependencies Dockerfile: $DOCKERFILE_NAME" - else - DOCKERFILE_NAME='maxtext_post_training_dependencies.Dockerfile' - echo "Using remote post-training dependencies Dockerfile: $DOCKERFILE_NAME" - fi - - docker build \ - --network host \ - --build-arg BASEIMAGE=${LOCAL_IMAGE_NAME} \ - --build-arg MODE=${MODE} \ - -f "$MAXTEXT_REPO_ROOT"'/dependencies/dockerfiles/'"$DOCKERFILE_NAME" \ - -t ${LOCAL_IMAGE_NAME} . -fi + echo "Building docker image with arguments: ${docker_build_args[*]}" + run_docker_build "$MAXTEXT_REPO_ROOT/dependencies/dockerfiles/maxtext_tpu_dependencies.Dockerfile" "${docker_build_args[@]}" + + # Handle post-training workflow if specified + if [[ ${WORKFLOW} == "post-training" || ${WORKFLOW} == "post-training-experimental" ]]; then + build_post_training_image + fi -if [[ ${CUSTOM_JAX} -eq 1 ]] ; then - echo "Installing custom jax and jaxlib" - docker build --network host \ - -f "$MAXTEXT_REPO_ROOT"'/dependencies/dockerfiles/maxtext_custom_wheels.Dockerfile' \ - -t ${LOCAL_IMAGE_NAME} . + # TODO: Remove 'custom_wheels' mode support when tpu-recipes migration is complete. + if [[ ${WORKFLOW} == "custom-wheels" ]]; then + echo "Building custom wheels dependencies." + run_docker_build "$MAXTEXT_REPO_ROOT/dependencies/dockerfiles/maxtext_custom_wheels.Dockerfile" "BASEIMAGE=${LOCAL_IMAGE_NAME}" + fi +} + +if [[ ${DEVICE} == "gpu" ]]; then + build_gpu_image +else + build_tpu_image fi echo "" diff --git a/docs/run_maxtext/run_maxtext_via_pathways.md b/docs/run_maxtext/run_maxtext_via_pathways.md index ca082f6c47..5ef3e29b1b 100644 --- a/docs/run_maxtext/run_maxtext_via_pathways.md +++ b/docs/run_maxtext/run_maxtext_via_pathways.md @@ -35,7 +35,7 @@ Before you can run a MaxText workload, you must complete the following setup ste Step 1: Build the Docker image for a TPU device. This image contains MaxText and its dependencies. ```shell - bash dependencies/scripts/docker_build_dependency_image.sh DEVICE=tpu MODE=jax_ai_image BASEIMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-ai-image/tpu:latest + bash dependencies/scripts/docker_build_dependency_image.sh DEVICE=tpu MODE=stable ``` Step 2: Configure Docker to authenticate with Google Cloud diff --git a/docs/run_maxtext/run_maxtext_via_xpk.md b/docs/run_maxtext/run_maxtext_via_xpk.md index 543f5595ee..a61f1dd987 100644 --- a/docs/run_maxtext/run_maxtext_via_xpk.md +++ b/docs/run_maxtext/run_maxtext_via_xpk.md @@ -117,8 +117,6 @@ pip install xpk ## 4. Build the MaxText Docker image -A recommended approach for running MaxText is to build your image from a **JAX AI Image**, which ensures all core libraries are version-matched and stable. - 1. **Clone the MaxText repository** ``` @@ -126,18 +124,18 @@ A recommended approach for running MaxText is to build your image from a **JAX A cd maxtext ``` -2. **Build the image for your target hardware (TPU or GPU)** This script creates a local Docker image named `maxtext_base_image`. You can find a full list of available base images in the [JAX AI Images documentation](https://cloud.google.com/ai-hypercomputer/docs/images). +2. **Build the image for your target hardware (TPU or GPU)** This script creates a local Docker image named `maxtext_base_image`. - **For TPUs:** ``` - bash docker_build_dependency_image.sh DEVICE=tpu MODE=jax_ai_image BASEIMAGE=us-docker.pkg.dev/cloud-tpu-images/jax-ai-image/tpu:jax0.5.2-rev2 + bash docker_build_dependency_image.sh DEVICE=tpu MODE=stable ``` - **For GPUs:** ``` - bash docker_build_dependency_image.sh DEVICE=gpu MODE=jax_ai_image BASEIMAGE=us-central1-docker.pkg.dev/deeplearning-images/jax-ai-image/gpu:jax0.5.1-cuda_dl25.02-rev1 + bash docker_build_dependency_image.sh DEVICE=gpu MODE=stable ``` * * * * * diff --git a/docs/tutorials/posttraining/rl_on_multi_host.md b/docs/tutorials/posttraining/rl_on_multi_host.md index 239d154e92..eaeef3d5e2 100644 --- a/docs/tutorials/posttraining/rl_on_multi_host.md +++ b/docs/tutorials/posttraining/rl_on_multi_host.md @@ -96,16 +96,16 @@ Run the following bash script to create a docker image with all the dependencies In addition to MaxText dependencies, primarily, it installs `vllm-tpu` which is [vllm](https://github.com/vllm-project/vllm) and [tpu-inference](https://github.com/vllm-project/tpu-inference) and thereby providing TPU inference for vLLM, with unified JAX and PyTorch support. This build process takes approximately 10 to 15 minutes. ``` -bash dependencies/scripts/docker_build_dependency_image.sh MODE=post-training +bash dependencies/scripts/docker_build_dependency_image.sh WORKFLOW=post-training ``` -You can also use `bash dependencies/scripts/docker_build_dependency_image.sh MODE=post-training-experimental` to try out new features via experimental dependencies such as improved pathwaysutils resharding API. +You can also use `bash dependencies/scripts/docker_build_dependency_image.sh WORKFLOW=post-training-experimental` to try out new features via experimental dependencies such as improved pathwaysutils resharding API. ### Option 2: Install from locally git cloned repositories You can also locally git clone [tunix](https://github.com/google/tunix), [tpu-inference](https://github.com/vllm-project/tpu-inference), [vllm](https://github.com/vllm-project/vllm.git) and then use the following command to build a docker image using them: ``` -bash dependencies/scripts/docker_build_dependency_image.sh MODE=post-training POST_TRAINING_SOURCE=local +bash dependencies/scripts/docker_build_dependency_image.sh WORKFLOW=post-training POST_TRAINING_SOURCE=local ``` ### Upload the dependency docker image along with MaxText code diff --git a/docs/tutorials/posttraining/sft_on_multi_host.md b/docs/tutorials/posttraining/sft_on_multi_host.md index 80a008c5cb..a55c24bb39 100644 --- a/docs/tutorials/posttraining/sft_on_multi_host.md +++ b/docs/tutorials/posttraining/sft_on_multi_host.md @@ -45,7 +45,7 @@ docker run hello-world ``` Then run the following command to create a local Docker image named `maxtext_base_image`. This build process takes approximately 10 to 15 minutes. ```bash -bash dependencies/scripts/docker_build_dependency_image.sh MODE=post-training +bash dependencies/scripts/docker_build_dependency_image.sh WORKFLOW=post-training ``` ### 1.3. Upload the Docker image to Artifact Registry diff --git a/preflight.sh b/preflight.sh index 50da4cf4cd..1eb3025e9d 100644 --- a/preflight.sh +++ b/preflight.sh @@ -6,7 +6,7 @@ echo "Running preflight.sh" # bash preflight.sh # Warning: -# For any dependencies, please add them into `setup.sh` or `maxtext_dependencies.Dockerfile`. +# For any dependencies, please add them into `setup.sh` or `maxtext_tpu_dependencies.Dockerfile`. # You should not install any dependencies in this file. # Stop execution if any command exits with error diff --git a/tools/setup/setup.sh b/tools/setup/setup.sh index f26be1ddcc..08b5418cc0 100644 --- a/tools/setup/setup.sh +++ b/tools/setup/setup.sh @@ -33,6 +33,9 @@ # Install dependencies in dependencies/generated_requirements/tpu-requirements.txt + specified jax-nightly, jaxlib-nightly + latest libtpu-nightly ## bash tools/setup/setup.sh MODE=nightly JAX_VERSION=0.8.2.dev20251211 +# Install dependencies in dependencies/generated_requirements/tpu-requirements.txt + specified jax-nightly, jaxlib-nightly + specific libtpu-nightly +## bash tools/setup/setup.sh MODE=nightly JAX_VERSION=0.8.1 LIBTPU_VERSION=0.0.31.dev20251119+nightly + # Install dependencies in dependencies/generated_requirements/tpu-requirements.txt + jax-nightly, jaxlib-nightly + custom libtpu ## bash tools/setup/setup.sh MODE=nightly LIBTPU_GCS_PATH=gs://my_custom_libtpu/libtpu.so @@ -153,7 +156,7 @@ fi # Unset optional variables if set to NONE unset_optional_vars() { - local optional_vars=("JAX_VERSION" "LIBTPU_GCS_PATH") + local optional_vars=("JAX_VERSION" "LIBTPU_VERSION" "LIBTPU_GCS_PATH") for var_name in "${optional_vars[@]}"; do if [[ ${!var_name} == NONE ]]; then unset "$var_name" @@ -221,6 +224,10 @@ if [[ "$MODE" == "stable" ]]; then fi if [[ -n "$LIBTPU_GCS_PATH" ]]; then install_custom_libtpu + elif [[ -n "$LIBTPU_VERSION" ]]; then + echo -e "\nInstalling libtpu ${LIBTPU_VERSION}" + version_mismatch_warning "libtpu" + python3 -m uv pip install -U --no-deps libtpu==${LIBTPU_VERSION} -f https://storage.googleapis.com/jax-releases/libtpu_releases.html fi elif [[ $DEVICE == "gpu" ]]; then if [[ -n "$JAX_VERSION" ]]; then @@ -249,9 +256,12 @@ if [[ $MODE == "nightly" ]]; then echo -e "\nInstalling the latest jax-nightly, jaxlib-nightly" python3 -m uv pip install --pre -U --no-deps jax jaxlib -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ fi - if [[ -n "$LIBTPU_GCS_PATH" ]]; then install_custom_libtpu + elif [[ -n "$LIBTPU_VERSION" ]]; then + echo -e "\nInstalling libtpu ${LIBTPU_VERSION}" + version_mismatch_warning "libtpu" + python3 -m uv pip install -U --no-deps libtpu==${LIBTPU_VERSION} -f https://storage.googleapis.com/jax-releases/libtpu_releases.html else echo -e "\nInstalling the latest libtpu-nightly" python3 -m uv pip install -U --pre --no-deps libtpu -f https://storage.googleapis.com/jax-releases/libtpu_releases.html