Tiny-Reason: Distilling Reasoning into 1.5B Parameter Models

Overview

Tiny-Reason demonstrates that explicit mathematical reasoning can be distilled into a 1.5B-parameter model using structured chain-of-thought supervision and extreme parameter efficiency — training only ~1.3M parameters (0.087% of total) via QLoRA on just 2,000 GSM8K samples in ~45 minutes on a free Google Colab T4 with 6GB VRAM. The approach is deliberately simple: pure supervised fine-tuning (no RL, no reward model, no teacher distillation) on human-written reasoning traces reformatted into structured <think> tags within ChatML templates. LoRA adapters are applied to all 7 projection matrices per transformer block — including the SwiGLU FFN components (gate_proj, up_proj, down_proj) — a deliberate divergence from standard LoRA practice based on the insight that mathematical computation relies heavily on feed-forward layers, not just attention. The fine-tuned model achieves 35–45% accuracy on held-out GSM8K test samples (vs. ~30% base model), demonstrating meaningful reasoning capability transfer despite extreme resource constraints.

Structured CoT Data Pipeline

The GSM8KProcessor transforms raw GSM8K samples into a reasoning-friendly training format by parsing the #### delimiter separating step-by-step reasoning from final answers, then wrapping them in structured tags within Qwen’s ChatML template:

<|im_start|>system
You are a helpful assistant that solves math problems step-by-step.<|im_end|>
<|im_start|>user
[question]<|im_end|>
<|im_start|>assistant
<think>
[step-by-step reasoning from GSM8K]
</think>

**Final Answer:** [numerical answer]<|im_end|>

This creates three distinct gradient-receiving regions: (a) the reasoning trace inside <think> tags, (b) the format delimiter, and (c) the answer value — teaching the model both problem decomposition and result extraction simultaneously. The structured <think> format mirrors the approach used by DeepSeek-R1, but adapted for ChatML templates.

At inference time, generation is primed by appending <think>\n to the prompt, forcing the model into its trained reasoning mode from the first token.

Data processing: 2,000 samples subsampled from GSM8K train split (shuffled with seed 3407), split 90/10 into 1,800 training + 200 validation samples. Malformed entries (missing ####) are gracefully filtered while maintaining batch alignment.

Model Architecture and QLoRA Configuration

Component Specification
Base Model Qwen2.5-1.5B-Instruct (28 transformer layers)
Quantization 4-bit NF4 via bitsandbytes (~0.8GB VRAM, down from ~3GB)
LoRA Rank / Alpha 16 / 32 (alpha/r = 2.0)
LoRA Dropout 0
LoRA Targets All 7 projections: q_proj, k_proj, v_proj, o_proj, gate_proj, up_proj, down_proj
Trainable parameters ~1.3M (0.087% of 1.5B total)
Gradient checkpointing Unsloth-optimized variant
Max sequence length 2,048 tokens
Precision BF16 on Ampere+ (A100, RTX 3090+), FP16 on Volta (T4) — auto-detected

Why LoRA on FFN layers: Most LoRA implementations target only attention projections. Tiny-Reason deliberately includes all three SwiGLU FFN components (gate_proj, up_proj, down_proj) because mathematical computation relies heavily on the feed-forward pathway for numerical operations — attention captures which numbers to combine, but the FFN performs the computation itself. This adds minimal parameter overhead while significantly improving mathematical reasoning quality.

Training Configuration

Parameter Value
Per-device batch size 2
Gradient accumulation 4 steps → 8 effective batch
Max training steps 250
Learning rate 2e-4 with linear decay
Warmup steps 50
Optimizer adamw_8bit (8-bit AdamW)
Weight decay 0.01
Evaluation interval Every 50 steps
Checkpoint saving Every 50 steps (max 3 kept)
Best model selection load_best_model_at_end=True on eval loss
Packing Disabled (each sample occupies its own sequence)

Pure SFT approach: No reinforcement learning, no reward model, no GRPO, no PPO. The reasoning capability comes entirely from structured CoT supervision and the standard causal language modeling loss (cross-entropy) applied to the formatted reasoning traces. This makes the approach maximally simple, reproducible, and accessible.

Validation-guided checkpointing: With only 250 training steps on 1,800 samples, overfitting is a real risk. The system evaluates every 50 steps on the 200-sample validation split and retains the best checkpoint, preventing degradation.

Evaluation

Evaluation uses greedy decoding (do_sample=False, max_new_tokens=512) on 50 held-out GSM8K test samples. Predicted answers are extracted via regex (\*\*Final Answer:\*\*\s*(.*?)) and normalized through a semantically-aware pipeline:

  • Strip whitespace, lowercase
  • Remove commas/spaces from numbers ("1,000""1000")
  • Strip floating-point artifacts (trailing .0)
  • Convert English word numbers to digits ("zero" through "ten" → 0–10)
  • Remove currency symbols ($) and unit suffixes ("dollars", "meters", "feet")

Results

Configuration GSM8K Accuracy
Base Qwen2.5-1.5B-Instruct ~30%
Tiny-Reason (2K samples) 35–45%
Frontier models (GPT-4, etc.) 80–90%+

+5–15 percentage points improvement despite extreme constraints: 0.087% trainable parameters, 2,000 training samples, 250 steps, 45 minutes on free hardware. Demonstrates that structured chain-of-thought supervision transfers meaningful reasoning capability even to sub-2B models.

Inference

Interactive REPL with real-time token streaming via TextStreamer from Transformers. Uses low-temperature sampling (temperature=0.1, do_sample=True) for near-deterministic but slightly varied outputs with max_new_tokens=1024.

Reproducibility

Full deterministic seeding across Python random, NumPy, PyTorch CPU/CUDA, and PYTHONHASHSEED. Every checkpoint saves the exact experiment config YAML alongside model weights and tokenizer. Seed 3407 — a reference to the paper “Torch.manual_seed(3407) is all you need” showing this seed often produces favorable training dynamics.

Tech Stack

Python (3.9+), PyTorch (2.1+), Hugging Face Transformers (4.38+), TRL (SFTTrainer), PEFT (LoRA), bitsandbytes (NF4 + 8-bit AdamW), Unsloth (optimized training kernels), Datasets (GSM8K), xformers