diff --git a/src/PostgREST/Admin.hs b/src/PostgREST/Admin.hs index c21065b045..99733a6995 100644 --- a/src/PostgREST/Admin.hs +++ b/src/PostgREST/Admin.hs @@ -19,22 +19,23 @@ import PostgREST.Observation (Observation (..)) import qualified PostgREST.AppState as AppState -import Protolude +import qualified Network.Socket as NS +import Protolude -runAdmin :: AppState -> Warp.Settings -> IO () -runAdmin appState settings = do - whenJust (AppState.getSocketAdmin appState) $ \adminSocket -> do +runAdmin :: AppState -> Maybe NS.Socket -> NS.Socket -> Warp.Settings -> IO () +runAdmin appState maybeAdminSocket socketREST settings = do + whenJust maybeAdminSocket $ \adminSocket -> do address <- resolveSocketToAddress adminSocket observer $ AdminStartObs address void . forkIO $ Warp.runSettingsSocket settings adminSocket adminApp where - adminApp = admin appState + adminApp = admin appState socketREST observer = AppState.getObserver appState -- | PostgREST admin application -admin :: AppState.AppState -> Wai.Application -admin appState req respond = do - isMainAppReachable <- isRight <$> reachMainApp (AppState.getSocketREST appState) +admin :: AppState.AppState -> NS.Socket -> Wai.Application +admin appState socketREST req respond = do + isMainAppReachable <- isRight <$> reachMainApp socketREST isLoaded <- AppState.isLoaded appState isPending <- AppState.isPending appState diff --git a/src/PostgREST/App.hs b/src/PostgREST/App.hs index 6cddba5ff9..72e9ba3489 100644 --- a/src/PostgREST/App.hs +++ b/src/PostgREST/App.hs @@ -60,10 +60,15 @@ import PostgREST.SchemaCache (SchemaCache (..)) import PostgREST.TimeIt (timeItT) import PostgREST.Version (docsVersion, prettyVersion) -import qualified Data.ByteString.Char8 as BS -import qualified Data.List as L -import qualified Network.HTTP.Types as HTTP -import Protolude hiding (Handler) +import qualified Data.ByteString.Char8 as BS +import qualified Data.List as L +import Data.Streaming.Network (bindPortTCP, + bindRandomPortTCP) +import qualified Data.Text as T +import qualified Network.HTTP.Types as HTTP +import qualified Network.Socket as NS +import PostgREST.Unix (createAndBindDomainSocket) +import Protolude hiding (Handler) type Handler = ExceptT Error @@ -72,19 +77,21 @@ run appState = do conf@AppConfig{..} <- AppState.getConfig appState AppState.schemaCacheLoader appState -- Loads the initial SchemaCache + (mainSocket, adminSocket) <- initSockets conf + Unix.installSignalHandlers (AppState.getMainThreadId appState) (AppState.schemaCacheLoader appState) (AppState.readInDbConfig False appState) Listener.runListener appState - Admin.runAdmin appState (serverSettings conf) + Admin.runAdmin appState adminSocket mainSocket (serverSettings conf) let app = postgrest configLogLevel appState (AppState.schemaCacheLoader appState) do - address <- resolveSocketToAddress (AppState.getSocketREST appState) + address <- resolveSocketToAddress mainSocket observer $ AppServerAddressObs address - Warp.runSettingsSocket (serverSettings conf & setOnException onWarpException) (AppState.getSocketREST appState) app + Warp.runSettingsSocket (serverSettings conf & setOnException onWarpException) mainSocket app where observer = AppState.getObserver appState @@ -229,3 +236,40 @@ addRetryHint delay response = do isServiceUnavailable :: Wai.Response -> Bool isServiceUnavailable response = Wai.responseStatus response == HTTP.status503 + +type AppSockets = (NS.Socket, Maybe NS.Socket) + +initSockets :: AppConfig -> IO AppSockets +initSockets AppConfig{..} = do + let + cfg'usp = configServerUnixSocket + cfg'uspm = configServerUnixSocketMode + cfg'host = configServerHost + cfg'port = configServerPort + cfg'adminHost = configAdminServerHost + cfg'adminPort = configAdminServerPort + + sock <- case cfg'usp of + -- I'm not using `streaming-commons`' bindPath function here because it's not defined for Windows, + -- but we need to have runtime error if we try to use it in Windows, not compile time error + Just path -> createAndBindDomainSocket path cfg'uspm + Nothing -> do + (_, sock) <- + if cfg'port /= 0 + then do + sock <- bindPortTCP cfg'port (fromString $ T.unpack cfg'host) + pure (cfg'port, sock) + else do + -- explicitly bind to a random port, returning bound port number + (num, sock) <- bindRandomPortTCP (fromString $ T.unpack cfg'host) + pure (num, sock) + pure sock + + adminSock <- case cfg'adminPort of + Just adminPort -> do + adminSock <- bindPortTCP adminPort (fromString $ T.unpack cfg'adminHost) + pure $ Just adminSock + Nothing -> pure Nothing + + pure (sock, adminSock) + diff --git a/src/PostgREST/AppState.hs b/src/PostgREST/AppState.hs index bf439f7dc3..f712982108 100644 --- a/src/PostgREST/AppState.hs +++ b/src/PostgREST/AppState.hs @@ -13,10 +13,7 @@ module PostgREST.AppState , getNextListenerDelay , getTime , getJwtCacheState - , getSocketREST - , getSocketAdmin , init - , initSockets , initWithPool , putNextListenerDelay , putSchemaCache @@ -32,13 +29,11 @@ module PostgREST.AppState import qualified Data.ByteString.Char8 as BS import Data.Either.Combinators (whenLeft) -import qualified Data.Text as T (unpack) import qualified Hasql.Pool as SQL import qualified Hasql.Pool.Config as SQL import qualified Hasql.Session as SQL import qualified Hasql.Transaction.Sessions as SQL import qualified Network.HTTP.Types.Status as HTTP -import qualified Network.Socket as NS import qualified PostgREST.Auth.JwtCache as JwtCache import qualified PostgREST.Error as Error import qualified PostgREST.Logger as Logger @@ -70,10 +65,7 @@ import PostgREST.SchemaCache (SchemaCache (..), querySchemaCache, showSummary) import PostgREST.SchemaCache.Identifiers (quoteQi) -import PostgREST.Unix (createAndBindDomainSocket) -import Data.Streaming.Network (bindPortTCP, bindRandomPortTCP) -import Data.String (IsString (..)) import Protolude data AppState = AppState @@ -99,10 +91,6 @@ data AppState = AppState , stateNextDelay :: IORef Int -- | Keeps track of the next delay for the listener , stateNextListenerDelay :: IORef Int - -- | Network socket for REST API - , stateSocketREST :: NS.Socket - -- | Network socket for the admin UI - , stateSocketAdmin :: Maybe NS.Socket -- | Observation handler , stateObserver :: ObservationHandler -- | JWT Cache @@ -117,8 +105,6 @@ data SchemaCacheStatus | SCPending deriving Eq -type AppSockets = (NS.Socket, Maybe NS.Socket) - init :: AppConfig -> IO AppState init conf@AppConfig{configLogLevel, configDbPoolSize} = do loggerState <- Logger.init @@ -128,12 +114,10 @@ init conf@AppConfig{configLogLevel, configDbPoolSize} = do observer $ AppStartObs prettyVersion pool <- initPool conf observer - (sock, adminSock) <- initSockets conf - state' <- initWithPool (sock, adminSock) pool conf loggerState metricsState observer - pure state' { stateSocketREST = sock, stateSocketAdmin = adminSock} + initWithPool pool conf loggerState metricsState observer --{ stateSocketREST = sock, stateSocketAdmin = adminSock} -initWithPool :: AppSockets -> SQL.Pool -> AppConfig -> Logger.LoggerState -> Metrics.MetricsState -> ObservationHandler -> IO AppState -initWithPool (sock, adminSock) pool conf loggerState metricsState observer = do +initWithPool :: SQL.Pool -> AppConfig -> Logger.LoggerState -> Metrics.MetricsState -> ObservationHandler -> IO AppState +initWithPool pool conf loggerState metricsState observer = do appState <- AppState pool <$> newIORef minimumPgVersion -- assume we're in a supported version when starting, this will be corrected on a later step @@ -146,8 +130,6 @@ initWithPool (sock, adminSock) pool conf loggerState metricsState observer = do <*> myThreadId <*> newIORef 0 <*> newIORef 1 - <*> pure sock - <*> pure adminSock <*> pure observer <*> JwtCache.init conf observer <*> pure loggerState @@ -166,40 +148,6 @@ initWithPool (sock, adminSock) pool conf loggerState metricsState observer = do destroy :: AppState -> IO () destroy = destroyPool -initSockets :: AppConfig -> IO AppSockets -initSockets AppConfig{..} = do - let - cfg'usp = configServerUnixSocket - cfg'uspm = configServerUnixSocketMode - cfg'host = configServerHost - cfg'port = configServerPort - cfg'adminHost = configAdminServerHost - cfg'adminPort = configAdminServerPort - - sock <- case cfg'usp of - -- I'm not using `streaming-commons`' bindPath function here because it's not defined for Windows, - -- but we need to have runtime error if we try to use it in Windows, not compile time error - Just path -> createAndBindDomainSocket path cfg'uspm - Nothing -> do - (_, sock) <- - if cfg'port /= 0 - then do - sock <- bindPortTCP cfg'port (fromString $ T.unpack cfg'host) - pure (cfg'port, sock) - else do - -- explicitly bind to a random port, returning bound port number - (num, sock) <- bindRandomPortTCP (fromString $ T.unpack cfg'host) - pure (num, sock) - pure sock - - adminSock <- case cfg'adminPort of - Just adminPort -> do - adminSock <- bindPortTCP adminPort (fromString $ T.unpack cfg'adminHost) - pure $ Just adminSock - Nothing -> pure Nothing - - pure (sock, adminSock) - initPool :: AppConfig -> ObservationHandler -> IO SQL.Pool initPool AppConfig{..} observer = do SQL.acquire $ SQL.settings @@ -313,12 +261,6 @@ getTime = stateGetTime getJwtCacheState :: AppState -> JwtCacheState getJwtCacheState = stateJwtCache -getSocketREST :: AppState -> NS.Socket -getSocketREST = stateSocketREST - -getSocketAdmin :: AppState -> Maybe NS.Socket -getSocketAdmin = stateSocketAdmin - getMainThreadId :: AppState -> ThreadId getMainThreadId = stateMainThreadId diff --git a/test/spec/Main.hs b/test/spec/Main.hs index e847926b6f..a3c638ffd8 100644 --- a/test/spec/Main.hs +++ b/test/spec/Main.hs @@ -84,13 +84,12 @@ main = do -- cached schema cache so most tests run fast baseSchemaCache <- loadSCache pool testCfg - sockets <- AppState.initSockets testCfg loggerState <- Logger.init metricsState <- Metrics.init (configDbPoolSize testCfg) let initApp sCache st config = do - appState <- AppState.initWithPool sockets pool config loggerState metricsState (Metrics.observationMetrics metricsState) + appState <- AppState.initWithPool pool config loggerState metricsState (Metrics.observationMetrics metricsState) AppState.putPgVersion appState actualPgVersion AppState.putSchemaCache appState (Just sCache) return (st, postgrest (configLogLevel config) appState (pure ()))