Fine-tune a LLM using TPUs on GKE with JAX

This tutorial shows you how to fine-tune a large language model (LLM) using Tensor Processing Units (TPUs) on Google Kubernetes Engine (GKE) with JAX. Fine-tuning lets you adapt a foundation model like Gemma 3 to a specific domain or task. This process improves the precision and accuracy of the model by updating its parameters with your own specialized dataset.

This guide is a good starting point if you need the granular control, customization, scalability, resilience, portability, and cost-effectiveness of managed Kubernetes when fine-tuning your AI/ML workloads.

Background

By using TPUs on GKE with Jax to fine-tune an LLM, you can build a robust, production-ready fine-tuning solution with all the benefits of managed Kubernetes.

Gemma

Gemma is a set of openly available, lightweight, generative AI/ML multimodal models released under an open license. These AI models are available to run in your applications, hardware, mobile devices, or hosted services. Gemma 3 introduces multimodality, and it supports vision-language input and text outputs. It handles context windows of up to 128,000 tokens and supports over 140 languages. Gemma 3 also offers improved math, reasoning, and chat capabilities, including structured outputs and function calling.

You can use the Gemma models for text generation, or you can also tune these models for specialized tasks.

For more information, see the Gemma documentation .

TPUs

TPUs are application-specific integrated circuits (ASICs) that Google custom-developed to accelerate machine learning and AI models that are built using frameworks such as TensorFlow , PyTorch , and JAX .

Before you use TPUs in GKE, we recommend that you complete the following learning path:

  1. Learn about current TPU version availability with the Cloud TPU system architecture .
  2. Learn about TPUs in GKE .

JAX

JAX is a high-performance machine learning framework that is designed to be used with TPUs and GPUs. JAX provides an API for building and training machine learning models.

To learn more, see the JAX repository .

Objectives

This tutorial covers the following steps:

  1. Create a GKE Autopilot or Standard cluster with the recommended TPU topology, based on the model characteristics. During this tutorial, you perform the fine-tuning on single-host node pools.
  2. Add data to a Cloud Storage bucket and mount it to the container through Cloud Storage FUSE.
  3. Deploy the LLM fine-tuning Job on GKE.
  4. Monitor the fine-tuning Job and view the logs.

Before you begin

  • Sign in to your Google Cloud account. If you're new to Google Cloud, create an account to evaluate how our products perform in real-world scenarios. New customers also get $300 in free credits to run, test, and deploy workloads.
  • In the Google Cloud console, on the project selector page, select or create a Google Cloud project.

    Roles required to select or create a project

    • Select a project : Selecting a project doesn't require a specific IAM role—you can select any project that you've been granted a role on.
    • Create a project : To create a project, you need the Project Creator role ( roles/resourcemanager.projectCreator ), which contains the resourcemanager.projects.create permission. Learn how to grant roles .

    Go to project selector

  • Verify that billing is enabled for your Google Cloud project .

  • Enable the required API.

    Roles required to enable APIs

    To enable APIs, you need the Service Usage Admin IAM role ( roles/serviceusage.serviceUsageAdmin ), which contains the serviceusage.services.enable permission. Learn how to grant roles .

    Enable the API

  • In the Google Cloud console, on the project selector page, select or create a Google Cloud project.

    Roles required to select or create a project

    • Select a project : Selecting a project doesn't require a specific IAM role—you can select any project that you've been granted a role on.
    • Create a project : To create a project, you need the Project Creator role ( roles/resourcemanager.projectCreator ), which contains the resourcemanager.projects.create permission. Learn how to grant roles .

    Go to project selector

  • Verify that billing is enabled for your Google Cloud project .

  • Enable the required API.

    Roles required to enable APIs

    To enable APIs, you need the Service Usage Admin IAM role ( roles/serviceusage.serviceUsageAdmin ), which contains the serviceusage.services.enable permission. Learn how to grant roles .

    Enable the API

  • Make sure that you have the following role or roles on the project: roles/container.admin,roles/iam.serviceAccountAdmin,roles/storage.admin

    Check for the roles

    1. In the Google Cloud console, go to the IAM page.

      Go to IAM
    2. Select the project.
    3. In the Principal column, find all rows that identify you or a group that you're included in. To learn which groups you're included in, contact your administrator.

    4. For all rows that specify or include you, check the Role column to see whether the list of roles includes the required roles.

    Grant the roles

    1. In the Google Cloud console, go to the IAM page.

      Go to IAM
    2. Select the project.
    3. Click Grant access .
    4. In the New principals field, enter your user identifier. This is typically the email address for a Google Account.

    5. Click Select a role , then search for the role.
    6. To grant additional roles, click Add another role and add each additional role.
    7. Click Save .
  • Ensure that you have sufficient quota for 16 TPU Trillium (v6e) chips. In this tutorial, you use a node pool configuration that requires 16 chips and on-demand instances .
  • Ensure that you have a Docker repository. If you don't have one, create a standard repository in Artifact Registry.

Prepare the environment

In this tutorial, you use Cloud Shell to manage resources hosted on Google Cloud. Cloud Shell comes preinstalled with the software you need for this tutorial, including kubectl and Google Cloud CLI .

To set up your environment with Cloud Shell, follow these steps:

  1. In the Google Cloud console, launch a Cloud Shell session and clickCloud Shell activation icon Activate Cloud Shell. This action launches a session in the bottom pane of the Google Cloud console.

  2. Set the default environment variables:

     gcloud  
    config  
     set 
      
    project  
     PROJECT_ID 
    gcloud  
    config  
     set 
      
    billing/quota_project  
     PROJECT_ID 
     export 
      
     PROJECT_ID 
     = 
     $( 
    gcloud  
    config  
    get  
    project ) 
     export 
      
     CLUSTER_NAME 
     = 
     CLUSTER_NAME 
     export 
      
     REGION 
     = 
     CONTROL_PLANE_LOCATION 
     export 
      
     ZONE 
     = 
     ZONE 
     export 
      
     GCS_BUCKET_NAME 
     = 
     BUCKET_NAME 
     
    

    Replace the following values:

    • PROJECT_ID : your Google Cloud project ID .
    • CLUSTER_NAME : the name of your GKE cluster.
    • CONTROL_PLANE_LOCATION : the Compute Engine region where your GKE cluster and TPU nodes are located. The region must contain zones where TPU Trillium (v6e) machine types are available.
    • ZONE : a zone within your selected CONTROL_PLANE_LOCATION region where TPU Trillium (v6e) machine types are available. To list zones where TPU Trillium (v6e) TPUs are available, run the following command:

         
      gcloud  
      compute  
      accelerator-types  
      list  
      --filter = 
       "name~ct6e" 
        
      --format = 
       "value(zone)" 
       
      
    • BUCKET_NAME : the name of the Cloud Storage bucket that contains your training data.

  3. Clone the sample repository:

     git  
    clone  
    https://github.com/GoogleCloudPlatform/kubernetes-engine-samples.git cd 
      
    kubernetes-engine-samples 
    
  4. Navigate to the working directory:

      cd 
      
    ai-ml/llm-training-jax-tpu-gemma3 
    

Create and configure Google Cloud resources

In this section, you create and configure Google Cloud resources.

Create a GKE cluster

You can fine-tune an LLM on TPUs in a GKE Autopilot or Standard cluster. We recommend that you use a Autopilot cluster for a fully managed Kubernetes experience. To choose the GKE mode of operation that's the best fit for your workloads, see Choose a GKE mode of operation .

Autopilot

Create a GKE Autopilot cluster that uses Workload Identity Federation for GKE and has Cloud Storage FUSE enabled.

 gcloud  
container  
clusters  
create-auto  
 ${ 
 CLUSTER_NAME 
 } 
  
 \ 
  
--location = 
 ${ 
 REGION 
 } 
 

The cluster creation might take several minutes.

Standard

  1. Create a regional GKE Standard cluster that uses Workload Identity Federation for GKE and has Cloud Storage FUSE enabled.

     gcloud  
    container  
    clusters  
    create  
     ${ 
     CLUSTER_NAME 
     } 
      
     \ 
      
    --enable-ip-alias  
     \ 
      
    --addons  
    GcsFuseCsiDriver  
     \ 
      
    --machine-type = 
    n2-standard-4  
     \ 
      
    --num-nodes = 
     2 
      
     \ 
      
    --workload-pool = 
     ${ 
     PROJECT_ID 
     } 
    .svc.id.goog  
     \ 
      
    --location = 
     ${ 
     REGION 
     } 
     
    

    The cluster creation might take several minutes.

  2. Create a single-host node pool:

     gcloud  
    container  
    node-pools  
    create  
    jax-tpu-nodepool  
     \ 
      
    --cluster = 
     ${ 
     CLUSTER_NAME 
     } 
      
     \ 
      
    --machine-type = 
    ct6e-standard-1t  
     \ 
      
    --num-nodes = 
     1 
      
     \ 
      
    --location = 
     ${ 
     REGION 
     } 
      
     \ 
      
    --node-locations = 
     ${ 
     ZONE 
     } 
      
     \ 
      
    --workload-metadata = 
    GKE_METADATA 
    

GKE creates a TPU Trillium node pool with a 1x1 topology and one node. The --workload-metadata=GKE_METADATA flag configures the node pool to use the GKE metadata server.

Install JobSet

  1. Configure kubectl to communicate with your cluster:

     gcloud  
    container  
    clusters  
    get-credentials  
     ${ 
     CLUSTER_NAME 
     } 
      
    --location = 
     ${ 
     REGION 
     } 
     
    
  2. Install the latest released version of JobSet :

     kubectl  
    apply  
    --server-side  
    -f  
    https://github.com/kubernetes-sigs/jobset/releases/download/ JOBSET_VERSION 
    /manifests.yaml 
    

    Replace JOBSET_VERSION with the latest released version of JobSet . For example, v0.11.0 .

  3. Verify the JobSet installation:

     kubectl  
    get  
    pods  
    -n  
    jobset-system 
    

    The output is similar to the following:

     NAME                                         READY   STATUS    RESTARTS   AGE
    jobset-controller-manager-6c56668494-l4dhc   1/1     Running   0          4m45s 
    

    You might need to add more nodes if JobSet is waiting for resources.

Configure Cloud Storage FUSE

To fine-tune the LLM, you need to provide training data. In this tutorial, you use the TinyStories dataset from Hugging Face. This dataset contains short stories, synthetically generated by GPT-3.5 and GPT-4, that use a limited vocabulary.

This section covers the steps to configure Cloud Storage FUSE to read data from a Cloud Storage bucket.

  1. Download the dataset:

     wget  
    https://huggingface.co/datasets/roneneldan/TinyStories/resolve/main/TinyStories-train.txt?download = 
     true 
      
    -O  
    TinyStories-train.txt 
    
  2. Upload the data into a new Cloud Storage bucket:

     gcloud  
    storage  
    buckets  
    create  
    gs:// ${ 
     GCS_BUCKET_NAME 
     } 
      
     \ 
      
    --location = 
     ${ 
     REGION 
     } 
      
     \ 
      
    --enable-hierarchical-namespace  
     \ 
      
    --uniform-bucket-level-access
    gcloud  
    storage  
    cp  
    TinyStories-train.txt  
    gs:// ${ 
     GCS_BUCKET_NAME 
     } 
     
    
  3. To allow your workload to read data through Cloud Storage FUSE, create a Kubernetes service account (KSA) and add the required permissions. Run the permissionsetup.sh script:

      # Copyright 2026 Google LLC 
     # 
     # Licensed under the Apache License, Version 2.0 (the "License"); 
     # you may not use this file except in compliance with the License. 
     # You may obtain a copy of the License at 
     # 
     #     http://www.apache.org/licenses/LICENSE-2.0 
     # 
     # Unless required by applicable law or agreed to in writing, software 
     # distributed under the License is distributed on an "AS IS" BASIS, 
     # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 
     # See the License for the specific language governing permissions and 
     # limitations under the License. 
     #!/bin/bash 
     # --- Configuration Variables --- 
     # Kubernetes Service Account details 
     export 
      
     KSA_NAME 
     = 
     "jaxserviceaccout" 
     export 
      
     NAMESPACE 
     = 
     "default" 
     # Google Cloud IAM Service Account details 
     export 
      
     GSA_NAME 
     = 
     "<GSA_NAME>" 
     # Automatically get the current project ID 
     export 
      
     PROJECT_ID 
     = 
     $( 
    gcloud  
    config  
    get-value  
    project ) 
     export 
      
     GSA_DESCRIPTION 
     = 
     "GKE Service Account to read GCS bucket for 
     ${ 
     KSA_NAME 
     } 
     " 
     # GCS Bucket details 
     export 
      
     GCS_BUCKET_NAME 
     = 
     "<GCS_BUCKET_NAME>" 
      
     # <--- IMPORTANT: Update this to your bucket name 
     # Derived Variables 
     export 
      
     GSA_EMAIL 
     = 
     " 
     ${ 
     GSA_NAME 
     } 
     @ 
     ${ 
     PROJECT_ID 
     } 
     .iam.gserviceaccount.com" 
     export 
      
     WI_MEMBER 
     = 
     "serviceAccount: 
     ${ 
     PROJECT_ID 
     } 
     .svc.id.goog[ 
     ${ 
     NAMESPACE 
     } 
     / 
     ${ 
     KSA_NAME 
     } 
     ]" 
     # --- Check if PROJECT_ID is set --- 
     if 
      
     [ 
      
    -z  
     " 
     ${ 
     PROJECT_ID 
     } 
     " 
      
     ] 
     ; 
      
     then 
      
     echo 
      
     "Error: PROJECT_ID is not set. Please set it using 'gcloud config set project YOUR_PROJECT_ID'" 
      
     exit 
      
     1 
     fi 
     echo 
      
     "--- Configuration ---" 
     echo 
      
     "KSA_NAME: 
     ${ 
     KSA_NAME 
     } 
     " 
     echo 
      
     "NAMESPACE: 
     ${ 
     NAMESPACE 
     } 
     " 
     echo 
      
     "GSA_NAME: 
     ${ 
     GSA_NAME 
     } 
     " 
     echo 
      
     "PROJECT_ID: 
     ${ 
     PROJECT_ID 
     } 
     " 
     echo 
      
     "GSA_EMAIL: 
     ${ 
     GSA_EMAIL 
     } 
     " 
     echo 
      
     "GCS_BUCKET_NAME: 
     ${ 
     GCS_BUCKET_NAME 
     } 
     " 
     echo 
      
     "WI_MEMBER: 
     ${ 
     WI_MEMBER 
     } 
     " 
     echo 
      
     "--------------------" 
     read 
      
    -p  
     "Press enter to continue..." 
     # --- Command Execution --- 
     echo 
      
     "[1/5] Creating Google Cloud IAM Service Account (GSA): 
     ${ 
     GSA_NAME 
     } 
     " 
    gcloud  
    iam  
    service-accounts  
    create  
     " 
     ${ 
     GSA_NAME 
     } 
     " 
      
     \ 
      
    --project = 
     " 
     ${ 
     PROJECT_ID 
     } 
     " 
      
     \ 
      
    --description = 
     " 
     ${ 
     GSA_DESCRIPTION 
     } 
     " 
      
     \ 
      
    --display-name = 
     " 
     ${ 
     GSA_NAME 
     } 
     " 
     echo 
      
     "[2/5] Granting GSA ' 
     ${ 
     GSA_EMAIL 
     } 
     ' read access (roles/storage.objectViewer) to bucket 'gs:// 
     ${ 
     GCS_BUCKET_NAME 
     } 
     '" 
    gcloud  
    storage  
    buckets  
    add-iam-policy-binding  
     "gs:// 
     ${ 
     GCS_BUCKET_NAME 
     } 
     " 
      
     \ 
      
    --member = 
     "serviceAccount: 
     ${ 
     GSA_EMAIL 
     } 
     " 
      
     \ 
      
    --role = 
     "roles/storage.objectViewer" 
      
     \ 
      
    --project = 
     " 
     ${ 
     PROJECT_ID 
     } 
     " 
     echo 
      
     "[3/5] Creating Kubernetes Service Account (KSA): 
     ${ 
     KSA_NAME 
     } 
     in namespace 
     ${ 
     NAMESPACE 
     } 
     " 
    kubectl  
    create  
    serviceaccount  
     " 
     ${ 
     KSA_NAME 
     } 
     " 
      
    --namespace  
     " 
     ${ 
     NAMESPACE 
     } 
     " 
     echo 
      
     "[4/5] Allowing KSA to impersonate GSA (Workload Identity Binding): 
     ${ 
     GSA_EMAIL 
     } 
     " 
    gcloud  
    iam  
    service-accounts  
    add-iam-policy-binding  
     " 
     ${ 
     GSA_EMAIL 
     } 
     " 
      
     \ 
      
    --role  
    roles/iam.workloadIdentityUser  
     \ 
      
    --member  
     " 
     ${ 
     WI_MEMBER 
     } 
     " 
      
     \ 
      
    --project = 
     " 
     ${ 
     PROJECT_ID 
     } 
     " 
     echo 
      
     "[5/5] Annotating KSA ' 
     ${ 
     KSA_NAME 
     } 
     ' to link with GSA ' 
     ${ 
     GSA_EMAIL 
     } 
     '" 
    kubectl  
    annotate  
    serviceaccount  
     " 
     ${ 
     KSA_NAME 
     } 
     " 
      
     \ 
      
    --namespace  
     " 
     ${ 
     NAMESPACE 
     } 
     " 
      
     \ 
      
    iam.gke.io/gcp-service-account = 
     " 
     ${ 
     GSA_EMAIL 
     } 
     " 
     echo 
      
     "--- Setup Complete ---" 
     echo 
      
     "Pods in namespace ' 
     ${ 
     NAMESPACE 
     } 
     ' using serviceAccount ' 
     ${ 
     KSA_NAME 
     } 
     ' can now authenticate as ' 
     ${ 
     GSA_EMAIL 
     } 
     ' and have read access to 'gs:// 
     ${ 
     GCS_BUCKET_NAME 
     } 
     '." 
     
    

    After you run this script, the following resources are configured in your Google Cloud project and GKE cluster:

    • A new IAM service account named gcs-fuse-sa is created in your project.
    • The created Google Cloud Service Account (GSA) ( gcs-fuse-sa ) is granted the roles/storage.objectViewer role on the Cloud Storage bucket specified by ${GCS_BUCKET_NAME} . This permission allows the GSA to read objects from the bucket.
    • A new KSA named jaxserviceaccount is created in the default namespace within your GKE cluster.
    • The IAM policy of the GSA is updated to grant the roles/iam.workloadIdentityUser role to the KSA. This permission allows the KSA to impersonate the GSA.
    • The KSA is annotated to link it to the GSA. This annotation tells GKE which GSA the KSA should impersonate by using Workload Identity.

      Any Pod running in the default namespace of your GKE cluster that uses the jaxserviceaccount service account will now be able to authenticate as the gcs-fuse-sa GSA. These Pods will have read access to the objects stored in the gs://${GCS_BUCKET_NAME} bucket, which is essential for the fine-tuning Job to access the dataset by using Cloud Storage FUSE.

Create the fine-tuning script

In this section, you explore the training script that performs a fine-tuning operation on a Gemma 3 model. This script uses the Gemma3Tokenizer .

Review the following Gemma3LLMTrain.py fine-tuning script:

  # Copyright 2026 Google LLC 
 # 
 # Licensed under the Apache License, Version 2.0 (the "License"); 
 # you may not use this file except in compliance with the License. 
 # You may obtain a copy of the License at 
 # 
 #     http://www.apache.org/licenses/LICENSE-2.0 
 # 
 # Unless required by applicable law or agreed to in writing, software 
 # distributed under the License is distributed on an "AS IS" BASIS, 
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 
 # See the License for the specific language governing permissions and 
 # limitations under the License. 
 import 
  
 grain.python 
  
 as 
  
 pygrain 
 import 
  
 jax 
 import 
  
 jax.numpy 
  
 as 
  
 jnp 
 import 
  
 optax 
 import 
  
 pandas 
  
 as 
  
 pd 
 import 
  
 time 
 import 
  
 argparse 
 from 
  
 dataclasses 
  
 import 
 dataclass 
 from 
  
 functools 
  
 import 
 partial 
 from 
  
 gemma 
  
 import 
 gm 
 from 
  
 flax.training 
  
 import 
 train_state 
 from 
  
 jax.sharding 
  
 import 
 Mesh 
 , 
 PartitionSpec 
 , 
 NamedSharding 
 jax 
 . 
 distributed 
 . 
 initialize 
 () 
 print 
 ( 
 "Global device count:" 
 , 
 jax 
 . 
 device_count 
 ()) 
 print 
 ( 
 "jax version:" 
 , 
 jax 
 . 
 __version__ 
 ) 
 tokenizer 
 = 
 gm 
 . 
 text 
 . 
 Gemma3Tokenizer 
 () 
 num_epochs 
 = 
 1 
 learning_rate 
 = 
 2e-5 
 @dataclass 
 class 
  
 TextDataset 
 : 
 data 
 : 
 list 
 maxlen 
 : 
 int 
 tokenizer 
 : 
 gm 
 . 
 text 
 . 
 Gemma3Tokenizer 
 def 
  
 __len__ 
 ( 
 self 
 ): 
 return 
 len 
 ( 
 self 
 . 
 data 
 ) 
 def 
  
 __getitem__ 
 ( 
 self 
 , 
 idx 
 : 
 int 
 ): 
 encoding 
 = 
 self 
 . 
 tokenizer 
 . 
 encode 
 ( 
 self 
 . 
 data 
 [ 
 idx 
 ])[: 
 self 
 . 
 maxlen 
 ] 
 # Tokenize and truncate 
 return 
 encoding 
 + 
 [ 
 0 
 ] 
 * 
 ( 
 self 
 . 
 maxlen 
 - 
 len 
 ( 
 encoding 
 )) 
 # Pad to maxlen 
 def 
  
 load_and_preprocess_data 
 ( 
 file_path 
 , 
 batch_size 
 , 
 maxlen 
 , 
 datacount 
 , 
 tokenizer 
 ): 
 with 
 open 
 ( 
 file_path 
 , 
 'r' 
 ) 
 as 
 f 
 : 
 text 
 = 
 f 
 . 
 read 
 () 
 stories 
 = 
 text 
 . 
 split 
 ( 
 '<|endoftext|>' 
 ) 
 stories 
 = 
 [ 
 story 
 for 
 story 
 in 
 stories 
 if 
 story 
 . 
 strip 
 ()][: 
 datacount 
 ] 
 df 
 = 
 pd 
 . 
 DataFrame 
 ({ 
 'text' 
 : 
 stories 
 }) 
 data 
 = 
 df 
 [ 
 'text' 
 ] 
 . 
 dropna 
 () 
 . 
 tolist 
 () 
 dataset 
 = 
 TextDataset 
 ( 
 data 
 , 
 maxlen 
 , 
 tokenizer 
 ) 
 sampler 
 = 
 pygrain 
 . 
 IndexSampler 
 ( 
 len 
 ( 
 dataset 
 ), 
 shuffle 
 = 
 False 
 , 
 seed 
 = 
 42 
 , 
 shard_options 
 = 
 pygrain 
 . 
 NoSharding 
 (), 
 num_epochs 
 = 
 num_epochs 
 , 
 ) 
 dataloader 
 = 
 pygrain 
 . 
 DataLoader 
 ( 
 data_source 
 = 
 dataset 
 , 
 sampler 
 = 
 sampler 
 , 
 operations 
 = 
 [ 
 pygrain 
 . 
 Batch 
 ( 
 batch_size 
 = 
 batch_size 
 , 
 drop_remainder 
 = 
 True 
 )], 
 ) 
 return 
 dataloader 
 def 
  
 generate_text 
 ( 
 model 
 , 
 params 
 , 
 tokenizer 
 , 
 prompt 
 ): 
 sampler 
 = 
 gm 
 . 
 text 
 . 
 Sampler 
 ( 
 model 
 = 
 model 
 , 
 params 
 = 
 params 
 , 
 tokenizer 
 = 
 tokenizer 
 , 
 ) 
 print 
 ( 
 "Generating response for: " 
 + 
 prompt 
 ) 
 out 
 = 
 sampler 
 . 
 sample 
 ( 
 prompt 
 , 
 max_new_tokens 
 = 
 32 
 ) 
 print 
 ( 
 "Reponse: 
 \n 
 " 
 + 
 out 
 + 
 " 
 \n 
 " 
 ) 
 return 
 out 
 prep_target_batch 
 = 
 jax 
 . 
 vmap 
 ( 
 lambda 
 tokens 
 : 
 jnp 
 . 
 concatenate 
 (( 
 tokens 
 [ 
 1 
 :], 
 jnp 
 . 
 array 
 ([ 
 0 
 ])))) 
 @partial 
 ( 
 jax 
 . 
 jit 
 , 
 donate_argnums 
 = 
 ( 
 0 
 ,)) 
 def 
  
 train_step 
 ( 
 state 
 , 
 batch 
 ): 
  
 """Performs one supervised fine-tuning step.""" 
 def 
  
 loss_fn 
 ( 
 params 
 ): 
 # Run the forward pass. The model returns logits. 
 logits 
 = 
 state 
 . 
 apply_fn 
 ({ 
 'params' 
 : 
 params 
 }, 
 batch 
 [ 
 0 
 ]) 
 . 
 logits 
 # Calculate the cross-entropy loss. 
 loss 
 = 
 optax 
 . 
 softmax_cross_entropy_with_integer_labels 
 ( 
 logits 
 = 
 logits 
 , 
 labels 
 = 
 batch 
 [ 
 1 
 ] 
 ) 
 . 
 mean 
 () 
 return 
 loss 
 # Compute gradients 
 grad_fn 
 = 
 jax 
 . 
 value_and_grad 
 ( 
 loss_fn 
 ) 
 loss 
 , 
 grads 
 = 
 grad_fn 
 ( 
 state 
 . 
 params 
 ) 
 # Update the model state 
 state 
 = 
 state 
 . 
 apply_gradients 
 ( 
 grads 
 = 
 grads 
 ) 
 metrics 
 = 
 { 
 'loss' 
 : 
 loss 
 } 
 return 
 state 
 , 
 metrics 
 def 
  
 train_model 
 ( 
 state 
 , 
 text_dl 
 , 
 num_epochs 
 , 
 sharding 
 ): 
 batchCount 
 = 
 0 
 start_time 
 = 
 time 
 . 
 time 
 () 
 for 
 epoch 
 in 
 range 
 ( 
 num_epochs 
 ): 
 start_time 
 = 
 time 
 . 
 time 
 () 
 for 
 batch 
 in 
 text_dl 
 : 
 if 
 len 
 ( 
 batch 
 ) 
 % 
 len 
 ( 
 jax 
 . 
 devices 
 ()) 
 != 
 0 
 : 
 continue 
 # skip the remaining elements 
 input_batch 
 = 
 jnp 
 . 
 array 
 ( 
 jnp 
 . 
 array 
 ( 
 batch 
 ) 
 . 
 T 
 ) 
 target_batch 
 = 
 prep_target_batch 
 ( 
 input_batch 
 ) 
 state 
 , 
 metrics 
 = 
 train_step 
 ( 
 state 
 , 
 jax 
 . 
 device_put 
 (( 
 input_batch 
 , 
 target_batch 
 ), 
 sharding 
 )) 
 if 
 batchCount 
 % 
 10 
 == 
 0 
 : 
 print 
 ( 
 f 
 "Loss after batch 
 { 
 batchCount 
 } 
 : 
 { 
 metrics 
 [ 
 'loss' 
 ] 
 } 
 " 
 ) 
 batchCount 
 += 
 1 
 end_time 
 = 
 time 
 . 
 time 
 () 
 print 
 ( 
 f 
 "Completed training model. Total time for training 
 { 
 end_time 
  
 - 
  
 start_time 
 } 
 seconds 
 \n 
 " 
 ) 
 return 
 state 
 def 
  
 run_training 
 ( 
 maxlen 
 , 
 batch_size 
 , 
 datacount 
 ): 
 print 
 ( 
 f 
 "Batch size: 
 { 
 batch_size 
 } 
 , Max length: 
 { 
 maxlen 
 } 
 , Data count: 
 { 
 datacount 
 } 
 " 
 ) 
 #Load the training data 
 tiny_stories_dl 
 = 
 load_and_preprocess_data 
 ( 
 '/data/TinyStories-train.txt' 
 , 
 batch_size 
 , 
 maxlen 
 , 
 datacount 
 , 
 tokenizer 
 ) 
 # Get the Gemma3 model 
 model 
 = 
 gm 
 . 
 nn 
 . 
 Gemma3_270M 
 () 
 # Load the pretrained parameters 
 params 
 = 
 gm 
 . 
 ckpts 
 . 
 load_params 
 ( 
 gm 
 . 
 ckpts 
 . 
 CheckpointPath 
 . 
 GEMMA3_270M_PT 
 ) 
 # Create an optimizer 
 optimizer 
 = 
 optax 
 . 
 adamw 
 ( 
 learning_rate 
 = 
 learning_rate 
 ) 
 # Define sharding for data parallel training 
 mesh 
 = 
 Mesh 
 ( 
 jax 
 . 
 devices 
 (), 
 ( 
 'batch' 
 ,)) 
 sharding 
 = 
 NamedSharding 
 ( 
 mesh 
 , 
 PartitionSpec 
 ( 
 'batch' 
 , 
 None 
 )) 
 # Testing out current state of the model 
 test_prompt 
 = 
 "Once upon a time, there was a girl named Amy." 
 generate_text 
 ( 
 model 
 , 
 params 
 , 
 tokenizer 
 , 
 test_prompt 
 ) 
 state 
 = 
 train_state 
 . 
 TrainState 
 . 
 create 
 ( 
 apply_fn 
 = 
 model 
 . 
 apply 
 , 
 params 
 = 
 params 
 , 
 tx 
 = 
 optimizer 
 ) 
 # Perform post training 
 print 
 ( 
 "Start training model" 
 ) 
 state 
 = 
 train_model 
 ( 
 state 
 , 
 tiny_stories_dl 
 , 
 num_epochs 
 , 
 sharding 
 ) 
 # Final text generation 
 generate_text 
 ( 
 model 
 , 
 state 
 . 
 params 
 , 
 tokenizer 
 , 
 test_prompt 
 ) 
 if 
 __name__ 
 == 
 '__main__' 
 : 
 parser 
 = 
 argparse 
 . 
 ArgumentParser 
 ( 
 description 
 = 
 'Train Gemma model with custom parameters.' 
 ) 
 parser 
 . 
 add_argument 
 ( 
 '--maxlen' 
 , 
 type 
 = 
 int 
 , 
 default 
 = 
 256 
 , 
 help 
 = 
 'Maximum sequence length' 
 ) 
 parser 
 . 
 add_argument 
 ( 
 '--batch_size' 
 , 
 type 
 = 
 int 
 , 
 default 
 = 
 128 
 , 
 help 
 = 
 'Batch size' 
 ) 
 parser 
 . 
 add_argument 
 ( 
 '--datacount' 
 , 
 type 
 = 
 int 
 , 
 default 
 = 
 296000 
 , 
 help 
 = 
 'Number of data samples to use' 
 ) 
 args 
 = 
 parser 
 . 
 parse_args 
 () 
 run_training 
 ( 
 maxlen 
 = 
 args 
 . 
 maxlen 
 , 
 batch_size 
 = 
 args 
 . 
 batch_size 
 , 
 datacount 
 = 
 args 
 . 
 datacount 
 ) 
 

In this script, the following applies:

  • A Gemma3Tokenizer converts text data into tokens that the model can process.
  • The load_and_preprocess_data function reads the training data from a file, splits it into individual stories, and uses the tokenizer to convert the text into padded sequences of tokens.
  • The generate_text function takes the model, its parameters, and a prompt to generate text.
  • The train_step function defines a single iteration of training that includes the forward pass, loss calculation (using cross-entropy), gradient computation, and parameter updates.
  • The train_model function iterates through the dataset for a specified number of epochs, which calls the train_step function for each batch.
  • The run_training function orchestrates the entire process to load data, initialize the Gemma 3 model ( Gemma3_270M ) and optimizer, load pre-trained parameters, set up data sharding for parallel processing, run a test generation, execute the training loop, and perform a final text generation to demonstrate the effect of fine-tuning.
  • The script uses argparse library to accept command-line arguments for maxlen , batch_size , and datacount parameters.

Now that you have explored the fine-tuning script, containerize it to run on GKE.

Containerize the fine-tuning script

Before you run the fine-tuning script in a GKE cluster, you need to containerize it. This tutorial uses a JAX AI image as the base image.

  1. Open the Dockerfile in the same directory as the Gemma3LLMTrain.py file:

      # Copyright 2026 Google LLC 
     # 
     # Licensed under the Apache License, Version 2.0 (the "License"); 
     # you may not use this file except in compliance with the License. 
     # You may obtain a copy of the License at 
     # 
     #     http://www.apache.org/licenses/LICENSE-2.0 
     # 
     # Unless required by applicable law or agreed to in writing, software 
     # distributed under the License is distributed on an "AS IS" BASIS, 
     # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 
     # See the License for the specific language governing permissions and 
     # limitations under the License. 
     FROM 
      
     us-docker.pkg.dev/cloud-tpu-images/jax-ai-image/tpu:jax0.7.2-rev1 
     RUN 
      
    apt-get  
    update && 
    apt-get  
    install  
    -y  
    wget && 
    rm  
    -rf  
    /var/lib/apt/lists/* RUN 
      
    pip  
    install  
    --upgrade  
    pip RUN 
      
    pip  
    install  
    gemma  
    grain WORKDIR 
      
     /app 
     # Copy your training script into the container 
     COPY 
      
    Gemma3LLMTrain.py  
    . 
    

    This Dockerfile installs the necessary dependencies and copies the Gemma3LLMTrain.py file into the container.

  2. Build the Docker image and push it to an image repository:

      export 
      
     REPOSITORY 
     = 
     REPOSITORY_NAME 
     export 
      
     IMAGE_NAME 
     = 
     "jax-gemma3-training" 
     export 
      
     IMAGE_TAG 
     = 
     "latest" 
     export 
      
     DOCKERFILE_PATH 
     = 
     "./Dockerfile" 
     export 
      
     IMAGE_URI 
     = 
     " 
     ${ 
     REGION 
     } 
     -docker.pkg.dev/ 
     ${ 
     PROJECT_ID 
     } 
     / 
     ${ 
     REPOSITORY 
     } 
     / 
     ${ 
     IMAGE_NAME 
     } 
     : 
     ${ 
     IMAGE_TAG 
     } 
     " 
    docker  
    build  
    -t  
     " 
     ${ 
     IMAGE_URI 
     } 
     " 
      
    -f  
     " 
     ${ 
     DOCKERFILE_PATH 
     } 
     " 
      
    .
    gcloud  
    auth  
    configure-docker  
     " 
     ${ 
     REGION 
     } 
     -docker.pkg.dev" 
      
    -q
    docker  
    push  
     " 
     ${ 
     IMAGE_URI 
     } 
     " 
     
    

    Replace REPOSITORY_NAME with the name of your Artifact Registry repository.

  3. Add role bindings to the service account:

      export 
      
     PROJECT_NUMBER 
     = 
     $( 
    gcloud  
    projects  
    describe  
     $PROJECT_ID 
      
    --format  
     'get(projectNumber)' 
     ) 
    gcloud  
    artifacts  
    repositories  
    add-iam-policy-binding  
     ${ 
     REPOSITORY 
     } 
      
     \ 
      
    --project = 
     ${ 
     PROJECT_ID 
     } 
      
     \ 
      
    --location = 
     ${ 
     REGION 
     } 
      
     \ 
      
    --member = 
     "serviceAccount: 
     ${ 
     PROJECT_NUMBER 
     } 
     -compute@developer.gserviceaccount.com" 
      
     \ 
      
    --role = 
     "roles/artifactregistry.reader" 
     
    

With the image in the repository, you can now deploy the fine-tuning Job into a GKE cluster.

Deploy the LLM fine-tuning Job

This section shows you how to deploy the LLM fine-tuning Job to your GKE cluster.

  1. Open the training_singlehost.yaml manifest:

      # Copyright 2026 Google LLC 
     # 
     # Licensed under the Apache License, Version 2.0 (the "License"); 
     # you may not use this file except in compliance with the License. 
     # You may obtain a copy of the License at 
     # 
     #     http://www.apache.org/licenses/LICENSE-2.0 
     # 
     # Unless required by applicable law or agreed to in writing, software 
     # distributed under the License is distributed on an "AS IS" BASIS, 
     # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 
     # See the License for the specific language governing permissions and 
     # limitations under the License. 
     apiVersion 
     : 
      
     batch/v1 
     kind 
     : 
      
     Job 
     metadata 
     : 
      
     name 
     : 
      
     jax-gemma3-train-singlehost 
     spec 
     : 
      
     template 
     : 
      
     metadata 
     : 
      
     annotations 
     : 
      
     gke-gcsfuse/volumes 
     : 
      
     "true" 
      
     spec 
     : 
      
     serviceAccountName 
     : 
      
     jaxserviceaccout 
      
     containers 
     : 
      
     - 
      
     name 
     : 
      
     training-container 
      
     image 
     : 
      
     ${IMAGE_URI} 
      
     imagePullPolicy 
     : 
      
     "Always" 
      
     command 
     : 
      
     [ 
     "python" 
     , 
      
     "Gemma3LLMTrain.py" 
     , 
      
     "--maxlen" 
     , 
      
     "256" 
     , 
      
     "--batch_size" 
     , 
      
     "64" 
     , 
      
     "--datacount" 
     , 
      
     "355120" 
     ] 
      
     resources 
     : 
      
     limits 
     : 
      
     google.com/tpu 
     : 
      
     1 
      
     volumeMounts 
     : 
      
     - 
      
     name 
     : 
      
     gcs-fuse-csi-ephemeral 
      
     mountPath 
     : 
      
     /data 
      
     nodeSelector 
     : 
      
     cloud.google.com/gke-tpu-accelerator 
     : 
      
     tpu-v6e-slice 
      
     cloud.google.com/gke-tpu-topology 
     : 
      
     1x1 
      
     restartPolicy 
     : 
      
     Never 
      
     volumes 
     : 
      
     - 
      
     name 
     : 
      
     gcs-fuse-csi-ephemeral 
      
     csi 
     : 
      
     driver 
     : 
      
     gcsfuse.csi.storage.gke.io 
      
     volumeAttributes 
     : 
      
     bucketName 
     : 
      
     ${GCS_BUCKET_NAME} 
      
     mountOptions 
     : 
      
     "implicit-dirs,file-cache:enable-parallel-downloads:true,file-cache:parallel-downloads-per-file:100,file-cache:max-parallel-downloads:-1,file-cache:download-chunk-size-mb:10,file-cache:max-size-mb:-1" 
      
     backoffLimit 
     : 
      
     1 
     
    
  2. Apply the manifest:

     envsubst < 
    training_singlehost.yaml  
     | 
      
    kubectl  
    apply  
    -f  
    - 
    

GKE creates a Job that launches a Pod on a TPU Trillium (v6e) node. This Pod runs the Python fine-tuning script, which accesses the fine-tuning data from the specified Cloud Storage bucket mounted at /data path by using Cloud Storage FUSE. The script then fine-tunes the Gemma model.

Monitor the training Job

In this section, you monitor the progress of the fine-tuning Job and its performance.

See fine-tuning progress

  1. List the Pods:

      # Find the Pods 
    kubectl  
    get  
    pods 
    
  2. Follow the log output:

     kubectl  
    logs  
    -f  
    pods/ POD_NAME 
     
    

    Replace POD_NAME with the name of your Pod.

    The output is similar to the following:

     Global device count: 1
    Batch size: 128, Max length: 256, Data count: 96000
    I1028 00:12:55.925999 1387 google_auth_provider.cc:181] Running on GCE, using service account ...
    Generating response for: Once upon a time, there was a girl named Amy.
    Response:
    Amy lived in a small house. The house was in a big field. Amy liked to play in the big field. She
    Start training model
    Loss after batch 0: 10.25
    Loss after batch 10: 4.3125
    .
    .
    .
    Loss after batch 740: 1.41406
    Completed training model. Total time for training 294.6791355609894 seconds
    Generating response for: Once upon a time, there was a girl named Amy.
    Response:
    She loved to play with her toys. One day, Amy's mom told her that she had to go to the store to 
    
  3. Analyze the output:

    • The Global device count: 1 line indicates the TPU cores used.
    • The model generates reasonable text before this fine-tuning run because it loads from a pre-trained checkpoint.
    • The output generated after fine-tuning shows more resemblance to the start of a short story, indicating the model is learning from the new dataset.
    • Fine-tuning on the full dataset should produce even more refined outputs.

Observe metrics

See the performance of the fine-tuning Job by checking the TPU and CPU metrics. To view observability metrics for your cluster, perform the steps in View cluster and workload observability metrics .

Alternative fine-tuning configurations

This section outlines alternative configurations for your fine-tuning workload.

Model selection

This tutorial used the Gemma3_270M model, which is a small model that fits into a single-host TPU Trillium (v6e) node pool. For larger models that require more memory and compute for fine-tuning, you can use multi-host or multislice node pool configurations.

For a complete list of available models, see the Gemma documentation .

Node pool configurations

This tutorial used a single-host node pool. You can also create multi-host TPU slice node pools or multislice node pools , depending on your needs.

The following tabs show how to create for multi-host and multislice node pools:

Multi-host

  1. In Cloud Shell, run the following command:

     gcloud  
    container  
    node-pools  
    create  
    jax-tpu-multihost1  
     \ 
      
    --cluster = 
     ${ 
     CLUSTER_NAME 
     } 
      
     \ 
      
    --machine-type = 
    ct6e-standard-4t  
     \ 
      
    --num-nodes = 
     2 
      
     \ 
      
    --tpu-topology = 
    2x4  
     \ 
      
    --location = 
     ${ 
     REGION 
     } 
      
     \ 
      
    --node-locations = 
     ${ 
     ZONE 
     } 
     
    

    GKE creates a TPU Trillium node pool with a 2x4 topology and two nodes.

  2. Open the training_multihost_jobset.yaml Job definition:

      # Copyright 2026 Google LLC 
     # 
     # Licensed under the Apache License, Version 2.0 (the "License"); 
     # you may not use this file except in compliance with the License. 
     # You may obtain a copy of the License at 
     # 
     #     http://www.apache.org/licenses/LICENSE-2.0 
     # 
     # Unless required by applicable law or agreed to in writing, software 
     # distributed under the License is distributed on an "AS IS" BASIS, 
     # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 
     # See the License for the specific language governing permissions and 
     # limitations under the License. 
     apiVersion 
     : 
      
     jobset.x-k8s.io/v1alpha2 
     kind 
     : 
      
     JobSet 
     metadata 
     : 
      
     name 
     : 
      
     jax-gemma3-train-multihost 
     spec 
     : 
      
     replicatedJobs 
     : 
      
     - 
      
     name 
     : 
      
     trainers 
      
     replicas 
     : 
      
     1 
      
     template 
     : 
      
     spec 
     : 
      
     parallelism 
     : 
      
     2 
      
     completions 
     : 
      
     2 
      
     backoffLimit 
     : 
      
     1 
      
     template 
     : 
      
     metadata 
     : 
      
     annotations 
     : 
      
     gke-gcsfuse/volumes 
     : 
      
     "true" 
      
     spec 
     : 
      
     serviceAccountName 
     : 
      
     jaxserviceaccout 
      
     nodeSelector 
     : 
      
     cloud.google.com/gke-tpu-accelerator 
     : 
      
     tpu-v6e-slice 
      
     cloud.google.com/gke-tpu-topology 
     : 
      
     2x4 
      
     cloud.google.com/gke-nodepool 
     : 
      
     jax-tpu-multihost1 
      
     containers 
     : 
      
     - 
      
     name 
     : 
      
     training-container 
      
     image 
     : 
      
     ${IMAGE_URI} 
      
      
     imagePullPolicy 
     : 
      
     "Always" 
      
     ports 
     : 
      
     - 
      
     containerPort 
     : 
      
     8471 
      
     command 
     : 
      
     [ 
     "python" 
     , 
      
     "Gemma3LLMTrain.py" 
     , 
      
     "--maxlen" 
     , 
      
     "256" 
     , 
      
     "--batch_size" 
     , 
      
     "64" 
     , 
      
     "--datacount" 
     , 
      
     "5120" 
     ] 
      
     resources 
     : 
      
     limits 
     : 
      
     google.com/tpu 
     : 
      
     4 
      
     volumeMounts 
     : 
      
     - 
      
     name 
     : 
      
     gcs-fuse-csi-ephemeral 
      
     mountPath 
     : 
      
     /data 
      
     volumes 
     : 
      
     - 
      
     name 
     : 
      
     gcs-fuse-csi-ephemeral 
      
     csi 
     : 
      
     driver 
     : 
      
     gcsfuse.csi.storage.gke.io 
      
     volumeAttributes 
     : 
      
     bucketName 
     : 
      
     ${GCS_BUCKET_NAME} 
      
     mountOptions 
     : 
      
     "implicit-dirs,file-cache:enable-parallel-downloads:true,file-cache:parallel-downloads-per-file:100,file-cache:max-parallel-downloads:-1,file-cache:download-chunk-size-mb:10,file-cache:max-size-mb:-1" 
     
    
  3. Deploy the fine-tuning Job:

     envsubst < 
    training_multihost_jobset.yaml  
     | 
      
    kubectl  
    apply  
    -f  
    - 
    

Multislice

  1. In Cloud Shell, run the following command:

     gcloud  
    container  
    node-pools  
    create  
    jax-tpu-multihost1  
     \ 
      
    --cluster = 
     ${ 
     CLUSTER_NAME 
     } 
      
     \ 
      
    --machine-type = 
    ct6e-standard-4t  
     \ 
      
    --num-nodes = 
     2 
      
     \ 
      
    --tpu-topology = 
    2x4  
     \ 
      
    --location = 
     ${ 
     REGION 
     } 
      
     \ 
      
    --node-locations = 
     ${ 
     ZONE 
     } 
    gcloud  
    container  
    node-pools  
    create  
    jax-tpu-multihost2  
     \ 
      
    --cluster = 
     ${ 
     CLUSTER_NAME 
     } 
      
     \ 
      
    --machine-type = 
    ct6e-standard-4t  
     \ 
      
    --num-nodes = 
     2 
      
     \ 
      
    --tpu-topology = 
    2x4  
     \ 
      
    --location = 
     ${ 
     REGION 
     } 
      
     \ 
      
    --node-locations = 
     ${ 
     ZONE 
     } 
     
    

    GKE creates two TPU Trillium node pools. Each node pool has a 2x4 topology and two nodes.

  2. Open the training_multislice_jobset.yaml Job definition:

      # Copyright 2026 Google LLC 
     # 
     # Licensed under the Apache License, Version 2.0 (the "License"); 
     # you may not use this file except in compliance with the License. 
     # You may obtain a copy of the License at 
     # 
     #     http://www.apache.org/licenses/LICENSE-2.0 
     # 
     # Unless required by applicable law or agreed to in writing, software 
     # distributed under the License is distributed on an "AS IS" BASIS, 
     # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 
     # See the License for the specific language governing permissions and 
     # limitations under the License. 
     apiVersion 
     : 
      
     jobset.x-k8s.io/v1alpha2 
     kind 
     : 
      
     JobSet 
     metadata 
     : 
      
     name 
     : 
      
     jax-gemma3-train-multislice 
     spec 
     : 
      
     replicatedJobs 
     : 
      
     - 
      
     name 
     : 
      
     trainers 
      
     replicas 
     : 
      
     2 
      
     template 
     : 
      
     spec 
     : 
      
     parallelism 
     : 
      
     2 
      
     completions 
     : 
      
     2 
      
     backoffLimit 
     : 
      
     1 
      
     template 
     : 
      
     metadata 
     : 
      
     annotations 
     : 
      
     gke-gcsfuse/volumes 
     : 
      
     "true" 
      
     spec 
     : 
      
     serviceAccountName 
     : 
      
     jaxserviceaccout 
      
     nodeSelector 
     : 
      
     cloud.google.com/gke-tpu-accelerator 
     : 
      
     tpu-v6e-slice 
      
     cloud.google.com/gke-tpu-topology 
     : 
      
     2x4 
      
     containers 
     : 
      
     - 
      
     name 
     : 
      
     training-container 
      
     image 
     : 
      
     ${IMAGE_URI} 
      
     imagePullPolicy 
     : 
      
     "Always" 
      
     ports 
     : 
      
     - 
      
     containerPort 
     : 
      
     8471 
      
     command 
     : 
      
     [ 
     "python" 
     , 
      
     "Gemma3LLMTrain.py" 
     , 
      
     "--maxlen" 
     , 
      
     "256" 
     , 
      
     "--batch_size" 
     , 
      
     "64" 
     , 
      
     "--datacount" 
     , 
      
     "5120" 
     ] 
      
     resources 
     : 
      
     limits 
     : 
      
     google.com/tpu 
     : 
      
     4 
      
     volumeMounts 
     : 
      
     - 
      
     name 
     : 
      
     gcs-fuse-csi-ephemeral 
      
     mountPath 
     : 
      
     /data 
      
     volumes 
     : 
      
     - 
      
     name 
     : 
      
     gcs-fuse-csi-ephemeral 
      
     csi 
     : 
      
     driver 
     : 
      
     gcsfuse.csi.storage.gke.io 
      
     volumeAttributes 
     : 
      
     bucketName 
     : 
      
     ${GCS_BUCKET_NAME} 
      
     mountOptions 
     : 
      
     "implicit-dirs,file-cache:enable-parallel-downloads:true,file-cache:parallel-downloads-per-file:100,file-cache:max-parallel-downloads:-1,file-cache:download-chunk-size-mb:10,file-cache:max-size-mb:100" 
     
    
  3. Deploy the fine-tuning Job:

     envsubst < 
    training_multislice_jobset.yaml  
     | 
      
    kubectl  
    apply  
    -f  
    - 
    

Performance analysis and optimization

To analyze and optimize the performance of your machine learning fine-tuning, you can use XProf . XProf is a suite of tools that profiles and inspects ML workloads built with JAX, TensorFlow, or PyTorch/XLA. By showing execution traces, memory usage, and other data, XProf lets you fine-tune your models and training setup for better efficiency and faster training.

To analyze the performance of your fine-tuning workload by using XProf, you complete the following steps in this section:

  • Install the xprof package. Modify your training script to start the XProf server.
  • Modify your Kubernetes Job manifest to include a volume mount for XProf logs.
  • Grant the service account permissions to write XProf logs to a Cloud Storage bucket.
  • Run XProf within your Pod and set up port forwarding to access the XProf dashboard.

Install the XProf package

  1. Navigate to the directory that contains the XProf samples:

       
     cd 
      
    ai-ml/llm-training-jax-tpu-gemma3/xprof-enabled 
    
  2. Build the Docker image and push it to an image repository:

      export 
      
     REPOSITORY 
     = 
     REPOSITORY_NAME 
     export 
      
     IMAGE_NAME 
     = 
     "jax-gemma3-training-xp" 
     export 
      
     IMAGE_TAG 
     = 
     "latest" 
     export 
      
     DOCKERFILE_PATH 
     = 
     "./Dockerfile" 
     export 
      
     IMAGE_URI 
     = 
     " 
     ${ 
     REGION 
     } 
     -docker.pkg.dev/ 
     ${ 
     PROJECT_ID 
     } 
     / 
     ${ 
     REPOSITORY 
     } 
     / 
     ${ 
     IMAGE_NAME 
     } 
     : 
     ${ 
     IMAGE_TAG 
     } 
     " 
    docker  
    build  
    -t  
     " 
     ${ 
     IMAGE_URI 
     } 
     " 
      
    -f  
     " 
     ${ 
     DOCKERFILE_PATH 
     } 
     " 
      
    .
    gcloud  
    auth  
    configure-docker  
     " 
     ${ 
     REGION 
     } 
     -docker.pkg.dev" 
      
    -q
    docker  
    push  
     " 
     ${ 
     IMAGE_URI 
     } 
     " 
     
    

    Replace REPOSITORY_NAME with the name of your Artifact Registry repository.

  3. Run the Dockerfile script:

      # Copyright 2026 Google LLC 
     # 
     # Licensed under the Apache License, Version 2.0 (the "License"); 
     # you may not use this file except in compliance with the License. 
     # You may obtain a copy of the License at 
     # 
     #     http://www.apache.org/licenses/LICENSE-2.0 
     # 
     # Unless required by applicable law or agreed to in writing, software 
     # distributed under the License is distributed on an "AS IS" BASIS, 
     # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 
     # See the License for the specific language governing permissions and 
     # limitations under the License. 
     FROM 
      
     us-docker.pkg.dev/cloud-tpu-images/jax-ai-image/tpu:jax0.7.2-rev1 
     RUN 
      
    apt-get  
    update && 
    apt-get  
    install  
    -y  
    wget && 
    rm  
    -rf  
    /var/lib/apt/lists/* RUN 
      
    pip  
    install  
    --upgrade  
    pip RUN 
      
    pip  
    install  
    gemma  
    grain  
    equinox RUN 
      
    pip  
    install  
    xprof WORKDIR 
      
     /app 
     # Copy your training script into the container 
     COPY 
      
    Gemma3LLMTrain.py  
    . 
    

    This Dockerfile installs XProf dependencies.

Copy your fine-tuning script into the container

In this section, create and apply a Kubernetes Job manifest that includes the necessary volume mounts for XProf logs.

  1. Open the training_singlehost.yaml Job definition:

      # Copyright 2026 Google LLC 
     # 
     # Licensed under the Apache License, Version 2.0 (the "License"); 
     # you may not use this file except in compliance with the License. 
     # You may obtain a copy of the License at 
     # 
     #     http://www.apache.org/licenses/LICENSE-2.0 
     # 
     # Unless required by applicable law or agreed to in writing, software 
     # distributed under the License is distributed on an "AS IS" BASIS, 
     # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 
     # See the License for the specific language governing permissions and 
     # limitations under the License. 
     apiVersion 
     : 
      
     batch/v1 
     kind 
     : 
      
     Job 
     metadata 
     : 
      
     name 
     : 
      
     jax-gemma3-train-singlehost 
     spec 
     : 
      
     template 
     : 
      
     metadata 
     : 
      
     annotations 
     : 
      
     gke-gcsfuse/volumes 
     : 
      
     "true" 
      
     spec 
     : 
      
     serviceAccountName 
     : 
      
     jaxserviceaccout 
      
     containers 
     : 
      
     - 
      
     name 
     : 
      
     training-container 
      
     image 
     : 
      
     ${IMAGE_URI} 
      
     imagePullPolicy 
     : 
      
     "Always" 
      
     command 
     : 
      
     [ 
     "python" 
     , 
      
     "Gemma3LLMTrain.py" 
     , 
      
     "--maxlen" 
     , 
      
     "256" 
     , 
      
     "--batch_size" 
     , 
      
     "64" 
     , 
      
     "--datacount" 
     , 
      
     "851200" 
     ] 
      
     resources 
     : 
      
     limits 
     : 
      
     google.com/tpu 
     : 
      
     1 
      
     volumeMounts 
     : 
      
     - 
      
     name 
     : 
      
     gcs-fuse-csi-ephemeral 
      
     mountPath 
     : 
      
     /data 
      
     - 
      
     name 
     : 
      
     gcs-fuse-csi-ephemeral2 
      
     mountPath 
     : 
      
     /xprof 
      
     nodeSelector 
     : 
      
     cloud.google.com/gke-tpu-accelerator 
     : 
      
     tpu-v6e-slice 
      
     cloud.google.com/gke-tpu-topology 
     : 
      
     1x1 
      
     restartPolicy 
     : 
      
     Never 
      
     volumes 
     : 
      
     - 
      
     name 
     : 
      
     gcs-fuse-csi-ephemeral 
      
     csi 
     : 
      
     driver 
     : 
      
     gcsfuse.csi.storage.gke.io 
      
     volumeAttributes 
     : 
      
     bucketName 
     : 
      
     ${GCS_BUCKET_NAME} 
      
     mountOptions 
     : 
      
     "implicit-dirs,file-cache:enable-parallel-downloads:true,file-cache:parallel-downloads-per-file:100,file-cache:max-parallel-downloads:-1,file-cache:download-chunk-size-mb:10,file-cache:max-size-mb:-1" 
      
     - 
      
     name 
     : 
      
     gcs-fuse-csi-ephemeral2 
      
     csi 
     : 
      
     driver 
     : 
      
     gcsfuse.csi.storage.gke.io 
      
     volumeAttributes 
     : 
      
     bucketName 
     : 
      
     ${XPROF_GCS_BUCKET_NAME} 
      
     mountOptions 
     : 
      
     "implicit-dirs,file-cache:enable-parallel-downloads:true,file-cache:parallel-downloads-per-file:100,file-cache:max-parallel-downloads:-1,file-cache:download-chunk-size-mb:10,file-cache:max-size-mb:-1" 
      
     backoffLimit 
     : 
      
     1 
     
    
  2. Apply the manifest:

     envsubst < 
    training_singlehost.yaml  
     | 
      
    kubectl  
    apply  
    -f  
    - 
    

Grant the service account permissions to write XProf logs

  1. To enable the service account to write and read, add the "roles/storage.objectUser" role:

      export 
      
     GSA_NAME 
     = 
     " GSA_NAME 
    " 
      
     # Same as used in initial setup 
     # Automatically get the current project ID 
     export 
      
     PROJECT_ID 
     = 
     $( 
    gcloud  
    config  
    get-value  
    project ) 
     # Cloud Storage Bucket details 
     export 
      
     XPROF_GCS_BUCKET_NAME 
     = 
     " XPROF_GCS_BUCKET_NAME 
    " 
     # Derived Variables 
     export 
      
     GSA_EMAIL 
     = 
     " 
     ${ 
     GSA_NAME 
     } 
     @ 
     ${ 
     PROJECT_ID 
     } 
     .iam.gserviceaccount.com" 
    gcloud  
    storage  
    buckets  
    add-iam-policy-binding  
     "gs:// 
     ${ 
     XPROF_GCS_BUCKET_NAME 
     } 
     " 
      
     \ 
      
    --member = 
     "serviceAccount: 
     ${ 
     GSA_EMAIL 
     } 
     " 
      
     \ 
      
    --role = 
     "roles/storage.objectUser" 
      
     \ 
      
    --project = 
     " 
     ${ 
     PROJECT_ID 
     } 
     " 
     
    

    Replace the following:

    • GSA_NAME : the name of the Google Service Account to which to grant the role.
    • XPROF_GCS_BUCKET_NAME : the name of the bucket to which to grant the role.
  2. Run XProf within your Pod:

     kubectl  
     exec 
      
     POD_NAME 
      
    -c  
    training-container  
    -it  
    --  
    bash  
     # exec into the container 
    xprof  
    --port  
     9001 
      
    --logdir  
    /xprof  
     # start xprof 
     
    

    Replace POD_NAME with the name of your Pod.

Access the XProf dashboard

  1. Set up port forwarding to the XProf server in the Pod:

     kubectl  
    port-forward  
     POD_NAME 
      
     9001 
    :9001 
    
  2. In your browser's address bar, enter the following:

     http://localhost:9001/ 
    

    The XProf Trace Viewer opens.

  3. In the TensorBoard window, click Capture profile.

  4. In the Profile Service URL(s) or TPU namefield, enter localhost:9002 .

  5. To capture more details, in the Host Trace (TraceMe) Level, select verboseand enable Python trace logging.

  6. To view the dashboard, click Capture.

    TensorBoard captures the profile and lets you analyze the performance of the training script. The graph shows the execution timeline for both TPU and CPU performance profiles:

An example of the XProf trace viewer that shows a performance matrix graph

For more profiling options to analyze your training workload performance, see the JAX documentation on Profiling computation .

Fine-tuning in production environments

This tutorial showed you how to test test JAX-based training in a distributed environment. For optimized LLM fine-tuning in production, use the Maxtext library. If you are interested in diffusion models, use Maxdiffusion implementations.

For long-running training or fine-tuning workloads in production, set up workload checkpointing to minimize progress loss during a failure. To learn more about setting up multi-tier checkpointing, see Train large-scale machine learning models on GKE with Multi-Tier Checkpointing .

Clean up

To avoid incurring charges to your Google Cloud account for the resources used in this tutorial, either delete the project that contains the resources, or keep the project and delete the individual resources.

Delete the individual resources

To avoid incurring charges to your Google Cloud account for the resources used in this tutorial, either delete the project that contains the resources, or keep the project and delete the individual resources by running the following commands:

  1. Delete the resources you created in this tutorial:

     gcloud  
    container  
    clusters  
    delete  
     ${ 
     CLUSTER_NAME 
     } 
      
    --location = 
     ${ 
     REGION 
     } 
    gcloud  
    storage  
    rm  
    --recursive  
    gs:// ${ 
     GCS_BUCKET_NAME 
     } 
    gcloud  
    artifacts  
    docker  
    images  
    delete  
     ${ 
     IMAGE_URI 
     } 
      
    --delete-tags 
    
  2. If you don't need the data generated by XProf, remove the Cloud Storage bucket used by XProf:

     gcloud  
    storage  
    rm  
    --recursive  
    gs:// ${ 
     XPROF_GCS_BUCKET_NAME 
     } 
     
    

What's next

Design a Mobile Site
View Site in Mobile | Classic
Share by: