Run a calculation on a Cloud TPU VM using PyTorch

This document provides a brief introduction to working with PyTorch and Cloud TPU.

Before you begin

Before running the commands in this document, you must create a Google Cloud account, install the Google Cloud CLI, and configure the gcloud command. For more information, see Set up the Cloud TPU environment .

Create a Cloud TPU using gcloud

  1. Define some environment variables to make the commands easier to use.

     export 
      
     PROJECT_ID 
     = 
     your-project-id 
     export 
      
     TPU_NAME 
     = 
     your-tpu-name 
     export 
      
     ZONE 
     = 
     us-east5-a 
     export 
      
     ACCELERATOR_TYPE 
     = 
     v5litepod-8 
     export 
      
     RUNTIME_VERSION 
     = 
     v2-alpha-tpuv5-lite 
    

    Environment variable descriptions

    Variable
    Description
    PROJECT_ID
    Your Google Cloud project ID. Use an existing project or create a new one.
    TPU_NAME
    The name of the TPU.
    ZONE
    The zone in which to create the TPU VM. For more information about supported zones, see TPU regions and zones .
    ACCELERATOR_TYPE
    The accelerator type specifies the version and size of the Cloud TPU you want to create. For more information about supported accelerator types for each TPU version, see TPU versions .
    RUNTIME_VERSION
    The Cloud TPU software version .
  2. Create your TPU VM by running the following command:

     $  
     
    gcloud  
    compute  
    tpus  
    tpu-vm  
    create  
     $TPU_NAME 
      
     \ 
      
    --project = 
     $PROJECT_ID 
      
     \ 
      
    --zone = 
     $ZONE 
      
     \ 
      
    --accelerator-type = 
     $ACCELERATOR_TYPE 
      
     \ 
      
    --version = 
     $RUNTIME_VERSION 
    

Connect to your Cloud TPU VM

Connect to your TPU VM over SSH using the following command:

 $  
 
gcloud  
compute  
tpus  
tpu-vm  
ssh  
 $TPU_NAME 
  
 \ 
  
--project = 
 $PROJECT_ID 
  
 \ 
  
--zone = 
 $ZONE 

If you fail to connect to a TPU VM using SSH, it might be because the TPU VM doesn't have an external IP address. To access a TPU VM without an external IP address, follow the instructions in Connect to a TPU VM without a public IP address .

Install PyTorch/XLA on your TPU VM

 $  
 ( 
vm ) 
  
 
sudo  
apt-get  
update $  
 ( 
vm ) 
  
 
sudo  
apt-get  
install  
libopenblas-dev  
-y $  
 ( 
vm ) 
  
 
pip  
install  
numpy $  
 ( 
vm ) 
  
 
pip  
install  
torch  
torch_xla [ 
tpu ] 
  
-f  
https://storage.googleapis.com/libtpu-releases/index.html

Verify PyTorch can access TPUs

Use the following command to verify PyTorch can access your TPUs:

 $  
 ( 
vm ) 
  
 
 PJRT_DEVICE 
 = 
TPU  
python3  
-c  
 "import torch_xla.core.xla_model as xm; print(xm.get_xla_supported_devices(\"TPU\"))" 

The output from the command should look like the following:

['xla:0', 'xla:1', 'xla:2', 'xla:3', 'xla:4', 'xla:5', 'xla:6', 'xla:7']

Perform a basic calculation

  1. Create a file named tpu-test.py in the current directory and copy and paste the following script into it:

      import 
      
     torch 
     import 
      
     torch_xla.core.xla_model 
      
     as 
      
     xm 
     dev 
     = 
     xm 
     . 
     xla_device 
     () 
     t1 
     = 
     torch 
     . 
     randn 
     ( 
     3 
     , 
     3 
     , 
     device 
     = 
     dev 
     ) 
     t2 
     = 
     torch 
     . 
     randn 
     ( 
     3 
     , 
     3 
     , 
     device 
     = 
     dev 
     ) 
     print 
     ( 
     t1 
     + 
     t2 
     ) 
     
    
  2. Run the script:

      ( 
    vm ) 
    $  
     
     PJRT_DEVICE 
     = 
    TPU  
    python3  
    tpu-test.py

    The output from the script shows the result of the computation:

     tensor([[-0.2121,  1.5589, -0.6951],
            [-0.7886, -0.2022,  0.9242],
            [ 0.8555, -1.8698,  1.4333]], device='xla:1') 
    

Clean up

To avoid incurring charges to your Google Cloud account for the resources used on this page, follow these steps.

  1. Disconnect from the Cloud TPU instance, if you have not already done so:

      ( 
    vm ) 
    $  
     
     exit 
    

    Your prompt should now be username@projectname , showing you are in the Cloud Shell.

  2. Delete your Cloud TPU.

     $  
     
    gcloud  
    compute  
    tpus  
    tpu-vm  
    delete  
     $TPU_NAME 
      
     \ 
      
    --project = 
     $PROJECT_ID 
      
     \ 
      
    --zone = 
     $ZONE 
    
  3. Verify the resources have been deleted by running the following command. Make sure your TPU is no longer listed. The deletion might take several minutes.

     $  
     
    gcloud  
    compute  
    tpus  
    tpu-vm  
    list  
     \ 
      
    --zone = 
     $ZONE 
    

What's next

Read more about Cloud TPU VMs:

Create a Mobile Website
View Site in Mobile | Classic
Share by: