This notebook shows how to use the Apache Beam RunInference transform for TensorFlow with a trained model from TensorFlow Hub . Apache Beam includes built-in support for two TensorFlow model handlers: TFModelHandlerNumpy and TFModelHandlerTensor .
- Use
TFModelHandlerNumpyto run inference on models that expect a NumPy array as an input. - Use
TFModelHandlerTensorto run inference on models expecting a tensor as an input.
For more information about using RunInference, see Get started with AI/ML pipelines in the Apache Beam documentation.
Before you begin
First, import tensorflow
. To use RunInference with the TensorFlow model handler, install Apache Beam version 2.46 or later.
pip install tensorflowpip install apache_beam [ interactive ]== 2 .46.0
Use TensorFlow Hub's trained model URL
To use TensorFlow Hub's trained model URL, pass the model URL to the model_uri
field of TFModelHandler
class.
import
tensorflow
as
tf
import
tensorflow_hub
as
hub
import
apache_beam
as
beam
# URL of the trained model from TensorFlow Hub
CLASSIFIER_URL
=
"https://tfhub.dev/google/tf2-preview/mobilenet_v2/classification/4"
import
numpy
as
np
import
PIL.Image
as
Image
IMAGE_RES
=
224
img
=
tf
.
keras
.
utils
.
get_file
(
origin
=
'https://storage.googleapis.com/apache-beam-samples/image_captioning/Cat-with-beanie.jpg'
)
img
=
Image
.
open
(
img
)
.
resize
((
IMAGE_RES
,
IMAGE_RES
))
img
Downloading data from https://storage.googleapis.com/apache-beam-samples/image_captioning/Cat-with-beanie.jpg 1812110/1812110 [==============================] - 0s 0us/step

# Convert the input image to the type and dimensions required by the model.
img
=
np
.
array
(
img
)
/
255.0
img_tensor
=
tf
.
cast
(
tf
.
convert_to_tensor
(
img
[
...
]),
dtype
=
tf
.
float32
)
from
apache_beam.ml.inference.tensorflow_inference
import
TFModelHandlerTensor
from
apache_beam.ml.inference.base
import
PredictionResult
from
apache_beam.ml.inference.base
import
RunInference
from
typing
import
Iterable
model_handler
=
TFModelHandlerTensor
(
model_uri
=
CLASSIFIER_URL
)
class
PostProcessor
(
beam
.
DoFn
):
"""Process the PredictionResult to get the predicted label.
Returns predicted label.
"""
def
setup
(
self
):
labels_path
=
tf
.
keras
.
utils
.
get_file
(
'ImageNetLabels.txt'
,
'https://storage.googleapis.com/download.tensorflow.org/data/ImageNetLabels.txt'
)
self
.
_imagenet_labels
=
np
.
array
(
open
(
labels_path
)
.
read
()
.
splitlines
())
def
process
(
self
,
element
:
PredictionResult
)
-
> Iterable
[
str
]:
predicted_class
=
np
.
argmax
(
element
.
inference
)
predicted_class_name
=
self
.
_imagenet_labels
[
predicted_class
]
yield
"Predicted Label:
{}
"
.
format
(
predicted_class_name
.
title
())
with
beam
.
Pipeline
()
as
p
:
_
=
(
p
|
"Create PCollection"
>> beam
.
Create
([
img_tensor
])
|
"Perform inference"
>> RunInference
(
model_handler
)
|
"Post Processing"
>> beam
.
ParDo
(
PostProcessor
())
|
"Print"
>> beam
.
Map
(
print
))
Predicted Label: Tiger Cat



