JetStream PyTorch inference on v6e TPU VMs
This tutorial shows how to use JetStream to serve PyTorch models on TPU v6e. JetStream is a throughput and memory optimized engine for large language model (LLM) inference on XLA devices (TPUs). In this tutorial, you run the inference benchmark for the Llama2-7B model.
Before you begin
Prepare to provision a TPU v6e with 4 chips:
-
Follow Set up the Cloud TPU environment guide to set up a Google Cloud project, configure the Google Cloud CLI, enable the Cloud TPU API, and ensure you have access to use Cloud TPUs.
-
Authenticate with Google Cloud and configure the default project and zone for Google Cloud CLI.
gcloud auth login gcloud config set project PROJECT_ID gcloud config set compute/zone ZONE
Secure capacity
When you are ready to secure TPU capacity, see Cloud TPU Quotas for more information about the Cloud TPU quotas. If you have additional questions about securing capacity, contact your Cloud TPU sales or account team.
Provision the Cloud TPU environment
You can provision TPU VMs with GKE , with GKE and XPK , or as queued resources .
Prerequisites
- Verify that your project has enough
TPUS_PER_TPU_FAMILY
quota, which specifies the maximum number of chips you can access within your Google Cloud project. - Verify that your project has enough TPU quota for:
- TPU VM quota
- IP address quota
- Hyperdisk Balanced quota
- User project permissions
- If you are using GKE with XPK, see Cloud Console Permissions on the user or service account for the permissions needed to run XPK.
Create environment variables
In a Cloud Shell, create the following environment variables:
export PROJECT_ID = your-project-id export TPU_NAME = your-tpu-name export ZONE = us-central2-b export ACCELERATOR_TYPE = v6e-4 export RUNTIME_VERSION = v2-alpha-tpuv6e export SERVICE_ACCOUNT = your-service-account export QUEUED_RESOURCE_ID = your-queued-resource-id
Environment variable descriptions
PROJECT_ID
TPU_NAME
ZONE
ACCELERATOR_TYPE
RUNTIME_VERSION
SERVICE_ACCOUNT
For example: tpu-service-account@ PROJECT_ID
.iam.gserviceaccount.com
QUEUED_RESOURCE_ID
Provision a TPU v6e
gcloud alpha compute tpus queued-resources create ${ QUEUED_RESOURCE_ID } \ --node-id ${ TPU_NAME } \ --project ${ PROJECT_ID } \ --zone ${ ZONE } \ --accelerator-type ${ ACCELERATOR_TYPE } \ --runtime-version ${ RUNTIME_VERSION } \ --service-account ${ SERVICE_ACCOUNT }
Use the list
or describe
commands
to query the status of your queued resource.
gcloud
alpha
compute
tpus
queued-resources
describe
${
QUEUED_RESOURCE_ID
}
\
--project
${
PROJECT_ID
}
--zone
${
ZONE
}
For a complete list of queued resource request statuses, see the Queued Resources documentation.
Connect to the TPU using SSH
gcloud compute tpus tpu-vm ssh ${ TPU_NAME }
Run the JetStream PyTorch Llama2-7B benchmark
To set up JetStream-PyTorch, convert the model checkpoints, and run the inference benchmark, follow the instructions in the GitHub repository .
When the inference benchmark is complete, be sure to clean up the TPU resources.
Clean up
Delete the TPU:
gcloud
compute
tpus
queued-resources
delete
${
QUEUED_RESOURCE_ID
}
\
--project
${
PROJECT_ID
}
\
--zone
${
ZONE
}
\
--force
\
--async