Skip to content

Raghuboi/pocket-gpt

Repository files navigation

pocket-gpt

From-scratch reproductions of foundational ML papers. Pure PyTorch, no frameworks.

Each reproduction lives in a single model.py you can read top to bottom. Training loops are similarly minimal. The goal: if you cat model.py, you understand every line and can trace it back to the paper.

Reproductions

Paper Model Params Status
Attention Is All You Need (2017) model.py ~21M (base) Trained

Quick Start

# Setup
python -m venv .venv && source .venv/bin/activate
pip install -r requirements.txt

# Prepare data (Multi30k EN-DE, ~30K sentence pairs)
python data/multi30k/prepare.py

# Train tiny model (~30 min on GPU, sanity check)
python train.py config/tiny.py

# Train paper-faithful baseline (~4 hours on GPU)
python train.py config/baseline.py

# Translate
python sample.py --ckpt=out-baseline-multi30k/best.pt --src="Hello world"

Attention Is All You Need

Faithful reproduction of the Transformer from Vaswani et al. (2017).

Architecture: Encoder-decoder with multi-head self-attention, positional encoding, and position-wise FFN. Post-LayerNorm, sinusoidal PE, ReLU FFN, three-way weight tying.

Configs:

  • config/tiny.py -- reduced model for rapid iteration (256d, 4 heads, 3 layers)
  • config/baseline.py -- paper-faithful base model (512d, 8 heads, 6 layers)
  • config/baseline_wmt14.py -- full WMT14 EN-DE (if you have time)

Expected results (Multi30k EN-DE):

Metric Paper (base) Expected
BLEU -- 25-28
Training (1x RTX 4090) -- ~4 hours

Modernization plan -- each swap is a minimal diff, trained and evaluated independently:

# Swap Expected Impact
1 Post-LN to Pre-LN Faster convergence
2 LayerNorm to RMSNorm Same loss, faster throughput
3 ReLU to SwiGLU Lower val loss
4 Sinusoidal PE to RoPE Better long-range
5 Standard to FlashAttention-2 2-3x speedup

Structure

model.py              # Full Transformer (~480 lines)
train.py              # Training loop (~350 lines)
sample.py             # Inference: beam search + greedy decode
configurator.py       # Config override system (Karpathy pattern)
config/               # Python config files per experiment
data/                 # Dataset preparation scripts
tests/                # Unit tests

Flat root. No src/ directory. One file per responsibility.

Design Decisions

Why Multi30k over WMT14? Multi30k trains in hours, not weeks. Still produces meaningful BLEU differences between architectural variants. Better signal-to-noise for measuring component-level impact.

Why PyTorch only? Communication and reproducibility. No frameworks, no abstractions. Every line of model.py maps to a section of the paper.

Why this structure? Following Karpathy's pattern (nanoGPT: 57K stars). Minimal nesting, maximum readability.

Hardware

Developed on: 2x RTX 3090 + 1x RTX 4090, 80 GB RAM, Arch Linux.

References

License

MIT

About

Faithful reproduction of 'Attention Is All You Need' (Vaswani et al. 2017) in pure PyTorch. ~700 lines. Paper-faithful baseline + systematic modernization swaps.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages