Training Resnet50 on Cloud TPU with PyTorch


This tutorial shows you how to train the ResNet-50 model on a Cloud TPU device with PyTorch. You can apply the same pattern to other TPU-optimised image classification models that use PyTorch and the ImageNet dataset.

The model in this tutorial is based on Deep Residual Learning for Image Recognition , which first introduces the residual network (ResNet) architecture. The tutorial uses the 50-layer variant, ResNet-50, and demonstrates training the model using PyTorch/XLA .

Objectives

  • Prepare the dataset.
  • Run the training job.
  • Verify the output results.

Costs

In this document, you use the following billable components of Google Cloud:

  • Compute Engine
  • Cloud TPU

To generate a cost estimate based on your projected usage, use the pricing calculator .

New Google Cloud users might be eligible for a free trial .

Before you begin

Before starting this tutorial, check that your Google Cloud project is correctly set up.

  1. Sign in to your Google Cloud account. If you're new to Google Cloud, create an account to evaluate how our products perform in real-world scenarios. New customers also get $300 in free credits to run, test, and deploy workloads.
  2. In the Google Cloud console, on the project selector page, select or create a Google Cloud project.

    Go to project selector

  3. Verify that billing is enabled for your Google Cloud project .

  4. In the Google Cloud console, on the project selector page, select or create a Google Cloud project.

    Go to project selector

  5. Verify that billing is enabled for your Google Cloud project .

  6. This walkthrough uses billable components of Google Cloud. Check the Cloud TPU pricing page to estimate your costs. Be sure to clean up resources you created when you've finished with them to avoid unnecessary charges.

Create a TPU VM

  1. Open a Cloud Shell window.

    Open Cloud Shell

  2. Create a TPU VM

    gcloud  
    compute  
    tpus  
    tpu-vm  
    create  
     your-tpu-name 
      
     \ 
    --accelerator-type = 
     v3-8 
      
     \ 
    --version = 
     tpu-ubuntu2204-base 
      
     \ 
    --zone = 
     us-central1-a 
      
     \ 
    --project = 
     your-project 
    
  3. Connect to your TPU VM using SSH:

    gcloud  
    compute  
    tpus  
    tpu-vm  
    ssh  
     your-tpu-name 
      
    --zone = 
     us-central1-a 
    
  4. Install PyTorch/XLA on your TPU VM:

      ( 
    vm ) 
    $  
     
    pip  
    install  
    torch  
    torch_xla [ 
    tpu ] 
      
    torchvision  
    -f  
    https://storage.googleapis.com/libtpu-releases/index.html  
    -f  
    https://storage.googleapis.com/libtpu-wheels/index.html
  5. Clone the PyTorch/XLA GitHub repo

      ( 
    vm ) 
    $  
     
    git  
    clone  
    --depth = 
     1 
      
    https://github.com/pytorch/xla.git
  6. Run the training script with fake data

      ( 
    vm ) 
      
    $  
     
     PJRT_DEVICE 
     = 
    TPU  
    python3  
    xla/test/test_train_mp_imagenet.py  
    --fake_data  
    --batch_size = 
     256 
      
    --num_epochs = 
     1 
    

Clean up

To avoid incurring charges to your Google Cloud account for the resources used in this tutorial, either delete the project that contains the resources, or keep the project and delete the individual resources.

  1. Disconnect from the TPU VM:

      ( 
    vm ) 
      
    $  
     
     exit 
    

    Your prompt should now be username@projectname , showing you are in the Cloud Shell.

  2. Delete your TPU VM.

     $  
     
    gcloud  
    compute  
    tpus  
    tpu-vm  
    delete  
     your-tpu-name 
      
     \ 
      
    --zone = 
     us-central1-a 
    

What's next

Design a Mobile Site
View Site in Mobile | Classic
Share by: