Run PyTorch code on TPU slices

Before running the commands in this document, make sure you have followed the instructions in Set up an account and Cloud TPU project .

After you have your PyTorch code running on a single TPU VM, you can scale up your code by running it on a TPU slice . TPU slices are multiple TPU boards connected to each other over dedicated high-speed network connections. This document is an introduction to running PyTorch code on TPU slices.

Create a Cloud TPU slice

  1. Define some environment variables to make the commands easier to use.

     export 
      
     PROJECT_ID 
     = 
     your-project-id 
     export 
      
     TPU_NAME 
     = 
     your-tpu-name 
     export 
      
     ZONE 
     = 
     europe-west4-b 
     export 
      
     ACCELERATOR_TYPE 
     = 
     v5p-32 
     export 
      
     RUNTIME_VERSION 
     = 
     v2-alpha-tpuv5 
    

    Environment variable descriptions

    Variable
    Description
    PROJECT_ID
    Your Google Cloud project ID. Use an existing project or create a new one.
    TPU_NAME
    The name of the TPU.
    ZONE
    The zone in which to create the TPU VM. For more information about supported zones, see TPU regions and zones .
    ACCELERATOR_TYPE
    The accelerator type specifies the version and size of the Cloud TPU you want to create. For more information about supported accelerator types for each TPU version, see TPU versions .
    RUNTIME_VERSION
    The Cloud TPU software version .
  2. Create your TPU VM by running the following command:

     $  
     
    gcloud  
    compute  
    tpus  
    tpu-vm  
    create  
     ${ 
     TPU_NAME 
     } 
      
     \ 
      
    --zone = 
     ${ 
     ZONE 
     } 
      
     \ 
      
    --project = 
     ${ 
     PROJECT_ID 
     } 
      
     \ 
      
    --accelerator-type = 
     ${ 
     ACCELERATOR_TYPE 
     } 
      
     \ 
      
    --version = 
     ${ 
     RUNTIME_VERSION 
     } 
    

Install PyTorch/XLA on your slice

After creating the TPU slice, you must install PyTorch on all hosts in the TPU slice. You can do this using the gcloud compute tpus tpu-vm ssh command using the --worker=all and --commamnd parameters.

If the following commands fail due to an SSH connection error, it might be because the TPU VMs don't have external IP addresses. To access a TPU VM without an external IP address, follow the instructions in Connect to a TPU VM without a public IP address .

  1. Install PyTorch/XLA on all TPU VM workers:

    gcloud  
    compute  
    tpus  
    tpu-vm  
    ssh  
     ${ 
     TPU_NAME 
     } 
      
     \ 
      
    --zone = 
     ${ 
     ZONE 
     } 
      
     \ 
      
    --project = 
     ${ 
     PROJECT_ID 
     } 
      
     \ 
      
    --worker = 
    all  
     \ 
      
    --command = 
     "pip install torch~=2.5.0 torch_xla[tpu]~=2.5.0 torchvision -f https://storage.googleapis.com/libtpu-releases/index.html" 
    
  2. Clone XLA on all TPU VM workers:

    gcloud  
    compute  
    tpus  
    tpu-vm  
    ssh  
     ${ 
     TPU_NAME 
     } 
      
     \ 
      
    --zone = 
     ${ 
     ZONE 
     } 
      
     \ 
      
    --project = 
     ${ 
     PROJECT_ID 
     } 
      
     \ 
      
    --worker = 
    all  
     \ 
      
    --command = 
     "git clone https://github.com/pytorch/xla.git" 
    

Run a training script on your TPU slice

Run the training script on all workers. The training script uses a Single Program Multiple Data (SPMD) sharding strategy. For more information on SPMD, see PyTorch/XLA SPMD User Guide .

gcloud  
compute  
tpus  
tpu-vm  
ssh  
 ${ 
 TPU_NAME 
 } 
  
 \ 
  
--zone = 
 ${ 
 ZONE 
 } 
  
 \ 
  
--project = 
 ${ 
 PROJECT_ID 
 } 
  
 \ 
  
--worker = 
all  
 \ 
  
--command = 
 "PJRT_DEVICE=TPU python3 ~/xla/test/spmd/test_train_spmd_imagenet.py  \ 
 --fake_data \ 
 --model=resnet50  \ 
 --num_epochs=1 2>&1 | tee ~/logs.txt" 

The training takes about 15 minutes. When it completes, you should see a message similar to the following:

Epoch 1 test end 23:49:15, Accuracy=100.00
     10.164.0.11 [0] Max Accuracy: 100.00%

Clean up

When you are done with your TPU VM, follow these steps to clean up your resources.

  1. Disconnect from the Cloud TPU instance, if you have not already done so:

      ( 
    vm ) 
    $  
     
     exit 
    

    Your prompt should now be username@projectname , showing you are in the Cloud Shell.

  2. Delete your Cloud TPU resources.

     $  
     
    gcloud  
    compute  
    tpus  
    tpu-vm  
    delete  
     \ 
      
    --zone = 
     ${ 
     ZONE 
     } 
    
  3. Verify the resources have been deleted by running gcloud compute tpus tpu-vm list . The deletion might take several minutes. The output from the following command shouldn't include any of the resources created in this tutorial:

     $  
     
    gcloud  
    compute  
    tpus  
    tpu-vm  
    list  
    --zone = 
     ${ 
     ZONE 
     } 
    
Create a Mobile Website
View Site in Mobile | Classic
Share by: