Use Gemma to gauge sentiment and summarize conversations

Gemma is a family of lightweight, state-of-the art open models built from research and technology used to create the Gemini models. You can use Gemma models in your Apache Beam inference pipelines.

Because large language models (LLMs) like Gemma are versatile, you can integrate them into business processes. The example in this notebook demonstrates how to use Gemma to gauge the sentiment of a conversation, summarize that conversation's content, and draft a reply for a difficult conversation. The system allows a person to review the reply before it's sent to customers. For more information, see the blog post Gemma for Streaming ML with Dataflow .

A requirement of this work is that customers who express a negative sentiment receive a reply in near real-time. As a result, the workflow needs to use a streaming data pipeline with an LLM that has minimal latency.

Use case

An example use case is a bustling food chain grappling with analyzing and storing a high volume of customer support requests. Customer interactions include both chats generated by automated chatbots and nuanced conversations that require the attention of live support staff.

Requirements

To address both types of interactions, the workflow has the following requirements.

  • It needs to efficiently manage and store chat data by summarizing positive interactions for easy reference and future analysis.

  • It must use real-time issue detection and resolution.

  • Sentiment analysis must identify dissatisfied customers and generate tailored responses to address their concerns.

Workflow

To meet these requirements, the pipeline processes completed chat messages in near real time. First, the pipeline uses Gemma to monitor the sentiment of the customer chats. All chats are then summarized, with positive or neutral sentiment chats sent directly to a data platform, BigQuery, by using the available Dataflow I/Os.

For chats that have a negative sentiment, the Gemma model crafts a contextually appropriate response for the customer. This response is sent to a human for review so that they can refine the message before it reaches the customer.

This example addresses important complexities inherent in using an LLM within a pipeline. For example, processing the responses in code is challenging because of the non-deterministic nature of the text. In this example, the workflow requires the LLM to generate JSON responses, which is not the default format. The worklow needs to parse and validate the response, a process similar to processing data from sources that don't always have correctly structured data.

This workflow allows businesses to respond to customers faster and to provide personalized responses when needed.

  • The automation of positive chat summarization allows support staff to focus on more complex interactions.
  • The scalability of the system makes it possible to adapt to increasing chat volumes without compromising response quality.

You can also use the in-depth analysis of chat data to drive data-driven decision-making.

The data processing pipeline

Screenshot 2024-08-08 at 11.15.41.png

At a high level, the pipeline has the following steps:

  1. Read the review data from Pub/Sub, the event messaging source. This data contains the chat ID and the chat history as a JSON payload. This payload is processed in the pipeline.
  2. Pass the text from the messages to Gemma with a prompt.
  3. The pipeline requests that the model complete the following two tasks:
    • Attach a sentiment score to the message, by using one of the following three values: 1 for a positive chat, 0 for a neutral chat, and -1 for a negative chat.
    • Provide a one-sentence summary of the chat.
  4. The pipeline branches, depending on the sentiment score:
    • If the score is 1 or 0 , the chat and its summarization are sent to a data analytics system for storage and future analysis.
    • If the score is -1 , the Gemma model drafts a response. This response and the chat information are sent to an event messaging system that connects the pipeline and other applications. This step allows a person to review the content of the response.

Build the pipeline

This section provides the code needed to run the pipeline.

Before you begin

Although you can use CPUs for testing and development, for a production Dataflow ML system we recommend that you use GPUs. When you use GPUs with Dataflow, we recommend that you use a custom container. For more information about configuring GPUs and custom containers with Dataflow, see Best practices for working with Dataflow GPUs . To faciliate rapid testing of the pipeline, follow the guide Run a pipeline with GPUs . The guide includes the steps needed to build the container image.

After you configure your environment, download the model gemma2_instruct_2b_en into a folder. In this example, the folder is named gemma2 .

Build the base image

Add the following Dockerfile to your folder, and then build the base image. Use the Dockerfile to build the image as you create the pipeline.py file. The images are broken into two groups to facilitate testing and development.

 ARG  
 SERVING_BUILD_IMAGE 
 = 
tensorflow/tensorflow:2.16.1-gpu

FROM  
 ${ 
 SERVING_BUILD_IMAGE 
 } 
WORKDIR  
/workspace

COPY  
gemma2  
gemma2
RUN  
apt-get  
update  
-y && 
apt-get  
install  
-y  
cmake && 
apt-get  
install  
-y  
vim 

When testing the pipeline code and when launchig the job on Dataflow, test and launch from inside the container. This step prevents dependency mismatches when running the pipeline on Dataflow.

The requirements.txt file contains the following dependencies:

 apache_beam [ 
gcp ]== 
 2 
.54.0 keras_nlp 
 == 
 0 
.14.3 keras 
 == 
 3 
.4.1
jax [ 
cuda12 ] 
 

The next step includes the files needed to construct the pipeine. The content of the pipeline.py file are contained in a later section of this notebook.

Replace DOCKERFILE_IMAGE with the image that you built using the first Dockerfile.

 FROM  
<DOCKERFILE_IMAGE>

WORKDIR  
/workspace # Copy files from the official SDK image, including the script and dependencies. 
COPY  
--from = 
apache/beam_python3.11_sdk:2.54.0  
/opt/apache/beam  
/opt/apache/beam


COPY  
requirements.txt  
requirements.txt
RUN  
pip  
install  
--upgrade  
--no-cache-dir  
pip  
 \ 
 && 
pip  
install  
--no-cache-dir  
-r  
requirements.txt # Copy the model directory downloaded from Kaggle and the pipeline code. 
COPY  
pipeline.py  
pipeline.py # The colab was tested and run with a JAX backend to let Dataflow workers 
 # pick up the environment needed to include in the Env of the image. 
ENV  
 KERAS_BACKEND 
 = 
 "jax" 
ENV  
 XLA_PYTHON_CLIENT_MEM_FRACTION 
 = 
 "0.9" 
 # Set the entrypoint to the Apache Beam SDK launcher. 
ENTRYPOINT  
 [ 
 "/opt/apache/beam/boot" 
 ] 
 

Run the pipeline

The following code creates and runs the pipeline.

  • The pip install steps are needed to run the code in the notebook, but aren't needed when running the code in your container.

  • Without a GPU, the inference takes a long time to complete.

  % 
 pip 
 install 
 apache_beam 
 [ 
 gcp 
 ] 
 == 
 "2.54.0" 
 keras_nlp 
 == 
 "0.14.3" 
 keras 
 == 
 "3.5.0" 
 jax 
 [ 
 cuda12 
 ] 
 
  import 
  
 os 
 os 
 . 
 environ 
 [ 
 "KERAS_BACKEND" 
 ] 
 = 
 "jax" 
 # Or "tensorflow" or "torch". 
 os 
 . 
 environ 
 [ 
 "XLA_PYTHON_CLIENT_MEM_FRACTION" 
 ] 
 = 
 "0.9" 
 import 
  
 keras 
 import 
  
 keras_nlp 
 import 
  
 numpy 
  
 as 
  
 np 
 import 
  
 json 
 import 
  
 ast 
 import 
  
 re 
 import 
  
 logging 
 import 
  
 apache_beam 
  
 as 
  
 beam 
 from 
  
 apache_beam.ml.inference 
  
 import 
 utils 
 from 
  
 apache_beam.ml.inference.base 
  
 import 
 RunInference 
 from 
  
 apache_beam.options 
  
 import 
 pipeline_options 
 from 
  
 apache_beam.options.pipeline_options 
  
 import 
 GoogleCloudOptions 
 from 
  
 apache_beam.options.pipeline_options 
  
 import 
 PipelineOptions 
 from 
  
 apache_beam.options.pipeline_options 
  
 import 
 SetupOptions 
 from 
  
 apache_beam.options.pipeline_options 
  
 import 
 StandardOptions 
 from 
  
 apache_beam.options.pipeline_options 
  
 import 
 WorkerOptions 
 from 
  
 apache_beam.ml.inference 
  
 import 
 utils 
 from 
  
 apache_beam.ml.inference.base 
  
 import 
 ModelHandler 
 from 
  
 apache_beam.ml.inference.base 
  
 import 
 PredictionResult 
 from 
  
 apache_beam.ml.inference.base 
  
 import 
 KeyedModelHandler 
 from 
  
 keras_nlp.models 
  
 import 
 GemmaCausalLM 
 from 
  
 typing 
  
 import 
 Any 
 , 
 Dict 
 , 
 Iterable 
 , 
 Optional 
 , 
 Sequence 
 
 

Set pipeline options and provide the input Pub/Sub topic. The options that are commented out enable running the pipeline on Dataflow.

  options 
 = 
 PipelineOptions 
 () 
 options 
 . 
 view_as 
 ( 
 StandardOptions 
 ) 
 . 
 streaming 
 = 
 True 
 options 
 . 
 view_as 
 ( 
 SetupOptions 
 ) 
 . 
 save_main_session 
 = 
 True 
 # options.view_as(StandardOptions).runner = "dataflowrunner" 
 # options.view_as(GoogleCloudOptions).project = <PROJECT> 
 # options.view_as(GoogleCloudOptions).temp_location= <TMP LOCATION> 
 # options.view_as(GoogleCloudOptions).region= "us-west1" 
 # options.view_as(WorkerOptions).machine_type= "g2-standard-4" 
 # options.view_as(WorkerOptions).worker_harness_container_image = <IMAGE YOU BUILT> 
 # options.view_as(WorkerOptions).disk_size_gb=200 
 # options.view_as(GoogleCloudOptions).dataflow_service_options=["worker_accelerator=type:nvidia-l4;count:1;install-nvidia-driver"] 
 topic_reviews 
 = 
 "<PubSub Topic>" 
 # @param {type:'string'} 
 
 

Define a custom model handler that loads the Gemma model and handles inference calls.

  class 
  
 GemmaModelHandler 
 ( 
 ModelHandler 
 [ 
 str 
 , 
 PredictionResult 
 , 
 GemmaCausalLM 
 ]): 
 def 
  
 __init__ 
 ( 
 self 
 , 
 model_name 
 : 
 str 
 = 
 "" 
 , 
 ): 
  
 """ Implementation of the ModelHandler interface for Gemma using text as input. 
 Example Usage:: 
 pcoll | RunInference(GemmaModelHandler()) 
 Args: 
 model_name: The Gemma model uri. 
 """ 
 self 
 . 
 _model_name 
 = 
 model_name 
 self 
 . 
 _env_vars 
 = 
 {} 
 def 
  
 share_model_across_processes 
 ( 
 self 
 ) 
 - 
> bool 
 : 
  
 """Returns whether to share a single model in memory across processes. 
 This is useful when the loaded model is large, preventing potential 
 out-of-memory issues when running the pipeline. 
 Returns: 
 bool 
 """ 
 return 
 True 
 def 
  
 load_model 
 ( 
 self 
 ) 
 - 
> GemmaCausalLM 
 : 
  
 """Loads and initializes a model for processing.""" 
 return 
 keras_nlp 
 . 
 models 
 . 
 GemmaCausalLM 
 . 
 from_preset 
 ( 
 self 
 . 
 _model_name 
 ) 
 def 
  
 run_inference 
 ( 
 self 
 , 
 batch 
 : 
 Sequence 
 [ 
 str 
 ], 
 model 
 : 
 GemmaCausalLM 
 , 
 inference_args 
 : 
 Optional 
 [ 
 Dict 
 [ 
 str 
 , 
 Any 
 ]] 
 = 
 None 
 ) 
 - 
> Iterable 
 [ 
 PredictionResult 
 ]: 
  
 """Runs inferences on a batch of text strings. 
 Args: 
 batch: A sequence of examples as text strings. 
 model: 
 inference_args: Any additional arguments for an inference. 
 Returns: 
 An Iterable of type PredictionResult. 
 """ 
 # Loop each text string, and use a tuple to store the inference results. 
 predictions 
 = 
 [] 
 for 
 one_text 
 in 
 batch 
 : 
 result 
 = 
 model 
 . 
 generate 
 ( 
 one_text 
 , 
 max_length 
 = 
 1024 
 ) 
 predictions 
 . 
 append 
 ( 
 result 
 ) 
 return 
 utils 
 . 
 _convert_to_result 
 ( 
 batch 
 , 
 predictions 
 , 
 self 
 . 
 _model_name 
 ) 
 
 

Define a prompt template. The template formats inputs and instructs the model on the task. This block also contains an example input to the model.

  prompt_template 
 = 
 """ 
< prompt 
> Provide the results of doing these two tasks on the chat history provided below for the user 
 {} 
 task 1 : assess if the tone is happy = 1 , neutral = 0 or unhappy = -1 
 task 2 : summarize the text with a maximum of 512 characters 
 Return the answer as a JSON string with fields [sentiment, summary] do NOT explain your answer 
 @@@ 
 {} 
 @@@ 
< answer 
> """ 
 chat_text 
 = 
 """ 
 id 221: Hay I am really annoyed that your menu includes a pizza with pineapple on it! 
 id 331: Sorry to hear that , but pineapple is nice on pizza 
 id 221: What a terriable thing to say! Its never ok, so unhappy right now! 
 """ 
 # Example input 
 chat 
 = 
 json 
 . 
 dumps 
 ({ 
 "id" 
 : 
 42 
 , 
 "user_id" 
 : 
 221 
 , 
 "chat_message" 
 : 
 chat_text 
 }) 
 print 
 ( 
 chat 
 ) 
 
 

Define pre- and post-processing functions. The CreatePrompt class creates a key-value pair that contains the chat ID and the formatted prompt. The extract_model_reply function parses the response, extracting the JSON string requested from the model. The LLM is not guaranteed to return a JSON-formatted object. If the reply is malformed, the code produces an exception. This helper is used in the SentimentAnalysis class DoFn to split out the sentiment score and the summary of the text. The sentiment score is then used to tag the output, allowing the pipeline to handle different sentiments in different ways.

  keyed_model_handler 
 = 
 KeyedModelHandler 
 ( 
 GemmaModelHandler 
 ( 
 'gemma_2' 
 )) 
 # Create the prompt by using the information from the chat. 
 class 
  
 CreatePrompt 
 ( 
 beam 
 . 
 DoFn 
 ): 
 def 
  
 process 
 ( 
 self 
 , 
 element 
 , 
 * 
 args 
 , 
 ** 
 kwargs 
 ): 
 user_chat 
 = 
 json 
 . 
 loads 
 ( 
 element 
 ) 
 chat_id 
 = 
 user_chat 
 [ 
 'id' 
 ] 
 user_id 
 = 
 user_chat 
 [ 
 'user_id' 
 ] 
 messages 
 = 
 user_chat 
 [ 
 'chat_message' 
 ] 
 yield 
 ( 
 chat_id 
 , 
 prompt_template 
 . 
 format 
 ( 
 user_id 
 , 
 messages 
 )) 
 def 
  
 extract_model_reply 
 ( 
 model_inference 
 ): 
 print 
 ( 
 model_inference 
 ) 
 match 
 = 
 re 
 . 
 search 
 ( 
 r 
 "(\{[\s\S]*?\})" 
 , 
 model_inference 
 ) 
 json_str 
 = 
 match 
 . 
 group 
 ( 
 1 
 ) 
 print 
 ( 
 json_str 
 ) 
 result 
 = 
 json 
 . 
 loads 
 ( 
 json_str 
 ) 
 if 
 all 
 ( 
 key 
 in 
 result 
 for 
 key 
 in 
 [ 
 'sentiment' 
 , 
 'summary' 
 ]): 
 return 
 result 
 raise 
 Exception 
 ( 
 'Malformed model reply' 
 ) 
 # @title 
 class 
  
 SentimentAnalysis 
 ( 
 beam 
 . 
 DoFn 
 ): 
 def 
  
 process 
 ( 
 self 
 , 
 element 
 ): 
 key 
 = 
 element 
 [ 
 0 
 ] 
 match 
 = 
 re 
 . 
 search 
 ( 
 r 
 "@@@([\s\S]*?)@@@" 
 , 
 element 
 [ 
 1 
 ] 
 . 
 example 
 ) 
 chats 
 = 
 match 
 . 
 group 
 ( 
 1 
 ) 
 try 
 : 
 # The result contains the prompt. Replace the prompt with "". 
 result 
 = 
 extract_model_reply 
 ( 
 element 
 [ 
 1 
 ] 
 . 
 inference 
 . 
 replace 
 ( 
 element 
 [ 
 1 
 ] 
 . 
 example 
 , 
 "" 
 )) 
 processed_result 
 = 
 ( 
 key 
 , 
 chats 
 , 
 result 
 [ 
 'sentiment' 
 ], 
 result 
 [ 
 'summary' 
 ]) 
 if 
 ( 
 result 
 [ 
 'sentiment' 
 ] 
 ==- 
 1 
 ): 
 output 
 = 
 beam 
 . 
 TaggedOutput 
 ( 
 'negative' 
 , 
 processed_result 
 ) 
 else 
 : 
 output 
 = 
 beam 
 . 
 TaggedOutput 
 ( 
 'main' 
 , 
 processed_result 
 ) 
 except 
 Exception 
 as 
 err 
 : 
 print 
 ( 
 "ERROR!" 
 + 
 str 
 ( 
 err 
 )) 
 output 
 = 
 beam 
 . 
 TaggedOutput 
 ( 
 'error' 
 , 
 element 
 ) 
 yield 
 output 
 gemma_inference 
 = 
 RunInference 
 ( 
 keyed_model_handler 
 ) 
 
 

To run the pipeline, use the following code. To use the example chat input created previously, instead of a custom Pub/Sub source, use chats = p | beam.Create([chat]) instead of the Pub/Sub read.

  with 
 beam 
 . 
 Pipeline 
 ( 
 options 
 = 
 options 
 ) 
 as 
 p 
 : 
 chats 
 = 
 ( 
 p 
 | 
 "Read Topic" 
>> beam 
 . 
 io 
 . 
 ReadFromPubSub 
 ( 
 subscription 
 = 
 topic_reviews 
 ) 
 | 
 "Parse" 
>> beam 
 . 
 Map 
 ( 
 lambda 
 x 
 : 
 x 
 . 
 decode 
 ( 
 "utf-8" 
 ))) 
 prompts 
 = 
 chats 
 | 
 "Create Prompt" 
>> beam 
 . 
 ParDo 
 ( 
 CreatePrompt 
 ()) 
 results 
 = 
 prompts 
 | 
 "RunInference-Gemma" 
>> gemma_inference 
 filtered_results 
 = 
 results 
 | 
 "Process Results" 
>> beam 
 . 
 ParDo 
 ( 
 SentimentAnalysis 
 ()) 
 . 
 with_outputs 
 ( 
 'main' 
 , 
 'negative' 
 , 
 'error' 
 ) 
 generated_responses 
 = 
 ( 
 filtered_results 
 . 
 negative 
 | 
 "Generate Response" 
>> beam 
 . 
 Map 
 ( 
 lambda 
 x 
 : 
 (( 
 x 
 [ 
 0 
 ], 
 x 
 [ 
 3 
 ]), 
 "<prompt>Generate an appology reponse for the user in this chat text: " 
 + 
 x 
 [ 
 1 
 ] 
 + 
 "<answer>" 
 )) 
 | 
 "Gemma-Response" 
>> gemma_inference 
 ) 
 generated_responses 
 | 
 "Print Reponse" 
>> beam 
 . 
 Map 
 ( 
 lambda 
 x 
 : 
 logging 
 . 
 info 
 ( 
 x 
 )) 
 filtered_results 
 . 
 main 
 | 
 "Print Main" 
>> beam 
 . 
 Map 
 ( 
 lambda 
 x 
 : 
 logging 
 . 
 info 
 ( 
 x 
 )) 
 filtered_results 
 . 
 error 
 | 
 "Print Errors" 
>> beam 
 . 
 Map 
 ( 
 lambda 
 x 
 : 
 logging 
 . 
 info 
 ( 
 x 
 )) 
 
 
Create a Mobile Website
View Site in Mobile | Classic
Share by: