Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 98 additions & 2 deletions mlx_lm/server.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright © 2023-2024 Apple Inc.

import argparse
import copy
import json
import logging
import pickle
Expand Down Expand Up @@ -40,7 +41,10 @@
)
from .models.cache import (
LRUPromptCache,
load_prompt_cache,
make_prompt_cache,
save_prompt_cache,
trim_prompt_cache,
)
from .sample_utils import make_logits_processors, make_sampler
from .utils import _parse_size, load, sharded_load
Expand Down Expand Up @@ -192,6 +196,8 @@ class GenerationArguments:
top_logprobs: int
seed: Optional[int]
chat_template_kwargs: Optional[Dict[str, Any]]
prompt_cache_file: Optional[str]
disable_prompt_cache: bool = False


@dataclass
Expand All @@ -203,6 +209,7 @@ class CompletionRequest:
messages: List[Any]
tools: Optional[List[Any]]
role_mapping: Optional[Dict[str, Any]]
prompt_cache_file: Optional[str] = None


@dataclass
Expand Down Expand Up @@ -683,7 +690,12 @@ def _make_state_machine(
return sm, sequences

def _is_batchable(self, args):
return self.model_provider.is_batchable and args.seed is None
return (
self.model_provider.is_batchable
and args.seed is None
and args.prompt_cache_file is None
and not args.disable_prompt_cache
)

def _generate(self):
# Local thread stream that we 'll pass to the BatchGenerator to make
Expand Down Expand Up @@ -962,7 +974,19 @@ def progress(tokens_processed, tokens_total):

# Load the KV cache
self._log_cache_stats()
cache, rest = self.prompt_cache.fetch_nearest_cache(
request_prompt_cache = self.prompt_cache
if args.prompt_cache_file is not None or args.disable_prompt_cache:
request_prompt_cache = LRUPromptCache(
self.prompt_cache.max_size,
self.prompt_cache.max_bytes,
)
self._load_disk_prompt_cache(
request_prompt_cache,
self.model_provider.model_key,
prompt,
args.prompt_cache_file if not args.disable_prompt_cache else None,
)
cache, rest = request_prompt_cache.fetch_nearest_cache(
self.model_provider.model_key, prompt
)
ctx.prompt_cache_count = len(prompt) - len(rest)
Expand All @@ -973,6 +997,7 @@ def progress(tokens_processed, tokens_total):
cache += make_prompt_cache(self.model_provider.draft_model)

# Process the prompt and generate tokens
saved_disk_prompt_cache = False
for gen in stream_generate(
model=model,
tokenizer=tokenizer,
Expand Down Expand Up @@ -1003,6 +1028,23 @@ def progress(tokens_processed, tokens_total):
),
)
)
if (
args.prompt_cache_file is not None
and not args.disable_prompt_cache
and not saved_disk_prompt_cache
):
disk_prompt_cache = copy.deepcopy(cache)
disk_cache_key = prompt[:]
if len(disk_cache_key) > 1:
trim_prompt_cache(disk_prompt_cache, 1)
disk_cache_key = disk_cache_key[:-1]
self._save_disk_prompt_cache(
args.prompt_cache_file,
disk_cache_key,
disk_prompt_cache,
"assistant",
)
saved_disk_prompt_cache = True
cache_key.append(gen.token)

if ctx._should_stop:
Expand All @@ -1023,6 +1065,50 @@ def progress(tokens_processed, tokens_total):
except Exception as e:
rqueue.put(e)

def _load_disk_prompt_cache(self, prompt_cache, model_key, prompt, prompt_cache_file):
if prompt_cache_file is None:
return
try:
loaded_cache, metadata = load_prompt_cache(
prompt_cache_file,
return_metadata=True,
)
except Exception as e:
logging.warning(
f"Failed to load prompt cache file {prompt_cache_file}: {type(e).__name__}: {e}"
)
return

cached_tokens = metadata.get("tokens")
cache_type = metadata.get("cache_type", "assistant")
if not isinstance(cached_tokens, list):
logging.warning(f"Prompt cache file {prompt_cache_file} has invalid token metadata")
return
if prompt[: len(cached_tokens)] != cached_tokens:
logging.info(f"Prompt cache file {prompt_cache_file} does not match request prefix")
return
prompt_cache.insert_cache(
model_key,
cached_tokens,
loaded_cache,
cache_type=cache_type,
)

def _save_disk_prompt_cache(self, prompt_cache_file, cache_key, prompt_cache, cache_type):
if prompt_cache_file is None:
return
metadata = {
"tokens": cache_key,
"cache_type": cache_type,
}
try:
Path(prompt_cache_file).parent.mkdir(parents=True, exist_ok=True)
save_prompt_cache(prompt_cache_file, prompt_cache, metadata)
except Exception as e:
logging.warning(
f"Failed to save prompt cache file {prompt_cache_file}: {type(e).__name__}: {e}"
)

def generate(
self,
request: CompletionRequest,
Expand Down Expand Up @@ -1190,6 +1276,10 @@ def do_POST(self):
self.top_logprobs = self.body.get("top_logprobs", -1)
self.seed = self.body.get("seed", None)
self.chat_template_kwargs = self.body.get("chat_template_kwargs")
self.prompt_cache_file = self.body.get("prompt_cache_file")
if self.prompt_cache_file is None:
self.prompt_cache_file = self.body.get("prompt-cache-file")
self.disable_prompt_cache = self.body.get("disable_prompt_cache", False)
self.validate_model_parameters()

# Get stop sequences
Expand Down Expand Up @@ -1249,6 +1339,8 @@ def validate_model_parameters(self):
self._validate("adapter", str, optional=True)
self._validate("seed", int, optional=True)
self._validate("logit_bias", dict, optional=True)
self._validate("prompt_cache_file", str, optional=True)
self._validate("disable_prompt_cache", bool)

if self.logit_bias is not None:
try:
Expand Down Expand Up @@ -1403,6 +1495,8 @@ def handle_completion(self, request: CompletionRequest, stop_words: List[str]):
top_logprobs=self.top_logprobs,
seed=self.seed,
chat_template_kwargs=self.chat_template_kwargs,
prompt_cache_file=self.prompt_cache_file,
disable_prompt_cache=self.disable_prompt_cache,
)

# Keep connection allive during long prompt processing (and also log
Expand Down Expand Up @@ -1596,6 +1690,7 @@ def handle_chat_completions(self) -> CompletionRequest:
body["messages"],
body.get("tools") or None,
body.get("role_mapping"),
self.prompt_cache_file,
)

def handle_text_completions(self) -> CompletionRequest:
Expand All @@ -1615,6 +1710,7 @@ def handle_text_completions(self) -> CompletionRequest:
[],
None,
None,
self.prompt_cache_file,
)

def do_GET(self):
Expand Down