Vision-R1: Visual System 2 Reasoning via GRPO
Overview
Vision-R1 extends the DeepSeek-R1 paradigm — using reinforcement learning to induce explicit chain-of-thought reasoning — to the vision-language domain. Most VLMs operate in “System 1” mode (fast, reactive answering) and lack the ability to perform deliberate, step-by-step visual reasoning. The system teaches Qwen2.5-VL-7B to (1) parse visual information explicitly via <visual_parse> tags, (2) reason step-by-step via <think> tags, and (3) produce verifiable answers via \boxed{} notation — through a three-phase pipeline progressing from supervised imitation to reinforcement learning with decomposed rule-based process supervision (no neural reward model). The reward function decomposes into four interpretable components — correctness (0.5), spatial grounding (0.2), symbolic verification (0.2), and format compliance (0.1) — providing fine-grained credit assignment without reward model training overhead. On Geometry3K, the system achieves ~45% Pass@1 (vs. ~5% base model) with ~65% Pass@16 via majority voting, while blind baseline ablation (black-image inputs) confirms genuine visual grounding rather than textual shortcut exploitation.
Relationship to DeepSeek-R1
| Aspect | DeepSeek-R1 | Vision-R1 |
|---|---|---|
| Domain | Text-only reasoning (math, code) | Visual + spatial reasoning |
| Base model | DeepSeek-V3 | Qwen2.5-VL-7B |
| RL algorithm | GRPO | GRPO (same) |
| Reasoning format | <think> tags | <visual_parse> + <think> tags |
| Reward model | Trained neural RM | Rule-based decomposed rewards |
| Key addition | — | Visual grounding verification + spatial coordinate validation |
The fundamental contribution is adapting “reasoning via RL” for multimodal inputs where the model must ground its reasoning in visual perception, not just textual context — introducing <visual_parse> as an auditable “perception checkpoint” that separates visual processing from logical reasoning.
Three-Phase Training Pipeline
Phase 1: SFT Cold Start
Bootstraps the structured reasoning format through next-token prediction on annotated reasoning traces. Data is JSONL with (image, question, answer, completion) tuples where completions follow the <visual_parse> → <think> → \boxed{} format.
| Parameter | Value |
|---|---|
| Learning rate | 2e-5 |
| Batch size | 4 × 4 gradient accumulation = 16 effective |
| Epochs | 3 |
| Warmup ratio | 0.1 |
| Weight decay | 0.01 |
| Max sequence length | 2,048 tokens |
| Precision | BFloat16 |
Uses SFTTrainer from TRL with a custom qwen_sft_collator() that combines prompt + completion, processes images via Qwen processor, and sets labels = input_ids.clone() for standard causal LM training.
Phase 1.5: Expert Iteration (Optional)
Rejection sampling to narrow the policy distribution toward correct trajectories before RL — inspired by AlphaGeometry:
- Generate K=16 candidate solutions per problem using temperature sampling
- Extract answers via
\boxed{}or<answer>tags with numeric tolerance - Retain only correct solutions as curated training data
- Track coverage (problems with ≥1 correct), pass rate, and average correct samples per problem
This bridges SFT and GRPO by filtering the model’s own generations for correctness, improving GRPO stability.
Phase 2: GRPO Reinforcement Learning
Group Relative Policy Optimization refines reasoning quality beyond SFT — samples G=4 completions per prompt, computes group-relative advantages (no critic network needed), and updates the policy with KL penalty against the SFT reference.
| Parameter | Value |
|---|---|
| Learning rate | 1e-5 |
| Batch size | 2 × 8 gradient accumulation = 16 effective |
| Epochs | 3 |
| GRPO group size | 4 completions per prompt |
| KL penalty (β) | 0.1 |
| Temperature | 0.7 |
| Top-p | 0.95 |
| Repetition penalty | 1.1 |
| Max completion length | 1,024 tokens |
Decomposed Rule-Based Reward System
Rather than training a neural reward model (expensive, prone to reward hacking), Vision-R1 uses four interpretable, orthogonal reward components providing fine-grained credit assignment:
Format Reward (weight: 0.1)
Regex-based structural validation:
-
<think>...</think>tags present: +0.3 -
\boxed{}or<answer>tags present: +0.3 -
<visual_parse>...</visual_parse>tags present: +0.2 - All tags properly closed (balanced): +0.2
- Unclosed tags: penalty (score < 0.3)
Grounding Reward (weight: 0.2)
Validates spatial coordinate outputs in [x1, y1, x2, y2] format:
- Extracts coordinate arrays via regex
- Bounds check: all values in [0, 1000]
- Ordering check:
x1 < x2andy1 < y2(valid bounding box geometry) - Awards 0.2 per valid coordinate set, capped at 1.0
Symbolic Verification Reward (weight: 0.2)
Validates mathematical equations in reasoning traces:
- Converts caret notation (
^) to Python power syntax (**) - Parses equations using SymPy’s
sympify()for syntactic validity - Awards points per valid equation, capped at 1.0
- Graceful fallback if SymPy unavailable
Correctness Reward (weight: 0.5)
Binary (0 or 1) with cascading matchers:
- Extract predicted answer from
\boxed{}using a depth-tracking brace parser (handles nested braces like\boxed{\frac{a^{2}}{b^{2}}}) - Symbolic verification via
math_verifylibrary - Numeric comparison with tolerance (1e-6)
- Case-insensitive string matching fallback
Length penalty (additional): Graduated penalty for responses exceeding 2,048 characters (max −0.5), discouraging verbose, unfocused reasoning.
Model Architecture
| Component | Specification |
|---|---|
| Base Model | Qwen2.5-VL-7B-Instruct (NaViT vision encoder + language decoder) |
| Dynamic Resolution | 256×28×28 to 1280×28×28 pixels |
| LoRA Rank / Alpha | 64 / 128 (alpha/r = 2.0) |
| LoRA Targets | All 7 projections: q_proj, k_proj, v_proj, o_proj, gate_proj, up_proj, down_proj |
| LoRA Dropout | 0.05 |
| Modules to save | embed_tokens, lm_head (fully trainable, not LoRA-adapted — critical for VLM fine-tuning stability) |
| Trainable parameters | ~3–5% of total |
| Precision | BFloat16 |
| Attention fallback | Flash Attention 2 → SDPA → Eager (auto-detected) |
The modules_to_save configuration ensures embedding and output layers are fully trainable rather than LoRA-adapted — noted as critical for stable VLM fine-tuning where the model must learn to map between visual tokens and reasoning output.
Inference: Generation Priming + Best-of-N
During inference, generation is primed by appending <think> to the prompt, forcing the model into its trained reasoning mode from the first token. Best-of-N sampling generates N candidates with temperature > 0, extracts answers from each via \boxed{}, and applies majority voting using frequency-based selection for the final answer.
Evaluation and Ablation Studies
Results
| Configuration | Pass@1 | Pass@16 (Majority Vote) |
|---|---|---|
| Base model (Qwen2.5-VL-7B) | ~5% | ~10% |
| + SFT (Phase 1) | ~35% | ~50% |
| + SFT + GRPO (Phase 1 + 2) | ~45% | ~65% |
| Blind baseline (black images) | ~5% | ~5% |
Key findings: (1) SFT provides a +30pp improvement (5% → 35%), demonstrating the value of structured reasoning format. (2) GRPO adds +10pp beyond SFT (35% → 45%), demonstrating RL-based reasoning quality improvement. (3) Majority voting at test-time provides +20pp additional gains (45% → 65%), validating test-time compute scaling. (4) Blind baseline equals random chance (~5%), confirming genuine visual grounding — the model actually uses image content.
Three Ablation Experiments
Blind Baseline Comparison: Replaces all images with 224×224 black squares. Performance drop >10% confirms genuine visual grounding vs. textual shortcut exploitation.
Thinking Length Scaling: Bins reasoning tokens into ranges (0–50, 50–100, …, 1000+) and computes per-bin accuracy with correlation coefficients and polynomial trend fitting. Demonstrates that more thinking leads to better answers — evidence for test-time compute scaling analogous to DeepSeek-R1 and OpenAI o1.
SFT vs. RL Comparison: Compares accuracy across three stages (base, +SFT, +SFT+GRPO) with delta annotations, isolating each phase’s contribution.
Pass@k Implementation
Uses the unbiased estimator: Pass@k = 1 − C(n−c, k) / C(n, k) where n = total samples, c = correct samples, with overflow prevention and edge-case handling.
Tech Stack
Python (3.10+), PyTorch (2.4+), Hugging Face Transformers (4.45+), TRL (GRPOTrainer, SFTTrainer), PEFT (LoRA), Qwen2.5-VL-7B, Flash Attention 2, SymPy (symbolic verification), math_verify, Weights & Biases, Matplotlib, Gradio (interactive demo)