Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions src/PostgREST/Admin.hs
Original file line number Diff line number Diff line change
Expand Up @@ -19,22 +19,23 @@ import PostgREST.Observation (Observation (..))

import qualified PostgREST.AppState as AppState

import Protolude
import qualified Network.Socket as NS
import Protolude

runAdmin :: AppState -> Warp.Settings -> IO ()
runAdmin appState settings = do
whenJust (AppState.getSocketAdmin appState) $ \adminSocket -> do
runAdmin :: AppState -> Maybe NS.Socket -> NS.Socket -> Warp.Settings -> IO ()
runAdmin appState maybeAdminSocket socketREST settings = do
whenJust maybeAdminSocket $ \adminSocket -> do
address <- resolveSocketToAddress adminSocket
observer $ AdminStartObs address
void . forkIO $ Warp.runSettingsSocket settings adminSocket adminApp
where
adminApp = admin appState
adminApp = admin appState socketREST
observer = AppState.getObserver appState

-- | PostgREST admin application
admin :: AppState.AppState -> Wai.Application
admin appState req respond = do
isMainAppReachable <- isRight <$> reachMainApp (AppState.getSocketREST appState)
admin :: AppState.AppState -> NS.Socket -> Wai.Application
admin appState socketREST req respond = do
isMainAppReachable <- isRight <$> reachMainApp socketREST
isLoaded <- AppState.isLoaded appState
isPending <- AppState.isPending appState

Expand Down
58 changes: 51 additions & 7 deletions src/PostgREST/App.hs
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,15 @@ import PostgREST.SchemaCache (SchemaCache (..))
import PostgREST.TimeIt (timeItT)
import PostgREST.Version (docsVersion, prettyVersion)

import qualified Data.ByteString.Char8 as BS
import qualified Data.List as L
import qualified Network.HTTP.Types as HTTP
import Protolude hiding (Handler)
import qualified Data.ByteString.Char8 as BS
import qualified Data.List as L
import Data.Streaming.Network (bindPortTCP,
bindRandomPortTCP)
import qualified Data.Text as T
import qualified Network.HTTP.Types as HTTP
import qualified Network.Socket as NS
import PostgREST.Unix (createAndBindDomainSocket)
import Protolude hiding (Handler)

type Handler = ExceptT Error

Expand All @@ -72,19 +77,21 @@ run appState = do
conf@AppConfig{..} <- AppState.getConfig appState

AppState.schemaCacheLoader appState -- Loads the initial SchemaCache
(mainSocket, adminSocket) <- initSockets conf

Unix.installSignalHandlers (AppState.getMainThreadId appState) (AppState.schemaCacheLoader appState) (AppState.readInDbConfig False appState)

Listener.runListener appState

Admin.runAdmin appState (serverSettings conf)
Admin.runAdmin appState adminSocket mainSocket (serverSettings conf)

let app = postgrest configLogLevel appState (AppState.schemaCacheLoader appState)

do
address <- resolveSocketToAddress (AppState.getSocketREST appState)
address <- resolveSocketToAddress mainSocket
observer $ AppServerAddressObs address

Warp.runSettingsSocket (serverSettings conf & setOnException onWarpException) (AppState.getSocketREST appState) app
Warp.runSettingsSocket (serverSettings conf & setOnException onWarpException) mainSocket app
where
observer = AppState.getObserver appState

Expand Down Expand Up @@ -229,3 +236,40 @@ addRetryHint delay response = do

isServiceUnavailable :: Wai.Response -> Bool
isServiceUnavailable response = Wai.responseStatus response == HTTP.status503

type AppSockets = (NS.Socket, Maybe NS.Socket)

initSockets :: AppConfig -> IO AppSockets
initSockets AppConfig{..} = do
let
cfg'usp = configServerUnixSocket
cfg'uspm = configServerUnixSocketMode
cfg'host = configServerHost
cfg'port = configServerPort
cfg'adminHost = configAdminServerHost
cfg'adminPort = configAdminServerPort

sock <- case cfg'usp of
-- I'm not using `streaming-commons`' bindPath function here because it's not defined for Windows,
-- but we need to have runtime error if we try to use it in Windows, not compile time error
Just path -> createAndBindDomainSocket path cfg'uspm
Nothing -> do
(_, sock) <-
if cfg'port /= 0
then do
sock <- bindPortTCP cfg'port (fromString $ T.unpack cfg'host)
pure (cfg'port, sock)
else do
-- explicitly bind to a random port, returning bound port number
(num, sock) <- bindRandomPortTCP (fromString $ T.unpack cfg'host)
pure (num, sock)
pure sock

adminSock <- case cfg'adminPort of
Just adminPort -> do
adminSock <- bindPortTCP adminPort (fromString $ T.unpack cfg'adminHost)
pure $ Just adminSock
Nothing -> pure Nothing

pure (sock, adminSock)

64 changes: 3 additions & 61 deletions src/PostgREST/AppState.hs
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,7 @@ module PostgREST.AppState
, getNextListenerDelay
, getTime
, getJwtCacheState
, getSocketREST
, getSocketAdmin
, init
, initSockets
, initWithPool
, putNextListenerDelay
, putSchemaCache
Expand All @@ -32,13 +29,11 @@ module PostgREST.AppState

import qualified Data.ByteString.Char8 as BS
import Data.Either.Combinators (whenLeft)
import qualified Data.Text as T (unpack)
import qualified Hasql.Pool as SQL
import qualified Hasql.Pool.Config as SQL
import qualified Hasql.Session as SQL
import qualified Hasql.Transaction.Sessions as SQL
import qualified Network.HTTP.Types.Status as HTTP
import qualified Network.Socket as NS
import qualified PostgREST.Auth.JwtCache as JwtCache
import qualified PostgREST.Error as Error
import qualified PostgREST.Logger as Logger
Expand Down Expand Up @@ -70,10 +65,7 @@ import PostgREST.SchemaCache (SchemaCache (..),
querySchemaCache,
showSummary)
import PostgREST.SchemaCache.Identifiers (quoteQi)
import PostgREST.Unix (createAndBindDomainSocket)

import Data.Streaming.Network (bindPortTCP, bindRandomPortTCP)
import Data.String (IsString (..))
import Protolude

data AppState = AppState
Expand All @@ -99,10 +91,6 @@ data AppState = AppState
, stateNextDelay :: IORef Int
-- | Keeps track of the next delay for the listener
, stateNextListenerDelay :: IORef Int
-- | Network socket for REST API
, stateSocketREST :: NS.Socket
-- | Network socket for the admin UI
, stateSocketAdmin :: Maybe NS.Socket
-- | Observation handler
, stateObserver :: ObservationHandler
-- | JWT Cache
Expand All @@ -117,8 +105,6 @@ data SchemaCacheStatus
| SCPending
deriving Eq

type AppSockets = (NS.Socket, Maybe NS.Socket)

init :: AppConfig -> IO AppState
init conf@AppConfig{configLogLevel, configDbPoolSize} = do
loggerState <- Logger.init
Expand All @@ -128,12 +114,10 @@ init conf@AppConfig{configLogLevel, configDbPoolSize} = do
observer $ AppStartObs prettyVersion

pool <- initPool conf observer
(sock, adminSock) <- initSockets conf
state' <- initWithPool (sock, adminSock) pool conf loggerState metricsState observer
pure state' { stateSocketREST = sock, stateSocketAdmin = adminSock}
initWithPool pool conf loggerState metricsState observer --{ stateSocketREST = sock, stateSocketAdmin = adminSock}

initWithPool :: AppSockets -> SQL.Pool -> AppConfig -> Logger.LoggerState -> Metrics.MetricsState -> ObservationHandler -> IO AppState
initWithPool (sock, adminSock) pool conf loggerState metricsState observer = do
initWithPool :: SQL.Pool -> AppConfig -> Logger.LoggerState -> Metrics.MetricsState -> ObservationHandler -> IO AppState
initWithPool pool conf loggerState metricsState observer = do

appState <- AppState pool
<$> newIORef minimumPgVersion -- assume we're in a supported version when starting, this will be corrected on a later step
Expand All @@ -146,8 +130,6 @@ initWithPool (sock, adminSock) pool conf loggerState metricsState observer = do
<*> myThreadId
<*> newIORef 0
<*> newIORef 1
<*> pure sock
<*> pure adminSock
<*> pure observer
<*> JwtCache.init conf observer
<*> pure loggerState
Expand All @@ -166,40 +148,6 @@ initWithPool (sock, adminSock) pool conf loggerState metricsState observer = do
destroy :: AppState -> IO ()
destroy = destroyPool

initSockets :: AppConfig -> IO AppSockets
initSockets AppConfig{..} = do
let
cfg'usp = configServerUnixSocket
cfg'uspm = configServerUnixSocketMode
cfg'host = configServerHost
cfg'port = configServerPort
cfg'adminHost = configAdminServerHost
cfg'adminPort = configAdminServerPort

sock <- case cfg'usp of
-- I'm not using `streaming-commons`' bindPath function here because it's not defined for Windows,
-- but we need to have runtime error if we try to use it in Windows, not compile time error
Just path -> createAndBindDomainSocket path cfg'uspm
Nothing -> do
(_, sock) <-
if cfg'port /= 0
then do
sock <- bindPortTCP cfg'port (fromString $ T.unpack cfg'host)
pure (cfg'port, sock)
else do
-- explicitly bind to a random port, returning bound port number
(num, sock) <- bindRandomPortTCP (fromString $ T.unpack cfg'host)
pure (num, sock)
pure sock

adminSock <- case cfg'adminPort of
Just adminPort -> do
adminSock <- bindPortTCP adminPort (fromString $ T.unpack cfg'adminHost)
pure $ Just adminSock
Nothing -> pure Nothing

pure (sock, adminSock)

initPool :: AppConfig -> ObservationHandler -> IO SQL.Pool
initPool AppConfig{..} observer = do
SQL.acquire $ SQL.settings
Expand Down Expand Up @@ -313,12 +261,6 @@ getTime = stateGetTime
getJwtCacheState :: AppState -> JwtCacheState
getJwtCacheState = stateJwtCache

getSocketREST :: AppState -> NS.Socket
getSocketREST = stateSocketREST

getSocketAdmin :: AppState -> Maybe NS.Socket
getSocketAdmin = stateSocketAdmin

getMainThreadId :: AppState -> ThreadId
getMainThreadId = stateMainThreadId

Expand Down
3 changes: 1 addition & 2 deletions test/spec/Main.hs
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,12 @@ main = do

-- cached schema cache so most tests run fast
baseSchemaCache <- loadSCache pool testCfg
sockets <- AppState.initSockets testCfg
loggerState <- Logger.init
metricsState <- Metrics.init (configDbPoolSize testCfg)

let
initApp sCache st config = do
appState <- AppState.initWithPool sockets pool config loggerState metricsState (Metrics.observationMetrics metricsState)
appState <- AppState.initWithPool pool config loggerState metricsState (Metrics.observationMetrics metricsState)
AppState.putPgVersion appState actualPgVersion
AppState.putSchemaCache appState (Just sCache)
return (st, postgrest (configLogLevel config) appState (pure ()))
Expand Down
Loading