Add logits processor arguments to mlx_lm.generate#1273
Open
realyxl wants to merge 1 commit into
Open
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Background
mlx_lm.generatedoes not currently expose seven logits-processorparameters that are available in the Python API:
logit_bias,{repetition,presence,frequency}_penalty, and their three_context_sizecompanions.
These parameters are accepted by
mlx_lm.sample_utils.make_logits_processors(), which produces the list thatgenerate()/stream_generate()already accept via thelogits_processorskeyword argument.mlx_lm.serveralready exposes thesame seven parameters over HTTP.
Goal
Add the same seven parameters to the
mlx_lm.generateCLI.Motivation
Common need in terminal testing.
mlx_lm.generateis convenient forquickly testing local checkpoints, LoRA adapters, and quantization levels
from the terminal. These parameters are routinely useful, and sometimes
necessary, for suppressing the repetition loops that small or
low-bit-quantized local models often fall into.
Consistency.
generate()already supportslogits_processors, andmlx_lm.sample_utils.make_logits_processorsisthe standard factory; the
mlx_lm.generateCLI does not yet expose them.mlx_lm.generatealready exposes theparallel
make_samplerparameters (--temp,--top-p,--top-k,--min-p,--xtc-*); the logits-processor side was asymmetric.Minimal and additive. No algorithmic changes. This only wires the
existing factory into argparse.
Why only
generate, notchat.mlx_lm.chatmight be intentionallyminimal — it doesn't expose even
--top-k. This PR keeps that scopeunchanged and focuses on
mlx_lm.generate, the primary terminal generationtool and the one where these knobs are most useful.
Implementation
A single file changed (
mlx_lm/generate.py).make_logits_processors.DEFAULT_*constants.Noneforlogit_bias,0.0for the threepenalties,
20for the three context sizes. This matches the0.0 == disabledconvention used by the sampler defaults in this file and bymlx_lm/server.py.parse_logit_biasargparse type that accepts a JSON object andconverts keys to
int/ values tofloat.add_argumentcalls insetup_arg_parser, ordered to mirror themake_logits_processorssignature.main(), buildlogits_processorsviamake_logits_processors(...)(kwarg form) and pass it to
generate(...)as the existinglogits_processors=kwarg.Tests
No new tests added.
pre-commitand the full existing suite(
tests.test_sample_utils+tests.test_generate, 31 tests) pass locally.