diff --git a/lmdeploy/messages.py b/lmdeploy/messages.py index eeb2f7a056..04b393517b 100644 --- a/lmdeploy/messages.py +++ b/lmdeploy/messages.py @@ -577,6 +577,10 @@ def extend(self, other: 'Response') -> 'Response': self.logprobs = self.logprobs or [] self.logprobs += other.logprobs self.routed_experts = other.routed_experts + if other.logits is not None: + self.logits = other.logits + if other.last_hidden_state is not None: + self.last_hidden_state = other.last_hidden_state return self diff --git a/lmdeploy/pytorch/engine/engine.py b/lmdeploy/pytorch/engine/engine.py index e2a7495624..87a801bb2c 100644 --- a/lmdeploy/pytorch/engine/engine.py +++ b/lmdeploy/pytorch/engine/engine.py @@ -44,6 +44,7 @@ class InferOutput: meta: Any = None finish: bool = False logits: torch.Tensor = None + last_hidden_state: torch.Tensor = None logprobs: torch.Tensor = None # send cache blocks back for migration in Disaggregated LLM Serving diff --git a/lmdeploy/pytorch/engine/engine_instance.py b/lmdeploy/pytorch/engine/engine_instance.py index 217e1d4609..0062bd711a 100644 --- a/lmdeploy/pytorch/engine/engine_instance.py +++ b/lmdeploy/pytorch/engine/engine_instance.py @@ -230,14 +230,17 @@ async def async_stream_infer(self, # request might be cancelled before any output token_ids = [] logits = None + last_hidden_state = None else: token_ids = resp_data['token_ids'][output_offset:].tolist() logits = resp_data.get('logits', None) + last_hidden_state = resp_data.get('last_hidden_state', None) num_ids = len(token_ids) - output_offset logger.debug(f'session[{session_id}] finish: num_out_ids={num_ids}.') yield EngineOutput(resp.type, token_ids, logits=logits, + last_hidden_state=last_hidden_state, cache_block_ids=cache_block_ids, req_metrics=req_metrics, routed_experts=routed_experts, diff --git a/lmdeploy/pytorch/engine/engine_loop.py b/lmdeploy/pytorch/engine/engine_loop.py index 2584b18e00..107d138f4f 100644 --- a/lmdeploy/pytorch/engine/engine_loop.py +++ b/lmdeploy/pytorch/engine/engine_loop.py @@ -160,6 +160,7 @@ def _send_resp(self, out: InferOutput): resp_type, data=dict(token_ids=out.token_ids, logits=out.logits, + last_hidden_state=out.last_hidden_state, cache_block_ids=out.cache_block_ids, req_metrics=out.req_metrics, routed_experts=out.routed_experts, @@ -225,6 +226,16 @@ def __get_logit(msg, logits: torch.Tensor, seq_length: list[int], idx: int): return logit + def __get_hidden_state(msg, hidden_states: torch.Tensor, seq_length: list[int], idx: int): + hs = hidden_states.split(seq_length)[idx] + if len(msg.all_hidden_states) > 0: + # for chunked long context + msg.append_hidden_states(hs) + hs = msg.hidden_states + msg.all_hidden_states.resize(0) + + return hs + def __get_logprobs(batched_outputs: 'BatchedOutputs'): """Get valid logprobs.""" batch_size = batched_outputs.stop_pos.size(0) @@ -249,13 +260,18 @@ def __get_logprobs(batched_outputs: 'BatchedOutputs'): return results logits = batched_outputs.logits + hidden_states = batched_outputs.hidden_states all_routed_experts = batched_outputs.all_routed_experts if model_inputs is not None and (model_inputs.is_chunk and not model_inputs.is_last_chunk): # chunk long context does not need to update seqs and outputs seq = running[0] seq.append_routed_experts(all_routed_experts) - seq.append_logits(logits) + # For 'all' mode, accumulate chunk logits/hidden_states; for 'generation' mode skip + if seq.return_logits and not seq.logits_generation_mode: + seq.append_logits(logits) + if seq.return_hidden_states and not seq.hidden_states_generation_mode: + seq.append_hidden_states(hidden_states) return dict() new_token_timestamp = batched_outputs.new_token_timestamp @@ -263,7 +279,12 @@ def __get_logprobs(batched_outputs: 'BatchedOutputs'): all_logprobs = __get_logprobs(batched_outputs) - seq_length = [seq.num_token_ids for seq in running] + if model_inputs is not None: + seq_length = model_inputs.seq_length.tolist() + elif delta is not None: + seq_length = delta.seq_length.tolist() + else: + seq_length = [seq.num_token_ids for seq in running] is_run = [seq.status == MessageStatus.RUNNING for seq in running] self.seq_strategy.update_running(running=running, batched_outputs=batched_outputs, @@ -314,10 +335,37 @@ def __get_logprobs(batched_outputs: 'BatchedOutputs'): outputs[session_id] = out if msg.return_logits: - logit = __get_logit(msg, logits, seq_length, idx) - outputs[session_id].logits = logit + if msg.logits_generation_mode: + # Accumulate last-position logit for each generation step + if logits is not None: + last_logit = logits.split(seq_length)[idx][-1:].detach().cpu() + msg.append_logits(last_logit) + if finish: + outputs[session_id].logits = msg.logits + msg.all_logits.resize(0) + else: + # 'all' mode: return full sequence logits (existing behavior) + logit = __get_logit(msg, logits, seq_length, idx) + outputs[session_id].logits = logit + + if msg.return_hidden_states: + if msg.hidden_states_generation_mode: + # 'generation' mode: accumulate last-position hidden state at each step + if hidden_states is not None: + last_hs = hidden_states[idx:idx + 1].detach().cpu() + msg.append_hidden_states(last_hs) + if finish: + outputs[session_id].last_hidden_state = msg.hidden_states + msg.all_hidden_states.resize(0) + else: + # 'all' mode: return full sequence hidden states + if hidden_states is not None: + hs = __get_hidden_state(msg, hidden_states, seq_length, idx) + outputs[session_id].last_hidden_state = hs + return outputs + async def _main_loop_try_send_next_inputs(self): """Try send next inputs.""" scheduler = self.scheduler diff --git a/lmdeploy/pytorch/engine/inputs_maker.py b/lmdeploy/pytorch/engine/inputs_maker.py index 72759d3cd6..0815846a27 100644 --- a/lmdeploy/pytorch/engine/inputs_maker.py +++ b/lmdeploy/pytorch/engine/inputs_maker.py @@ -613,6 +613,14 @@ def __need_logits(seqs: 'SeqList'): return True return any(seq.return_logits for seq in seqs) + def __need_hidden_states(seqs: 'SeqList'): + """Need hidden states.""" + return any(seq.return_hidden_states for seq in seqs) + + def __hidden_states_all_mode(seqs: 'SeqList'): + """Check if any sequence uses hidden states 'all' mode.""" + return any(seq.return_hidden_states and not seq.hidden_states_generation_mode for seq in seqs) + def __need_routed_experts(seqs: 'SeqList'): """Need routed experts.""" return any(seq.return_routed_experts for seq in seqs) @@ -711,6 +719,8 @@ def __create_inputs_prefill(): stopping_criteria = None return_logits = __need_logits(running) + return_hidden_states = __need_hidden_states(running) + hidden_states_all_mode = __hidden_states_all_mode(running) return_routed_experts = __need_routed_experts(running) return dict( @@ -722,6 +732,8 @@ def __create_inputs_prefill(): sampling_inputs=sampling_inputs, stopping_criteria=stopping_criteria, return_logits=return_logits, + return_hidden_states=return_hidden_states, + hidden_states_all_mode=hidden_states_all_mode, extra_inputs=extra_inputs, return_routed_experts=return_routed_experts, ) diff --git a/lmdeploy/pytorch/engine/model_agent/agent.py b/lmdeploy/pytorch/engine/model_agent/agent.py index 55fb6e807b..28a51ce699 100644 --- a/lmdeploy/pytorch/engine/model_agent/agent.py +++ b/lmdeploy/pytorch/engine/model_agent/agent.py @@ -77,6 +77,7 @@ class BatchedOutputs: stopped: torch.Tensor stop_pos: torch.Tensor | None = None logits: torch.Tensor | None = None + hidden_states: torch.Tensor | None = None model_metas: list[dict[str, Any]] = None logprobs: BatchedLogProbs | None = None new_token_timestamp: int = 0 @@ -439,16 +440,47 @@ async def _async_model_forward( self, inputs: ModelInputs, return_logits: bool, + return_hidden_states: bool = False, + hidden_states_all_mode: bool = False, ): """Model forward.""" origin_inputs = inputs ret = await self.async_forward(inputs) + # For 'all' mode hidden states without return_logits, save the full hidden + # states before _postprocess_forward_output slices them to last position. + pre_postprocess_full_hs = None + if return_hidden_states and hidden_states_all_mode and not return_logits: + pre_postprocess_full_hs = ret['hidden_states'][0] # [total_tokens, hidden_dim] + if not return_logits: ret = self._postprocess_forward_output(ret, origin_inputs) hidden_states, ret = self.spec_agent.update_main_model_outputs(ret, origin_inputs) + if return_hidden_states: + # Extract hidden states to return to the user + hs = hidden_states + seq_length = ret.get('seq_length', inputs.seq_length) + if hidden_states_all_mode: + if return_logits: + # Full hidden states still available (postprocessing was skipped) + full_hs = hs[0] # [total_tokens, hidden_dim] + else: + # Use the saved full hidden states from before postprocessing + full_hs = pre_postprocess_full_hs + ret['last_hidden_states'] = full_hs + ret['hidden_states_seq_length'] = seq_length + else: + # 'generation' mode: last-position hidden state per sequence + if return_logits: + # hidden_states is full sequence, need to slice to last position + last_hs = self._slice_outs(hs[0], seq_length) + else: + # _postprocess_forward_output already sliced to last position + last_hs = hs[0] + ret['last_hidden_states'] = last_hs + logits = self.get_logits(hidden_states) ret['logits'] = logits return ret @@ -601,6 +633,8 @@ async def _step_postprocess_with_output(self, model_metas: Any, need_broadcast_next: bool, return_logits: bool = False, + return_hidden_states: bool = False, + last_hidden_states: torch.Tensor = None, all_routed_experts: Any = None, extra_inputs: ExtraInputs = None): """Step postprocess with output.""" @@ -639,6 +673,7 @@ async def _step_postprocess_with_output(self, self._push_output( BatchedOutputs(next_token_ids=output_token_ids, logits=logits if return_logits else None, + hidden_states=last_hidden_states if return_hidden_states else None, stopped=stopped, stop_pos=stop_pos, model_metas=model_metas, @@ -677,6 +712,8 @@ async def _async_step( sampling_inputs: SamplingInputs = None, stopping_criteria: StoppingCriteria = None, return_logits: bool = False, + return_hidden_states: bool = False, + hidden_states_all_mode: bool = False, return_routed_experts: bool = False, extra_inputs: ExtraInputs = None, ): @@ -738,6 +775,8 @@ async def _async_step( output = await self._async_model_forward( inputs, return_logits=return_logits, + return_hidden_states=return_hidden_states, + hidden_states_all_mode=hidden_states_all_mode, ) # recovery is_decoding inputs.is_decoding = is_decoding @@ -751,6 +790,7 @@ async def _async_step( last_logits = self._slice_outs(logits, seq_length) # [bs, 1, prob] -> [bs, prob] extra_inputs = self.agent_strategy.slice_extra_inputs(extra_inputs, inputs, output) model_metas = output.get('model_metas') + last_hidden_states = output.get('last_hidden_states', None) if self.need_output: logger.debug(f' rank[{rank}]: Sampling.') @@ -776,6 +816,8 @@ async def _async_step( model_metas, need_broadcast_next, return_logits=return_logits, + return_hidden_states=return_hidden_states, + last_hidden_states=last_hidden_states, all_routed_experts=all_routed_experts, extra_inputs=extra_inputs, )) diff --git a/lmdeploy/pytorch/messages.py b/lmdeploy/pytorch/messages.py index 1ef5caba83..dd7f3757cf 100644 --- a/lmdeploy/pytorch/messages.py +++ b/lmdeploy/pytorch/messages.py @@ -59,7 +59,9 @@ class SamplingParam: response_format: None | str = None logits_processors: None | list[LogitsProcessor] = None out_logits: bool = False + out_logits_mode: str = None out_last_hidden_states: bool = False + out_last_hidden_states_mode: str = None num_logprobs: int = -1 return_routed_experts: bool = False @@ -87,13 +89,15 @@ def from_gen_config(cls, gen_config: GenerationConfig): response_format = gen_config.response_format output_logits = gen_config.output_logits - if output_logits: - if (output_logits != 'all' or gen_config.max_new_tokens > 0): - output_logits = None - logger.warning('Pytorch Engine only support output_logits="all"' - ' with max_new_tokens=0') - if gen_config.output_last_hidden_state is not None: - logger.warning('Pytorch Engine does not support output last hidden states.') + if output_logits == 'all' and gen_config.max_new_tokens > 0: + output_logits = None + logger.warning('Pytorch Engine only support output_logits="all"' + ' with max_new_tokens=0') + output_last_hidden_state = gen_config.output_last_hidden_state + if output_last_hidden_state == 'all' and gen_config.max_new_tokens > 0: + output_last_hidden_state = None + logger.warning('Pytorch Engine only support output_last_hidden_state="all"' + ' with max_new_tokens=0') if top_p < 0 or top_p > 1.0: logger.warning('`top_p` has to be a float > 0 and < 1' f' but is {top_p}') @@ -156,6 +160,9 @@ def from_gen_config(cls, gen_config: GenerationConfig): min_new_tokens=min_new_tokens, logits_processors=gen_config.logits_processors, out_logits=(output_logits is not None), + out_logits_mode=output_logits, + out_last_hidden_states=(output_last_hidden_state is not None), + out_last_hidden_states_mode=output_last_hidden_state, num_logprobs=logprobs, return_routed_experts=gen_config.return_routed_experts, repetition_ngram_size=repetition_ngram_size, @@ -549,6 +556,50 @@ def clone(self): return ret +class HistoryHiddenStates(_HistoryDataBase): + """History hidden states. + + Hidden states are stored as int16 numpy arrays (same bit-level storage as HistoryLogits), reinterpreting + float16/bfloat16 tensors byte-for-byte. _create_empty_array returns None so that the shape (hidden_dim) is inferred + dynamically from the first append call, matching the HistoryLogits pattern. + """ + ALLOC_SIZE = 64 + COPY_ON_RESIZE = True + + def __init__(self, hidden_states: np.ndarray = None, dtype: np.dtype = np.int16): + super().__init__(hidden_states, dtype) + self._torch_dtype = None + + def _create_empty_array(self, dtype): + """Return None; shape is determined on first append (see + HistoryLogits).""" + return None + + def _get_pad_width(self, reserve_size: int): + """Get pad width for multi-dimensional array.""" + return ((0, reserve_size), (0, 0)) + + def set_torch_dtype(self, torch_dtype): + """Set torch dtype.""" + self._torch_dtype = torch_dtype + + def get_hidden_states(self): + """Get hidden states as torch tensor.""" + if self._data is None: + return None + if self._torch_dtype is None: + return None + + hs_np = self.get_real() + return torch.frombuffer(hs_np, dtype=self._torch_dtype).view(hs_np.shape) + + def clone(self): + """clone.""" + ret = super().clone() + ret.set_torch_dtype(self._torch_dtype) + return ret + + class HistoryMropePosIds(_HistoryDataBase): """History mrope position ids.""" ALLOC_SIZE = 64 @@ -653,6 +704,9 @@ class SchedulerSequence: # logits all_logits: HistoryLogits = field(default_factory=HistoryLogits) + # hidden states + all_hidden_states: HistoryHiddenStates = field(default_factory=HistoryHiddenStates) + # mrope history_mrope_pos_ids: HistoryMropePosIds = field(default_factory=HistoryMropePosIds) @@ -790,11 +844,30 @@ def status(self): def return_logits(self): return self.sampling_param.out_logits + @property + def logits_generation_mode(self): + """Check if logits are in generation mode.""" + return self.sampling_param.out_logits_mode == 'generation' + @property def logits(self): """Get logits.""" return self.all_logits.get_logits() + @property + def return_hidden_states(self): + return self.sampling_param.out_last_hidden_states + + @property + def hidden_states_generation_mode(self): + """Check if hidden states are in generation mode.""" + return self.sampling_param.out_last_hidden_states_mode == 'generation' + + @property + def hidden_states(self): + """Get hidden states.""" + return self.all_hidden_states.get_hidden_states() + @property def mrope_pos_ids(self): """Get mrope pos ids.""" @@ -813,6 +886,17 @@ def append_logits(self, logits: Tensor | np.ndarray): logits = logits.view(torch.int16).numpy() self.all_logits.append(logits) + def append_hidden_states(self, hidden_states: Tensor | np.ndarray): + """Append hidden states.""" + if not self.return_hidden_states: + return + if hidden_states is None: + return + if isinstance(hidden_states, Tensor): + self.all_hidden_states.set_torch_dtype(hidden_states.dtype) + hidden_states = hidden_states.view(torch.int16).numpy() + self.all_hidden_states.append(hidden_states) + def get_input_multimodals(self): """Get input multimodals.""" start = self.num_history_ids