diff --git a/CHANGELOG.md b/CHANGELOG.md index b5534293b2..a722e2b13e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -47,6 +47,7 @@ All notable changes to this project will be documented in this file. From versio ### Fixed - Fix leaking table and function names when calculating error hint by @taimoorzaeem in #4675 +- Limit concurrent schema cache loads by @mkleczek in #4643 ## [14.5] - 2026-02-12 diff --git a/postgrest.cabal b/postgrest.cabal index 188653a410..3612efe996 100644 --- a/postgrest.cabal +++ b/postgrest.cabal @@ -159,6 +159,7 @@ library , stm-hamt >= 1.2 && < 2 , focus >= 1.0 && < 2 , some >= 1.0.4.1 && < 2 + , uuid >= 1.3 && < 2 -- -fno-spec-constr may help keep compile time memory use in check, -- see https://gitlab.haskell.org/ghc/ghc/issues/16017#note_219304 -- -optP-Wno-nonportable-include-path diff --git a/src/PostgREST/AppState.hs b/src/PostgREST/AppState.hs index 3f60ef04a6..65112b6e81 100644 --- a/src/PostgREST/AppState.hs +++ b/src/PostgREST/AppState.hs @@ -1,7 +1,9 @@ -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE NamedFieldPuns #-} -{-# LANGUAGE RecordWildCards #-} -{-# LANGUAGE RecursiveDo #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE NamedFieldPuns #-} +{-# LANGUAGE QuasiQuotes #-} +{-# LANGUAGE RecordWildCards #-} +{-# LANGUAGE RecursiveDo #-} +{-# LANGUAGE TypeApplications #-} module PostgREST.AppState ( AppState @@ -33,7 +35,8 @@ import qualified Data.ByteString.Char8 as BS import Data.Either.Combinators (whenLeft) import qualified Hasql.Pool as SQL import qualified Hasql.Pool.Config as SQL -import qualified Hasql.Session as SQL +import qualified Hasql.Session as SQL hiding (statement) +import qualified Hasql.Transaction as SQL hiding (sql) import qualified Hasql.Transaction.Sessions as SQL import qualified Network.HTTP.Types.Status as HTTP import qualified PostgREST.Auth.JwtCache as JwtCache @@ -62,11 +65,17 @@ import PostgREST.Config.Database (queryDbSettings, queryRoleSettings) import PostgREST.Config.PgVersion (PgVersion (..), minimumPgVersion) +import PostgREST.Metrics (MetricsState (connTrack)) import PostgREST.SchemaCache (SchemaCache (..), querySchemaCache, showSummary) import PostgREST.SchemaCache.Identifiers (quoteQi) +import qualified Hasql.Decoders as HD +import qualified Hasql.Encoders as HE +import qualified Hasql.Statement as SQL +import NeatInterpolation (trimming) + import Protolude data AppState = AppState @@ -309,7 +318,7 @@ getObserver = stateObserver -- + Because connections cache the pg catalog(see #2620) -- + For rapid recovery. Otherwise, the pool idle or lifetime timeout would have to be reached for new healthy connections to be acquired. retryingSchemaCacheLoad :: AppState -> IO () -retryingSchemaCacheLoad appState@AppState{stateObserver=observer, stateMainThreadId=mainThreadId} = +retryingSchemaCacheLoad appState@AppState{stateObserver=observer, stateMainThreadId=mainThreadId, stateMetrics} = void $ retrying retryPolicy shouldRetry (\RetryStatus{rsIterNumber, rsPreviousDelay} -> do when (rsIterNumber > 0) $ do let delay = fromMaybe 0 rsPreviousDelay `div` oneSecondInUs @@ -350,9 +359,23 @@ retryingSchemaCacheLoad appState@AppState{stateObserver=observer, stateMainThrea qSchemaCache :: IO (Maybe SchemaCache) qSchemaCache = do conf@AppConfig{..} <- getConfig appState + -- Throttle concurrent schema cache loads, guarded by advisory locks. + -- This is to prevent thundering herd problem on startup or when many PostgREST + -- instances receive "reload schema" notifications at the same time + -- See get_lock_sql for details of the algorithm. + -- Here we calculate the number of open connections passed to the query. + Metrics.ConnStats connected inUse <- Metrics.connectionCounts $ connTrack stateMetrics + -- Determine whether schema cache loading will create a new session + let + -- if all connections in use but pool not full - schema cache loading will create session + scLoadingSessions = if connected <= inUse && inUse < configDbPoolSize then 1 else 0 + withTxLock = SQL.statement + (fromIntegral $ connected + scLoadingSessions) + (SQL.Statement get_lock_sql get_lock_params HD.noResult configDbPreparedStatements) + (resultTime, result) <- let transaction = if configDbPreparedStatements then SQL.transaction else SQL.unpreparedTransaction in - timeItT $ usePool appState (transaction SQL.ReadCommitted SQL.Read $ querySchemaCache conf) + timeItT $ usePool appState (transaction SQL.ReadCommitted SQL.Read $ withTxLock *> querySchemaCache conf) case result of Left e -> do markSchemaCachePending appState @@ -369,6 +392,43 @@ retryingSchemaCacheLoad appState@AppState{stateObserver=observer, stateMainThrea observer . uncurry SchemaCacheLoadedObs =<< timeItT (evaluate $ showSummary sCache) markSchemaCacheLoaded appState return $ Just sCache + where + -- Recursive query that tries acquiring locks in order + -- and waits for randomly selected lock if no attempt succeeded. + -- It has a single parameter: this node open connection count. + -- It is used to estimate the number of nodes + -- by counting the number of active sessions for current session_user + -- and dividing it by this node open connections. + -- Assuming load is uniform among cluster nodes, all should have + -- statistically the same number of open connections. + -- Once the number of nodes is known we calculate the number + -- of locks as ceil(log(2, number_of_nodes)) + get_lock_sql = encodeUtf8 [trimming| + WITH RECURSIVE attempts AS ( + SELECT 1 AS lock_number, pg_try_advisory_xact_lock(lock_id, 1) AS success FROM parameters + UNION ALL + SELECT next_lock_number AS lock_number, pg_try_advisory_xact_lock(lock_id, next_lock_number) AS success + FROM + parameters CROSS JOIN LATERAL ( + SELECT lock_number + 1 AS next_lock_number FROM attempts + WHERE NOT success AND lock_number < locks_count + ORDER BY lock_number DESC + LIMIT 1 + ) AS previous_attempt + ), + counts AS ( + SELECT round(log(2, round(count(*)::double precision/$$1)::numeric))::int AS locks_count + FROM + pg_stat_activity WHERE usename = SESSION_USER + ), + parameters AS ( + SELECT locks_count, 50168275 AS lock_id FROM counts WHERE locks_count > 0 + ) + SELECT pg_advisory_xact_lock(lock_id, floor(random() * locks_count)::int + 1) + FROM + parameters WHERE NOT EXISTS (SELECT 1 FROM attempts WHERE success) |] + + get_lock_params = HE.param (HE.nonNullable HE.int4) shouldRetry :: RetryStatus -> (Maybe PgVersion, Maybe SchemaCache) -> IO Bool shouldRetry _ (pgVer, sCache) = do diff --git a/src/PostgREST/Metrics.hs b/src/PostgREST/Metrics.hs index 6261cbb1a7..f008f2edc7 100644 --- a/src/PostgREST/Metrics.hs +++ b/src/PostgREST/Metrics.hs @@ -5,7 +5,10 @@ Description : Metrics based on the Observation module. See Observation.hs. -} module PostgREST.Metrics ( init + , ConnTrack + , ConnStats (..) , MetricsState (..) + , connectionCounts , observationMetrics , metricsToText ) where @@ -17,12 +20,18 @@ import Prometheus import PostgREST.Observation -import Protolude +import Control.Arrow ((&&&)) +import Data.Bitraversable (bisequenceA) +import Data.Tuple.Extra (both) +import Data.UUID (UUID) +import qualified Focus +import Protolude +import qualified StmHamt.SizedHamt as SH data MetricsState = MetricsState { poolTimeouts :: Counter, - poolAvailable :: Gauge, + connTrack :: ConnTrack, poolWaiting :: Gauge, poolMaxSize :: Gauge, schemaCacheLoads :: Vector Label1 Counter, @@ -36,7 +45,7 @@ init :: Int -> IO MetricsState init configDbPoolSize = do metricState <- MetricsState <$> register (counter (Info "pgrst_db_pool_timeouts_total" "The total number of pool connection timeouts")) <*> - register (gauge (Info "pgrst_db_pool_available" "Available connections in the pool")) <*> + register (Metric ((identity &&& dbPoolAvailable) <$> connectionTracker)) <*> register (gauge (Info "pgrst_db_pool_waiting" "Requests waiting to acquire a pool connection")) <*> register (gauge (Info "pgrst_db_pool_max" "Max pool connections")) <*> register (vector "status" $ counter (Info "pgrst_schema_cache_loads_total" "The total number of times the schema cache was loaded")) <*> @@ -46,20 +55,19 @@ init configDbPoolSize = do register (counter (Info "pgrst_jwt_cache_evictions_total" "The total number of JWT cache evictions")) setGauge (poolMaxSize metricState) (fromIntegral configDbPoolSize) pure metricState + where + dbPoolAvailable = (pure . noLabelsGroup (Info "pgrst_db_pool_available" "Available connections in the pool") GaugeType . calcAvailable <$>) . connectionCounts + where + calcAvailable = (configDbPoolSize -) . inUse + toSample name labels = Sample name labels . encodeUtf8 . show + noLabelsGroup info sampleType = SampleGroup info sampleType . pure . toSample (metricName info) mempty -- Only some observations are used as metrics observationMetrics :: MetricsState -> ObservationHandler observationMetrics MetricsState{..} obs = case obs of PoolAcqTimeoutObs -> do incCounter poolTimeouts - (HasqlPoolObs (SQL.ConnectionObservation _ status)) -> case status of - SQL.ReadyForUseConnectionStatus -> do - incGauge poolAvailable - SQL.InUseConnectionStatus -> do - decGauge poolAvailable - SQL.TerminatedConnectionStatus _ -> do - decGauge poolAvailable - SQL.ConnectingConnectionStatus -> pure () + (HasqlPoolObs sqlObs) -> trackConnections connTrack sqlObs PoolRequest -> incGauge poolWaiting PoolRequestFullfilled -> @@ -77,3 +85,28 @@ observationMetrics MetricsState{..} obs = case obs of metricsToText :: IO LBS.ByteString metricsToText = exportMetricsAsText + +data ConnStats = ConnStats { + connected :: Int, + inUse :: Int +} deriving (Eq, Show) + +data ConnTrack = ConnTrack { connTrackConnected :: SH.SizedHamt UUID, connTrackInUse :: SH.SizedHamt UUID } + +connectionTracker :: IO ConnTrack +connectionTracker = ConnTrack <$> SH.newIO <*> SH.newIO + +trackConnections :: ConnTrack -> SQL.Observation -> IO () +trackConnections ConnTrack{..} (SQL.ConnectionObservation uuid status) = case status of + SQL.ReadyForUseConnectionStatus -> atomically $ + SH.insert identity uuid connTrackConnected *> + SH.focus Focus.delete identity uuid connTrackInUse + SQL.TerminatedConnectionStatus _ -> atomically $ + SH.focus Focus.delete identity uuid connTrackConnected *> + SH.focus Focus.delete identity uuid connTrackInUse + SQL.InUseConnectionStatus -> atomically $ + SH.insert identity uuid connTrackInUse + _ -> mempty + +connectionCounts :: ConnTrack -> IO ConnStats +connectionCounts = atomically . fmap (uncurry ConnStats) . bisequenceA . both SH.size . (connTrackConnected &&& connTrackInUse) diff --git a/test/io/test_io.py b/test/io/test_io.py index 86310cd4be..3293bad56e 100644 --- a/test/io/test_io.py +++ b/test/io/test_io.py @@ -1,5 +1,6 @@ "Unit tests for Input/Ouput of PostgREST seen as a black box." +import contextlib import os import re import signal @@ -18,6 +19,7 @@ sleep_until_postgrest_full_reload, sleep_until_postgrest_scache_reload, wait_until_exit, + wait_until_status_code, ) @@ -1219,6 +1221,91 @@ def test_schema_cache_concurrent_notifications(slow_schema_cache_env): assert response.status_code == 200 +@pytest.mark.parametrize( + "instance_count, expected_concurrency", [(2, 2), (4, 3), (6, 4), (8, 4), (16, 5)] +) +def test_schema_cache_reload_throttled_with_advisory_locks( + instance_count, expected_concurrency, slow_schema_cache_env +): + "schema cache reloads should be throttled across instances if instance count > 10" + + internal_sleep_ms = int( + slow_schema_cache_env["PGRST_INTERNAL_SCHEMA_CACHE_QUERY_SLEEP"] + ) + lock_wait_threshold_ms = internal_sleep_ms * 2 + query_log_pattern = re.compile(r"Schema cache queried in ([\d.]+) milliseconds") + + def read_available_output_lines(postgrest): + try: + output = postgrest.process.stdout.read() + except BlockingIOError: + return [] + + if not output: + return [] + return output.decode().splitlines() + + with contextlib.ExitStack() as stack: + instances = [ + stack.enter_context( + run( + env=slow_schema_cache_env, + wait_for_readiness=False, + wait_max_seconds=10, + ) + ) + for _ in range(instance_count) + ] + + for postgrest in instances: + wait_until_status_code( + postgrest.admin.baseurl + "/ready", max_seconds=10, status_code=200 + ) + + # Drop startup logs so only reload logs are parsed. + for postgrest in instances: + read_available_output_lines(postgrest) + + response = instances[0].session.get("/rpc/notify_pgrst") + assert response.status_code == 204 + + # Wait long enough for the lock-throttled cache reloads to finish. + time.sleep((internal_sleep_ms / 1000) * 2) + + reload_durations_ms = [] + for postgrest in instances: + output_lines = [] + for _ in range(instance_count * 2): + output_lines.extend(read_available_output_lines(postgrest)) + if any(query_log_pattern.search(line) for line in output_lines): + break + time.sleep(0.2) + + durations = [] + for line in output_lines: + match = query_log_pattern.search(line) + if match: + durations.append(float(match.group(1))) + + assert durations + reload_durations_ms.append(max(durations)) + + assert len(reload_durations_ms) == instance_count + + # 10 instances should be fast, remaining instances should be slow + assert ( + instance_count + - len( + [ + duration + for duration in reload_durations_ms + if duration > lock_wait_threshold_ms + ] + ) + == expected_concurrency + ) + + def test_schema_cache_query_sleep_logs(defaultenv): """Schema cache sleep should be reflected in the logged query duration.""" @@ -1856,7 +1943,7 @@ def test_requests_with_resource_embedding_wait_for_schema_cache_reload(defaulten env = { **defaultenv, "PGRST_DB_POOL": "2", - "PGRST_INTERNAL_SCHEMA_CACHE_RELATIONSHIP_LOAD_SLEEP": "5100", + "PGRST_INTERNAL_SCHEMA_CACHE_RELATIONSHIP_LOAD_SLEEP": "5200", } with run(env=env, wait_max_seconds=30) as postgrest: diff --git a/test/observability/Observation/MetricsSpec.hs b/test/observability/Observation/MetricsSpec.hs index 524e0c1018..70bee440c2 100644 --- a/test/observability/Observation/MetricsSpec.hs +++ b/test/observability/Observation/MetricsSpec.hs @@ -6,17 +6,20 @@ module Observation.MetricsSpec where -import Data.List (lookup) -import Network.Wai (Application) +import Data.List (lookup) +import qualified Hasql.Pool.Observation as SQL +import Network.Wai (Application) import ObsHelper -import qualified PostgREST.AppState as AppState -import PostgREST.Config (AppConfig (configDbSchemas)) -import qualified PostgREST.Metrics as Metrics +import qualified PostgREST.AppState as AppState +import PostgREST.Config (AppConfig (configDbSchemas)) +import PostgREST.Metrics (ConnStats (..), + MetricsState (..), + connectionCounts) import PostgREST.Observation -import Prometheus (getCounter, getVectorWith) +import Prometheus (getCounter, getVectorWith) import Protolude -import Test.Hspec (SpecWith, describe, it) -import Test.Hspec.Wai (getState) +import Test.Hspec (SpecWith, describe, it) +import Test.Hspec.Wai (getState) spec :: SpecWith (SpecState, Application) spec = describe "Server started with metrics enabled" $ do @@ -71,9 +74,33 @@ spec = describe "Server started with metrics enabled" $ do -- (there should be none but we need to verify that) threadDelay $ 1 * sec + it "Should track in use connections" $ do + SpecState{specAppState = appState, specMetrics = metrics, specObsChan} <- getState + let waitFor = waitForObs specObsChan + + liftIO $ checkState' metrics [ + -- we expect in use connections to be the same once finished + inUseConnections (+ 0) + ] $ do + signal <- newEmptyMVar + -- make sure waiting thread is signaled + bracket_ (pure ()) (putMVar signal ()) $ + -- expecting one more connection in use + checkState' metrics [ + inUseConnections (+ 1) + ] $ do + -- start a thread hanging on a single connection until signaled + void $ forkIO $ void $ AppState.usePool appState $ liftIO (readMVar signal) + -- main thread waits for ConnectionObservation with InUseConnectionStatus + -- after which used connections count should be incremented + waitFor (1 * sec) "InUseConnectionStatus" $ \x -> [ o | o@(HasqlPoolObs (SQL.ConnectionObservation _ SQL.InUseConnectionStatus)) <- pure x] + + -- hanging thread was signaled and should return the connection + waitFor (1 * sec) "ReadyForUseConnectionStatus" $ \x -> [ o | o@(HasqlPoolObs (SQL.ConnectionObservation _ SQL.ReadyForUseConnectionStatus)) <- pure x] where -- prometheus-client api to handle vectors is convoluted schemaCacheLoads label = expectField @"schemaCacheLoads" $ fmap (maybe (0::Int) round . lookup label) . (`getVectorWith` getCounter) + inUseConnections = expectField @"connTrack" ((inUse <$>) . connectionCounts) sec = 1000000