Run TPU workloads in a Docker container

Docker containers make configuring applications easier by combining your code and all needed dependencies in one distributable package. You can run Docker containers within TPU VMs to simplify configuring and sharing your Cloud TPU applications. This document describes how to set up a Docker container for each ML framework supported by Cloud TPU.

Train a PyTorch model in a Docker container

TPU device

  1. Create Cloud TPU VM

    gcloud  
    compute  
    tpus  
    tpu-vm  
    create  
     your-tpu-name 
      
     \ 
    --zone = 
     europe-west4-a 
      
     \ 
    --accelerator-type = 
     v2-8 
      
     \ 
    --version = 
    tpu-ubuntu2204-base
  2. Connect to the TPU VM using SSH

    gcloud  
    compute  
    tpus  
    tpu-vm  
    ssh  
     your-tpu-name 
      
     \ 
    --zone = 
     europe-west4-a 
    
  3. Make sure your Google Cloud user has been granted the Artifact Registry Reader role. For more information, see Granting Artifact Registry roles .

  4. Start a container in the TPU VM using the nightly PyTorch/XLA image

    sudo  
    docker  
    run  
    --net = 
    host  
    -ti  
    --rm  
    --name  
    your-container-name  
    --privileged  
    us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.6.0_3.10_tpuvm_cxx11  
     \ 
    bash
  5. Configure TPU runtime

    There are two PyTorch/XLA runtime options: PJRT and XRT. We recommend you use PJRT unless you have a reason to use XRT. To learn more about the different runtime configurations, see the PJRT runtime documentation .

    PJRT

     export 
      
     PJRT_DEVICE 
     = 
    TPU

    XRT

     export 
      
     XRT_TPU_CONFIG 
     = 
     "localservice;0;localhost:51011" 
    
  6. Clone the PyTorch XLA repository

    git  
    clone  
    --recursive  
    https://github.com/pytorch/xla.git
  7. Train ResNet50

    python3  
    xla/test/test_train_mp_imagenet.py  
    --fake_data  
    --model = 
    resnet50  
    --num_epochs = 
     1 
    

When the training script completes, make sure you clean up the resources.

  1. Type exit to exit from the Docker container
  2. Type exit to exit from the TPU VM
  3. Delete the TPU VM

    gcloud  
    compute  
    tpus  
    tpu-vm  
    delete  
     your-tpu-name 
      
    --zone = 
     europe-west4-a 
    

TPU slice

When you run PyTorch code on a TPU slice, you must run your code on all TPU workers at the same time. One way to do this is to use the gcloud compute tpus tpu-vm ssh command with the --worker=all and --command flags. The following procedure shows you how to create a Docker image to make setting up each TPU worker easier.

  1. Create a TPU VM

    gcloud  
    compute  
    tpus  
    tpu-vm  
    create  
     your-tpu-name 
      
     \ 
    --zone = 
     us-central2-b 
      
     \ 
    --accelerator-type = 
     v4-32 
      
     \ 
    --version = 
     tpu-ubuntu2204-base 
    
  2. Add the current user to the Docker group

    gcloud  
    compute  
    tpus  
    tpu-vm  
    ssh  
     your-tpu-name 
      
     \ 
    --zone = 
     us-central2-b 
      
     \ 
    --worker = 
    all  
     \ 
    --command = 
     'sudo usermod -a -G docker $USER' 
    
  3. Clone the PyTorch XLA repository

    gcloud  
    compute  
    tpus  
    tpu-vm  
    ssh  
     your-tpu-name 
      
    --worker = 
    all  
     \ 
    --zone = 
     us-central2-b 
      
     \ 
    --command = 
     "git clone --recursive https://github.com/pytorch/xla.git" 
    
  4. Run the training script in a container on all TPU workers

    gcloud  
    compute  
    tpus  
    tpu-vm  
    ssh  
     your-tpu-name 
      
    --worker = 
    all  
     \ 
    --zone = 
     us-central2-b 
      
     \ 
    --command = 
     "docker run --rm --privileged --net=host  -v ~/xla:/xla -e PJRT_DEVICE=TPU us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.6.0_3.10_tpuvm_cxx11 python /xla/test/test_train_mp_imagenet.py --fake_data --model=resnet50 --num_epochs=1" 
    

    Docker command flags:

    • --rm removes the container after its process terminates.
    • --privileged exposes the TPU device to the container.
    • --net=host binds all of the container's ports to the TPU VM to allow communication between the hosts in the pod.
    • -e sets environment variables.

When the training script completes, make sure you clean up the resources.

Delete the TPU VM using the following command:

gcloud  
compute  
tpus  
tpu-vm  
delete  
 your-tpu-name 
  
 \ 
--zone = 
 us-central2-b 

Train a JAX model in a Docker container

TPU device

  1. Create the TPU VM

    gcloud  
    compute  
    tpus  
    tpu-vm  
    create  
     your-tpu-name 
      
     \ 
    --zone = 
     europe-west4-a 
      
     \ 
    --accelerator-type = 
     v2-8 
      
     \ 
    --version = 
     tpu-ubuntu2204-base 
    
  2. Connect to the TPU VM using SSH

    gcloud  
    compute  
    tpus  
    tpu-vm  
    ssh  
     your-tpu-name 
      
    --zone = 
     europe-west4-a 
    
  3. Start Docker daemon in TPU VM

    sudo  
    systemctl  
    start  
    docker
  4. Start Docker container

    sudo  
    docker  
    run  
    --net = 
    host  
    -ti  
    --rm  
    --name  
     your-container-name 
      
     \ 
    --privileged  
    us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.6.0_3.10_tpuvm_cxx11  
     \ 
    bash
  5. Install JAX

    pip  
    install  
    jax [ 
    tpu ] 
    
  6. Install FLAX

    pip  
    install  
    --upgrade  
    clu
    git  
    clone  
    https://github.com/google/flax.git
    pip  
    install  
    --user  
    -e  
    flax
  7. Install tensorflow and tensorflow-dataset packages

    pip  
    install  
    tensorflow
    pip  
    install  
    tensorflow-datasets
  8. Run the FLAX MNIST training script

     cd 
      
    flax/examples/mnist
    python3  
    main.py  
    --workdir = 
    /tmp/mnist  
     \ 
    --config = 
    configs/default.py  
     \ 
    --config.learning_rate = 
     0 
    .05  
     \ 
    --config.num_epochs = 
     5 
    

When the training script completes, make sure you clean up the resources.

  1. Type exit to exit from the Docker container
  2. Type exit to exit from the TPU VM
  3. Delete the TPU VM

    gcloud  
    compute  
    tpus  
    tpu-vm  
    delete  
     your-tpu-name 
      
    --zone = 
     europe-west4-a 
    

TPU slice

When you run JAX code on a TPU slice, you must run your JAX code on all TPU workers at the same time. One way to do this is to use the gcloud compute tpus tpu-vm ssh command with the --worker=all and --command flags. The following procedure shows you how to create a Docker image to make setting up each TPU worker easier.

  1. Create a file named Dockerfile in your current directory and paste the following text

    FROM  
    python:3.10
    RUN  
    pip  
    install  
    jax [ 
    tpu ] 
    RUN  
    pip  
    install  
    --upgrade  
    clu
    RUN  
    git  
    clone  
    https://github.com/google/flax.git
    RUN  
    pip  
    install  
    --user  
    -e  
    flax
    RUN  
    pip  
    install  
    tensorflow
    RUN  
    pip  
    install  
    tensorflow-datasets
    WORKDIR  
    ./flax/examples/mnist
  2. Prepare an Artifact Registry

    gcloud  
    artifacts  
    repositories  
    create  
     your-repo 
      
     \ 
    --repository-format = 
    docker  
     \ 
    --location = 
    europe-west4  
    --description = 
     "Docker repository" 
      
     \ 
    --project = 
     your-project 
    gcloud  
    artifacts  
    repositories  
    list  
     \ 
    --project = 
     your-project 
    gcloud  
    auth  
    configure-docker  
    europe-west4-docker.pkg.dev
  3. Build the Docker image

    docker  
    build  
    -t  
     your-image-name 
      
    .
  4. Add a tag to your Docker image before pushing it to the Artifact Registry. For more information on working with Artifact Registry, see Work with container images .

    docker  
    tag  
     your-image-name 
      
    europe-west4-docker.pkg.dev/ your-project 
    / your-repo 
    / your-image-name 
    : your-tag 
    
  5. Push your Docker image to the Artifact Registry

    docker  
    push  
    europe-west4-docker.pkg.dev/ your-project 
    / your-repo 
    / your-image-name 
    : your-tag 
    
  6. Create a TPU VM

    gcloud  
    compute  
    tpus  
    tpu-vm  
    create  
     your-tpu-name 
      
     \ 
    --zone = 
     europe-west4-a 
      
     \ 
    --accelerator-type = 
     v2-8 
      
     \ 
    --version = 
     tpu-ubuntu2204-base 
    
  7. Pull the Docker image from the Artifact Registry on all TPU workers

    gcloud  
    compute  
    tpus  
    tpu-vm  
    ssh  
     your-tpu-name 
      
    --worker = 
    all  
     \ 
    --zone = 
     europe-west4-a 
      
     \ 
    --command = 
     'sudo usermod -a -G docker ${USER}' 
    
    gcloud  
    compute  
    tpus  
    tpu-vm  
    ssh  
     your-tpu-name 
      
    --worker = 
    all  
     \ 
    --zone = 
     europe-west4-a 
      
     \ 
    --command = 
     "gcloud auth configure-docker europe-west4-docker.pkg.dev --quiet" 
    
    gcloud  
    compute  
    tpus  
    tpu-vm  
    ssh  
     your-tpu-name 
      
    --worker = 
    all  
     \ 
    --zone = 
     europe-west4-a 
      
     \ 
    --command = 
     "docker pull europe-west4-docker.pkg.dev/ your-project 
    / your-repo 
    / your-image-name 
    : your-tag 
    " 
    
  8. Run the container on all TPU workers

    gcloud  
    compute  
    tpus  
    tpu-vm  
    ssh  
     your-tpu-name 
      
    --worker = 
    all  
     \ 
    --zone = 
     europe-west4-a 
      
     \ 
    --command = 
     "docker run -ti -d --privileged --net=host --name your-container-name 
    europe-west4-docker.pkg.dev/ your-project 
    / your-repo 
    / your-image-name 
    : your-tag 
    bash" 
    
  9. Run the training script on all TPU workers

    gcloud  
    compute  
    tpus  
    tpu-vm  
    ssh  
     your-tpu-name 
      
    --worker = 
    all  
     \ 
    --zone = 
     europe-west4-a 
      
     \ 
    --command = 
     "docker exec --privileged your-container-name 
    python3 main.py --workdir=/tmp/mnist \ 
     --config=configs/default.py \ 
     --config.learning_rate=0.05 \ 
     --config.num_epochs=5" 
    

When the training script completes, make sure you clean up the resources.

  1. Shut down the container on all workers

    gcloud  
    compute  
    tpus  
    tpu-vm  
    ssh  
     your-tpu-name 
      
    --worker = 
    all  
     \ 
    --zone = 
     europe-west4-a 
      
     \ 
    --command = 
     "docker kill your-container-name 
    " 
    
  2. Delete the TPU VM

    gcloud  
    compute  
    tpus  
    tpu-vm  
    delete  
     your-tpu-name 
      
     \ 
    --zone = 
     europe-west4-a 
    

Train a JAX model in a Docker container using JAX Stable Stack

You can build the MaxText and MaxDiffusion Docker images using the JAX Stable Stack base image.

JAX Stable Stack provides a consistent environment for MaxText and MaxDiffusion by bundling JAX with core packages like orbax , flax , optax , and libtpu.so . These libraries are tested to ensure compatibility and provide a stable foundation to build and run MaxText and MaxDiffusion. This eliminates potential conflicts due to incompatible package versions.

JAX Stable Stack includes a fully released and qualified libtpu.so , the core library that drives TPU program compilation, execution, and ICI network configuration. The libtpu release replaces the nightly build previously used by JAX and ensures consistent functionality of XLA computations on TPU with PJRT-level qualification tests in HLO/StableHLO IRs.

To build the MaxText and MaxDiffusion Docker image with JAX Stable Stack, when you run the docker_build_dependency_image.sh script, set the MODE variable to stable_stack and set the BASEIMAGE variable to the base image you want to use.

docker_build_dependency_image.sh is located in the MaxDiffusion GitHub repo and in the MaxText GitHub repo . Clone the repository you want to use and run the docker_build_dependency_image.sh script from that repository to build the Docker image.

git  
clone  
https://github.com/AI-Hypercomputer/maxdiffusion.git
git  
clone  
https://github.com/AI-Hypercomputer/maxtext.git

The following command generates a Docker image for use with MaxText and MaxDiffusion using us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:jax0.4.35-rev1 as the base image.

sudo  
bash  
docker_build_dependency_image.sh  
 MODE 
 = 
stable_stack  
 BASEIMAGE 
 = 
 us-docker.pkg.dev/cloud-tpu-images/jax-stable-stack/tpu:jax0.4.35-rev1 

For a list of available JAX Stable Stack base images, see JAX Stable Stack images in Artifact Registry .

What's next

Design a Mobile Site
View Site in Mobile | Classic
Share by: