Text Searcher with TensorFlow Lite Model Maker

Licensed under the Apache License, Version 2.0 (the "License");

  # 
  
 you 
  
 may 
  
 not 
  
 use 
  
 this 
  
 file 
  
 except 
  
 in 
  
 compliance 
  
 with 
  
 the 
  
 License 
 . 
 # 
  
 You 
  
 may 
  
 obtain 
  
 a 
  
 copy 
  
 of 
  
 the 
  
 License 
  
 at 
 # 
 # 
  
 https 
 : 
 //www.apache.org/licenses/LICENSE-2.0 
 # 
 # 
  
 Unless 
  
 required 
  
 by 
  
 applicable 
  
 law 
  
 or 
  
 agreed 
  
 to 
  
 in 
  
 writing 
 , 
  
 software 
 # 
  
 distributed 
  
 under 
  
 the 
  
 License 
  
 is 
  
 distributed 
  
 on 
  
 an 
  
 "AS IS" 
  
 BASIS 
 , 
 # 
  
 WITHOUT 
  
 WARRANTIES 
  
 OR 
  
 CONDITIONS 
  
 OF 
  
 ANY 
  
 KIND 
 , 
  
 either 
  
 express 
  
 or 
  
 implied 
 . 
 # 
  
 See 
  
 the 
  
 License 
  
 for 
  
 the 
  
 specific 
  
 language 
  
 governing 
  
 permissions 
  
 and 
 # 
  
 limitations 
  
 under 
  
 the 
  
 License 
 . 
 
Run in Google Colab View source on GitHub Download notebook

In this colab notebook, you can learn how to use the TensorFlow Lite Model Maker library to create a TFLite Searcher model. You can use a text Searcher model to build Semantic Search or Smart Reply for your app. This type of model lets you take a text query and search for the most related entries in a text dataset, such as a database of web pages. The model returns a list of the smallest distance scoring entries in the dataset, including metadata you specify, such as URL, page title, or other text entry identifiers. After building this, you can deploy it onto devices (e.g. Android) using Task Library Searcher API to run inference with just a few lines of code.

This tutorial leverages CNN/DailyMail dataset as an instance to create the TFLite Searcher model. You can try with your own dataset with the compatible input comma separated value (CSV) format.

Text search using Scalable Nearest Neighbor

This tutorial uses the publicly available CNN/DailyMail non-anonymized summarization dataset, which was produced from the GitHub repo . This dataset contains over 300k news articles, which makes it a good dataset to build the Searcher model, and return various related news during model inference for a text query.

The text Searcher model in this example uses a ScaNN (Scalable Nearest Neighbors) index file that can search for similar items from a predefined database. ScaNN achieves state-of-the-art performance for efficient vector similarity search at scale.

Highlights and urls in this dataset are used in this colab to create the model:

  1. Highlights are the text for generating the embedding feature vectors and then used for search.
  2. Urls are the returned result shown to users after searching the related highlights.

This tutorial saves these data into the CSV file and then uses the CSV file to build the model. Here are several examples from the dataset.

Highlights Urls
Hawaiian Airlines again lands at No. 1 in on-time performance. The Airline Quality Rankings Report looks at the 14 largest U.S. airlines. ExpressJet
and American Airlines had the worst on-time performance. Virgin America had the best baggage handling; Southwest had lowest complaint rate.
http://www.cnn.com/2013/04/08/travel/airline-quality-report
European football's governing body reveals list of countries bidding to host 2020 finals. The 60th anniversary edition of the finals will be hosted by 13
countries. Thirty-two countries are considering bids to host 2020 matches. UEFA will announce host cities on September 25.
http://edition.cnn.com :80/2013/09/20/sport/football/football-euro-2020-bid-countries/index.html?
Once octopus-hunter Dylan Mayer has now also signed a petition of 5,000 divers banning their hunt at Seacrest Park. Decision by Washington
Department of Fish and Wildlife could take months.
http://www.dailymail.co.uk :80/news/article-2238423/Dylan-Mayer-Washington-considers-ban-Octopus-hunting-diver-caught-ate-Puget-Sound.html?
Galaxy was observed 420 million years after the Big Bang. found by NASA’s Hubble Space Telescope, Spitzer Space Telescope, and one of nature’s
own natural 'zoom lenses' in space.
http://www.dailymail.co.uk/sciencetech/article-2233883/The-furthest-object-seen-Record-breaking-image-shows-galaxy-13-3-BILLION-light-years-Earth.html

Setup

Start by installing the required packages, including the Model Maker package from the GitHub repo .

 sudo  
apt  
-y  
install  
libportaudio2 
 pip  
install  
-q  
tflite-model-maker 
 pip  
install  
gdown 

Import the required packages.

  from 
  
 tflite_model_maker 
  
 import 
 searcher 
 

Prepare the dataset

This tutorial uses the dataset CNN / Daily Mail summarization dataset from the GitHub repo .

First, download the text and urls of cnn and dailymail and unzip them. If it failed to download from google drive, please wait a few minutes to try it again or download it manually and then upload it to the colab.

 gdown  
https://drive.google.com/uc?id = 
0BwmD_VLjROrfTHk4NFg2SndKcjQ 
 gdown  
https://drive.google.com/uc?id = 
0BwmD_VLjROrfM1BxdkxVaTY2bWs 
  
 wget  
-O  
all_train.txt  
https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/all_train.txt 
 tar  
xzf  
cnn_stories.tgz 
 tar  
xzf  
dailymail_stories.tgz 

Then, save the data into the CSV file that can be loaded into tflite_model_maker library. The code is based on the logic used to load this data in tensorflow_datasets . We can't use tensorflow_dataset directly since it doesn't contain urls which are used in this colab.

Since it takes a long time to process the data into embedding feature vectors for the whole dataset. Only first 5% stories of CNN and Daily Mail dataset are selected by default for demo purpose. You can adjust the fraction or try with the pre-built TFLite model with 50% stories of CNN and Daily Mail dataset to search as well.

Save the highlights and urls to the CSV file

  CNN_FRACTION 
 = 
 0.05 
 DAILYMAIL_FRACTION 
 = 
 0.05 
 import 
  
 csv 
 import 
  
 hashlib 
 import 
  
 os 
 import 
  
 tensorflow 
  
 as 
  
 tf 
 dm_single_close_quote 
 = 
 u 
 " 
 \u2019 
 " 
 # unicode 
 dm_double_close_quote 
 = 
 u 
 " 
 \u201d 
 " 
 END_TOKENS 
 = 
 [ 
 "." 
 , 
 "!" 
 , 
 "?" 
 , 
 "..." 
 , 
 "'" 
 , 
 "`" 
 , 
 '"' 
 , 
 dm_single_close_quote 
 , 
 dm_double_close_quote 
 , 
 ")" 
 ] 
 # acceptable ways to end a sentence 
 def 
  
 read_file 
 ( 
 file_path 
 ): 
  
 """Reads lines in the file.""" 
 lines 
 = 
 [] 
 with 
 tf 
 . 
 io 
 . 
 gfile 
 . 
 GFile 
 ( 
 file_path 
 , 
 "r" 
 ) 
 as 
 f 
 : 
 for 
 line 
 in 
 f 
 : 
 lines 
 . 
 append 
 ( 
 line 
 . 
 strip 
 ()) 
 return 
 lines 
 def 
  
 url_hash 
 ( 
 url 
 ): 
  
 """Gets the hash value of the url.""" 
 h 
 = 
 hashlib 
 . 
 sha1 
 () 
 url 
 = 
 url 
 . 
 encode 
 ( 
 "utf-8" 
 ) 
 h 
 . 
 update 
 ( 
 url 
 ) 
 return 
 h 
 . 
 hexdigest 
 () 
 def 
  
 get_url_hashes_dict 
 ( 
 urls_path 
 ): 
  
 """Gets hashes dict that maps the hash value to the original url in file.""" 
 urls 
 = 
 read_file 
 ( 
 urls_path 
 ) 
 return 
 { 
 url_hash 
 ( 
 url 
 ): 
 url 
 [ 
 url 
 . 
 find 
 ( 
 "id_/" 
 ) 
 + 
 4 
 :] 
 for 
 url 
 in 
 urls 
 } 
 def 
  
 find_files 
 ( 
 folder 
 , 
 url_dict 
 ): 
  
 """Finds files corresponding to the urls in the folder.""" 
 all_files 
 = 
 tf 
 . 
 io 
 . 
 gfile 
 . 
 listdir 
 ( 
 folder 
 ) 
 ret_files 
 = 
 [] 
 for 
 file 
 in 
 all_files 
 : 
 # Gets the file name without extension. 
 filename 
 = 
 os 
 . 
 path 
 . 
 splitext 
 ( 
 os 
 . 
 path 
 . 
 basename 
 ( 
 file 
 ))[ 
 0 
 ] 
 if 
 filename 
 in 
 url_dict 
 : 
 ret_files 
 . 
 append 
 ( 
 os 
 . 
 path 
 . 
 join 
 ( 
 folder 
 , 
 file 
 )) 
 return 
 ret_files 
 def 
  
 fix_missing_period 
 ( 
 line 
 ): 
  
 """Adds a period to a line that is missing a period.""" 
 if 
 "@highlight" 
 in 
 line 
 : 
 return 
 line 
 if 
 not 
 line 
 : 
 return 
 line 
 if 
 line 
 [ 
 - 
 1 
 ] 
 in 
 END_TOKENS 
 : 
 return 
 line 
 return 
 line 
 + 
 "." 
 def 
  
 get_highlights 
 ( 
 story_file 
 ): 
  
 """Gets highlights from a story file path.""" 
 lines 
 = 
 read_file 
 ( 
 story_file 
 ) 
 # Put periods on the ends of lines that are missing them 
 # (this is a problem in the dataset because many image captions don't end in 
 # periods; consequently they end up in the body of the article as run-on 
 # sentences) 
 lines 
 = 
 [ 
 fix_missing_period 
 ( 
 line 
 ) 
 for 
 line 
 in 
 lines 
 ] 
 # Separate out article and abstract sentences 
 highlight_list 
 = 
 [] 
 next_is_highlight 
 = 
 False 
 for 
 line 
 in 
 lines 
 : 
 if 
 not 
 line 
 : 
 continue 
 # empty line 
 elif 
 line 
 . 
 startswith 
 ( 
 "@highlight" 
 ): 
 next_is_highlight 
 = 
 True 
 elif 
 next_is_highlight 
 : 
 highlight_list 
 . 
 append 
 ( 
 line 
 ) 
 # Make highlights into a single string. 
 highlights 
 = 
 " 
 \n 
 " 
 . 
 join 
 ( 
 highlight_list 
 ) 
 return 
 highlights 
 url_hashes_dict 
 = 
 get_url_hashes_dict 
 ( 
 "all_train.txt" 
 ) 
 cnn_files 
 = 
 find_files 
 ( 
 "cnn/stories" 
 , 
 url_hashes_dict 
 ) 
 dailymail_files 
 = 
 find_files 
 ( 
 "dailymail/stories" 
 , 
 url_hashes_dict 
 ) 
 # The size to be selected. 
 cnn_size 
 = 
 int 
 ( 
 CNN_FRACTION 
 * 
 len 
 ( 
 cnn_files 
 )) 
 dailymail_size 
 = 
 int 
 ( 
 DAILYMAIL_FRACTION 
 * 
 len 
 ( 
 dailymail_files 
 )) 
 print 
 ( 
 "CNN size: 
 %d 
 " 
 % 
 cnn_size 
 ) 
 print 
 ( 
 "Daily Mail size: 
 %d 
 " 
 % 
 dailymail_size 
 ) 
 with 
 open 
 ( 
 "cnn_dailymail.csv" 
 , 
 "w" 
 ) 
 as 
 csvfile 
 : 
 writer 
 = 
 csv 
 . 
 DictWriter 
 ( 
 csvfile 
 , 
 fieldnames 
 = 
 [ 
 "highlights" 
 , 
 "urls" 
 ]) 
 writer 
 . 
 writeheader 
 () 
 for 
 file 
 in 
 cnn_files 
 [: 
 cnn_size 
 ] 
 + 
 dailymail_files 
 [: 
 dailymail_size 
 ]: 
 highlights 
 = 
 get_highlights 
 ( 
 file 
 ) 
 # Gets the filename which is the hash value of the url. 
 filename 
 = 
 os 
 . 
 path 
 . 
 splitext 
 ( 
 os 
 . 
 path 
 . 
 basename 
 ( 
 file 
 ))[ 
 0 
 ] 
 url 
 = 
 url_hashes_dict 
 [ 
 filename 
 ] 
 writer 
 . 
 writerow 
 ({ 
 "highlights" 
 : 
 highlights 
 , 
 "urls" 
 : 
 url 
 }) 
 

Build the text Searcher model

Create a text Searcher model by loading a dataset, creating a model with the data and exporting the TFLite model.

Step 1. Load the dataset

Model Maker takes the text dataset and the corresponding metadata of each text string (such as urls in this example) in the CSV format. It embeds the text strings into feature vectors using the user-specified embedder model.

In this demo, we build the Searcher model using Universal Sentence Encoder , a state-of-the-art sentence embedding model which is already retrained from colab . The model is optimized for on-device inference performance, and only takes 6ms to embed a query string (measured on Pixel 6). Alternatively, you can use this quantized version, which is smaller but takes 38ms for each embedding.

 wget  
-O  
universal_sentence_encoder.tflite  
https://storage.googleapis.com/download.tensorflow.org/models/tflite_support/searcher/text_to_image_blogpost/text_embedder.tflite 

Create a searcher.TextDataLoader instance and use data_loader.load_from_csv method to load the dataset. It takes ~10 minutes for this step since it generates the embedding feature vector for each text one by one. You can try to upload your own CSV file and load it to build the customized model as well.

Specify the name of text column and metadata column in the CSV file.

  • Text is used to generate the embedding feature vectors.
  • Metadata is the content to be shown when you search the certain text.

Here are the first 4 lines of the CNN-DailyMail CSV file generated above.

highlights urls
Syrian official: Obama climbed to the top of the tree, doesn't know how to get down. Obama sends a letter to the heads of the House and Senate. Obama
to seek congressional approval on military action against Syria. Aim is to determine whether CW were used, not by whom, says U.N. spokesman.
http://www.cnn.com/2013/08/31/world/meast/syria-civil-war/
Usain Bolt wins third gold of world championship. Anchors Jamaica to 4x100m relay victory. Eighth gold at the championships for Bolt. Jamaica double
up in women's 4x100m relay.
http://edition.cnn.com/2013/08/18/sport/athletics-bolt-jamaica-gold
The employee in agency's Kansas City office is among hundreds of "virtual" workers. The employee's travel to and from the mainland U.S. last year cost
more than $24,000. The telecommuting program, like all GSA practices, is under review.
http://www.cnn.com :80/2012/08/23/politics/gsa-hawaii-teleworking
NEW: A Canadian doctor says she was part of a team examining Harry Burkhart in 2010. NEW: Diagnosis: "autism, severe anxiety, post-traumatic stress
disorder and depression" Burkhart is also suspected in a German arson probe, officials say. Prosecutors believe the German national set a string of fires
in Los Angeles.
http://edition.cnn.com :80/2012/01/05/justice/california-arson/index.html?
  data_loader 
  
 = 
  
 searcher 
 . 
 TextDataLoader 
 . 
 create 
 ( 
 "universal_sentence_encoder.tflite" 
 , 
  
 l2_normalize 
 = 
 True 
 ) 
 data_loader 
 . 
 load_from_csv 
 ( 
 "cnn_dailymail.csv" 
 , 
  
 text_column 
 = 
 "highlights" 
 , 
  
 metadata_column 
 = 
 "urls" 
 ) 
 

For image use cases, you can create a searcher.ImageDataLoader instance and then use data_loader.load_from_folder to load images from the folder. The searcher.ImageDataLoader instance needs to be created by a TFLite embedder model because it will be leveraged to encode queries to feature vectors and be exported with the TFLite Searcher model. For instance:

  data_loader 
 = 
 searcher 
 . 
 ImageDataLoader 
 . 
 create 
 ( 
 "mobilenet_v2_035_96_embedder_with_metadata.tflite" 
 ) 
 data_loader 
 . 
 load_from_folder 
 ( 
 "food/" 
 ) 
 

Step 2. Create the Searcher model

  • Configure ScaNN options. See api doc for more details.
  • Create the Searcher model from data and ScaNN options. You can see the in-depth examination to learn more about the ScaNN algorithm.
  scann_options 
  
 = 
  
 searcher 
 . 
 ScaNNOptions 
 ( 
  
 distance_measure 
 = 
 "dot_product" 
 , 
  
 tree 
 = 
 searcher 
 . 
 Tree 
 ( 
 num_leaves 
 = 
 140 
 , 
  
 num_leaves_to_search 
 = 
 4 
 ), 
  
 score_ah 
 = 
 searcher 
 . 
 ScoreAH 
 ( 
 dimensions_per_block 
 = 
 1 
 , 
  
 anisotropic_quantization_threshold 
 = 
 0.2 
 )) 
 model 
  
 = 
  
 searcher 
 . 
 Searcher 
 . 
 create_from_data 
 ( 
 data_loader 
 , 
  
 scann_options 
 ) 
 

In the above example, we define the following options:

  • distance_measure : we use "dot_product" to measure the distance between two embedding vectors. Note that we actually compute the negativedot product value to preserve the notion that "smaller is closer".

  • tree : the dataset is divided the dataset into 140 partitions (roughly the square root of the data size), and 4 of them are searched during retrieval, which is roughly 3% of the dataset.

  • score_ah : we quantize the float embeddings to int8 values with the same dimension to save space.

Step 3. Export the TFLite model

Then you can export the TFLite Searcher model.

  model 
 . 
 export 
 ( 
  
 export_filename 
 = 
 "searcher.tflite" 
 , 
  
 userinfo 
 = 
 "" 
 , 
  
 export_format 
 = 
 searcher 
 . 
 ExportFormat 
 . 
 TFLITE 
 ) 
 

Test the TFLite model on your query

You can test the exported TFLite model using custom query text. To query text using the Searcher model, initialize the model and run a search with text phrase, as follows:

  from 
  
 tflite_support.task 
  
 import 
 text 
 # Initializes a TextSearcher object. 
 searcher 
 = 
 text 
 . 
 TextSearcher 
 . 
 create_from_file 
 ( 
 "searcher.tflite" 
 ) 
 # Searches the input query. 
 results 
 = 
 searcher 
 . 
 search 
 ( 
 "The Airline Quality Rankings Report looks at the 14 largest U.S. airlines." 
 ) 
 print 
 ( 
 results 
 ) 
 

See the Task Library documentation for more information about how to integrate the model to various platforms.

Read more

For more information, please refer to:

Create a Mobile Website
View Site in Mobile | Classic
Share by: