Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions docs/docs/integrations/bedrock.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,18 @@ Export it before running Mellea:
export AWS_BEARER_TOKEN_BEDROCK=your-bedrock-key
```

## Connecting with `create_bedrock_mantle_backend`
## Connecting with `create_bedrock_openai_backend`

```python
# Requires: mellea
# Returns: ModelOutputThunk
from mellea import MelleaSession
from mellea.backends import model_ids
from mellea.backends.bedrock import create_bedrock_mantle_backend
from mellea.backends.bedrock import create_bedrock_openai_backend
from mellea.stdlib.context import ChatContext

m = MelleaSession(
backend=create_bedrock_mantle_backend(model_id=model_ids.OPENAI_GPT_OSS_120B),
backend=create_bedrock_openai_backend(model_id=model_ids.OPENAI_GPT_OSS_120B),
ctx=ChatContext(),
)

Expand All @@ -54,10 +54,10 @@ The default region is `us-east-1`. Pass `region` to target a different region:
# Requires: mellea
# Returns: MelleaSession
from mellea import MelleaSession
from mellea.backends.bedrock import create_bedrock_mantle_backend
from mellea.backends.bedrock import create_bedrock_openai_backend

m = MelleaSession(
backend=create_bedrock_mantle_backend(
backend=create_bedrock_openai_backend(
model_id="amazon.nova-pro-v1:0",
region="eu-west-1",
)
Expand All @@ -73,10 +73,10 @@ model ID string directly:
# Requires: mellea
# Returns: MelleaSession
from mellea import MelleaSession
from mellea.backends.bedrock import create_bedrock_mantle_backend
from mellea.backends.bedrock import create_bedrock_litellm_backend

m = MelleaSession(
backend=create_bedrock_mantle_backend(
backend=create_bedrock_openai_backend(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be the litellm path given the import above.

model_id="anthropic.claude-3-haiku-20240307-v1:0"
)
)
Expand Down Expand Up @@ -144,7 +144,7 @@ Model X is not supported in region us-east-1.

Either enable model access for the requested model in your AWS account at
[Bedrock Model Access](https://us-east-1.console.aws.amazon.com/bedrock/home#/model-access),
or pass a different `region` to `create_bedrock_mantle_backend`.
or pass a different `region` to `create_bedrock_litellm_backend`.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This got caught in the rename sweep — the troubleshooting tip points to create_bedrock_litellm_backend, but the "Model X is not supported in region" error only comes from create_bedrock_openai_backend (the one that validates via list_mantle_models). The litellm path skips that check entirely.

## Vision support

Expand Down
8 changes: 4 additions & 4 deletions docs/examples/bedrock/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@ uv pip install mellea[litellm]

```python
from mellea import MelleaSession
from mellea.backends.bedrock import create_bedrock_mantle_backend
from mellea.backends.bedrock import create_bedrock_openai_backend
from mellea.backends.model_ids import OPENAI_GPT_OSS_120B
from mellea.stdlib.context import ChatContext

bedrock_oai_backend = create_bedrock_mantle_backend(model_id=OPENAI_GPT_OSS_120B, region="us-east-1")
bedrock_oai_backend = create_bedrock_openai_backend(model_id=OPENAI_GPT_OSS_120B, region="us-east-1")

m = MelleaSession(backend=bedrock_oai_backend, ctx=ChatContext())

Expand All @@ -38,10 +38,10 @@ You can also use your own model IDs as strings, as long as they're accessible us

```python
from mellea import MelleaSession
from mellea.backends.bedrock import create_bedrock_mantle_backend
from mellea.backends.bedrock import create_bedrock_openai_backend
from mellea.stdlib.context import ChatContext

bedrock_oai_backend = create_bedrock_mantle_backend(
bedrock_oai_backend = create_bedrock_openai_backend(
model_id="qwen.qwen3-coder-480b-a35b-instruct",
region="us-east-1"
)
Expand Down
17 changes: 8 additions & 9 deletions docs/examples/bedrock/bedrock_litellm_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
import os

import mellea
from mellea.backends.bedrock import create_bedrock_litellm_backend
from mellea.backends.model_ids import MISTRALAI_DEVSTRAL_2_123B
from mellea.stdlib.context import SimpleContext

try:
import boto3
Expand All @@ -20,16 +23,12 @@
"Run `uv pip install mellea[litellm]`"
)

assert "AWS_BEARER_TOKEN_BEDROCK" in os.environ.keys(), (
"Using AWS Bedrock requires setting a AWS_BEARER_TOKEN_BEDROCK environment variable. "
"Generate a key from the AWS console at: https://us-east-1.console.aws.amazon.com/bedrock/home?region=us-east-1#/api-keys?tab=long-term "
"Then run `export AWS_BEARER_TOKEN_BEDROCK=<insert your key here>"
)
MODEL_ID = MISTRALAI_DEVSTRAL_2_123B

MODEL_ID = "bedrock/converse/us.amazon.nova-pro-v1:0"
backend = create_bedrock_litellm_backend(MODEL_ID)
ctx = SimpleContext()
m = mellea.MelleaSession(backend, ctx)

m = mellea.start_session(backend_name="litellm", model_id=MODEL_ID)

result = m.chat("Give me three facts about Amazon.")
result = m.chat("What model am I talking to rn?")

print(result.content)
7 changes: 4 additions & 3 deletions docs/examples/bedrock/bedrock_openai_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@

from mellea import MelleaSession
from mellea.backends import model_ids
from mellea.backends.bedrock import create_bedrock_mantle_backend
from mellea.backends.bedrock import create_bedrock_openai_backend
from mellea.backends.model_ids import OPENAI_GPT_OSS_120B
from mellea.backends.openai import OpenAIBackend
from mellea.stdlib.context import ChatContext

Expand All @@ -22,10 +23,10 @@
)

m = MelleaSession(
backend=create_bedrock_mantle_backend(model_id=model_ids.OPENAI_GPT_OSS_120B),
backend=create_bedrock_openai_backend(model_id=OPENAI_GPT_OSS_120B),
ctx=ChatContext(),
)

result = m.chat("Give me three facts about Amazon.")
result = m.chat("What model am I talking to rn?")

print(result.content)
85 changes: 83 additions & 2 deletions mellea/backends/bedrock.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,62 @@
"""Helpers for creating bedrock backends from openai/litellm."""

import logging
import os

import boto3
import botocore.exceptions
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

boto3 is in the litellm optional extras, not the base mellea deps. Making it a top-level import means from mellea.backends.bedrock import create_bedrock_openai_backend — which worked fine with a plain pip install mellea before — now raises ModuleNotFoundError on a base install. Scoping the import inside _assert_bedrock_auth (where it's actually needed) would fix the regression without losing anything.

from openai import OpenAI
from openai.pagination import SyncPage

from mellea.backends.litellm import LiteLLMBackend
from mellea.backends.model_ids import ModelIdentifier
from mellea.backends.openai import OpenAIBackend

# botocore logs a credential-resolution message on every boto3.Session() call. Suppress it.
logging.getLogger("botocore.credentials").setLevel(logging.WARNING)


def _assert_region(region: str | None) -> None:
resolved_region = (
region
or os.environ.get("AWS_REGION_NAME")
or os.environ.get("AWS_DEFAULT_REGION")
or os.environ.get("AWS_REGION")
)
assert resolved_region is not None, (
"you must specify a region: pass `region` explicitly or set AWS_REGION_NAME, AWS_DEFAULT_REGION, or AWS_REGION."
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

assert resolved_region is not None is stripped silently under python -O. For config validation (as opposed to invariant checks) the idiomatic pattern is if resolved_region is None: raise SomeError(...). Same applies to the assert model_name != "" a few lines below.

)
Comment on lines +18 to +27
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do the OpenAI and LiteLLM SDKs check for these same env vars?



def _assert_bedrock_auth() -> None:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we be calling this on the OpenAI path as well? Does the OpenAI SDK work with the creds path?

"""Raises if no valid AWS credentials can be resolved.

Accepts any credential source that boto3 supports:
- Static env vars (AWS_ACCESS_KEY_ID + AWS_SECRET_ACCESS_KEY)
- Named profile (AWS_PROFILE or ~/.aws/credentials)
- ECS task role (AWS_CONTAINER_CREDENTIALS_RELATIVE_URI)
- EC2 / ECS instance profile (IMDSv2)
- LiteLLM-specific Bedrock API key (AWS_BEARER_TOKEN_BEDROCK)
"""
if "AWS_BEARER_TOKEN_BEDROCK" in os.environ:
return

try:
creds = boto3.Session().get_credentials()
if creds is None:
raise botocore.exceptions.NoCredentialsError()
# Resolve to catch expired/invalid assume-role chains early.
creds.get_frozen_credentials()
except botocore.exceptions.NoCredentialsError:
raise AssertionError(
"No AWS credentials found. Provide one of:\n"
" - AWS_BEARER_TOKEN_BEDROCK (Bedrock API key)\n"
" - AWS_ACCESS_KEY_ID + AWS_SECRET_ACCESS_KEY\n"
" - AWS_PROFILE pointing to a configured profile\n"
" - An IAM role attached to the instance/task (EC2, ECS, Lambda)"
)
except botocore.exceptions.NoRegionError:
pass # Credentials exist; region is validated separately.


def _make_region_for_uri(region: str | None):
if region is None:
Expand Down Expand Up @@ -53,7 +102,39 @@ def stringify_mantle_model_ids(region: str | None = None) -> str:
return f" * {model_names}"


def create_bedrock_mantle_backend(
def create_bedrock_litellm_backend(
model_id: ModelIdentifier | str, region: str | None = None, num_retries: int = 15
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

15 seems like a high default for retries. I did a quick search and it seems like Litellm uses exponential waits. I couldn't find a number, but that seems like this might resolve in long hangs.

) -> LiteLLMBackend:
"""Returns a LiteLLM backend that points to Bedrock for model `model_id`.

Use this instead of `create_bedrock_openai_backend` when you need auth with an AWS_ACCESS_KEY_ID.
"""
_assert_bedrock_auth()
_assert_region(region)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If region is passed in, this assertion could succeed but the region isn't passed into the actual LiteLLM Backend init.


model_name = ""
match model_id:
case ModelIdentifier():
if model_id.bedrock_litellm_name is None:
raise Exception(
f"We do not have a known bedrock model identifier for {model_id}. If Bedrock supports this model, please pass the model_id string directly and open an issue to add the model id: https://github.com/generative-computing/mellea/issues/new"
)
else:
model_name = model_id.bedrock_litellm_name
case str():
model_name = model_id
assert model_name != "", (
f"Model identifier {model_id} does not specify a bedrock_name."
)

backend = LiteLLMBackend(model_id=model_name, num_retries=num_retries)

# TODO litellm doesn't even appear to use this...?
backend._base_url = None # type: ignore
Comment on lines +132 to +133
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we remove this? or add a comment why setting the _base_url to None is necessary?

return backend


def create_bedrock_openai_backend(
model_id: ModelIdentifier | str, region: str | None = None
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

create_bedrock_mantle_backend was the documented public entry point before this PR. Removing it without a deprecation alias is a breaking change for anyone importing it directly. A two-line shim with DeprecationWarning would cover existing users.

) -> OpenAIBackend:
"""Return an OpenAI backend that points to Bedrock mantle for the given model.
Expand Down
53 changes: 46 additions & 7 deletions mellea/backends/litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
'Please install them with: pip install "mellea[litellm]"'
) from e


from ..backends import model_ids
from ..core import (
BaseModelSubclass,
Expand Down Expand Up @@ -83,6 +84,7 @@ def __init__(
formatter: ChatFormatter | None = None,
base_url: str | None = "http://localhost:11434",
model_options: dict | None = None,
num_retries: int = 0,
):
"""Initialize a LiteLLM-compatible backend for the given model ID and endpoint."""
super().__init__(
Expand All @@ -103,6 +105,8 @@ def __init__(
else:
self._base_url = base_url

self._num_retries = num_retries

# A mapping of common options for this backend mapped to their Mellea ModelOptions equivalent.
# These are usually values that must be extracted before hand or that are common among backend providers.
# OpenAI has some deprecated parameters. Those map to the same mellea parameter, but
Expand Down Expand Up @@ -163,6 +167,7 @@ async def _generate_from_context(
assert ctx.is_chat_context, NotImplementedError(
"The Openai backend only supports chat-like contexts."
)

span = start_generate_span(
backend=self, action=action, ctx=ctx, format=format, tool_calls=tool_calls
)
Expand Down Expand Up @@ -260,9 +265,17 @@ def _make_backend_specific_and_remove(
standard_openai_subset = litellm.get_standard_openai_params(backend_specific)
unknown_keys = [] # Keys that are unknown to litellm.
unsupported_openai_params = [] # OpenAI params that are known to litellm but not supported for this model/provider.
# Bedrock-specific pass-through params that LiteLLM accepts but doesn't list as supported OpenAI params.
known_provider_passthrough = {
"additional_model_request_fields",
"additional_model_response_field_paths",
}

for key in backend_specific.keys():
if key not in supported_params:
if key in standard_openai_subset:
if key in known_provider_passthrough:
pass # Expected provider-specific params; no warning needed.
elif key in standard_openai_subset:
# LiteLLM is pretty confident that this standard OpenAI parameter won't work.
unsupported_openai_params.append(key)
else:
Expand All @@ -287,8 +300,9 @@ async def _generate_from_chat_context_standard(
action: Component[C] | CBlock,
ctx: Context,
*,
_format: type[BaseModelSubclass]
| None = None, # Type[BaseModelSubclass] is a class object of a subclass of BaseModel
_format: (
type[BaseModelSubclass] | None
) = None, # Type[BaseModelSubclass] is a class object of a subclass of BaseModel
model_options: dict | None = None,
tool_calls: bool = False,
) -> ModelOutputThunk[C]:
Expand Down Expand Up @@ -362,6 +376,7 @@ async def _generate_from_chat_context_standard(
tools=formatted_tools,
reasoning_effort=thinking, # type: ignore
drop_params=True, # See note in `_make_backend_specific_and_remove`.
num_retries=self._num_retries,
**extra_params,
**model_specific_options,
)
Expand Down Expand Up @@ -442,6 +457,16 @@ async def processing(
if content_chunk is not None:
mot._underlying_value += content_chunk

if getattr(choice, "logprobs", None) is not None:
mot._meta["logprobs"] = choice.logprobs

# In some cases (converse API) Bedrock returns logprobs via additionalModelResponseFields.
additional_fields = getattr(chunk, "model_extra", {}) or {}
if "additionalModelResponseFields" in additional_fields:
mot._meta["additionalModelResponseFields"] = additional_fields[
"additionalModelResponseFields"
]

# Store the full response (includes usage) as a dict
mot._meta["litellm_full_response"] = chunk.model_dump()
# Also store just the choice for backward compatibility
Expand All @@ -460,6 +485,12 @@ async def processing(
if content_chunk is not None:
mot._underlying_value += content_chunk

stream_logprobs = getattr(chunk.choices[0], "logprobs", None)
if stream_logprobs is not None:
if "logprobs" not in mot._meta:
mot._meta["logprobs"] = []
mot._meta["logprobs"].append(stream_logprobs)

if mot._meta.get("litellm_chat_response_streamed", None) is None:
mot._meta["litellm_chat_response_streamed"] = []
mot._meta["litellm_chat_response_streamed"].append(
Expand Down Expand Up @@ -494,6 +525,7 @@ async def post_processing(
_format: The structured output format class used during generation, if any.
"""
# Reconstruct the chat_response from chunks if streamed.

streamed_chunks = mot._meta.get("litellm_chat_response_streamed", None)
if streamed_chunks is not None:
# Must handle ollama differently due to: https://github.com/BerriAI/litellm/issues/14579.
Expand All @@ -518,6 +550,7 @@ async def post_processing(
tool_chunk = extract_model_tool_requests(
tools, mot._meta["litellm_chat_response"]
)

if tool_chunk is not None:
if mot.tool_calls is None:
mot.tool_calls = {}
Expand All @@ -540,6 +573,7 @@ async def post_processing(
}
generate_log.action = mot._action
generate_log.result = mot

mot._generate_log = generate_log

# Extract token usage from full response dict or streaming usage
Expand Down Expand Up @@ -675,7 +709,10 @@ async def generate_from_raw(
prompts = [self.formatter.print(action) for action in actions]

completion_response = await litellm.atext_completion(
model=self._model_id, prompt=prompts, **model_specific_options
model=self._model_id,
prompt=prompts,
num_retries=self._num_retries,
**model_specific_options,
)

# Necessary for type checker.
Expand All @@ -696,9 +733,11 @@ async def generate_from_raw(
output._model_options = model_opts
output._meta = {
"litellm_chat_response": res.model_dump(),
"usage": completion_response.usage.model_dump()
if completion_response.usage
else None,
"usage": (
completion_response.usage.model_dump()
if completion_response.usage
else None
),
}

output.parsed_repr = (
Expand Down
Loading
Loading