Run a batch workload with Pathways

For the purpose of this document, batch workloads are defined as JAX workloads that execute to completion and are deployed within the same GKE cluster as the Pathways cluster, specifically alongside the Pathways controller components (IFRT proxy server and Pathways resource manager). Completion of the JAX workload also terminates the Pathways cluster components. This guide uses a JAX training workload to demonstrate this.

Before you begin

Make sure you have:

Build a training image using Maxtext

MaxText is an open-source large language model (LLM) project developed by Google. It's written in JAX and designed to be highly performant and scalable, running efficiently on Google Cloud TPUs and GPUs.

To build a MaxText Docker image by using the latest version of stable JAX from the OSS GitHub repository, run the following command:

 git clone https://github.com/AI-Hypercomputer/maxtext 
 cd maxtext/ 
 gcloud config set project PROJECT 
 
 bash ./docker_build_dependency_image.sh MODE=stable 
 gcloud auth configure-docker 
 bash ./docker_upload_runner.sh CLOUD_IMAGE_NAME= USER 
_runner 
  
 # This script needs bash version >= 4.2 to execute. 

This command pushes the MaxText Kubernetes image to gcr.io/$PROJECT/${USER}_runner . You can use this Docker image to run training on TPUs using Pathways backend.

Run a batch workload using the PathwaysJob API

The following manifest deploys the Pathways components and runs a MaxText workload using the PathwaysJob API . The workload is encapsulated in the main container and exercises train.py .

Copy the following YAML into a file named pathways-job-batch-training.yaml and update the editable values.

 apiVersion 
 : 
  
 pathways-job.pathways.domain/v1 
 kind 
 : 
  
 PathwaysJob 
 metadata 
 : 
  
 name 
 : 
  
 pathways- USER 
 
 spec 
 : 
  
 maxRestarts 
 : 
  
  MAX_RESTARTS 
 
  
 workers 
 : 
  
 - 
  
 type 
 : 
  
  TPU_MACHINE_TYPE 
 
  
 topology 
 : 
  
  TOPOLOGY 
 
  
 numSlices 
 : 
  
  WORKLOAD_NODEPOOL_COUNT 
 
  
 pathwaysDir 
 : 
  
 "gs:// BUCKET_NAME 
" 
  
 controller 
 : 
  
 deploymentMode 
 : 
  
 default 
  
 template 
 : 
  
 spec 
 : 
  
 containers 
 : 
  
 - 
  
 name 
 : 
  
 main 
  
 image 
 : 
  
 gcr.io/ PROJECT 
/ USER 
_runner 
  
 command 
 : 
  
 - 
  
 bash 
  
 - 
  
 -c 
  
 - 
  
 | 
  
 python3 -m MaxText.train MaxText/configs/base.yml \ 
  
 base_output_directory=gs:// BUCKET_NAME 
\ 
  
 run_name= RUN_NAME 
\ 
  
 per_device_batch_size=1 \ 
  
 enable_checkpointing=false \ 
  
 remat_policy=full \ 
  
 global_parameter_scale=1 \ 
  
 steps=20 \ 
  
 max_target_length=2048 \ 
  
 use_iota_embed=true \ 
  
 reuse_example_batch=1 \ 
  
 dataset_type=synthetic \ 
  
 attention=flash \ 
  
 gcs_metrics=True \ 
  
 enable_single_controller=True 

Replace the following:

  • USER : your Google Cloud user ID
  • MAX_RESTARTS : the maximum number of times the Job can be restarted
  • TPU_MACHINE_TYPE : the TPU machine type
  • TOPOLOGY : the TPU v4 or later topology. For more information about TPU versions and supported topologies, see TPU versions
  • WORKLOAD_NODEPOOL_COUNT : the number of node pools used by a Pathways workload
  • BUCKET_NAME : a Cloud Storage bucket for storing temporary files
  • PROJECT : your Google Cloud project ID
  • RUN_NAME : a user-assigned name to identify the workflow run

You can deploy the PathwaysJob YAML as follows:

 kubectl  
apply  
-f  
pathways-job-batch-training.yaml 

To view the PathwaysJob instance is created by the previous command use:

 kubectl  
get  
pathwaysjob 

The output should look like this:

 NAME             AGE 
 pathways-trial   9s 

To modify an attribute of the PathwaysJob instance, delete the PathwaysJob instance, modify the YAML and apply it to create a new PathwaysJob instance.

You can follow the progress of your workload by navigating to the Logs Explorer for your JAX container by choosing main under the Container Name filter.

You should see logs like the following which indicates training is progressing. The workload will complete after 30 steps.

 completed step 
 : 
  
 1, seconds 
 : 
  
 0.484, TFLOP/s/device 
 : 
  
 87.349, Tokens/s/device 
 : 
  
 2117.382, total_weights 
 : 
  
 2945, loss 
 : 
  
 10.888 
 completed step 
 : 
  
 2, seconds 
 : 
  
 0.407, TFLOP/s/device 
 : 
  
 103.699, Tokens/s/device 
 : 
  
 2513.735, total_weights 
 : 
  
 3253, loss 
 : 
  
 9.697 
 completed step 
 : 
  
 3, seconds 
 : 
  
 0.248, TFLOP/s/device 
 : 
  
 170.300, Tokens/s/device 
 : 
  
 4128.167, total_weights 
 : 
  
 3154, loss 
 : 
  
 9.641 
 completed step 
 : 
  
 4, seconds 
 : 
  
 0.216, TFLOP/s/device 
 : 
  
 195.122, Tokens/s/device 
 : 
  
 4729.880, total_weights 
 : 
  
 3119, loss 
 : 
  
 9.547 
 completed step 
 : 
  
 5, seconds 
 : 
  
 0.272, TFLOP/s/device 
 : 
  
 155.298, Tokens/s/device 
 : 
  
 3764.512, total_weights 
 : 
  
 2837, loss 
 : 
  
 10.179 
 completed step 
 : 
  
 6, seconds 
 : 
  
 0.472, TFLOP/s/device 
 : 
  
 89.489, Tokens/s/device 
 : 
  
 2169.266, total_weights 
 : 
  
 3069, loss 
 : 
  
 9.776 

To delete the PathwaysJob instance, you can use the following command:

 kubectl  
delete  
-f  
pathways-job-batch-training.yaml 

Run a batch workload using XPK

Now you can submit the prebuilt Maxtext docker image using XPK with the same command you used previously.

xpk  
workload  
create-pathways  
 \ 
--workload = 
 WORKLOAD 
  
 \ 
--cluster = 
 CLUSTER 
  
 \ 
--num-slices = 
 WORKLOAD_NODEPOOL_COUNT 
  
 \ 
--tpu-type = 
 TPU_TYPE 
  
 \ 
--project = 
 PROJECT 
  
 \ 
--zone = 
 ZONE 
  
 \ 
--docker-image = 
 'gcr.io/ PROJECT 
/ USER 
_runner' 
  
 \ 
--command = 
 "python3 -m MaxText.train MaxText/configs/base.yml base_output_directory=gs:// BUCKET_NAME 
per_device_batch_size=1 enable_checkpointing=false remat_policy=full global_parameter_scale=1 steps=20 max_target_length=2048 use_iota_embed=true reuse_example_batch=1 dataset_type=synthetic attention=flash gcs_metrics=True enable_single_controller=True run_name= RUN_NAME 
-pathways-job" 

Replace the following:

  • WORKLOAD : a unique name to identify your workload
  • CLUSTER : the name of your GKE cluster
  • WORKLOAD_NODEPOOL_COUNT : the maximum number of times the job can be restarted
  • TPU_TYPE : the TPU type specifies the version and size of the Cloud TPU you want to create. For more information about supported TPU types for each TPU version, see TPU versions
  • PROJECT : you Google Cloud project ID
  • ZONE : the zone where you plan to run your workload
  • USER : your Google Cloud user ID
  • RUN_NAME : a user-assigned name to identify the workflow run

You should see output like the following:

 [ 
 XPK 
 ] 
  
 Follow your Pathways workload and other resources here 
 : 
  
 https://console.cloud.google.com/logs/query;query=resource.type%3D"k8s_container"%0Aresource.labels.project_id%3D"<project-name>"%0Aresource.labels.location%3D"<your-zone>"%0Aresource.labels.cluster_name%3D"<your-cluster-name>"%0Aresource.labels.pod_name:"<your-pod-name>"%0Aseverity>%3DDEFAULT 

Use the link in the output from the previous XPK command to follow the progress of your workload. You can filter the logs for your JAX container by choosing jax-tpu under the Container Name filter.

 completed step 
 : 
  
 1, seconds 
 : 
  
 0.484, TFLOP/s/device 
 : 
  
 87.349, Tokens/s/device 
 : 
  
 2117.382, total_weights 
 : 
  
 2945, loss 
 : 
  
 10.888 
 completed step 
 : 
  
 2, seconds 
 : 
  
 0.407, TFLOP/s/device 
 : 
  
 103.699, Tokens/s/device 
 : 
  
 2513.735, total_weights 
 : 
  
 3253, loss 
 : 
  
 9.697 
 completed step 
 : 
  
 3, seconds 
 : 
  
 0.248, TFLOP/s/device 
 : 
  
 170.300, Tokens/s/device 
 : 
  
 4128.167, total_weights 
 : 
  
 3154, loss 
 : 
  
 9.641 
 completed step 
 : 
  
 4, seconds 
 : 
  
 0.216, TFLOP/s/device 
 : 
  
 195.122, Tokens/s/device 
 : 
  
 4729.880, total_weights 
 : 
  
 3119, loss 
 : 
  
 9.547 
 completed step 
 : 
  
 5, seconds 
 : 
  
 0.272, TFLOP/s/device 
 : 
  
 155.298, Tokens/s/device 
 : 
  
 3764.512, total_weights 
 : 
  
 2837, loss 
 : 
  
 10.179 
 completed step 
 : 
  
 6, seconds 
 : 
  
 0.472, TFLOP/s/device 
 : 
  
 89.489, Tokens/s/device 
 : 
  
 2169.266, total_weights 
 : 
  
 3069, loss 
 : 
  
 9.776 

The workload will complete after the specified number of steps, however, if you want to terminate it prematurely, use the following command:

xpk  
workload  
delete  
--workload = 
 WORKLOAD 
  
--cluster = 
 CLUSTER 
  
--project = 
 PROJECT 
  
--zone = 
 ZONE 

What's next

Create a Mobile Website
View Site in Mobile | Classic
Share by: