Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions postgrest.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,7 @@ test-suite observability
main-is: Main.hs
other-modules: ObsHelper
Observation.JwtCache
Observation.MetricsSpec
build-depends: base >= 4.9 && < 4.20
, base64-bytestring >= 1 && < 1.3
, bytestring >= 0.10.8 && < 0.13
Expand All @@ -321,6 +322,7 @@ test-suite observability
, postgrest
, prometheus-client >= 1.1.1 && < 1.2.0
, protolude >= 0.3.1 && < 0.4
, text >= 1.2.2 && < 2.2
, wai >= 3.2.1 && < 3.3
ghc-options: -threaded -O0 -Werror -Wall -fwarn-identities
-fno-spec-constr -optP-Wno-nonportable-include-path
Expand Down
1 change: 1 addition & 0 deletions src/PostgREST/AppState.hs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ module PostgREST.AppState
, getJwtCacheState
, init
, initWithPool
, putConfig -- For tests TODO refactoring
, putNextListenerDelay
, putSchemaCache
, putPgVersion
Expand Down
2 changes: 2 additions & 0 deletions src/PostgREST/Observation.hs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
{-# LANGUAGE DeriveGeneric #-}
{-|
Module : PostgREST.Observation
Description : This module holds an Observation type which is the core of Observability for PostgREST.
Expand Down Expand Up @@ -56,6 +57,7 @@ data Observation
| JwtCacheEviction
| TerminationUnixSignalObs Text
| WarpErrorObs Text
deriving (Generic)

data ObsFatalError = ServerAuthError | ServerPgrstBug | ServerError42P05 | ServerError08P01

Expand Down
27 changes: 21 additions & 6 deletions test/observability/Main.hs
Original file line number Diff line number Diff line change
Expand Up @@ -15,39 +15,54 @@ import qualified PostgREST.Metrics as Metrics
import PostgREST.SchemaCache (querySchemaCache)

import qualified Observation.JwtCache
import qualified Observation.MetricsSpec

import ObsHelper
import Protolude hiding (toList, toS)
import PostgREST.Observation (Observation (HasqlPoolObs))
import Protolude hiding (toList, toS)
import Test.Hspec

main :: IO ()
main = do
poolChan <- newChan
-- make sure poolChan is not growing indefinitely
-- start a thread that drains the channel
-- this is necessary because test cases operate on
-- copies so poolChan is never read from
void $ forkIO $ forever $ readChan poolChan
metricsState <- Metrics.init (configDbPoolSize testCfg)
pool <- P.acquire $ P.settings
[ P.size 3
, P.acquisitionTimeout 10
, P.agingTimeout 60
, P.idlenessTimeout 60
, P.staticConnectionSettings (toUtf8 $ configDbUri testCfg)
-- make sure metrics are updated and pool observations published to poolChan
, P.observationHandler $ (writeChan poolChan <> Metrics.observationMetrics metricsState) . HasqlPoolObs
]

actualPgVersion <- either (panic . show) id <$> P.use pool (queryPgVersion False)

-- cached schema cache so most tests run fast
baseSchemaCache <- loadSCache pool testCfg
loggerState <- Logger.init
metricsState <- Metrics.init (configDbPoolSize testCfg)

let
initApp sCache st config = do
appState <- AppState.initWithPool pool config loggerState metricsState (Metrics.observationMetrics metricsState)
initApp sCache config = do
-- duplicate poolChan as a starting point
obsChan <- dupChan poolChan
stateObsChan <- newObsChan obsChan
appState <- AppState.initWithPool pool config loggerState metricsState (Metrics.observationMetrics metricsState <> writeChan obsChan)
AppState.putPgVersion appState actualPgVersion
AppState.putSchemaCache appState (Just sCache)
return (st, postgrest (configLogLevel config) appState (pure ()))
return (SpecState appState metricsState stateObsChan, postgrest (configLogLevel config) appState (pure ()))

-- Run all test modules
hspec $ do
before (initApp baseSchemaCache metricsState testCfgJwtCache) $
before (initApp baseSchemaCache testCfgJwtCache) $
describe "Observation.JwtCacheObs" Observation.JwtCache.spec
before (initApp baseSchemaCache testCfg) $
describe "Feature.MetricsSpec" Observation.MetricsSpec.spec

where
loadSCache pool conf =
Expand Down
123 changes: 100 additions & 23 deletions test/observability/ObsHelper.hs
Original file line number Diff line number Diff line change
@@ -1,32 +1,70 @@
{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}
module ObsHelper where

import qualified Data.ByteString.Base64 as B64 (decodeLenient)
import qualified Data.ByteString.Char8 as BS
import qualified Data.ByteString.Lazy as BL
import qualified Jose.Jwa as JWT
import qualified Jose.Jws as JWT
import qualified Jose.Jwt as JWT
import qualified Data.ByteString as BS
import qualified Data.ByteString.Base64 as B64
import qualified Data.ByteString.Lazy as BL
import qualified Data.List as DL
import Data.List.NonEmpty (fromList)
import Data.String (String)
import qualified Data.Text as T
import qualified Jose.Jwa as JWT
import qualified Jose.Jws as JWT
import qualified Jose.Jwt as JWT
import Network.HTTP.Types
import qualified PostgREST.AppState as AppState
import PostgREST.Config (AppConfig (..),
JSPathExp (..),
LogLevel (..),
OpenAPIMode (..),
Verbosity (..),
parseSecret)
import qualified PostgREST.Metrics as Metrics
import PostgREST.Observation (Observation (..))
import Prometheus (Counter, getCounter)
import Protolude hiding (get, toS)
import System.Timeout (timeout)
import Test.Hspec
import Test.Hspec.Expectations.Contrib (annotate)

-- helpers used to produce observation diagnostics in waitForObs
class HasConstructor f where
genericConstrName :: f x -> Text

instance HasConstructor f => HasConstructor (D1 c f) where
genericConstrName (M1 x) = genericConstrName x

instance (HasConstructor x, HasConstructor y) => HasConstructor (x :+: y) where
genericConstrName (L1 l) = genericConstrName l
genericConstrName (R1 r) = genericConstrName r

instance Constructor c => HasConstructor (C1 c f) where
genericConstrName = T.pack . conName

data SpecState = SpecState {
specAppState :: AppState.AppState,
specMetrics :: Metrics.MetricsState,
specObsChan :: ObsChan
}

import PostgREST.Config (AppConfig (..), JSPathExp (..),
LogLevel (..), OpenAPIMode (..),
Verbosity (..), parseSecret)
data StateCheck st m = forall a. StateCheck (st -> (String, m a)) (a -> a -> Expectation)

import Data.List.NonEmpty (fromList)
import Data.String (String)
import Prometheus (Counter, getCounter)
import Test.Hspec.Expectations.Contrib (annotate)
data TimeoutException = TimeoutException deriving (Show, Exception)

import Network.HTTP.Types
import Protolude
import Test.Hspec
import Test.Hspec.Wai
data ObsChan = ObsChan (Chan Observation) (Chan Observation)

constrName :: (HasConstructor (Rep a), Generic a)=> a -> Text
constrName = genericConstrName . from

baseCfg :: AppConfig
baseCfg = let secret = encodeUtf8 "reallyreallyreallyreallyverysafe" in
Expand Down Expand Up @@ -109,18 +147,12 @@ generateJWT claims =
either mempty JWT.unJwt $ JWT.hmacEncode JWT.HS256 generateSecret (BL.toStrict claims)

-- state check helpers

data StateCheck st m = forall a. StateCheck (st -> (String, m a)) (a -> a -> Expectation)

stateCheck :: (Show a, Eq a) => (c -> m a) -> (st -> (String, c)) -> (a -> a) -> StateCheck st m
stateCheck extractValue extractComponent expect = StateCheck (second extractValue . extractComponent) (flip shouldBe . expect)

expectField :: forall s st a c m. (KnownSymbol s, Show a, Eq a, HasField s st c) => (c -> m a) -> (a -> a) -> StateCheck st m
expectField extractValue = stateCheck extractValue ((symbolVal (Proxy @s),) . getField @s)

checkState :: (Traversable t) => t (StateCheck st (WaiSession st)) -> WaiSession st b -> WaiSession st ()
checkState checks act = getState >>= flip (`checkState'` checks) act

checkState' :: (Traversable t, MonadIO m) => st -> t (StateCheck st m) -> m b -> m ()
checkState' initialState checks act = do
expectations <- traverse (\(StateCheck g expect) -> let (msg, m) = g initialState in m >>= createExpectation msg m . expect) checks
Expand All @@ -133,3 +165,48 @@ expectCounter :: forall s st m. (KnownSymbol s, HasField s st Counter, MonadIO m
expectCounter = expectField @s intCounter
where
intCounter = ((round @Double @Int) <$>) . getCounter

accumulateUntilTimeout :: Int -> (s -> a -> s) -> s -> IO a -> IO s
accumulateUntilTimeout t f start act = do
tid <- myThreadId
-- mask to make sure TimeoutException is not thrown before starting the loop
mask $ \unmask -> do
-- start timeout thread unmasking exceptions
ttid <- forkIOWithUnmask ($ (threadDelay t *> throwTo tid TimeoutException))
-- unmask effect
unmask (fix (\loop accum -> (act >>= loop . f accum) `onTimeout` pure accum) start)
-- make sure we catch timeout if happens before entering the loop
`onTimeout` pure start
-- make sure timer thread is killed on other exceptions
-- so that it won't throw TimeoutException later
`onException` killThread ttid
where
onTimeout m a = m `catch` \TimeoutException -> a

newObsChan :: Chan Observation -> IO ObsChan
newObsChan = fmap <$> ObsChan <*> dupChan

-- read messages from copy chan and once condition is met drain original to the same point
-- upon timeout report error and messages remaining in the original chan
-- that way we report messages since last successful read
waitForObs :: HasCallStack => ObsChan -> Int -> Text -> (Observation -> Maybe a) -> IO ()
waitForObs (ObsChan orig copy) t msg f =
timeout t (readUntil copy *> readUntil orig) >>= maybe failTimeout mempty
where
failTimeout = takeUntilTimeout decisecond (readChan orig)
>>= expectationFailure . DL.unlines . fmap show . (failureMessageHeader :) . fmap obsDiagMessage
failureMessageHeader = "Timeout waiting for " <> msg <> " at " <> loc <> ". Remaining observations:"
readUntil = void . untilM (pure . not . null . f) . readChan
loc = fromMaybe "(unknown)" . head $ (T.pack . prettySrcLoc . snd <$> getCallStack callStack)
-- execute effectful computation until result meets provided condition
untilM cond m = fix $ \loop -> m >>= \a -> ifM (cond a) (pure a) loop
-- duplicate the provided channel and construct wairFor function binding both channels
-- accumulate effecful computation results into a list for specified time
takeUntilTimeout t' = fmap reverse . accumulateUntilTimeout t' (flip (:)) []
decisecond = 100000

obsDiagMessage :: Observation -> Text
obsDiagMessage = \case
(HasqlPoolObs o) -> show o
o@(DBListenStart host port name channel) -> constrName o <> show (host, port, name, channel)
o -> constrName o
18 changes: 16 additions & 2 deletions test/observability/Observation/JwtCache.hs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE TypeApplications #-}
module Observation.JwtCache where

Expand All @@ -13,9 +14,11 @@ import PostgREST.Metrics (MetricsState (..))
import Protolude
import Test.Hspec.Wai.JSON (json)

spec :: SpecWith (MetricsState, Application)
spec :: SpecWith (SpecState, Application)
spec = describe "Server started with JWT and metrics enabled" $ do
it "Should not have JWT in cache" $ do
expectCounters <- checkState' . specMetrics <$> getState

let auth = genToken [json|{"exp": 9999999999, "role": "postgrest_test_author", "id": "jdoe1"}|]

expectCounters
Expand All @@ -27,6 +30,8 @@ spec = describe "Server started with JWT and metrics enabled" $ do
request methodGet "/authors_only" [auth] "" `shouldRespondWith` 200

it "Should have JWT in cache" $ do
expectCounters <- checkState' . specMetrics <$> getState

let auth = genToken [json|{"exp": 9999999999, "role": "postgrest_test_author", "id": "jdoe2"}|]

expectCounters
Expand All @@ -39,6 +44,8 @@ spec = describe "Server started with JWT and metrics enabled" $ do
*> request methodGet "/authors_only" [auth] "" `shouldRespondWith` 200

it "Should not cache invalid JWTs" $ do
expectCounters <- checkState' . specMetrics <$> getState

let auth = authHeaderJWT "some random bytes"

expectCounters
Expand All @@ -51,6 +58,8 @@ spec = describe "Server started with JWT and metrics enabled" $ do
*> request methodGet "/authors_only" [auth] "" `shouldRespondWith` 401

it "Should cache expired JWTs" $ do
expectCounters <- checkState' . specMetrics <$> getState

let auth = genToken [json|{"exp": 1, "role": "postgrest_test_author", "id": "jdoe2"}|]

expectCounters
Expand All @@ -63,6 +72,8 @@ spec = describe "Server started with JWT and metrics enabled" $ do
*> request methodGet "/authors_only" [auth] "" `shouldRespondWith` 401

it "Should evict entries from the JWT cache (jwt cache max is 2)" $ do
expectCounters <- checkState' . specMetrics <$> getState

let jwt1 = genToken [json|{"exp": 9999999999, "role": "postgrest_test_author", "id": "jdoe3"}|]
jwt2 = genToken [json|{"exp": 9999999999, "role": "postgrest_test_author", "id": "jdoe4"}|]
jwt3 = genToken [json|{"exp": 9999999999, "role": "postgrest_test_author", "id": "jdoe5"}|]
Expand All @@ -82,6 +93,8 @@ spec = describe "Server started with JWT and metrics enabled" $ do
*> request methodGet "/authors_only" [jwt3] ""

it "Should not evict entries from the JWT cache in FIFO order" $ do
expectCounters <- checkState' . specMetrics <$> getState

let jwt1 = genToken [json|{"exp": 9999999999, "role": "postgrest_test_author", "id": "jdoe6"}|]
jwt2 = genToken [json|{"exp": 9999999999, "role": "postgrest_test_author", "id": "jdoe7"}|]
jwt3 = genToken [json|{"exp": 9999999999, "role": "postgrest_test_author", "id": "jdoe8"}|]
Expand All @@ -108,6 +121,8 @@ spec = describe "Server started with JWT and metrics enabled" $ do
-- The test case was added based on coverage report
-- showing this scenario was not covered by previous tests
it "Should evict entries even though all were hit" $ do
expectCounters <- checkState' . specMetrics <$> getState

let jwt1 = genToken [json|{"exp": 9999999999, "role": "postgrest_test_author", "id": "jdoe9"}|]
jwt2 = genToken [json|{"exp": 9999999999, "role": "postgrest_test_author", "id": "jdoe10"}|]
jwt3 = genToken [json|{"exp": 9999999999, "role": "postgrest_test_author", "id": "jdoe11"}|]
Expand Down Expand Up @@ -135,4 +150,3 @@ spec = describe "Server started with JWT and metrics enabled" $ do
requests = expectCounter @"jwtCacheRequests"
hits = expectCounter @"jwtCacheHits"
evictions = expectCounter @"jwtCacheEvictions"
expectCounters = checkState
54 changes: 54 additions & 0 deletions test/observability/Observation/MetricsSpec.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE MonadComprehensions #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE TypeApplications #-}

module Observation.MetricsSpec where

import Data.List (lookup)
import Network.Wai (Application)
import ObsHelper
import qualified PostgREST.AppState as AppState
import PostgREST.Config (AppConfig (configDbSchemas))
import qualified PostgREST.Metrics as Metrics
import PostgREST.Observation
import Prometheus (getCounter, getVectorWith)
import Protolude
import Test.Hspec (SpecWith, describe, it)
import Test.Hspec.Wai (getState)

spec :: SpecWith (SpecState, Application)
spec = describe "Server started with metrics enabled" $ do
it "Should update pgrst_schema_cache_loads_total[SUCCESS]" $ do
SpecState{specAppState = appState, specMetrics = metrics, specObsChan} <- getState
let waitFor = waitForObs specObsChan

liftIO $ checkState' metrics [
schemaCacheLoads "SUCCESS" (+1)
] $ do
AppState.schemaCacheLoader appState
waitFor (1 * sec) "SchemaCacheLoadedObs" $ \x -> [ o | o@(SchemaCacheLoadedObs{}) <- pure x]

it "Should update pgrst_schema_cache_loads_total[ERROR]" $ do
SpecState{specAppState = appState, specMetrics = metrics, specObsChan} <- getState
let waitFor = waitForObs specObsChan

liftIO $ checkState' metrics [
schemaCacheLoads "FAIL" (+1),
schemaCacheLoads "SUCCESS" (+1)
] $ do
AppState.getConfig appState >>= \prev -> do
AppState.putConfig appState $ prev { configDbSchemas = pure "bad_schema" }
AppState.schemaCacheLoader appState
waitFor (1 * sec) "SchemaCacheErrorObs" $ \x -> [ o | o@(SchemaCacheErrorObs{}) <- pure x]
AppState.putConfig appState prev

-- wait up to 2 secs so that retry can happen
waitFor (2 * sec) "SchemaCacheLoadedObs" $ \x -> [ o | o@(SchemaCacheLoadedObs{}) <- pure x]

where
-- prometheus-client api to handle vectors is convoluted
schemaCacheLoads label = expectField @"schemaCacheLoads" $
fmap (maybe (0::Int) round . lookup label) . (`getVectorWith` getCounter)
sec = 1000000