Retrain a speech recognition model with TensorFlow Lite Model MakerStay organized with collectionsSave and categorize content based on your preferences.
Copyright 2024 The AI Edge Authors.
Licensed under the Apache License, Version 2.0 (the "License");
In this colab notebook, you'll learn how to use theTensorFlow Lite Model Makerto train a speech recognition model that can classify spoken words or short phrases using one-second sound samples. The Model Maker library uses transfer learning to retrain an existing TensorFlow model with a new dataset, which reduces the amount of sample data and time required for training.
By default, this notebook retrains the model (BrowserFft, from theTFJS Speech Command Recognizer) using a subset of words from thespeech commands dataset(such as "up," "down," "left," and "right"). Then it exports a TFLite model that you can run on a mobile device or embedded system (such as a Raspberry Pi). It also exports the trained model as a TensorFlow SavedModel.
This notebook is also designed to accept a custom dataset of WAV files, uploaded to Colab in a ZIP file. The more samples you have for each class, the better your accuracy will be, but because the transfer learning process uses feature embeddings from the pre-trained model, you can still get a fairly accurate model with only a few dozen samples in each of your classes.
If you want to run the notebook with the default speech dataset, you can run the whole thing now by clickingRuntime > Run allin the Colab toolbar. However, if you want to use your own dataset, then continue down toPrepare the datasetand follow the instructions there.
Import the required packages
You'll need TensorFlow, TFLite Model Maker, and some modules for audio manipulation, playback, and visualizations.
To train with the default speech dataset, just run all the code below as-is.
But if you want to train with your own speech dataset, follow these steps:
Be sure each sample in your dataset is inWAV file format, about one second long. Then create a ZIP file with all your WAV files, organized into separate subfolders for each classification. For example, each sample for a speech command "yes" should be in a subfolder named "yes". Even if you have only one class, the samples must be saved in a subdirectory with the class name as the directory name. (This script assumes your datasetis not splitinto train/validation/test sets and performs that split for you.)
Click theFilestab in the left panel and just drag-drop your ZIP file there to upload it.
Use the following drop-down option to setuse_custom_datasetto True.
Whether you're using the default speech dataset or a custom dataset, you should have a good set of background noises so your model can distinguish speech from other noises (including silence).
Because the following background samples are provided in WAV files that are a minute long or longer, we need to split them up into smaller one-second samples so we can reserve some for our test dataset. We'll also combine a couple different sample sources to build a comprehensive set of background noises and silence:
#Create a list of all the background wav files
files = glob.glob(os.path.join('./dataset-speech/_background_noise_', '*.wav'))
files = files + glob.glob(os.path.join('./dataset-background', '*.wav'))
background_dir = './background'
os.makedirs(background_dir, exist_ok=True)#Loop through all files and split each into several one-second wav files
for file in files:
filename = os.path.basename(os.path.normpath(file))
print('Splitting', filename)
name = os.path.splitext(filename)[0]
rate = librosa.get_samplerate(file)
length = round(librosa.get_duration(filename=file))
for i in range(length - 1):
start = i* ratestop = (i *rate) + rate
data, _ = sf.read(file, start=start, stop=stop)
sf.write(os.path.join(background_dir, name + str(i) + '.wav'), data, rate)
Prepare the speech commands dataset
We already downloaded the speech commands dataset, so now we just need to prune the number of classes for our model.
This dataset includes over 30 speech command classifications, and most of them have over 2,000 samples. But because we're using transfer learning, we don't need that many samples. So the following code does a few things:
Specify which classifications we want to use, and delete the rest.
Keep only 150 samples of each class for training (to prove that transfer learning works well with smaller datasets and simply to reduce the training time).
Create a separate directory for a test dataset so we can easily run inference with them later.
If you want to train the model with our own speech dataset, you need to upload your samples as WAV files in a ZIP (as described above) and modify the following variables to specify your dataset:
ifuse_custom_dataset:# Specify the ZIP file you uploaded:!unzipYOUR-FILENAME.zip# Specify the unzipped path to your custom dataset# (this path contains all the subfolders with classification names):dataset_dir='./YOUR-DIRNAME'
After changing the filename and path name above, you're ready to train the model with your custom dataset. In the Colab toolbar, selectRuntime > Run allto run the whole notebook.
The following code integrates our new background noise samples into your dataset and then separates a portion of all samples to create a test set.
def move_background_dataset(dataset_dir):
dest_dir = os.path.join(dataset_dir, 'background')
if os.path.exists(dest_dir):
files = glob.glob(os.path.join(background_dir, '*.wav'))
for file in files:
shutil.move(file, dest_dir)
else:
shutil.move(background_dir, dest_dir)
if use_custom_dataset:
# Move background samples into custom dataset
move_background_dataset(dataset_dir)
# Now we separate some of the files that we'll use for testing:
test_dir = './dataset-test'
test_data_ratio = 0.2
dirs = glob.glob(os.path.join(dataset_dir, '*/'))
for dir in dirs:
files = glob.glob(os.path.join(dir, '*.wav'))
test_count = round(len(files) * test_data_ratio)
random.seed(42)
random.shuffle(files)
# Move test samples:
for file in files[:test_count]:
class_dir = os.path.basename(os.path.normpath(dir))
os.makedirs(os.path.join(test_dir, class_dir), exist_ok=True)
os.rename(file, os.path.join(test_dir, class_dir, os.path.basename(file)))
print('Moved', test_count, 'images from', class_dir)
Play a sample
To be sure the dataset looks correct, let's play at a random sample from the test set:
When using Model Maker to retrain any model, you have to start by defining a model spec. The spec defines the base model from which your new model will extract feature embeddings to begin learning new classes. The spec for this speech recognizer is based on the pre-trainedBrowserFft model from TFJS.
The model expects input as an audio sample that's 44.1 kHz, and just under a second long: the exact sample length must be 44034 frames.
You don't need to do any resampling with your training dataset. Model Maker takes care of that for you. But when you later run inference, you must be sure that your input matches that expected format.
All you need to do here is instantiate theBrowserFftSpec:
spec = audio_classifier.BrowserFftSpec()
Load your dataset
Now you need to load your dataset according to the model specifications. Model Maker includes theDataLoaderAPI, which will load your dataset from a folder and ensure it's in the expected format for the model spec.
We already reserved some test files by moving them to a separate directory, which makes it easier to run inference with them later. Now we'll create aDataLoaderfor each split: the training set, the validation set, and the test set.
Now we'll use the Model Makercreate()function to create a model based on our model spec and training dataset, and begin training.
If you're using a custom dataset, you might want to change the batch size as appropriate for the number of samples in your train set.
#If your dataset has fewer than 100 samples per class,#you might want to try a smaller batch size
batch_size = 25
epochs = 25
model = audio_classifier.create(train_data, spec, validation_data, batch_size, epochs)
Review the model performance
Even if the accuracy/loss looks good from the training output above, it's important to also run the model using test data that the model has not seen yet, which is what theevaluate()method does here:
model.evaluate(test_data)
View the confusion matrix
When training a classification model such as this one, it's also useful to inspect theconfusion matrix. The confusion matrix gives you detailed visual representation of how well your classifier performs for each classification in your test data.
defshow_confusion_matrix(confusion,test_labels):"""Compute confusion matrix and normalize."""confusion_normalized=confusion.astype("float")/confusion.sum(axis=1)sns.set(rc={'figure.figsize':(6,6)})sns.heatmap(confusion_normalized,xticklabels=test_labels,yticklabels=test_labels,cmap='Blues',annot=True,fmt='.2f',square=True,cbar=False)plt.title("Confusion matrix")plt.ylabel("True label")plt.xlabel("Predicted label")confusion_matrix=model.confusion_matrix(test_data)show_confusion_matrix(confusion_matrix.numpy(),test_data.index_to_label)
Export the model
The last step is exporting your model into the TensorFlow Lite format for execution on mobile/embedded devices and into theSavedModel formatfor execution elsewhere.
When exporting a.tflitefile from Model Maker, it includesmodel metadatathat describes various details that can later help during inference. It even includes a copy of the classification labels file, so you don't need to a separatelabels.txtfile. (In the next section, we show how to use this metadata to run an inference.)
print(f'Exporing the model to {SAVE_PATH}')model.export(SAVE_PATH,tflite_filename=TFLITE_FILENAME)model.export(SAVE_PATH,export_format=[mm.ExportFormat.SAVED_MODEL,mm.ExportFormat.LABEL])
Run inference with TF Lite model
Now your TFLite model can be deployed and run using any of the supportedinferencing librariesor with the newTFLite AudioClassifier Task API. The following code shows how you can run inference with the.tflitemodel in Python.
# This library provides the TFLite metadata APIpipinstall-qtflite_support
fromtflite_supportimportmetadataimportjsondefget_labels(model):"""Returns a list of labels, extracted from the model metadata."""displayer=metadata.MetadataDisplayer.with_model_file(model)labels_file=displayer.get_packed_associated_file_list()[0]labels=displayer.get_associated_file_buffer(labels_file).decode()return[lineforlineinlabels.split('\n')]defget_input_sample_rate(model):"""Returns the model's expected sample rate, from the model metadata."""displayer=metadata.MetadataDisplayer.with_model_file(model)metadata_json=json.loads(displayer.get_metadata_json())input_tensor_metadata=metadata_json['subgraph_metadata'][0]['input_tensor_metadata'][0]input_content_props=input_tensor_metadata['content']['content_properties']returninput_content_props['sample_rate']
To observe how well the model performs with real samples, run the following code block over and over. Each time, it will fetch a new test sample and run inference with it, and you can listen to the audio sample below.
Now you can deploy the TF Lite model to your mobile or embedded device. You don't need to download the labels file because you can instead retrieve the labels from.tflitefile metadata, as shown in the previous inferencing example.
[[["Easy to understand","easyToUnderstand","thumb-up"],["Solved my problem","solvedMyProblem","thumb-up"],["Other","otherUp","thumb-up"]],[["Missing the information I need","missingTheInformationINeed","thumb-down"],["Too complicated / too many steps","tooComplicatedTooManySteps","thumb-down"],["Out of date","outOfDate","thumb-down"],["Samples / code issue","samplesCodeIssue","thumb-down"],["Other","otherDown","thumb-down"]],["Last updated 2026-05-28 UTC."],[],[]]