Notebooks
G
Google Gemini
Classify Text With Embeddings

Classify Text With Embeddings

gemini-cookbookgemini-apiexamplesgemini
Copyright 2025 Google LLC.
[1]

Classify text with embeddings

⚠️

This notebook requires paid tier rate limits to run properly.
(cf. pricing for more details).

Overview

In this notebook, you'll learn to use the embeddings produced by the Gemini API to train a model that can classify different types of newsgroup posts based on the topic.

[ ]
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 149.6/149.6 kB 3.2 MB/s eta 0:00:00

Grab an API Key

Before you can use the Gemini API, you must first obtain an API key. If you don't already have one, create a key with one click in Google AI Studio.

Get an API key

In Colab, add the key to the secrets manager under the "🔑" in the left panel. Give it the name GEMINI_API_KEY.

Once you have the API key, pass it to the SDK. You can do this in two ways:

  • Put the key in the GEMINI_API_KEY environment variable (the SDK will automatically pick it up from there).
  • Pass the key to genai.Client(api_key=...)
[ ]

Key Point: Next, you will choose a model. Any embedding model will work for this tutorial, but for real applications it's important to choose a specific model and stick with it. The outputs of different models are not compatible with each other.

[2]
models/embedding-001
models/text-embedding-004
models/gemini-embedding-exp-03-07
models/gemini-embedding-exp
models/gemini-embedding-001

Select the model to be used

[3]
MODEL_ID

Prepare the dataset

The 20 Newsgroups Text Dataset contains 18,000 newsgroups posts on 20 topics divided into training and test sets. The split between the training and test datasets are based on messages posted before and after a specific date. For this tutorial, you will be using the subsets of the training and test datasets. You will preprocess and organize the data into Pandas dataframes.

[4]
['alt.atheism',
, 'comp.graphics',
, 'comp.os.ms-windows.misc',
, 'comp.sys.ibm.pc.hardware',
, 'comp.sys.mac.hardware',
, 'comp.windows.x',
, 'misc.forsale',
, 'rec.autos',
, 'rec.motorcycles',
, 'rec.sport.baseball',
, 'rec.sport.hockey',
, 'sci.crypt',
, 'sci.electronics',
, 'sci.med',
, 'sci.space',
, 'soc.religion.christian',
, 'talk.politics.guns',
, 'talk.politics.mideast',
, 'talk.politics.misc',
, 'talk.religion.misc']

Here is an example of what a data point from the training set looks like.

[5]
Lines: 15

 I was wondering if anyone out there could enlighten me on this car I saw
the other day. It was a 2-door sports car, looked to be from the late 60s/
early 70s. It was called a Bricklin. The doors were really small. In addition,
the front bumper was separate from the rest of the body. This is 
all I know. If anyone can tellme a model name, engine specs, years
of production, where this car is made, history, or whatever info you
have on this funky looking car, please e-mail.

Thanks,
- IL
   ---- brought to you by your neighborhood Lerxst ----





Now you will begin preprocessing the data for this tutorial. Remove any sensitive information like names, email, or redundant parts of the text like "From: " and "\nSubject: ". Organize the information into a Pandas dataframe so it is more readable.

[6]
[7]

Next, you will sample some of the data by taking 100 data points in the training dataset, and dropping a few of the categories to run through this tutorial. Choose the science categories to compare.

[8]
[9]
/tmp/ipykernel_190557/1286960996.py:2: FutureWarning: DataFrameGroupBy.apply operated on the grouping columns. This behavior is deprecated, and in a future version of pandas the grouping columns will be excluded from the operation. Either pass `include_groups=False` to exclude the groupings or explicitly select the grouping columns after groupby to silence this warning.
  df = df.groupby('Label', as_index = False).apply(lambda x: x.sample(num_samples)).reset_index(drop=True)
/tmp/ipykernel_190557/1286960996.py:2: FutureWarning: DataFrameGroupBy.apply operated on the grouping columns. This behavior is deprecated, and in a future version of pandas the grouping columns will be excluded from the operation. Either pass `include_groups=False` to exclude the groupings or explicitly select the grouping columns after groupby to silence this warning.
  df = df.groupby('Label', as_index = False).apply(lambda x: x.sample(num_samples)).reset_index(drop=True)
[10]
Class Name
,sci.crypt          100
,sci.electronics    100
,sci.med            100
,sci.space          100
,Name: count, dtype: int64
[11]
Class Name
,sci.crypt          25
,sci.electronics    25
,sci.med            25
,sci.space          25
,Name: count, dtype: int64

Generate the embeddings

In this section, you will see how to generate embeddings for the different texts in the dataframe using the embeddings from the Gemini API.

The Gemini embedding model supports several task types, each tailored for a specific goal. Here’s a general overview of the available types and their applications:

Task TypeDescription
RETRIEVAL_QUERYSpecifies the given text is a query in a search/retrieval setting.
RETRIEVAL_DOCUMENTSpecifies the given text is a document in a search/retrieval setting.
SEMANTIC_SIMILARITYSpecifies the given text will be used for Semantic Textual Similarity (STS).
CLASSIFICATIONSpecifies that the embeddings will be used for classification.
CLUSTERINGSpecifies that the embeddings will be used for clustering.
[18]
[21]

Preparing the training dataset:

[24]
100%|██████████| 400/400 [03:35<00:00,  1.86it/s]
[25]
100%|██████████| 100/100 [00:50<00:00,  1.99it/s]

Build a simple classification model

Here you will define a simple model with one hidden layer and a single class probability output. The prediction will correspond to the probability of a piece of text being a particular class of news. When you build your model, Keras will automatically shuffle the data points.

[28]
[29]
2025-07-10 19:34:32.478648: E external/local_xla/xla/stream_executor/cuda/cuda_platform.cc:51] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)
[30]
3072

Train the model to classify newsgroups

Finally, you can train a simple model. Use a small number of epochs to avoid overfitting. The first epoch takes much longer than the rest, because the embeddings need to be computed only once.

[31]
Epoch 1/25
13/13 ━━━━━━━━━━━━━━━━━━━━ 2s 54ms/step - accuracy: 0.6282 - loss: 1.1392 - val_accuracy: 0.8800 - val_loss: 0.5001
Epoch 2/25
13/13 ━━━━━━━━━━━━━━━━━━━━ 1s 50ms/step - accuracy: 0.9463 - loss: 0.3431 - val_accuracy: 0.9100 - val_loss: 0.2622
Epoch 3/25
13/13 ━━━━━━━━━━━━━━━━━━━━ 1s 53ms/step - accuracy: 0.9749 - loss: 0.1185 - val_accuracy: 0.9200 - val_loss: 0.1709
Epoch 4/25
13/13 ━━━━━━━━━━━━━━━━━━━━ 1s 54ms/step - accuracy: 0.9909 - loss: 0.0580 - val_accuracy: 0.9200 - val_loss: 0.2019
Epoch 5/25
13/13 ━━━━━━━━━━━━━━━━━━━━ 1s 57ms/step - accuracy: 0.9927 - loss: 0.0396 - val_accuracy: 0.9300 - val_loss: 0.1498
Epoch 6/25
13/13 ━━━━━━━━━━━━━━━━━━━━ 1s 61ms/step - accuracy: 1.0000 - loss: 0.0220 - val_accuracy: 0.9300 - val_loss: 0.1596
Epoch 7/25
13/13 ━━━━━━━━━━━━━━━━━━━━ 1s 65ms/step - accuracy: 1.0000 - loss: 0.0158 - val_accuracy: 0.9400 - val_loss: 0.1307
Epoch 8/25
13/13 ━━━━━━━━━━━━━━━━━━━━ 1s 73ms/step - accuracy: 1.0000 - loss: 0.0106 - val_accuracy: 0.9300 - val_loss: 0.1309
Epoch 9/25
13/13 ━━━━━━━━━━━━━━━━━━━━ 1s 67ms/step - accuracy: 1.0000 - loss: 0.0089 - val_accuracy: 0.9300 - val_loss: 0.1293

Evaluate model performance

Use Keras Model.evaluate to get the loss and accuracy on the test dataset.

[32]
4/4 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step - accuracy: 0.9355 - loss: 0.1274
{'accuracy': 0.9300000071525574, 'loss': 0.12925000488758087}

One way to evaluate your model performance is to visualize the classifier performance. Use plot_history to see the loss and accuracy trends over the epochs.

[33]
Output

Another way to view model performance, beyond just measuring loss and accuracy is to use a confusion matrix. The confusion matrix allows you to assess the performance of the classification model beyond accuracy. You can see what misclassified points get classified as. In order to build the confusion matrix for this multi-class classification problem, get the actual values in the test set and the predicted values.

Start by generating the predicted class for each example in the validation set using Model.predict().

[34]
4/4 ━━━━━━━━━━━━━━━━━━━━ 0s 20ms/step
[35]
{'sci.crypt': 0, 'sci.electronics': 1, 'sci.med': 2, 'sci.space': 3}
[24]
Output

Next steps

You've now created your own text classification! To learn how to use other services in the Gemini API, see the Get started guide.