From ee579a8de5ceccd397fdf27a781d519015154501 Mon Sep 17 00:00:00 2001 From: realyxl Date: Thu, 14 May 2026 01:59:52 +0800 Subject: [PATCH] Add logits processor arguments to mlx_lm.generate --- mlx_lm/generate.py | 74 +++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 73 insertions(+), 1 deletion(-) diff --git a/mlx_lm/generate.py b/mlx_lm/generate.py index 3573b2640..e8fccb5db 100644 --- a/mlx_lm/generate.py +++ b/mlx_lm/generate.py @@ -38,7 +38,7 @@ TokenBuffer, load_prompt_cache, ) -from .sample_utils import make_sampler +from .sample_utils import make_logits_processors, make_sampler from .tokenizer_utils import TokenizerWrapper from .utils import does_model_support_input_embeddings, load @@ -51,6 +51,13 @@ DEFAULT_XTC_PROBABILITY = 0.0 DEFAULT_XTC_THRESHOLD = 0.0 DEFAULT_MIN_TOKENS_TO_KEEP = 1 +DEFAULT_LOGIT_BIAS = None +DEFAULT_REPETITION_PENALTY = 0.0 +DEFAULT_REPETITION_CONTEXT_SIZE = 20 +DEFAULT_PRESENCE_PENALTY = 0.0 +DEFAULT_PRESENCE_CONTEXT_SIZE = 20 +DEFAULT_FREQUENCY_PENALTY = 0.0 +DEFAULT_FREQUENCY_CONTEXT_SIZE = 20 DEFAULT_SEED = None DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit" DEFAULT_QUANTIZED_KV_START = 5000 @@ -60,6 +67,18 @@ def str2bool(string): return string.lower() not in ["false", "f"] +def parse_logit_bias(string): + try: + bias = json.loads(string) + return {int(k): float(v) for k, v in bias.items()} + except (ValueError, TypeError, AttributeError) as e: + raise argparse.ArgumentTypeError( + "must be a JSON object mapping token ids (int) to biases (float), " + 'e.g. \'{"100": 3.0, "200": -1.8}\'. ' + f"({e})" + ) + + def setup_arg_parser(): """Set up and return the argument parser.""" parser = argparse.ArgumentParser(description="LLM inference script") @@ -142,6 +161,49 @@ def setup_arg_parser(): default=DEFAULT_MIN_TOKENS_TO_KEEP, help="Minimum tokens to keep for min-p sampling.", ) + parser.add_argument( + "--logit-bias", + type=parse_logit_bias, + default=DEFAULT_LOGIT_BIAS, + help="Additive logit bias as a JSON object mapping token ids to biases, " + 'e.g. \'{"100": 3.0, "200": -1.8}\'.', + ) + parser.add_argument( + "--repetition-penalty", + type=float, + default=DEFAULT_REPETITION_PENALTY, + help="A (sign-aware) multiplicative penalty for repeating tokens.", + ) + parser.add_argument( + "--repetition-context-size", + type=int, + default=DEFAULT_REPETITION_CONTEXT_SIZE, + help="Number of previous tokens to consider for repetition penalty.", + ) + parser.add_argument( + "--presence-penalty", + type=float, + default=DEFAULT_PRESENCE_PENALTY, + help="An additive penalty for each token in the recent context.", + ) + parser.add_argument( + "--presence-context-size", + type=int, + default=DEFAULT_PRESENCE_CONTEXT_SIZE, + help="Number of previous tokens to consider for presence penalty.", + ) + parser.add_argument( + "--frequency-penalty", + type=float, + default=DEFAULT_FREQUENCY_PENALTY, + help="An additive penalty for each occurrence in the recent context.", + ) + parser.add_argument( + "--frequency-context-size", + type=int, + default=DEFAULT_FREQUENCY_CONTEXT_SIZE, + help="Number of previous tokens to consider for frequency penalty.", + ) parser.add_argument( "--seed", type=int, @@ -2069,6 +2131,15 @@ def main(): xtc_threshold=args.xtc_threshold, xtc_special_tokens=tokenizer.encode("\n") + list(tokenizer.eos_token_ids), ) + logits_processors = make_logits_processors( + logit_bias=args.logit_bias, + repetition_penalty=args.repetition_penalty, + repetition_context_size=args.repetition_context_size, + presence_penalty=args.presence_penalty, + presence_context_size=args.presence_context_size, + frequency_penalty=args.frequency_penalty, + frequency_context_size=args.frequency_context_size, + ) response = generate( model, tokenizer, @@ -2076,6 +2147,7 @@ def main(): max_tokens=args.max_tokens, verbose=args.verbose, sampler=sampler, + logits_processors=logits_processors, max_kv_size=args.max_kv_size, prompt_cache=prompt_cache if using_cache else None, kv_bits=args.kv_bits,