Generate text embeddings by using the EmbeddingGemma model from Hugging Face

Use text embeddings to represent text as numerical vectors. This process lets computers understand and process text data, which is essential for many natural language processing (NLP) tasks.

The following NLP tasks use embeddings:

  • Semantic search:Find documents or passages that are relevant to a query when the query doesn't use the exact same words as the documents.
  • Text classification:Categorize text data into different classes, such as spam and not spam, or positive sentiment and negative sentiment.
  • Machine translation:Translate text from one language to another and preserve the meaning.
  • Text summarization:Create shorter summaries of text.

This notebook uses Apache Beam's MLTransform to generate embeddings from text data.

Using a small, highly efficient open model like EmbeddingGemma at the core of your pipeline makes the entire process self-contained, which can simplify management by eliminating the need for external network calls to other services for the embedding step. Because it's an open model, it can be hosted entirely within Dataflow. This provides the confidence to securely process large-scale, private datasets. For more information about the model, see the model card

Hugging Face's SentenceTransformers framework uses Python to generate sentence, text, and image embeddings.

To generate text embeddings that use Hugging Face models and MLTransform , use the SentenceTransformerEmbeddings module to specify the model configuration.

Install dependencies

Install Apache Beam and the dependencies needed to work with Hugging Face embeddings. The dependencies includes the sentence-transformers package, which is required to use the SentenceTransformerEmbeddings module.

   
pip  
install  
apache_beam> = 
 2 
.53.0  
--quiet 
   
pip  
install  
sentence-transformers  
--quiet 
  import 
  
 tempfile 
 import 
  
 apache_beam 
  
 as 
  
 beam 
 from 
  
 apache_beam.ml.transforms.base 
  
 import 
 MLTransform 
 from 
  
 apache_beam.ml.transforms.embeddings.huggingface 
  
 import 
 SentenceTransformerEmbeddings 
 

Authenticate with HuggingFace

To ensure that you can pull the correct model, authenticate with HuggingFace by following the prompts in the cell.

 hf  
auth  
login 

Process the data

MLTransform is a PTransform that you can use for data preparation, including generating text embeddings.

Use MLTransform in write mode

In write mode, MLTransform saves the transforms and their attributes to an artifact location. Then, when you run MLTransform in read mode, these transforms are used. This process ensures that you're applying the same preprocessing steps when you train your model and when you serve the model in production or test its accuracy.

For more information about using MLTransform , see Preprocess data with MLTransform in the Apache Beam documentation.

Get the data

The following text inputs come from the Hugging Face blog Getting Started With Embeddings .

MLTransform operates on dictionaries of data. To generate embeddings for specific columns, provide the column names as input to the columns argument in the SentenceTransformerEmbeddings package.

  content 
 = 
 [ 
 { 
 'x' 
 : 
 'How do I get a replacement Medicare card?' 
 }, 
 { 
 'x' 
 : 
 'What is the monthly premium for Medicare Part B?' 
 }, 
 { 
 'x' 
 : 
 'How do I terminate my Medicare Part B (medical insurance)?' 
 }, 
 { 
 'x' 
 : 
 'How do I sign up for Medicare?' 
 }, 
 { 
 'x' 
 : 
 'Can I sign up for Medicare Part B if I am working and have health insurance through an employer?' 
 }, 
 { 
 'x' 
 : 
 'How do I sign up for Medicare Part B if I already have Part A?' 
 }, 
 { 
 'x' 
 : 
 'What are Medicare late enrollment penalties?' 
 }, 
 { 
 'x' 
 : 
 'What is Medicare and who can get it?' 
 }, 
 { 
 'x' 
 : 
 'How can I get help with my Medicare Part A and Part B premiums?' 
 }, 
 { 
 'x' 
 : 
 'What are the different parts of Medicare?' 
 }, 
 { 
 'x' 
 : 
 'Will my Medicare premiums be higher because of my higher income?' 
 }, 
 { 
 'x' 
 : 
 'What is TRICARE ?' 
 }, 
 { 
 'x' 
 : 
 "Should I sign up for Medicare Part B if I have Veterans' Benefits?" 
 } 
 ] 
 text_embedding_model_name 
 = 
 'google/embeddinggemma-300m' 
 # helper function that returns a dict containing only first 
 # ten elements of generated embeddings 
 def 
  
 truncate_embeddings 
 ( 
 d 
 ): 
 for 
 key 
 in 
 d 
 . 
 keys 
 (): 
 d 
 [ 
 key 
 ] 
 = 
 d 
 [ 
 key 
 ][: 
 10 
 ] 
 return 
 d 
 

Generate text embeddings

This example uses the model google/embeddinggemma-300m to generate text embeddings. For more information about the model, see the model card .

  artifact_location_t5 
 = 
 tempfile 
 . 
 mkdtemp 
 ( 
 prefix 
 = 
 'huggingface_' 
 ) 
 embedding_transform 
 = 
 SentenceTransformerEmbeddings 
 ( 
 model_name 
 = 
 text_embedding_model_name 
 , 
 columns 
 = 
 [ 
 'x' 
 ]) 
 with 
 beam 
 . 
 Pipeline 
 () 
 as 
 pipeline 
 : 
 data_pcoll 
 = 
 ( 
 pipeline 
 | 
 "CreateData" 
>> beam 
 . 
 Create 
 ( 
 content 
 )) 
 transformed_pcoll 
 = 
 ( 
 data_pcoll 
 | 
 "MLTransform" 
>> MLTransform 
 ( 
 write_artifact_location 
 = 
 artifact_location_t5 
 ) 
 . 
 with_transform 
 ( 
 embedding_transform 
 )) 
 transformed_pcoll 
 | 
 beam 
 . 
 Map 
 ( 
 truncate_embeddings 
 ) 
 | 
 'LogOutput' 
>> beam 
 . 
 Map 
 ( 
 print 
 ) 
 transformed_pcoll 
 | 
 "PrintEmbeddingShape" 
>> beam 
 . 
 Map 
 ( 
 lambda 
 x 
 : 
 print 
 ( 
 f 
 "Embedding shape: 
 { 
 len 
 ( 
 x 
 [ 
 'x' 
 ]) 
 } 
 " 
 )) 
 
{'x': [-0.0317193828523159, -0.005265399813652039, -0.012499183416366577, 0.00018130357784684747, -0.005592408124357462, 0.06207558885216713, -0.01656288281083107, 0.0167048592120409, -0.01239298190921545, 0.03041897714138031]}
Embedding shape: 10
{'x': [-0.015295305289328098, 0.005405726842582226, -0.015631258487701416, 0.022797023877501488, -0.027843449264764786, 0.03968179598450661, -0.004387892782688141, 0.022909151390194893, 0.01015392318367958, 0.04723235219717026]}
Embedding shape: 10
{'x': [-0.03450256213545799, -0.002632762538269162, -0.022460950538516045, -0.011689935810863972, -0.027329981327056885, 0.07293087989091873, -0.03069353476166725, 0.05429817736148834, -0.01308195199817419, 0.017668722197413445]}
Embedding shape: 10
{'x': [-0.02869587577879429, -0.0002648509689606726, -0.007186499424278736, -0.0003750955802388489, 0.012458174489438534, 0.06721009314060211, -0.013404129073023796, 0.03204648941755295, -0.021021844819188118, 0.04968355968594551]}
Embedding shape: 10
{'x': [-0.03241290897130966, 0.006845517549663782, 0.02001815102994442, -0.0057969288900494576, 0.008191823959350586, 0.08160955458879471, -0.009215254336595535, 0.023534387350082397, -0.02034241147339344, 0.0357462577521801]}
Embedding shape: 10
{'x': [-0.04592451825737953, -0.0025395643897354603, -0.01178023498505354, 0.011568977497518063, -0.0029014083556830883, 0.06971456110477448, -0.021167151629924774, 0.015902182087302208, -0.015007994137704372, 0.026213033124804497]}
Embedding shape: 10
{'x': [0.005221465136855841, -0.002127869985997677, -0.002369001042097807, -0.019337018951773643, 0.023243796080350876, 0.05599674955010414, -0.022721167653799057, 0.024813007563352585, -0.010685156099498272, 0.03624529018998146]}
Embedding shape: 10
{'x': [-0.035339221358299255, 0.010706206783652306, -0.001701260800473392, -0.00862252525985241, 0.006445988081395626, 0.08198338001966476, -0.022678885608911514, 0.01434261817485094, -0.008092232048511505, 0.03345781937241554]}
Embedding shape: 10
{'x': [-0.030748076736927032, 0.009340512566268444, -0.013637945055961609, 0.011183148249983788, -0.013879665173590183, 0.046350326389074326, -0.024090109393000603, 0.02885228954255581, -0.01699884608387947, 0.01672385260462761]}
Embedding shape: 10
{'x': [-0.040792081505060196, -0.00872269831597805, -0.015838179737329483, -0.03141209855675697, -7.104632823029533e-05, 0.08301416039466858, -0.034691162407398224, 0.0026397297624498606, 0.009255227632820606, 0.05415954813361168]}
Embedding shape: 10
{'x': [-0.02156883291900158, 0.003969342447817326, -0.030446071177721024, 0.008231461979448795, -0.01271845493465662, 0.03793857619166374, -0.013524272479116917, -0.0385628417134285, -0.0058258213102817535, 0.03505263477563858]}
Embedding shape: 10
{'x': [-0.027544165030121803, -0.01773364469408989, -0.013286487199366093, -0.008328652940690517, -0.011047529056668282, 0.05237515643239021, -0.016948163509368896, 0.02806701697409153, -0.0018120920285582542, 0.027241172268986702]}
Embedding shape: 10
{'x': [-0.03464886546134949, -0.003521248232573271, -0.010239562019705772, -0.018618224188685417, 0.004094886127859354, 0.062059685587882996, -0.013881963677704334, -0.0008639032603241503, -0.029874088242650032, 0.033531222492456436]}
Embedding shape: 10

You can pass additional arguments that are supported by sentence-transformer models, such as convert_to_numpy=False . These arguments are passed as a dict to the SentenceTransformerEmbeddings transform by using the inference_args parameter.

When you pass convert_to_numpy=False , the output contains torch.Tensor matrices.

  artifact_location_t5_with_inference_args 
 = 
 tempfile 
 . 
 mkdtemp 
 ( 
 prefix 
 = 
 'huggingface_' 
 ) 
 embedding_transform 
 = 
 SentenceTransformerEmbeddings 
 ( 
 model_name 
 = 
 text_embedding_model_name 
 , 
 columns 
 = 
 [ 
 'x' 
 ], 
 inference_args 
 = 
 { 
 'convert_to_numpy' 
 : 
 False 
 } 
 ) 
 with 
 beam 
 . 
 Pipeline 
 () 
 as 
 pipeline 
 : 
 data_pcoll 
 = 
 ( 
 pipeline 
 | 
 "CreateData" 
>> beam 
 . 
 Create 
 ( 
 content 
 )) 
 transformed_pcoll 
 = 
 ( 
 data_pcoll 
 | 
 "MLTransform" 
>> MLTransform 
 ( 
 write_artifact_location 
 = 
 artifact_location_t5_with_inference_args 
 ) 
 . 
 with_transform 
 ( 
 embedding_transform 
 )) 
 # The outputs are in the PyTorch tensor type. 
 transformed_pcoll 
 | 
 'LogOutput' 
>> beam 
 . 
 Map 
 ( 
 lambda 
 x 
 : 
 print 
 ( 
 type 
 ( 
 x 
 [ 
 'x' 
 ]))) 
 transformed_pcoll 
 | 
 "PrintEmbeddingShape" 
>> beam 
 . 
 Map 
 ( 
 lambda 
 x 
 : 
 print 
 ( 
 f 
 "Embedding shape: 
 { 
 len 
 ( 
 x 
 [ 
 'x' 
 ]) 
 } 
 " 
 )) 
 
<class 'torch.Tensor'>
Embedding shape&colon; 768
<class 'torch.Tensor'>
Embedding shape&colon; 768
<class 'torch.Tensor'>
Embedding shape&colon; 768
<class 'torch.Tensor'>
Embedding shape&colon; 768
<class 'torch.Tensor'>
Embedding shape&colon; 768
<class 'torch.Tensor'>
Embedding shape&colon; 768
<class 'torch.Tensor'>
Embedding shape&colon; 768
<class 'torch.Tensor'>
Embedding shape&colon; 768
<class 'torch.Tensor'>
Embedding shape&colon; 768
<class 'torch.Tensor'>
Embedding shape&colon; 768
<class 'torch.Tensor'>
Embedding shape&colon; 768
<class 'torch.Tensor'>
Embedding shape&colon; 768
<class 'torch.Tensor'>
Embedding shape&colon; 768

Use MLTransform in read mode

In read mode, MLTransform uses the artifacts generated during write mode. In this case, the SentenceTransformEmbedding transform and its attributes are loaded from the saved artifacts. You don't need to specify the artifacts again during read mode.

In this way, MLTransform provides consistent preprocessing steps for training and inference workloads.

  test_content 
 = 
 [ 
 { 
 'x' 
 : 
 'This is a test sentence' 
 }, 
 { 
 'x' 
 : 
 'The park is full of dogs' 
 }, 
 { 
 'x' 
 : 
 "Should I sign up for Medicare Part B if I have Veterans' Benefits?" 
 } 
 ] 
 # Uses the T5 model to generate text embeddings 
 with 
 beam 
 . 
 Pipeline 
 () 
 as 
 pipeline 
 : 
 data_pcoll 
 = 
 ( 
 pipeline 
 | 
 "CreateData" 
>> beam 
 . 
 Create 
 ( 
 test_content 
 )) 
 transformed_pcoll 
 = 
 ( 
 data_pcoll 
 | 
 "MLTransform" 
>> MLTransform 
 ( 
 read_artifact_location 
 = 
 artifact_location_t5 
 )) 
 transformed_pcoll 
 | 
 beam 
 . 
 Map 
 ( 
 truncate_embeddings 
 ) 
 | 
 'LogOutput' 
>> beam 
 . 
 Map 
 ( 
 print 
 ) 
 
{'x'&colon; [0.00036313451710157096, -0.03929319977760315, -0.03574873134493828, 0.05015222355723381, 0.04295048117637634, 0.04800170287489891, 0.006883862894028425, -0.02567591704428196, -0.048067063093185425, 0.036534328013658524]}
{'x'&colon; [-0.053793832659721375, 0.006730600260198116, -0.025130020454525948, 0.04363932088017464, 0.03323192894458771, 0.008803879842162132, -0.015412433072924614, 0.008926985785365105, -0.061175212264060974, 0.04573329910635948]}
{'x'&colon; [-0.03464885801076889, -0.003521254053339362, -0.010239563882350922, -0.018618224188685417, 0.004094892647117376, 0.062059689313173294, -0.013881963677704334, -0.000863900815602392, -0.029874078929424286, 0.03353121876716614]}
Create a Mobile Website
View Site in Mobile | Classic
Share by: