diff --git a/mlx_lm/server.py b/mlx_lm/server.py index ce8d95817..1aa19faf2 100644 --- a/mlx_lm/server.py +++ b/mlx_lm/server.py @@ -1,6 +1,7 @@ # Copyright © 2023-2024 Apple Inc. import argparse +import copy import json import logging import pickle @@ -40,7 +41,10 @@ ) from .models.cache import ( LRUPromptCache, + load_prompt_cache, make_prompt_cache, + save_prompt_cache, + trim_prompt_cache, ) from .sample_utils import make_logits_processors, make_sampler from .utils import _parse_size, load, sharded_load @@ -192,6 +196,8 @@ class GenerationArguments: top_logprobs: int seed: Optional[int] chat_template_kwargs: Optional[Dict[str, Any]] + prompt_cache_file: Optional[str] + disable_prompt_cache: bool = False @dataclass @@ -203,6 +209,7 @@ class CompletionRequest: messages: List[Any] tools: Optional[List[Any]] role_mapping: Optional[Dict[str, Any]] + prompt_cache_file: Optional[str] = None @dataclass @@ -683,7 +690,12 @@ def _make_state_machine( return sm, sequences def _is_batchable(self, args): - return self.model_provider.is_batchable and args.seed is None + return ( + self.model_provider.is_batchable + and args.seed is None + and args.prompt_cache_file is None + and not args.disable_prompt_cache + ) def _generate(self): # Local thread stream that we 'll pass to the BatchGenerator to make @@ -962,7 +974,19 @@ def progress(tokens_processed, tokens_total): # Load the KV cache self._log_cache_stats() - cache, rest = self.prompt_cache.fetch_nearest_cache( + request_prompt_cache = self.prompt_cache + if args.prompt_cache_file is not None or args.disable_prompt_cache: + request_prompt_cache = LRUPromptCache( + self.prompt_cache.max_size, + self.prompt_cache.max_bytes, + ) + self._load_disk_prompt_cache( + request_prompt_cache, + self.model_provider.model_key, + prompt, + args.prompt_cache_file if not args.disable_prompt_cache else None, + ) + cache, rest = request_prompt_cache.fetch_nearest_cache( self.model_provider.model_key, prompt ) ctx.prompt_cache_count = len(prompt) - len(rest) @@ -973,6 +997,7 @@ def progress(tokens_processed, tokens_total): cache += make_prompt_cache(self.model_provider.draft_model) # Process the prompt and generate tokens + saved_disk_prompt_cache = False for gen in stream_generate( model=model, tokenizer=tokenizer, @@ -1003,6 +1028,23 @@ def progress(tokens_processed, tokens_total): ), ) ) + if ( + args.prompt_cache_file is not None + and not args.disable_prompt_cache + and not saved_disk_prompt_cache + ): + disk_prompt_cache = copy.deepcopy(cache) + disk_cache_key = prompt[:] + if len(disk_cache_key) > 1: + trim_prompt_cache(disk_prompt_cache, 1) + disk_cache_key = disk_cache_key[:-1] + self._save_disk_prompt_cache( + args.prompt_cache_file, + disk_cache_key, + disk_prompt_cache, + "assistant", + ) + saved_disk_prompt_cache = True cache_key.append(gen.token) if ctx._should_stop: @@ -1023,6 +1065,50 @@ def progress(tokens_processed, tokens_total): except Exception as e: rqueue.put(e) + def _load_disk_prompt_cache(self, prompt_cache, model_key, prompt, prompt_cache_file): + if prompt_cache_file is None: + return + try: + loaded_cache, metadata = load_prompt_cache( + prompt_cache_file, + return_metadata=True, + ) + except Exception as e: + logging.warning( + f"Failed to load prompt cache file {prompt_cache_file}: {type(e).__name__}: {e}" + ) + return + + cached_tokens = metadata.get("tokens") + cache_type = metadata.get("cache_type", "assistant") + if not isinstance(cached_tokens, list): + logging.warning(f"Prompt cache file {prompt_cache_file} has invalid token metadata") + return + if prompt[: len(cached_tokens)] != cached_tokens: + logging.info(f"Prompt cache file {prompt_cache_file} does not match request prefix") + return + prompt_cache.insert_cache( + model_key, + cached_tokens, + loaded_cache, + cache_type=cache_type, + ) + + def _save_disk_prompt_cache(self, prompt_cache_file, cache_key, prompt_cache, cache_type): + if prompt_cache_file is None: + return + metadata = { + "tokens": cache_key, + "cache_type": cache_type, + } + try: + Path(prompt_cache_file).parent.mkdir(parents=True, exist_ok=True) + save_prompt_cache(prompt_cache_file, prompt_cache, metadata) + except Exception as e: + logging.warning( + f"Failed to save prompt cache file {prompt_cache_file}: {type(e).__name__}: {e}" + ) + def generate( self, request: CompletionRequest, @@ -1190,6 +1276,10 @@ def do_POST(self): self.top_logprobs = self.body.get("top_logprobs", -1) self.seed = self.body.get("seed", None) self.chat_template_kwargs = self.body.get("chat_template_kwargs") + self.prompt_cache_file = self.body.get("prompt_cache_file") + if self.prompt_cache_file is None: + self.prompt_cache_file = self.body.get("prompt-cache-file") + self.disable_prompt_cache = self.body.get("disable_prompt_cache", False) self.validate_model_parameters() # Get stop sequences @@ -1249,6 +1339,8 @@ def validate_model_parameters(self): self._validate("adapter", str, optional=True) self._validate("seed", int, optional=True) self._validate("logit_bias", dict, optional=True) + self._validate("prompt_cache_file", str, optional=True) + self._validate("disable_prompt_cache", bool) if self.logit_bias is not None: try: @@ -1403,6 +1495,8 @@ def handle_completion(self, request: CompletionRequest, stop_words: List[str]): top_logprobs=self.top_logprobs, seed=self.seed, chat_template_kwargs=self.chat_template_kwargs, + prompt_cache_file=self.prompt_cache_file, + disable_prompt_cache=self.disable_prompt_cache, ) # Keep connection allive during long prompt processing (and also log @@ -1596,6 +1690,7 @@ def handle_chat_completions(self) -> CompletionRequest: body["messages"], body.get("tools") or None, body.get("role_mapping"), + self.prompt_cache_file, ) def handle_text_completions(self) -> CompletionRequest: @@ -1615,6 +1710,7 @@ def handle_text_completions(self) -> CompletionRequest: [], None, None, + self.prompt_cache_file, ) def do_GET(self):