JetStream MaxText inference on v6e TPU VMs

This tutorial shows how to use JetStream to serve MaxText models on TPU v6e. JetStream is a throughput and memory optimized engine for large language model (LLM) inference on XLA devices (TPUs). In this tutorial, you run the inference benchmark for the Llama2-7B model.

Before you begin

Prepare to provision a TPU v6e with 4 chips:

  1. Follow Set up the Cloud TPU environment guide to set up a Google Cloud project, configure the Google Cloud CLI, enable the Cloud TPU API, and ensure you have access to use Cloud TPUs.

  2. Authenticate with Google Cloud and configure the default project and zone for Google Cloud CLI.

    gcloud  
    auth  
    login
    gcloud  
    config  
     set 
      
    project  
     PROJECT_ID 
    gcloud  
    config  
     set 
      
    compute/zone  
     ZONE 
    

Secure capacity

When you are ready to secure TPU capacity, see Cloud TPU Quotas for more information about the Cloud TPU quotas. If you have additional questions about securing capacity, contact your Cloud TPU sales or account team.

Provision the Cloud TPU environment

You can provision TPU VMs with GKE , with GKE and XPK , or as queued resources .

Prerequisites

  • Verify that your project has enough TPUS_PER_TPU_FAMILY quota, which specifies the maximum number of chips you can access within your Google Cloud project.
  • Verify that your project has enough TPU quota for:
    • TPU VM quota
    • IP address quota
    • Hyperdisk Balanced quota
  • User project permissions

Create environment variables

In a Cloud Shell, create the following environment variables:

 export 
  
 PROJECT_ID 
 = 
 your-project-id 
 export 
  
 TPU_NAME 
 = 
 your-tpu-name 
 export 
  
 ZONE 
 = 
 us-east5-b 
 export 
  
 ACCELERATOR_TYPE 
 = 
 v6e-4 
 export 
  
 RUNTIME_VERSION 
 = 
 v2-alpha-tpuv6e 
 export 
  
 SERVICE_ACCOUNT 
 = 
 your-service-account 
 export 
  
 QUEUED_RESOURCE_ID 
 = 
 your-queued-resource-id 

Environment variable descriptions

Variable
Description
PROJECT_ID
Your Google Cloud project ID. Use an existing project or create a new one.
TPU_NAME
The name of the TPU.
ZONE
The zone in which to create the TPU VM. For more information about supported zones, see TPU regions and zones .
ACCELERATOR_TYPE
The accelerator type specifies the version and size of the Cloud TPU you want to create. For more information about supported accelerator types for each TPU version, see TPU versions .
RUNTIME_VERSION
The Cloud TPU software version .
SERVICE_ACCOUNT
The email address for your service account. You can find it by going to the Service Accounts page in the Google Cloud console.

For example: tpu-service-account@ PROJECT_ID .iam.gserviceaccount.com

QUEUED_RESOURCE_ID
The user-assigned text ID of the queued resource request.

Provision a TPU v6e

Use the following command to provision a TPU v6e:

gcloud  
alpha  
compute  
tpus  
queued-resources  
create  
 ${ 
 QUEUED_RESOURCE_ID 
 } 
  
 \ 
  
--node-id = 
 ${ 
 TPU_NAME 
 } 
  
 \ 
  
--project = 
 ${ 
 PROJECT_ID 
 } 
  
 \ 
  
--zone = 
 ${ 
 ZONE 
 } 
  
 \ 
  
--accelerator-type = 
 ${ 
 ACCELERATOR_TYPE 
 } 
  
 \ 
  
--runtime-version = 
 ${ 
 RUNTIME_VERSION 
 } 
  
 \ 
  
--service-account = 
 ${ 
 SERVICE_ACCOUNT 
 } 

Use the list or describe commands to query the status of your queued resource.

 gcloud  
alpha  
compute  
tpus  
queued-resources  
describe  
 ${ 
 QUEUED_RESOURCE_ID 
 } 
  
 \ 
  
--project  
 ${ 
 PROJECT_ID 
 } 
  
--zone  
 ${ 
 ZONE 
 } 
 

For more information about queued resource request statuses, see Manage queued resources .

Connect to the TPU using SSH

  
gcloud  
compute  
tpus  
tpu-vm  
ssh  
 ${ 
 TPU_NAME 
 } 

Once you have connected to the TPU, you can run the inference benchmark.

Set up your TPU VM environment

  1. Create a directory for running the inference benchmark:

     export 
      
     MAIN_DIR 
     = 
     your-main-directory 
    mkdir  
    -p  
     ${ 
     MAIN_DIR 
     } 
    
  2. Set up a Python virtual environment:

     cd 
      
     ${ 
     MAIN_DIR 
     } 
    sudo  
    apt  
    update
    sudo  
    apt  
    install  
    python3.10  
    python3.10-venv
    python3.10  
    -m  
    venv  
    venv source 
      
    venv/bin/activate
  3. Install Git Large File Storage (LFS) (for OpenOrca data):

    sudo  
    apt-get  
    install  
    git-lfs
    git  
    lfs  
    install
  4. Clone and install JetStream:

     cd 
      
     $MAIN_DIR 
    git  
    clone  
    https://github.com/google/JetStream.git cd 
      
    JetStream
    git  
    checkout  
    main
    pip  
    install  
    -e  
    . cd 
      
    benchmarks
    pip  
    install  
    -r  
    requirements.in
  5. Set up MaxText:

     cd 
      
     $MAIN_DIR 
    git  
    clone  
    https://github.com/google/maxtext.git cd 
      
    maxtext
    git  
    checkout  
    main
    bash  
    setup.sh
    pip  
    install  
    torch  
    --index-url  
    https://download.pytorch.org/whl/cpu
  6. Request Access to Llama Models to get a download key from Meta for the Llama 2 model.

  7. Clone the Llama repository:

     cd 
      
     $MAIN_DIR 
    git  
    clone  
    https://github.com/meta-llama/llama cd 
      
    llama
  8. Run bash download.sh . When prompted, provide your download key. This script creates a llama-2-7b directory inside your llama directory.

    bash  
    download.sh
  9. Create storage buckets:

     export 
      
     CHKPT_BUCKET 
     = 
    gs:// your-checkpoint-bucket 
     export 
      
     BASE_OUTPUT_DIRECTORY 
     = 
    gs:// your-output-dir 
     export 
      
     CONVERTED_CHECKPOINT_PATH 
     = 
    gs:// bucket-to-store-converted-checkpoints 
     export 
      
     MAXTEXT_BUCKET_UNSCANNED 
     = 
    gs:// bucket-to-store-unscanned-data 
    gcloud  
    storage  
    buckets  
    create  
     ${ 
     CHKPT_BUCKET 
     } 
    gcloud  
    storage  
    buckets  
    create  
     ${ 
     BASE_OUTPUT_DIRECTORY 
     } 
    gcloud  
    storage  
    buckets  
    create  
     ${ 
     CONVERTED_CHECKPOINT_PATH 
     } 
    gcloud  
    storage  
    buckets  
    create  
     ${ 
     MAXTEXT_BUCKET_UNSCANNED 
     } 
    gcloud  
    storage  
    cp  
    --recursive  
    llama-2-7b/*  
     ${ 
     CHKPT_BUCKET 
     } 
    

Perform checkpoint conversion

  1. Perform conversion to scanned checkpoints:

     cd 
      
     $MAIN_DIR 
    /maxtext
    python3  
    -m  
    MaxText.llama_or_mistral_ckpt  
     \ 
      
    --base-model-path  
     $MAIN_DIR 
    /llama/llama-2-7b  
     \ 
      
    --model-size  
    llama2-7b  
     \ 
      
    --maxtext-model-path  
     ${ 
     CONVERTED_CHECKPOINT_PATH 
     } 
    
  2. Convert to unscanned checkpoints:

     export 
      
     CONVERTED_CHECKPOINT 
     = 
     ${ 
     CONVERTED_CHECKPOINT_PATH 
     } 
    /0/items export 
      
     DIRECT_PARAMETER_CHECKPOINT_RUN 
     = 
    direct_generate_param_only_checkpoint
    python3  
    -m  
    MaxText.generate_param_only_checkpoint  
     \ 
      
    MaxText/configs/base.yml  
     \ 
      
     base_output_directory 
     = 
     ${ 
     MAXTEXT_BUCKET_UNSCANNED 
     } 
      
     \ 
      
     load_parameters_path 
     = 
     ${ 
     CONVERTED_CHECKPOINT 
     } 
      
     \ 
      
     run_name 
     = 
     ${ 
     DIRECT_PARAMETER_CHECKPOINT_RUN 
     } 
      
     \ 
      
     model_name 
     = 
     'llama2-7b' 
      
     \ 
      
     force_unroll 
     = 
     true 
    

Perform inference

  1. Run a validation test:

     export 
      
     UNSCANNED_CKPT_PATH 
     = 
     ${ 
     MAXTEXT_BUCKET_UNSCANNED 
     } 
    / ${ 
     DIRECT_PARAMETER_CHECKPOINT_RUN 
     } 
    /checkpoints/0/items
    python3  
    -m  
    MaxText.decode  
     \ 
      
    MaxText/configs/base.yml  
     \ 
      
     load_parameters_path 
     = 
     ${ 
     UNSCANNED_CKPT_PATH 
     } 
      
     \ 
      
     run_name 
     = 
    runner_decode_unscanned_ ${ 
     idx 
     } 
      
     \ 
      
     base_output_directory 
     = 
     ${ 
     BASE_OUTPUT_DIRECTORY 
     } 
      
     \ 
      
     per_device_batch_size 
     = 
     1 
      
     \ 
      
     model_name 
     = 
     'llama2-7b' 
      
     \ 
      
     ici_autoregressive_parallelism 
     = 
     4 
      
     \ 
      
     max_prefill_predict_length 
     = 
     4 
      
     \ 
      
     max_target_length 
     = 
     16 
      
     \ 
      
     prompt 
     = 
     "I love to" 
      
     \ 
      
     attention 
     = 
    dot_product  
     \ 
      
     scan_layers 
     = 
     false 
    
  2. Run the server in your current terminal:

     export 
      
     TOKENIZER_PATH 
     = 
    assets/tokenizer.llama2 export 
      
     LOAD_PARAMETERS_PATH 
     = 
     ${ 
     UNSCANNED_CKPT_PATH 
     } 
     export 
      
     MAX_PREFILL_PREDICT_LENGTH 
     = 
     1024 
     export 
      
     MAX_TARGET_LENGTH 
     = 
     2048 
     export 
      
     MODEL_NAME 
     = 
    llama2-7b export 
      
     ICI_FSDP_PARALLELISM 
     = 
     1 
     export 
      
     ICI_AUTOREGRESSIVE_PARALLELISM 
     = 
     1 
     export 
      
     ICI_TENSOR_PARALLELISM 
     = 
    -1 export 
      
     SCAN_LAYERS 
     = 
     false 
     export 
      
     WEIGHT_DTYPE 
     = 
    bfloat16 export 
      
     PER_DEVICE_BATCH_SIZE 
     = 
     11 
     cd 
      
     $MAIN_DIR 
    /maxtext
    python3  
    -m  
    MaxText.maxengine_server  
     \ 
      
    MaxText/configs/base.yml  
     \ 
      
     tokenizer_path 
     = 
     ${ 
     TOKENIZER_PATH 
     } 
      
     \ 
      
     load_parameters_path 
     = 
     ${ 
     LOAD_PARAMETERS_PATH 
     } 
      
     \ 
      
     max_prefill_predict_length 
     = 
     ${ 
     MAX_PREFILL_PREDICT_LENGTH 
     } 
      
     \ 
      
     max_target_length 
     = 
     ${ 
     MAX_TARGET_LENGTH 
     } 
      
     \ 
      
     model_name 
     = 
     ${ 
     MODEL_NAME 
     } 
      
     \ 
      
     ici_fsdp_parallelism 
     = 
     ${ 
     ICI_FSDP_PARALLELISM 
     } 
      
     \ 
      
     ici_autoregressive_parallelism 
     = 
     ${ 
     ICI_AUTOREGRESSIVE_PARALLELISM 
     } 
      
     \ 
      
     ici_tensor_parallelism 
     = 
     ${ 
     ICI_TENSOR_PARALLELISM 
     } 
      
     \ 
      
     scan_layers 
     = 
     ${ 
     SCAN_LAYERS 
     } 
      
     \ 
      
     weight_dtype 
     = 
     ${ 
     WEIGHT_DTYPE 
     } 
      
     \ 
      
     per_device_batch_size 
     = 
     ${ 
     PER_DEVICE_BATCH_SIZE 
     } 
    
  3. Open a new terminal window, connect to the TPU, and switch to the same virtual environment you used in the first terminal window:

      source 
      
    venv/bin/activate 
    
  4. Run the following commands to run the JetStream benchmark.

     export 
      
     MAIN_DIR 
     = 
     your-main-directory 
     cd 
      
     $MAIN_DIR 
    python  
    JetStream/benchmarks/benchmark_serving.py  
     \ 
      
    --tokenizer  
     $MAIN_DIR 
    /maxtext/assets/tokenizer.llama2  
     \ 
      
    --warmup-mode  
    sampled  
     \ 
      
    --save-result  
     \ 
      
    --save-request-outputs  
     \ 
      
    --request-outputs-file-path  
    outputs.json  
     \ 
      
    --num-prompts  
     1000 
      
     \ 
      
    --max-output-length  
     1024 
      
     \ 
      
    --dataset  
    openorca  
     \ 
      
    --dataset-path  
     $MAIN_DIR 
    /JetStream/benchmarks/open_orca_gpt4_tokenized_llama.calibration_1000.pkl

Results

The following output was generated when running the benchmark using v6e-8. Results will vary based on hardware, software, model, and networking.

 Mean output size: 929.5959798994975
Median output size: 1026.0
P99 output size: 1026.0
Successful requests: 995
Benchmark duration: 195.533269 s
Total input tokens: 217011
Total generated tokens: 924948
Request throughput: 5.09 requests/s
Input token throughput: 1109.84 tokens/s
Output token throughput: 4730.39 tokens/s
Overall token throughput: 5840.23 tokens/s
Mean ttft: 538.49 ms
Median ttft: 95.66 ms
P99 ttft: 13937.86 ms
Mean ttst: 1218.72 ms
Median ttst: 152.57 ms
P99 ttst: 14241.30 ms
Mean TPOT: 91.83 ms
Median TPOT: 16.63 ms
P99 TPOT: 363.37 ms 

Clean up

  1. Disconnect from the TPU:

     $  
     ( 
    vm ) 
      
     
     exit 
    
  2. Delete the TPU:

    gcloud  
    compute  
    tpus  
    queued-resources  
    delete  
     ${ 
     QUEUED_RESOURCE_ID 
     } 
      
     \ 
      
    --project  
     ${ 
     PROJECT_ID 
     } 
      
     \ 
      
    --zone  
     ${ 
     ZONE 
     } 
      
     \ 
      
    --force  
     \ 
      
    --async
  3. Delete the buckets and their contents:

     export 
      
     CHKPT_BUCKET 
     = 
    gs:// your-checkpoint-bucket 
     export 
      
     BASE_OUTPUT_DIRECTORY 
     = 
    gs:// your-output-dir 
     export 
      
     CONVERTED_CHECKPOINT_PATH 
     = 
    gs:// bucket-to-store-converted-checkpoints 
     export 
      
     MAXTEXT_BUCKET_UNSCANNED 
     = 
    gs:// bucket-to-store-unscanned-data 
    gcloud  
    storage  
    rm  
    -r  
     ${ 
     CHKPT_BUCKET 
     } 
    gcloud  
    storage  
    rm  
    -r  
     ${ 
     BASE_OUTPUT_DIRECTORY 
     } 
    gcloud  
    storage  
    rm  
    -r  
     ${ 
     CONVERTED_CHECKPOINT_PATH 
     } 
    gcloud  
    storage  
    rm  
    -r  
     ${ 
     MAXTEXT_BUCKET_UNSCANNED 
     } 
    
Create a Mobile Website
View Site in Mobile | Classic
Share by: