Skip to content
Open
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions lmdeploy/pytorch/engine/inputs_maker.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,16 @@ class InputsMakerConfig:
spec_decoding: bool = False
enable_chunked_prefill: bool = False
use_mrope: bool = False
prefill_interval: int = 16

@staticmethod
def from_engine(engine: 'Engine'):
cache_config = engine.cache_config
model_config = engine.model_config
prefill_interval = engine.engine_config.prefill_interval
kwargs = dict()
if prefill_interval is not None and prefill_interval > 0:
Comment thread
grimoire marked this conversation as resolved.
Outdated
kwargs['prefill_interval'] = prefill_interval
return InputsMakerConfig(
spec_decoding=engine.specdecode_config is not None,
max_batches=cache_config.max_batches,
Expand All @@ -69,6 +74,7 @@ def from_engine(engine: 'Engine'):
dp=engine.dist_config.dp,
enable_chunked_prefill=engine.misc_config.enable_chunked_prefill,
use_mrope=model_config.use_mrope,
**kwargs,
)


Expand Down Expand Up @@ -225,6 +231,9 @@ def __init__(

self._init_do_prefill(config)

# consecutive decode counter for prefill starvation prevention
self._decode_count = 0

# record for next forward.
self.next_is_prefill = True
self.forward_inputs = None
Expand Down Expand Up @@ -693,6 +702,10 @@ def __create_inputs_prefill():
swap_out_map,
) = __create_inputs_prefill()

# reset decode count when non-decoding inputs are produced
if inputs is not None and not inputs.is_decoding:
self._decode_count = 0

# try decoding
if inputs is None and len(self.running_seqs) > 0 and self.config.role != EngineRole.Prefill:
prefill = False
Expand Down Expand Up @@ -735,8 +748,13 @@ def do_prefill_default(self):

# do decoding if not waiting
if not scheduler.has_waiting():
self._decode_count = 0
return False

# force prefill if too many consecutive decode rounds
if self._decode_count >= self.config.prefill_interval:
Comment thread
grimoire marked this conversation as resolved.
return True

# do prefill if too much tokens
waiting = scheduler.waiting
token_count = 0
Expand All @@ -753,6 +771,7 @@ def do_prefill_default(self):
return True

# decoding
self._decode_count += 1
return False

def do_prefill_chunked(self):
Expand Down
Loading