Scale ML workloads using Ray

This document provides details on how to run machine learning (ML) workloads with Ray and JAX on TPUs. There are two different modes for using TPUs with Ray: Device-centric mode (PyTorch/XLA) and Host-centric mode (JAX) .

This document assumes that you already have a TPU environment set up. For more information, see the following resources:

Device-centric mode (PyTorch/XLA)

Device-centric mode retains much of the programmatic style of classic PyTorch. In this mode, you add a new XLA device type, which works like any other PyTorch device. Each individual process interacts with one XLA device.

This mode is ideal if you are already familiar with PyTorch with GPUs and want to use similar coding abstractions.

The following sections describe how to run a PyTorch/XLA workload on one or more devices without using Ray, then how to run the same workload on multiple hosts using Ray.

Create a TPU

  1. Create environment variables for TPU creation parameters.

     export 
      
     PROJECT_ID 
     = 
     your-project-id 
     export 
      
     TPU_NAME 
     = 
     your-tpu-name 
     export 
      
     ZONE 
     = 
     europe-west4-b 
     export 
      
     ACCELERATOR_TYPE 
     = 
     v5p-8 
     export 
      
     RUNTIME_VERSION 
     = 
     v2-alpha-tpuv5 
    

    Environment variable descriptions

    Variable
    Description
    PROJECT_ID
    Your Google Cloud project ID. Use an existing project or create a new one.
    TPU_NAME
    The name of the TPU.
    ZONE
    The zone in which to create the TPU VM. For more information about supported zones, see TPU regions and zones .
    ACCELERATOR_TYPE
    The accelerator type specifies the version and size of the Cloud TPU you want to create. For more information about supported accelerator types for each TPU version, see TPU versions .
    RUNTIME_VERSION
    The Cloud TPU software version .
  2. Use the following command to create a v5p TPU VM with 8 cores:

    gcloud  
    compute  
    tpus  
    tpu-vm  
    create  
     $TPU_NAME 
      
     \ 
      
    --zone = 
     $ZONE 
      
     \ 
      
    --accelerator-type = 
     $ACCELERATOR_TYPE 
      
     \ 
      
    --version = 
     $RUNTIME_VERSION 
    
  3. Connect to the TPU VM using the following command:

    gcloud  
    compute  
    tpus  
    tpu-vm  
    ssh  
     $TPU_NAME 
      
    --zone = 
     $ZONE 
    

If you're using GKE, see the KubeRay on GKE guide for setup information.

Install requirements

Run the following commands on your TPU VM to install required dependencies:

  1. Save the following to a file. For example, requirements.txt .

     --find-links https://storage.googleapis.com/libtpu-releases/index.html
    --find-links https://storage.googleapis.com/libtpu-wheels/index.html
    torch~=2.6.0
    torch_xla[tpu]~=2.6.0
    ray[default]==2.40.0 
    
  2. To install required dependencies, run:

     pip  
    install  
    -r  
    requirements.txt 
    

If you're running your workload on GKE, we recommend creating a Dockerfile that installs the required dependencies. For an example, see Run your workload on TPU slice nodes in the GKE documentation.

Run a PyTorch/XLA workload on a single device

The following example demonstrates how to create a XLA tensor on a single device, which is a TPU chip. This is similar to how PyTorch handles other device types.

  1. Save the following code snippet to a file. For example, workload.py .

      import 
      
     torch 
     import 
      
     torch_xla 
     import 
      
     torch_xla.core.xla_model 
      
     as 
      
     xm 
     t 
     = 
     torch 
     . 
     randn 
     ( 
     2 
     , 
     2 
     , 
     device 
     = 
     xm 
     . 
     xla_device 
     ()) 
     print 
     ( 
     t 
     . 
     device 
     ) 
     print 
     ( 
     t 
     ) 
     
    

    The import torch_xla import statement initializes PyTorch/XLA, and the xm.xla_device() function returns the current XLA device, a TPU chip.

  2. Set the PJRT_DEVICE environment variable to TPU.

      export 
      
     PJRT_DEVICE 
     = 
    TPU 
    
  3. Run the script.

     python  
    workload.py 
    

    The output looks similar to the following. Make sure that the output indicates that the XLA device is found.

     xla:0
    tensor ([[ 
      
     0 
    .6220,  
    -1.4707 ] 
    ,  
     [ 
    -1.2112,  
     0 
    .7024 ]] 
    ,  
     device 
     = 
     'xla:0' 
     ) 
     
    

Run PyTorch/XLA on multiple devices

  1. Update the code snippet from the previous section to run on multiple devices.

      import 
      
     torch 
     import 
      
     torch_xla 
     import 
      
     torch_xla.core.xla_model 
      
     as 
      
     xm 
     def 
      
     _mp_fn 
     ( 
     index 
     ): 
     t 
     = 
     torch 
     . 
     randn 
     ( 
     2 
     , 
     2 
     , 
     device 
     = 
     xm 
     . 
     xla_device 
     ()) 
     print 
     ( 
     t 
     . 
     device 
     ) 
     print 
     ( 
     t 
     ) 
     if 
     __name__ 
     == 
     '__main__' 
     : 
     torch_xla 
     . 
     launch 
     ( 
     _mp_fn 
     , 
     args 
     = 
     ()) 
     
    
  2. Run the script.

     python  
    workload.py 
    

    If you run the code snippet on a TPU v5p-8, the output looks similar to the following:

     xla:0
    xla:0
    xla:0
    tensor ([[ 
      
     1 
    .2309,  
     0 
    .9896 ] 
    ,  
     [ 
      
     0 
    .5820,  
    -1.2950 ]] 
    ,  
     device 
     = 
     'xla:0' 
     ) 
    xla:0
    tensor ([[ 
      
     1 
    .2309,  
     0 
    .9896 ] 
    ,  
     [ 
      
     0 
    .5820,  
    -1.2950 ]] 
    ,  
     device 
     = 
     'xla:0' 
     ) 
    tensor ([[ 
      
     1 
    .2309,  
     0 
    .9896 ] 
    ,  
     [ 
      
     0 
    .5820,  
    -1.2950 ]] 
    ,  
     device 
     = 
     'xla:0' 
     ) 
    tensor ([[ 
      
     1 
    .2309,  
     0 
    .9896 ] 
    ,  
     [ 
      
     0 
    .5820,  
    -1.2950 ]] 
    ,  
     device 
     = 
     'xla:0' 
     ) 
     
    

torch_xla.launch() takes two arguments: a function and a list of parameters. It creates a process for each available XLA device and calls the function specified in the arguments. In this example, there are 4 TPU devices available, so torch_xla.launch() creates 4 processes and calls _mp_fn() on each device. Each process only has access to one device, so each device has the index 0, and xla:0 is printed for all processes.

Run PyTorch/XLA on multiple hosts with Ray

The following sections show how to run the same code snippet on a larger multi-host TPU slice. For more information about the multi-host TPU architecture, see System architecture .

In this example, you manually set up Ray. If you are already familiar with setting up Ray, you can skip to the last section, Run a Ray workload . For more information about setting up Ray for a production environment, see the following resources:

Create a multi-host TPU VM

  1. Create environment variables for TPU creation parameters.

     export 
      
     PROJECT_ID 
     = 
     your-project-id 
     export 
      
     TPU_NAME 
     = 
     your-tpu-name 
     export 
      
     ZONE 
     = 
     europe-west4-b 
     export 
      
     ACCELERATOR_TYPE 
     = 
     v5p-16 
     export 
      
     RUNTIME_VERSION 
     = 
     v2-alpha-tpuv5 
    

    Environment variable descriptions

    Variable
    Description
    PROJECT_ID
    Your Google Cloud project ID. Use an existing project or create a new one.
    TPU_NAME
    The name of the TPU.
    ZONE
    The zone in which to create the TPU VM. For more information about supported zones, see TPU regions and zones .
    ACCELERATOR_TYPE
    The accelerator type specifies the version and size of the Cloud TPU you want to create. For more information about supported accelerator types for each TPU version, see TPU versions .
    RUNTIME_VERSION
    The Cloud TPU software version .
  2. Create a multi-host TPU v5p with 2 hosts (a v5p-16, with 4 TPU chips on each host) using the following command:

    gcloud  
    compute  
    tpus  
    tpu-vm  
    create  
     $TPU_NAME 
      
     \ 
      
    --zone = 
     $ZONE 
      
     \ 
      
    --accelerator-type = 
     $ACCELERATOR_TYPE 
      
     \ 
      
    --version = 
     $RUNTIME_VERSION 
    

Set up Ray

A TPU v5p-16 has 2 TPU hosts, each with 4 TPU chips. In this example, you will start the Ray head node on one host and add the second host as a worker node to the Ray cluster.

  1. Connect to the first host using SSH.

    gcloud  
    compute  
    tpus  
    tpu-vm  
    ssh  
     $TPU_NAME 
      
    --zone = 
     $ZONE 
      
    --worker = 
     0 
    
  2. Install dependencies with the same requirements file as in the Install requirements section .

     pip  
    install  
    -r  
    requirements.txt 
    
  3. Start the Ray process.

     ray  
    start  
    --head  
    --port = 
     6379 
     
    

    The output looks similar to the following:

     Enable  
    usage  
    stats  
    collection?  
    This  
    prompt  
    will  
    auto-proceed  
     in 
      
     10 
      
    seconds  
    to  
    avoid  
    blocking  
    cluster  
    startup.  
    Confirm  
     [ 
    Y/n ] 
    :  
    y
    Usage  
    stats  
    collection  
    is  
    enabled.  
    To  
    disable  
    this,  
    add  
     ` 
    --disable-usage-stats ` 
      
    to  
    the  
     command 
      
    that  
    starts  
    the  
    cluster,  
    or  
    run  
    the  
    following  
    command:  
     ` 
    ray  
    disable-usage-stats ` 
      
    before  
    starting  
    the  
    cluster.  
    See  
    https://docs.ray.io/en/master/cluster/usage-stats.html  
     for 
      
    more  
    details.
    
    Local  
    node  
    IP:  
     10 
    .130.0.76
    
    --------------------
    Ray  
    runtime  
    started.
    --------------------
    
    Next  
    steps
    To  
    add  
    another  
    node  
    to  
    this  
    Ray  
    cluster,  
    run  
    ray  
    start  
    --address = 
     '10.130.0.76:6379' 
    To  
    connect  
    to  
    this  
    Ray  
    cluster:  
    import  
    ray  
    ray.init () 
    To  
    terminate  
    the  
    Ray  
    runtime,  
    run  
    ray  
    stop
    
    To  
    view  
    the  
    status  
    of  
    the  
    cluster,  
    use  
    ray  
    status 
    

    This TPU host is now the Ray head node. Make a note of the lines that show how to add another node to the Ray cluster, similar to the following:

     To  
    add  
    another  
    node  
    to  
    this  
    Ray  
    cluster,  
    run  
    ray  
    start  
    --address = 
     '10.130.0.76:6379' 
     
    

    You will use this command in a later step.

  4. Check the Ray cluster status:

     ray  
    status 
    

    The output looks similar to the following:

      ======== 
      
    Autoscaler  
    status:  
     2025 
    -01-14  
     22 
    :03:39.385610  
     ======== 
    Node  
    status
    ---------------------------------------------------------------
    Active: 1 
      
    node_bc0c62819ddc0507462352b76cc06b462f0e7f4898a77e5133c16f79
    Pending: ( 
    no  
    pending  
    nodes ) 
    Recent  
    failures: ( 
    no  
    failures ) 
    Resources
    ---------------------------------------------------------------
    Usage: 0 
    .0/208.0  
    CPU 0 
    .0/4.0  
    TPU 0 
    .0/1.0  
    TPU-v5p-16-head
    0B/268.44GiB  
    memory
    0B/119.04GiB  
    object_store_memory 0 
    .0/1.0  
    your-tpu-name
    
    Demands: ( 
    no  
    resource  
    demands ) 
     
    

    The cluster only contains 4 TPUs ( 0.0/4.0 TPU ) because you've only added the head node so far.

    Now that the head node is running, you can add the second host to the cluster.

  5. Connect to the second host using SSH.

    gcloud  
    compute  
    tpus  
    tpu-vm  
    ssh  
     $TPU_NAME 
      
    --zone = 
     $ZONE 
      
    --worker = 
     1 
    
  6. Install dependencies with the same requirements file as in the Install requirements section.

     pip  
    install  
    -r  
    requirements.txt 
    
  7. Start the Ray process. To add this node to the existing Ray cluster, use the command from the output of the ray start command. Make sure to replace the IP address and port in the following command:

    ray  
    start  
    --address = 
     ' 10.130.0.76:6379 
    ' 
    

    The output looks similar to the following:

     Local  
    node  
    IP:  
     10 
    .130.0.80 [ 
     2025 
    -01-14  
     22 
    :30:07,397  
    W  
     75572 
      
     75572 
     ] 
      
    global_state_accessor.cc:463:  
    Retrying  
    to  
    get  
    node  
    with  
    node  
    ID  
    35f9ac0675c91429805cdc1b97c3713422d97eee783ccb0c0304f5c1
    
    --------------------
    Ray  
    runtime  
    started.
    --------------------
    
    To  
    terminate  
    the  
    Ray  
    runtime,  
    run
    ray  
    stop 
    
  8. Check the Ray status again:

     ray  
    status 
    

    The output looks similar to the following:

      ======== 
      
    Autoscaler  
    status:  
     2025 
    -01-14  
     22 
    :45:21.485617  
     ======== 
    Node  
    status
    ---------------------------------------------------------------
    Active: 1 
      
    node_bc0c62819ddc0507462352b76cc06b462f0e7f4898a77e5133c16f79 1 
      
    node_35f9ac0675c91429805cdc1b97c3713422d97eee783ccb0c0304f5c1
    Pending: ( 
    no  
    pending  
    nodes ) 
    Recent  
    failures: ( 
    no  
    failures ) 
    Resources
    ---------------------------------------------------------------
    Usage: 0 
    .0/416.0  
    CPU 0 
    .0/8.0  
    TPU 0 
    .0/1.0  
    TPU-v5p-16-head
    0B/546.83GiB  
    memory
    0B/238.35GiB  
    object_store_memory 0 
    .0/2.0  
    your-tpu-name
    
    Demands: ( 
    no  
    resource  
    demands ) 
     
    

    The second TPU host is now a node in the cluster. The list of available resources now shows 8 TPUs ( 0.0/8.0 TPU ).

Run a Ray workload

  1. Update the code snippet to run on the Ray cluster:

      import 
      
     os 
     import 
      
     torch 
     import 
      
     torch_xla 
     import 
      
     torch_xla.core.xla_model 
      
     as 
      
     xm 
     import 
      
     ray 
     import 
      
     torch.distributed 
      
     as 
      
     dist 
     import 
      
     torch_xla.runtime 
      
     as 
      
     xr 
     from 
      
     torch_xla._internal 
      
     import 
     pjrt 
     # Defines the local PJRT world size, the number of processes per host. 
     LOCAL_WORLD_SIZE 
     = 
     4 
     # Defines the number of hosts in the Ray cluster. 
     NUM_OF_HOSTS 
     = 
     4 
     GLOBAL_WORLD_SIZE 
     = 
     LOCAL_WORLD_SIZE 
     * 
     NUM_OF_HOSTS 
     def 
      
     init_env 
     (): 
     local_rank 
     = 
     int 
     ( 
     os 
     . 
     environ 
     [ 
     'TPU_VISIBLE_CHIPS' 
     ]) 
     pjrt 
     . 
     initialize_multiprocess 
     ( 
     local_rank 
     , 
     LOCAL_WORLD_SIZE 
     ) 
     xr 
     . 
     _init_world_size_ordinal 
     () 
     # This decorator signals to Ray that the `print_tensor()` function should be run on a single TPU chip. 
     @ray 
     . 
     remote 
     ( 
     resources 
     = 
     { 
     "TPU" 
     : 
     1 
     }) 
     def 
      
     print_tensor 
     (): 
     # Initializes the runtime environment on each Ray worker. Equivalent to 
     # the `torch_xla.launch call` in the Run PyTorch/XLA on multiple devices section. 
     init_env 
     () 
     t 
     = 
     torch 
     . 
     randn 
     ( 
     2 
     , 
     2 
     , 
     device 
     = 
     xm 
     . 
     xla_device 
     ()) 
     print 
     ( 
     t 
     . 
     device 
     ) 
     print 
     ( 
     t 
     ) 
     ray 
     . 
     init 
     () 
     # Uses Ray to dispatch the function call across available nodes in the cluster. 
     tasks 
     = 
     [ 
     print_tensor 
     . 
     remote 
     () 
     for 
     _ 
     in 
     range 
     ( 
     GLOBAL_WORLD_SIZE 
     )] 
     ray 
     . 
     get 
     ( 
     tasks 
     ) 
     ray 
     . 
     shutdown 
     () 
     
    
  2. Run the script on the Ray head node. Replace ray-workload.py with the path to your script.

    python  
     ray-workload.py 
    

    The output looks similar to the following:

     WARNING:root:libtpu.so  
    and  
    TPU  
    device  
    found.  
    Setting  
     PJRT_DEVICE 
     = 
    TPU.
    xla:0
    xla:0
    xla:0
    xla:0
    xla:0
    tensor ([[ 
      
     0 
    .6220,  
    -1.4707 ] 
    ,  
     [ 
    -1.2112,  
     0 
    .7024 ]] 
    ,  
     device 
     = 
     'xla:0' 
     ) 
    tensor ([[ 
      
     0 
    .6220,  
    -1.4707 ] 
    ,  
     [ 
    -1.2112,  
     0 
    .7024 ]] 
    ,  
     device 
     = 
     'xla:0' 
     ) 
    xla:0
    xla:0
    tensor ([[ 
      
     0 
    .6220,  
    -1.4707 ] 
    ,  
     [ 
    -1.2112,  
     0 
    .7024 ]] 
    ,  
     device 
     = 
     'xla:0' 
     ) 
    tensor ([[ 
      
     0 
    .6220,  
    -1.4707 ] 
    ,  
     [ 
    -1.2112,  
     0 
    .7024 ]] 
    ,  
     device 
     = 
     'xla:0' 
     ) 
    tensor ([[ 
      
     0 
    .6220,  
    -1.4707 ] 
    ,  
     [ 
    -1.2112,  
     0 
    .7024 ]] 
    ,  
     device 
     = 
     'xla:0' 
     ) 
    tensor ([[ 
      
     0 
    .6220,  
    -1.4707 ] 
    ,  
     [ 
    -1.2112,  
     0 
    .7024 ]] 
    ,  
     device 
     = 
     'xla:0' 
     ) 
    tensor ([[ 
      
     0 
    .6220,  
    -1.4707 ] 
    ,  
     [ 
    -1.2112,  
     0 
    .7024 ]] 
    ,  
     device 
     = 
     'xla:0' 
     ) 
    xla:0
    tensor ([[ 
      
     0 
    .6220,  
    -1.4707 ] 
    ,  
     [ 
    -1.2112,  
     0 
    .7024 ]] 
    ,  
     device 
     = 
     'xla:0' 
     ) 
     
    

    The output indicates that the function was successfully called on each XLA device (8 devices in this example) in the multi-host TPU slice.

Host-centric mode (JAX)

The following sections describe the host-centric mode with JAX. JAX uses a functional programming paradigm and supports higher-level single program, multiple data (SPMD) semantics. Instead of having each process interact with a single XLA device, JAX code is designed to operate across multiple devices on a single host concurrently.

JAX is designed for high performance computing and can efficiently utilize TPUs for large-scale training and inference. This mode is ideal if you're familiar with functional programming concepts so that you can take advantage of JAX's full potential.

These instructions assume that you already have a Ray and TPU environment set up, including a software environment that includes JAX and other related packages. To create a Ray TPU cluster, follow the instructions in Start Google Cloud GKE cluster with TPUs for KubeRay . For more information about using TPUs with KubeRay, see Use TPUs with KubeRay .

Run a JAX workload on a single-host TPU

The following example script demonstrates how to run a JAX function on a Ray cluster with a single-host TPU, such as a v6e-4. If you have a multi-host TPU, this script stops responding due to JAX's multi-controller execution model . For more information about running Ray on a multi-host TPU, see Run a JAX workload on a multi-host TPU .

  1. Create environment variables for TPU creation parameters.

     export 
      
     PROJECT_ID 
     = 
     your-project-id 
     export 
      
     TPU_NAME 
     = 
     your-tpu-name 
     export 
      
     ZONE 
     = 
     europe-west4-a 
     export 
      
     ACCELERATOR_TYPE 
     = 
     v6e-4 
     export 
      
     RUNTIME_VERSION 
     = 
     v2-alpha-tpuv6e 
    

    Environment variable descriptions

    Variable
    Description
    PROJECT_ID
    Your Google Cloud project ID. Use an existing project or create a new one.
    TPU_NAME
    The name of the TPU.
    ZONE
    The zone in which to create the TPU VM. For more information about supported zones, see TPU regions and zones .
    ACCELERATOR_TYPE
    The accelerator type specifies the version and size of the Cloud TPU you want to create. For more information about supported accelerator types for each TPU version, see TPU versions .
    RUNTIME_VERSION
    The Cloud TPU software version .
  2. Use the following command to create a v6e TPU VM with 4 cores:

    gcloud  
    compute  
    tpus  
    tpu-vm  
    create  
     $TPU_NAME 
      
     \ 
      
    --zone = 
     $ZONE 
      
     \ 
      
    --accelerator-type = 
     $ACCELERATOR_TYPE 
      
     \ 
      
    --version = 
     $RUNTIME_VERSION 
    
  3. Connect to the TPU VM using the following command:

    gcloud  
    compute  
    tpus  
    tpu-vm  
    ssh  
     $TPU_NAME 
      
    --zone = 
     $ZONE 
    
  4. Install JAX and Ray on your TPU.

     pip  
    install  
    ray  
    jax [ 
    tpu ] 
      
    -f  
    https://storage.googleapis.com/jax-releases/libtpu_releases.html 
    
  5. Save the following code to a file. For example, ray-jax-single-host.py .

      import 
      
     ray 
     import 
      
     jax 
     @ray 
     . 
     remote 
     ( 
     resources 
     = 
     { 
     "TPU" 
     : 
     4 
     }) 
     def 
      
     my_function 
     () 
     - 
    > int 
     : 
     return 
     jax 
     . 
     device_count 
     () 
     h 
     = 
     my_function 
     . 
     remote 
     () 
     print 
     ( 
     ray 
     . 
     get 
     ( 
     h 
     )) 
     # => 4 
     
    

    If you're used to running Ray with GPUs, there are some key differences when using TPUs:

    • Rather than setting num_gpus , specify TPU as a custom resource and set the number of TPU chips.
    • Specify the TPU using the number of chips per Ray worker node. For example, if you're using a v6e-4, running a remote function with TPU set to 4 consumes the entire TPU host.
    • This is different from how GPUs typically run, with one process per host. Setting TPU to a number that isn't 4 is not recommended.
      • Exception: If you have a single-host v6e-8 or v5litepod-8 , you should set this value to 8.
  6. Run the script.

    python  
     ray-jax-single-host.py 
    

Run a JAX workload on a multi-host TPU

The following example script demonstrates how to run a JAX function on a Ray cluster with a multi-host TPU. The example script uses a v6e-16.

  1. Create environment variables for TPU creation parameters.

     export 
      
     PROJECT_ID 
     = 
     your-project-id 
     export 
      
     TPU_NAME 
     = 
     your-tpu-name 
     export 
      
     ZONE 
     = 
     europe-west4-a 
     export 
      
     ACCELERATOR_TYPE 
     = 
     v6e-16 
     export 
      
     RUNTIME_VERSION 
     = 
     v2-alpha-tpuv6e 
    

    Environment variable descriptions

    Variable
    Description
    PROJECT_ID
    Your Google Cloud project ID. Use an existing project or create a new one.
    TPU_NAME
    The name of the TPU.
    ZONE
    The zone in which to create the TPU VM. For more information about supported zones, see TPU regions and zones .
    ACCELERATOR_TYPE
    The accelerator type specifies the version and size of the Cloud TPU you want to create. For more information about supported accelerator types for each TPU version, see TPU versions .
    RUNTIME_VERSION
    The Cloud TPU software version .
  2. Use the following command to create a v6e TPU VM with 16 cores:

    gcloud  
    compute  
    tpus  
    tpu-vm  
    create  
     $TPU_NAME 
      
     \ 
      
    --zone = 
     $ZONE 
      
     \ 
      
    --accelerator-type = 
     $ACCELERATOR_TYPE 
      
     \ 
      
    --version = 
     $RUNTIME_VERSION 
    
  3. Install JAX and Ray on all TPU workers.

    gcloud  
    compute  
    tpus  
    tpu-vm  
    ssh  
     $TPU_NAME 
      
     \ 
      
    --zone = 
     $ZONE 
      
     \ 
      
    --worker = 
    all  
     \ 
      
    --command = 
     "pip install ray jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html" 
    
  4. Save the following code to a file. For example, ray-jax-multi-host.py .

      import 
      
     ray 
     import 
      
     jax 
     @ray 
     . 
     remote 
     ( 
     resources 
     = 
     { 
     "TPU" 
     : 
     4 
     }) 
     def 
      
     my_function 
     () 
     - 
    > int 
     : 
     return 
     jax 
     . 
     device_count 
     () 
     ray 
     . 
     init 
     () 
     num_tpus 
     = 
     ray 
     . 
     available_resources 
     ()[ 
     "TPU" 
     ] 
     num_hosts 
     = 
     int 
     ( 
     num_tpus 
     ) 
     # 4 
     h 
     = 
     [ 
     my_function 
     . 
     remote 
     () 
     for 
     _ 
     in 
     range 
     ( 
     num_hosts 
     )] 
     print 
     ( 
     ray 
     . 
     get 
     ( 
     h 
     )) 
     # [16, 16, 16, 16] 
     
    

    If you're used to running Ray with GPUs, there are some key differences when using TPUs:

  5. Copy the script to all TPU workers.

    gcloud  
    compute  
    tpus  
    tpu-vm  
    scp  
     ray-jax-multi-host.py 
      
     $TPU_NAME 
    :  
    --zone = 
     $ZONE 
      
    --worker = 
    all
  6. Run the script.

    gcloud  
    compute  
    tpus  
    tpu-vm  
    ssh  
     $TPU_NAME 
      
     \ 
      
    --zone = 
     $ZONE 
      
     \ 
      
    --worker = 
    all  
     \ 
      
    --command = 
     "python ray-jax-multi-host.py 
    " 
    

Run a Multislice JAX workload

Multislice lets you run workloads that span multiple TPU slices within a single TPU Pod or in multiple pods over the data center network.

You can use the ray-tpu package to simplify Ray's interactions with TPU slices.

Install ray-tpu using pip .

 pip  
install  
ray-tpu 

For more information about using the ray-tpu package, see Getting started in the GitHub repository. For an example using Multislice, see Running on Multislice .

Orchestrate workloads using Ray and MaxText

For more information about using Ray with MaxText, see Run a training job with MaxText .

TPU and Ray resources

Ray treats TPUs differently from GPUs to accommodate for the difference in usage. In the following example, there are nine Ray nodes total:

  • The Ray head node is running on an n1-standard-16 VM.
  • The Ray worker nodes are running on two v6e-16 TPUs. Each TPU constitutes four workers.
 $  
ray  
 status 
 ======== 
  
Autoscaler  
status:  
 2024 
-10-17  
 09 
:30:00.854415  
 ======== 
Node  
status
---------------------------------------------------------------
Active:  
 1 
  
node_e54a65b81456cee40fcab16ce7b96f85406637eeb314517d9572dab2  
 1 
  
node_9a8931136f8d2ab905b07d23375768f41f27cc42f348e9f228dcb1a2  
 1 
  
node_c865cf8c0f7d03d4d6cae12781c68a840e113c6c9b8e26daeac23d63  
 1 
  
node_435b1f8f1fbcd6a4649c09690915b692a5bac468598e9049a2fac9f1  
 1 
  
node_3ed19176e9ecc2ac240c818eeb3bd4888fbc0812afebabd2d32f0a91  
 1 
  
node_6a88fe1b74f252a332b08da229781c3c62d8bf00a5ec2b90c0d9b867  
 1 
  
node_5ead13d0d60befd3a7081ef8b03ca0920834e5c25c376822b6307393  
 1 
  
node_b93cb79c06943c1beb155d421bbd895e161ba13bccf32128a9be901a  
 1 
  
node_9072795b8604ead901c5268ffcc8cc8602c662116ac0a0272a7c4e04
Pending:  
 ( 
no  
pending  
nodes ) 
Recent  
failures:  
 ( 
no  
failures ) 
Resources
---------------------------------------------------------------
Usage:  
 0 
.0/727.0  
CPU  
 0 
.0/32.0  
TPU  
 0 
.0/2.0  
TPU-v6e-16-head  
0B/5.13TiB  
memory  
0B/1.47TiB  
object_store_memory  
 0 
.0/4.0  
tpu-group-0  
 0 
.0/4.0  
tpu-group-1

Demands:  
 ( 
no  
resource  
demands ) 
 

Resource usage field descriptions:

  • CPU : The total number of CPUs available in the cluster.
  • TPU : The number of TPU chips in the cluster.
  • TPU-v6e-16-head : A special identifier for the resource that corresponds with worker 0 of a TPU slice. This is important for accessing individual TPU slices.
  • memory : Worker heap memory used by your application.
  • object_store_memory : Memory used when your application creates objects in the object store using ray.put and when it returns values from remote functions.
  • tpu-group-0 and tpu-group-1 : Unique identifiers for the individual TPU slices. This is important for running jobs on slices. These fields are set to 4 because there are 4 hosts per TPU slice in a v6e-16.
Create a Mobile Website
View Site in Mobile | Classic
Share by: