Grpo For Verifiable Tasks
GRPO Fine-tuning for verifiable tasks
Fine-tuning requires a GPU. If you don't have one locally, you can run this notebook for free on Google Colab using a free NVIDIA T4 GPU instance.
What's in this notebook?
In this notebook you will learn how to fine-tune a Small Text-to-text Language Model for verifiable tasks using Group Relative Policy Optimization (GRPO).
GRPO is a Reinforcement Learning algorithm widely used by AI labs and practitioners to fine-tune models for easily verifiables tasks like
- Mathematical problem solving with numeric verification
- Code generation with unit test validation
- Structured output tasks (JSON, SQL) with schema validation
- Question answering with ground truth answers
We will cover
- Environment setup
- Data preparation
- Model training
- Local inference with your new model
- Model saving and exporting it into the format you need for deployment.
Deployment options
LFM2.5 models are small and efficient, enabling deployment across a wide range of platforms:
| Deployment Target | Use Case |
|---|---|
| 📱 Android | Mobile apps on Android devices |
| 📱 iOS | Mobile apps on iPhone/iPad |
| 🍎 Apple Silicon Mac | Local inference on Mac with MLX |
| 🦙 llama.cpp | Local deployments on any hardware |
| 🦙 Ollama | Local inference with easy setup |
| 🖥️ LM Studio | Desktop app for local inference |
| ⚡ vLLM | Cloud deployments with high throughput |
| ☁️ Modal | Serverless cloud deployment |
| 🏗️ Baseten | Production ML infrastructure |
| 🚀 Fal | Fast inference API |
📦 Installation & Setup
First, let's install all the required packages:
We'll install TRL with the PEFT extra, which ensures all main dependencies such as Transformers and PEFT (a package for parameter-efficient fine-tuning, e.g., LoRA/QLoRA) are included.
Additionally, we'll install
- trackio to log and monitor our experiments
- bitsandbytes to enable quantization of LLMs, reducing memory consumption for both inference and training, and
- liger-kernel for more efficient training.
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 59.1/59.1 MB 13.7 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.0/1.0 MB 41.1 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 209.1/209.1 kB 16.8 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 230.6/230.6 kB 9.5 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 24.3/24.3 MB 35.2 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 55.7/55.7 kB 2.2 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 9.9/9.9 MB 45.2 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 532.9/532.9 kB 11.8 MB/s eta 0:00:00 Collecting flash-attn Downloading flash_attn-2.8.3.tar.gz (8.4 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 8.4/8.4 MB 93.0 MB/s eta 0:00:00 Preparing metadata (setup.py) ... done Requirement already satisfied: torch in /usr/local/lib/python3.12/dist-packages (from flash-attn) (2.9.0+cu126) Requirement already satisfied: einops in /usr/local/lib/python3.12/dist-packages (from flash-attn) (0.8.1) Requirement already satisfied: filelock in /usr/local/lib/python3.12/dist-packages (from torch->flash-attn) (3.20.3) Requirement already satisfied: typing-extensions>=4.10.0 in /usr/local/lib/python3.12/dist-packages (from torch->flash-attn) (4.15.0) Requirement already satisfied: setuptools in /usr/local/lib/python3.12/dist-packages (from torch->flash-attn) (75.2.0) Requirement already satisfied: sympy>=1.13.3 in /usr/local/lib/python3.12/dist-packages (from torch->flash-attn) (1.14.0) Requirement already satisfied: networkx>=2.5.1 in /usr/local/lib/python3.12/dist-packages (from torch->flash-attn) (3.6.1) Requirement already satisfied: jinja2 in /usr/local/lib/python3.12/dist-packages (from torch->flash-attn) (3.1.6) Requirement already satisfied: fsspec>=0.8.5 in /usr/local/lib/python3.12/dist-packages (from torch->flash-attn) (2025.3.0) Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.6.77 in /usr/local/lib/python3.12/dist-packages (from torch->flash-attn) (12.6.77) Requirement already satisfied: nvidia-cuda-runtime-cu12==12.6.77 in /usr/local/lib/python3.12/dist-packages (from torch->flash-attn) (12.6.77) Requirement already satisfied: nvidia-cuda-cupti-cu12==12.6.80 in /usr/local/lib/python3.12/dist-packages (from torch->flash-attn) (12.6.80) Requirement already satisfied: nvidia-cudnn-cu12==9.10.2.21 in /usr/local/lib/python3.12/dist-packages (from torch->flash-attn) (9.10.2.21) Requirement already satisfied: nvidia-cublas-cu12==12.6.4.1 in /usr/local/lib/python3.12/dist-packages (from torch->flash-attn) (12.6.4.1) Requirement already satisfied: nvidia-cufft-cu12==11.3.0.4 in /usr/local/lib/python3.12/dist-packages (from torch->flash-attn) (11.3.0.4) Requirement already satisfied: nvidia-curand-cu12==10.3.7.77 in /usr/local/lib/python3.12/dist-packages (from torch->flash-attn) (10.3.7.77) Requirement already satisfied: nvidia-cusolver-cu12==11.7.1.2 in /usr/local/lib/python3.12/dist-packages (from torch->flash-attn) (11.7.1.2) Requirement already satisfied: nvidia-cusparse-cu12==12.5.4.2 in /usr/local/lib/python3.12/dist-packages (from torch->flash-attn) (12.5.4.2) Requirement already satisfied: nvidia-cusparselt-cu12==0.7.1 in /usr/local/lib/python3.12/dist-packages (from torch->flash-attn) (0.7.1) Requirement already satisfied: nvidia-nccl-cu12==2.27.5 in /usr/local/lib/python3.12/dist-packages (from torch->flash-attn) (2.27.5) Requirement already satisfied: nvidia-nvshmem-cu12==3.3.20 in /usr/local/lib/python3.12/dist-packages (from torch->flash-attn) (3.3.20) Requirement already satisfied: nvidia-nvtx-cu12==12.6.77 in /usr/local/lib/python3.12/dist-packages (from torch->flash-attn) (12.6.77) Requirement already satisfied: nvidia-nvjitlink-cu12==12.6.85 in /usr/local/lib/python3.12/dist-packages (from torch->flash-attn) (12.6.85) Requirement already satisfied: nvidia-cufile-cu12==1.11.1.6 in /usr/local/lib/python3.12/dist-packages (from torch->flash-attn) (1.11.1.6) Requirement already satisfied: triton==3.5.0 in /usr/local/lib/python3.12/dist-packages (from torch->flash-attn) (3.5.0) Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.12/dist-packages (from sympy>=1.13.3->torch->flash-attn) (1.3.0) Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.12/dist-packages (from jinja2->torch->flash-attn) (3.0.3) Building wheels for collected packages: flash-attn Building wheel for flash-attn (setup.py) ... done Created wheel for flash-attn: filename=flash_attn-2.8.3-cp312-cp312-linux_x86_64.whl size=253780426 sha256=4e2f9e39313266b1544b68138b15b91ee6221eccf14f7902b7c6620351340810 Stored in directory: /root/.cache/pip/wheels/3d/59/46/f282c12c73dd4bb3c2e3fe199f1a0d0f8cec06df0cccfeee27 Successfully built flash-attn Installing collected packages: flash-attn Successfully installed flash-attn-2.8.3
Load the dataset
In this step, we load the AI-MO/NuminaMath-TIR dataset from the Hugging Face Hub using the datasets library.
This dataset focuses on mathematical reasoning, featuring problems that require step-by-step logical solutions.
By fine-tuning a model that does not yet exhibit strong reasoning capabilities, it can learn to generate structured reasoning steps, enhancing both the model's accuracy and interpretability on math-related tasks.
For efficiency, we'll load only a small portion of the training split:
README.md: 0.00B [00:00, ?B/s]
data/train-00000-of-00001.parquet: 0%| | 0.00/147M [00:00<?, ?B/s]
data/test-00000-of-00001.parquet: 0%| | 0.00/215k [00:00<?, ?B/s]
Generating train split: 0%| | 0/72441 [00:00<?, ? examples/s]
Generating test split: 0%| | 0/99 [00:00<?, ? examples/s]
Transform the dataset
We will adapt our dataset to a conversational format using a custom system prompt, guiding the LLM to generate both step-by-step reasoning and the final answer.
Map: 0%| | 0/3622 [00:00<?, ? examples/s]
Load the model
This notebook can be used with two fine-tuning methods. By default, it is set up for QLoRA, which includes quantization using BitsAndBytesConfig. If you prefer to use standard LoRA without quantization, simply comment out the BitsAndBytesConfig configuration (training without quantization consumes more memory).
Define LoRA adapters
The following cell defines LoRA (or QLoRA if needed). When training with LoRA/QLoRA, we use a base model (the one selected above) and, instead of modifying its original weights, we fine-tune a LoRA adapter, a lightweight layer that enables efficient and memory-friendly training. The target_modules specify which parts of the model (e.g., attention or projection layers) will be adapted by LoRA during fine-tuning.
Load reward functions or define your own ones
GRPO requires reward functions to guide the learning process. For convenience, we can directly load pre-defined rewards from trl.rewards, which already includes a collection of ready-to-use rewards.
If you want to create your own custom reward functions to teach the model, a reward function is simply a Python function that takes the generated completions and returns a list of floats. For example, the following function, which we use in this notebook, rewards completions that correctly follow the <think> format:
def think_format_reward(completions: list[list[dict[str, str]]], **kwargs) -> list[float]:
pattern = r"^<think>(?!.*<think>)(.*?)</think>.*$"
completion_contents = [completion[0]["content"] for completion in completions]
matches = [re.match(pattern, content, re.DOTALL | re.MULTILINE) for content in completion_contents]
return [1.0 if match else 0.0 for match in matches]
In this notebook, we will use both think_format_reward, which rewards completions that correctly follow the <think> format, and reasoning_accuracy_reward, which evaluates the correctness of the model's solution to the mathematical problem. Together, these rewards guide the model to generate structured reasoning while producing accurate answers.
Train model
(Optional) Save fine tuned model
In this step, we save the fine-tuned model locally.