Run PyTorch code on TPU slices
Before running the commands in this document, make sure you have followed the instructions in Set up an account and Cloud TPU project .
After you have your PyTorch code running on a single TPU VM, you can scale up your code by running it on a TPU slice . TPU slices are multiple TPU boards connected to each other over dedicated high-speed network connections. This document is an introduction to running PyTorch code on TPU slices.
Required roles
To get the permissions that you need to create a TPU and connect to it using SSH, ask your administrator to grant you the following IAM roles on your project:
- TPU Admin
(
roles/tpu.admin) - Service Account User
(
roles/iam.serviceAccountUser) - Compute Viewer
(
roles/compute.viewer)
For more information about granting roles, see Manage access to projects, folders, and organizations .
You might also be able to get the required permissions through custom roles or other predefined roles .
Create a Cloud TPU slice
-
Define some environment variables to make the commands easier to use.
export PROJECT_ID = your-project-id export TPU_NAME = your-tpu-name export ZONE = europe-west4-b export ACCELERATOR_TYPE = v5p-32 export RUNTIME_VERSION = v2-alpha-tpuv5
Environment variable descriptions
-
PROJECT_ID - Your Google Cloud project ID. Use an existing project or create a new one.
-
TPU_NAME - The name of the TPU.
-
ZONE - The zone in which to create the TPU VM. For more information about supported zones, see TPU regions and zones .
-
ACCELERATOR_TYPE - The accelerator type specifies the version and size of the Cloud TPU you want to create. For more information about supported accelerator types for each TPU version, see TPU versions .
-
RUNTIME_VERSION - The Cloud TPU software version .
-
-
Create your TPU VM by running the following command:
$ gcloud compute tpus tpu-vm create ${ TPU_NAME } \ --zone = ${ ZONE } \ --project = ${ PROJECT_ID } \ --accelerator-type = ${ ACCELERATOR_TYPE } \ --version = ${ RUNTIME_VERSION }
Install PyTorch/XLA on your slice
After creating the TPU slice, you must install PyTorch on all hosts in the
TPU slice. You can do this using the gcloud compute tpus tpu-vm ssh
command using
the --worker=all
and --command
parameters.
If the following commands fail due to an SSH connection error, it might be because the TPU VMs don't have external IP addresses. To access a TPU VM without an external IP address, follow the instructions in Connect to a TPU VM without a public IP address .
-
Install PyTorch/XLA on all TPU VM workers:
gcloud compute tpus tpu-vm ssh ${ TPU_NAME } \ --zone = ${ ZONE } \ --project = ${ PROJECT_ID } \ --worker = all \ --command = "pip install torch~=2.5.0 torch_xla[tpu]~=2.5.0 torchvision -f https://storage.googleapis.com/libtpu-releases/index.html"
-
Clone XLA on all TPU VM workers:
gcloud compute tpus tpu-vm ssh ${ TPU_NAME } \ --zone = ${ ZONE } \ --project = ${ PROJECT_ID } \ --worker = all \ --command = "git clone https://github.com/pytorch/xla.git"
Run a training script on your TPU slice
Run the training script on all workers. The training script uses a Single Program Multiple Data (SPMD) sharding strategy. For more information on SPMD, see PyTorch/XLA SPMD User Guide .
gcloud compute tpus tpu-vm ssh ${ TPU_NAME } \ --zone = ${ ZONE } \ --project = ${ PROJECT_ID } \ --worker = all \ --command = "PJRT_DEVICE=TPU python3 ~/xla/test/spmd/test_train_spmd_imagenet.py \ --fake_data \ --model=resnet50 \ --num_epochs=1 2>&1 | tee ~/logs.txt"
The training takes about 15 minutes. When it completes, you should see a message similar to the following:
Epoch 1 test end 23:49:15, Accuracy=100.00
10.164.0.11 [0] Max Accuracy: 100.00%
Clean up
When you are done with your TPU VM, follow these steps to clean up your resources.
-
Disconnect from the Cloud TPU instance, if you have not already done so:
( vm ) $ exit
Your prompt should now be
username@projectname, showing you are in the Cloud Shell. -
Delete your Cloud TPU resources.
$ gcloud compute tpus tpu-vm delete \ --zone = ${ ZONE }
-
Verify the resources have been deleted by running
gcloud compute tpus tpu-vm list. The deletion might take several minutes. The output from the following command shouldn't include any of the resources created in this tutorial:$ gcloud compute tpus tpu-vm list --zone = ${ ZONE }

