Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions lmdeploy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,6 +577,10 @@ def extend(self, other: 'Response') -> 'Response':
self.logprobs = self.logprobs or []
self.logprobs += other.logprobs
self.routed_experts = other.routed_experts
if other.logits is not None:
self.logits = other.logits
if other.last_hidden_state is not None:
self.last_hidden_state = other.last_hidden_state
return self


Expand Down
1 change: 1 addition & 0 deletions lmdeploy/pytorch/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ class InferOutput:
meta: Any = None
finish: bool = False
logits: torch.Tensor = None
last_hidden_state: torch.Tensor = None
logprobs: torch.Tensor = None

# send cache blocks back for migration in Disaggregated LLM Serving
Expand Down
3 changes: 3 additions & 0 deletions lmdeploy/pytorch/engine/engine_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,14 +230,17 @@ async def async_stream_infer(self,
# request might be cancelled before any output
token_ids = []
logits = None
last_hidden_state = None
else:
token_ids = resp_data['token_ids'][output_offset:].tolist()
logits = resp_data.get('logits', None)
last_hidden_state = resp_data.get('last_hidden_state', None)
num_ids = len(token_ids) - output_offset
logger.debug(f'session[{session_id}] finish: num_out_ids={num_ids}.')
yield EngineOutput(resp.type,
token_ids,
logits=logits,
last_hidden_state=last_hidden_state,
cache_block_ids=cache_block_ids,
req_metrics=req_metrics,
routed_experts=routed_experts,
Expand Down
56 changes: 52 additions & 4 deletions lmdeploy/pytorch/engine/engine_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ def _send_resp(self, out: InferOutput):
resp_type,
data=dict(token_ids=out.token_ids,
logits=out.logits,
last_hidden_state=out.last_hidden_state,
cache_block_ids=out.cache_block_ids,
req_metrics=out.req_metrics,
routed_experts=out.routed_experts,
Expand Down Expand Up @@ -225,6 +226,16 @@ def __get_logit(msg, logits: torch.Tensor, seq_length: list[int], idx: int):

return logit

def __get_hidden_state(msg, hidden_states: torch.Tensor, seq_length: list[int], idx: int):
hs = hidden_states.split(seq_length)[idx]
if len(msg.all_hidden_states) > 0:
# for chunked long context
msg.append_hidden_states(hs)
hs = msg.hidden_states
msg.all_hidden_states.resize(0)

return hs

def __get_logprobs(batched_outputs: 'BatchedOutputs'):
"""Get valid logprobs."""
batch_size = batched_outputs.stop_pos.size(0)
Expand All @@ -249,21 +260,31 @@ def __get_logprobs(batched_outputs: 'BatchedOutputs'):
return results

logits = batched_outputs.logits
hidden_states = batched_outputs.hidden_states
all_routed_experts = batched_outputs.all_routed_experts

if model_inputs is not None and (model_inputs.is_chunk and not model_inputs.is_last_chunk):
# chunk long context does not need to update seqs and outputs
seq = running[0]
seq.append_routed_experts(all_routed_experts)
seq.append_logits(logits)
# For 'all' mode, accumulate chunk logits/hidden_states; for 'generation' mode skip
if seq.return_logits and not seq.logits_generation_mode:
seq.append_logits(logits)
if seq.return_hidden_states and not seq.hidden_states_generation_mode:
seq.append_hidden_states(hidden_states)
return dict()

new_token_timestamp = batched_outputs.new_token_timestamp
logprobs = batched_outputs.logprobs

all_logprobs = __get_logprobs(batched_outputs)

seq_length = [seq.num_token_ids for seq in running]
if model_inputs is not None:
seq_length = model_inputs.seq_length.tolist()
elif delta is not None:
seq_length = delta.seq_length.tolist()
else:
seq_length = [seq.num_token_ids for seq in running]
is_run = [seq.status == MessageStatus.RUNNING for seq in running]
self.seq_strategy.update_running(running=running,
batched_outputs=batched_outputs,
Expand Down Expand Up @@ -314,10 +335,37 @@ def __get_logprobs(batched_outputs: 'BatchedOutputs'):
outputs[session_id] = out

if msg.return_logits:
logit = __get_logit(msg, logits, seq_length, idx)
outputs[session_id].logits = logit
if msg.logits_generation_mode:
# Accumulate last-position logit for each generation step
if logits is not None:
last_logit = logits.split(seq_length)[idx][-1:].detach().cpu()
msg.append_logits(last_logit)
if finish:
outputs[session_id].logits = msg.logits
msg.all_logits.resize(0)
else:
# 'all' mode: return full sequence logits (existing behavior)
logit = __get_logit(msg, logits, seq_length, idx)
outputs[session_id].logits = logit

if msg.return_hidden_states:
if msg.hidden_states_generation_mode:
# 'generation' mode: accumulate last-position hidden state at each step
if hidden_states is not None:
last_hs = hidden_states[idx:idx + 1].detach().cpu()
msg.append_hidden_states(last_hs)
if finish:
outputs[session_id].last_hidden_state = msg.hidden_states
msg.all_hidden_states.resize(0)
else:
# 'all' mode: return full sequence hidden states
if hidden_states is not None:
hs = __get_hidden_state(msg, hidden_states, seq_length, idx)
outputs[session_id].last_hidden_state = hs

return outputs


async def _main_loop_try_send_next_inputs(self):
"""Try send next inputs."""
scheduler = self.scheduler
Expand Down
12 changes: 12 additions & 0 deletions lmdeploy/pytorch/engine/inputs_maker.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,14 @@ def __need_logits(seqs: 'SeqList'):
return True
return any(seq.return_logits for seq in seqs)

def __need_hidden_states(seqs: 'SeqList'):
"""Need hidden states."""
return any(seq.return_hidden_states for seq in seqs)

def __hidden_states_all_mode(seqs: 'SeqList'):
"""Check if any sequence uses hidden states 'all' mode."""
return any(seq.return_hidden_states and not seq.hidden_states_generation_mode for seq in seqs)

def __need_routed_experts(seqs: 'SeqList'):
"""Need routed experts."""
return any(seq.return_routed_experts for seq in seqs)
Expand Down Expand Up @@ -711,6 +719,8 @@ def __create_inputs_prefill():
stopping_criteria = None

return_logits = __need_logits(running)
return_hidden_states = __need_hidden_states(running)
hidden_states_all_mode = __hidden_states_all_mode(running)
return_routed_experts = __need_routed_experts(running)

return dict(
Expand All @@ -722,6 +732,8 @@ def __create_inputs_prefill():
sampling_inputs=sampling_inputs,
stopping_criteria=stopping_criteria,
return_logits=return_logits,
return_hidden_states=return_hidden_states,
hidden_states_all_mode=hidden_states_all_mode,
extra_inputs=extra_inputs,
return_routed_experts=return_routed_experts,
)
Expand Down
42 changes: 42 additions & 0 deletions lmdeploy/pytorch/engine/model_agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ class BatchedOutputs:
stopped: torch.Tensor
stop_pos: torch.Tensor | None = None
logits: torch.Tensor | None = None
hidden_states: torch.Tensor | None = None
model_metas: list[dict[str, Any]] = None
logprobs: BatchedLogProbs | None = None
new_token_timestamp: int = 0
Expand Down Expand Up @@ -439,16 +440,47 @@ async def _async_model_forward(
self,
inputs: ModelInputs,
return_logits: bool,
return_hidden_states: bool = False,
hidden_states_all_mode: bool = False,
):
"""Model forward."""
origin_inputs = inputs
ret = await self.async_forward(inputs)

# For 'all' mode hidden states without return_logits, save the full hidden
# states before _postprocess_forward_output slices them to last position.
pre_postprocess_full_hs = None
if return_hidden_states and hidden_states_all_mode and not return_logits:
pre_postprocess_full_hs = ret['hidden_states'][0] # [total_tokens, hidden_dim]

if not return_logits:
ret = self._postprocess_forward_output(ret, origin_inputs)

hidden_states, ret = self.spec_agent.update_main_model_outputs(ret, origin_inputs)

if return_hidden_states:
# Extract hidden states to return to the user
hs = hidden_states
seq_length = ret.get('seq_length', inputs.seq_length)
if hidden_states_all_mode:
if return_logits:
# Full hidden states still available (postprocessing was skipped)
full_hs = hs[0] # [total_tokens, hidden_dim]
else:
# Use the saved full hidden states from before postprocessing
full_hs = pre_postprocess_full_hs
ret['last_hidden_states'] = full_hs
ret['hidden_states_seq_length'] = seq_length
else:
# 'generation' mode: last-position hidden state per sequence
if return_logits:
# hidden_states is full sequence, need to slice to last position
last_hs = self._slice_outs(hs[0], seq_length)
else:
# _postprocess_forward_output already sliced to last position
last_hs = hs[0]
ret['last_hidden_states'] = last_hs

logits = self.get_logits(hidden_states)
ret['logits'] = logits
return ret
Expand Down Expand Up @@ -601,6 +633,8 @@ async def _step_postprocess_with_output(self,
model_metas: Any,
need_broadcast_next: bool,
return_logits: bool = False,
return_hidden_states: bool = False,
last_hidden_states: torch.Tensor = None,
all_routed_experts: Any = None,
extra_inputs: ExtraInputs = None):
"""Step postprocess with output."""
Expand Down Expand Up @@ -639,6 +673,7 @@ async def _step_postprocess_with_output(self,
self._push_output(
BatchedOutputs(next_token_ids=output_token_ids,
logits=logits if return_logits else None,
hidden_states=last_hidden_states if return_hidden_states else None,
stopped=stopped,
stop_pos=stop_pos,
model_metas=model_metas,
Expand Down Expand Up @@ -677,6 +712,8 @@ async def _async_step(
sampling_inputs: SamplingInputs = None,
stopping_criteria: StoppingCriteria = None,
return_logits: bool = False,
return_hidden_states: bool = False,
hidden_states_all_mode: bool = False,
return_routed_experts: bool = False,
extra_inputs: ExtraInputs = None,
):
Expand Down Expand Up @@ -738,6 +775,8 @@ async def _async_step(
output = await self._async_model_forward(
inputs,
return_logits=return_logits,
return_hidden_states=return_hidden_states,
hidden_states_all_mode=hidden_states_all_mode,
)
# recovery is_decoding
inputs.is_decoding = is_decoding
Expand All @@ -751,6 +790,7 @@ async def _async_step(
last_logits = self._slice_outs(logits, seq_length) # [bs, 1, prob] -> [bs, prob]
extra_inputs = self.agent_strategy.slice_extra_inputs(extra_inputs, inputs, output)
model_metas = output.get('model_metas')
last_hidden_states = output.get('last_hidden_states', None)

if self.need_output:
logger.debug(f'<ForwardTask> rank[{rank}]: Sampling.')
Expand All @@ -776,6 +816,8 @@ async def _async_step(
model_metas,
need_broadcast_next,
return_logits=return_logits,
return_hidden_states=return_hidden_states,
last_hidden_states=last_hidden_states,
all_routed_experts=all_routed_experts,
extra_inputs=extra_inputs,
))
Expand Down
Loading
Loading