From f27488a8cb1f5fc88b84094d41f1de08ac3cd1a4 Mon Sep 17 00:00:00 2001 From: Maksim Koltsov Date: Wed, 27 May 2026 07:44:14 +0000 Subject: [PATCH] Export authenticateOIDC and decouple Permit from OIDCConfig - Extract OIDC authentication pipeline into a public `authenticateOIDC` function so custom combinators can reuse it without duplicating logic. - Add `rolesVaultKey` to vault infrastructure; `authenticateOIDC` writes roles there and `Permit` reads from it, removing the `HasContextEntry context OIDCConfig` constraint on `Permit`. - Add `oidcRoles` field to `OIDCUser` carrying the authenticated principal's roles. - Expose `Web.Template.Log` as a public module so downstream combinators can import vault keys directly. Co-Authored-By: Claude Sonnet 4.6 --- src/Web/Template/Log.hs | 7 + src/Web/Template/Servant/Auth.hs | 334 ++++++++++++++++++------------- web-template.cabal | 2 +- 3 files changed, 203 insertions(+), 140 deletions(-) diff --git a/src/Web/Template/Log.hs b/src/Web/Template/Log.hs index 3af4c68..87d757e 100644 --- a/src/Web/Template/Log.hs +++ b/src/Web/Template/Log.hs @@ -15,6 +15,7 @@ module Web.Template.Log , userIdVaultKey , tokenVaultKey , pTokenVaultKey + , rolesVaultKey ) where import Control.Monad (forM_, when) @@ -54,6 +55,10 @@ pTokenVaultKey :: Key (IORef (Maybe ClaimsSet)) pTokenVaultKey = unsafePerformIO newKey {-# NOINLINE pTokenVaultKey #-} +rolesVaultKey :: Key (IORef (Maybe [Text])) +rolesVaultKey = unsafePerformIO newKey +{-# NOINLINE rolesVaultKey #-} + data AccessLogRecord = AccessLogRecord { alStart :: !POSIXTime @@ -92,8 +97,10 @@ logMiddlewareCustom log400 mLogAction app request respond = do userIdRef <- newIORef Nothing tokenRef <- newIORef Nothing ptokenRef <- newIORef Nothing + rolesRef <- newIORef Nothing let vaultWithEverything = + insert rolesVaultKey rolesRef $ insert tokenVaultKey tokenRef $ insert pTokenVaultKey ptokenRef $ insert userIdVaultKey userIdRef $ diff --git a/src/Web/Template/Servant/Auth.hs b/src/Web/Template/Servant/Auth.hs index a86c293..13e5892 100644 --- a/src/Web/Template/Servant/Auth.hs +++ b/src/Web/Template/Servant/Auth.hs @@ -16,6 +16,7 @@ module Web.Template.Servant.Auth , defaultOIDCCfg , oidcCfgWithManager , OIDCUser (..) + , authenticateOIDC , Permit , swaggerSchemaUIBCDServer ) where @@ -25,7 +26,7 @@ module Web.Template.Servant.Auth import Control.Applicative ((<|>)) import Control.Lens (Iso', at, coerced, ix, (&), (.~), (<&>), (?~), (^..), (^?)) import Control.Monad (unless) -import Control.Monad.Except (runExceptT) +import Control.Monad.Except (ExceptT, runExceptT, throwError) import Control.Monad.IO.Class (MonadIO, liftIO) import Data.IORef (readIORef, writeIORef) import Data.Maybe (catMaybes) @@ -81,7 +82,7 @@ import System.Clock (TimeSpec (..)) import Web.Cookie (parseCookiesText) import System.BCD.Log (Level (..), log') -import Web.Template.Log (pTokenVaultKey, tokenVaultKey, userIdVaultKey) +import Web.Template.Log (pTokenVaultKey, rolesVaultKey, tokenVaultKey, userIdVaultKey) -- | Adds authenthication via @id@ Cookie. -- @@ -130,17 +131,16 @@ instance HasOpenApi api => HasOpenApi (CbdAuth :> api) where (SecuritySchemeApiKey (ApiKeyParams "id" ApiKeyCookie)) (Just "`id` cookie") --- | Adds authenthication via jwt +-- | Adds authenthication via JWT. -- -- Usage: -- -- > type API = OIDCAuth :> (....) -- --- Takes token from 'Authorization' header. +-- Takes token from the @Authorization@ header. +-- Handlers will get an 'OIDCUser' argument. -- --- Handlers will get an 'UserId' argument. --- --- Stores token and claims in vault. +-- See 'authenticateOIDC' for the full pipeline description and vault contract. data OIDCAuth @@ -149,6 +149,7 @@ data OIDCUser { oidcUserId :: UserId , oidcAccessToken :: Text , oidcParsedToken :: ClaimsSet + , oidcRoles :: [Text] } deriving (Eq, Show, Generic) @@ -188,6 +189,186 @@ oidcCfgWithManager mgr = do , oidcAllowServiceToken = False } +-- | Run the full OIDC authentication pipeline against a WAI 'Request'. +-- +-- This is the core logic of 'OIDCAuth', exposed so that applications can +-- build custom Servant combinators that incorporate OIDC validation — for +-- example, a combinator that tries application-specific auth first and falls +-- back to OIDC — without duplicating this implementation. +-- +-- === Pipeline +-- +-- 1. Extract a @Bearer@ token from the @Authorization@ request header. +-- 2. Decode the raw bytes as a signed JWT. +-- 3. Fetch (and cache in 'oidcDiscoCache') the OIDC discovery document at +-- @\\/.well-known\/openid-configuration@. +-- 4. Fetch (and cache in 'oidcKeyCache') the JWK key set referenced by +-- the discovery document. +-- 5. Verify the JWT: check signature, @iss@ == 'oidcIssuer', +-- @aud@ == 'oidcClientId'. +-- 6. Extract the user identity from the @object_guid@ unregistered claim. +-- If absent and 'oidcAllowServiceToken' is 'True', falls back to +-- @preferred_username@. +-- 7. Populate the vault (see /Vault contract/ below). +-- +-- === Vault contract +-- +-- The WAI vault holds 'IORef' cells pre-inserted by the logging middleware. +-- Auth combinators signal identity by writing into these cells; they must +-- never insert new keys. The relevant keys are defined in "Web.Template.Log". +-- +-- __When delegating to 'authenticateOIDC':__ on 'Right' this function writes +-- all three keys listed below; the caller must not write them again. +-- On 'Left' the vault is left untouched. +-- +-- __When taking a path that bypasses 'authenticateOIDC':__ the combinator is +-- responsible for populating the vault itself. +-- +-- * 'rolesVaultKey' — __required for 'Permit'__. Write @Just roles@ where +-- @roles@ is the list of role strings the authenticated principal holds. +-- 'Permit' returns 401 if the key is absent from the vault (logging +-- middleware not set up) or if the 'IORef' is still 'Nothing', and 403 if +-- none of the required roles are present. Because 'Permit' reads only this +-- key, any combinator — OIDC or otherwise — can make 'Permit' work simply +-- by writing the right role list here. +-- * 'pTokenVaultKey' — write @Just claims@ to the 'IORef' for logging. +-- * 'userIdVaultKey' — write @Just uid@ so the logging middleware records the +-- authenticated user ID. +-- * 'tokenVaultKey' — write @Just token@ so the logging middleware records +-- the raw credential. +-- +-- === Errors +-- +-- Returns @'Left' 'ServerError'@ (never throws) in the following cases: +-- +-- [@401@] No @Authorization: Bearer \@ header present. +-- [@401@] The @Authorization@ value is not a valid signed JWT. +-- [@401@] JWT verification fails (signature, issuer, audience, or expiry). +-- [@401@] No @object_guid@ claim and service tokens are not allowed. +-- [@500@] OIDC discovery document or JWK set could not be fetched. +authenticateOIDC :: OIDCConfig -> Request -> IO (Either ServerError OIDCUser) +authenticateOIDC cfg req = runExceptT $ do + token <- maybe (throwError err401') return $ getToken req + jwt <- getJWT token + disco <- getDisco cfg + jwkSet <- getJWKSet cfg disco + claims <- getClaims cfg jwt jwkSet + + let guid = claims ^? unregisteredClaims . ix "object_guid" . _String + username = claims ^? unregisteredClaims . ix "preferred_username" . _String + + uid <- maybe + (die ERROR (throwError err401') ("No object_guid found" :: Text)) + return + (guid <|> (if oidcAllowServiceToken cfg then username else Nothing)) + + let roles = oidcRoles cfg claims + + liftIO $ sequence_ $ catMaybes + [ userIdVaultKey req <&> flip writeIORef (Just uid) + , tokenVaultKey req <&> flip writeIORef (Just $ decodeUtf8 token) + , pTokenVaultKey req <&> flip writeIORef (Just claims) + , rolesVaultKey req <&> flip writeIORef (Just roles) + ] + + return OIDCUser + { oidcUserId = UserId uid + , oidcAccessToken = decodeUtf8 token + , oidcParsedToken = claims + , oidcRoles = roles + } + where + https mgr = (`httpLbs` mgr) + + err401' :: ServerError + err401' = err401 + { errBody = "{\"error\": \"Authorization failed\"}" + , errHeaders = [(hContentType, "application/json")] + } + + err500' :: ServerError + err500' = err500 + { errBody = "{\"error\": \"Internal server error\"}" + , errHeaders = [(hContentType, "application/json")] + } + + die :: (MonadIO m, Show err) => Level -> ExceptT ServerError m b -> err -> ExceptT ServerError m b + die lvl fin err = liftIO (log' lvl ("web-template" :: Text) $ show err) >> fin + + getToken :: Request -> Maybe ByteString + getToken r = lookup "Authorization" (requestHeaders r) >>= stripPrefix "Bearer " + + oidcRoles :: OIDCConfig -> ClaimsSet -> [Text] + oidcRoles OIDCConfig {..} claims = claims + ^.. unregisteredClaims + . ix "resource_access" +#if MIN_VERSION_aeson(2, 0, 0) + . key (fromText oidcClientId) +#else + . key oidcClientId +#endif + . key "roles" + . values . _String + + expiration :: UTCTime -> Maybe UTCTime -> NominalDiffTime -> Maybe TimeSpec + expiration now ex defaultExp = diffTime + <$> (ex <|> pure (addUTCTime defaultExp now)) + -- If expiration is not set by OIDC provider, cache data for some + -- default amount of time, to avoid too many requests. + <*> pure now + where + tTreshold = 60 -- consider token expired 'tTreshold' seconds earlier + + diffTime :: UTCTime -> UTCTime -> TimeSpec + diffTime from to = let + diff = diffUTCTime from to - tTreshold + in max + TimeSpec {sec = 0, nsec = 0} + TimeSpec {sec = floor $ nominalDiffTimeToSeconds diff, nsec = 0} + + getJWT :: ByteString -> ExceptT ServerError IO SignedJWT + getJWT = either (die WARNING (throwError err401')) return + . decodeCompact @_ @JWTError + . LB.fromStrict + + getDisco :: OIDCConfig -> ExceptT ServerError IO Discovery + getDisco OIDCConfig {..} = liftIO (Cache.lookup oidcDiscoCache ()) + >>= maybe fetchDisco return + where + fetchDisco = liftIO (discovery (https oidcManager) (appWellKnown oidcIssuer)) + >>= either + (die ERROR (throwError err500')) + (uncurry discoSuccess) + where + discoSuccess disco mbDiscoExp = liftIO $ do + now <- getCurrentTime + Cache.insert' oidcDiscoCache (expiration now mbDiscoExp oidcDefaultExpiration) () disco + return disco + + getJWKSet :: OIDCConfig -> Discovery -> ExceptT ServerError IO JWKSet + getJWKSet OIDCConfig {..} disco = liftIO (Cache.lookup oidcKeyCache ()) + >>= maybe fetchKeys return + where + fetchKeys = liftIO (keysFromDiscovery (https oidcManager) disco) + >>= either + (die ERROR (throwError err500')) + (uncurry keysSuccess) + where + keysSuccess jwkSet mbKeysExp = liftIO $ do + now <- getCurrentTime + Cache.insert' oidcKeyCache (expiration now mbKeysExp oidcDefaultExpiration) () jwkSet + return jwkSet + + getClaims :: OIDCConfig -> SignedJWT -> JWKSet -> ExceptT ServerError IO ClaimsSet + getClaims OIDCConfig {..} jwt jwkSet = + liftIO (runExceptT $ verifyClaims @_ @_ @JWTError (jwtValidation oidcIssuer oidcClientId) jwkSet jwt) + >>= either (die ERROR (throwError err401')) return + where + jwtValidation :: URI -> Text -> JWTValidationSettings + jwtValidation issuer audience = defaultJWTValidationSettings (const True) + & issuerPredicate .~ (\iss -> iss ^? uri == Just issuer) + & audiencePredicate .~ (\aud -> aud ^? string == Just audience) + instance ( HasServer api context , HasContextEntry context OIDCConfig ) => HasServer (OIDCAuth :> api) context where @@ -199,113 +380,10 @@ instance ( HasServer api context route _ context sub = route @api Proxy context $ addAuthCheck sub - $ withRequest $ \req -> do - - token <- maybe unauth401 return $ getToken req - - jwt <- getJWT token - - let cfg = getContextEntry context - - disco <- getDisco cfg - - jwkSet <- getJWKSet cfg disco - - claims <- getClaims cfg jwt jwkSet - - let guid = claims ^? unregisteredClaims . ix "object_guid" . _String - let username = claims ^? unregisteredClaims . ix "preferred_username" . _String - - uid <- maybe - (die ERROR unauth401 ("No object_guid found" :: Text)) - return - (guid <|> (if oidcAllowServiceToken cfg then username else Nothing)) - - liftIO $ sequence_ $ catMaybes - [ userIdVaultKey req <&> flip writeIORef (Just uid) - , tokenVaultKey req <&> flip writeIORef (Just $ decodeUtf8 token) - , pTokenVaultKey req <&> flip writeIORef (Just claims) - ] - - return OIDCUser - { oidcUserId = UserId uid - , oidcAccessToken = decodeUtf8 token - , oidcParsedToken = claims - } + $ withRequest $ \req -> + liftIO (authenticateOIDC cfg req) >>= either delayedFailFatal return where - https mgr = (`httpLbs` mgr) - - die :: Show err => Level -> DelayedIO b -> err -> DelayedIO b - die lvl fin err = liftIO (log' lvl ("web-template" :: Text) $ show err) >> fin - - getToken :: Request -> Maybe ByteString - getToken r = lookup "Authorization" (requestHeaders r) >>= stripPrefix "Bearer " - - expiration :: UTCTime -> Maybe UTCTime -> NominalDiffTime -> Maybe TimeSpec - expiration now ex defaultExp = diffTime - <$> (ex <|> pure (addUTCTime defaultExp now)) - -- If expiration is not set by OIDC provider, cache data for some - -- default amount of time, to avoid too many requests. - <*> pure now - where - tTreshold = 60 -- consider token expired 'tTreshold' seconds earlier - - diffTime :: UTCTime -> UTCTime -> TimeSpec - diffTime from to = let - diff = diffUTCTime from to - tTreshold - in max - TimeSpec {sec = 0, nsec = 0} - TimeSpec {sec = floor $ nominalDiffTimeToSeconds diff, nsec = 0} - - getJWT :: ByteString -> DelayedIO SignedJWT - getJWT = either (die WARNING unauth401) return . decodeToken - where - decodeToken = decodeCompact @_ @JWTError . LB.fromStrict - - getDisco :: OIDCConfig -> DelayedIO Discovery - getDisco OIDCConfig {..} = liftIO (Cache.lookup oidcDiscoCache ()) - >>= maybe - fetchDisco - return - where - fetchDisco = liftIO (discovery (https oidcManager) (appWellKnown oidcIssuer)) - >>= either - (die ERROR unauth500) - (uncurry discoSuccess) - where - discoSuccess disco mbDiscoExp = liftIO $ do - now <- getCurrentTime - Cache.insert' oidcDiscoCache (expiration now mbDiscoExp oidcDefaultExpiration) () disco - return disco - - getJWKSet :: OIDCConfig -> Discovery -> DelayedIO JWKSet - getJWKSet OIDCConfig {..} disco = liftIO (Cache.lookup oidcKeyCache ()) - >>= maybe - fetchKeys - return - where - fetchKeys = liftIO (keysFromDiscovery (https oidcManager) disco) - >>= either - (die ERROR unauth500) - (uncurry keysSuccess) - where - keysSuccess jwkSet mbKeysExp = liftIO $ do - now <- getCurrentTime - Cache.insert' oidcKeyCache (expiration now mbKeysExp oidcDefaultExpiration) () jwkSet - return jwkSet - - getClaims :: OIDCConfig -> SignedJWT -> JWKSet -> DelayedIO ClaimsSet - getClaims OIDCConfig {..} jwt jwkSet = liftIO - (runExceptT $ - verifyClaims @_ @_ @JWTError (jwtValidation oidcIssuer oidcClientId) jwkSet jwt - ) >>= either - (die ERROR unauth401) - return - where - jwtValidation :: URI -> Text -> JWTValidationSettings - jwtValidation issuer audience = defaultJWTValidationSettings (const True) - & issuerPredicate .~ (\iss -> iss ^? uri == Just issuer) - & audiencePredicate .~ (\aud -> aud ^? string == Just audience) + cfg = getContextEntry context instance HasOpenApi api => HasOpenApi (OIDCAuth :> api) where toOpenApi _ = toOpenApi @api Proxy @@ -337,7 +415,6 @@ data Permit (rs :: [Symbol]) instance ( HasServer api context , KnownSymbols roles - , HasContextEntry context OIDCConfig ) => HasServer (Permit roles :> api) context where type ServerT (Permit roles :> api) m = ServerT api m @@ -348,25 +425,10 @@ instance ( HasServer api context route @api Proxy context $ addAuthCheck (const <$> sub) $ withRequest $ \req -> do - pTokenRef <- maybe unauth401 return $ pTokenVaultKey req - - claims <- liftIO (readIORef pTokenRef) >>= maybe unauth401 return - - let OIDCConfig {..} = getContextEntry context - - let haveRoles = claims - ^.. unregisteredClaims - . ix "resource_access" -#if MIN_VERSION_aeson(2, 0, 0) - . key (fromText oidcClientId) -#else - . key oidcClientId -#endif - . key "roles" - . values . _String - - unless (any (`elem` rolesNeeded) haveRoles) - unauth403 + rolesRef <- maybe unauth401 return $ rolesVaultKey req + haveRoles <- liftIO (readIORef rolesRef) >>= maybe unauth401 return + unless (any (`elem` rolesNeeded) haveRoles) + unauth403 where rolesNeeded = symbolsVal (Proxy @roles) @@ -406,12 +468,6 @@ unauth403 = delayedFailFatal $ err403 , errHeaders = [(hContentType, "application/json")] } -unauth500 :: DelayedIO a -unauth500 = delayedFailFatal $ err500 - { errBody = "{\"error\": \"Internal server error\"}" - , errHeaders = [(hContentType, "application/json")] - } - swaggerUiIndexBCDTemplate :: Text #if MIN_VERSION_file_embed_lzma(0,1,0) swaggerUiIndexBCDTemplate = $$(embedText "index.html.tmpl") diff --git a/web-template.cabal b/web-template.cabal index 8ba5b5a..609661a 100644 --- a/web-template.cabal +++ b/web-template.cabal @@ -40,10 +40,10 @@ library , Web.Template.Servant.Error , Web.Template.Servant.Error.Instance , Web.Template.Servant.Swagger + , Web.Template.Log other-modules: Web.Template.Except , Web.Template.Server , Web.Template.Types - , Web.Template.Log build-depends: base >= 4.7 && < 5 , aeson