diff --git a/mlx_lm/server.py b/mlx_lm/server.py index ce8d95817..6744a2995 100644 --- a/mlx_lm/server.py +++ b/mlx_lm/server.py @@ -404,10 +404,7 @@ def _make_sampler(args, tokenizer): min_p=args.sampling.min_p, xtc_probability=args.sampling.xtc_probability, xtc_threshold=args.sampling.xtc_threshold, - xtc_special_tokens=[ - tokenizer.eos_token_id, - tokenizer.encode("\n"), - ], + xtc_special_tokens=tokenizer.encode("\n") + list(tokenizer.eos_token_ids), ) @@ -1243,8 +1240,8 @@ def validate_model_parameters(self): self._validate("frequency_context_size", int, min_val=0) self._validate("logprobs", bool) self._validate("top_logprobs", int, min_val=0, max_val=11, whitelist=[-1]) - self._validate("xtc_probability", float, min_val=0, max_val=1) - self._validate("xtc_threshold", float, min_val=0, max_val=1) + self._validate("xtc_probability", (float, int), min_val=0, max_val=1) + self._validate("xtc_threshold", (float, int), min_val=0, max_val=1) self._validate("requested_model", str) self._validate("adapter", str, optional=True) self._validate("seed", int, optional=True)