Train a TensorFlow model with Keras on Google Kubernetes Engine

The following section provides an example of fine-tuning a BERT model for sequence classification using the Hugging Face transformers library with TensorFlow. The dataset is downloaded into a mounted Parallelstore-backed volume, allowing the model training to directly read data from the volume.

Prerequisites

Save the following YAML manifest ( parallelstore-csi-job-example.yaml ) for your model training Job.

   
 apiVersion 
 : 
  
 batch/v1 
  
 kind 
 : 
  
 Job 
  
 metadata 
 : 
  
 name 
 : 
  
 parallelstore-csi-job-example 
  
 spec 
 : 
  
 template 
 : 
  
 metadata 
 : 
  
 annotations 
 : 
  
 gke-parallelstore/cpu-limit 
 : 
  
 "0" 
  
 gke-parallelstore/memory-limit 
 : 
  
 "0" 
  
 spec 
 : 
  
 securityContext 
 : 
  
 runAsUser 
 : 
  
 1000 
  
 runAsGroup 
 : 
  
 100 
  
 fsGroup 
 : 
  
 100 
  
 containers 
 : 
  
 - 
  
 name 
 : 
  
 tensorflow 
  
 image 
 : 
  
 jupyter/tensorflow-notebook@sha256:173f124f638efe870bb2b535e01a76a80a95217e66ed00751058c51c09d6d85d 
  
 command 
 : 
  
 [ 
 "bash" 
 , 
  
 "-c" 
 ] 
  
 args 
 : 
  
 - 
  
 | 
  
 pip install transformers datasets 
  
 python - <<EOF 
  
 from datasets import load_dataset 
  
 dataset = load_dataset("glue", "cola", cache_dir='/data') 
  
 dataset = dataset["train"] 
  
 from transformers import AutoTokenizer 
  
 import numpy as np 
  
 tokenizer = AutoTokenizer.from_pretrained("bert-base-cased") 
  
 tokenized_data = tokenizer(dataset["sentence"], return_tensors="np", padding=True) 
  
 tokenized_data = dict(tokenized_data) 
  
 labels = np.array(dataset["label"]) 
  
 from transformers import TFAutoModelForSequenceClassification 
  
 from tensorflow.keras.optimizers import Adam 
  
 model = TFAutoModelForSequenceClassification.from_pretrained("bert-base-cased") 
  
 model.compile(optimizer=Adam(3e-5)) 
  
 model.fit(tokenized_data, labels) 
  
 EOF 
  
 volumeMounts 
 : 
  
 - 
  
 name 
 : 
  
 parallelstore-volume 
  
 mountPath 
 : 
  
 /data 
  
 volumes 
 : 
  
 - 
  
 name 
 : 
  
 parallelstore-volume 
  
 persistentVolumeClaim 
 : 
  
 claimName 
 : 
  
 parallelstore-pvc 
  
 restartPolicy 
 : 
  
 Never 
  
 backoffLimit 
 : 
  
 1 
 

Apply the YAML manifest to the cluster.

kubectl apply -f parallelstore-csi-job-example.yaml

Check your data loading and model training progress with the following command:

  POD_NAME 
 = 
 $( 
kubectl  
get  
pod  
 | 
  
grep  
 'parallelstore-csi-job-example' 
  
 | 
  
awk  
 '{print $1}' 
 ) 
kubectl  
logs  
-f  
 $POD_NAME 
  
-c  
tensorflow 
Create a Mobile Website
View Site in Mobile | Classic
Share by: