From 938755148e5105bfba405845508361f39c4d914c Mon Sep 17 00:00:00 2001 From: Michael Gorven Date: Fri, 9 Jun 2023 14:40:54 -0700 Subject: [PATCH] [view] Support CREATE OR REPLACE --- sqlalchemy_utils/view.py | 21 ++++++++--- tests/test_views.py | 75 +++++++++++++++++++++++++++++++++++----- 2 files changed, 83 insertions(+), 13 deletions(-) diff --git a/sqlalchemy_utils/view.py b/sqlalchemy_utils/view.py index 96cbe36c..fae06bfc 100644 --- a/sqlalchemy_utils/view.py +++ b/sqlalchemy_utils/view.py @@ -7,15 +7,19 @@ class CreateView(DDLElement): - def __init__(self, name, selectable, materialized=False): + def __init__(self, name, selectable, materialized=False, replace=False): + if materialized and replace: + raise ValueError("Cannot use CREATE OR REPLACE with materialized views") self.name = name self.selectable = selectable self.materialized = materialized + self.replace = replace @compiler.compiles(CreateView) def compile_create_materialized_view(element, compiler, **kw): - return 'CREATE {}VIEW {} AS {}'.format( + return 'CREATE {}{}VIEW {} AS {}'.format( + 'OR REPLACE ' if element.replace else '', 'MATERIALIZED ' if element.materialized else '', compiler.dialect.identifier_preparer.quote(element.name), compiler.sql_compiler.process(element.selectable, literal_binds=True), @@ -124,7 +128,8 @@ def create_view( name, selectable, metadata, - cascade_on_drop=True + cascade_on_drop=True, + replace=False, ): """ Create a view on a given metadata @@ -133,6 +138,10 @@ def create_view( :param metadata: An SQLAlchemy Metadata instance that stores the features of the database being described. + :param cascade_on_drop: If ``True`` the view will be dropped with + ``CASCADE``, deleting all dependent objects as well. + :param replace: If ``True`` the view will be created with ``OR REPLACE``, + replacing an existing view with the same name. The process for creating a view is similar to the standard way that a table is constructed, except that a selectable is provided instead of @@ -164,7 +173,11 @@ def create_view( metadata=None ) - sa.event.listen(metadata, 'after_create', CreateView(name, selectable)) + sa.event.listen( + metadata, + 'after_create', + CreateView(name, selectable, replace=replace), + ) @sa.event.listens_for(metadata, 'after_create') def create_indexes(target, connection, **kw): diff --git a/tests/test_views.py b/tests/test_views.py index b3c0e2c2..2ec3ef96 100644 --- a/tests/test_views.py +++ b/tests/test_views.py @@ -8,6 +8,7 @@ refresh_materialized_view ) from sqlalchemy_utils.compat import _select_args +from sqlalchemy_utils.view import CreateView @pytest.fixture @@ -15,7 +16,7 @@ def Article(Base, User): class Article(Base): __tablename__ = 'article' id = sa.Column(sa.Integer, primary_key=True) - name = sa.Column(sa.String) + name = sa.Column(sa.String(128)) author_id = sa.Column(sa.Integer, sa.ForeignKey(User.id)) author = sa.orm.relationship(User) return Article @@ -26,7 +27,7 @@ def User(Base): class User(Base): __tablename__ = 'user' id = sa.Column(sa.Integer, primary_key=True) - name = sa.Column(sa.String) + name = sa.Column(sa.String(128)) return User @@ -121,16 +122,18 @@ def life_cycle( engine, metadata, column, - cascade_on_drop + cascade_on_drop, + replace=False, ): - __table__ = create_view( + create_view( name='trivial_view', selectable=sa.select(*_select_args(column)), metadata=metadata, - cascade_on_drop=cascade_on_drop + cascade_on_drop=cascade_on_drop, + replace=replace, ) - __table__.create(engine) - __table__.drop(engine) + metadata.create_all(engine) + metadata.drop_all(engine) class SupportsCascade(TrivialViewTestCases): @@ -164,13 +167,67 @@ def test_life_cycle_no_cascade( self.life_cycle(engine, Base.metadata, User.id, cascade_on_drop=False) +class SupportsReplace(TrivialViewTestCases): + def test_life_cycle_replace( + self, + connection, + engine, + Base, + User + ): + self.life_cycle( + engine, + Base.metadata, + User.id, + cascade_on_drop=False, + replace=True, + ) + + def test_life_cycle_replace_existing( + self, + connection, + engine, + Base, + User + ): + create_view( + name='trivial_view', + selectable=sa.select(*_select_args(User.id)), + metadata=Base.metadata, + ) + Base.metadata.create_all(engine) + view = CreateView( + name='trivial_view', + selectable=sa.select(*_select_args(User.id)), + replace=True, + ) + with connection.begin(): + connection.execute(view) + Base.metadata.drop_all(engine) + + def test_replace_materialized( + self, + connection, + engine, + Base, + User + ): + with pytest.raises(ValueError): + CreateView( + name='trivial_view', + selectable=sa.select(*_select_args(User.id)), + materialized=True, + replace=True, + ) + + @pytest.mark.usefixtures('postgresql_dsn') -class TestPostgresTrivialView(SupportsCascade, SupportsNoCascade): +class TestPostgresTrivialView(SupportsCascade, SupportsNoCascade, SupportsReplace): pass @pytest.mark.usefixtures('mysql_dsn') -class TestMySqlTrivialView(SupportsCascade, SupportsNoCascade): +class TestMySqlTrivialView(SupportsCascade, SupportsNoCascade, SupportsReplace): pass