Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 73 additions & 1 deletion mlx_lm/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
TokenBuffer,
load_prompt_cache,
)
from .sample_utils import make_sampler
from .sample_utils import make_logits_processors, make_sampler
from .tokenizer_utils import TokenizerWrapper
from .utils import does_model_support_input_embeddings, load

Expand All @@ -51,6 +51,13 @@
DEFAULT_XTC_PROBABILITY = 0.0
DEFAULT_XTC_THRESHOLD = 0.0
DEFAULT_MIN_TOKENS_TO_KEEP = 1
DEFAULT_LOGIT_BIAS = None
DEFAULT_REPETITION_PENALTY = 0.0
DEFAULT_REPETITION_CONTEXT_SIZE = 20
DEFAULT_PRESENCE_PENALTY = 0.0
DEFAULT_PRESENCE_CONTEXT_SIZE = 20
DEFAULT_FREQUENCY_PENALTY = 0.0
DEFAULT_FREQUENCY_CONTEXT_SIZE = 20
DEFAULT_SEED = None
DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit"
DEFAULT_QUANTIZED_KV_START = 5000
Expand All @@ -60,6 +67,18 @@ def str2bool(string):
return string.lower() not in ["false", "f"]


def parse_logit_bias(string):
try:
bias = json.loads(string)
return {int(k): float(v) for k, v in bias.items()}
except (ValueError, TypeError, AttributeError) as e:
raise argparse.ArgumentTypeError(
"must be a JSON object mapping token ids (int) to biases (float), "
'e.g. \'{"100": 3.0, "200": -1.8}\'. '
f"({e})"
)


def setup_arg_parser():
"""Set up and return the argument parser."""
parser = argparse.ArgumentParser(description="LLM inference script")
Expand Down Expand Up @@ -142,6 +161,49 @@ def setup_arg_parser():
default=DEFAULT_MIN_TOKENS_TO_KEEP,
help="Minimum tokens to keep for min-p sampling.",
)
parser.add_argument(
"--logit-bias",
type=parse_logit_bias,
default=DEFAULT_LOGIT_BIAS,
help="Additive logit bias as a JSON object mapping token ids to biases, "
'e.g. \'{"100": 3.0, "200": -1.8}\'.',
)
parser.add_argument(
"--repetition-penalty",
type=float,
default=DEFAULT_REPETITION_PENALTY,
help="A (sign-aware) multiplicative penalty for repeating tokens.",
)
parser.add_argument(
"--repetition-context-size",
type=int,
default=DEFAULT_REPETITION_CONTEXT_SIZE,
help="Number of previous tokens to consider for repetition penalty.",
)
parser.add_argument(
"--presence-penalty",
type=float,
default=DEFAULT_PRESENCE_PENALTY,
help="An additive penalty for each token in the recent context.",
)
parser.add_argument(
"--presence-context-size",
type=int,
default=DEFAULT_PRESENCE_CONTEXT_SIZE,
help="Number of previous tokens to consider for presence penalty.",
)
parser.add_argument(
"--frequency-penalty",
type=float,
default=DEFAULT_FREQUENCY_PENALTY,
help="An additive penalty for each occurrence in the recent context.",
)
parser.add_argument(
"--frequency-context-size",
type=int,
default=DEFAULT_FREQUENCY_CONTEXT_SIZE,
help="Number of previous tokens to consider for frequency penalty.",
)
parser.add_argument(
"--seed",
type=int,
Expand Down Expand Up @@ -2069,13 +2131,23 @@ def main():
xtc_threshold=args.xtc_threshold,
xtc_special_tokens=tokenizer.encode("\n") + list(tokenizer.eos_token_ids),
)
logits_processors = make_logits_processors(
logit_bias=args.logit_bias,
repetition_penalty=args.repetition_penalty,
repetition_context_size=args.repetition_context_size,
presence_penalty=args.presence_penalty,
presence_context_size=args.presence_context_size,
frequency_penalty=args.frequency_penalty,
frequency_context_size=args.frequency_context_size,
)
response = generate(
model,
tokenizer,
prompt,
max_tokens=args.max_tokens,
verbose=args.verbose,
sampler=sampler,
logits_processors=logits_processors,
max_kv_size=args.max_kv_size,
prompt_cache=prompt_cache if using_cache else None,
kv_bits=args.kv_bits,
Expand Down