JetStream PyTorch inference on v6e TPU VMs

This tutorial shows how to use JetStream to serve PyTorch models on TPU v6e. JetStream is a throughput and memory optimized engine for large language model (LLM) inference on XLA devices (TPUs). In this tutorial, you run the inference benchmark for the Llama2-7B model.

Before you begin

Prepare to provision a TPU v6e with 4 chips:

  1. Follow Set up the Cloud TPU environment guide to set up a Google Cloud project, configure the Google Cloud CLI, enable the Cloud TPU API, and ensure you have access to use Cloud TPUs.

  2. Authenticate with Google Cloud and configure the default project and zone for Google Cloud CLI.

    gcloud  
    auth  
    login
    gcloud  
    config  
     set 
      
    project  
     PROJECT_ID 
    gcloud  
    config  
     set 
      
    compute/zone  
     ZONE 
    

Secure capacity

When you are ready to secure TPU capacity, see Cloud TPU Quotas for more information about the Cloud TPU quotas. If you have additional questions about securing capacity, contact your Cloud TPU sales or account team.

Provision the Cloud TPU environment

You can provision TPU VMs with GKE , with GKE and XPK , or as queued resources .

Prerequisites

  • Verify that your project has enough TPUS_PER_TPU_FAMILY quota, which specifies the maximum number of chips you can access within your Google Cloud project.
  • Verify that your project has enough TPU quota for:
    • TPU VM quota
    • IP address quota
    • Hyperdisk Balanced quota
  • User project permissions

Create environment variables

In a Cloud Shell, create the following environment variables:

 export 
  
 PROJECT_ID 
 = 
 your-project-id 
 export 
  
 TPU_NAME 
 = 
 your-tpu-name 
 export 
  
 ZONE 
 = 
 us-central2-b 
 export 
  
 ACCELERATOR_TYPE 
 = 
 v6e-4 
 export 
  
 RUNTIME_VERSION 
 = 
 v2-alpha-tpuv6e 
 export 
  
 SERVICE_ACCOUNT 
 = 
 your-service-account 
 export 
  
 QUEUED_RESOURCE_ID 
 = 
 your-queued-resource-id 

Environment variable descriptions

Variable
Description
PROJECT_ID
Your Google Cloud project ID. Use an existing project or create a new one.
TPU_NAME
The name of the TPU.
ZONE
The zone in which to create the TPU VM. For more information about supported zones, see TPU regions and zones .
ACCELERATOR_TYPE
The accelerator type specifies the version and size of the Cloud TPU you want to create. For more information about supported accelerator types for each TPU version, see TPU versions .
RUNTIME_VERSION
The Cloud TPU software version .
SERVICE_ACCOUNT
The email address for your service account. You can find it by going to the Service Accounts page in the Google Cloud console.

For example: tpu-service-account@ PROJECT_ID .iam.gserviceaccount.com

QUEUED_RESOURCE_ID
The user-assigned text ID of the queued resource request.

Provision a TPU v6e

  
gcloud  
alpha  
compute  
tpus  
queued-resources  
create  
 ${ 
 QUEUED_RESOURCE_ID 
 } 
  
 \ 
  
--node-id  
 ${ 
 TPU_NAME 
 } 
  
 \ 
  
--project  
 ${ 
 PROJECT_ID 
 } 
  
 \ 
  
--zone  
 ${ 
 ZONE 
 } 
  
 \ 
  
--accelerator-type  
 ${ 
 ACCELERATOR_TYPE 
 } 
  
 \ 
  
--runtime-version  
 ${ 
 RUNTIME_VERSION 
 } 
  
 \ 
  
--service-account  
 ${ 
 SERVICE_ACCOUNT 
 } 
  

Use the list or describe commands to query the status of your queued resource.

   
gcloud  
alpha  
compute  
tpus  
queued-resources  
describe  
 ${ 
 QUEUED_RESOURCE_ID 
 } 
  
 \ 
  
--project  
 ${ 
 PROJECT_ID 
 } 
  
--zone  
 ${ 
 ZONE 
 } 
 

For a complete list of queued resource request statuses, see the Queued Resources documentation.

Connect to the TPU using SSH

  
gcloud  
compute  
tpus  
tpu-vm  
ssh  
 ${ 
 TPU_NAME 
 } 

Run the JetStream PyTorch Llama2-7B benchmark

To set up JetStream-PyTorch, convert the model checkpoints, and run the inference benchmark, follow the instructions in the GitHub repository .

When the inference benchmark is complete, be sure to clean up the TPU resources.

Clean up

Delete the TPU:

   
gcloud  
compute  
tpus  
queued-resources  
delete  
 ${ 
 QUEUED_RESOURCE_ID 
 } 
  
 \ 
  
--project  
 ${ 
 PROJECT_ID 
 } 
  
 \ 
  
--zone  
 ${ 
 ZONE 
 } 
  
 \ 
  
--force  
 \ 
  
--async 
Design a Mobile Site
View Site in Mobile | Classic
Share by: