TPU Monitoring Library

Unlock deep insights into your Cloud TPU hardware's performance and behavior with advanced TPU monitoring capabilities, built directly upon the foundational software layer, LibTPU. While LibTPU encompasses drivers, networking libraries, the XLA compiler, and TPU runtime for interacting with TPUs, the focus of this document is the TPU Monitoring Library.

The TPU Monitoring Library provides:

  • Comprehensive observability: Gain access to telemetry API and metrics suite. This lets you obtain detailed insights into the operational performance and specific behaviors of your TPUs.

  • Diagnostic toolkits: Provides an SDK and command-line interface (CLI) designed to enable debugging and in-depth performance analysis of your TPU resources.

These monitoring features are designed to be a top-level, customer-facing solution, providing you with the essential tools to optimize your TPU workloads effectively.

The TPU Monitoring Library gives you detailed information on how machine learning workloads are performing on TPU hardware. It's designed to help you understand your TPU utilization, identify bottlenecks, and debug performance issues. It gives you more detailed information than interruption metrics, goodput metrics, and other metrics.

Get started with the TPU Monitoring Library

Accessing these powerful insights is straightforward. The TPU monitoring functionality is integrated with the LibTPU SDK, so the functionality is included when you install LibTPU.

Install LibTPU

 pip  
install  
libtpu 

Alternately, the LibTPU updates are coordinated with JAX releases, meaning that when you install the latest JAX release (released monthly), it will typically pin you to the latest compatible LibTPU version and its features.

Install JAX

 pip  
install  
-U  
 "jax[tpu]" 
 

For PyTorch users, installing PyTorch/XLA provides the latest LibTPU and TPU monitoring functionality.

Install PyTorch/XLA

 pip  
install  
torch~ = 
 2 
.6.0  
 'torch_xla[tpu]~=2.6.0' 
  
 \ 
  
-f  
https://storage.googleapis.com/libtpu-releases/index.html  
 \ 
  
-f  
https://storage.googleapis.com/libtpu-wheels/index.html  
 # Optional: if you're using custom kernels, install pallas dependencies 
pip  
install  
 'torch_xla[pallas]' 
  
 \ 
  
-f  
https://storage.googleapis.com/jax-releases/jax_nightly_releases.html  
 \ 
  
-f  
https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html 

For more information about installing PyTorch/XLA, see Installation in the PyTorch/XLA GitHub repository.

Import the library in Python

To start using the TPU Monitoring Library, you need to import the libtpu module in your Python code.

 from  
libtpu.sdk  
import  
tpumonitoring 

List all supported functionality

List all metric names and the functionality they support:

 from  
libtpu.sdk  
import  
tpumonitoring

tpumonitoring.help () 
 " libtpu.sdk.monitoring.help(): 
 List all supported functionality. 
 libtpu.sdk.monitoring.list_support_metrics() 
 List support metric names in the list of str format. 
 libtpu.sdk.monitoring.get_metric(metric_name:str) 
 Get metric data with metric name. It represents the snapshot mode. 
 The metric data is a object with `description()` and `data()` methods, 
 where the `description()` returns a string describe the format of data 
 and data unit, `data()` returns the metric data in the list in str format. 
 " 
 

Supported metrics

The following code sample shows how to list all supported metric names:

 from libtpu.sdk import tpumonitoring

tpumonitoring.list_supported_metrics()

["duty_cycle_pct", "tensorcore_util", "hbm_util", ...] 

The following table shows all metrics and their corresponding definitions:

Metric Definition Metric name for API Example values
Tensor Core Utilization
Measures the percentage of your TensorCore usage, calculated as the percentage of operations that are part of the TensorCore operations. Sampled 10 microseconds every 1 second. You cannot modify the sampling rate. This metric lets you monitor the efficiency of your workloads on TPU devices. tensorcore_util ['1.11', '2.22', '3.33', '4.44']

# utilization percentage for accelerator ID 0-3
Duty Cycle Percentage
Percentage of time over the past sample period (every 5 seconds; can be tuned by setting the LIBTPU_INIT_ARG flag) during which the accelerator was actively processing (recorded with cycles used to execute HLO programs over the last sampling period). This metric represents how busy a TPU is. The metric is emitted per chip. duty_cycle_pct ['10.00', '20.00', '30.00', '40.00']

# Duty cycle percentage for accelerator ID 0-3
HBM Capacity Total
This metric reports the total HBM capacity in bytes. hbm_capacity_total ['30000000000', '30000000000', '30000000000', '30000000000']

# Total HBM capacity in bytes that attached to accelerator ID 0-3
HBM Capacity Usage
This metric reports the usage of HBM capacity in bytes over the past sample period (every 5 seconds; can be tuned by setting the LIBTPU_INIT_ARG flag). hbm_capacity_usage ['100', '200', '300', '400']

# Capacity usage for HBM in bytes that attached to accelerator ID 0-3
Buffer transfer latency
Network transfer latencies for megascale multi-slice traffic. This visualization lets you understand the overall network performance environment. buffer_transfer_latency ["'8MB+', '2233.25', '2182.02', '3761.93', '19277.01', '53553.6'"]

# buffer size, mean, p50, p90, p99, p99.9 of network transfer latency distribution
High Level Operation Execution Time Distribution Metrics
Provides granular performance insights into the HLO compiled binary execution status, enabling regression detection and model-level debugging. hlo_exec_timing ["'tensorcore-0', '10.00', '10.00', '20.00', '30.00', '40.00'"]

# The HLO execution time duration distribution for CoreType-CoreID with mean, p50, p90, p95, p999
High Level Optimizer queue size
HLO execution queue size monitoring tracks the number of compiled HLO programs waiting or undergoing execution. This metric reveals execution pipeline congestion, enabling identification of performance bottlenecks in hardware execution, driver overhead, or resource allocation. hlo_queue_size ["tensorcore-0: 1", "tensorcore-1: 2"]

# Measures queue size for CoreType-CoreID.
Collective End to End Latency
This metric measures the end-to-end collective latency over DCN in microseconds, from the host initiating the operation to all peers receiving the output. It includes host-side data reduction and sending output to the TPU. Results are strings detailing buffer size, type, and mean, p50, p90, p95, and p99.9 latencies. collective_e2e_latency ["8MB+-ALL_REDUCE, 1000, 2000, 3000, 4000, 5000", …]

# Transfer size-collective op, mean, p50, p90, p95, p999 of collective end to end latency

Read metric data - snapshot mode

To enable snapshot mode, specify the metric name when you call the tpumonitoring.get_metric function. Snapshot mode lets you insert ad hoc metric checks into low-performance code to identify whether performance issues stem from software or hardware.

The following code sample shows how to use snapshot mode to read the duty_cycle .

 from  
libtpu.sdk  
import  
tpumonitoring metric 
  
 = 
  
tpumonitoring.get_metric ( 
 "duty_cycle_pct" 
 ) 
metric.description () 
 "The metric provides a list of duty cycle percentages, one for each 
 accelerator (from accelerator_0 to accelerator_x). The duty cycle represents 
 the percentage of time an accelerator was actively processing during the 
 last sample period, indicating TPU utilization." 
metric.data () 
 [ 
 "0.00" 
,  
 "0.00" 
,  
 "0.00" 
,  
 "0.00" 
 ] 
 # accelerator_0-3 
 

Access metrics using the CLI

The following steps show how to interact with LibTPU metrics using the CLI:

  1. Install tpu-info :

     pip  
    install  
    tpu-info 
    
      # Access help information of tpu-info 
    tpu-info  
    --help  
    /  
    -h 
    
  2. Run the default vision of tpu-info :

     tpu-info 
    

    The output is similar to the following:

   
TPU  
Chips  
┏━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━┓  
  
Chip  
  
Type  
  
Devices  
  
PID  
  
┡━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━┩  
  
/dev/accel0  
  
TPU  
v4  
chip  
  
 1 
  
  
 130007 
  
  
  
/dev/accel1  
  
TPU  
v4  
chip  
  
 1 
  
  
 130007 
  
  
  
/dev/accel2  
  
TPU  
v4  
chip  
  
 1 
  
  
 130007 
  
  
  
/dev/accel3  
  
TPU  
v4  
chip  
  
 1 
  
  
 130007 
  
  
└─────────────┴─────────────┴─────────┴────────┘  
TPU  
Runtime  
Utilization  
┏━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓  
  
Device  
  
Memory  
usage  
  
Duty  
cycle  
  
┡━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩  
  
 0 
  
  
 0 
.00  
GiB  
/  
 31 
.75  
GiB  
  
 0 
.00%  
  
  
 1 
  
  
 0 
.00  
GiB  
/  
 31 
.75  
GiB  
  
 0 
.00%  
  
  
 2 
  
  
 0 
.00  
GiB  
/  
 31 
.75  
GiB  
  
 0 
.00%  
  
  
 3 
  
  
 0 
.00  
GiB  
/  
 31 
.75  
GiB  
  
 0 
.00%  
  
└────────┴──────────────────────┴────────────┘  
TensorCore  
Utilization  
┏━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┓  
  
Chip  
ID  
  
TensorCore  
Utilization  
  
┡━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━┩  
  
 0 
  
  
 0 
.00%  
  
  
 1 
  
  
 0 
.00%  
  
  
 3 
  
  
 0 
.00%  
  
  
 2 
  
  
 0 
.00%  
 | 
  
└─────────┴────────────────────────┘  
Buffer  
Transfer  
Latency  
┏━━━━━━━━━━━━━┳━━━━━┳━━━━━┳━━━━━┳━━━━━━┓  
  
Buffer  
Size  
  
P50  
  
P90  
  
P95  
  
P999  
  
┡━━━━━━━━━━━━━╇━━━━━╇━━━━━╇━━━━━╇━━━━━━┩  
  
8MB+  
 | 
  
0us  
  
0us  
  
0us  
  
0us  
 | 
  
└─────────────┴─────┴─────┴─────┴──────┘ 

Use metrics to check TPU utilization

The following examples show how to use metrics from the TPU Monitoring Library to track TPU utilization.

Monitor TPU duty cycle during JAX training

Scenario:You are running a JAX training script and want to monitor the TPU's duty_cycle_pct metric throughout the training process to confirm your TPUs are being effectively utilized. You can log this metric periodically during training to track TPU utilization.

The following code sample shows how to monitor TPU Duty Cycle during JAX training:

 import  
jax
import  
jax.numpy  
as  
jnp
from  
libtpu.sdk  
import  
tpumonitoring
import  
 time 
  
 # --- Your JAX model and training setup would go here --- 
  
 #  --- Example placeholder model and data (replace with your actual setup)--- 
def  
simple_model ( 
x ) 
:  
 return 
  
jnp.sum ( 
x ) 
def  
loss_fn ( 
params,  
x,  
y ) 
:  
 preds 
  
 = 
  
simple_model ( 
x ) 
  
 return 
  
jnp.mean (( 
preds  
-  
y ) 
**2 ) 
def  
train_step ( 
params,  
x,  
y,  
optimizer ) 
:  
 grads 
  
 = 
  
jax.grad ( 
loss_fn )( 
params,  
x,  
y ) 
  
 return 
  
optimizer.update ( 
grads,  
params ) 
 key 
  
 = 
  
jax.random.PRNGKey ( 
 0 
 ) 
 params 
  
 = 
  
jnp.array ([ 
 1 
.0,  
 2 
.0 ]) 
  
 # Example params 
 optimizer 
  
 = 
  
...  
 # Your optimizer (for example, optax.adam) 
 data_x 
  
 = 
  
jnp.ones (( 
 10 
,  
 10 
 )) 
 data_y 
  
 = 
  
jnp.zeros (( 
 10 
, )) 
 num_epochs 
  
 = 
  
 10 
 log_interval_steps 
  
 = 
  
 2 
  
 # Log duty cycle every 2 steps 
 for 
  
epoch  
 in 
  
range ( 
num_epochs ) 
:  
 for 
  
step  
 in 
  
range ( 
 5 
 ) 
:  
 # Example steps per epoch 
  
 params 
  
 = 
  
train_step ( 
params,  
data_x,  
data_y,  
optimizer ) 
  
 if 
  
 ( 
step  
+  
 1 
 ) 
  
%  
 log_interval_steps 
  
 == 
  
 0 
:  
 # --- Integrate TPU Monitoring Library here to get duty_cycle --- 
  
 duty_cycle_metric 
  
 = 
  
tpumonitoring.get_metric ( 
 "duty_cycle_pct" 
 ) 
  
 duty_cycle_data 
  
 = 
  
duty_cycle_metric.data  
print ( 
f "Epoch {epoch+1}, Step {step+1}: TPU Duty Cycle Data:" 
 ) 
  
print ( 
f "  Description: {duty_cycle_metric.description}" 
 ) 
  
print ( 
f "  Data: {duty_cycle_data}" 
 ) 
  
 # --- End TPU Monitoring Library Integration --- 
  
 # --- Rest of your training loop logic --- 
  
time.sleep ( 
 0 
.1 ) 
  
 # Simulate some computation 
print ( 
 "Training complete." 
 ) 
 

Check HBM utilization before running JAX inference

Scenario:Before running inference with your JAX model, check the current HBM (High Bandwidth Memory) utilization on the TPU to confirm that you have enough memory available and to get a baseline measurement before inference starts.

  # The following code sample shows how to check HBM utilization before JAX inference: 
import  
jax
import  
jax.numpy  
as  
jnp
from  
libtpu.sdk  
import  
tpumonitoring  
 # --- Your JAX model and inference setup would go here --- 
  
 # --- Example placeholder model (replace with your actual model loading/setup)--- 
def  
simple_model ( 
x ) 
:  
 return 
  
jnp.sum ( 
x ) 
 key 
  
 = 
  
jax.random.PRNGKey ( 
 0 
 ) 
 params 
  
 = 
  
...  
 # Load your trained parameters 
  
 # Integrate the TPU Monitoring Library to get HBM utilization before inference 
 hbm_util_metric 
  
 = 
  
tpumonitoring.get_metric ( 
 "hbm_util" 
 ) 
 hbm_util_data 
  
 = 
  
hbm_util_metric.data
print ( 
 "HBM Utilization Before Inference:" 
 ) 
print ( 
f "  Description: {hbm_util_metric.description}" 
 ) 
print ( 
f "  Data: {hbm_util_data}" 
 ) 
  
 # End TPU Monitoring Library Integration 
  
 # Your Inference Logic 
 input_data 
  
 = 
  
jnp.ones (( 
 1 
,  
 10 
 )) 
  
 # Example input 
 predictions 
  
 = 
  
simple_model ( 
input_data ) 
print ( 
 "Inference Predictions:" 
,  
predictions ) 
print ( 
 "Inference complete." 
 ) 
 

Export frequency of TPU metrics

The refresh frequency of TPU metrics is constrained to a minimum of one second. Host metric data is exported at a fixed frequency of 1 Hz. The latency introduced by this export process is negligible. Runtime metrics from LibTPU are not subject to the same frequency constraint. However, for consistency, these metrics are also sampled at 1 Hz or 1 sample per second.

Design a Mobile Site
View Site in Mobile | Classic
Share by: