Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions lmdeploy/pytorch/engine/engine_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,10 +151,11 @@ def _log_resps(outputs: list[InferOutput]):
def _send_resp(self, out: InferOutput):
"""Send response."""
# skip cancelled response
if out.resp.is_done:
return
resp_type = (ResponseType.FINISH if out.finish else ResponseType.SUCCESS)
logprobs = None if out.resp.data is None else out.resp.data.get('logprobs', None)
if out.resp.is_done:
resp_type = ResponseType.CANCEL
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we not return here, response would activate response event

I am not sure if this behaviour is safe.

else:
resp_type = (ResponseType.FINISH if out.finish else ResponseType.SUCCESS)
response_reqs(self.req_manager,
out.resp,
resp_type,
Expand Down Expand Up @@ -242,7 +243,7 @@ def __get_logprobs(batched_outputs: 'BatchedOutputs'):
stop_pos = batched_outputs.stop_pos[idx]
# only apply when stopped
if stop_pos > -1:
mask = mask & (stop_pos >= range_tensor)
mask = torch.logical_and(mask, stop_pos >= range_tensor)
indices = logprobs.indices[start:end][mask].tolist()
vals = logprobs.vals[start:end][mask].tolist()
results[idx] = list(zip(vals, indices))
Expand Down
2 changes: 1 addition & 1 deletion lmdeploy/pytorch/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -731,7 +731,7 @@ def routed_experts(self) -> np.ndarray:
if (not self.return_routed_experts) or self.all_routed_experts is None:
return None

end = max(0, self.num_all_ids - 1)
end = max(0, self.num_valid_ids - 1)
if 0 < end <= len(self.all_routed_experts):
return self.all_routed_experts.get_real()[:end]
else:
Expand Down
3 changes: 2 additions & 1 deletion lmdeploy/pytorch/spec_decode/proposers/deepseek_mtp.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@ def get_outputs(self,
model_metas = model_outputs['model_metas']
if extra_inputs is not None:
last_token_loc = extra_inputs.last_token_indices
target_hidden_states = model_inputs.target_hidden_states[:, last_token_loc]
hidden_states = hidden_states[:, last_token_loc]
# use hidden states for draft prefill forward for next step
target_hidden_states = hidden_states
else:
target_hidden_states = hidden_states

Expand Down
14 changes: 10 additions & 4 deletions lmdeploy/pytorch/spec_decode/spec_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,10 +353,16 @@ def __compute_logprobs(raw_logprobs: torch.Tensor, token_ids: torch.LongTensor,
# update last token indices
last_token_indices = last_token_indices - num_rejected_tokens
else:
bonus_logits, raw_logprobs = await logits_processor(target_logits)
# Sample next token from bonus position
next_token_ids = logits_processor.sampling(bonus_logits) # [batch_size]
output_token_ids = next_token_ids.unsqueeze(-1)
if model_inputs.is_chunk and not model_inputs.is_last_chunk:
# dummy output, no need to sampling or compute logprobs for non-last chunk
next_token_ids = num_rejected_tokens
output_token_ids = num_rejected_tokens.unsqueeze(-1)
raw_logprobs = None
else:
bonus_logits, raw_logprobs = await logits_processor(target_logits)
# Sample next token from bonus position
next_token_ids = logits_processor.sampling(bonus_logits) # [batch_size]
output_token_ids = next_token_ids.unsqueeze(-1)

logprobs = __compute_logprobs(raw_logprobs, output_token_ids, sampling_inputs.max_num_logprobs)

Expand Down
8 changes: 6 additions & 2 deletions lmdeploy/pytorch/strategies/ar_spec/model_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,11 @@ def update(self, delta: 'ModelInputsDelta'):
indices = delta.indices
output_draft_token_ids = self.output_draft_token_ids[indices]
num_rejected_tokens = self.num_rejected_tokens[indices]
return ARSpecExtraInputs(output_draft_token_ids=output_draft_token_ids, num_rejected_tokens=num_rejected_tokens)
output_token_ids=self.output_token_ids[indices] if self.output_token_ids is not None else None
return ARSpecExtraInputs(output_draft_token_ids=output_draft_token_ids,
num_rejected_tokens=num_rejected_tokens,
output_token_ids=output_token_ids,
)


@dataclass
Expand Down Expand Up @@ -119,7 +123,7 @@ def step(self,
stop_words_rsp = stop_words.reshape(1, 1, -1)
assert stop_words_rsp.ndim == token_ids_rsp.ndim == 3
stop_mask = (token_ids_rsp == stop_words_rsp).any(-1)
mask = mask ^ stop_mask
mask = torch.logical_or(mask, stop_mask)
# find the index of first `1`, if not found, would be 0
index = torch.argmax(mask.int(), dim=-1, keepdim=True)
# update index of 0 to -1 if not found
Expand Down
126 changes: 68 additions & 58 deletions lmdeploy/pytorch/strategies/ar_spec/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,6 @@ class SchedulerSequenceARSpec(SchedulerSequenceDefault):
def __post_init__(self):
"""Post init."""
super().__post_init__()
self._num_spec_ids: int = 0
self._num_new_valid: int = 0
self._num_valid_ids: int = len(self.history_cache)
self._strategy: ARSpecSequenceStrategy = self._seq_meta.strategy

Expand All @@ -40,79 +38,73 @@ def num_valid_ids(self):
return self._num_valid_ids

@property
def num_spec_ids(self):
return self._num_spec_ids
def routed_experts(self) -> np.ndarray:
if (not self.return_routed_experts) or self.all_routed_experts is None:
return None

end = max(0, self.num_valid_ids - 1)
if 0 < end <= len(self.all_routed_experts):
return self.all_routed_experts.get_real()[:end]
else:
return None

@property
def generated_ids(self) -> np.ndarray:
end = self.num_valid_ids
start = end - self.num_new_tokens
return self.history_cache[start:end]

def set_stop_pos(self, pos: int):
val = self._num_new_valid - pos - 1
self._num_valid_ids -= val
self.num_new_tokens -= val
self._num_token_ids = 1
self._num_history_ids -= val

self._num_spec_ids = 0
self._num_new_valid = 0
self.history_cache.resize(self.num_valid_ids)

def _update_token_ids_inputs(self, token_ids: np.ndarray):
"""Append tokens."""
num_tokens = len(token_ids)
self.output_start_pos = self.num_valid_ids + num_tokens
self._num_valid_ids = self.num_history_ids + num_tokens
self._num_valid_ids = self._num_valid_ids + num_tokens
self._num_token_ids = num_tokens
self.num_new_tokens = 0
self._num_spec_ids = 0
self._num_new_valid = 0
self.history_cache.append(token_ids)

def _update_token_ids_prefill(self, token_ids: np.ndarray, draft_token_ids: np.ndarray):
def _update_token_ids_prefill(self, token_ids: np.ndarray, draft_token_ids: np.ndarray,
stop_pos: int = -1, routed_experts: np.ndarray = None):
"""Update token ids for prefill."""
num_valid = len(token_ids)
self._num_spec_ids = len(draft_token_ids)
token_ids = np.concatenate([token_ids, draft_token_ids])
num_tokens = len(token_ids)
self.history_cache.append(token_ids)
self.append_routed_experts(routed_experts)
self._num_history_ids += self._num_token_ids
self._num_token_ids = num_tokens
self.num_new_tokens += num_valid
self._num_new_valid = num_valid
self._num_valid_ids = self.num_history_ids + num_valid
self.history_cache.append(token_ids)

def _update_token_ids_decode(self, token_ids: np.ndarray, draft_token_ids: np.ndarray = None):
self._num_token_ids = num_valid
if stop_pos == -1:
# not stopping, add drafted tokens
self._num_token_ids += len(draft_token_ids)
self.history_cache.append(draft_token_ids)

def _update_token_ids_decode(self, token_ids: np.ndarray, draft_token_ids: np.ndarray,
stop_pos: int = -1, routed_experts: np.ndarray = None):
"""Update token ids for decode."""
# back to last valid position
self.history_cache.resize(self.num_valid_ids)

valid_ids = token_ids[token_ids > -1]
num_valid = len(valid_ids)
self.num_new_tokens = self.num_new_tokens + num_valid
if stop_pos > -1:
valid_ids = valid_ids[:stop_pos+1]

self._num_new_valid = num_valid
num_valid = len(valid_ids)
self.num_new_tokens += num_valid
self._num_valid_ids += num_valid
self._num_history_ids = self.num_valid_ids - 1
# append the last accepted tokens
self.history_cache.append(valid_ids)
# append valid experts
if routed_experts is not None:
routed_experts = routed_experts[:num_valid]
self.append_routed_experts(routed_experts)

# last step has spec ids
if self.num_spec_ids > 0:
token_ids = valid_ids[-1:]
else:
token_ids = valid_ids

num_tokens = len(token_ids)

if draft_token_ids is not None:
num_tokens = 1 + len(draft_token_ids)
token_ids = np.concatenate([token_ids, draft_token_ids])
self._num_spec_ids = len(draft_token_ids)
if stop_pos > -1:
self._num_token_ids = 1
else:
self._num_spec_ids = 0

self._num_token_ids = num_tokens
if self.num_history_ids < len(self.history_cache):
self.history_cache.resize(self.num_history_ids)
self.history_cache.append(token_ids)
# add new draft tokens if not stopped
self.history_cache.append(draft_token_ids)
self._num_token_ids = 1 + len(draft_token_ids)

def update_token_ids(self,
token_ids: Tensor,
Expand All @@ -122,6 +114,7 @@ def update_token_ids(self,
draft_token_ids: Tensor = None,
mode: UpdateTokenMode = UpdateTokenMode.INPUTS,
routed_experts: np.ndarray = None,
stop_pos: int = -1,
**kwargs):
"""Update token ids, old token ids will be added to history."""
# update history image nums
Expand All @@ -134,25 +127,42 @@ def update_token_ids(self,

token_ids: np.ndarray = _to_ndarray(token_ids)

# record cached expert ids
if routed_experts is not None:
num_reject_tokens = (token_ids == -1).sum().item()
routed_experts = routed_experts[:routed_experts.shape[0] - num_reject_tokens]
self.append_routed_experts(routed_experts)

if draft_token_ids is not None:
draft_token_ids = _to_ndarray(draft_token_ids)
if mode == UpdateTokenMode.INPUTS:
self._update_token_ids_inputs(token_ids)
elif mode == UpdateTokenMode.PREFILL:
self._update_token_ids_prefill(token_ids, draft_token_ids)
self._update_token_ids_prefill(token_ids, draft_token_ids,
stop_pos=stop_pos, routed_experts=routed_experts)
else:
self._update_token_ids_decode(token_ids, draft_token_ids)
self._update_token_ids_decode(token_ids, draft_token_ids,
stop_pos=stop_pos, routed_experts=routed_experts)
if model_meta is not None:
self.model_meta = model_meta

self._update_mrope_pos_ids()

def set_step(self, step: int):
"""Set step."""
num_valid_ids = self.num_valid_ids
# update step for vlm
if len(self.history_embeddings) > 0:
new_step, self._num_history_images, self._num_images = \
self.history_embeddings.get_step(step)
assert 0 <= new_step <= step
step = new_step
self._num_history_ids = step
self._num_token_ids = num_valid_ids - step
self.num_ignored_history = min(step, self.num_ignored_history)

self.history_cache.resize(num_valid_ids)
self.model_meta = None

if self.return_routed_experts:
# chunk long context might not have all routed experts
if len(self.all_routed_experts) > step:
self.all_routed_experts.resize(step)


class ARSpecSequenceStrategy(ARSequenceStrategy):

Expand Down Expand Up @@ -219,7 +229,7 @@ def update_running(self, running: SeqList, batched_outputs: BatchedOutputs, mode
draft_token_ids=cur_draft_tokens,
model_meta=model_meta,
mode=update_mode,
routed_experts=routed_experts)
routed_experts=routed_experts,
stop_pos=stop_pos[idx])
if stop:
msg.set_stop_pos(stop_pos[idx])
msg.state.finish()
3 changes: 0 additions & 3 deletions lmdeploy/pytorch/strategies/ar_spec/step_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,9 +177,6 @@ def step_decode(
model_inputs.is_decoding = True
model_inputs.model_metas = model_metas

# update extra inputs
extra_inputs.output_token_ids = extra_outputs.draft_token_ids

# update inputs with rejected token adjustment
step_seqlens = model_inputs.seq_length - extra_inputs.num_rejected_tokens
batch_size = step_seqlens.size(0)
Expand Down
12 changes: 8 additions & 4 deletions lmdeploy/serve/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -1080,12 +1080,16 @@ async def _inner_call():
return create_error_response(HTTPStatus.BAD_REQUEST, 'Client disconnected')
text += res.response or ''
output_ids.extend(res.token_ids or [])
if res.logprobs:
for tok, tok_logprobs in zip(res.token_ids, res.logprobs):
logprobs.append((tok_logprobs[tok], tok))
logprobs.extend(res.logprobs or [])

output_token_logprobs = []
if len(logprobs) and len(output_ids):
for tok, tok_logprobs in zip(output_ids, logprobs):
output_token_logprobs.append((tok_logprobs[tok], tok))

nonlocal response
meta = GenerateReqMetaOutput(finish_reason=dict(type=res.finish_reason) if res.finish_reason else None,
output_token_logprobs=logprobs or None,
output_token_logprobs=output_token_logprobs or None,
prompt_tokens=res.input_token_len,
routed_experts=res.routed_experts,
completion_tokens=res.generate_token_len)
Expand Down
Loading
Loading