diff --git a/postgrest.cabal b/postgrest.cabal index afe56c658d..d88a035db2 100644 --- a/postgrest.cabal +++ b/postgrest.cabal @@ -307,6 +307,7 @@ test-suite observability main-is: Main.hs other-modules: ObsHelper Observation.JwtCache + Observation.MetricsSpec build-depends: base >= 4.9 && < 4.20 , base64-bytestring >= 1 && < 1.3 , bytestring >= 0.10.8 && < 0.13 @@ -321,6 +322,7 @@ test-suite observability , postgrest , prometheus-client >= 1.1.1 && < 1.2.0 , protolude >= 0.3.1 && < 0.4 + , text >= 1.2.2 && < 2.2 , wai >= 3.2.1 && < 3.3 ghc-options: -threaded -O0 -Werror -Wall -fwarn-identities -fno-spec-constr -optP-Wno-nonportable-include-path diff --git a/src/PostgREST/AppState.hs b/src/PostgREST/AppState.hs index 867123ee50..bed8e37b3a 100644 --- a/src/PostgREST/AppState.hs +++ b/src/PostgREST/AppState.hs @@ -15,6 +15,7 @@ module PostgREST.AppState , getJwtCacheState , init , initWithPool + , putConfig -- For tests TODO refactoring , putNextListenerDelay , putSchemaCache , putPgVersion diff --git a/src/PostgREST/Observation.hs b/src/PostgREST/Observation.hs index e96c867480..1c0dd1da91 100644 --- a/src/PostgREST/Observation.hs +++ b/src/PostgREST/Observation.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE DeriveGeneric #-} {-| Module : PostgREST.Observation Description : This module holds an Observation type which is the core of Observability for PostgREST. @@ -56,6 +57,7 @@ data Observation | JwtCacheEviction | TerminationUnixSignalObs Text | WarpErrorObs Text + deriving (Generic) data ObsFatalError = ServerAuthError | ServerPgrstBug | ServerError42P05 | ServerError08P01 diff --git a/test/observability/Main.hs b/test/observability/Main.hs index 637b632d73..7a38b51c81 100644 --- a/test/observability/Main.hs +++ b/test/observability/Main.hs @@ -15,19 +15,31 @@ import qualified PostgREST.Metrics as Metrics import PostgREST.SchemaCache (querySchemaCache) import qualified Observation.JwtCache +import qualified Observation.MetricsSpec import ObsHelper -import Protolude hiding (toList, toS) +import PostgREST.Observation (Observation (HasqlPoolObs)) +import Protolude hiding (toList, toS) import Test.Hspec main :: IO () main = do + poolChan <- newChan + -- make sure poolChan is not growing indefinitely + -- start a thread that drains the channel + -- this is necessary because test cases operate on + -- copies so poolChan is never read from + -- this means we have another thread running for the entire duration of the spec but this shouldn't be a problem since Haskell green threads are lightweight + void $ forkIO $ forever $ readChan poolChan + metricsState <- Metrics.init (configDbPoolSize testCfg) pool <- P.acquire $ P.settings [ P.size 3 , P.acquisitionTimeout 10 , P.agingTimeout 60 , P.idlenessTimeout 60 , P.staticConnectionSettings (toUtf8 $ configDbUri testCfg) + -- make sure metrics are updated and pool observations published to poolChan + , P.observationHandler $ (writeChan poolChan <> Metrics.observationMetrics metricsState) . HasqlPoolObs ] actualPgVersion <- either (panic . show) id <$> P.use pool (queryPgVersion False) @@ -35,19 +47,23 @@ main = do -- cached schema cache so most tests run fast baseSchemaCache <- loadSCache pool testCfg loggerState <- Logger.init - metricsState <- Metrics.init (configDbPoolSize testCfg) let - initApp sCache st config = do - appState <- AppState.initWithPool pool config loggerState metricsState (Metrics.observationMetrics metricsState) + initApp sCache config = do + -- duplicate poolChan as a starting point + obsChan <- dupChan poolChan + stateObsChan <- newObsChan obsChan + appState <- AppState.initWithPool pool config loggerState metricsState (Metrics.observationMetrics metricsState <> writeChan obsChan) AppState.putPgVersion appState actualPgVersion AppState.putSchemaCache appState (Just sCache) - return (st, postgrest (configLogLevel config) appState (pure ())) + return (SpecState appState metricsState stateObsChan, postgrest (configLogLevel config) appState (pure ())) -- Run all test modules hspec $ do - before (initApp baseSchemaCache metricsState testCfgJwtCache) $ + before (initApp baseSchemaCache testCfgJwtCache) $ describe "Observation.JwtCacheObs" Observation.JwtCache.spec + before (initApp baseSchemaCache testCfg) $ + describe "Feature.MetricsSpec" Observation.MetricsSpec.spec where loadSCache pool conf = diff --git a/test/observability/ObsHelper.hs b/test/observability/ObsHelper.hs index 2f70c7e150..cabe8c42de 100644 --- a/test/observability/ObsHelper.hs +++ b/test/observability/ObsHelper.hs @@ -1,32 +1,71 @@ {-# LANGUAGE AllowAmbiguousTypes #-} +{-# LANGUAGE DeriveAnyClass #-} {-# LANGUAGE ExistentialQuantification #-} {-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TupleSections #-} {-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeOperators #-} module ObsHelper where -import qualified Data.ByteString.Base64 as B64 (decodeLenient) -import qualified Data.ByteString.Char8 as BS -import qualified Data.ByteString.Lazy as BL -import qualified Jose.Jwa as JWT -import qualified Jose.Jws as JWT -import qualified Jose.Jwt as JWT +import qualified Data.ByteString as BS +import qualified Data.ByteString.Base64 as B64 +import qualified Data.ByteString.Lazy as BL +import qualified Data.List as DL +import Data.List.NonEmpty (fromList) +import Data.String (String) +import qualified Data.Text as T +import qualified Jose.Jwa as JWT +import qualified Jose.Jws as JWT +import qualified Jose.Jwt as JWT +import Network.HTTP.Types +import qualified PostgREST.AppState as AppState +import PostgREST.Config (AppConfig (..), + JSPathExp (..), + LogLevel (..), + OpenAPIMode (..), + Verbosity (..), + parseSecret) +import qualified PostgREST.Metrics as Metrics +import PostgREST.Observation (Observation (..)) +import Prometheus (Counter, getCounter) +import Protolude hiding (get, toS) +import System.Timeout (timeout) +import Test.Hspec +import Test.Hspec.Expectations.Contrib (annotate) + +-- helpers used to produce observation diagnostics in waitForObs +-- Implementing the Show instance for Observation is hard due to having many different parameters so instead we use generic programming (`conName`) to obtain the constructor name as `Text` +class HasConstructor f where + genericConstrName :: f x -> Text + +instance HasConstructor f => HasConstructor (D1 c f) where + genericConstrName (M1 x) = genericConstrName x + +instance (HasConstructor x, HasConstructor y) => HasConstructor (x :+: y) where + genericConstrName (L1 l) = genericConstrName l + genericConstrName (R1 r) = genericConstrName r + +instance Constructor c => HasConstructor (C1 c f) where + genericConstrName = T.pack . conName + +data SpecState = SpecState { + specAppState :: AppState.AppState, + specMetrics :: Metrics.MetricsState, + specObsChan :: ObsChan +} -import PostgREST.Config (AppConfig (..), JSPathExp (..), - LogLevel (..), OpenAPIMode (..), - Verbosity (..), parseSecret) +data StateCheck st m = forall a. StateCheck (st -> (String, m a)) (a -> a -> Expectation) -import Data.List.NonEmpty (fromList) -import Data.String (String) -import Prometheus (Counter, getCounter) -import Test.Hspec.Expectations.Contrib (annotate) +data TimeoutException = TimeoutException deriving (Show, Exception) -import Network.HTTP.Types -import Protolude -import Test.Hspec -import Test.Hspec.Wai +data ObsChan = ObsChan (Chan Observation) (Chan Observation) +constrName :: (HasConstructor (Rep a), Generic a)=> a -> Text +constrName = genericConstrName . from baseCfg :: AppConfig baseCfg = let secret = encodeUtf8 "reallyreallyreallyreallyverysafe" in @@ -109,18 +148,12 @@ generateJWT claims = either mempty JWT.unJwt $ JWT.hmacEncode JWT.HS256 generateSecret (BL.toStrict claims) -- state check helpers - -data StateCheck st m = forall a. StateCheck (st -> (String, m a)) (a -> a -> Expectation) - stateCheck :: (Show a, Eq a) => (c -> m a) -> (st -> (String, c)) -> (a -> a) -> StateCheck st m stateCheck extractValue extractComponent expect = StateCheck (second extractValue . extractComponent) (flip shouldBe . expect) expectField :: forall s st a c m. (KnownSymbol s, Show a, Eq a, HasField s st c) => (c -> m a) -> (a -> a) -> StateCheck st m expectField extractValue = stateCheck extractValue ((symbolVal (Proxy @s),) . getField @s) -checkState :: (Traversable t) => t (StateCheck st (WaiSession st)) -> WaiSession st b -> WaiSession st () -checkState checks act = getState >>= flip (`checkState'` checks) act - checkState' :: (Traversable t, MonadIO m) => st -> t (StateCheck st m) -> m b -> m () checkState' initialState checks act = do expectations <- traverse (\(StateCheck g expect) -> let (msg, m) = g initialState in m >>= createExpectation msg m . expect) checks @@ -133,3 +166,48 @@ expectCounter :: forall s st m. (KnownSymbol s, HasField s st Counter, MonadIO m expectCounter = expectField @s intCounter where intCounter = ((round @Double @Int) <$>) . getCounter + +accumulateUntilTimeout :: Int -> (s -> a -> s) -> s -> IO a -> IO s +accumulateUntilTimeout t f start act = do + tid <- myThreadId + -- mask to make sure TimeoutException is not thrown before starting the loop + mask $ \unmask -> do + -- start timeout thread unmasking exceptions + ttid <- forkIOWithUnmask ($ (threadDelay t *> throwTo tid TimeoutException)) + -- unmask effect + unmask (fix (\loop accum -> (act >>= loop . f accum) `onTimeout` pure accum) start) + -- make sure we catch timeout if happens before entering the loop + `onTimeout` pure start + -- make sure timer thread is killed on other exceptions + -- so that it won't throw TimeoutException later + `onException` killThread ttid + where + onTimeout m a = m `catch` \TimeoutException -> a + +newObsChan :: Chan Observation -> IO ObsChan +newObsChan = fmap <$> ObsChan <*> dupChan + +-- read messages from copy chan and once condition is met drain original to the same point +-- upon timeout report error and messages remaining in the original chan +-- that way we report messages since last successful read +waitForObs :: HasCallStack => ObsChan -> Int -> Text -> (Observation -> Maybe a) -> IO () +waitForObs (ObsChan orig copy) t msg f = + timeout t (readUntil copy *> readUntil orig) >>= maybe failTimeout mempty + where + failTimeout = takeUntilTimeout decisecond (readChan orig) + >>= expectationFailure . DL.unlines . fmap show . (failureMessageHeader :) . fmap obsDiagMessage + failureMessageHeader = "Timeout waiting for " <> msg <> " at " <> loc <> ". Remaining observations:" + readUntil = void . untilM (pure . not . null . f) . readChan + loc = fromMaybe "(unknown)" . head $ (T.pack . prettySrcLoc . snd <$> getCallStack callStack) + -- execute effectful computation until result meets provided condition + untilM cond m = fix $ \loop -> m >>= \a -> ifM (cond a) (pure a) loop + -- duplicate the provided channel and construct wairFor function binding both channels + -- accumulate effecful computation results into a list for specified time + takeUntilTimeout t' = fmap reverse . accumulateUntilTimeout t' (flip (:)) [] + decisecond = 100000 + +obsDiagMessage :: Observation -> Text +obsDiagMessage = \case + (HasqlPoolObs o) -> show o + o@(DBListenStart host port name channel) -> constrName o <> show (host, port, name, channel) + o -> constrName o diff --git a/test/observability/Observation/JwtCache.hs b/test/observability/Observation/JwtCache.hs index 56e83680a1..2631c82b05 100644 --- a/test/observability/Observation/JwtCache.hs +++ b/test/observability/Observation/JwtCache.hs @@ -1,4 +1,5 @@ {-# LANGUAGE DataKinds #-} +{-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE TypeApplications #-} module Observation.JwtCache where @@ -13,9 +14,11 @@ import PostgREST.Metrics (MetricsState (..)) import Protolude import Test.Hspec.Wai.JSON (json) -spec :: SpecWith (MetricsState, Application) +spec :: SpecWith (SpecState, Application) spec = describe "Server started with JWT and metrics enabled" $ do it "Should not have JWT in cache" $ do + expectCounters <- checkState' . specMetrics <$> getState + let auth = genToken [json|{"exp": 9999999999, "role": "postgrest_test_author", "id": "jdoe1"}|] expectCounters @@ -27,6 +30,8 @@ spec = describe "Server started with JWT and metrics enabled" $ do request methodGet "/authors_only" [auth] "" `shouldRespondWith` 200 it "Should have JWT in cache" $ do + expectCounters <- checkState' . specMetrics <$> getState + let auth = genToken [json|{"exp": 9999999999, "role": "postgrest_test_author", "id": "jdoe2"}|] expectCounters @@ -39,6 +44,8 @@ spec = describe "Server started with JWT and metrics enabled" $ do *> request methodGet "/authors_only" [auth] "" `shouldRespondWith` 200 it "Should not cache invalid JWTs" $ do + expectCounters <- checkState' . specMetrics <$> getState + let auth = authHeaderJWT "some random bytes" expectCounters @@ -51,6 +58,8 @@ spec = describe "Server started with JWT and metrics enabled" $ do *> request methodGet "/authors_only" [auth] "" `shouldRespondWith` 401 it "Should cache expired JWTs" $ do + expectCounters <- checkState' . specMetrics <$> getState + let auth = genToken [json|{"exp": 1, "role": "postgrest_test_author", "id": "jdoe2"}|] expectCounters @@ -63,6 +72,8 @@ spec = describe "Server started with JWT and metrics enabled" $ do *> request methodGet "/authors_only" [auth] "" `shouldRespondWith` 401 it "Should evict entries from the JWT cache (jwt cache max is 2)" $ do + expectCounters <- checkState' . specMetrics <$> getState + let jwt1 = genToken [json|{"exp": 9999999999, "role": "postgrest_test_author", "id": "jdoe3"}|] jwt2 = genToken [json|{"exp": 9999999999, "role": "postgrest_test_author", "id": "jdoe4"}|] jwt3 = genToken [json|{"exp": 9999999999, "role": "postgrest_test_author", "id": "jdoe5"}|] @@ -82,6 +93,8 @@ spec = describe "Server started with JWT and metrics enabled" $ do *> request methodGet "/authors_only" [jwt3] "" it "Should not evict entries from the JWT cache in FIFO order" $ do + expectCounters <- checkState' . specMetrics <$> getState + let jwt1 = genToken [json|{"exp": 9999999999, "role": "postgrest_test_author", "id": "jdoe6"}|] jwt2 = genToken [json|{"exp": 9999999999, "role": "postgrest_test_author", "id": "jdoe7"}|] jwt3 = genToken [json|{"exp": 9999999999, "role": "postgrest_test_author", "id": "jdoe8"}|] @@ -108,6 +121,8 @@ spec = describe "Server started with JWT and metrics enabled" $ do -- The test case was added based on coverage report -- showing this scenario was not covered by previous tests it "Should evict entries even though all were hit" $ do + expectCounters <- checkState' . specMetrics <$> getState + let jwt1 = genToken [json|{"exp": 9999999999, "role": "postgrest_test_author", "id": "jdoe9"}|] jwt2 = genToken [json|{"exp": 9999999999, "role": "postgrest_test_author", "id": "jdoe10"}|] jwt3 = genToken [json|{"exp": 9999999999, "role": "postgrest_test_author", "id": "jdoe11"}|] @@ -135,4 +150,3 @@ spec = describe "Server started with JWT and metrics enabled" $ do requests = expectCounter @"jwtCacheRequests" hits = expectCounter @"jwtCacheHits" evictions = expectCounter @"jwtCacheEvictions" - expectCounters = checkState diff --git a/test/observability/Observation/MetricsSpec.hs b/test/observability/Observation/MetricsSpec.hs new file mode 100644 index 0000000000..f21f711791 --- /dev/null +++ b/test/observability/Observation/MetricsSpec.hs @@ -0,0 +1,54 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE MonadComprehensions #-} +{-# LANGUAGE NamedFieldPuns #-} +{-# LANGUAGE TypeApplications #-} + +module Observation.MetricsSpec where + +import Data.List (lookup) +import Network.Wai (Application) +import ObsHelper +import qualified PostgREST.AppState as AppState +import PostgREST.Config (AppConfig (configDbSchemas)) +import qualified PostgREST.Metrics as Metrics +import PostgREST.Observation +import Prometheus (getCounter, getVectorWith) +import Protolude +import Test.Hspec (SpecWith, describe, it) +import Test.Hspec.Wai (getState) + +spec :: SpecWith (SpecState, Application) +spec = describe "Server started with metrics enabled" $ do + it "Should update pgrst_schema_cache_loads_total[SUCCESS]" $ do + SpecState{specAppState = appState, specMetrics = metrics, specObsChan} <- getState + let waitFor = waitForObs specObsChan + + liftIO $ checkState' metrics [ + schemaCacheLoads "SUCCESS" (+1) + ] $ do + AppState.schemaCacheLoader appState + waitFor (1 * sec) "SchemaCacheLoadedObs" $ \x -> [ o | o@(SchemaCacheLoadedObs{}) <- pure x] + + it "Should update pgrst_schema_cache_loads_total[ERROR]" $ do + SpecState{specAppState = appState, specMetrics = metrics, specObsChan} <- getState + let waitFor = waitForObs specObsChan + + liftIO $ checkState' metrics [ + schemaCacheLoads "FAIL" (+1), + schemaCacheLoads "SUCCESS" (+1) + ] $ do + AppState.getConfig appState >>= \prev -> do + AppState.putConfig appState $ prev { configDbSchemas = pure "bad_schema" } + AppState.schemaCacheLoader appState + waitFor (1 * sec) "SchemaCacheErrorObs" $ \x -> [ o | o@(SchemaCacheErrorObs{}) <- pure x] + AppState.putConfig appState prev + + -- wait up to 2 secs so that retry can happen + waitFor (2 * sec) "SchemaCacheLoadedObs" $ \x -> [ o | o@(SchemaCacheLoadedObs{}) <- pure x] + + where + -- prometheus-client api to handle vectors is convoluted + schemaCacheLoads label = expectField @"schemaCacheLoads" $ + fmap (maybe (0::Int) round . lookup label) . (`getVectorWith` getCounter) + sec = 1000000