PyTorch to LiteRT quickstart

  # Licensed under the Apache License, Version 2.0 (the "License"); 
 # you may not use this file except in compliance with the License. 
 # You may obtain a copy of the License at 
 # 
 # https://www.apache.org/licenses/LICENSE-2.0 
 # 
 # Unless required by applicable law or agreed to in writing, software 
 # distributed under the License is distributed on an "AS IS" BASIS, 
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 
 # See the License for the specific language governing permissions and 
 # limitations under the License. 
 

This Colab demonstrates how to convert a PyTorch model to the LiteRT format using the LiteRT Torch package. In this example, we will convert ResNet18 , a popular image recognition model, into a LiteRT model that can be later applied to a LiteRT or MediaPipe app.

 pip  
install  
litert-torch-nightly  
torchvision 

Import packages

The PyTorch converter is available in the LiteRT Torch GitHub repository. We also require the PyTorch library, as well as numpy and torchvision , which includes the ResNet18 model.

  import 
  
 litert_torch 
 import 
  
 numpy 
 import 
  
 torch 
 import 
  
 torchvision 
 

Instantiate the PyTorch model

Let's instantiate resnet18 as a sample model from PyTorch's torchvision package. We'll also provide it with a sample input and execute the model through PyTorch.

  resnet18 
 = 
 torchvision 
 . 
 models 
 . 
 resnet18 
 ( 
 torchvision 
 . 
 models 
 . 
 ResNet18_Weights 
 . 
 IMAGENET1K_V1 
 ) 
 . 
 eval 
 () 
 sample_inputs 
 = 
 ( 
 torch 
 . 
 randn 
 ( 
 1 
 , 
 3 
 , 
 224 
 , 
 224 
 ),) 
 torch_output 
 = 
 resnet18 
 ( 
 * 
 sample_inputs 
 ) 
 

Convert the model to LiteRT

Use the convert function from the litert_torch package, which converts PyTorch models to the LiteRT format. This will turn the PyTorch model into an on-device model, ready to use with LiteRT and MediaPipe. The conversion process requires a model's sample input for tracing and shape inference.

  edge_model 
 = 
 litert_torch 
 . 
 convert 
 ( 
 resnet18 
 . 
 eval 
 (), 
 sample_inputs 
 ) 
 

Inference

Get outputs from inference with the TFLite runtime by directly calling the edge_model with the inputs. Many of the details of TFLite inference in Python are abstracted away with this API.

  edge_output 
 = 
 edge_model 
 ( 
 * 
 sample_inputs 
 ) 
 

Validate the model

Make sure that the output generated by the new converted model matches the output generated by original PyTorch.

  if 
 ( 
 numpy 
 . 
 allclose 
 ( 
 torch_output 
 . 
 detach 
 () 
 . 
 numpy 
 (), 
 edge_output 
 , 
 atol 
 = 
 1e-5 
 , 
 rtol 
 = 
 1e-5 
 , 
 )): 
 print 
 ( 
 "Inference result with Pytorch and TfLite was within tolerance" 
 ) 
 else 
 : 
 print 
 ( 
 "Something wrong with Pytorch --> TfLite" 
 ) 
 

Serialization

The converted model includes an export method, which you can use to serialize the model. This exports the model as a TFLite Flatbuffers file.

  from 
  
 google.colab 
  
 import 
 files 
 edge_model 
 . 
 export 
 ( 
 'resnet.tflite' 
 ) 
 # Download the tflite flatbuffer which can be used with the existing TfLite APIs. 
 # files.download('resnet.tflite') 
 

Visualization

The export function creates a TFLite file, which is visualizable with the Google AI Edge Model Explorer .

 pip  
install  
ai-edge-model-explorer 
  import 
  
 model_explorer 
 model_explorer 
 . 
 visualize 
 ( 
 'resnet.tflite' 
 ) 
 
Create a Mobile Website
View Site in Mobile | Classic
Share by: