diff --git a/CHANGELOG.md b/CHANGELOG.md index ba4f4ce6e2..2bab0150a7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,7 @@ All notable changes to this project will be documented in this file. From versio ### Added +- Use `pg_basetype()` for domain type resolution on PostgreSQL 17+ by @joelonsql in #XXXX - Log error when `db-schemas` config contains schema `pg_catalog` or `information_schema` by @taimoorzaeem in #4359 ### Fixed diff --git a/src/PostgREST/App.hs b/src/PostgREST/App.hs index 5387c69285..7103725d7f 100644 --- a/src/PostgREST/App.hs +++ b/src/PostgREST/App.hs @@ -149,7 +149,8 @@ postgrestResponse appState conf@AppConfig{..} maybeSchemaCache authResult@AuthRe (parseTime, apiReq@ApiRequest{..}) <- withTiming $ liftEither . mapLeft Error.ApiRequestError $ ApiRequest.userApiRequest conf prefs req body (planTime, plan) <- withTiming $ liftEither $ Plan.actionPlan iAction conf apiReq sCache - let mainQ = Query.mainQuery plan conf apiReq authResult configDbPreRequest + pgVer <- lift $ AppState.getPgVersion appState + let mainQ = Query.mainQuery pgVer plan conf apiReq authResult configDbPreRequest tx = MainTx.mainTx mainQ conf authResult apiReq plan sCache obsQuery s = when configLogQuery $ observer $ QueryObs mainQ s diff --git a/src/PostgREST/AppState.hs b/src/PostgREST/AppState.hs index f231a4e9a4..1e749e7a4f 100644 --- a/src/PostgREST/AppState.hs +++ b/src/PostgREST/AppState.hs @@ -401,9 +401,10 @@ retryingSchemaCacheLoad appState@AppState{stateObserver=observer, stateMainThrea qSchemaCache :: IO (Maybe SchemaCache) qSchemaCache = do conf@AppConfig{..} <- getConfig appState + pgVer <- getPgVersion appState (resultTime, result) <- let transaction = if configDbPreparedStatements then SQL.transaction else SQL.unpreparedTransaction in - timeItT $ usePool appState (transaction SQL.ReadCommitted SQL.Read $ querySchemaCache conf) + timeItT $ usePool appState (transaction SQL.ReadCommitted SQL.Read $ querySchemaCache pgVer conf) case result of Left e -> do putSCacheStatus appState SCPending diff --git a/src/PostgREST/CLI.hs b/src/PostgREST/CLI.hs index a635aa0814..80abc8c0ba 100644 --- a/src/PostgREST/CLI.hs +++ b/src/PostgREST/CLI.hs @@ -61,10 +61,11 @@ runAppCommand conf@AppConfig{..} runCmd = do dumpSchema :: AppState -> IO LBS.ByteString dumpSchema appState = do conf@AppConfig{..} <- AppState.getConfig appState + pgVer <- AppState.getPgVersion appState result <- let transaction = if configDbPreparedStatements then SQL.transaction else SQL.unpreparedTransaction in AppState.usePool appState - (transaction SQL.ReadCommitted SQL.Read $ querySchemaCache conf) + (transaction SQL.ReadCommitted SQL.Read $ querySchemaCache pgVer conf) case result of Left e -> do let observer = AppState.getObserver appState diff --git a/src/PostgREST/Query.hs b/src/PostgREST/Query.hs index 69dcbe5604..6d607c6b65 100644 --- a/src/PostgREST/Query.hs +++ b/src/PostgREST/Query.hs @@ -23,6 +23,7 @@ import PostgREST.ApiRequest.Preferences (Preferences (..), shouldExplainCount) import PostgREST.Auth.Types (AuthResult (..)) import PostgREST.Config (AppConfig (..)) +import PostgREST.Config.PgVersion (PgVersion) import PostgREST.Plan (ActionPlan (..), CrudPlan (..), DbActionPlan (..), @@ -41,9 +42,9 @@ data MainQuery = MainQuery , mqExplain :: Maybe SQL.Snippet -- ^ the explain query that gets generated for the "Prefer: count=estimated" case } -mainQuery :: ActionPlan -> AppConfig -> ApiRequest -> AuthResult -> Maybe QualifiedIdentifier -> MainQuery -mainQuery (NoDb _) _ _ _ _ = MainQuery mempty Nothing mempty (mempty, mempty, mempty) mempty -mainQuery (Db plan) conf@AppConfig{..} apiReq@ApiRequest{iPreferences=Preferences{..}} authRes preReq = +mainQuery :: PgVersion -> ActionPlan -> AppConfig -> ApiRequest -> AuthResult -> Maybe QualifiedIdentifier -> MainQuery +mainQuery _ (NoDb _) _ _ _ _ = MainQuery mempty Nothing mempty (mempty, mempty, mempty) mempty +mainQuery pgVer (Db plan) conf@AppConfig{..} apiReq@ApiRequest{iPreferences=Preferences{..}} authRes preReq = let genQ = MainQuery (PreQuery.txVarQuery plan conf authRes apiReq) (PreQuery.preReqQuery <$> preReq) in case plan of DbCrud _ WrappedReadPlan{..} -> @@ -55,4 +56,4 @@ mainQuery (Db plan) conf@AppConfig{..} apiReq@ApiRequest{iPreferences=Preference DbCrud _ CallReadPlan{..} -> genQ (Statements.mainCall crProc crCallPlan crReadPlan preferCount pMedia crHandler) (mempty, mempty, mempty) mempty MayUseDb InspectPlan{ipSchema=tSchema} -> - genQ mempty (SqlFragment.accessibleTables tSchema, SqlFragment.accessibleFuncs tSchema, SqlFragment.schemaDescription tSchema) mempty + genQ mempty (SqlFragment.accessibleTables tSchema, SqlFragment.accessibleFuncs pgVer tSchema, SqlFragment.schemaDescription tSchema) mempty diff --git a/src/PostgREST/Query/SqlFragment.hs b/src/PostgREST/Query/SqlFragment.hs index 461d4e8274..a9639d9676 100644 --- a/src/PostgREST/Query/SqlFragment.hs +++ b/src/PostgREST/Query/SqlFragment.hs @@ -74,6 +74,7 @@ import PostgREST.ApiRequest.Types (AggregateFunction (..), OrderNulls (..), QuantOperator (..), SimpleOperator (..)) +import PostgREST.Config.PgVersion (PgVersion, pgVersion170) import PostgREST.MediaType (MTVndPlanFormat (..), MTVndPlanOption (..)) import PostgREST.Plan.ReadPlan (JoinCondition (..)) @@ -614,39 +615,60 @@ accessibleTables schema = SQL.sql (encodeUtf8 [trimming| where encodedSchema = SQL.encoderAndParam (HE.nonNullable HE.text) schema -accessibleFuncs :: Text -> SQL.Snippet -accessibleFuncs schema = baseFuncSqlQuery <> "AND p.pronamespace = " <> encodedSchema <> "::regnamespace" +accessibleFuncs :: PgVersion -> Text -> SQL.Snippet +accessibleFuncs pgVer schema = baseFuncSqlQuery pgVer <> "AND p.pronamespace = " <> encodedSchema <> "::regnamespace" where encodedSchema = SQL.encoderAndParam (HE.nonNullable HE.text) schema -baseFuncSqlQuery :: SQL.Snippet -baseFuncSqlQuery = SQL.sql $ encodeUtf8 [trimming| +baseTypesCte :: PgVersion -> Text +baseTypesCte pgVer + | pgVer >= pgVersion170 = [trimming| + -- Get base types using pg_basetype() (PG 17+) + base_types AS ( + SELECT + t.oid, + bt.typnamespace AS base_namespace, + bt.oid AS base_type + FROM pg_type t + JOIN pg_type bt ON bt.oid = pg_basetype(t.oid) + ) + |] + | otherwise = [trimming| + -- Recursively get the base types of domains (PG < 17) + base_types AS ( + WITH RECURSIVE + recurse AS ( + SELECT + oid, + typbasetype, + typnamespace AS base_namespace, + COALESCE(NULLIF(typbasetype, 0), oid) AS base_type + FROM pg_type + UNION + SELECT + t.oid, + b.typbasetype, + b.typnamespace AS base_namespace, + COALESCE(NULLIF(b.typbasetype, 0), b.oid) AS base_type + FROM recurse t + JOIN pg_type b ON t.typbasetype = b.oid + ) + SELECT + oid, + base_namespace, + base_type + FROM recurse + WHERE typbasetype = 0 + ) + |] + +-- | SQL query to get accessible functions for OpenAPI. +baseFuncSqlQuery :: PgVersion -> SQL.Snippet +baseFuncSqlQuery pgVer = + let baseCte = baseTypesCte pgVer + in SQL.sql $ encodeUtf8 [trimming| WITH - base_types AS ( - WITH RECURSIVE - recurse AS ( - SELECT - oid, - typbasetype, - typnamespace AS base_namespace, - COALESCE(NULLIF(typbasetype, 0), oid) AS base_type - FROM pg_type - UNION - SELECT - t.oid, - b.typbasetype, - b.typnamespace AS base_namespace, - COALESCE(NULLIF(b.typbasetype, 0), b.oid) AS base_type - FROM recurse t - JOIN pg_type b ON t.typbasetype = b.oid - ) - SELECT - oid, - base_namespace, - base_type - FROM recurse - WHERE typbasetype = 0 - ), + $baseCte, arguments AS ( SELECT oid, diff --git a/src/PostgREST/SchemaCache.hs b/src/PostgREST/SchemaCache.hs index 528d7905f4..c490225c67 100644 --- a/src/PostgREST/SchemaCache.hs +++ b/src/PostgREST/SchemaCache.hs @@ -42,6 +42,7 @@ import NeatInterpolation (trimming) import PostgREST.Config (AppConfig (..)) import PostgREST.Config.Database (TimezoneNames, toIsolationLevel) +import PostgREST.Config.PgVersion (PgVersion, pgVersion170) import PostgREST.SchemaCache.Identifiers (FieldName, QualifiedIdentifier (..), RelIdentifier (..), @@ -139,13 +140,13 @@ data KeyDep type SqlQuery = ByteString -querySchemaCache :: AppConfig -> SQL.Transaction SchemaCache -querySchemaCache conf@AppConfig{..} = do +querySchemaCache :: PgVersion -> AppConfig -> SQL.Transaction SchemaCache +querySchemaCache pgVer conf@AppConfig{..} = do SQL.sql "set local schema ''" -- This voids the search path. The following queries need this for getting the fully qualified name(schema.name) of every db object - tabs <- SQL.statement conf $ allTables prepared + tabs <- SQL.statement conf $ allTables pgVer prepared keyDeps <- SQL.statement conf $ allViewsKeyDependencies prepared m2oRels <- SQL.statement mempty $ allM2OandO2ORels prepared - funcs <- SQL.statement conf $ allFunctions prepared + funcs <- SQL.statement conf $ allFunctions pgVer prepared cRels <- SQL.statement mempty $ allComputedRels prepared reps <- SQL.statement conf $ dataRepresentations prepared mHdlers <- SQL.statement conf $ mediaHandlers prepared @@ -353,47 +354,61 @@ dataRepresentations = SQL.Statement sql mempty decodeRepresentations OR (dst_t.typtype = 'd' AND c.castsource IN ('json'::regtype::oid , 'text'::regtype::oid))) |] -allFunctions :: Bool -> SQL.Statement AppConfig RoutineMap -allFunctions = SQL.Statement funcsSqlQuery params decodeFuncs +allFunctions :: PgVersion -> Bool -> SQL.Statement AppConfig RoutineMap +allFunctions pgVer = SQL.Statement (funcsSqlQuery pgVer) params decodeFuncs where params = (map escapeIdent . toList . configDbSchemas >$< arrayParam HE.text) <> (configDbHoistedTxSettings >$< arrayParam HE.text) -baseTypesCte :: Text -baseTypesCte = [trimming| - -- Recursively get the base types of domains - base_types AS ( - WITH RECURSIVE - recurse AS ( - SELECT - oid, - typbasetype, - typnamespace AS base_namespace, - COALESCE(NULLIF(typbasetype, 0), oid) AS base_type - FROM pg_type - UNION - SELECT - t.oid, - b.typbasetype, - b.typnamespace AS base_namespace, - COALESCE(NULLIF(b.typbasetype, 0), b.oid) AS base_type - FROM recurse t - JOIN pg_type b ON t.typbasetype = b.oid - ) - SELECT - oid, - base_namespace, - base_type - FROM recurse - WHERE typbasetype = 0 - ) -|] +baseTypesCte :: PgVersion -> Text +baseTypesCte pgVer + | pgVer >= pgVersion170 = [trimming| + -- Get base types using pg_basetype() (PG 17+) + base_types AS ( + SELECT + t.oid, + bt.typnamespace AS base_namespace, + bt.oid AS base_type + FROM pg_type t + JOIN pg_type bt ON bt.oid = pg_basetype(t.oid) + ) + |] + | otherwise = [trimming| + -- Recursively get the base types of domains (PG < 17) + base_types AS ( + WITH RECURSIVE + recurse AS ( + SELECT + oid, + typbasetype, + typnamespace AS base_namespace, + COALESCE(NULLIF(typbasetype, 0), oid) AS base_type + FROM pg_type + UNION + SELECT + t.oid, + b.typbasetype, + b.typnamespace AS base_namespace, + COALESCE(NULLIF(b.typbasetype, 0), b.oid) AS base_type + FROM recurse t + JOIN pg_type b ON t.typbasetype = b.oid + ) + SELECT + oid, + base_namespace, + base_type + FROM recurse + WHERE typbasetype = 0 + ) + |] -funcsSqlQuery :: SqlQuery -funcsSqlQuery = encodeUtf8 [trimming| +funcsSqlQuery :: PgVersion -> SqlQuery +funcsSqlQuery pgVer = + let baseCte = baseTypesCte pgVer + in encodeUtf8 [trimming| WITH - $baseTypesCte, + $baseCte, arguments AS ( SELECT oid, @@ -566,22 +581,23 @@ addViewPrimaryKeys tabs keyDeps = takeFirstPK = mapMaybe (head . snd) indexedDeps = HM.fromListWith (++) $ fmap ((keyDepType &&& keyDepView) &&& pure) keyDeps -allTables :: Bool -> SQL.Statement AppConfig TablesMap -allTables = SQL.Statement tablesSqlQuery params decodeTables +allTables :: PgVersion -> Bool -> SQL.Statement AppConfig TablesMap +allTables pgVer = SQL.Statement (tablesSqlQuery pgVer) params decodeTables where params = map escapeIdent . toList . configDbSchemas >$< arrayParam HE.text -- | Gets tables with their PK cols -tablesSqlQuery :: SqlQuery -tablesSqlQuery = +tablesSqlQuery :: PgVersion -> SqlQuery +tablesSqlQuery pgVer = -- the tbl_constraints/key_col_usage CTEs are based on the standard "information_schema.table_constraints"/"information_schema.key_column_usage" views, -- we cannot use those directly as they include the following privilege filter: -- (pg_has_role(ss.relowner, 'USAGE'::text) OR has_column_privilege(ss.roid, a.attnum, 'SELECT, INSERT, UPDATE, REFERENCES'::text)); -- on the "columns" CTE, left joining on pg_depend and pg_class is used to obtain the sequence name as a column default in case there are GENERATED .. AS IDENTITY, -- generated columns are only available from pg >= 10 but the query is agnostic to versions. dep.deptype = 'i' is done because there are other 'a' dependencies on PKs - encodeUtf8 [trimming| + let baseCte = baseTypesCte pgVer + in encodeUtf8 [trimming| WITH - $baseTypesCte, + $baseCte, columns AS ( SELECT c.oid AS relid, diff --git a/test/io/test_io.py b/test/io/test_io.py index 0e3648a548..1f3ce8a7ea 100644 --- a/test/io/test_io.py +++ b/test/io/test_io.py @@ -759,7 +759,7 @@ def drain_stdout(proc): ) infinite_recursion_5xx_regx = r'.+: WITH pgrst_source AS.+SELECT "public"\."infinite_recursion"\.\* FROM "public"\."infinite_recursion".+_postgrest_t' root_tables_regx = r".+: SELECT n.nspname AS table_schema, .+ FROM pg_class c .+ ORDER BY table_schema, table_name" - root_procs_regx = r".+: WITH base_types AS \(.+\) SELECT pn.nspname AS proc_schema, .+ FROM pg_proc p.+AND p.pronamespace = \$1::regnamespace" + root_procs_regx = r".+: WITH.+base_types AS.+pn\.nspname AS proc_schema.+FROM pg_proc p.+p\.pronamespace = \$1::regnamespace" root_descr_regx = r".+: SELECT pg_catalog\.obj_description\(\$1::regnamespace, 'pg_namespace'\)" set_config_regx = ( r".+: select set_config\('search_path', \$1, true\), set_config\(" diff --git a/test/spec/Main.hs b/test/spec/Main.hs index e847926b6f..f0f0ca4305 100644 --- a/test/spec/Main.hs +++ b/test/spec/Main.hs @@ -83,7 +83,7 @@ main = do actualPgVersion <- either (panic . show) id <$> P.use pool (queryPgVersion False) -- cached schema cache so most tests run fast - baseSchemaCache <- loadSCache pool testCfg + baseSchemaCache <- loadSCache pool actualPgVersion testCfg sockets <- AppState.initSockets testCfg loggerState <- Logger.init metricsState <- Metrics.init (configDbPoolSize testCfg) @@ -100,7 +100,7 @@ main = do -- For tests that run with a different SchemaCache (depends on configSchemas) appDbs config = do - customSchemaCache <- loadSCache pool config + customSchemaCache <- loadSCache pool actualPgVersion config initApp customSchemaCache () config let withApp = app testCfg @@ -279,5 +279,5 @@ main = do describe "Feature.Auth.JwtCacheSpec" Feature.Auth.JwtCacheSpec.spec where - loadSCache pool conf = - either (panic.show) id <$> P.use pool (HT.transaction HT.ReadCommitted HT.Read $ querySchemaCache conf) + loadSCache pool pgVer conf = + either (panic.show) id <$> P.use pool (HT.transaction HT.ReadCommitted HT.Read $ querySchemaCache pgVer conf)