Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,10 @@ __pycache__
.python-version
*.pt
wandb/
./Code 2. Cartpole/6. A3C/Cartpole_A3C.pgy
logs/
./Code 2. Cartpole/6. A3C/Cartpole_A3C.pgy
# Local scratch scripts
scripts/

# Local-only docs (not for github)
docs/
85 changes: 59 additions & 26 deletions 3-atari/1-dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
serious training.
"""
import random
import sys
from collections import deque

import numpy as np
Expand All @@ -23,17 +22,17 @@


SAVE_PATH = "atari_dqn.pt"
TOTAL_FRAMES = 1_000_000 # bump to ~10M for paper-quality results
BUFFER_CAPACITY = 100_000 # bump to 1M with enough RAM
BATCH_SIZE = 64
TOTAL_FRAMES = 10_000_000 # Nature uses 50M agent steps; 10M is laptop-friendly
BUFFER_CAPACITY = 500_000 # ~3.5GB RAM (uint8, single frames stacked at sample time); sized for 8GB Macs
BATCH_SIZE = 32
GAMMA = 0.99
LR = 1e-4
LEARN_START = 10_000 # frames of pure exploration before training begins
LEARN_START = 80_000 # frames of pure exploration before training begins
TRAIN_EVERY = 4
TARGET_UPDATE_EVERY = 1_000 # in training steps, not env steps
TARGET_UPDATE_EVERY = 250 # in training steps, not env steps (~1k env frames)
EPSILON_START = 1.0
EPSILON_END = 0.05
EPSILON_DECAY_FRAMES = 250_000 # linear decay from start to end over this many frames
EPSILON_END = 0.01
EPSILON_DECAY_FRAMES = 1_000_000 # linear decay from start to end over this many frames


# Standard Nature CNN.
Expand All @@ -57,34 +56,59 @@ def forward(self, x):


class ReplayBuffer:
"""Uint8 replay buffer — far more memory-efficient than storing floats."""
"""Single-frame uint8 buffer — stacks of 4 are reconstructed at sample time,
cutting RAM ~4x vs. storing the full stack per slot."""

def __init__(self, capacity, obs_shape):
def __init__(self, capacity, frame_shape=(84, 84), stack=4):
self.capacity = capacity
self.obs = np.zeros((capacity, *obs_shape), dtype=np.uint8)
self.next_obs = np.zeros((capacity, *obs_shape), dtype=np.uint8)
self.actions = np.zeros(capacity, dtype=np.int64)
self.rewards = np.zeros(capacity, dtype=np.float32)
self.dones = np.zeros(capacity, dtype=np.float32)
self.stack = stack
self.frames = np.zeros((capacity, *frame_shape), dtype=np.uint8)
self.actions = np.zeros(capacity, dtype=np.int64)
self.rewards = np.zeros(capacity, dtype=np.float32)
self.dones = np.zeros(capacity, dtype=np.float32)
self.idx = 0
self.size = 0

def push(self, obs, action, reward, next_obs, done):
self.obs[self.idx] = obs
def push(self, frame, action, reward, done):
self.frames[self.idx] = frame
self.actions[self.idx] = action
self.rewards[self.idx] = reward
self.next_obs[self.idx] = next_obs
self.dones[self.idx] = float(done)
self.idx = (self.idx + 1) % self.capacity
self.size = min(self.size + 1, self.capacity)

def _stack(self, idx):
# Gather frames[idx-stack+1 .. idx]; newest at last channel.
offsets = np.arange(self.stack)
gather = (idx[:, None] - (self.stack - 1) + offsets[None, :]) % self.capacity
out = self.frames[gather]
# Zero out frames sitting before an episode boundary inside the stack.
# dones at the (stack-1) older positions mark where a prior episode ended.
older = self.dones[gather[:, :-1]].astype(bool)
# Once we cross any done walking newest→oldest, everything older is invalid.
invalid = np.cumsum(older[:, ::-1], axis=1)[:, ::-1] > 0
mask = np.concatenate([~invalid, np.ones((idx.shape[0], 1), dtype=bool)], axis=1)
return out * mask[:, :, None, None]

def sample(self, batch_size, device):
idx = np.random.randint(0, self.size, size=batch_size)
# Reject indices whose stack would straddle the write head (stale frames).
while True:
if self.size < self.capacity:
if self.size < self.stack + 2:
raise RuntimeError("buffer too small to sample yet")
idx = np.random.randint(self.stack - 1, self.size - 1, size=batch_size)
break
idx = np.random.randint(0, self.capacity, size=batch_size)
dist = (self.idx - 1 - idx) % self.capacity
if np.all(dist >= self.stack):
break
states = self._stack(idx)
next_states = self._stack((idx + 1) % self.capacity)
return (
torch.as_tensor(self.obs[idx], device=device),
torch.as_tensor(states, device=device),
torch.as_tensor(self.actions[idx], device=device),
torch.as_tensor(self.rewards[idx], device=device),
torch.as_tensor(self.next_obs[idx], device=device),
torch.as_tensor(next_states, device=device),
torch.as_tensor(self.dones[idx], device=device),
)

Expand Down Expand Up @@ -130,10 +154,12 @@ def greedy_action(obs):

print(f"device: {device}, env: {args.env}, actions: {n_actions}")

buffer = ReplayBuffer(BUFFER_CAPACITY, env.observation_space.shape)
buffer = ReplayBuffer(BUFFER_CAPACITY)
obs, _ = env.reset()
ep_return = 0.0
ep_return = 0.0 # accumulates within one life (LifeLossTerminalEnv ends an "episode" per life)
game_return = 0.0 # accumulates across all 5 lives until real game-over
recent_returns = deque(maxlen=20)
recent_game_returns = deque(maxlen=20)
train_step = 0
last_loss = 0.0

Expand All @@ -146,18 +172,23 @@ def greedy_action(obs):
else:
action = greedy_action(obs)

next_obs, reward, terminated, truncated, _ = env.step(action)
next_obs, reward, terminated, truncated, info = env.step(action)
done = terminated or truncated
# Reward clipping (DeepMind standard) — keeps Q-values from blowing up
# when one game has rewards in tens and another in hundreds.
clipped = np.sign(reward)
buffer.push(np.asarray(obs), action, clipped, np.asarray(next_obs), done)
# FrameStack gives (4, 84, 84); store just the newest frame and stack at sample time.
buffer.push(np.asarray(obs)[-1], action, clipped, done)

ep_return += reward
game_return += reward
obs = next_obs
if done:
recent_returns.append(ep_return)
ep_return = 0.0
if info.get("game_over", True):
recent_game_returns.append(game_return)
game_return = 0.0
obs, _ = env.reset()

# Training.
Expand All @@ -182,12 +213,14 @@ def greedy_action(obs):
# Logging.
if frame % 10_000 == 0:
mean = float(np.mean(recent_returns)) if recent_returns else 0.0
game_mean = float(np.mean(recent_game_returns)) if recent_game_returns else 0.0
print(f"frame: {frame:>8} eps: {epsilon(frame):.3f} "
f"recent_mean_return: {mean:.1f} buffer: {buffer.size}")
f"per_life: {mean:.1f} per_game: {game_mean:.1f} buffer: {buffer.size}")
if args.wandb:
wandb.log({
"global_step": frame,
"recent_mean_return": mean,
"recent_mean_game_return": game_mean,
"epsilon": epsilon(frame),
"loss": last_loss,
"buffer_size": buffer.size,
Expand Down
47 changes: 38 additions & 9 deletions 3-atari/2-ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@


SAVE_PATH = "atari_ppo.pt"
TOTAL_FRAMES = 5_000_000
TOTAL_FRAMES = 10_000_000
N_ENVS = 8
ROLLOUT_STEPS = 128 # batch = N_ENVS * ROLLOUT_STEPS = 1024
EPOCHS = 4
Expand Down Expand Up @@ -109,10 +109,17 @@ def policy_action(obs):
frames_per_update = batch_size
n_updates = TOTAL_FRAMES // frames_per_update
obs, _ = envs.reset()
ep_returns_per_env = np.zeros(N_ENVS, dtype=np.float32)
ep_returns_per_env = np.zeros(N_ENVS, dtype=np.float32) # per-life (resets every life loss)
game_returns_per_env = np.zeros(N_ENVS, dtype=np.float32) # per-game (resets only on real game-over)
ep_returns = []
game_returns = []

for update in range(1, n_updates + 1):
# Linear LR anneal from LR -> 0 over the run (CleanRL convention).
lr_now = LR * (1.0 - (update - 1) / n_updates)
for g in optimizer.param_groups:
g["lr"] = lr_now

obs_buf = np.zeros((ROLLOUT_STEPS, N_ENVS, *obs_shape), dtype=np.uint8)
act_buf = np.zeros((ROLLOUT_STEPS, N_ENVS), dtype=np.int64)
logp_buf = np.zeros((ROLLOUT_STEPS, N_ENVS), dtype=np.float32)
Expand All @@ -134,29 +141,35 @@ def policy_action(obs):
logp_buf[t] = logp.cpu().numpy()
val_buf[t] = value.cpu().numpy()

next_obs, reward, terminated, truncated, _ = envs.step(act_buf[t])
next_obs, reward, terminated, truncated, info = envs.step(act_buf[t])
done = np.logical_or(terminated, truncated)
ep_returns_per_env += reward
game_returns_per_env += reward
rew_buf[t] = np.sign(reward).astype(np.float32) # DeepMind reward clipping
done_buf[t] = done.astype(np.float32)

# LifeLossTerminalEnv tags each step's info with game_over (True only on real game-over).
game_over = info.get("game_over", done)
for i in range(N_ENVS):
if done[i]:
ep_returns.append(float(ep_returns_per_env[i]))
ep_returns_per_env[i] = 0.0
if bool(game_over[i]):
game_returns.append(float(game_returns_per_env[i]))
game_returns_per_env[i] = 0.0
obs = next_obs

# --- GAE ---
with torch.no_grad():
obs_t = torch.as_tensor(np.asarray(obs), device=device)
_, last_value = model(obs_t)
advantages, returns = compute_gae(rew_buf, val_buf, done_buf, last_value.cpu().numpy())
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

# Flatten (T, N_ENVS, ...) -> (T*N_ENVS, ...)
obs_t = torch.as_tensor(obs_buf.reshape(batch_size, *obs_shape), device=device)
act_t = torch.as_tensor(act_buf.reshape(batch_size), device=device)
old_logp_t = torch.as_tensor(logp_buf.reshape(batch_size), device=device)
old_val_t = torch.as_tensor(val_buf.reshape(batch_size), device=device)
adv_t = torch.as_tensor(advantages.reshape(batch_size), device=device)
ret_t = torch.as_tensor(returns.reshape(batch_size), device=device)

Expand All @@ -173,11 +186,22 @@ def policy_action(obs):
new_logp = dist.log_prob(act_t[mb])
entropy = dist.entropy().mean()

# Advantage normalization per minibatch (CleanRL convention).
mb_adv = adv_t[mb]
mb_adv = (mb_adv - mb_adv.mean()) / (mb_adv.std() + 1e-8)

ratio = (new_logp - old_logp_t[mb]).exp()
unclipped = ratio * adv_t[mb]
clipped = torch.clamp(ratio, 1 - CLIP_COEF, 1 + CLIP_COEF) * adv_t[mb]
unclipped = ratio * mb_adv
clipped = torch.clamp(ratio, 1 - CLIP_COEF, 1 + CLIP_COEF) * mb_adv
policy_loss = -torch.min(unclipped, clipped).mean()
value_loss = (values - ret_t[mb]).pow(2).mean()

# Value loss with clipping around the old value prediction.
v_clipped = old_val_t[mb] + torch.clamp(
values - old_val_t[mb], -CLIP_COEF, CLIP_COEF)
vl_unclipped = (values - ret_t[mb]).pow(2)
vl_clipped = (v_clipped - ret_t[mb]).pow(2)
value_loss = 0.5 * torch.max(vl_unclipped, vl_clipped).mean()

loss = policy_loss + VALUE_COEF * value_loss - ENTROPY_COEF * entropy

optimizer.zero_grad()
Expand All @@ -192,18 +216,23 @@ def policy_action(obs):

global_step = update * frames_per_update
if ep_returns:
recent = ep_returns[-20:]
life_mean = float(np.mean(ep_returns[-20:]))
game_mean = float(np.mean(game_returns[-20:])) if game_returns else 0.0
print(f"update: {update:>4} frames: {global_step:>8} "
f"recent_mean_return: {np.mean(recent):.1f} episodes: {len(ep_returns)}")
f"per_life: {life_mean:.1f} per_game: {game_mean:.1f} "
f"lives: {len(ep_returns)} games: {len(game_returns)}")
if args.wandb:
log = {
"global_step": global_step,
"policy_loss": pl_sum / n_mb,
"value_loss": vl_sum / n_mb,
"entropy": ent_sum / n_mb,
"lr": lr_now,
}
if ep_returns:
log["recent_mean_return"] = float(np.mean(ep_returns[-20:]))
if game_returns:
log["recent_mean_game_return"] = float(np.mean(game_returns[-20:]))
wandb.log(log, step=global_step)

torch.save(model.state_dict(), SAVE_PATH)
Expand Down
36 changes: 35 additions & 1 deletion 3-atari/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,39 @@ def reset(self, **kwargs):
obs, _ = self.env.reset(**kwargs)
return obs, {}


# Treats each life as its own episode for bootstrapping (so Q-targets / GAE don't
# value-chain across deaths) but only resets the real game when all lives are
# gone. Without this, every life loss triggers a full env.reset() — burning
# frames on noop_max + FIRE and breaking long-horizon credit assignment.
class LifeLossTerminalEnv(gym.Wrapper):
def __init__(self, env):
super().__init__(env)
self.lives = 0
self.game_over = True

def step(self, action):
obs, reward, terminated, truncated, info = self.env.step(action)
self.game_over = terminated or truncated
lives = info.get("lives", 0)
if 0 < lives < self.lives:
terminated = True
self.lives = lives
info["game_over"] = self.game_over
return obs, reward, terminated, truncated, info

def reset(self, **kwargs):
if self.game_over:
obs, info = self.env.reset(**kwargs)
else:
# Fake terminal from a life loss — advance one frame instead of
# resetting so the game keeps its remaining lives.
obs, _, terminated, truncated, info = self.env.step(0)
if terminated or truncated:
obs, info = self.env.reset(**kwargs)
self.lives = info.get("lives", 0)
return obs, info

ENV_IDS = {
"breakout": "ALE/Breakout-v5",
"pong": "ALE/Pong-v5",
Expand Down Expand Up @@ -61,12 +94,13 @@ def make_env(args):
noop_max=30,
frame_skip=4,
screen_size=84,
terminal_on_life_loss=True,
terminal_on_life_loss=False, # handled by LifeLossTerminalEnv below
grayscale_obs=True,
scale_obs=False, # keep uint8; we normalize in the model
)
if "FIRE" in env.unwrapped.get_action_meanings():
env = FireResetEnv(env)
env = LifeLossTerminalEnv(env)
env = gym.wrappers.FrameStackObservation(env, stack_size=4)
return env

Expand Down
20 changes: 20 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,26 @@ From the basics to deep reinforcement learning, this repo provides easy-to-read
8. A2C — [`2-a2c.py`](./2-cartpole/2-a2c.py)
9. PPO — [`3-ppo.py`](./2-cartpole/3-ppo.py)

**Atari** ([`3-atari/`](./3-atari))

10. DQN — [`1-dqn.py`](./3-atari/1-dqn.py)
11. PPO — [`2-ppo.py`](./3-atari/2-ppo.py)

## Benchmarks

Trained on a **MacBook Pro 14" (Apple M3, 8 GB unified memory)**, macOS 26.2, Python 3.11, PyTorch 2.11 with the MPS backend. CPU / GPU figures are read from Activity Monitor on the `python3.11` process after the run has stabilized (~5 min in); peak RAM is the process's real memory at its high-water mark. Final score is the mean per-game return over the last 20 episodes of training.

### Atari — Breakout (10M agent steps, `ALE/Breakout-v5` with sticky actions)

| Algorithm | Params | Train time | Final mean (per-game) | Peak RAM | CPU% | GPU% | W&B |
|-----------|--------|------------|-----------------------|----------|------|------|-----|
| DQN | 1.69M | ~9h | 93.5 ± 9.6 | 5.27 GB | ~60 | ~55 | [report](https://api.wandb.ai/links/rlcode/ljkn7ahp) |
| PPO | 1.69M | ~3.5h | _TBD_¹ | 1.98 GB | ~62 | ~55 | [report](https://api.wandb.ai/links/rlcode/jbdsbn6t) |

> Single seed per row, mean ± std over the final 20 logged episodes. `Params` counts only trainable network weights. `CPU%` is the single-process value reported by Activity Monitor (sum across cores, so >100% means multi-core use); `GPU%` is the same column for the Apple GPU. Sticky actions (`repeat_action_probability=0.25`) make absolute scores lower than the deterministic `*-v4` environments often cited in older papers.
>
> ¹ Most recent PPO run predates the `LifeLossTerminalEnv` fix and reports only per-life return (final 20: 27.2 ± 3.2). Per-game number will be filled in after the next training run.

## Setup

Requires Python 3.11 and [uv](https://docs.astral.sh/uv/).
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@ dependencies = [
"pygame>=2.6,<3",
"opencv-python-headless>=4.13,<4.14",
"wandb>=0.27.0",
"moviepy>=2.2.1",
]
Loading