diff --git a/lmdeploy/pytorch/engine/inputs_maker.py b/lmdeploy/pytorch/engine/inputs_maker.py index 72759d3cd6..aa4fccab7e 100644 --- a/lmdeploy/pytorch/engine/inputs_maker.py +++ b/lmdeploy/pytorch/engine/inputs_maker.py @@ -55,11 +55,19 @@ class InputsMakerConfig: spec_decoding: bool = False enable_chunked_prefill: bool = False use_mrope: bool = False + prefill_interval: int = 16 @staticmethod def from_engine(engine: 'Engine'): cache_config = engine.cache_config model_config = engine.model_config + prefill_interval = engine.engine_config.prefill_interval + kwargs = dict() + if prefill_interval is not None: + if not isinstance(prefill_interval, int) or prefill_interval <= 0: + raise ValueError('engine.engine_config.prefill_interval must be a positive int ' + f'or None, but got {prefill_interval!r}') + kwargs['prefill_interval'] = prefill_interval return InputsMakerConfig( spec_decoding=engine.specdecode_config is not None, max_batches=cache_config.max_batches, @@ -69,6 +77,7 @@ def from_engine(engine: 'Engine'): dp=engine.dist_config.dp, enable_chunked_prefill=engine.misc_config.enable_chunked_prefill, use_mrope=model_config.use_mrope, + **kwargs, ) @@ -225,6 +234,9 @@ def __init__( self._init_do_prefill(config) + # consecutive decode counter for prefill starvation prevention + self._decode_count = 0 + # record for next forward. self.next_is_prefill = True self.forward_inputs = None @@ -693,6 +705,10 @@ def __create_inputs_prefill(): swap_out_map, ) = __create_inputs_prefill() + # reset decode count when non-decoding inputs are produced + if inputs is not None and not inputs.is_decoding: + self._decode_count = 0 + # try decoding if inputs is None and len(self.running_seqs) > 0 and self.config.role != EngineRole.Prefill: prefill = False @@ -735,8 +751,13 @@ def do_prefill_default(self): # do decoding if not waiting if not scheduler.has_waiting(): + self._decode_count = 0 return False + # force prefill if too many consecutive decode rounds + if self._decode_count >= self.config.prefill_interval: + return True + # do prefill if too much tokens waiting = scheduler.waiting token_count = 0 @@ -753,6 +774,7 @@ def do_prefill_default(self): return True # decoding + self._decode_count += 1 return False def do_prefill_chunked(self):