diff --git a/mlx_lm/tokenizer_utils.py b/mlx_lm/tokenizer_utils.py index c7e50fbe7..d8260379a 100644 --- a/mlx_lm/tokenizer_utils.py +++ b/mlx_lm/tokenizer_utils.py @@ -561,6 +561,8 @@ def _infer_tool_parser(chat_template): return "glm47" elif "<|tool_list_start|>" in chat_template: return "pythonic" + elif "" in chat_template: + return "olmo3" elif ( "\\n\n +function_name(arg1="value1", arg2=2) + + +Multiple tool calls are newline-separated within the tags. Argument values +are JSON literals (null, true, false) instead of Python literals. +""" + + +_tool_call_regex = re.compile(r"^\s*(\w+)\((.*)\)\s*$", re.MULTILINE) +_tool_args_regex = re.compile(r'(\w+)=(?:"([^"]*)"|([^,]+))(?:,\s*|$)', re.DOTALL) + +_JSON_LITERALS = {"null": None, "true": True, "false": False} + + +def _coerce(value: str): + value = value.strip() + if value in _JSON_LITERALS: + return _JSON_LITERALS[value] + try: + return ast.literal_eval(value) + except (ValueError, SyntaxError): + return value + + +def parse_tool_call(text: str, tools: Any | None = None): + calls = [] + for match in _tool_call_regex.finditer(text): + func_name = match.group(1) + args_str = match.group(2) + arguments = {} + if args_str: + for key, quoted, raw in _tool_args_regex.findall(args_str): + arguments[key.strip()] = quoted if quoted else _coerce(raw) + calls.append({"name": func_name, "arguments": arguments}) + + if not calls: + raise ValueError("No function provided.") + + return calls if len(calls) > 1 else calls[0] + + +tool_call_start = "" +tool_call_end = "" diff --git a/tests/test_tool_parsing.py b/tests/test_tool_parsing.py index 52892b7ff..7f8d0be10 100644 --- a/tests/test_tool_parsing.py +++ b/tests/test_tool_parsing.py @@ -10,6 +10,7 @@ longcat, minimax_m2, mistral, + olmo3, pythonic, qwen3_coder, ) @@ -53,6 +54,10 @@ def test_parsers(self): "[multiply(a=12234585, b=48838483920)]", pythonic, ), + ( + "multiply(a=12234585, b=48838483920)", + olmo3, + ), ( 'multiply[ARGS]{"a": 12234585, "b": 48838483920}', mistral, @@ -123,6 +128,10 @@ def test_parsers(self): '[get_current_temperature(location="London")]', pythonic, ), + ( + 'get_current_temperature(location="London")', + olmo3, + ), ( 'get_current_temperature[ARGS]{"location": "London"}', mistral, @@ -313,6 +322,30 @@ def test_kimi_k2(self): ] self.assertEqual(tool_calls, expected) + def test_olmo3(self): + # Multiple tool calls + test_case = ( + 'search(query="weather")\n' + 'read_file(path="/tmp/test.txt")' + ) + tool_calls = olmo3.parse_tool_call(test_case, None) + self.assertIsInstance(tool_calls, list) + self.assertEqual(len(tool_calls), 2) + self.assertEqual(tool_calls[0], {"name": "search", "arguments": {"query": "weather"}}) + self.assertEqual( + tool_calls[1], + {"name": "read_file", "arguments": {"path": "/tmp/test.txt"}}, + ) + + # JSON literals are accepted (true/false/null instead of Python literals) + test_case = 'configure(enabled=true, name="x", missing=null)' + tool_call = olmo3.parse_tool_call(test_case, None) + self.assertEqual(tool_call["name"], "configure") + self.assertEqual( + tool_call["arguments"], + {"enabled": True, "name": "x", "missing": None}, + ) + def test_minimax_m2(self): test_case = ( '\n'