# Copyright 2023 The MediaPipe Authors.# Licensed under the Apache License, Version 2.0 (the "License");## you may not use this file except in compliance with the License.# You may obtain a copy of the License at## https://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.
The MediaPipe object detection solution provides several models you can use immediately for machine learning (ML) in your application. However, if you need to detect objects not covered by the provided models, you can customize any of the provided models with your own data and MediaPipe Model Maker. This model modification tool rebuilds the model using data you provide. This method is faster than training a new model and can produce a model that is more useful for your specific application.
The following sections show you how to use Model Maker to retrain a pre-built model for object detection with your own data, which you can then use with the MediaPipeObject Detector. The example retrains a general purpose object detection model to detect android figurines in images.
Setup
This section describes key steps for setting up your development environment to retrain a model. These instructions describe how to update a model usingGoogle Colab, and you can also use Python in your own development environment. For general information on setting up your development environment for using MediaPipe, including platform version requirements, see theSetup guide for Python.
Attention:This MediaPipe Solutions Preview is an early release.Learn more.
To install the libraries for customizing a model, run the following commands:
Retraining a model for object detection requires a dataset that includes the items, or classes, that you want the completed model to be able to identify. You can do this by trimming down a public dataset to only the classes that are relevant to your usecase, compiling your own dataset, or some combination of both, The dataset can be significantly smaller than what would be required to train a new model. For example, theCOCOdataset used to train many reference models contains hundreds of thousands of images with 91 classes of objects. Transfer learning with Model Maker can retrain an existing model with a smaller dataset and still perform well, depending on your inference accuracy goals. These instructions use a smaller dataset containing 2 types of android figurines, or 2 classes, with 62 total training images.
To download the example dataset, use the following code:
This code stores the dataset at the directory locationandroid_figurine. The directory contains two subdirectories for the training and validation datasets, located inandroid_figurine/trainandandroid_figurine/validationrespectively. Each of the train and validation datasets follow the COCO Dataset format described below.
Supported dataset formats
Model Maker Object Detection API supports reading the following dataset formats:
COCO format
The COCO dataset format has adatadirectory which stores all of the images and a singlelabels.jsonfile which contains the object annotations for all images.
The PASCAL VOC dataset format also has adatadirectory which stores all of the images, however the annotations are split up per image into corresponding xml files in theAnnotationsdirectory.
Verify the dataset content by printing the categories from thelabels.jsonfile. There should be 3 total categories. Index 0 is always set to be thebackgroundclass which may be unused in the dataset. There should be two non-background categories ofandroidandpig_android.
To better understand the dataset, plot a couple of example images along with their bounding boxes.
Visualize the training dataset
importmatplotlib.pyplotaspltfrommatplotlibimportpatches,text,patheffectsfromcollectionsimportdefaultdictimportmathdefdraw_outline(obj):obj.set_path_effects([patheffects.Stroke(linewidth=4,foreground='black'),patheffects.Normal()])defdraw_box(ax,bb):patch=ax.add_patch(patches.Rectangle((bb[0],bb[1]),bb[2],bb[3],fill=False,edgecolor='red',lw=2))draw_outline(patch)defdraw_text(ax,bb,txt,disp):text=ax.text(bb[0],(bb[1]-disp),txt,verticalalignment='top',color='white',fontsize=10,weight='bold')draw_outline(text)defdraw_bbox(ax,annotations_list,id_to_label,image_shape):forannotationinannotations_list:cat_id=annotation["category_id"]bbox=annotation["bbox"]draw_box(ax,bbox)draw_text(ax,bbox,id_to_label[cat_id],image_shape[0]*0.05)defvisualize(dataset_folder,max_examples=None):withopen(os.path.join(dataset_folder,"labels.json"),"r")asf:labels_json=json.load(f)images=labels_json["images"]cat_id_to_label={item["id"]:item["name"]foriteminlabels_json["categories"]}image_annots=defaultdict(list)forannotation_objinlabels_json["annotations"]:image_id=annotation_obj["image_id"]image_annots[image_id].append(annotation_obj)ifmax_examplesisNone:max_examples=len(image_annots.items())n_rows=math.ceil(max_examples/3)fig,axs=plt.subplots(n_rows,3,figsize=(24,n_rows*8))# 3 columns(2nd index), 8x8 for each imageforind,(image_id,annotations_list)inenumerate(list(image_annots.items())[:max_examples]):ax=axs[ind//3,ind%3]img=plt.imread(os.path.join(dataset_folder,"images",images[image_id]["file_name"]))ax.imshow(img)draw_bbox(ax,annotations_list,cat_id_to_label,img.shape)plt.show()visualize(train_dataset_path,9)
Create dataset
The Dataset class has two methods for loading in COCO or PASCAL VOC datasets:
Dataset.from_coco_folder
Dataset.from_pascal_voc_folder
Since the android_figurines dataset is in the COCO dataset format, use thefrom_coco_foldermethod to load the dataset located attrain_dataset_pathandvalidation_dataset_path. When loading the dataset, the data will be parsed from the provided path and converted into a standardizedTFRecordformat which is cached for later use. You should create acache_dirlocation and reuse it for all your training to avoid saving multiple caches of the same dataset.
Once you have completed preparing your data, you can begin retraining a model to recognize the new objects, or classes, defined by your training data. The instructions below use the data prepared in the previous section to retrain an image classification model to recognize the two types of android figurines.
Set retraining options
There are a few required settings to run retraining aside from your training dataset: output directory for the model, and the model architecture. UseHParamsto specify theexport_dirparameter for the output directory. Use theSupportedModelsclass to specify the model architecture. The object detector solution supports the following model architectures:
MobileNet-V2
MobileNet-MultiHW-AVG
For more advanced customization of training parameters, see theHyperparameterssection below.
To set the required parameters, use the following code:
With your training dataset and retraining options prepared, you are ready to start the retraining process. This process is resource intensive and can take a few minutes to a few hours depending on your available compute resources. Using a Google Colab environment with standard GPU runtimes, the example retraining below takes about 2~4 minutes.
To begin the retraining process, use thecreate()method with dataset and options you previously defined:
After training the model, evaluate it on validation dataset and print the loss and coco_metrics. The most important metric for evaluating the model performance is typically the "AP" coco metric for Average Precision.
After creating the model, convert and export it to a Tensorflow Lite model format for later use on an on-device application. The export also includes model metadata, which includes the label map.
Model quantization is a model modification technique that can reduce the model size and improve the speed of predictions with only a relatively minor decrease in accuracy.
This section of the guide explains how to apply quantization to your model. Model Maker supports two forms of quantization for object detector:
Quantization Aware Training: 8 bit integer precision for CPU usage
Post-Training Quantization: 16 bit floating point precision for GPU usage
Quantization aware training (int8 quantization)
Quantization aware training (QAT) is a fine-tuning step which happens after fully training your model. This technique further tunes a model which emulates inference time quantization in order to account for the lower precision of 8 bit integer quantization. For on-device applications with a standard CPU, use Int8 precision. For more information, see theTensorFlow Litedocumentation.
To apply quantization aware training and export to an int8 model, create aQATHParamsconfiguration and run thequantization_aware_trainingmethod. See theHyperparameterssection below on detailed usage ofQATHParams.
The QAT step often requires multiple runs to tune the parameters of training. To avoid having to rerun model training using thecreatemethod, use therestore_float_ckptmethod to restore the model state back to the fully trained float model(After running thecreatemethod) in order to run QAT again.
Finally, us theexport_modelto export to an int8 quantized model. Theexport_modelfunction will automatically export to either float32 or int8 model depending on whetherquantization_aware_trainingwas run.
Post-training model quantization is a model modification technique that can reduce the model size and improve the speed of predictions with only a relatively minor decrease in accuracy. This approach reduces the size of the data processed by the model, for example by transforming 32-bit floating point numbers to 16-bit floats. Float16 quantization is reccomended for GPU usage. For more information, see theTensorFlow Litedocumentation.
First, import the MediaPipe Model Maker quantization module:
frommediapipe_model_makerimportquantization
Define a QuantizationConfig object using thefor_float16()class method. This configuration modifies a trained model to use 16-bit floating point numbers instead of 32-bit floating point numbers. You can further customize the quantization process by setting additional parameters for the QuantizationConfig class.
Export the model using the additional quantization_config object to apply post-training quantization. Note that if you previously ranquantization_aware_training, you must first convert the model back to a float model by usingrestore_float_ckpt.
You can further customize the model using the ObjectDetectorOptions class, which has three parameters forSupportedModels,ModelOptions, andHParams.
Use theSupportedModelsenum class to specify the model architecture to use for training. The following model architectures are supported:
MOBILENET_V2
MOBILENET_V2_I320
MOBILENET_MULTI_AVG
MOBILENET_MULTI_AVG_I384
Use theHParamsclass to customize other parameters related to training and saving the model:
learning_rate: Learning rate to use for gradient descent training. Defaults to 0.3.
batch_size: Batch size for training. Defaults to 8.
epochs: Number of training iterations over the dataset. Defaults to 30.
cosine_decay_epochs: The number of epochs for cosine decay learning rate. Seetf.keras.optimizers.schedules.CosineDecayfor more info. Defaults to None, which is equivalent to setting it toepochs.
cosine_decay_alpha: The alpha value for cosine decay learning rate. Seetf.keras.optimizers.schedules.CosineDecayfor more info. Defaults to 1.0, which means no cosine decay.
Use theModelOptionsclass to customize parameters related to the model itself:
Below is a summary of our benchmarking results for the supported model architectures. These models were trained and evaluated on the same android figurines dataset as this notebook. When considering the model benchmarking results, there are a few important caveats to keep in mind:
The android figurines dataset is a small and simple dataset with 62 training examples and 10 validation examples. Since the dataset is quite small, metrics may vary drastically due to variances in the training process. This dataset was provided for demo purposes and it is recommended to collect more data samples for better performing models.
The float32 models were trained with the default HParams, and the QAT step for the int8 models was run withQATHParams(learning_rate=0.1, batch_size=4, epochs=30, decay_rate=1).
For your own dataset, you will likely need to tune values for both HParams and QATHParams in order to achieve the best results. See theHyperparameterssection above for more information on configuring training parameters.
All latency numbers are benchmarked on the Pixel 6.
[[["Easy to understand","easyToUnderstand","thumb-up"],["Solved my problem","solvedMyProblem","thumb-up"],["Other","otherUp","thumb-up"]],[["Missing the information I need","missingTheInformationINeed","thumb-down"],["Too complicated / too many steps","tooComplicatedTooManySteps","thumb-down"],["Out of date","outOfDate","thumb-down"],["Samples / code issue","samplesCodeIssue","thumb-down"],["Other","otherDown","thumb-down"]],["Last updated 2026-06-05 UTC."],[],[]]