# Copyright 2023 The MediaPipe Authors.# Licensed under the Apache License, Version 2.0 (the "License");## you may not use this file except in compliance with the License.# You may obtain a copy of the License at## https://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.
The MediaPipe Model Maker package is a simple, low-code solution for customizing on-device machine learning (ML) Models. This notebook shows the end-to-end process of customizing a text classification model for the specific use case of performing sentiment analysis on movie reviews.
The following code block downloads theSST-2(Stanford Sentiment Treebank) dataset which contains 67,349 movie reviews for training and 872 movie reviews for testing. The dataset has two classes: positive and negative movie reviews. Positive reviews are labelled with1and negative reviews with0. We will use this dataset to train the two text classifiers featured in this tutorial.
Disclaimer: The dataset linked in this Colab is not owned or distributed by Google, and is made available by third parties. Please review the terms and conditions made available by the third parties before using the data.
data_path=tf.keras.utils.get_file(fname='SST-2.zip',origin='https://dl.fbaipublicfiles.com/glue/data/SST-2.zip',extract=True)data_dir=os.path.join(os.path.dirname(data_path),'SST-2')# folder name
The SST-2 dataset is stored as a TSV file. The only difference between the TSV and CSV formats is that TSV uses a tab\tcharacter as its delimiter and CSV uses a comma,.
The following code block extracts the training and validation data from their TSV files using theDataset.from_csvmethod.
Model Maker's Text Classifier supports two classifiers with distinct model architectures: an average word embedding model and a BERT model. The first demo classifier wil use an average word embedding architecture.
To create and train a text classifier we need to set someTextClassifierOptions. These options require us to specify asupported_model, which can take the valueAVERAGE_WORD_EMBEDDING_CLASSIFIERorMOBILEBERT_CLASSIFIER. We'll useAVERAGE_WORD_EMBEDDING_CLASSIFIERfor now.
For more information on theTextClassifierOptionsclass and its fields, see the TextClassifierOptions section below.
Now we can use the training and validation data with theTextClassifierOptionswe've defined to create and train a text classifier. To do so, we use theTextClassifier.createfunction.
Evaluate the model. Note the improved performance compared to the average word embedding-based classifier.
metrics=bert_model.evaluate(validation_data)print(f'Test loss:{metrics[0]}, Test accuracy:{metrics[1]}')
The MobileBERT model is over 100MB so when we export the BERT-based classifier as a TFLite model, it will help to use quantization which can bring the TFLite model size down to 28MB.
We can configure text classifier training withTextClassifierOptions, which takes one required parameter:
supported_modelwhich describes the model architecture that the text classifier is based on. It can be either anAVERAGE_WORD_EMBEDDING_CLASSIFIERor aMOBILEBERT_CLASSIFIER.
TextClassifierOptionscan also take two optional parameters:
hparamswhich describes hyperparameters used during model training. This takes anHParamsobject.
model_optionswhich describes configurable parameters related to the model architecture or data preprocessing. For an average word-embedding classifier, this field takes anAverageWordEmbeddingModelOptionsobject. For a BERT-based classifier, this field takes aBertModelOptionsobject.
If these fields aren't set, model creation and training will be run with predefined default values.
HParamshas the following list of customizable parameters which affect model accuracy:
learning_rate: The learning rate to use for gradient descent-based optimizers. Defaults to 3e-5 for the BERT-based classifier and 0 for the average word-embedding classifier because it does not need such an optimizer.
batch_size: Batch size for training. Defaults to 32 for the average word-embedding classifier and 48 for the BERT-based classifier.
epochs: Number of training iterations over the dataset. Defaults to 10 for the average word-embedding classifier and 3 for the BERT-based classifier.
steps_per_epoch: An optional integer that indicates the number of training steps per epoch. If not set, the training pipeline calculates the default steps per epoch as the training dataset size divided by batch size.
shuffle: True if the dataset is shuffled before training. Defaults to False.
AdditionalHParamsparameters that do not affect model accuracy:
export_dir: The location of the model checkpoint files and exported model files.
AverageWordEmbeddingModelOptionshas the following list of customizable parameters related to the model architecture:
seq_len: the length of the input sequence for the model. Defaults to 256.
wordvec_dim: the dimension of the word embeddings. Defaults to 16.
dropout_rate: The rate used in the model's dropout layer. Defaults to 0.2.
It also has the following customizable parameters related to data preprocessing:
do_lower_case: whether text input is converted to lower case before training or inference. Defaults to True.
vocab_size: the maximum size of the vocab generated from the set of text data.
BertModelOptionshas the following list of customizable parameters related to the model architecture:
seq_len: the length of the input sequence for the BERT-encoder. Defaults to 128
dropout_rate: the rate used in the classifier's dropout layer. Defaults to 0.1.
do_fine_tuning: whether the BERT-encoder is unfrozen and should be trainable along with the classifier layers. Defaults to True.
Benchmarks
Below is a summary of our benchmarking results for the average word-embedding and BERT-based classifiers featured in this tutorial. To optimize model performance for your use-case, it's worthwhile to experiment with different model and training parameters in order to obtain the highest test accuracy. Refer to the TextClassifierOptions section for more information on customizing these parameters.
[[["Easy to understand","easyToUnderstand","thumb-up"],["Solved my problem","solvedMyProblem","thumb-up"],["Other","otherUp","thumb-up"]],[["Missing the information I need","missingTheInformationINeed","thumb-down"],["Too complicated / too many steps","tooComplicatedTooManySteps","thumb-down"],["Out of date","outOfDate","thumb-down"],["Samples / code issue","samplesCodeIssue","thumb-down"],["Other","otherDown","thumb-down"]],["Last updated 2026-06-05 UTC."],[],[]]