Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 164 additions & 0 deletions hishel/_core/_storages/_async_redis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
from __future__ import annotations

import contextlib
from collections.abc import AsyncIterator, Callable
from dataclasses import replace
from time import time
from typing import cast
from uuid import UUID, uuid4

from redis import RedisError
from redis.asyncio import Redis

from hishel._core._storages._async_base import AsyncBaseStorage
from hishel._core._storages._packing import pack, unpack
from hishel._core.models import Entry, EntryMeta, Request, Response


class AsyncRedisStorage(AsyncBaseStorage):
def __init__(self, client: Redis, ttl: int | float | None = None, key_prefix: str = "hishel") -> None:
self._client = client
self._default_ttl = ttl
self._key_prefix = key_prefix

async def create_entry(self, request: Request, response: Response, key: str, id_: UUID | None = None) -> Entry:
pair_id = id_ or uuid4()
key_bytes = key.encode()

response_with_stream = replace(
response,
stream=self._save_stream(cast(AsyncIterator[bytes], response.stream), pair_id),
)

entry = Entry(
id=pair_id,
request=request,
response=response_with_stream,
meta=EntryMeta(created_at=time()),
cache_key=key_bytes,
)

packed = pack(entry, kind="pair")
entry_key = f"{self._key_prefix}:entry:{pair_id.hex}"
idx_key = f"{self._key_prefix}:idx:{key}"

if self._default_ttl is not None:
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note that we should respect hishel_ttl metadata as we do support for sqlite3 (

ttl = pair.request.metadata["hishel_ttl"] if "hishel_ttl" in pair.request.metadata else self.default_ttl
)

await self._client.set(entry_key, packed, ex=int(self._default_ttl))
else:
await self._client.set(entry_key, packed)

await self._client.sadd(idx_key, pair_id.hex) # type: ignore[misc]
if self._default_ttl is not None:
await self._client.expire(idx_key, int(self._default_ttl))

return entry

async def _save_stream(self, stream: AsyncIterator[bytes], pair_id: UUID) -> AsyncIterator[bytes]:
stream_key = f"{self._key_prefix}:stream:{pair_id.hex}"
done_key = f"{self._key_prefix}:stream_done:{pair_id.hex}"

async for chunk in stream:
await self._client.rpush(stream_key, chunk) # type: ignore[misc]
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't the redis client have types? That'd be great if we could avoid this type ignores

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Redis client used to have mixed sync and async types, which forces us to either have this type ignores, or do a cast.

I've just seen it has been changed a few weeks ago with this PR: redis/redis-py#4005 and will be included in the next major version (release: https://github.com/redis/redis-py/releases/tag/v8.0.0b1)

Until then I don't think we have any other option but to have this type ignores or do a cast.

I've pushed a commit (35a0bed) for being consistent and having casts everywhere needed instead of type ignores.

yield chunk

# sentinel to mark end of stream
await self._client.rpush(stream_key, b"") # type: ignore[misc]
await self._client.set(done_key, b"1")
if self._default_ttl is not None:
await self._client.expire(stream_key, int(self._default_ttl))
await self._client.expire(done_key, int(self._default_ttl))

async def _is_stream_complete(self, entry_id: UUID) -> bool:
result = await self._client.exists(f"{self._key_prefix}:stream_done:{entry_id.hex}")
return bool(result)

def _is_pair_expired(self, pair: Entry) -> bool:
ttl = pair.request.metadata.get("hishel_ttl", self._default_ttl)
return ttl is not None and pair.meta.created_at + ttl < time()

async def _stream_from_cache(self, entry_id: UUID) -> AsyncIterator[bytes]:
stream_key = f"{self._key_prefix}:stream:{entry_id.hex}"
length = cast(int, await self._client.llen(stream_key)) # type: ignore[misc]
for i in range(length - 1): # -1 excludes the sentinel
chunk = cast(bytes | None, await self._client.lindex(stream_key, i)) # type: ignore[misc]
if chunk is not None:
yield chunk.encode() if isinstance(chunk, str) else chunk

async def get_entries(self, key: str) -> list[Entry]:
idx_key = f"{self._key_prefix}:idx:{key}"
members = cast(set[bytes], await self._client.smembers(idx_key)) # type: ignore[misc]

result: list[Entry] = []
for member in members:
hex_str = member.decode() if isinstance(member, bytes) else member
entry_key = f"{self._key_prefix}:entry:{hex_str}"

data = await self._client.get(entry_key)
if data is None:
await self._client.srem(idx_key, member) # type: ignore[misc]
continue

entry = unpack(cast(bytes, data), kind="pair")
if entry is None:
continue

if not await self._is_stream_complete(entry.id):
continue

if self._is_pair_expired(entry):
continue

if self.is_soft_deleted(entry):
continue

result.append(
replace(
entry,
response=replace(entry.response, stream=self._stream_from_cache(entry.id)),
)
)

return result

async def update_entry(
self,
id: UUID, # noqa: A002
new_entry: Entry | Callable[[Entry], Entry],
) -> Entry | None:
entry_key = f"{self._key_prefix}:entry:{id.hex}"
data = await self._client.get(entry_key)
if data is None:
return None

existing = unpack(cast(bytes, data), kind="pair")

updated = new_entry(existing) if callable(new_entry) else new_entry

if existing.id != updated.id:
raise ValueError("Entry ID mismatch")

pttl = cast(int, await self._client.pttl(entry_key))
packed = pack(updated, kind="pair")
if pttl > 0:
await self._client.set(entry_key, packed, px=pttl)
else:
await self._client.set(entry_key, packed)

if existing.cache_key != updated.cache_key:
old_key = existing.cache_key.decode() if isinstance(existing.cache_key, bytes) else existing.cache_key
new_key = updated.cache_key.decode() if isinstance(updated.cache_key, bytes) else updated.cache_key
await self._client.srem(f"{self._key_prefix}:idx:{old_key}", id.hex) # type: ignore[misc]
await self._client.sadd(f"{self._key_prefix}:idx:{new_key}", id.hex) # type: ignore[misc]

return updated

async def remove_entry(self, id: UUID) -> None: # noqa: A002
entry_key = f"{self._key_prefix}:entry:{id.hex}"
stream_key = f"{self._key_prefix}:stream:{id.hex}"
done_key = f"{self._key_prefix}:stream_done:{id.hex}"
with contextlib.suppress(RedisError):
# don't let deletes prevent reads; failures are non-fatal
await self._client.delete(entry_key, stream_key, done_key)

async def close(self) -> None:
await self._client.aclose()
164 changes: 164 additions & 0 deletions hishel/_core/_storages/_sync_redis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
from __future__ import annotations

import contextlib
from collections.abc import Iterator, Callable
from dataclasses import replace
from time import time
from typing import cast
from uuid import UUID, uuid4

from redis import RedisError
from redis import Redis

from hishel._core._storages._sync_base import SyncBaseStorage
from hishel._core._storages._packing import pack, unpack
from hishel._core.models import Entry, EntryMeta, Request, Response


class RedisStorage(SyncBaseStorage):
def __init__(self, client: Redis, ttl: int | float | None = None, key_prefix: str = "hishel") -> None:
self._client = client
self._default_ttl = ttl
self._key_prefix = key_prefix

def create_entry(self, request: Request, response: Response, key: str, id_: UUID | None = None) -> Entry:
pair_id = id_ or uuid4()
key_bytes = key.encode()

response_with_stream = replace(
response,
stream=self._save_stream(cast(Iterator[bytes], response.stream), pair_id),
)

entry = Entry(
id=pair_id,
request=request,
response=response_with_stream,
meta=EntryMeta(created_at=time()),
cache_key=key_bytes,
)

packed = pack(entry, kind="pair")
entry_key = f"{self._key_prefix}:entry:{pair_id.hex}"
idx_key = f"{self._key_prefix}:idx:{key}"

if self._default_ttl is not None:
self._client.set(entry_key, packed, ex=int(self._default_ttl))
else:
self._client.set(entry_key, packed)

self._client.sadd(idx_key, pair_id.hex) # type: ignore[misc]
if self._default_ttl is not None:
self._client.expire(idx_key, int(self._default_ttl))

return entry

def _save_stream(self, stream: Iterator[bytes], pair_id: UUID) -> Iterator[bytes]:
stream_key = f"{self._key_prefix}:stream:{pair_id.hex}"
done_key = f"{self._key_prefix}:stream_done:{pair_id.hex}"

for chunk in stream:
self._client.rpush(stream_key, chunk) # type: ignore[misc]
yield chunk

# sentinel to mark end of stream
self._client.rpush(stream_key, b"") # type: ignore[misc]
self._client.set(done_key, b"1")
if self._default_ttl is not None:
self._client.expire(stream_key, int(self._default_ttl))
self._client.expire(done_key, int(self._default_ttl))

def _is_stream_complete(self, entry_id: UUID) -> bool:
result = self._client.exists(f"{self._key_prefix}:stream_done:{entry_id.hex}")
return bool(result)

def _is_pair_expired(self, pair: Entry) -> bool:
ttl = pair.request.metadata.get("hishel_ttl", self._default_ttl)
return ttl is not None and pair.meta.created_at + ttl < time()

def _stream_from_cache(self, entry_id: UUID) -> Iterator[bytes]:
stream_key = f"{self._key_prefix}:stream:{entry_id.hex}"
length = cast(int, self._client.llen(stream_key)) # type: ignore[misc]
for i in range(length - 1): # -1 excludes the sentinel
chunk = cast(bytes | None, self._client.lindex(stream_key, i)) # type: ignore[misc]
if chunk is not None:
yield chunk.encode() if isinstance(chunk, str) else chunk

def get_entries(self, key: str) -> list[Entry]:
idx_key = f"{self._key_prefix}:idx:{key}"
members = cast(set[bytes], self._client.smembers(idx_key)) # type: ignore[misc]

result: list[Entry] = []
for member in members:
hex_str = member.decode() if isinstance(member, bytes) else member
entry_key = f"{self._key_prefix}:entry:{hex_str}"

data = self._client.get(entry_key)
if data is None:
self._client.srem(idx_key, member) # type: ignore[misc]
continue

entry = unpack(cast(bytes, data), kind="pair")
if entry is None:
continue

if not self._is_stream_complete(entry.id):
continue

if self._is_pair_expired(entry):
continue

if self.is_soft_deleted(entry):
continue

result.append(
replace(
entry,
response=replace(entry.response, stream=self._stream_from_cache(entry.id)),
)
)

return result

def update_entry(
self,
id: UUID, # noqa: A002
new_entry: Entry | Callable[[Entry], Entry],
) -> Entry | None:
entry_key = f"{self._key_prefix}:entry:{id.hex}"
data = self._client.get(entry_key)
if data is None:
return None

existing = unpack(cast(bytes, data), kind="pair")

updated = new_entry(existing) if callable(new_entry) else new_entry

if existing.id != updated.id:
raise ValueError("Entry ID mismatch")

pttl = cast(int, self._client.pttl(entry_key))
packed = pack(updated, kind="pair")
if pttl > 0:
self._client.set(entry_key, packed, px=pttl)
else:
self._client.set(entry_key, packed)

if existing.cache_key != updated.cache_key:
old_key = existing.cache_key.decode() if isinstance(existing.cache_key, bytes) else existing.cache_key
new_key = updated.cache_key.decode() if isinstance(updated.cache_key, bytes) else updated.cache_key
self._client.srem(f"{self._key_prefix}:idx:{old_key}", id.hex) # type: ignore[misc]
self._client.sadd(f"{self._key_prefix}:idx:{new_key}", id.hex) # type: ignore[misc]

return updated

def remove_entry(self, id: UUID) -> None: # noqa: A002
entry_key = f"{self._key_prefix}:entry:{id.hex}"
stream_key = f"{self._key_prefix}:stream:{id.hex}"
done_key = f"{self._key_prefix}:stream_done:{id.hex}"
with contextlib.suppress(RedisError):
# don't let deletes prevent reads; failures are non-fatal
self._client.delete(entry_key, stream_key, done_key)

def close(self) -> None:
self._client.close()
6 changes: 6 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ httpx = [
fastapi = [
"fastapi>=0.119.1",
]
redis = [
"redis>=6.2.0",
]

[project.urls]
Homepage = "https://hishel.com"
Expand Down Expand Up @@ -103,6 +106,7 @@ exclude = [
"hishel/_sync_cache.py",
"tests/test_sync_httpx.py",
"hishel/_core/_storages/_sync_sqlite.py",
"hishel/_core/_storages/_sync_redis.py",
"hishel/_core/_storages/_sync_base.py",
"tests/test_sync_httpx.py",
"hishel/_sync_httpx.py"
Expand Down Expand Up @@ -140,5 +144,7 @@ dev = [
"types-boto3==1.42.39",
"types-pyyaml==6.0.12.20250915",
"types-requests>=2.31.0.6",
"fakeredis>=2.0",
"redis>=6.2.0",
"zipp>=3.19.1",
]
9 changes: 9 additions & 0 deletions scripts/unasync
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,14 @@ SUBS = [
("AsyncBaseStorage", "SyncBaseStorage"),
("AsyncCacheClient", "SyncCacheClient"),
("AsyncSqliteStorage", "SyncSqliteStorage"),
("AsyncRedisStorage", "RedisStorage"),
("anysqlite", "sqlite3"),
("redis.asyncio", "redis"),
("fakeredis.aioredis", "fakeredis"),
(
"hishel._core._storages._async_redis",
"hishel._core._storages._sync_redis",
),
("@pytest.mark.anyio", ""),
("from anyio import Lock", "from threading import RLock"),
("self._lock = Lock", "self._lock = RLock"),
Expand Down Expand Up @@ -111,6 +118,8 @@ def main():
("hishel/_async_cache.py", "hishel/_sync_cache.py"),
("hishel/_core/_storages/_async_base.py", "hishel/_core/_storages/_sync_base.py"),
("hishel/_core/_storages/_async_sqlite.py", "hishel/_core/_storages/_sync_sqlite.py"),
("hishel/_core/_storages/_async_redis.py", "hishel/_core/_storages/_sync_redis.py"),
("tests/_core/_async/test_redis_storage.py", "tests/_core/_sync/test_redis_storage.py"),
("hishel/_async_httpx.py", "hishel/_sync_httpx.py"),
]

Expand Down
Loading
Loading