Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions dflash/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,11 @@ def dflash_generate(
start = num_input_tokens
draft_prefill = True

stop_token_tensor = (
torch.as_tensor(stop_token_ids, device=output_ids.device, dtype=output_ids.dtype)
if stop_token_ids is not None else None
)

while start < max_length:
block_output_ids = output_ids[:, start : start + block_size].clone()
block_position_ids = position_ids[:, start : start + block_size]
Expand Down Expand Up @@ -142,15 +147,16 @@ def dflash_generate(
if block_size > 1:
target_hidden = extract_context_feature(output.hidden_states, model.target_layer_ids)[:, :acceptance_length + 1, :]

if stop_token_ids is not None and any(
stop_token_id in output_ids[:, num_input_tokens:] for stop_token_id in stop_token_ids
):
if stop_token_tensor is not None and torch.isin(
output_ids[0, num_input_tokens : start + 1], stop_token_tensor
).any().item():
break

output_ids = output_ids[:, :min(start + 1, max_length)]
if stop_token_ids is not None:
stop_token_ids = torch.tensor(stop_token_ids, device=output_ids.device)
stop_token_indices = torch.isin(output_ids[0][num_input_tokens:], stop_token_ids).nonzero(as_tuple=True)[0]
if stop_token_tensor is not None:
stop_token_indices = torch.isin(
output_ids[0][num_input_tokens:], stop_token_tensor
).nonzero(as_tuple=True)[0]
if stop_token_indices.numel() > 0:
output_ids = output_ids[:, : num_input_tokens + stop_token_indices[0] + 1]

Expand Down
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,8 @@ mlx = [
"mlx==0.31.2",
"mlx-lm==0.31.3",
]
test = [
"pytest>=8",
"torch",
"transformers==4.57.1",
]
Empty file added tests/__init__.py
Empty file.
135 changes: 135 additions & 0 deletions tests/test_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
"""Unit tests for dflash.model helpers and the dflash_generate stop-token check.

These tests intentionally cover only the pure-Python / pure-tensor logic so they
run on CPU without needing weights or transformers Qwen3 wiring.
"""

import pytest

torch = pytest.importorskip("torch")

from dflash.model import build_target_layer_ids, extract_context_feature, sample


def test_build_target_layer_ids_single_draft_layer():
assert build_target_layer_ids(num_target_layers=24, num_draft_layers=1) == [12]


def test_build_target_layer_ids_endpoints_for_two_draft_layers():
layers = build_target_layer_ids(num_target_layers=64, num_draft_layers=2)
assert layers == [1, 64 - 3]


def test_build_target_layer_ids_evenly_interpolates():
layers = build_target_layer_ids(num_target_layers=64, num_draft_layers=4)
assert layers[0] == 1
assert layers[-1] == 64 - 3
assert len(layers) == 4
assert layers == sorted(layers)


def test_extract_context_feature_concatenates_offset_layers():
# extract_context_feature reads hidden_states[layer_id + 1] for each layer_id
# (the +1 offset skips the embedding output).
bsz, seq, hidden = 2, 5, 8
hidden_states = [torch.full((bsz, seq, hidden), float(i)) for i in range(6)]
out = extract_context_feature(hidden_states, layer_ids=[0, 2, 4])
assert out.shape == (bsz, seq, hidden * 3)
# hidden_states[1] is all 1.0, hidden_states[3] is all 3.0, hidden_states[5] is all 5.0
assert torch.equal(out[..., :hidden], torch.full((bsz, seq, hidden), 1.0))
assert torch.equal(out[..., hidden : 2 * hidden], torch.full((bsz, seq, hidden), 3.0))
assert torch.equal(out[..., 2 * hidden :], torch.full((bsz, seq, hidden), 5.0))


def test_sample_temperature_zero_is_argmax():
logits = torch.tensor([[[2.0, 1.5, 1.0, 0.5], [0.1, 0.9, 0.2, 0.0]]])
out = sample(logits, temperature=0.0)
assert out.shape == (1, 2)
assert out[0, 0].item() == 0
assert out[0, 1].item() == 1


def test_sample_with_temperature_returns_in_range():
torch.manual_seed(0)
logits = torch.randn(1, 4, 32)
out = sample(logits, temperature=1.0)
assert out.shape == (1, 4)
assert (out >= 0).all() and (out < 32).all()


def _legacy_in_loop_check(output_ids, num_input_tokens, stop_token_ids):
"""Reproduces the pre-fix in-loop early-exit check from dflash_generate."""
return any(
stop_token_id in output_ids[:, num_input_tokens:]
for stop_token_id in stop_token_ids
)


def _new_in_loop_check(output_ids, num_input_tokens, cursor, stop_token_tensor):
"""The new in-loop early-exit check, scoped to the actually written slice."""
return torch.isin(
output_ids[0, num_input_tokens : cursor + 1], stop_token_tensor
).any().item()


def test_stop_token_check_does_not_scan_uninitialized_buffer():
"""Regression test for the pre-fix in-loop check.

dflash_generate pre-allocates output_ids with mask_token_id past the
cursor. If mask_token_id is also a stop token (a model-config-dependent
edge case the maintainer already cares about — see PR #76 "Preserve
output tokens that equal mask_token_id"), the legacy
`stop_id in output_ids[:, num_input_tokens:]` check fires on every
iteration because the still-pristine tail of the buffer is full of
mask tokens equal to a stop token. Generation aborts after the first
block even though no real stop was emitted.

The new check restricts the scan to positions that have actually been
written (`[num_input_tokens, cursor + 1]`) and does not regress.
"""
mask_token_id = 99
stop_token_ids = [99, 1] # collides with mask_token_id
num_input_tokens = 10
max_length = 50
block_size = 8

output_ids = torch.full(
(1, max_length + block_size), mask_token_id, dtype=torch.long
)
output_ids[:, :num_input_tokens] = 5

# Simulate one block of generation: 4 accepted draft tokens + 1 bonus,
# none of which are stop tokens. After dflash_generate's
# `start += acceptance_length + 1`, the cursor lands on the bonus token.
written = torch.tensor([5, 6, 7, 8, 10]) # no stop tokens, no mask tokens
output_ids[0, num_input_tokens : num_input_tokens + written.numel()] = written
cursor = num_input_tokens + written.numel() - 1

# Pre-fix check spuriously fires because it scans the still-mask-filled
# tail past the cursor.
assert _legacy_in_loop_check(output_ids, num_input_tokens, stop_token_ids) is True

stop_tensor = torch.as_tensor(stop_token_ids, dtype=output_ids.dtype)
assert _new_in_loop_check(output_ids, num_input_tokens, cursor, stop_tensor) is False


def test_stop_token_check_detects_real_stop_after_cursor_advance():
"""When a real stop token is written and the cursor has advanced, the
new check fires (matching the legacy semantic for the common case)."""
mask_token_id = 0
stop_token_ids = [7, 11]
num_input_tokens = 4
max_length = 32
block_size = 4

output_ids = torch.full(
(1, max_length + block_size), mask_token_id, dtype=torch.long
)
output_ids[:, :num_input_tokens] = 5

# Simulate one block worth of generation; cursor advances by 4.
output_ids[0, num_input_tokens : num_input_tokens + 4] = torch.tensor([3, 7, 6, 8])
cursor = num_input_tokens + 3

stop_tensor = torch.as_tensor(stop_token_ids, dtype=output_ids.dtype)
assert _new_in_loop_check(output_ids, num_input_tokens, cursor, stop_tensor) is True