Run inference with a Gemma open model

Gemma is a family of lightweight, state-of-the art open models built from research and technology used to create the Gemini models. You can use Gemma models in your Apache Beam inference pipelines with the RunInference transform.

This notebook demonstrates how to load the preconfigured Gemma 2B model and then use it in your Apache Beam inference pipeline. The pipeline runs examples by using a built-in model handler and a custom inference function.

For more information about using RunInference, see Get started with AI/ML pipelines in the Apache Beam documentation.

Requirements

Serving and using Gemma models requires a substantial amount of RAM. To run this example, we recommend that you use a notebook instance with GPUs. At a mimumum, use a machine that has the T4 GPU type. This configuration provides sufficient memory for running inference with a saved model.

Before you begin

  • To use a fine-tuned version of the model, follow the steps in Gemma fine-tuning .
  • For testing this workflow, we recommend using the instruction tuned model in your Apache Beam workflow. For example, if you use the Gemma 2B model in your pipeline, when you load the model, change the GemmaCausalLM.from_preset() argument from gemma_2b_en to gemma_instruct_2b_en . For more information, see Create a model in "Get started with Gemma using KerasNLP". For a list of models, see Gemma models .

Install Dependencies

To use the RunInference transform with the built-in TensorFlow model handler, install Apache Beam version 2.46.0 or later. The model class is contained in the Keras natural language processing (NLP) package versions 0.8.0 and later.

  ! 
 pip 
 install 
 - 
 q 
 - 
 U 
 protobuf 
 ! 
 pip 
 install 
 - 
 q 
 - 
 U 
 apache_beam 
 [ 
 gcp 
 ] 
 ! 
 pip 
 install 
 - 
 q 
 - 
 U 
 keras_nlp 
> = 
 0.8.0 
 ! 
 pip 
 install 
 - 
 q 
 - 
 U 
 keras>3 
 # To use the newly installed versions, restart the runtime. 
 exit 
 () 
 

Authenticate with Kaggle

The pipeline defined here automatically pulls the model weights from Kaggle. First, accept the terms of use for Gemma models on the Keras Gemma page. Next, generate an API token by following the instructions in How to use Kaggle . Provide your username and token.

  import 
  
 kagglehub 
 kagglehub 
 . 
 login 
 () 
 
VBox(children=(HTML(value='<center> <img\nsrc=https&colon;//www.kaggle.com/static/images/site-logo.png\nalt=\'Kaggle…
Kaggle credentials set.
Kaggle credentials successfully validated.

Import dependencies and provide a model preset

Use the following code to import dependencies.

Replace the value for the model_preset variable with the name of the Gemma preset to use. For example, to use the default English weights, use the value gemma_2b_en . This example uses the instruction-tuned preset gemma_instruct_2b_en . Optionally, to run the model at half-precision and reduce GPU memory usage, use Keras.

  import 
  
 numpy 
  
 as 
  
 np 
 import 
  
 apache_beam 
  
 as 
  
 beam 
 import 
  
 keras_nlp 
 import 
  
 keras 
 from 
  
 apache_beam.ml.inference 
  
 import 
 utils 
 from 
  
 apache_beam.ml.inference.base 
  
 import 
 RunInference 
 from 
  
 apache_beam.ml.inference.tensorflow_inference 
  
 import 
 TFModelHandlerNumpy 
 from 
  
 apache_beam.options.pipeline_options 
  
 import 
 PipelineOptions 
 model_preset 
 = 
 "gemma_instruct_2b_en" 
 # Optionally set the model to run at half-precision 
 # (recommended for smaller GPUs) 
 keras 
 . 
 config 
 . 
 set_floatx 
 ( 
 "bfloat16" 
 ) 
 

Run the pipeline

To run the pipeline, use a custom model handler.

Provide a custom model handler

To simplify model loading, this notebook defines a custom model handler that loads the model by pulling the model weights directly from Kaggle presets. To customize the behavior of the handler, implement load_model , validate_inference_args , and share_model_across_processes . The Keras implementation of the Gemma models has a generate method that generates text based on a prompt. To route the prompts properly, use this function in the run_inference method.

  # To load the model and perform the inference, define `GemmaModelHandler`. 
 from 
  
 apache_beam.ml.inference.base 
  
 import 
 ModelHandler 
 from 
  
 apache_beam.ml.inference.base 
  
 import 
 PredictionResult 
 from 
  
 typing 
  
 import 
 Any 
 from 
  
 typing 
  
 import 
 Dict 
 from 
  
 typing 
  
 import 
 Iterable 
 from 
  
 typing 
  
 import 
 Optional 
 from 
  
 typing 
  
 import 
 Sequence 
 from 
  
 keras_nlp.src.models.gemma.gemma_causal_lm 
  
 import 
 GemmaCausalLM 
 class 
  
 GemmaModelHandler 
 ( 
 ModelHandler 
 [ 
 str 
 , 
 PredictionResult 
 , 
 GemmaCausalLM 
 ]): 
 def 
  
 __init__ 
 ( 
 self 
 , 
 model_name 
 : 
 str 
 = 
 "gemma_2b_en" 
 , 
 ): 
  
 """ Implementation of the ModelHandler interface for Gemma using text as input. 
 Example Usage:: 
 pcoll | RunInference(GemmaModelHandler()) 
 Args: 
 model_name: The Gemma model preset. Default is gemma_2b_instruct_en. 
 """ 
 self 
 . 
 _model_name 
 = 
 model_name 
 self 
 . 
 _env_vars 
 = 
 {} 
 def 
  
 share_model_across_processes 
 ( 
 self 
 ) 
 - 
> bool 
 : 
 return 
 True 
 def 
  
 load_model 
 ( 
 self 
 ) 
 - 
> GemmaCausalLM 
 : 
  
 """Loads and initializes a model for processing.""" 
 return 
 keras_nlp 
 . 
 models 
 . 
 GemmaCausalLM 
 . 
 from_preset 
 ( 
 self 
 . 
 _model_name 
 ) 
 def 
  
 validate_inference_args 
 ( 
 self 
 , 
 inference_args 
 : 
 Optional 
 [ 
 Dict 
 [ 
 str 
 , 
 Any 
 ]]): 
  
 """Validates the inference arguments.""" 
 for 
 key 
 , 
 value 
 in 
 inference_args 
 . 
 items 
 (): 
 if 
 key 
 != 
 "max_length" 
 : 
 raise 
 ValueError 
 ( 
 f 
 "Invalid inference argument: 
 { 
 key 
 } 
 " 
 ) 
 def 
  
 run_inference 
 ( 
 self 
 , 
 batch 
 : 
 Sequence 
 [ 
 str 
 ], 
 model 
 : 
 GemmaCausalLM 
 , 
 inference_args 
 : 
 Optional 
 [ 
 Dict 
 [ 
 str 
 , 
 Any 
 ]] 
 = 
 None 
 ) 
 - 
> Iterable 
 [ 
 PredictionResult 
 ]: 
  
 """Runs inferences on a batch of text strings. 
 Args: 
 batch: A sequence of examples as text strings. 
 model: 
 inference_args: Any additional arguments for an inference. 
 Returns: 
 An Iterable of type PredictionResult. 
 """ 
 # Loop each text string, and use a tuple to store the inference results. 
 predictions 
 = 
 [] 
 for 
 one_text 
 in 
 batch 
 : 
 result 
 = 
 model 
 . 
 generate 
 ( 
 one_text 
 , 
 ** 
 inference_args 
 ) 
 predictions 
 . 
 append 
 ( 
 result 
 ) 
 return 
 utils 
 . 
 _convert_to_result 
 ( 
 batch 
 , 
 predictions 
 , 
 self 
 . 
 _model_name 
 ) 
 

Execute the pipeline

Use the following code to run the pipeline. The code includes the path to the trained TensorFlow model. This cell can take a few minutes to run, because the model is downloaded and then loaded onto the worker. This delay is a one-time cost per worker.

The max_length argument determines how long the response from Gemma is. The response includes your input, so the response length includes your input and the output. For longer prompts, use a larger maximum length. Longer lengths require more time to generate.

  class 
  
 FormatOutput 
 ( 
 beam 
 . 
 DoFn 
 ): 
 def 
  
 process 
 ( 
 self 
 , 
 element 
 , 
 * 
 args 
 , 
 ** 
 kwargs 
 ): 
 yield 
 "Input: 
 {input} 
 , Output: 
 {output} 
 " 
 . 
 format 
 ( 
 input 
 = 
 element 
 . 
 example 
 , 
 output 
 = 
 element 
 . 
 inference 
 ) 
 # Instantiate a NumPy array of string prompts for the model. 
 examples 
 = 
 np 
 . 
 array 
 ([ 
 "Tell me the sentiment of the phrase 'I like pizza': " 
 ]) 
 # Specify the model handler, providing a path and the custom inference function. 
 model_handler 
 = 
 GemmaModelHandler 
 ( 
 model_preset 
 ) 
 with 
 beam 
 . 
 Pipeline 
 () 
 as 
 p 
 : 
 _ 
 = 
 ( 
 p 
 | 
 beam 
 . 
 Create 
 ( 
 examples 
 ) 
 # Create a PCollection of the prompts. 
 | 
 RunInference 
 ( 
 model_handler 
 , 
 inference_args 
 = 
 { 
 'max_length' 
 : 
 32 
 }) 
 # Send the prompts to the model and get responses. 
 | 
 beam 
 . 
 ParDo 
 ( 
 FormatOutput 
 ()) 
 # Format the output. 
 | 
 beam 
 . 
 Map 
 ( 
 print 
 ) 
 # Print the formatted output. 
 ) 
 
WARNING&colon;apache_beam.runners.interactive.interactive_environment&colon;Dependencies required for Interactive Beam PCollection visualization are not available, please use&colon; `pip install apache-beam[interactive]` to install necessary dependencies to enable all data visualization features.
Input&colon; Tell me the sentiment of the phrase 'I like pizza'&colon; , Output&colon; Tell me the sentiment of the phrase 'I like pizza'&colon; 

The sentiment of the phrase "I like pizza" is positive. It expresses a personal
Create a Mobile Website
View Site in Mobile | Classic
Share by: