For the purpose of this document, batch workloads are defined as JAX workloads that execute to completion and are deployed within the same GKE cluster as the Pathways cluster, specifically alongside the Pathways controller components (IFRT proxy server and Pathways resource manager). Completion of the JAX workload also terminates the Pathways cluster components. This guide uses a JAX training workload to demonstrate this.
Before you begin
Make sure you have:
- Created a GKE cluster .
- Installed XPK
- Installed Kubernetes tools
- Enabled the TPU API
- Enabled the Google Kubernetes Engine API
- Ensure your Google Cloud project is allowlisted for Pathways
Build a training image using Maxtext
MaxText is an open-source large language model (LLM) project developed by Google. It's written in JAX and designed to be highly performant and scalable, running efficiently on Google Cloud TPUs and GPUs.
To build a MaxText Docker image by using the latest version of stable JAX from the OSS GitHub repository, run the following command:
git clone https://github.com/AI-Hypercomputer/maxtext cd maxtext/ gcloud config set project PROJECT bash ./docker_build_dependency_image.sh MODE=stable gcloud auth configure-docker bash ./docker_upload_runner.sh CLOUD_IMAGE_NAME= USER _runner # This script needs bash version >= 4.2 to execute.
This command pushes the MaxText Kubernetes image to gcr.io/$PROJECT/${USER}_runner
.
You can use this Docker image to run training on TPUs using Pathways backend.
Run a batch workload using the PathwaysJob API
The following manifest deploys the Pathways components and runs a MaxText
workload using the PathwaysJob API
.
The workload is encapsulated in the main
container and exercises train.py
.
Copy the following YAML into a file named pathways-job-batch-training.yaml
and
update the editable values.
apiVersion : pathways-job.pathways.domain/v1 kind : PathwaysJob metadata : name : pathways- USER spec : maxRestarts : MAX_RESTARTS workers : - type : TPU_MACHINE_TYPE topology : TOPOLOGY numSlices : WORKLOAD_NODEPOOL_COUNT pathwaysDir : "gs:// BUCKET_NAME " controller : deploymentMode : default template : spec : containers : - name : main image : gcr.io/ PROJECT / USER _runner command : - bash - -c - | python3 -m MaxText.train MaxText/configs/base.yml \ base_output_directory=gs:// BUCKET_NAME \ run_name= RUN_NAME \ per_device_batch_size=1 \ enable_checkpointing=false \ remat_policy=full \ global_parameter_scale=1 \ steps=20 \ max_target_length=2048 \ use_iota_embed=true \ reuse_example_batch=1 \ dataset_type=synthetic \ attention=flash \ gcs_metrics=True \ enable_single_controller=True
Replace the following:
-
USER
: your Google Cloud user ID -
MAX_RESTARTS
: the maximum number of times the Job can be restarted -
TPU_MACHINE_TYPE
: the TPU machine type -
TOPOLOGY
: the TPU v4 or later topology. For more information about TPU versions and supported topologies, see TPU versions -
WORKLOAD_NODEPOOL_COUNT
: the number of node pools used by a Pathways workload -
BUCKET_NAME
: a Cloud Storage bucket for storing temporary files -
PROJECT
: your Google Cloud project ID -
RUN_NAME
: a user-assigned name to identify the workflow run
You can deploy the PathwaysJob
YAML as follows:
kubectl
apply
-f
pathways-job-batch-training.yaml
To view the PathwaysJob
instance is created by the previous command use:
kubectl
get
pathwaysjob
The output should look like this:
NAME AGE pathways-trial 9s
To modify an attribute of the PathwaysJob
instance, delete the PathwaysJob
instance, modify the YAML and apply it to create a new PathwaysJob
instance.
You can follow the progress of your workload by navigating to the Logs Explorer
for
your JAX container by choosing main
under the Container Name filter.
You should see logs like the following which indicates training is progressing. The workload will complete after 30 steps.
completed step : 1, seconds : 0.484, TFLOP/s/device : 87.349, Tokens/s/device : 2117.382, total_weights : 2945, loss : 10.888 completed step : 2, seconds : 0.407, TFLOP/s/device : 103.699, Tokens/s/device : 2513.735, total_weights : 3253, loss : 9.697 completed step : 3, seconds : 0.248, TFLOP/s/device : 170.300, Tokens/s/device : 4128.167, total_weights : 3154, loss : 9.641 completed step : 4, seconds : 0.216, TFLOP/s/device : 195.122, Tokens/s/device : 4729.880, total_weights : 3119, loss : 9.547 completed step : 5, seconds : 0.272, TFLOP/s/device : 155.298, Tokens/s/device : 3764.512, total_weights : 2837, loss : 10.179 completed step : 6, seconds : 0.472, TFLOP/s/device : 89.489, Tokens/s/device : 2169.266, total_weights : 3069, loss : 9.776
To delete the PathwaysJob
instance, you can use the following command:
kubectl
delete
-f
pathways-job-batch-training.yaml
Run a batch workload using XPK
Now you can submit the prebuilt Maxtext docker image using XPK with the same command you used previously.
xpk workload create-pathways \ --workload = WORKLOAD \ --cluster = CLUSTER \ --num-slices = WORKLOAD_NODEPOOL_COUNT \ --tpu-type = TPU_TYPE \ --project = PROJECT \ --zone = ZONE \ --docker-image = 'gcr.io/ PROJECT / USER _runner' \ --command = "python3 -m MaxText.train MaxText/configs/base.yml base_output_directory=gs:// BUCKET_NAME per_device_batch_size=1 enable_checkpointing=false remat_policy=full global_parameter_scale=1 steps=20 max_target_length=2048 use_iota_embed=true reuse_example_batch=1 dataset_type=synthetic attention=flash gcs_metrics=True enable_single_controller=True run_name= RUN_NAME -pathways-job"
Replace the following:
-
WORKLOAD
: a unique name to identify your workload -
CLUSTER
: the name of your GKE cluster -
WORKLOAD_NODEPOOL_COUNT
: the maximum number of times the job can be restarted -
TPU_TYPE
: the TPU type specifies the version and size of the Cloud TPU you want to create. For more information about supported TPU types for each TPU version, see TPU versions -
PROJECT
: you Google Cloud project ID -
ZONE
: the zone where you plan to run your workload -
USER
: your Google Cloud user ID -
RUN_NAME
: a user-assigned name to identify the workflow run
You should see output like the following:
[ XPK ] Follow your Pathways workload and other resources here : https://console.cloud.google.com/logs/query;query=resource.type%3D"k8s_container"%0Aresource.labels.project_id%3D"<project-name>"%0Aresource.labels.location%3D"<your-zone>"%0Aresource.labels.cluster_name%3D"<your-cluster-name>"%0Aresource.labels.pod_name:"<your-pod-name>"%0Aseverity>%3DDEFAULT
Use the link in the output from the previous XPK command to follow the progress
of your workload. You can filter the logs for your JAX container by choosing jax-tpu
under the Container Name filter.
completed step : 1, seconds : 0.484, TFLOP/s/device : 87.349, Tokens/s/device : 2117.382, total_weights : 2945, loss : 10.888 completed step : 2, seconds : 0.407, TFLOP/s/device : 103.699, Tokens/s/device : 2513.735, total_weights : 3253, loss : 9.697 completed step : 3, seconds : 0.248, TFLOP/s/device : 170.300, Tokens/s/device : 4128.167, total_weights : 3154, loss : 9.641 completed step : 4, seconds : 0.216, TFLOP/s/device : 195.122, Tokens/s/device : 4729.880, total_weights : 3119, loss : 9.547 completed step : 5, seconds : 0.272, TFLOP/s/device : 155.298, Tokens/s/device : 3764.512, total_weights : 2837, loss : 10.179 completed step : 6, seconds : 0.472, TFLOP/s/device : 89.489, Tokens/s/device : 2169.266, total_weights : 3069, loss : 9.776
The workload will complete after the specified number of steps, however, if you want to terminate it prematurely, use the following command:
xpk workload delete --workload = WORKLOAD --cluster = CLUSTER --project = PROJECT --zone = ZONE