Apache Beam RunInference with TensorFlow and TensorFlow Hub

This notebook shows how to use the Apache Beam RunInference transform for TensorFlow with a trained model from TensorFlow Hub . Apache Beam includes built-in support for two TensorFlow model handlers: TFModelHandlerNumpy and TFModelHandlerTensor .

  • Use TFModelHandlerNumpy to run inference on models that expect a NumPy array as an input.
  • Use TFModelHandlerTensor to run inference on models expecting a tensor as an input.

For more information about using RunInference, see Get started with AI/ML pipelines in the Apache Beam documentation.

Before you begin

First, import tensorflow . To use RunInference with the TensorFlow model handler, install Apache Beam version 2.46 or later.

 pip  
install  
tensorflow 
 pip  
install  
 apache_beam 
 == 
 2 
.46.0 

Use TensorFlow Hub's trained model URL

To use TensorFlow Hub's trained model URL, pass the model URL to the model_uri field of TFModelHandler class.

  import 
  
 tensorflow 
  
 as 
  
 tf 
 import 
  
 tensorflow_hub 
  
 as 
  
 hub 
 import 
  
 apache_beam 
  
 as 
  
 beam 
 
  # URL of the trained model from TensorFlow Hub 
 CLASSIFIER_URL 
 = 
 "https://tfhub.dev/google/tf2-preview/mobilenet_v2/classification/4" 
 
  import 
  
 numpy 
  
 as 
  
 np 
 import 
  
 PIL.Image 
  
 as 
  
 Image 
 IMAGE_RES 
 = 
 224 
 img 
 = 
 tf 
 . 
 keras 
 . 
 utils 
 . 
 get_file 
 ( 
 origin 
 = 
 'https://storage.googleapis.com/apache-beam-samples/image_captioning/Cat-with-beanie.jpg' 
 ) 
 img 
 = 
 Image 
 . 
 open 
 ( 
 img 
 ) 
 . 
 resize 
 (( 
 IMAGE_RES 
 , 
 IMAGE_RES 
 )) 
 img 
 
Downloading data from https://storage.googleapis.com/apache-beam-samples/image_captioning/Cat-with-beanie.jpg
1812110/1812110 [==============================] - 0s 0us/step

png

  # Convert the input image to the type and dimensions required by the model. 
 img 
 = 
 np 
 . 
 array 
 ( 
 img 
 ) 
 / 
 255.0 
 img_tensor 
 = 
 tf 
 . 
 cast 
 ( 
 tf 
 . 
 convert_to_tensor 
 ( 
 img 
 [ 
 ... 
 ]), 
 dtype 
 = 
 tf 
 . 
 float32 
 ) 
 
  from 
  
 apache_beam.ml.inference.tensorflow_inference 
  
 import 
 TFModelHandlerTensor 
 from 
  
 apache_beam.ml.inference.base 
  
 import 
 PredictionResult 
 from 
  
 apache_beam.ml.inference.base 
  
 import 
 RunInference 
 from 
  
 typing 
  
 import 
 Iterable 
 model_handler 
 = 
 TFModelHandlerTensor 
 ( 
 model_uri 
 = 
 CLASSIFIER_URL 
 ) 
 class 
  
 PostProcessor 
 ( 
 beam 
 . 
 DoFn 
 ): 
  
 """Process the PredictionResult to get the predicted label. 
 Returns predicted label. 
 """ 
 def 
  
 setup 
 ( 
 self 
 ): 
 labels_path 
 = 
 tf 
 . 
 keras 
 . 
 utils 
 . 
 get_file 
 ( 
 'ImageNetLabels.txt' 
 , 
 'https://storage.googleapis.com/download.tensorflow.org/data/ImageNetLabels.txt' 
 ) 
 self 
 . 
 _imagenet_labels 
 = 
 np 
 . 
 array 
 ( 
 open 
 ( 
 labels_path 
 ) 
 . 
 read 
 () 
 . 
 splitlines 
 ()) 
 def 
  
 process 
 ( 
 self 
 , 
 element 
 : 
 PredictionResult 
 ) 
 - 
> Iterable 
 [ 
 str 
 ]: 
 predicted_class 
 = 
 np 
 . 
 argmax 
 ( 
 element 
 . 
 inference 
 ) 
 predicted_class_name 
 = 
 self 
 . 
 _imagenet_labels 
 [ 
 predicted_class 
 ] 
 yield 
 "Predicted Label: 
 {} 
 " 
 . 
 format 
 ( 
 predicted_class_name 
 . 
 title 
 ()) 
 with 
 beam 
 . 
 Pipeline 
 () 
 as 
 p 
 : 
 _ 
 = 
 ( 
 p 
 | 
 "Create PCollection" 
>> beam 
 . 
 Create 
 ([ 
 img_tensor 
 ]) 
 | 
 "Perform inference" 
>> RunInference 
 ( 
 model_handler 
 ) 
 | 
 "Post Processing" 
>> beam 
 . 
 ParDo 
 ( 
 PostProcessor 
 ()) 
 | 
 "Print" 
>> beam 
 . 
 Map 
 ( 
 print 
 )) 
 
Predicted Label: Tiger Cat
Create a Mobile Website
View Site in Mobile | Classic
Share by: