Train a model using TPU v6e

This document guides you through training models on Cloud TPU v6e (also called Trillium), covering environment setup, performance optimization, and practical training examples using JAX and PyTorch/XLA.

TPU v6e, also called Trillium, is Google's 6th generation of TPUs. On all technical surfaces, such as the API and logs, and throughout this document, Trillium will be referred to as v6e. With 256 chips per Pod, the architecture of TPU v6e shares many similarities with v5e . TPU v6e is optimized for transformer, text-to-image, and convolutional neural network (CNN) training, fine-tuning, and serving. For more information about the TPU v6e system architecture and configurations, see TPU v6e .

For information about running inference on Cloud TPU v6e, see the following tutorials:

Before you begin

Before you begin, you need to:

  • Create a Google Cloud account and project with billing enabled
  • Install Google Cloud CLI alpha components
  • Enable the Cloud TPU API
  • Create a Cloud TPU service agent
  • Create a Cloud TPU service account and grant permissions

For more information, see Set up the Cloud TPU environment .

Verify quota and permissions

Verify that your project has the following quotas:

If you're using GKE with XPK, you need additional permissions in the Google Cloud console. For more information, see Permissions needed on Google Cloud console .

Provision TPUs

You can provision and manage TPU v6e using the following methods:

  • GKE: You can use GKE to provision and manage TPUs as a pool of accelerators for your containerized machine learning workloads. For more information, see About TPUs in GKE .
  • GKE and XPK: XPK is a command-line tool that simplifies cluster creation and workload execution on GKE. It's designed for ML practitioners to provision TPUs and run training jobs without needing deep Kubernetes expertise. For more information, see the XPK GitHub repository .
  • Cloud TPU queued resources: Queued resources let you request TPU capacity that is provisioned when it becomes available. It's ideal for batch jobs and fault-tolerant workloads that can wait in a queue. You can specify a time window for your request. For more information, see Manage queued resources .

Provision v6e Cloud TPUs with GKE and XPK

If you are using GKE commands with v6e, you can use Kubernetes commands or XPK to provision Cloud TPUs and train or serve models. See Plan for Cloud TPUs in GKE to learn how to plan your Cloud TPU configurations in GKE clusters. The following sections provide commands to create an XPK cluster with single-NIC support and multi-NIC support.

Create an XPK cluster with single-NIC support

 export 
  
 CLUSTER_NAME 
 = 
 xpk-cluster-name 
 export 
  
 ZONE 
 = 
 us-east1-d 
 export 
  
 PROJECT_ID 
 = 
 your-project-id 
 export 
  
 TPU_TYPE 
 = 
 v6e-256 
 export 
  
 NUM_SLICES 
 = 
  2 
 
 export 
  
 NETWORK_NAME 
 = 
 ${ 
 CLUSTER_NAME 
 } 
-mtu9k export 
  
 NETWORK_FW_NAME 
 = 
 ${ 
 NETWORK_NAME 
 } 
-fw
gcloud  
compute  
networks  
create  
 ${ 
 NETWORK_NAME 
 } 
  
 \ 
  
--mtu = 
 8896 
  
 \ 
  
--project = 
 ${ 
 PROJECT_ID 
 } 
  
 \ 
  
--subnet-mode = 
auto  
 \ 
  
--bgp-routing-mode = 
regional
gcloud  
compute  
firewall-rules  
create  
 ${ 
 NETWORK_FW_NAME 
 } 
  
 \ 
  
--network = 
 ${ 
 NETWORK_NAME 
 } 
  
 \ 
  
--allow  
tcp,icmp,udp  
 \ 
  
--project = 
 ${ 
 PROJECT_ID 
 } 
  export 
  
 CLUSTER_ARGUMENTS 
 = 
 "--network= 
 ${ 
 NETWORK_NAME 
 } 
 --subnetwork= 
 ${ 
 NETWORK_NAME 
 } 
 " 
 
python3  
xpk.py  
cluster  
create  
--cluster = 
 ${ 
 CLUSTER_NAME 
 } 
  
 \ 
  
--cluster-cpu-machine-type = 
 e2-standard-8 
  
 \ 
  
--num-slices = 
 ${ 
 NUM_SLICES 
 } 
  
 \ 
  
--tpu-type = 
 ${ 
 TPU_TYPE 
 } 
  
 \ 
  
--zone = 
 ${ 
 ZONE 
 } 
  
 \ 
  
--project = 
 ${ 
 PROJECT_ID 
 } 
  
 \ 
  
--on-demand  
 \ 
  
--custom-cluster-arguments = 
 " 
 ${ 
 CLUSTER_ARGUMENTS 
 } 
 " 
  
 \ 
  
--create-vertex-tensorboard

Command flag descriptions

Variable
Description
CLUSTER_NAME
The user-assigned name for the XPK cluster.
PROJECT_ID
Google Cloud project name. Use an existing project or create a new one. For more information, see Set up your Google Cloud project .
ZONE
See the Cloud TPU regions and zones document for the supported zones.
TPU_TYPE
NUM_SLICES
The number of slices you want to create
CLUSTER_ARGUMENTS
The network and subnetwork to use.

For example: --network=${NETWORK_NAME} --subnetwork=${NETWORK_NAME}

NUM_SLICES
The number of slices to create.
NETWORK_NAME
The name of a secondary network to use.
NETWORK_FW_NAME
The name of a secondary network firewall to use.

Create an XPK cluster with multi-NIC support

 export 
  
 CLUSTER_NAME 
 = 
 xpk-cluster-name 
 export 
  
 REGION 
 = 
 your-region 
 export 
  
 ZONE 
 = 
 us-east1-d 
 export 
  
 PROJECT_ID 
 = 
 your-project-id 
 export 
  
 TPU_TYPE 
 = 
 v6e-256 
 export 
  
 NUM_SLICES 
 = 
  2 
 
 export 
  
 NETWORK_NAME_1 
 = 
 ${ 
 CLUSTER_NAME 
 } 
-mtu9k-1- ${ 
 ZONE 
 } 
 export 
  
 SUBNET_NAME_1 
 = 
 ${ 
 CLUSTER_NAME 
 } 
-privatesubnet-1- ${ 
 ZONE 
 } 
 export 
  
 NETWORK_FW_NAME_1 
 = 
 ${ 
 NETWORK_NAME_1 
 } 
-fw-1- ${ 
 ZONE 
 } 
 export 
  
 FIREWALL_RULE_NAME 
 = 
 ${ 
 CLUSTER_NAME 
 } 
-privatefirewall-1- ${ 
 ZONE 
 } 
 export 
  
 ROUTER_NAME 
 = 
 ${ 
 CLUSTER_NAME 
 } 
-network-1- ${ 
 ZONE 
 } 
 export 
  
 NAT_CONFIG 
 = 
 ${ 
 CLUSTER_NAME 
 } 
-natconfig-1- ${ 
 ZONE 
 } 
gcloud  
compute  
networks  
create  
 ${ 
 NETWORK_NAME_1 
 } 
  
 \ 
  
--mtu = 
 8896 
  
 \ 
  
--bgp-routing-mode = 
regional  
 \ 
  
--subnet-mode = 
custom  
 \ 
  
--project = 
 ${ 
 PROJECT_ID 
 } 
gcloud  
compute  
networks  
subnets  
create  
 ${ 
 SUBNET_NAME_1 
 } 
  
 \ 
  
--network = 
 ${ 
 NETWORK_NAME_1 
 } 
  
 \ 
  
--range = 
 10 
.11.0.0/18  
 \ 
  
--region = 
 ${ 
 REGION 
 } 
  
 \ 
  
--project = 
 ${ 
 PROJECT_ID 
 } 
gcloud  
compute  
firewall-rules  
create  
 ${ 
 FIREWALL_RULE_NAME 
 } 
  
 \ 
  
--network = 
 ${ 
 NETWORK_NAME_1 
 } 
  
 \ 
  
--allow  
tcp,icmp,udp  
 \ 
  
--project = 
 ${ 
 PROJECT_ID 
 } 
gcloud  
compute  
routers  
create  
 ${ 
 ROUTER_NAME 
 } 
  
 \ 
  
--project = 
 ${ 
 PROJECT_ID 
 } 
  
 \ 
  
--network = 
 ${ 
 NETWORK_NAME_1 
 } 
  
 \ 
  
--region = 
 ${ 
 REGION 
 } 
gcloud  
compute  
routers  
nats  
create  
 ${ 
 NAT_CONFIG 
 } 
  
 \ 
  
--router = 
 ${ 
 ROUTER_NAME 
 } 
  
 \ 
  
--region = 
 ${ 
 REGION 
 } 
  
 \ 
  
--auto-allocate-nat-external-ips  
 \ 
  
--nat-all-subnet-ip-ranges  
 \ 
  
--project = 
 ${ 
 PROJECT_ID 
 } 
  
 \ 
  
--enable-logging
  # Secondary subnet for multi-nic experience. 
 # Need custom IP routing to be different from the first network's subnet. 
 export 
  
 NETWORK_NAME_2 
 = 
 ${ 
 CLUSTER_NAME 
 } 
-privatenetwork-2- ${ 
 ZONE 
 } 
 export 
  
 SUBNET_NAME_2 
 = 
 ${ 
 CLUSTER_NAME 
 } 
-privatesubnet-2- ${ 
 ZONE 
 } 
 export 
  
 FIREWALL_RULE_NAME 
 = 
 ${ 
 CLUSTER_NAME 
 } 
-privatefirewall-2- ${ 
 ZONE 
 } 
 export 
  
 ROUTER_NAME 
 = 
 ${ 
 CLUSTER_NAME 
 } 
-network-2- ${ 
 ZONE 
 } 
 export 
  
 NAT_CONFIG 
 = 
 ${ 
 CLUSTER_NAME 
 } 
-natconfig-2- ${ 
 ZONE 
 } 
 
gcloud  
compute  
networks  
create  
 ${ 
 NETWORK_NAME_2 
 } 
  
 \ 
  
--mtu = 
 8896 
  
 \ 
  
--bgp-routing-mode = 
regional  
 \ 
  
--subnet-mode = 
custom  
 \ 
  
--project = 
 ${ 
 PROJECT_ID 
 } 
gcloud  
compute  
networks  
subnets  
create  
 ${ 
 SUBNET_NAME_2 
 } 
  
 \ 
  
--network = 
 ${ 
 NETWORK_NAME_2 
 } 
  
 \ 
  
--range = 
 10 
.10.0.0/18  
 \ 
  
--region = 
 ${ 
 REGION 
 } 
  
 \ 
  
--project = 
 ${ 
 PROJECT_ID 
 } 
gcloud  
compute  
firewall-rules  
create  
 ${ 
 FIREWALL_RULE_NAME 
 } 
  
 \ 
  
--network = 
 ${ 
 NETWORK_NAME_2 
 } 
  
 \ 
  
--allow  
tcp,icmp,udp  
 \ 
  
--project = 
 ${ 
 PROJECT_ID 
 } 
gcloud  
compute  
routers  
create  
 ${ 
 ROUTER_NAME 
 } 
  
 \ 
  
--project = 
 ${ 
 PROJECT_ID 
 } 
  
 \ 
  
--network = 
 ${ 
 NETWORK_NAME_2 
 } 
  
 \ 
  
--region = 
 ${ 
 REGION 
 } 
gcloud  
compute  
routers  
nats  
create  
 ${ 
 NAT_CONFIG 
 } 
  
 \ 
  
--router = 
 ${ 
 ROUTER_NAME 
 } 
  
 \ 
  
--region = 
 ${ 
 REGION 
 } 
  
 \ 
  
--auto-allocate-nat-external-ips  
 \ 
  
--nat-all-subnet-ip-ranges  
 \ 
  
--project = 
 ${ 
 PROJECT_ID 
 } 
  
 \ 
  
--enable-logging
  export 
  
 CLUSTER_ARGUMENTS 
 = 
 "--enable-dataplane-v2 --enable-ip-alias --enable-multi-networking --network= 
 ${ 
 NETWORK_NAME_1 
 } 
 --subnetwork= 
 ${ 
 SUBNET_NAME_1 
 } 
 " 
 export 
  
 NODE_POOL_ARGUMENTS 
 = 
 "--additional-node-network network= 
 ${ 
 NETWORK_NAME_2 
 } 
 ,subnetwork= 
 ${ 
 SUBNET_NAME_2 
 } 
 " 
 
python3  
xpk.py  
cluster  
create  
 \ 
  
--cluster = 
 ${ 
 CLUSTER_NAME 
 } 
  
 \ 
  
--cluster-cpu-machine-type = 
 e2-standard-8 
  
 \ 
  
--num-slices = 
 ${ 
 NUM_SLICES 
 } 
  
 \ 
  
--tpu-type = 
 ${ 
 TPU_TYPE 
 } 
  
 \ 
  
--zone = 
 ${ 
 ZONE 
 } 
  
 \ 
  
--project = 
 ${ 
 PROJECT_ID 
 } 
  
 \ 
  
--on-demand  
 \ 
  
--custom-cluster-arguments = 
 " 
 ${ 
 CLUSTER_ARGUMENTS 
 } 
 " 
  
 \ 
  
--custom-nodepool-arguments = 
 " 
 ${ 
 NODE_POOL_ARGUMENTS 
 } 
 " 
  
 \ 
  
--create-vertex-tensorboard

Command flag descriptions

Variable
Description
CLUSTER_NAME
The user-assigned name for the XPK cluster.
PROJECT_ID
Google Cloud project name. Use an existing project or create a new one. For more information, see Set up your Google Cloud project .
ZONE
See the Cloud TPU regions and zones document for the supported zones.
TPU_TYPE
NUM_SLICES
The number of slices you want to create
CLUSTER_ARGUMENTS
The network and subnetwork to use.

For example: --enable-dataplane-v2 --enable-ip-alias --enable-multi-networking --network=${NETWORK_NAME_1} --subnetwork=${SUBNET_NAME_1}

NODE_POOL_ARGUMENTS
Additional node network to use.

For example: --additional-node-network network=${NETWORK_NAME_2},subnetwork=${SUBNET_NAME_2}

NUM_SLICES
The number of slices to create (needed for Multislice only).
NETWORK_NAME
The name of a secondary network to use.
NETWORK_FW_NAME
The name of a secondary network firewall to use.

Set up JAX or PyTorch

The following resources show how to set up JAX or PyTorch on your Cloud TPU, depending on which provisioning and management method you use:

To set up and run XPK with MaxText , see Running MaxText at Scale with XPK .

Optimize network performance

This section describes how to optimize your network performance by configuring the maximum transmission unit (MTU), using multi-NIC for Multislice environments, and improving TCP settings.

Configure MTU

For the best network performance, use a network with 8,896 MTU (maximum transmission unit).

By default, a Virtual Private Cloud (VPC) only provides an MTU of 1,460 bytes, which provides suboptimal network performance. You can set a VPC network's MTU to any value between 1,300 bytes and 8,896 bytes (inclusive). Common custom MTU sizes are 1,500 bytes (standard Ethernet) or 8,896 bytes (the maximum possible). For more information, see Valid VPC network MTU sizes .

For more information about changing the MTU setting for an existing or default network, see Change the MTU setting of a VPC network .

The following example creates a network with 8,896 MTU and a corresponding firewall rule that allows TCP, ICMP, and UDP traffic within the network.

 export 
  
 RESOURCE_NAME 
 = 
 your-resource-name 
 export 
  
 NETWORK_NAME 
 = 
 ${ 
 RESOURCE_NAME 
 } 
-privatenetwork export 
  
 NETWORK_FW_NAME 
 = 
 ${ 
 RESOURCE_NAME 
 } 
-privatefirewall
gcloud  
compute  
networks  
create  
 ${ 
 NETWORK_NAME 
 } 
  
--mtu = 
 8896 
  
--project = 
 ${ 
 PROJECT_ID 
 } 
  
 \ 
  
--subnet-mode = 
auto  
--bgp-routing-mode = 
regional
gcloud  
compute  
firewall-rules  
create  
 ${ 
 NETWORK_FW_NAME 
 } 
  
--network = 
 ${ 
 NETWORK_NAME 
 } 
  
 \ 
  
--allow  
tcp,icmp,udp  
--project = 
 ${ 
 PROJECT_ID 
 } 

Replace your-resource-name with a base name for the network and firewall.

Use the multi-NIC option for Multislice

If you're using a Multislice environment, set the following environment variables, which are required for a secondary subnet:

 export 
  
 NETWORK_NAME_2 
 = 
 ${ 
 RESOURCE_NAME 
 } 
 export 
  
 SUBNET_NAME_2 
 = 
 ${ 
 RESOURCE_NAME 
 } 
 export 
  
 FIREWALL_RULE_NAME 
 = 
 ${ 
 RESOURCE_NAME 
 } 
 export 
  
 ROUTER_NAME 
 = 
 ${ 
 RESOURCE_NAME 
 } 
-network-2 export 
  
 NAT_CONFIG 
 = 
 ${ 
 RESOURCE_NAME 
 } 
-natconfig-2 export 
  
 REGION 
 = 
 your-region 

Use the following commands to create custom IP routing for the network and subnet.

  1. Create the secondary network.

     gcloud  
    compute  
    networks  
    create  
     ${ 
     NETWORK_NAME_2 
     } 
      
    --mtu = 
     8896 
      
     \ 
    --bgp-routing-mode = 
    regional  
    --subnet-mode = 
    custom  
    --project = 
     ${ 
     PROJECT_ID 
     } 
     
    
  2. Create a subnetwork for the secondary network.

     gcloud  
    compute  
    networks  
    subnets  
    create  
     ${ 
     SUBNET_NAME_2 
     } 
      
     \ 
    --network = 
     ${ 
     NETWORK_NAME_2 
     } 
      
     \ 
    --range = 
     10 
    .10.0.0/18  
    --region = 
     ${ 
     REGION 
     } 
      
     \ 
    --project = 
     ${ 
     PROJECT_ID 
     } 
     
    
  3. Create a firewall rule to allow traffic within the new subnetwork.

     gcloud  
    compute  
    firewall-rules  
    create  
     ${ 
     FIREWALL_RULE_NAME 
     } 
      
     \ 
    --network = 
     ${ 
     NETWORK_NAME_2 
     } 
      
    --allow  
    tcp,icmp,udp  
     \ 
    --source-ranges  
     10 
    .10.0.0/18  
    --project = 
     ${ 
     PROJECT_ID 
     } 
     
    
  4. Create a Cloud Router for the secondary network.

     gcloud  
    compute  
    routers  
    create  
     ${ 
     ROUTER_NAME 
     } 
      
     \ 
    --project = 
     ${ 
     PROJECT_ID 
     } 
      
     \ 
    --network = 
     ${ 
     NETWORK_NAME_2 
     } 
      
     \ 
    --region = 
     ${ 
     REGION 
     } 
     
    
  5. Create a NAT configuration for the Cloud Router.

     gcloud  
    compute  
    routers  
    nats  
    create  
     ${ 
     NAT_CONFIG 
     } 
      
     \ 
    --router = 
     ${ 
     ROUTER_NAME 
     } 
      
     \ 
    --region = 
     ${ 
     REGION 
     } 
      
     \ 
    --auto-allocate-nat-external-ips  
     \ 
    --nat-all-subnet-ip-ranges  
     \ 
    --project = 
     ${ 
     PROJECT_ID 
     } 
      
     \ 
    --enable-logging 
    

After you create a multi-network slice, you can validate that both network interface cards (NICs) are being used by setting up an XPK cluster and adding the --command ifconfig flag to the XPK workload creation command .

  1. Use the following workload create command to display the output of the ifconfig command in Google Cloud console logs and check that both eth0 and eth1 have MTU set to 8,896.

    python3  
    xpk.py  
    workload  
    create  
     \ 
      
    --cluster  
     CLUSTER_NAME 
      
     \ 
      
     { 
    --base-docker-image  
    maxtext_base_image  
     | 
      
    --docker-image  
     your-cloud-image-name 
     } 
      
     \ 
      
    --workload = 
     ${ 
     USER 
     } 
    -xpk- ${ 
     ACCELERATOR_TYPE 
     } 
    - ${ 
     NUM_SLICES 
     } 
      
     \ 
      
    --tpu-type = 
     ${ 
     ACCELERATOR_TYPE 
     } 
      
     \ 
      
    --num-slices = 
     ${ 
     NUM_SLICES 
     } 
      
     \ 
      
    --on-demand  
     \ 
      
    --zone = 
     ${ 
     ZONE 
     } 
      
     \ 
      
    --project = 
     ${ 
     PROJECT_ID 
     } 
      
     \ 
      
    --command  
     "ifconfig" 
    

    If you want to enable debug logs or use Vertex AI TensorBoard , add the following optional arguments to the command:

    --enable-debug-logs  
     \ 
    --use-vertex-tensorboard
  2. Verify that both eth0 and eth1 have MTU set to 8,896 by checking the output of the XPK workload in Google Cloud console logs.

Improve TCP settings

If you provisioned your Cloud TPUs using queued resources, you can run the following command to improve network performance by increasing TCP receive buffer limits.

gcloud  
alpha  
compute  
tpus  
queued-resources  
ssh  
 " 
 ${ 
 QUEUED_RESOURCE_ID 
 } 
 " 
  
 \ 
  
--project  
 " 
 ${ 
 PROJECT_ID 
 } 
 " 
  
 \ 
  
--zone  
 " 
 ${ 
 ZONE 
 } 
 " 
  
 \ 
  
--node = 
all  
 \ 
  
--worker = 
all  
 \ 
  
--command = 
 ' 
 sudo sh -c "echo \"4096 41943040 314572800\" > /proc/sys/net/ipv4/tcp_rmem"' 

Optimize memory allocation performance

The tcmalloc library is used by default on Cloud TPU VMs to improve performance for models with sizable, frequent memory allocations. This is configured through the LD_PRELOAD environment variable.

However, for some workloads (for example, DLRM with very large embedding table allocations), tcmalloc can cause a slowdown. In such cases, you can revert to the standard malloc function by unsetting the LD_PRELOAD variable in your shell session before running your training script:

  unset 
  
LD_PRELOAD 

Use SkyPilot

You can use Cloud TPU v6e with SkyPilot . SkyPilot is an open-source framework that simplifies the process of running, managing, and scaling AI workloads. You can add v6e-related location and pricing information to SkyPilot. For more information, see the SkyPilot TPU v6e example .

Training examples

The following sections provide examples for training MaxText, MaxDiffusion, and PyTorch models on Cloud TPU v6e.

These examples have been tested with the following software versions:

  • Python 3.10 or later
  • Nightly software versions:
    • Nightly JAX 0.4.32.dev20240912
    • Nightly LibTPU 0.1.dev20240912+nightly
  • Stable software versions:
    • JAX + JAX Lib of v0.4.37

Train MaxText and MaxDiffusion on Cloud TPU v6e

The following sections cover the training lifecycle of the MaxText and MaxDiffusion models.

In general, the high-level steps are:

  1. Build the workload base image.
  2. Run your workload using XPK.
    1. Build the training command for the workload.
    2. Deploy the workload.
  3. Follow the workload and view metrics.
  4. Delete the XPK workload if it isn't needed.
  5. Delete the XPK cluster when it's no longer needed.

Build base image

Install MaxText or MaxDiffusion and build the Docker image:

  1. Clone the repository you want to use and change to the directory for the repository:

    MaxText:

     git  
    clone  
    https://github.com/google/maxtext.git && 
     cd 
      
    maxtext 
    

    MaxDiffusion:

     git  
    clone  
    https://github.com/google/maxdiffusion.git && 
     cd 
      
    maxdiffusion && 
    git  
    checkout  
    4a8155ec0129512812b31930f0a91c6d5a141103 
    
  2. Configure Docker to use the Google Cloud CLI:

     gcloud  
    auth  
    configure-docker 
    
  3. Build the Docker image using the following command or using a JAX AI image. For more information about JAX AI images, see JAX AI images .

    MaxText:

     bash  
    docker_build_dependency_image.sh  
     MODE 
     = 
    stable  
     JAX_VERSION 
     = 
     0 
    .4.35 
    

    MaxDiffusion:

     bash  
    .github/workflows/build_and_upload_images.sh  
     CLOUD_IMAGE_NAME 
     = 
    maxdiffusion_jax_stable_stack  
     MODE 
     = 
    jax_ai_image  
     PROJECT 
     = 
     ${ 
     PROJECT_ID 
     } 
      
     LOCAL_IMAGE_NAME 
     = 
    maxdiffusion_jax_stable_stack  
     BASEIMAGE 
     = 
    us-docker.pkg.dev/cloud-tpu-images/jax-ai-image/tpu:latest 
    
  4. Set your project ID in your active gcloud CLI configuration:

     gcloud  
    config  
     set 
      
    project  
     ${ 
     PROJECT_ID 
     } 
     
    
  5. If you're launching the workload from a machine that doesn't have the image built locally, upload the image.

    1. Set the CLOUD_IMAGE_NAME environment variable:

        export 
        
       CLOUD_IMAGE_NAME 
       = 
       ${ 
       USER 
       } 
      _runner 
      
    2. Upload the image:

       bash  
      docker_upload_runner.sh  
       ${ 
       CLOUD_IMAGE_NAME 
       } 
       
      

Run your workload using XPK

  1. Set the following environment variables if you're not using the default values set by MaxText or MaxDiffusion :

     export 
      
     BASE_OUTPUT_DIR 
     = 
    gs:// YOUR_BUCKET 
     export 
      
     PER_DEVICE_BATCH_SIZE 
     = 
      2 
     
     export 
      
     NUM_STEPS 
     = 
      30 
     
     export 
      
     MAX_TARGET_LENGTH 
     = 
      8192 
     
    
  2. Build your model script. This script will be copied as a training command in a later step.

    Don't execute the model script yet.

    MaxText

    MaxText is a high performance, highly scalable, open-source LLM written in pure Python and JAX and targeting Google Cloud TPUs and GPUs for training and inference.

      JAX_PLATFORMS 
     = 
    tpu,cpu  
     \ 
     ENABLE_PJRT_COMPATIBILITY 
     = 
     true 
      
     \ 
     TPU_SLICE_BUILDER_DUMP_CHIP_FORCE 
     = 
     true 
      
     \ 
     TPU_SLICE_BUILDER_DUMP_ICI 
     = 
     true 
     && 
     \ 
    python3  
    -m  
    MaxText.train  
    MaxText/configs/base.yml  
     \ 
      
     base_output_directory 
     = 
     ${ 
     BASE_OUTPUT_DIR 
     } 
      
     \ 
      
     dataset_type 
     = 
    synthetic  
     \ 
      
     per_device_batch_size 
     = 
     ${ 
     PER_DEVICE_BATCH_SIZE 
     } 
      
     \ 
      
     enable_checkpointing 
     = 
     false 
      
     \ 
      
     gcs_metrics 
     = 
     true 
      
     \ 
      
     profiler 
     = 
    xplane  
     \ 
      
     skip_first_n_steps_for_profiler 
     = 
     5 
      
     \ 
      
     steps 
     = 
     ${ 
     NUM_STEPS 
     } 
      
     # attention='dot_product'" 
     
    

    Gemma2

    Gemma is a family of open-weights LLMs developed by Google DeepMind, based on Gemini research and technology.

     python3  
    -m  
    MaxText.train  
    MaxText/configs/base.yml  
     \ 
      
     model_name 
     = 
    gemma2-27b  
     \ 
      
     run_name 
     = 
    gemma2-27b-run  
     \ 
      
     base_output_directory 
     = 
     ${ 
     BASE_OUTPUT_DIR 
     } 
      
     \ 
      
     max_target_length 
     = 
     ${ 
     MAX_TARGET_LENGTH 
     } 
      
     \ 
      
     per_device_batch_size 
     = 
     ${ 
     PER_DEVICE_BATCH_SIZE 
     } 
      
     \ 
      
     steps 
     = 
     ${ 
     NUM_STEPS 
     } 
      
     \ 
      
     enable_checkpointing 
     = 
     false 
      
     \ 
      
     use_iota_embed 
     = 
     true 
      
     \ 
      
     gcs_metrics 
     = 
     true 
      
     \ 
      
     dataset_type 
     = 
    synthetic  
     \ 
      
     profiler 
     = 
    xplane  
     \ 
      
     attention 
     = 
    flash 
    

    Mixtral 8x7b

    Mixtral is a state-of-the-art AI model developed by Mistral AI, utilizing a sparse mixture-of-experts (MoE) architecture.

     python3  
    -m  
    MaxText.train  
    MaxText/configs/base.yml  
     \ 
      
     base_output_directory 
     = 
     ${ 
     BASE_OUTPUT_DIR 
     } 
      
     \ 
      
     per_device_batch_size 
     = 
     ${ 
     PER_DEVICE_BATCH_SIZE 
     } 
      
     \ 
      
     model_name 
     = 
    mixtral-8x7b  
     \ 
      
     steps 
     = 
     ${ 
     NUM_STEPS 
     } 
      
     \ 
      
     max_target_length 
     = 
     ${ 
     MAX_TARGET_LENGTH 
     } 
      
     \ 
      
     tokenizer_path 
     = 
    assets/tokenizer.mistral-v1  
     \ 
      
     attention 
     = 
    flash  
     \ 
      
     dtype 
     = 
    bfloat16  
     \ 
      
     dataset_type 
     = 
    synthetic  
     \ 
      
     profiler 
     = 
    xplane 
    

    Llama3-8b

    Llama is a family of open-weights LLMs developed by Meta.

    For an example of how to run Llama3 on PyTorch, see torch_xla models in the torchprime repository .

    MaxDiffusion

    MaxDiffusion is a collection of reference implementations of various latent diffusion models written in pure Python and JAX that run on XLA devices including Cloud TPUs and GPUs. Stable Diffusion is a latent text-to-image model that generates photo-realistic images from any text input.

    You need to install a specific Git branch to run MaxDiffusion as shown in the following training script.

     git  
    clone  
    https://github.com/google/maxdiffusion.git
    &&  
     cd 
      
    maxdiffusion
    &&  
    git  
    checkout  
    4a8155ec0129512812b31930f0a91c6d5a141103
    &&  
    pip  
    install  
    -r  
    requirements.txt && 
    pip  
    install  
    .
    &&  
    pip  
    install  
     huggingface_hub 
     == 
     0 
    .30.2 && 
     OUT_DIR 
     = 
     ${ 
     BASE_OUTPUT_DIR 
     } 
    &&  
    python  
    src/maxdiffusion/train_sdxl.py  
     \ 
      
    src/maxdiffusion/configs/base_xl.yml  
     \ 
      
     revision 
     = 
    refs/pr/95  
     \ 
      
     activations_dtype 
     = 
    bfloat16  
     \ 
      
     weights_dtype 
     = 
    bfloat16  
     \ 
      
     resolution 
     = 
     1024 
      
     \ 
      
     per_device_batch_size 
     = 
     1 
      
     \ 
      
     output_dir 
     = 
     ${ 
     OUT_DIR 
     } 
      
     \ 
      
     jax_cache_dir 
     = 
     ${ 
     OUT_DIR 
     } 
    /cache_dir/  
     \ 
      
     max_train_steps 
     = 
     200 
      
     \ 
      
     attention 
     = 
    flash  
     \ 
      
     run_name 
     = 
    sdxl-ddp-v6e 
    
  3. Export the following variables:

     export 
      
     CLUSTER_NAME 
     = 
     CLUSTER_NAME 
     export 
      
     ACCELERATOR_TYPE 
     = 
     ACCELERATOR_TYPE 
     export 
      
     NUM_SLICES 
     = 
     NUM_SLICES 
     export 
      
     YOUR_MODEL_SCRIPT 
     = 
     YOUR_MODEL_SCRIPT 
    

    Environment variable descriptions

    Variable
    Description
    CLUSTER_NAME
    The name of your XPK cluster.
    ACCELERATOR_TYPE
    The accelerator type specifies the version and size of the Cloud TPU you want to create. For more information about supported accelerator types for each TPU version, see TPU versions .
    NUM_SLICES
    The number of TPU slices.
    YOUR_MODEL_SCRIPT
    The model script to execute as a training command.
  4. Run the model using the script you created in the previous step. You must either specify the --base-docker-image flag to use the MaxText base image or specify the --docker-image flag and the image you want to use.

    You can choose to add the following optional flags:

    • You can enable debug logging by including the --enable-debug-logs flag. For more information, see Debug JAX on MaxText .
    • You can create a Vertex AI Experiment to upload data to Vertex AI TensorBoard by including the --use-vertex-tensorboard flag. For more information, see Monitor JAX on MaxText using Vertex AI .
    python3  
    xpk.py  
    workload  
    create  
     \ 
      
    --cluster  
     ${ 
     CLUSTER_NAME 
     } 
      
     \ 
      
     { 
    --base-docker-image  
    maxtext_base_image  
     | 
      
    --docker-image  
    gcr.io/ ${ 
     PROJECT_ID 
     } 
    / ${ 
     CLOUD_IMAGE_NAME 
     } 
    :latest } 
      
     \ 
      
    --workload = 
     ${ 
     USER 
     } 
    -xpk- ${ 
     ACCELERATOR_TYPE 
     } 
    - ${ 
     NUM_SLICES 
     } 
      
     \ 
      
    --tpu-type = 
     ${ 
     ACCELERATOR_TYPE 
     } 
      
     \ 
      
    --num-slices = 
     ${ 
     NUM_SLICES 
     } 
      
     \ 
      
    --on-demand  
     \ 
      
    --zone = 
     ${ 
     ZONE 
     } 
      
     \ 
      
    --project = 
     ${ 
     PROJECT_ID 
     } 
      
     \ 
      
    --command = 
     " 
      ${ 
     YOUR_MODEL_SCRIPT 
     } 
     
     " 
    

    The output includes a link to follow your workload. Open the link and click the Logstab to track your workload in real time.

Debug JAX on MaxText

Use supplemental XPK commands to diagnose why the cluster or workload isn't running:

Monitor JAX on MaxText using Vertex AI

To use TensorBoard, your Google Cloud user account must have the aiplatform.user role. Run the following command to grant this role:

gcloud  
projects  
add-iam-policy-binding  
 your-project-id 
  
 \ 
  
--member = 
 'user: your-email 
' 
  
 \ 
  
--role = 
 'roles/aiplatform.user' 

View scalar and profile data through the Vertex AI managed TensorBoard.

  1. Increase resource management (CRUD) requests for the zone you're using from 600 to 5000. This might not be an issue for small workloads using less than 16 VMs.

  2. Install dependencies such as cloud-accelerator-diagnostics for Vertex AI:

     # xpk dependencies will install cloud-accelerator-diagnostics for Vertex AI 
     cd 
      
    ~/xpk
    pip  
    install  
    .
  3. Create your XPK cluster using the --create-vertex-tensorboard flag, as documented in Create Vertex AI TensorBoard . You can also run this command on existing clusters.

  4. Create your Vertex AI experiment when running your XPK workload using the --use-vertex-tensorboard flag and the optional --experiment-name flag. For the full list of steps, see Create Vertex AI Experiment to upload data to Vertex AI TensorBoard .

The logs include a link to a Vertex AI TensorBoard, similar to the following:

View  
your  
TensorBoard  
at  
https://us-central1.tensorboard.googleusercontent.com/experiment/project_id+locations+us-central1+tensorboards+hash+experiments+name

You can also find the Vertex AI TensorBoard link in the Google Cloud console. Go to Vertex AI Experiments in the Google Cloud console. Select the appropriate region from the drop-down.

The TensorBoard directory is also written to the Cloud Storage bucket that you specified with ${BASE_OUTPUT_DIR} .

Delete your XPK workload

Use the xpk workload delete command to delete one or more workloads based on the job prefix or job status. This command might be useful if you sent XPK workloads that no longer need to be run, or if you have jobs that are stuck in the queue.

Delete your XPK cluster

Use the xpk cluster delete command to delete your cluster:

python3  
xpk.py  
cluster  
delete  
--cluster  
 ${ 
 CLUSTER_NAME 
 } 
  
 \ 
  
--zone = 
 ${ 
 ZONE 
 } 
  
--project = 
 ${ 
 PROJECT_ID 
 } 

MaxDiffusion benchmarking results

We ran the training script for MaxDiffusion on a v6e-4, a v6e-16, and two v6e-16. The following table shows the measured throughputs.

v6e-4 v6e-16 Two v6e-16
Training steps
0.069 0.073 0.13
Global batch size
8 32 64
Throughput (examples/sec)
115.9 438.4 492.3

Train Llama models using PyTorch/XLA on Cloud TPU v6e

This section describes how to train Llama models using PyTorch/XLA on Cloud TPU v6e using the WikiText dataset.

Get access to Hugging Face and the Llama 3 model

You need a Hugging Face user access token for this example. For information about creating user access tokens, see the Hugging Face documentation on user access tokens .

You also need permission to access the Llama-3-8B model on Hugging Face. To get access, go to the Meta-Llama-3-8B model on HuggingFace and request access.

Create a Cloud TPU VM

Create a Cloud TPU v6e with 8 chips for this example.

  1. Set up environment variables:

     export 
      
     PROJECT_ID 
     = 
     your-project-id 
     export 
      
     TPU_NAME 
     = 
     your-tpu-name 
     export 
      
     ZONE 
     = 
     us-east1-d 
     export 
      
     ACCELERATOR_TYPE 
     = 
     v6e-8 
     export 
      
     RUNTIME_VERSION 
     = 
     v2-alpha-tpuv6e 
    

    Environment variable descriptions

    Variable
    Description
    PROJECT_ID
    Your Google Cloud project ID. Use an existing project or create a new one.
    TPU_NAME
    The name of the TPU.
    ZONE
    The zone in which to create the TPU VM. For more information about supported zones, see TPU regions and zones .
    ACCELERATOR_TYPE
    The accelerator type specifies the version and size of the Cloud TPU you want to create. For more information about supported accelerator types for each TPU version, see TPU versions .
    RUNTIME_VERSION
    The Cloud TPU software version .
  2. Create a Cloud TPU VM:

    gcloud  
    alpha  
    compute  
    tpus  
    tpu-vm  
    create  
     ${ 
     TPU_NAME 
     } 
      
    --version = 
     ${ 
     RUNTIME_VERSION 
     } 
      
     \ 
      
    --accelerator-type = 
     ${ 
     ACCELERATOR_TYPE 
     } 
      
     \ 
      
    --zone = 
     ${ 
     ZONE 
     } 
      
     \ 
      
    --project = 
     ${ 
     PROJECT_ID 
     } 
    

Installation

Install the pytorch-tpu/transformers fork of Hugging Face transformers and dependencies. This example was tested with the following dependency versions:

  • torch : compatible with 2.5.0
  • torch_xla[tpu] : compatible with 2.5.0
  • jax : 0.4.33
  • jaxlib : 0.4.33
gcloud  
alpha  
compute  
tpus  
tpu-vm  
ssh  
 ${ 
 TPU_NAME 
 } 
  
 \ 
  
--project = 
 ${ 
 PROJECT_ID 
 } 
  
 \ 
  
--zone  
 ${ 
 ZONE 
 } 
  
 \ 
  
--worker = 
all  
 \ 
  
--command = 
 'git clone -b flash_attention https://github.com/pytorch-tpu/transformers.git 
 cd transformers 
 sudo pip3 install -e . 
 pip3 install datasets 
 pip3 install evaluate 
 pip3 install scikit-learn 
 pip3 install accelerate 
 pip install torch~=2.6.0 torch_xla[tpu]~=2.6.0 -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html 
 pip install jax==0.4.38 jaxlib==0.4.38 -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/' 

Set up model configuration files

The training command in the next section, Run the model , uses two JSON configuration files to define model parameters and Fully Sharded Data Parallel (FSDP) configuration. FSDP sharding lets you use a bigger batch size while training by sharding your model weights across multiple TPUs. When training with smaller models, it might be sufficient to use data parallelism and replicate the weights on each device. For more information about how to shard tensors across devices in PyTorch/XLA, see PyTorch/XLA SPMD user guide .

  1. Create the model parameter configuration file. The following is the model parameter configuration for Llama-3-8B. For other models, find the configuration file on Hugging Face . For example, see the Llama-2-7B config .

     cat > 
    llama-config.json << 
    EOF { 
      
     "architectures" 
    :  
     [ 
      
     "LlamaForCausalLM" 
      
     ] 
    ,  
     "attention_bias" 
    :  
    false,  
     "attention_dropout" 
    :  
     0 
    .0,  
     "bos_token_id" 
    :  
     128000 
    ,  
     "eos_token_id" 
    :  
     128001 
    ,  
     "hidden_act" 
    :  
     "silu" 
    ,  
     "hidden_size" 
    :  
     4096 
    ,  
     "initializer_range" 
    :  
     0 
    .02,  
     "intermediate_size" 
    :  
     14336 
    ,  
     "max_position_embeddings" 
    :  
     8192 
    ,  
     "model_type" 
    :  
     "llama" 
    ,  
     "num_attention_heads" 
    :  
     32 
    ,  
     "num_hidden_layers" 
    :  
     32 
    ,  
     "num_key_value_heads" 
    :  
     8 
    ,  
     "pretraining_tp" 
    :  
     1 
    ,  
     "rms_norm_eps" 
    :  
    1e-05,  
     "rope_scaling" 
    :  
    null,  
     "rope_theta" 
    :  
     500000 
    .0,  
     "tie_word_embeddings" 
    :  
    false,  
     "torch_dtype" 
    :  
     "bfloat16" 
    ,  
     "transformers_version" 
    :  
     "4.40.0.dev0" 
    ,  
     "use_cache" 
    :  
    false,  
     "vocab_size" 
    :  
     128256 
     } 
    EOF 
    
  2. Create the FSDP configuration file:

     cat > 
    fsdp-config.json << 
    EOF { 
      
     "fsdp_transformer_layer_cls_to_wrap" 
    :  
     [ 
      
     "LlamaDecoderLayer" 
      
     ] 
    ,  
     "xla" 
    :  
    true,  
     "xla_fsdp_v2" 
    :  
    true,  
     "xla_fsdp_grad_ckpt" 
    :  
     true 
     } 
    EOF 
    

    For more information about FSDP, see Fully Sharded Data Parallel using SPMD .

  3. Upload the configuration files to your Cloud TPU VMs using the following command:

    gcloud  
    alpha  
    compute  
    tpus  
    tpu-vm  
    scp  
    llama-config.json  
    fsdp-config.json  
     ${ 
     TPU_NAME 
     } 
    :.  
     \ 
      
    --worker = 
    all  
     \ 
      
    --project = 
     ${ 
     PROJECT_ID 
     } 
      
     \ 
      
    --zone = 
     ${ 
     ZONE 
     } 
    

Run the model

Using the configuration files you created in the previous section, run the run_clm.py script to train the Llama-3-8B model on the WikiText dataset. The training script takes approximately 10 minutes to run on a Cloud TPU v6e-8.

  1. Sign in to Hugging Face on your Cloud TPU using the following command:

    gcloud  
    alpha  
    compute  
    tpus  
    tpu-vm  
    ssh  
     ${ 
     TPU_NAME 
     } 
      
     \ 
      
    --project = 
     ${ 
     PROJECT_ID 
     } 
      
     \ 
      
    --zone  
     ${ 
     ZONE 
     } 
      
     \ 
      
    --worker = 
    all  
     \ 
      
    --command = 
     ' 
     pip3 install "huggingface_hub[cli]" 
     huggingface-cli login --token HUGGING_FACE_TOKEN 
    ' 
    
  2. Run the model training:

    gcloud  
    alpha  
    compute  
    tpus  
    tpu-vm  
    ssh  
     ${ 
     TPU_NAME 
     } 
      
     \ 
      
    --project = 
     ${ 
     PROJECT_ID 
     } 
      
     \ 
      
    --zone  
     ${ 
     ZONE 
     } 
      
     \ 
      
    --worker = 
    all  
     \ 
      
    --command = 
     ' 
     export PJRT_DEVICE=TPU 
     export XLA_USE_SPMD=1 
     export ENABLE_PJRT_COMPATIBILITY=true 
     # Optional variables for debugging: 
     export XLA_IR_DEBUG=1 
     export XLA_HLO_DEBUG=1 
     export PROFILE_EPOCH=0 
     export PROFILE_STEP=3 
     export PROFILE_DURATION_MS=100000 
     # Set PROFILE_LOGDIR to a local VM path or gs://my-bucket/profile_path 
     export PROFILE_LOGDIR= PROFILE_PATH 
     
     python3 transformers/examples/pytorch/language-modeling/run_clm.py \ 
     --dataset_name wikitext \ 
     --dataset_config_name wikitext-2-raw-v1 \ 
     --per_device_train_batch_size 16 \ 
     --do_train \ 
     --output_dir /home/$USER/tmp/test-clm \ 
     --overwrite_output_dir \ 
     --config_name /home/$USER/llama-config.json \ 
     --cache_dir /home/$USER/cache \ 
     --tokenizer_name meta-llama/Meta-Llama-3-8B \ 
     --block_size 8192 \ 
     --optim adafactor \ 
     --save_strategy no \ 
     --logging_strategy no \ 
     --fsdp "full_shard" \ 
     --fsdp_config /home/$USER/fsdp-config.json \ 
     --torch_dtype bfloat16 \ 
     --dataloader_drop_last yes \ 
     --flash_attention \ 
     --max_steps 20' 
    

Troubleshooting PyTorch/XLA

If you set the optional variables for debugging in the previous section, the profile for the model will be stored at the location specified by the variable PROFILE_LOGDIR . You can extract the xplane.pb file stored at this location and use tensorboard to view the profiles in your browser using the TensorBoard instructions .

If PyTorch/XLA isn't performing as expected, see the Troubleshooting guide , which has suggestions for debugging, profiling, and optimizing your model.

Design a Mobile Site
View Site in Mobile | Classic
Share by: