Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -254,9 +254,11 @@ def check_url_safety(cls, v: Any) -> Any: # noqa: N805
return v

from oss.src.dbs.postgres.git.mappings import map_dto_to_dbe
from oss.src.dbs.postgres.shared.engine import engine as db_engine
from oss.src.dbs.postgres.shared.engine import get_transactions_engine
from datetime import datetime, timezone

db_engine = get_transactions_engine()

workflow_create = WorkflowCreate(
**application_create.model_dump(mode="json"),
)
Expand All @@ -267,7 +269,7 @@ def check_url_safety(cls, v: Any) -> Any: # noqa: N805

# Avoid slug collision with existing workflow artifacts (e.g. evaluators)
artifact_slug = git_artifact_create.slug
async with db_engine.core_session() as session:
async with db_engine.session() as session:
existing = (
await session.execute(
select(WorkflowArtifactDBE).filter(
Expand Down Expand Up @@ -298,7 +300,7 @@ def check_url_safety(cls, v: Any) -> Any: # noqa: N805
dto=artifact_dto,
)

async with db_engine.core_session() as session:
async with db_engine.session() as session:
session.add(artifact_dbe)
await session.commit()

Expand Down Expand Up @@ -364,7 +366,7 @@ def check_url_safety(cls, v: Any) -> Any: # noqa: N805
dto=variant_dto,
)

async with db_engine.core_session() as session:
async with db_engine.session() as session:
session.add(variant_dbe)
await session.commit()

Expand Down Expand Up @@ -415,7 +417,7 @@ def check_url_safety(cls, v: Any) -> Any: # noqa: N805
dto=revision_dto,
)

async with db_engine.core_session() as session:
async with db_engine.session() as session:
session.add(revision_dbe)
await session.commit()

Expand Down
4 changes: 2 additions & 2 deletions api/ee/databases/postgres/migrations/core/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from alembic import context

from oss.src.dbs.postgres.shared.engine import engine
from oss.src.utils.env import env
from oss.src.dbs.postgres.shared.base import Base

# Side-effect imports: register SQLAlchemy models with Base.metadata
Expand All @@ -29,7 +29,7 @@
# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
config = context.config
config.set_main_option("sqlalchemy.url", engine.postgres_uri_core) # type: ignore
config.set_main_option("sqlalchemy.url", env.postgres.uri_core)


# Interpret the config file for Python logging.
Expand Down
4 changes: 2 additions & 2 deletions api/ee/databases/postgres/migrations/tracing/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from alembic import context

from oss.src.dbs.postgres.shared.engine import engine
from oss.src.utils.env import env
from oss.src.dbs.postgres.shared.base import Base

# Side-effect import: register SQLAlchemy model with Base.metadata
Expand All @@ -19,7 +19,7 @@
# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
config = context.config
config.set_main_option("sqlalchemy.url", engine.postgres_uri_tracing) # type: ignore
config.set_main_option("sqlalchemy.url", env.postgres.uri_tracing)


# Interpret the config file for Python logging.
Expand Down
15 changes: 6 additions & 9 deletions api/ee/src/core/meters/service.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from typing import Awaitable, Tuple, Callable, List, Optional
from uuid import uuid4

import stripe

from oss.src.utils.logging import get_module_logger
from oss.src.utils.env import env
from oss.src.utils.lazy import _load_stripe

from ee.src.core.entitlements.types import Quota
from ee.src.core.entitlements.types import Counter, Gauge, REPORTS
Expand All @@ -13,13 +12,6 @@

log = get_module_logger(__name__)

# Initialize Stripe only if enabled
if env.stripe.enabled:
stripe.api_key = env.stripe.api_key
log.info("✓ Stripe enabled:", target=env.stripe.webhook_target)
else:
log.info("✗ Stripe disabled")


class MetersService:
def __init__(
Expand Down Expand Up @@ -82,6 +74,11 @@ async def report(
log.warn("✗ Stripe disabled")
return

stripe = _load_stripe()
if stripe is None:
log.error("[report] Failed to load Stripe module")
return

log.info("[report] ============================================")
log.info("[report] Starting meter report job")
log.info("[report] ============================================")
Expand Down
20 changes: 10 additions & 10 deletions api/ee/src/core/subscriptions/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,10 @@
from uuid import getnode
from datetime import datetime, timezone, timedelta


import stripe

from oss.src.utils.logging import get_module_logger
from oss.src.utils.env import env
from oss.src.utils.caching import invalidate_cache
from oss.src.utils.lazy import _load_stripe

from ee.src.core.subscriptions.types import (
SubscriptionDTO,
Expand All @@ -24,13 +22,6 @@

log = get_module_logger(__name__)

# Initialize Stripe only if enabled
if env.stripe.enabled:
stripe.api_key = env.stripe.api_key
log.info("✓ Stripe enabled:", target=env.stripe.webhook_target)
else:
log.info("✗ Stripe disabled")

MAC_ADDRESS = ":".join(f"{(getnode() >> ele) & 0xFF:02x}" for ele in range(40, -1, -8))


Expand Down Expand Up @@ -83,6 +74,10 @@ async def start_reverse_trial(
if not env.stripe.enabled:
raise EventException("Reverse trial requires Stripe to be enabled")

stripe = _load_stripe()
if stripe is None:
raise EventException("Failed to load Stripe module")

now = datetime.now(tz=timezone.utc)
anchor = now + timedelta(days=REVERSE_TRIAL_DAYS)

Expand Down Expand Up @@ -262,6 +257,11 @@ async def process_event(
log.warn("✗ Stripe disabled")
return None

stripe = _load_stripe()
if stripe is None:
log.error("Failed to load Stripe module")
raise EventException("Stripe is not available for plan switching")

if subscription.plan == plan:
log.warn("Subscription already on the plan: %s", plan)

Expand Down
21 changes: 13 additions & 8 deletions api/ee/src/dbs/postgres/meters/dao.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@


from oss.src.utils.logging import get_module_logger
from oss.src.dbs.postgres.shared.engine import engine
from oss.src.dbs.postgres.shared.engine import (
TransactionsEngine,
get_transactions_engine,
)

from ee.src.core.entitlements.types import Quota
from ee.src.core.meters.types import MeterDTO
Expand All @@ -22,16 +25,18 @@


class MetersDAO(MetersDAOInterface):
def __init__(self):
pass
def __init__(self, engine: TransactionsEngine = None):
if engine is None:
engine = get_transactions_engine()
self.engine = engine

async def dump(
self,
limit: Optional[int] = None,
) -> list[MeterDTO]:
log.info(f"[report] [dump] Starting (limit={limit or 'none'})")

async with engine.core_session() as session:
async with self.engine.session() as session:
try:
stmt = (
select(MeterDBE)
Expand Down Expand Up @@ -203,7 +208,7 @@ async def _bump_commit_chunk(
missing_count = 0
missing_samples: list[str] = []

async with engine.core_session() as session:
async with self.engine.session() as session:
for meter in meters:
stmt = (
update(MeterDBE)
Expand Down Expand Up @@ -249,7 +254,7 @@ async def fetch(
year: Optional[int] = None,
month: Optional[int] = None,
) -> list[MeterDTO]:
async with engine.core_session() as session:
async with self.engine.session() as session:
stmt = select(MeterDBE).filter_by(
organization_id=organization_id,
) # NO RISK OF DEADLOCK
Expand Down Expand Up @@ -288,7 +293,7 @@ async def check(
year, month = compute_billing_period(anchor=anchor)
meter.year, meter.month = year, month

async with engine.core_session() as session:
async with self.engine.session() as session:
stmt = select(MeterDBE).filter_by(
organization_id=meter.organization_id,
key=meter.key,
Expand Down Expand Up @@ -376,7 +381,7 @@ async def adjust(
where = where | where_clause

# 4. Build SQL statement (atomic upsert with RETURNING)
async with engine.core_session() as session:
async with self.engine.session() as session:
stmt = (
insert(MeterDBE)
.values(
Expand Down
Loading
Loading