diff --git a/sqlalchemy_utils/view.py b/sqlalchemy_utils/view.py index 96103db1..390b6d3e 100644 --- a/sqlalchemy_utils/view.py +++ b/sqlalchemy_utils/view.py @@ -1,50 +1,156 @@ +from typing import Any, Dict, List, Optional, TYPE_CHECKING + import sqlalchemy as sa from sqlalchemy.ext import compiler from sqlalchemy.schema import DDLElement, PrimaryKeyConstraint from sqlalchemy_utils.functions import get_columns +if TYPE_CHECKING: + from sqlalchemy.engine.default import DefaultDialect + from sqlalchemy.orm import Session + from sqlalchemy.sql import Selectable + from sqlalchemy.sql.compiler import SQLCompiler + + +def _prepare_view_identifier( + dialect: 'DefaultDialect', + view_name: str, + schema: Optional[str] = None, +) -> str: + quoted_view_name = dialect.identifier_preparer.quote(view_name) + if schema: + return dialect.identifier_preparer.quote_schema(schema) + '.' + quoted_view_name + + return quoted_view_name + class CreateView(DDLElement): - def __init__(self, name, selectable, materialized=False): + def __init__( + self, + name: str, + selectable: 'Selectable', + schema: Optional[str] = None, + ): self.name = name self.selectable = selectable - self.materialized = materialized + self.schema = schema @compiler.compiles(CreateView) -def compile_create_materialized_view(element, compiler, **kw): - return 'CREATE {}VIEW {} AS {}'.format( - 'MATERIALIZED ' if element.materialized else '', - compiler.dialect.identifier_preparer.quote(element.name), - compiler.sql_compiler.process(element.selectable, literal_binds=True), +def compile_create_view( + element: 'CreateView', + compiler: 'SQLCompiler', + **kw: Any, +) -> str: + view_identifier = _prepare_view_identifier( + compiler.dialect, element.name, element.schema + ) + compiled_selectable = compiler.sql_compiler.process( + element.selectable, literal_binds=True ) + return f'CREATE VIEW {view_identifier} AS {compiled_selectable}' class DropView(DDLElement): - def __init__(self, name, materialized=False, cascade=True): + def __init__( + self, + name: str, + schema: Optional[str] = None, + cascade: Optional[bool] = None, + ): self.name = name - self.materialized = materialized + self.schema = schema self.cascade = cascade @compiler.compiles(DropView) -def compile_drop_materialized_view(element, compiler, **kw): - return 'DROP {}VIEW IF EXISTS {} {}'.format( - 'MATERIALIZED ' if element.materialized else '', - compiler.dialect.identifier_preparer.quote(element.name), - 'CASCADE' if element.cascade else '' +def compile_drop_view(element: 'DropView', compiler: 'SQLCompiler', **kw: Any) -> str: + view_identifier = _prepare_view_identifier( + compiler.dialect, element.name, element.schema + ) + + stmt = f'DROP VIEW IF EXISTS {view_identifier}' + if element.cascade is True: + stmt += ' CASCADE' + elif element.cascade is False: + stmt += ' RESTRICT' + return stmt + + +class CreateMaterializedView(DDLElement): + def __init__( + self, + name: str, + selectable: 'Selectable', + schema: Optional[str] = None, + populate: Optional[bool] = None, + ): + self.name = name + self.selectable = selectable + self.schema = schema + self.populate = populate + + +@compiler.compiles(CreateMaterializedView) +def compile_create_materialized_view( + element: 'CreateMaterializedView', + compiler: 'SQLCompiler', + **kw: Any, +) -> str: + view_identifier = _prepare_view_identifier( + dialect=compiler.dialect, view_name=element.name, schema=element.schema + ) + compiled_selectable = compiler.sql_compiler.process( + element.selectable, literal_binds=True ) + stmt = f'CREATE MATERIALIZED VIEW {view_identifier} AS {compiled_selectable}' + if element.populate is True: + stmt += ' WITH DATA' + elif element.populate is False: + stmt += ' WITH NO DATA' + return stmt + + +class DropMaterializedView(DDLElement): + def __init__( + self, + name: str, + schema: Optional[str] = None, + cascade: Optional[bool] = None, + ): + self.name = name + self.schema = schema + self.cascade = cascade + + +@compiler.compiles(DropMaterializedView) +def compile_drop_materialized_view( + element: 'DropMaterializedView', + compiler: 'SQLCompiler', + **kw: Any, +) -> str: + view_identifier = _prepare_view_identifier( + dialect=compiler.dialect, view_name=element.name, schema=element.schema + ) + stmt = f'DROP MATERIALIZED VIEW IF EXISTS {view_identifier}' + if element.cascade is True: + stmt += ' CASCADE' + elif element.cascade is False: + stmt += ' RESTRICT' + return stmt + def create_table_from_selectable( - name, - selectable, - indexes=None, - metadata=None, - aliases=None, - **kwargs -): + name: str, + selectable: 'Selectable', + indexes: Optional[List[sa.Index]] = None, + metadata: Optional[sa.MetaData] = None, + aliases: Optional[Dict[str, str]] = None, + schema: Optional[str] = None, + **kwargs: Any, +) -> sa.Table: if indexes is None: indexes = [] if metadata is None: @@ -60,7 +166,7 @@ def create_table_from_selectable( ) for c in get_columns(selectable) ] + indexes - table = sa.Table(name, metadata, *args, **kwargs) + table = sa.Table(name, metadata, *args, schema=schema, **kwargs) if not any([c.primary_key for c in get_columns(selectable)]): table.append_constraint( @@ -70,12 +176,16 @@ def create_table_from_selectable( def create_materialized_view( - name, - selectable, - metadata, - indexes=None, - aliases=None -): + name: str, + selectable: 'Selectable', + metadata: sa.MetaData, + indexes: Optional[List[sa.Index]] = None, + aliases: Optional[Dict[str, str]] = None, + *, + schema: Optional[str] = None, + populate: Optional[bool] = None, + cascade_on_drop: Optional[bool] = None, +) -> sa.Table: """ Create a view on a given metadata :param name: The name of the view to create. @@ -87,6 +197,17 @@ def create_materialized_view( :param aliases: An optional dictionary containing with keys as column names and values as column aliases. + :param schema: The name of the schema where the view will be created (optional). + :param populate: + Set ``populate=True`` to create the view with ``WITH DATA``. + Set ``populate=False`` to create the view with ``WITH NO DATA``. + Default to ``None`` for no flags. + See also: https://www.postgresql.org/docs/current/sql-createview.html + :param cascade_on_drop: + Set ``cascade_on_drop=True`` to drop the view with ``CASCADE``. + Set ``cascade_on_drop=False`` to create the view with ``RESTRICT``. + Default to ``None`` for no flags. + See also: https://www.postgresql.org/docs/current/sql-dropmaterializedview.html Same as for ``create_view`` except that a ``CREATE MATERIALIZED VIEW`` statement is emitted instead of a ``CREATE VIEW``. @@ -97,13 +218,14 @@ def create_materialized_view( selectable=selectable, indexes=indexes, metadata=None, - aliases=aliases + aliases=aliases, + schema=schema, ) sa.event.listen( metadata, 'after_create', - CreateView(name, selectable, materialized=True) + CreateMaterializedView(name, selectable, schema=schema, populate=populate) ) @sa.event.listens_for(metadata, 'after_create') @@ -114,17 +236,19 @@ def create_indexes(target, connection, **kw): sa.event.listen( metadata, 'before_drop', - DropView(name, materialized=True) + DropMaterializedView(name, schema=schema, cascade=cascade_on_drop) ) return table def create_view( - name, - selectable, - metadata, - cascade_on_drop=True -): + name: str, + selectable: 'Selectable', + metadata: sa.MetaData, + *, + schema: Optional[str] = None, + cascade_on_drop: Optional[str] = None, +) -> sa.Table: """ Create a view on a given metadata :param name: The name of the view to create. @@ -132,6 +256,11 @@ def create_view( :param metadata: An SQLAlchemy Metadata instance that stores the features of the database being described. + :param schema: The name of the schema where the view will be created (optional). + :param cascade_on_drop: + Set ``cascade_on_drop=True`` to drop the view with ``CASCADE``. + Set ``cascade_on_drop=False`` to create the view with ``RESTRICT``. + Default to ``None`` for no flags. The process for creating a view is similar to the standard way that a table is constructed, except that a selectable is provided instead of @@ -160,10 +289,15 @@ def create_view( table = create_table_from_selectable( name=name, selectable=selectable, - metadata=None + metadata=None, + schema=schema, ) - sa.event.listen(metadata, 'after_create', CreateView(name, selectable)) + sa.event.listen( + metadata, + 'after_create', + CreateView(name, selectable, schema=schema), + ) @sa.event.listens_for(metadata, 'after_create') def create_indexes(target, connection, **kw): @@ -173,12 +307,18 @@ def create_indexes(target, connection, **kw): sa.event.listen( metadata, 'before_drop', - DropView(name, cascade=cascade_on_drop) + DropView(name, schema=schema, cascade=cascade_on_drop) ) return table -def refresh_materialized_view(session, name, concurrently=False): +def refresh_materialized_view( + session: 'Session', + name: str, + concurrently: bool = False, + *, + schema: Optional[str] = None, +) -> None: """ Refreshes an already existing materialized view :param session: An SQLAlchemy Session instance. @@ -186,6 +326,7 @@ def refresh_materialized_view(session, name, concurrently=False): :param concurrently: Optional flag that causes the ``CONCURRENTLY`` parameter to be specified when the materialized view is refreshed. + :param schema: The schema of the view to be refreshed (optional). """ # Since session.execute() bypasses autoflush, we must manually flush in # order to include newly-created/modified objects in the refresh. @@ -193,6 +334,6 @@ def refresh_materialized_view(session, name, concurrently=False): session.execute( sa.text('REFRESH MATERIALIZED VIEW {}{}'.format( 'CONCURRENTLY ' if concurrently else '', - session.bind.engine.dialect.identifier_preparer.quote(name) + _prepare_view_identifier(session.bind.engine.dialect, name, schema), )) ) diff --git a/tests/test_views.py b/tests/test_views.py index c4be7099..e76bff23 100644 --- a/tests/test_views.py +++ b/tests/test_views.py @@ -73,7 +73,41 @@ class ArticleView(Base): @pytest.fixture -def init_models(ArticleMV, ArticleView): +def view_schema(Base): + sa.event.listen( + Base.metadata, + "before_create", + sa.DDL("CREATE SCHEMA IF NOT EXISTS views"), + ) + + +@pytest.fixture +def UserMV(Base, User): + class UserMV(Base): + __table__ = create_materialized_view( + name='user-mv', + selectable=sa.select(*_select_args(User.id)), + metadata=Base.metadata, + schema='views', + populate=False, + ) + return UserMV + + +@pytest.fixture +def UserView(Base, User): + class UserView(Base): + __table__ = create_view( + name='user-view', + selectable=sa.select(*_select_args(User.id)), + metadata=Base.metadata, + schema='views', + ) + return UserView + + +@pytest.fixture +def init_models(view_schema, ArticleMV, ArticleView, UserMV, UserView): pass @@ -114,6 +148,21 @@ def test_querying_view( assert row.name == 'Some article' assert row.author_name == 'Some user' + def test_querying_view_in_schema(self, session, User, UserView): + user = User(name='Some user') + session.add(user) + session.commit() + assert session.query(User).first().id == session.query(UserView).first().id + assert 'views."user-view"' in str(session.query(UserView)) + + def test_querying_unpopulated_mv_in_schema(self, session, User, UserMV): + with pytest.raises(sa.exc.OperationalError): + session.query(UserMV).first() + session.rollback() + + refresh_materialized_view(session, 'user-mv', schema='views') + session.query(UserMV).all() + class TrivialViewTestCases: def life_cycle(