Copyright 2024 The LiteRT Torch Authors.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
This Colab demonstrates how to convert a PyTorch model to the LiteRT format using the LiteRT Torch package. In this example, we will convert ResNet18 , a popular image recognition model, into a LiteRT model that can be later applied to a LiteRT or MediaPipe app.
pip
install
litert-torch-nightly
torchvision
Import packages
The PyTorch converter is available in the LiteRT Torch
GitHub repository. We also require the PyTorch library, as well as numpy
and torchvision
, which includes the ResNet18 model.
import
litert_torch
import
numpy
import
torch
import
torchvision
Instantiate the PyTorch model
Let's instantiate resnet18
as a sample model from PyTorch's torchvision
package. We'll also provide it with a sample input and execute the model through PyTorch.
resnet18
=
torchvision
.
models
.
resnet18
(
torchvision
.
models
.
ResNet18_Weights
.
IMAGENET1K_V1
)
.
eval
()
sample_inputs
=
(
torch
.
randn
(
1
,
3
,
224
,
224
),)
torch_output
=
resnet18
(
*
sample_inputs
)
Convert the model to LiteRT
Use the convert
function from the litert_torch
package, which converts PyTorch models to the LiteRT format. This will turn the PyTorch model into an on-device model, ready to use with LiteRT and MediaPipe. The conversion process requires a model's sample input for tracing and shape inference.
edge_model
=
litert_torch
.
convert
(
resnet18
.
eval
(),
sample_inputs
)
Inference
Get outputs from inference with the TFLite runtime by directly calling the edge_model with the inputs. Many of the details of TFLite inference in Python are abstracted away with this API.
edge_output
=
edge_model
(
*
sample_inputs
)
Validate the model
Make sure that the output generated by the new converted model matches the output generated by original PyTorch.
if
(
numpy
.
allclose
(
torch_output
.
detach
()
.
numpy
(),
edge_output
,
atol
=
1e-5
,
rtol
=
1e-5
,
)):
print
(
"Inference result with Pytorch and TfLite was within tolerance"
)
else
:
print
(
"Something wrong with Pytorch --> TfLite"
)
Serialization
The converted model includes an export
method, which you can use to serialize the model. This exports the model as a TFLite Flatbuffers file.
from
google.colab
import
files
edge_model
.
export
(
'resnet.tflite'
)
# Download the tflite flatbuffer which can be used with the existing TfLite APIs.
# files.download('resnet.tflite')
Visualization
The export
function creates a TFLite file, which is visualizable with the Google AI Edge Model Explorer
.
pip
install
ai-edge-model-explorer
import
model_explorer
model_explorer
.
visualize
(
'resnet.tflite'
)

