diff --git a/docs/en/llm/api_server.md b/docs/en/llm/api_server.md index cd322b45ea..382205c0f3 100644 --- a/docs/en/llm/api_server.md +++ b/docs/en/llm/api_server.md @@ -144,6 +144,10 @@ for item in api_client.completions_v1(model=model_name, prompt='hi'): May refer to [api_server_tools](./api_server_tools.md). +### Anthropic-Compatible Endpoints + +May refer to [api_server_anthropic](./api_server_anthropic.md). + ### Integrate with Java/Golang/Rust May use [openapi-generator-cli](https://github.com/OpenAPITools/openapi-generator-cli) to convert `http://{server_ip}:{server_port}/openapi.json` to java/rust/golang client. diff --git a/docs/en/llm/api_server_anthropic.md b/docs/en/llm/api_server_anthropic.md new file mode 100644 index 0000000000..fcd9befa30 --- /dev/null +++ b/docs/en/llm/api_server_anthropic.md @@ -0,0 +1,48 @@ +# Anthropic-Compatible Endpoints + +LMDeploy provides a lightweight Anthropic-compatible surface for easier integration with Anthropic-style clients and gateways. + +## Supported Endpoints + +- `POST /v1/messages` +- `POST /v1/messages/count_tokens` +- `GET /anthropic/v1/models` + +## Required Headers + +For Anthropic `POST` endpoints, include: + +- `content-type: application/json` +- `anthropic-version: 2023-06-01` (or another accepted version string) + +## Notes and Current Limits + +- Tool-call fields are **temporarily unsupported** in this phase (`tools`, `tool_choice`). +- If tool fields are provided, LMDeploy returns an Anthropic-style error response. +- `count_tokens` is tokenizer/chat-template based and is intended for practical estimation. + +## Example: `/v1/messages` + +```bash +curl http://{server_ip}:{server_port}/v1/messages \ + -H "content-type: application/json" \ + -H "anthropic-version: 2023-06-01" \ + -d '{ + "model": "internlm-chat-7b", + "max_tokens": 128, + "messages": [{"role": "user", "content": "Hello from Anthropic client"}] + }' +``` + +## Example: `/v1/messages/count_tokens` + +```bash +curl http://{server_ip}:{server_port}/v1/messages/count_tokens \ + -H "content-type: application/json" \ + -H "anthropic-version: 2023-06-01" \ + -d '{ + "model": "internlm-chat-7b", + "system": "You are a helpful assistant.", + "messages": [{"role": "user", "content": "Count these tokens"}] + }' +``` diff --git a/lmdeploy/serve/anthropic/__init__.py b/lmdeploy/serve/anthropic/__init__.py new file mode 100644 index 0000000000..c669b733a4 --- /dev/null +++ b/lmdeploy/serve/anthropic/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""Anthropic-compatible serving endpoints.""" + +from .router import create_anthropic_router + +__all__ = ['create_anthropic_router'] diff --git a/lmdeploy/serve/anthropic/adapter.py b/lmdeploy/serve/anthropic/adapter.py new file mode 100644 index 0000000000..4900e3e5b4 --- /dev/null +++ b/lmdeploy/serve/anthropic/adapter.py @@ -0,0 +1,105 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""Adapters between Anthropic requests and LMDeploy internals.""" + +from __future__ import annotations + +from typing import Any + +from lmdeploy.messages import GenerationConfig + +from .protocol import CountTokensRequest, MessagesRequest, TextContentBlockParam + + +def get_model_list(server_context) -> list[str]: + """Return available model names from the server context.""" + + model_names = [server_context.async_engine.model_name] + cfg = server_context.async_engine.backend_config + model_names += getattr(cfg, 'adapters', None) or [] + return model_names + + +def ensure_tools_not_requested(request: MessagesRequest | CountTokensRequest) -> None: + """Reject tool-related fields while parser refactor is in progress.""" + + if getattr(request, 'tools', None): + raise NotImplementedError('Anthropic tool fields are temporarily unsupported.') + if getattr(request, 'tool_choice', None) is not None: + raise NotImplementedError('Anthropic tool_choice is temporarily unsupported.') + + +def _text_from_blocks(blocks: list[TextContentBlockParam | dict[str, Any]], field_name: str) -> str: + out: list[str] = [] + for idx, block in enumerate(blocks): + if isinstance(block, dict): + block_type = block.get('type') + text = block.get('text') + else: + block_type = block.type + text = block.text + if block_type != 'text': + raise ValueError( + f'Only text content blocks are supported in `{field_name}`. ' + f'Got: {block_type!r} at index {idx}.') + if text is None: + raise ValueError(f'Missing `text` in `{field_name}` content block at index {idx}.') + out.append(text) + return ''.join(out) + + +def text_from_content(content: str | list[TextContentBlockParam], field_name: str) -> str: + """Normalize Anthropic content field to plain text.""" + + if isinstance(content, str): + return content + return _text_from_blocks(content, field_name=field_name) + + +def to_lmdeploy_messages(request: MessagesRequest | CountTokensRequest) -> list[dict[str, str]]: + """Convert Anthropic request messages into LMDeploy chat messages.""" + + lm_messages: list[dict[str, str]] = [] + if request.system is not None: + lm_messages.append( + dict(role='system', content=text_from_content(request.system, field_name='system'))) + for idx, message in enumerate(request.messages): + content = text_from_content(message.content, field_name=f'messages[{idx}].content') + lm_messages.append(dict(role=message.role, content=content)) + return lm_messages + + +def to_generation_config(request: MessagesRequest) -> GenerationConfig: + """Map Anthropic messages request to LMDeploy generation config.""" + + return GenerationConfig( + max_new_tokens=request.max_tokens, + do_sample=True, + top_k=40 if request.top_k is None else request.top_k, + top_p=1.0 if request.top_p is None else request.top_p, + temperature=1.0 if request.temperature is None else request.temperature, + stop_words=request.stop_sequences, + skip_special_tokens=True, + spaces_between_special_tokens=True, + ) + + +def count_input_tokens(async_engine, messages: list[dict[str, str]]) -> int: + """Approximate Anthropic token counting using LMDeploy + tokenizer/template.""" + + prompt = async_engine.chat_template.messages2prompt(messages, sequence_start=True) + token_ids = async_engine.tokenizer.encode(prompt, add_bos=True) + return len(token_ids) + + +def map_finish_reason(reason: str | None) -> str: + """Map LMDeploy/OpenAI-like finish reason to Anthropic stop reason.""" + + mapping = { + 'stop': 'end_turn', + 'length': 'max_tokens', + 'tool_calls': 'stop_sequence', + 'abort': 'stop_sequence', + 'error': 'stop_sequence', + } + return mapping.get(reason, 'end_turn') diff --git a/lmdeploy/serve/anthropic/endpoints/__init__.py b/lmdeploy/serve/anthropic/endpoints/__init__.py new file mode 100644 index 0000000000..f77ef80e19 --- /dev/null +++ b/lmdeploy/serve/anthropic/endpoints/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""Anthropic endpoint modules.""" diff --git a/lmdeploy/serve/anthropic/endpoints/messages.py b/lmdeploy/serve/anthropic/endpoints/messages.py new file mode 100644 index 0000000000..c60a9b30ae --- /dev/null +++ b/lmdeploy/serve/anthropic/endpoints/messages.py @@ -0,0 +1,101 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""Endpoint for ``POST /v1/messages``.""" + +from __future__ import annotations + +from http import HTTPStatus + +import shortuuid +from fastapi import APIRouter, Depends, Request +from fastapi.responses import StreamingResponse + +from lmdeploy.serve.utils.server_utils import validate_json_request + +from ..adapter import ( + ensure_tools_not_requested, + get_model_list, + map_finish_reason, + to_generation_config, + to_lmdeploy_messages, +) +from ..errors import create_error_response +from ..protocol import MessagesRequest, MessagesResponse, MessageTextBlock, MessageUsage +from ..streaming import stream_messages_response + + +def _validate_headers(raw_request: Request): + anthropic_version = raw_request.headers.get('anthropic-version') + if not anthropic_version: + return create_error_response(HTTPStatus.BAD_REQUEST, 'Missing required header: anthropic-version') + return None + + +def register(router: APIRouter, server_context) -> None: + """Register endpoint onto router.""" + + @router.post('/v1/messages', dependencies=[Depends(validate_json_request)]) + async def create_message(request: MessagesRequest, raw_request: Request): + header_error = _validate_headers(raw_request) + if header_error is not None: + return header_error + + if request.model not in get_model_list(server_context): + return create_error_response( + HTTPStatus.NOT_FOUND, + f'The model {request.model!r} does not exist.', + error_type='not_found_error', + ) + + try: + ensure_tools_not_requested(request) + messages = to_lmdeploy_messages(request) + except NotImplementedError as err: + return create_error_response(HTTPStatus.BAD_REQUEST, str(err)) + except ValueError as err: + return create_error_response(HTTPStatus.BAD_REQUEST, str(err)) + + session = server_context.get_session(-1) + adapter_name = None if request.model == server_context.async_engine.model_name else request.model + result_generator = server_context.async_engine.generate( + messages, + session, + gen_config=to_generation_config(request), + stream_response=True, + sequence_start=True, + sequence_end=True, + do_preprocess=True, + adapter_name=adapter_name, + ) + + request_id = f'msg_{shortuuid.random()}' + + if request.stream: + return StreamingResponse( + stream_messages_response(result_generator, request_id=request_id, model=request.model), + media_type='text/event-stream', + ) + + text = '' + final_res = None + async for res in result_generator: + if await raw_request.is_disconnected(): + await session.async_abort() + return create_error_response(HTTPStatus.BAD_REQUEST, 'Client disconnected') + final_res = res + text += res.response or '' + + if final_res is None: + return create_error_response(HTTPStatus.INTERNAL_SERVER_ERROR, 'No generation output from engine.') + + response = MessagesResponse( + id=request_id, + model=request.model, + content=[MessageTextBlock(text=text)], + stop_reason=map_finish_reason(final_res.finish_reason), + stop_sequence=None, + usage=MessageUsage( + input_tokens=final_res.input_token_len, + output_tokens=final_res.generate_token_len, + ), + ) + return response.model_dump() diff --git a/lmdeploy/serve/anthropic/endpoints/messages_count_tokens.py b/lmdeploy/serve/anthropic/endpoints/messages_count_tokens.py new file mode 100644 index 0000000000..d739dd3dde --- /dev/null +++ b/lmdeploy/serve/anthropic/endpoints/messages_count_tokens.py @@ -0,0 +1,49 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""Endpoint for ``POST /v1/messages/count_tokens``.""" + +from __future__ import annotations + +from http import HTTPStatus + +from fastapi import APIRouter, Depends, Request + +from lmdeploy.serve.utils.server_utils import validate_json_request + +from ..adapter import count_input_tokens, ensure_tools_not_requested, get_model_list, to_lmdeploy_messages +from ..errors import create_error_response +from ..protocol import CountTokensRequest, CountTokensResponse + + +def _validate_headers(raw_request: Request): + anthropic_version = raw_request.headers.get('anthropic-version') + if not anthropic_version: + return create_error_response(HTTPStatus.BAD_REQUEST, 'Missing required header: anthropic-version') + return None + + +def register(router: APIRouter, server_context) -> None: + """Register endpoint onto router.""" + + @router.post('/v1/messages/count_tokens', dependencies=[Depends(validate_json_request)]) + async def count_tokens(request: CountTokensRequest, raw_request: Request): + header_error = _validate_headers(raw_request) + if header_error is not None: + return header_error + + if request.model not in get_model_list(server_context): + return create_error_response( + HTTPStatus.NOT_FOUND, + f'The model {request.model!r} does not exist.', + error_type='not_found_error', + ) + + try: + ensure_tools_not_requested(request) + messages = to_lmdeploy_messages(request) + input_tokens = count_input_tokens(server_context.async_engine, messages) + except NotImplementedError as err: + return create_error_response(HTTPStatus.BAD_REQUEST, str(err)) + except ValueError as err: + return create_error_response(HTTPStatus.BAD_REQUEST, str(err)) + + return CountTokensResponse(input_tokens=input_tokens).model_dump() diff --git a/lmdeploy/serve/anthropic/endpoints/models.py b/lmdeploy/serve/anthropic/endpoints/models.py new file mode 100644 index 0000000000..3ea3534f7f --- /dev/null +++ b/lmdeploy/serve/anthropic/endpoints/models.py @@ -0,0 +1,20 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""Endpoint for Anthropic-scoped model listing.""" + +from __future__ import annotations + +from fastapi import APIRouter + +from ..adapter import get_model_list +from ..protocol import AnthropicModel, AnthropicModelList + + +def register(router: APIRouter, server_context) -> None: + """Register endpoint onto router.""" + + @router.get('/anthropic/v1/models') + async def list_models(): + models = [AnthropicModel(id=name, display_name=name) for name in get_model_list(server_context)] + first_id = models[0].id if models else None + last_id = models[-1].id if models else None + return AnthropicModelList(data=models, first_id=first_id, last_id=last_id).model_dump() diff --git a/lmdeploy/serve/anthropic/errors.py b/lmdeploy/serve/anthropic/errors.py new file mode 100644 index 0000000000..07f5f41213 --- /dev/null +++ b/lmdeploy/serve/anthropic/errors.py @@ -0,0 +1,17 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""Error helpers for Anthropic-compatible endpoints.""" + +from __future__ import annotations + +from http import HTTPStatus + +from fastapi.responses import JSONResponse + +from .protocol import AnthropicError, AnthropicErrorResponse + + +def create_error_response(status: HTTPStatus, message: str, error_type: str = 'invalid_request_error') -> JSONResponse: + """Create Anthropic-style error response.""" + + payload = AnthropicErrorResponse(error=AnthropicError(type=error_type, message=message)).model_dump() + return JSONResponse(payload, status_code=status.value) diff --git a/lmdeploy/serve/anthropic/protocol.py b/lmdeploy/serve/anthropic/protocol.py new file mode 100644 index 0000000000..63ae294fa2 --- /dev/null +++ b/lmdeploy/serve/anthropic/protocol.py @@ -0,0 +1,116 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""Anthropic-compatible protocol models.""" + +from __future__ import annotations + +import time +from typing import Any, Literal + +import shortuuid +from pydantic import BaseModel, Field + + +class AnthropicError(BaseModel): + """Anthropic-style error body.""" + + type: str + message: str + + +class AnthropicErrorResponse(BaseModel): + """Anthropic-style error response.""" + + type: Literal['error'] = 'error' + error: AnthropicError + + +class TextContentBlockParam(BaseModel): + """Input text content block.""" + + type: Literal['text'] = 'text' + text: str + + +class MessageParam(BaseModel): + """Anthropic input message.""" + + role: Literal['user', 'assistant'] + content: str | list[TextContentBlockParam] + + +class MessagesRequest(BaseModel): + """Request body for ``POST /v1/messages``.""" + + model: str + messages: list[MessageParam] + max_tokens: int = Field(gt=0) + system: str | list[TextContentBlockParam] | None = None + stop_sequences: list[str] | None = None + stream: bool = False + temperature: float | None = 1.0 + top_p: float | None = None + top_k: int | None = None + metadata: dict[str, Any] | None = None + tools: list[dict[str, Any]] | None = None + tool_choice: str | dict[str, Any] | None = None + service_tier: Literal['auto', 'standard_only'] | None = None + + +class MessageTextBlock(BaseModel): + """Output text content block.""" + + type: Literal['text'] = 'text' + text: str + + +class MessageUsage(BaseModel): + """Token usage in Anthropic style.""" + + input_tokens: int = 0 + output_tokens: int = 0 + + +class MessagesResponse(BaseModel): + """Response body for ``POST /v1/messages``.""" + + id: str = Field(default_factory=lambda: f'msg_{shortuuid.random()}') + type: Literal['message'] = 'message' + role: Literal['assistant'] = 'assistant' + content: list[MessageTextBlock] + model: str + stop_reason: Literal['end_turn', 'max_tokens', 'stop_sequence'] | None = None + stop_sequence: str | None = None + usage: MessageUsage + + +class CountTokensRequest(BaseModel): + """Request body for ``POST /v1/messages/count_tokens``.""" + + model: str + messages: list[MessageParam] + system: str | list[TextContentBlockParam] | None = None + tools: list[dict[str, Any]] | None = None + + +class CountTokensResponse(BaseModel): + """Response body for ``POST /v1/messages/count_tokens``.""" + + input_tokens: int + + +class AnthropicModel(BaseModel): + """Anthropic-like model metadata.""" + + id: str + type: Literal['model'] = 'model' + display_name: str + created_at: int = Field(default_factory=lambda: int(time.time())) + + +class AnthropicModelList(BaseModel): + """Anthropic-like model listing response.""" + + data: list[AnthropicModel] + has_more: bool = False + first_id: str | None = None + last_id: str | None = None diff --git a/lmdeploy/serve/anthropic/router.py b/lmdeploy/serve/anthropic/router.py new file mode 100644 index 0000000000..7fcfa08d9a --- /dev/null +++ b/lmdeploy/serve/anthropic/router.py @@ -0,0 +1,18 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""Router assembly for Anthropic-compatible endpoints.""" + +from __future__ import annotations + +from fastapi import APIRouter + +from .endpoints import messages, messages_count_tokens, models + + +def create_anthropic_router(server_context) -> APIRouter: + """Create router with all Anthropic endpoints.""" + + router = APIRouter(tags=['anthropic']) + messages.register(router, server_context) + messages_count_tokens.register(router, server_context) + models.register(router, server_context) + return router diff --git a/lmdeploy/serve/anthropic/streaming.py b/lmdeploy/serve/anthropic/streaming.py new file mode 100644 index 0000000000..c0894a513f --- /dev/null +++ b/lmdeploy/serve/anthropic/streaming.py @@ -0,0 +1,90 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""Streaming helpers for Anthropic-compatible responses.""" + +from __future__ import annotations + +import json +from collections.abc import AsyncGenerator + +from .adapter import map_finish_reason + + +def _format_sse(event: str, data: dict) -> str: + return f'event: {event}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n' + + +async def stream_messages_response(result_generator, + *, + request_id: str, + model: str) -> AsyncGenerator[str, None]: + """Convert LMDeploy generation stream to Anthropic SSE events.""" + + yield _format_sse( + 'message_start', + { + 'type': 'message_start', + 'message': { + 'id': request_id, + 'type': 'message', + 'role': 'assistant', + 'content': [], + 'model': model, + 'stop_reason': None, + 'stop_sequence': None, + 'usage': { + 'input_tokens': 0, + 'output_tokens': 0, + }, + }, + }, + ) + yield _format_sse( + 'content_block_start', + { + 'type': 'content_block_start', + 'index': 0, + 'content_block': { + 'type': 'text', + 'text': '', + }, + }, + ) + + final_res = None + input_tokens = 0 + async for res in result_generator: + final_res = res + input_tokens = res.input_token_len + text = res.response or '' + if text: + yield _format_sse( + 'content_block_delta', + { + 'type': 'content_block_delta', + 'index': 0, + 'delta': { + 'type': 'text_delta', + 'text': text, + }, + }, + ) + + yield _format_sse('content_block_stop', {'type': 'content_block_stop', 'index': 0}) + + output_tokens = 0 if final_res is None else final_res.generate_token_len + stop_reason = map_finish_reason(None if final_res is None else final_res.finish_reason) + yield _format_sse( + 'message_delta', + { + 'type': 'message_delta', + 'delta': { + 'stop_reason': stop_reason, + 'stop_sequence': None, + }, + 'usage': { + 'input_tokens': input_tokens, + 'output_tokens': output_tokens, + }, + }, + ) + yield _format_sse('message_stop', {'type': 'message_stop'}) diff --git a/lmdeploy/serve/openai/api_server.py b/lmdeploy/serve/openai/api_server.py index ea1b594af9..95146a8216 100644 --- a/lmdeploy/serve/openai/api_server.py +++ b/lmdeploy/serve/openai/api_server.py @@ -41,6 +41,7 @@ DistServeInitRequest, MigrationRequest, ) +from lmdeploy.serve.anthropic import create_anthropic_router from lmdeploy.serve.core import AsyncEngine from lmdeploy.serve.openai.harmony_utils import GptOssChatParser from lmdeploy.serve.openai.protocol import ( @@ -1542,6 +1543,7 @@ def serve(model_path: str, app = FastAPI(docs_url='/', lifespan=lifespan) app.include_router(router) + app.include_router(create_anthropic_router(VariableInterface)) app.add_exception_handler(RequestValidationError, validation_exception_handler) mount_metrics(app, backend_config) diff --git a/tests/test_lmdeploy/serve/anthropic/test_endpoints.py b/tests/test_lmdeploy/serve/anthropic/test_endpoints.py new file mode 100644 index 0000000000..62b209fa58 --- /dev/null +++ b/tests/test_lmdeploy/serve/anthropic/test_endpoints.py @@ -0,0 +1,202 @@ +from __future__ import annotations + +from types import SimpleNamespace + +from fastapi import FastAPI +from fastapi.testclient import TestClient + +from lmdeploy.serve.anthropic.endpoints import messages, messages_count_tokens, models +from lmdeploy.serve.anthropic.router import create_anthropic_router + + +class _FakeSession: + + def __init__(self): + self.aborted = False + + async def async_abort(self): + self.aborted = True + + +class _FakeTokenizer: + + def encode(self, text: str, add_bos: bool = True, **kwargs): + tokens = text.split() + if add_bos: + return [0] + list(range(1, len(tokens) + 1)) + return list(range(len(tokens))) + + +class _FakeChatTemplate: + + def messages2prompt(self, messages, sequence_start: bool = True, **kwargs): + parts = [f"{item['role']}:{item['content']}" for item in messages] + return '\n'.join(parts) + + +class _FakeEngine: + + def __init__(self): + self.model_name = 'fake-model' + self.backend_config = SimpleNamespace(adapters=['adapter-model']) + self.tokenizer = _FakeTokenizer() + self.chat_template = _FakeChatTemplate() + + def generate(self, *args, **kwargs): + async def _gen(): + yield SimpleNamespace( + response='Hello ', + input_token_len=8, + generate_token_len=1, + finish_reason=None, + ) + yield SimpleNamespace( + response='world!', + input_token_len=8, + generate_token_len=2, + finish_reason='stop', + ) + + return _gen() + + +class _FakeServerContext: + async_engine = _FakeEngine() + + @staticmethod + def get_session(_session_id: int): + return _FakeSession() + + +def _make_client() -> TestClient: + app = FastAPI() + app.include_router(create_anthropic_router(_FakeServerContext)) + return TestClient(app) + + +def test_endpoint_modules_export_register(): + assert callable(messages.register) + assert callable(messages_count_tokens.register) + assert callable(models.register) + + +def test_messages_non_stream(): + client = _make_client() + response = client.post( + '/v1/messages', + headers={'anthropic-version': '2023-06-01'}, + json={ + 'model': 'fake-model', + 'max_tokens': 16, + 'messages': [{ + 'role': 'user', + 'content': 'Hi there', + }], + }, + ) + assert response.status_code == 200 + data = response.json() + assert data['type'] == 'message' + assert data['content'][0]['type'] == 'text' + assert data['content'][0]['text'] == 'Hello world!' + assert data['stop_reason'] == 'end_turn' + assert data['usage']['input_tokens'] == 8 + assert data['usage']['output_tokens'] == 2 + + +def test_messages_requires_anthropic_version_header(): + client = _make_client() + response = client.post( + '/v1/messages', + json={ + 'model': 'fake-model', + 'max_tokens': 16, + 'messages': [{ + 'role': 'user', + 'content': 'Hi there', + }], + }, + ) + assert response.status_code == 400 + assert response.json()['error']['message'] == 'Missing required header: anthropic-version' + + +def test_messages_rejects_tools_temporarily(): + client = _make_client() + response = client.post( + '/v1/messages', + headers={'anthropic-version': '2023-06-01'}, + json={ + 'model': 'fake-model', + 'max_tokens': 16, + 'messages': [{ + 'role': 'user', + 'content': 'Hi there', + }], + 'tools': [{ + 'name': 'search', + 'description': 'demo', + 'input_schema': { + 'type': 'object', + 'properties': {}, + }, + }], + }, + ) + assert response.status_code == 400 + assert 'temporarily unsupported' in response.json()['error']['message'] + + +def test_messages_streaming_sse_shape(): + client = _make_client() + with client.stream( + 'POST', + '/v1/messages', + headers={'anthropic-version': '2023-06-01'}, + json={ + 'model': 'fake-model', + 'max_tokens': 16, + 'stream': True, + 'messages': [{ + 'role': 'user', + 'content': 'Hi there', + }], + }, + ) as response: + body = '\n'.join(response.iter_lines()) + + assert response.status_code == 200 + assert 'event: message_start' in body + assert 'event: content_block_start' in body + assert 'event: content_block_delta' in body + assert 'event: message_delta' in body + assert 'event: message_stop' in body + + +def test_count_tokens(): + client = _make_client() + response = client.post( + '/v1/messages/count_tokens', + headers={'anthropic-version': '2023-06-01'}, + json={ + 'model': 'fake-model', + 'messages': [{ + 'role': 'user', + 'content': 'count these tokens', + }], + 'system': 'You are helpful.', + }, + ) + assert response.status_code == 200 + data = response.json() + assert isinstance(data['input_tokens'], int) + assert data['input_tokens'] > 0 + + +def test_anthropic_model_listing(): + client = _make_client() + response = client.get('/anthropic/v1/models') + assert response.status_code == 200 + data = response.json() + assert data['has_more'] is False + assert [item['id'] for item in data['data']] == ['fake-model', 'adapter-model']