Run JAX code on TPU slices

Before running the commands in this document, make sure you have followed the instructions in Set up an account and Cloud TPU project .

After you have your JAX code running on a single TPU board, you can scale up your code by running it on a TPU slice . TPU slices are multiple TPU boards connected to each other over dedicated high-speed network connections. This document is an introduction to running JAX code on TPU slices; for more in-depth information, see Using JAX in multi-host and multi-process environments .

Create a Cloud TPU slice

  1. Create some environment variables:

     export 
      
     PROJECT_ID 
     = 
     your-project-id 
     export 
      
     TPU_NAME 
     = 
     your-tpu-name 
     export 
      
     ZONE 
     = 
     europe-west4-b 
     export 
      
     ACCELERATOR_TYPE 
     = 
     v5litepod-32 
     export 
      
     RUNTIME_VERSION 
     = 
     v2-alpha-tpuv5-lite 
    

    Environment variable descriptions

    Variable
    Description
    PROJECT_ID
    Your Google Cloud project ID. Use an existing project or create a new one.
    TPU_NAME
    The name of the TPU.
    ZONE
    The zone in which to create the TPU VM. For more information about supported zones, see TPU regions and zones .
    ACCELERATOR_TYPE
    The accelerator type specifies the version and size of the Cloud TPU you want to create. For more information about supported accelerator types for each TPU version, see TPU versions .
    RUNTIME_VERSION
    The Cloud TPU software version .
  2. Create a TPU slice using the gcloud command. For example, to create a v5litepod-32 slice use the following command:

     $  
     
    gcloud  
    compute  
    tpus  
    tpu-vm  
    create  
     ${ 
     TPU_NAME 
     } 
      
     \ 
      
    --zone = 
     ${ 
     ZONE 
     } 
      
     \ 
      
    --project = 
     ${ 
     PROJECT_ID 
     } 
      
     \ 
      
    --accelerator-type = 
     ${ 
     ACCELERATOR_TYPE 
     } 
      
     \ 
      
    --version = 
     ${ 
     RUNTIME_VERSION 
     } 
      
    

Install JAX on your slice

After creating the TPU slice, you must install JAX on all hosts in the TPU slice. You can do this using the gcloud compute tpus tpu-vm ssh command using the --worker=all and --commamnd parameters.

gcloud  
compute  
tpus  
tpu-vm  
ssh  
 ${ 
 TPU_NAME 
 } 
  
 \ 
  
--zone = 
 ${ 
 ZONE 
 } 
  
 \ 
  
--project = 
 ${ 
 PROJECT_ID 
 } 
  
 \ 
  
--worker = 
all  
 \ 
  
--command = 
 'pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html' 

Run JAX code on the slice

To run JAX code on a TPU slice, you must run the code on each host in the TPU slice. The jax.device_count() call stops responding until it is called on each host in the slice. The following example illustrates how to run a JAX calculation on a TPU slice.

Prepare the code

You need gcloud version >= 344.0.0 (for the scp command). Use gcloud --version to check your gcloud version, and run gcloud components upgrade , if needed.

Create a file called example.py with the following code:

  import 
  
 jax 
 # The total number of TPU cores in the slice 
 device_count 
 = 
 jax 
 . 
 device_count 
 () 
 # The number of TPU cores attached to this host 
 local_device_count 
 = 
 jax 
 . 
 local_device_count 
 () 
 # The psum is performed over all mapped devices across the slice 
 xs 
 = 
 jax 
 . 
 numpy 
 . 
 ones 
 ( 
 jax 
 . 
 local_device_count 
 ()) 
 r 
 = 
 jax 
 . 
 pmap 
 ( 
 lambda 
 x 
 : 
 jax 
 . 
 lax 
 . 
 psum 
 ( 
 x 
 , 
 'i' 
 ), 
 axis_name 
 = 
 'i' 
 )( 
 xs 
 ) 
 # Print from a single host to avoid duplicated output 
 if 
 jax 
 . 
 process_index 
 () 
 == 
 0 
 : 
 print 
 ( 
 'global device count:' 
 , 
 jax 
 . 
 device_count 
 ()) 
 print 
 ( 
 'local device count:' 
 , 
 jax 
 . 
 local_device_count 
 ()) 
 print 
 ( 
 'pmap result:' 
 , 
 r 
 ) 
 

Copy example.py to all TPU worker VMs in the slice

 $  
 
gcloud  
compute  
tpus  
tpu-vm  
scp  
./example.py  
 ${ 
 TPU_NAME 
 } 
:  
 \ 
  
--worker = 
all  
 \ 
  
--zone = 
 ${ 
 ZONE 
 } 
  
 \ 
  
--project = 
 ${ 
 PROJECT_ID 
 } 

If you have not previously used the scp command, you might see an error similar to the following:

 ERROR:  
 ( 
gcloud.alpha.compute.tpus.tpu-vm.scp ) 
  
SSH  
Key  
is  
not  
present  
 in 
  
the  
SSH
agent.  
Please  
run  
 ` 
ssh-add  
/.../.ssh/google_compute_engine ` 
  
to  
add  
it,  
and  
try
again. 

To resolve the error, run the ssh-add command as displayed in the error message and rerun the command.

Run the code on the slice

Launch the example.py program on every VM:

 $  
 
gcloud  
compute  
tpus  
tpu-vm  
ssh  
 ${ 
 TPU_NAME 
 } 
  
 \ 
  
--zone = 
 ${ 
 ZONE 
 } 
  
 \ 
  
--project = 
 ${ 
 PROJECT_ID 
 } 
  
 \ 
  
--worker = 
all  
 \ 
  
--command = 
 "python3 ./example.py" 

Output (produced with a v5litepod-32 slice):

 global  
device  
count:  
 32 
 local 
  
device  
count:  
 4 
pmap  
result:  
 [ 
 32 
.  
 32 
.  
 32 
.  
 32 
. ] 
 

Clean up

When you are done with your TPU VM follow these steps to clean up your resources.

  1. Delete your Cloud TPU and Compute Engine resources.

     $  
     
    gcloud  
    compute  
    tpus  
    tpu-vm  
    delete  
     ${ 
     TPU_NAME 
     } 
      
     \ 
      
    --zone = 
     ${ 
     ZONE 
     } 
      
     \ 
      
    --project = 
     ${ 
     PROJECT_ID 
     } 
    
  2. Verify the resources have been deleted by running gcloud compute tpus execution-groups list . The deletion might take several minutes. The output from the following command shouldn't include any of the resources created in this tutorial:

     $  
     
    gcloud  
    compute  
    tpus  
    tpu-vm  
    list  
    --zone = 
     ${ 
     ZONE 
     } 
      
     \ 
    --project = 
     ${ 
     PROJECT_ID 
     } 
    
Design a Mobile Site
View Site in Mobile | Classic
Share by: