This tutorial shows you how to train large language models (LLMs) like Llama 3 70B on Google Kubernetes Engine (GKE) using MaxText , Ray Train , and Multislice Trillium TPUs . This tutorial provides a complete, end-to-end walkthrough, from configuring the necessary secondary data center networking to submitting and successfully running a distributed training workload across 32 physical TPU chips.
This tutorial is for Platform admins, operators, and AI specialists who want to learn how to overcome the memory and networking challenges of training 70-billion parameter models on distributed, multi-host TPU slices.
Background
The combination of GKE, KubeRay, MaxText, and TPUs provides a powerful and scalable platform for large-scale model training. This section describes the key technologies used in this guide:
JAX
JAX is a Python library for accelerator-oriented array computation and program transformation, utilizing the XLA compiler to create highly optimized code that scales efficiently on accelerators.
MaxText
MaxText is a high-performance, open-source LLM framework designed for scalability and customizability. MaxText is built on top of JAX and is optimized to run efficiently on Cloud TPUs.
TPUs
Tensor Processing Units (TPUs) are custom-designed accelerators created by Google to optimize machine learning workloads. Unlike general-purpose CPUs or parallel-processing GPUs, TPUs are highly specialized for the massive matrix and tensor computations at the foundation of deep learning, making them efficient at this specific task. The primary advantage of TPUs is performance at scale.
This tutorial uses TPU Trillium , the sixth generation of TPUs, in a Multislice deployment pattern. Cloud TPU Multislice is where two or more Cloud TPU slices communicate over the data center network (DCN). Multislice enables full-stack, cost-effective, large scale training with near-linear scaling up to tens of thousands of TPU chips. For more information about Multislice, see Cloud TPU Multislice Overview .
KubeRay
KubeRay is a Kubernetes operator that provides a unified way to deploy, manage, and monitor Ray applications on Kubernetes. The KubeRay operator is installed and managed through the Ray on GKE add-on , which is the recommended way to deploy and manage Ray clusters on GKE.
GKE Dynamic Resource Allocation Network (DRANET)
GKE DRANET (Dynamic Resource Allocation Network) is a feature that dynamically attaches high-performance network devices to Pods, bypassing standard Kubernetes networking and enabling high performance over the DCN.
Objectives
This tutorial shows you how to do the following:
- Set up a GKE cluster with two multi-host TPU node pools.
- Configure a secondary DCN for cross-slice TPU communication.
- Configure KubeRay to manage the distributed training environment.
- Deploy a RayCluster custom resource by using Dynamic Resource Allocation (DRA) for network attachments.
- Create a Python training script by utilizing Ray Train's JaxTrainer to orchestrate the MaxText training loop across the TPU slices.
- Run a baseline Llama 3 8B training job.
- Scale up to Llama 3 70B utilizing 2D sharding (Tensor Parallelism and FSDP) over the DCN.
Before you begin
- Sign in to your Google Cloud account. If you're new to Google Cloud, create an account to evaluate how our products perform in real-world scenarios. New customers also get $300 in free credits to run, test, and deploy workloads.
-
Install the Google Cloud CLI.
-
If you're using an external identity provider (IdP), you must first sign in to the gcloud CLI with your federated identity .
-
To initialize the gcloud CLI, run the following command:
gcloud init
-
Create or select a Google Cloud project .
Roles required to select or create a project
- Select a project : Selecting a project doesn't require a specific IAM role—you can select any project that you've been granted a role on.
- Create a project
: To create a project, you need the Project Creator role
(
roles/resourcemanager.projectCreator), which contains theresourcemanager.projects.createpermission. Learn how to grant roles .
-
Create a Google Cloud project:
gcloud projects create PROJECT_IDReplace
PROJECT_IDwith a name for the Google Cloud project you are creating. -
Select the Google Cloud project that you created:
gcloud config set project PROJECT_IDReplace
PROJECT_IDwith your Google Cloud project name.
-
Verify that billing is enabled for your Google Cloud project .
-
Enable the required APIs:
Roles required to enable APIs
To enable APIs, you need the Service Usage Admin IAM role (
roles/serviceusage.serviceUsageAdmin), which contains theserviceusage.services.enablepermission. Learn how to grant roles .gcloud services enable container.googleapis.com
cloudbuild.googleapis.com -
Install the Google Cloud CLI.
-
If you're using an external identity provider (IdP), you must first sign in to the gcloud CLI with your federated identity .
-
To initialize the gcloud CLI, run the following command:
gcloud init
-
Create or select a Google Cloud project .
Roles required to select or create a project
- Select a project : Selecting a project doesn't require a specific IAM role—you can select any project that you've been granted a role on.
- Create a project
: To create a project, you need the Project Creator role
(
roles/resourcemanager.projectCreator), which contains theresourcemanager.projects.createpermission. Learn how to grant roles .
-
Create a Google Cloud project:
gcloud projects create PROJECT_IDReplace
PROJECT_IDwith a name for the Google Cloud project you are creating. -
Select the Google Cloud project that you created:
gcloud config set project PROJECT_IDReplace
PROJECT_IDwith your Google Cloud project name.
-
Verify that billing is enabled for your Google Cloud project .
-
Enable the required APIs:
Roles required to enable APIs
To enable APIs, you need the Service Usage Admin IAM role (
roles/serviceusage.serviceUsageAdmin), which contains theserviceusage.services.enablepermission. Learn how to grant roles .gcloud services enable container.googleapis.com
cloudbuild.googleapis.com -
Grant roles to your user account. Run the following command once for each of the following IAM roles:
roles/container.admin, roles/iam.serviceAccountAdmin, roles/cloudbuild.builds.editorgcloud projects add-iam-policy-binding PROJECT_ID --member = "user: USER_IDENTIFIER " --role = ROLE
Replace the following:
-
PROJECT_ID: Your project ID. -
USER_IDENTIFIER: The identifier for your user account. For example,myemail@example.com. -
ROLE: The IAM role that you grant to your user account.
-
- Because this tutorial utilizes TPU Trillium (v6e), select a region or zone with availability. For more information, see Cloud TPU quotas .
Prepare your environment
In this tutorial, you use Cloud Shell
. Cloud Shell comes
preinstalled with the gcloud
, helm
, and kubectl
command-line tools that
are used in this tutorial.
-
Go to the Google Cloud console .
-
At the top of the Google Cloud console window, click the Activate Cloud Shell
button.A Cloud Shell session opens inside a new frame in the Google Cloud console and displays a command-line prompt.
-
In your terminal, clone the
kubernetes-engine-samplesrepository:git clone https://github.com/GoogleCloudPlatform/kubernetes-engine-samples.git -
Change to the directory containing the sample files:
cd kubernetes-engine-samples/ai-ml/gke-ray/raytrain/maxtext -
Create and activate a Python virtual environment:
python3 -m venv ray-env source ray-env/bin/activate -
Install the Ray CLI:
pip install "ray[default]==2.55.0" -
Set the following environment variables:
export PROJECT_ID = $( gcloud config get project ) export PROJECT_NUMBER = $( gcloud projects describe ${ PROJECT_ID } --format = "value(projectNumber)" ) export GS_BUCKET = GS_BUCKET export KSA_NAME = KSA_NAME export NAMESPACE = default export CLUSTER_NAME = CLUSTER_NAME export REGION = REGION export ZONE = ZONE export CLUSTER_VERSION = 1 .35.2-gke.1842000Replace the following:
-
GS_BUCKET: the name of the Cloud Storage bucket. -
KSA_NAME: the name of the Kubernetes Service Account. -
CLUSTER_NAME: the name of the new cluster. -
REGION: the region where your TPU Trillium capacity is available. -
ZONE: the zone where your TPU Trillium capacity is available. For more information, see TPU availability in GKE .
-
Configure cluster networking for Cloud TPU Multislice
Within a multi-host TPU slice, TPU devices communicate over the high-speed
inter-chip interconnects. However, when running Multislice jobs, the TPU slices
must communicate with each other over the DCN. Standard
Kubernetes Pod networks can bottleneck this traffic.
The ct6e-standard-4t
machine type is backed by multiple physical network
interface cards (NICs). To achieve the best performance, you create two additional
VPC networks and use GKE DRANET to connect them directly to the Ray Pods.
-
Create the two additional VPC networks with a large maximum training unit (MTU):
gcloud compute networks create ${ CLUSTER_NAME } -net-1 \ --subnet-mode = custom \ --mtu = 8896 gcloud compute networks create ${ CLUSTER_NAME } -net-2 \ --subnet-mode = custom \ --mtu = 8896 -
Create the dedicated subnets:
gcloud compute networks subnets create tpu-subnet-1 \ --network = ${ CLUSTER_NAME } -net-1 \ --region = ${ REGION } \ --range = 10 .50.0.0/16 gcloud compute networks subnets create tpu-subnet-2 \ --network = ${ CLUSTER_NAME } -net-2 \ --region = ${ REGION } \ --range = 10 .60.0.0/16
Create a GKE cluster
You can configure KubeRay on TPUs in a GKE Autopilot or Standard cluster. We recommend that you use a Autopilot cluster for a fully managed Kubernetes experience. To choose the GKE mode of operation that's the best fit for your workloads, see About GKE modes of operation .
To use GKE managed DRANET, your cluster must use version 1.35.2-gke.1842000 or later for Autopilot mode, or 1.34.1-gke.1829001 or later for Standard mode. This tutorial uses version 1.35.2-gke.1842000.
Autopilot
-
In Cloud Shell, run the following command:
gcloud container clusters create-auto $CLUSTER_NAME \ --enable-ray-operator \ --machine-type = n1-standard-16 \ --location = $REGION \ --cluster-version = ${ CLUSTER_VERSION } -
To communicate with your cluster, configure
kubectl:gcloud container clusters get-credentials CLUSTER_NAME \ --location = $REGION
Standard
-
In Cloud Shell, create a Standard cluster that enables the Ray operator add-on by running the following command:
gcloud container clusters create $CLUSTER_NAME \ --addons = RayOperator,GcsFuseCsiDriver \ --machine-type = n1-standard-16 \ --enable-dataplane-v2 \ --workload-pool = $PROJECT_ID .svc.id.goog \ --location = $ZONE \ --cluster-version = ${ CLUSTER_VERSION }This command also enables the
GcsFuseCsiDriver, which allows Pods to mount Cloud Storage buckets as local file systems. The cluster creation might take several minutes. -
To communicate with your cluster, configure
kubectl:gcloud container clusters get-credentials CLUSTER_NAME \ --location = $ZONE -
Create the first multi-host TPU slice node pool with GKE DRANET enabled:
gcloud container node-pools create v6e-16-0 \ --location = $ZONE \ --cluster = $CLUSTER_NAME \ --machine-type = ct6e-standard-4t \ --threads-per-core = 1 \ --tpu-topology = 4x4 \ --num-nodes = 4 \ --additional-node-network = network = ${ CLUSTER_NAME } -net-1,subnetwork = tpu-subnet-1 \ --additional-node-network = network = ${ CLUSTER_NAME } -net-2,subnetwork = tpu-subnet-2 \ --node-labels = cloud.google.com/gke-networking-dra-driver = true \ --enable-gvnic \ --scopes = https://www.googleapis.com/auth/cloud-platform -
Create the second TPU slice node pool:
gcloud container node-pools create v6e-16-1 \ --location = $ZONE \ --cluster = $CLUSTER_NAME \ --machine-type = ct6e-standard-4t \ --threads-per-core = 1 \ --tpu-topology = 4x4 \ --num-nodes = 4 \ --additional-node-network = network = ${ CLUSTER_NAME } -net-1,subnetwork = tpu-subnet-1 \ --additional-node-network = network = ${ CLUSTER_NAME } -net-2,subnetwork = tpu-subnet-2 \ --node-labels = cloud.google.com/gke-networking-dra-driver = true \ --enable-gvnic \ --scopes = https://www.googleapis.com/auth/cloud-platform
GKE provisions a node pool consisting of four TPU Trillium (v6e)
VMs, which are configured together as a multi-host TPU slice that has a 4x4
topology. This node pool is ready for distributed training workloads.
The Ray operator -enabled GKE cluster automatically installs KubeRay and the KubeRay TPU webhook in your cluster.
Configure a Cloud Storage bucket and a service account
-
Create a Cloud Storage bucket for shared checkpoints between the multi-host TPU nodes.
gsutil mb -p ${ PROJECT_ID } -c STANDARD -l ${ REGION } gs:// ${ GS_BUCKET } -
To enable access to the Cloud Storage bucket, create a Kubernetes Service Account:
kubectl create serviceaccount ${ KSA_NAME } --namespace ${ NAMESPACE } -
To enable access to the Cloud Storage bucket, add the required IAM policy bindings to the service account:
gcloud storage buckets add-iam-policy-binding gs:// ${ GS_BUCKET } \ --member "principal://iam.googleapis.com/projects/ ${ PROJECT_NUMBER } /locations/global/workloadIdentityPools/ ${ PROJECT_ID } .svc.id.goog/subject/ns/ ${ NAMESPACE } /sa/ ${ KSA_NAME } " \ --role "roles/storage.objectUser"
Create the training script
The maxtext_multi_slice_trainer.py
script uses Ray Train's JaxTrainer to run a distributed MaxText
training job across two TPU slices. The script configures the training
environment for eight multi-host TPU workers and runs the MaxText training job on
each worker node. The train_loop_per_worker
function wraps the MaxText main
entry point, and uses the Ray's distributed scheduler to execute the MaxText
trainer on a multi-host TPU slice:
The preceding script defines a JaxTrainer instance requesting eight workers
and a topology of 4x4
. Internally, Ray provisions a SlicePlacementGroup
across the two TPU slices and helps ensure that the Ray Train workers run atomically
across both slices, with one worker per host.
Train the model
-
The
ray-cluster.tpu-multi-slice.yamlmanifest in the current directory defines the RayCluster custom resource. This manifest includes the DRANETResourceClaimTemplateto provision the network devices for GKE DRANET and Multislice:The preceding RayCluster spec creates a TPU worker group with eight workers (
numOfHosts: 4) per replica, with two replicas. Each worker requests four TPU chips (google.com/tpu: "4"). The workers are each scheduled on a TPU Trillium node (tpu-v6e-slice), which is part of the same colocated multi-host slice. KubeRay scales all four workers in a slice atomically. The required JAX environment variables, as well as Pod Affinities for scheduling, are bootstrapped by GKE through a mutating webhook . -
To create the RayCluster, apply the manifest:
envsubst < ray-cluster.tpu-multi-slice.yaml | kubectl apply -f - -
Verify that the cluster is ready and running:
kubectl get rayclusters maxtext-tpu-clusterThe output should be similar to the following:
NAME DESIRED WORKERS AVAILABLE WORKERS CPUS MEMORY GPUS STATUS AGE maxtext-tpu-cluster 8 8 72 1579277216Ki 0 ready 2m11s -
To access the Ray Dashboard through the Ray head service, establish a port-forwarding session:
kubectl port-forward svc/maxtext-tpu-cluster-head-svc 8265 :8265 2>&1 >/dev/null & -
Verify that the RayCluster is reachable from your local environment:
ray list nodes --address http://localhost:8265The output should be similar to the following:
ray list nodes --address http://localhost:8265 2026-04-21 10:20:20,080 - INFO - Note: NumExpr detected 64 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8. 2026-04-21 10:20:20,080 - INFO - NumExpr defaulting to 8 threads. ======== List: 2026-04-21 10:20:20.945431 ======== Stats: ------------------------------ Total: 9 Table: ------------------------------ NODE_ID NODE_IP IS_HEAD_NODE STATE STATE_MESSAGE NODE_NAME RESOURCES_TOTAL LABELS 0 4f0e4d742de5375047c7688f4d2bc64a42d1e5c77c2d8344b3b375a1 10.68.9.5 False ALIVE 10.68.9.5 CPU: 8.0 ray.io/accelerator-type: TPU-V6E TPU: 4.0 ray.io/node-group: tpu-group accelerator_type:TPU-V6E: 1.0 ray.io/node-id: 4f0e4d742... memory: 186.265 GiB ray.io/tpu-pod-type: v6e-16 node:10.68.9.5: 1.0 ray.io/tpu-slice-name: tpu-group-0 object_store_memory: 186.265 GiB ray.io/tpu-topology: 4x4 tpu-group-0: 1.0 ray.io/tpu-worker-id: '1' ... 6 ce7056807b95831ce107ba1951dac34b80635e6fdbb312e7f9649938 10.68.2.9 True ALIVE 10.68.2.9 CPU: 8.0 ray.io/node-group: headgroup memory: 16.000 GiB ray.io/node-id: ce7056807... node:10.68.2.9: 1.0 node:__internal_head__: 1.0 object_store_memory: 4.765 GiB ... -
Download the base MaxText configuration file. This file is required by the training script to set the model's default hyperparameters:
curl -O https://raw.githubusercontent.com/google/maxtext/maxtext-v0.2.1/src/maxtext/configs/base.yml -
Submit the JaxTrainer script to the RayCluster and check that the RayJob completes successfully:
Llama 3 8B
ray
job
submit
\
--address
http://localhost:8265
\
--working-dir
.
\
--runtime-env-json
'{"excludes": ["ray-env", ".git"]}'
\
--
python
maxtext_multi_slice_trainer.py
\
base.yml
\
base_output_directory
=
/data/
\
dataset_type
=
synthetic
\
per_device_batch_size
=
4
\
max_target_length
=
4096
\
model_name
=
llama3-8b
\
steps
=
100
\
ici_fsdp_parallelism
=
4
\
ici_tensor_parallelism
=
4
\
run_name
=
rayjob-multi-slice
Llama 3 70B
ray
job
submit
\
--address
http://localhost:8265
\
--working-dir
.
\
--runtime-env-json
'{"excludes": ["ray-env", ".git"]}'
\
--
python
maxtext_multi_slice_trainer.py
\
base.yml
\
base_output_directory
=
/data/
\
dataset_type
=
synthetic
\
per_device_batch_size
=
2
\
max_target_length
=
4096
\
model_name
=
llama3-70b
\
steps
=
100
\
ici_tensor_parallelism
=
4
\
ici_fsdp_parallelism
=
4
\
dcn_fsdp_parallelism
=
2
\
dcn_data_parallelism
=
1
\
remat_policy
=
full
\
run_name
=
rayjob-multi-slice-70b-fsdp
The preceding command submits the Python script, which calls the JaxTrainer
Ray code to the RayCluster. The ray job submit
command includes some
MaxText-specific arguments to pass to the model configuration.
In your terminal, you should see output similar to the following for the Llama 3 70B job:
[process=5][thread=save_finalize][step=99] CheckpointManager Save Finalize is done on all hosts. [repeated 7x across cluster]
(RayTrainWorker pid=130520, ip=10.60.7.7) [process=5][thread=TrainingThread(train_fn_with_final_checkpoint_flush)][step=99][wait_until_finished] Done waiting for Save Finalize thread (save_finalize) running at step=99. [repeated 7x across cluster]
(RayTrainWorker pid=130520, ip=10.60.7.7) [process=5][thread=TrainingThread(train_fn_with_final_checkpoint_flush)][wait_until_finished] No Save Finalize thread to wait for. Returning. [repeated 6x across cluster]
(RayTrainWorker pid=130520, ip=10.60.7.7) completed step: 99, seconds: 0.693, TFLOP/s/device: 83.171, Tokens/s/device: 11819.175, total_weights: 262144, loss: 0.334 [repeated 6x across cluster]
------------------------------------------
Job 'raysubmit_XwUdZMrhsYRKvjqs' succeeded
------------------------------------------
Run Multislice elastic training on Spot VMs
When using highly sought-after accelerators like TPUs, utilizing Spot VMs might significantly reduce costs. However, Spot VMs may be preempted unexpectedly.
Ray Train supports elastic training, which allows your job to dynamically scale the number of participating TPU slices up or down without failing. If a slice is preempted, Ray pauses the training loop, waits for the remaining workers to reorganize, restores from the latest MaxText checkpoint, and resumes training on the smaller footprint.
To enable elastic training, change the num_workers
parameter in your ScalingConfig
from
a static integer to a tuple representing (minimum_workers, maximum_workers)
.
Additionally, add a FailureConfig(max_failures=3)
to the RunConfig
, which instructs
Ray Train to retry the training loop up to 3 times instead of failing the job entirely
when a worker is preempted.
Update the Ray Train script
-
The
maxtext_elastic_trainer.pyscript in the current directory enables elastic training. Notice that it setsnum_workers=(4,8), which tells Ray to proceed if at least one 16-chip slice (four workers) is available, but to scale up to two slices (eight workers) if possible. It includes aFailureConfigto enable elastic training, define the number of retries, and help ensure the job survives preemptions: -
Submit the job by using the Ray Job CLI. Be sure to provide a unique
run_nameso the checkpoints don't conflict with previous runs.ray job submit \ --address http://localhost:8265 \ --working-dir . \ --runtime-env-json '{"excludes": ["ray-env", ".git"]}' \ -- python maxtext_elastic_trainer.py \ base.yml \ base_output_directory = /data/ \ dataset_type = synthetic \ per_device_batch_size = 4 \ max_target_length = 4096 \ model_name = llama3-8b \ steps = 100 \ ici_fsdp_parallelism = 4 \ ici_tensor_parallelism = 4 \ run_name = rayjob-elastic-8b -
To simulate a node termination or preemption during training, delete a Pod.
kubectl delete pod $( kubectl get pods -l ray.io/node-type = worker -o jsonpath = '{.items[0].metadata.name}' )
The terminal logs a worker failure, but the orchestration controller keeps
the job alive and automatically resumes from the /data/rayjob-elastic-8b/checkpoints
checkpoint after the minimum topology is available.
Because MaxText dynamically recalculates the device mesh upon resumption, you don't need to write any custom logic to handle checkpoint re-sharding when the topology shrinks. JAX's Orbax checkpointer will automatically re-shard the saved weights into the new physical layout before continuing the training loop. The following output shows the Ray Train controller detect newly available TPU resources in the cluster and perform a scaling operation from one slice (four workers) to two slices (eight workers) during training.
...
(pid=, ip=10.68.9.5) W0421 04:19:07.570048 20579 grpc_transport.cc:1930] GetMultiSliceTopology returned with status: UNAVAILABLE: failed to connect to all addresses; last error: UNKNOWN: ipv4:10.68.8.5:9915: connect endpoint failed (Failed to connect to remote host: Connection refused)
...
(TrainController pid=23150) Detected changes in the cluster resources. Deciding to resize the worker group from 4 -> 8 workers.
(TrainController pid=23150) Using SlicePlacementGroup utility to reserve 2 slice(s) with topology '4x4'...
(TrainController pid=23150) Attempting to start training worker group of size 8 with the following resources: [{'TPU': 4, 'accelerator_type:TPU-V6E': 0.001}] * 8
Clean up
To avoid incurring charges to your Google Cloud account for the resources used in this tutorial, either delete the project that contains the resources, or keep the project and delete the individual resources.
-
Delete the RayCluster:
kubectl delete raycluster maxtext-tpu-cluster -
Delete the GKE cluster:
gcloud container clusters delete $CLUSTER_NAME --zone = $ZONE -
Delete the Cloud Storage bucket:
gsutil rm -r gs:// ${ GS_BUCKET }
What's next
- Learn about Ray on Kubernetes .
- Learn how to serve vLLM on GKE with TPUs .
- Learn more about TPUs in GKE .

