Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ All notable changes to this project will be documented in this file. From versio
### Fixed

- Fix leaking table and function names when calculating error hint by @taimoorzaeem in #4675
- Limit concurrent schema cache loads by @mkleczek in #4643

## [14.5] - 2026-02-12

Expand Down
1 change: 1 addition & 0 deletions postgrest.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ library
, stm-hamt >= 1.2 && < 2
, focus >= 1.0 && < 2
, some >= 1.0.4.1 && < 2
, uuid >= 1.3 && < 2
-- -fno-spec-constr may help keep compile time memory use in check,
-- see https://gitlab.haskell.org/ghc/ghc/issues/16017#note_219304
-- -optP-Wno-nonportable-include-path
Expand Down
74 changes: 67 additions & 7 deletions src/PostgREST/AppState.hs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE RecursiveDo #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE RecursiveDo #-}
{-# LANGUAGE TypeApplications #-}

module PostgREST.AppState
( AppState
Expand Down Expand Up @@ -33,7 +35,8 @@ import qualified Data.ByteString.Char8 as BS
import Data.Either.Combinators (whenLeft)
import qualified Hasql.Pool as SQL
import qualified Hasql.Pool.Config as SQL
import qualified Hasql.Session as SQL
import qualified Hasql.Session as SQL hiding (statement)
import qualified Hasql.Transaction as SQL hiding (sql)
import qualified Hasql.Transaction.Sessions as SQL
import qualified Network.HTTP.Types.Status as HTTP
import qualified PostgREST.Auth.JwtCache as JwtCache
Expand Down Expand Up @@ -62,11 +65,17 @@ import PostgREST.Config.Database (queryDbSettings,
queryRoleSettings)
import PostgREST.Config.PgVersion (PgVersion (..),
minimumPgVersion)
import PostgREST.Metrics (MetricsState (connTrack))
import PostgREST.SchemaCache (SchemaCache (..),
querySchemaCache,
showSummary)
import PostgREST.SchemaCache.Identifiers (quoteQi)

import qualified Hasql.Decoders as HD
import qualified Hasql.Encoders as HE
import qualified Hasql.Statement as SQL
import NeatInterpolation (trimming)

import Protolude

data AppState = AppState
Expand Down Expand Up @@ -309,7 +318,7 @@ getObserver = stateObserver
-- + Because connections cache the pg catalog(see #2620)
-- + For rapid recovery. Otherwise, the pool idle or lifetime timeout would have to be reached for new healthy connections to be acquired.
retryingSchemaCacheLoad :: AppState -> IO ()
retryingSchemaCacheLoad appState@AppState{stateObserver=observer, stateMainThreadId=mainThreadId} =
retryingSchemaCacheLoad appState@AppState{stateObserver=observer, stateMainThreadId=mainThreadId, stateMetrics} =
void $ retrying retryPolicy shouldRetry (\RetryStatus{rsIterNumber, rsPreviousDelay} -> do
when (rsIterNumber > 0) $ do
let delay = fromMaybe 0 rsPreviousDelay `div` oneSecondInUs
Expand Down Expand Up @@ -350,9 +359,23 @@ retryingSchemaCacheLoad appState@AppState{stateObserver=observer, stateMainThrea
qSchemaCache :: IO (Maybe SchemaCache)
qSchemaCache = do
conf@AppConfig{..} <- getConfig appState
-- Throttle concurrent schema cache loads, guarded by advisory locks.
-- This is to prevent thundering herd problem on startup or when many PostgREST
-- instances receive "reload schema" notifications at the same time
-- See get_lock_sql for details of the algorithm.
-- Here we calculate the number of open connections passed to the query.
Metrics.ConnStats connected inUse <- Metrics.connectionCounts $ connTrack stateMetrics
-- Determine whether schema cache loading will create a new session
let
-- if all connections in use but pool not full - schema cache loading will create session
scLoadingSessions = if connected <= inUse && inUse < configDbPoolSize then 1 else 0
withTxLock = SQL.statement
(fromIntegral $ connected + scLoadingSessions)
(SQL.Statement get_lock_sql get_lock_params HD.noResult configDbPreparedStatements)

(resultTime, result) <-
let transaction = if configDbPreparedStatements then SQL.transaction else SQL.unpreparedTransaction in
timeItT $ usePool appState (transaction SQL.ReadCommitted SQL.Read $ querySchemaCache conf)
timeItT $ usePool appState (transaction SQL.ReadCommitted SQL.Read $ withTxLock *> querySchemaCache conf)
case result of
Left e -> do
markSchemaCachePending appState
Expand All @@ -369,6 +392,43 @@ retryingSchemaCacheLoad appState@AppState{stateObserver=observer, stateMainThrea
observer . uncurry SchemaCacheLoadedObs =<< timeItT (evaluate $ showSummary sCache)
markSchemaCacheLoaded appState
return $ Just sCache
where
-- Recursive query that tries acquiring locks in order
-- and waits for randomly selected lock if no attempt succeeded.
-- It has a single parameter: this node open connection count.
-- It is used to estimate the number of nodes
-- by counting the number of active sessions for current session_user
-- and dividing it by this node open connections.
-- Assuming load is uniform among cluster nodes, all should have
-- statistically the same number of open connections.
-- Once the number of nodes is known we calculate the number
-- of locks as ceil(log(2, number_of_nodes))
get_lock_sql = encodeUtf8 [trimming|
WITH RECURSIVE attempts AS (
SELECT 1 AS lock_number, pg_try_advisory_xact_lock(lock_id, 1) AS success FROM parameters
UNION ALL
SELECT next_lock_number AS lock_number, pg_try_advisory_xact_lock(lock_id, next_lock_number) AS success
FROM
parameters CROSS JOIN LATERAL (
SELECT lock_number + 1 AS next_lock_number FROM attempts
WHERE NOT success AND lock_number < locks_count
ORDER BY lock_number DESC
LIMIT 1
) AS previous_attempt
),
counts AS (
SELECT round(log(2, round(count(*)::double precision/$$1)::numeric))::int AS locks_count
FROM
pg_stat_activity WHERE usename = SESSION_USER
),
parameters AS (
SELECT locks_count, 50168275 AS lock_id FROM counts WHERE locks_count > 0
)
SELECT pg_advisory_xact_lock(lock_id, floor(random() * locks_count)::int + 1)
FROM
parameters WHERE NOT EXISTS (SELECT 1 FROM attempts WHERE success) |]

get_lock_params = HE.param (HE.nonNullable HE.int4)

shouldRetry :: RetryStatus -> (Maybe PgVersion, Maybe SchemaCache) -> IO Bool
shouldRetry _ (pgVer, sCache) = do
Expand Down
55 changes: 44 additions & 11 deletions src/PostgREST/Metrics.hs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@ Description : Metrics based on the Observation module. See Observation.hs.
-}
module PostgREST.Metrics
( init
, ConnTrack
, ConnStats (..)
, MetricsState (..)
, connectionCounts
, observationMetrics
, metricsToText
) where
Expand All @@ -17,12 +20,18 @@ import Prometheus

import PostgREST.Observation

import Protolude
import Control.Arrow ((&&&))
import Data.Bitraversable (bisequenceA)
import Data.Tuple.Extra (both)
import Data.UUID (UUID)
import qualified Focus
import Protolude
import qualified StmHamt.SizedHamt as SH

data MetricsState =
MetricsState {
poolTimeouts :: Counter,
poolAvailable :: Gauge,
connTrack :: ConnTrack,
poolWaiting :: Gauge,
poolMaxSize :: Gauge,
schemaCacheLoads :: Vector Label1 Counter,
Expand All @@ -36,7 +45,7 @@ init :: Int -> IO MetricsState
init configDbPoolSize = do
metricState <- MetricsState <$>
register (counter (Info "pgrst_db_pool_timeouts_total" "The total number of pool connection timeouts")) <*>
register (gauge (Info "pgrst_db_pool_available" "Available connections in the pool")) <*>
register (Metric ((identity &&& dbPoolAvailable) <$> connectionTracker)) <*>
register (gauge (Info "pgrst_db_pool_waiting" "Requests waiting to acquire a pool connection")) <*>
register (gauge (Info "pgrst_db_pool_max" "Max pool connections")) <*>
register (vector "status" $ counter (Info "pgrst_schema_cache_loads_total" "The total number of times the schema cache was loaded")) <*>
Expand All @@ -46,20 +55,19 @@ init configDbPoolSize = do
register (counter (Info "pgrst_jwt_cache_evictions_total" "The total number of JWT cache evictions"))
setGauge (poolMaxSize metricState) (fromIntegral configDbPoolSize)
pure metricState
where
dbPoolAvailable = (pure . noLabelsGroup (Info "pgrst_db_pool_available" "Available connections in the pool") GaugeType . calcAvailable <$>) . connectionCounts
where
calcAvailable = (configDbPoolSize -) . inUse
toSample name labels = Sample name labels . encodeUtf8 . show
noLabelsGroup info sampleType = SampleGroup info sampleType . pure . toSample (metricName info) mempty

-- Only some observations are used as metrics
observationMetrics :: MetricsState -> ObservationHandler
observationMetrics MetricsState{..} obs = case obs of
PoolAcqTimeoutObs -> do
incCounter poolTimeouts
(HasqlPoolObs (SQL.ConnectionObservation _ status)) -> case status of
SQL.ReadyForUseConnectionStatus -> do
incGauge poolAvailable
SQL.InUseConnectionStatus -> do
decGauge poolAvailable
SQL.TerminatedConnectionStatus _ -> do
decGauge poolAvailable
SQL.ConnectingConnectionStatus -> pure ()
(HasqlPoolObs sqlObs) -> trackConnections connTrack sqlObs
PoolRequest ->
incGauge poolWaiting
PoolRequestFullfilled ->
Expand All @@ -77,3 +85,28 @@ observationMetrics MetricsState{..} obs = case obs of

metricsToText :: IO LBS.ByteString
metricsToText = exportMetricsAsText

data ConnStats = ConnStats {
connected :: Int,
inUse :: Int
} deriving (Eq, Show)

data ConnTrack = ConnTrack { connTrackConnected :: SH.SizedHamt UUID, connTrackInUse :: SH.SizedHamt UUID }

connectionTracker :: IO ConnTrack
connectionTracker = ConnTrack <$> SH.newIO <*> SH.newIO

trackConnections :: ConnTrack -> SQL.Observation -> IO ()
trackConnections ConnTrack{..} (SQL.ConnectionObservation uuid status) = case status of
SQL.ReadyForUseConnectionStatus -> atomically $
SH.insert identity uuid connTrackConnected *>
SH.focus Focus.delete identity uuid connTrackInUse
SQL.TerminatedConnectionStatus _ -> atomically $
SH.focus Focus.delete identity uuid connTrackConnected *>
SH.focus Focus.delete identity uuid connTrackInUse
SQL.InUseConnectionStatus -> atomically $
SH.insert identity uuid connTrackInUse
_ -> mempty

connectionCounts :: ConnTrack -> IO ConnStats
connectionCounts = atomically . fmap (uncurry ConnStats) . bisequenceA . both SH.size . (connTrackConnected &&& connTrackInUse)
89 changes: 88 additions & 1 deletion test/io/test_io.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"Unit tests for Input/Ouput of PostgREST seen as a black box."

import contextlib
import os
import re
import signal
Expand All @@ -18,6 +19,7 @@
sleep_until_postgrest_full_reload,
sleep_until_postgrest_scache_reload,
wait_until_exit,
wait_until_status_code,
)


Expand Down Expand Up @@ -1219,6 +1221,91 @@ def test_schema_cache_concurrent_notifications(slow_schema_cache_env):
assert response.status_code == 200


@pytest.mark.parametrize(
"instance_count, expected_concurrency", [(2, 2), (4, 3), (6, 4), (8, 4), (16, 5)]
)
def test_schema_cache_reload_throttled_with_advisory_locks(
instance_count, expected_concurrency, slow_schema_cache_env
):
"schema cache reloads should be throttled across instances if instance count > 10"

internal_sleep_ms = int(
slow_schema_cache_env["PGRST_INTERNAL_SCHEMA_CACHE_QUERY_SLEEP"]
)
lock_wait_threshold_ms = internal_sleep_ms * 2
query_log_pattern = re.compile(r"Schema cache queried in ([\d.]+) milliseconds")

def read_available_output_lines(postgrest):
try:
output = postgrest.process.stdout.read()
except BlockingIOError:
return []

if not output:
return []
return output.decode().splitlines()

with contextlib.ExitStack() as stack:
instances = [
stack.enter_context(
run(
env=slow_schema_cache_env,
wait_for_readiness=False,
wait_max_seconds=10,
)
)
for _ in range(instance_count)
]

for postgrest in instances:
wait_until_status_code(
postgrest.admin.baseurl + "/ready", max_seconds=10, status_code=200
)

# Drop startup logs so only reload logs are parsed.
for postgrest in instances:
read_available_output_lines(postgrest)

response = instances[0].session.get("/rpc/notify_pgrst")
assert response.status_code == 204

# Wait long enough for the lock-throttled cache reloads to finish.
time.sleep((internal_sleep_ms / 1000) * 2)

reload_durations_ms = []
for postgrest in instances:
output_lines = []
for _ in range(instance_count * 2):
output_lines.extend(read_available_output_lines(postgrest))
if any(query_log_pattern.search(line) for line in output_lines):
break
time.sleep(0.2)

durations = []
for line in output_lines:
match = query_log_pattern.search(line)
if match:
durations.append(float(match.group(1)))

assert durations
reload_durations_ms.append(max(durations))

assert len(reload_durations_ms) == instance_count

# 10 instances should be fast, remaining instances should be slow
assert (
instance_count
- len(
[
duration
for duration in reload_durations_ms
if duration > lock_wait_threshold_ms
]
)
== expected_concurrency
)


def test_schema_cache_query_sleep_logs(defaultenv):
"""Schema cache sleep should be reflected in the logged query duration."""

Expand Down Expand Up @@ -1856,7 +1943,7 @@ def test_requests_with_resource_embedding_wait_for_schema_cache_reload(defaulten
env = {
**defaultenv,
"PGRST_DB_POOL": "2",
"PGRST_INTERNAL_SCHEMA_CACHE_RELATIONSHIP_LOAD_SLEEP": "5100",
"PGRST_INTERNAL_SCHEMA_CACHE_RELATIONSHIP_LOAD_SLEEP": "5200",
}

with run(env=env, wait_max_seconds=30) as postgrest:
Expand Down
Loading