Skip to content

Commit b7c1db9

Browse files
committed
[view] Support CREATE OR REPLACE
1 parent 4b05d05 commit b7c1db9

File tree

2 files changed

+76
-10
lines changed

2 files changed

+76
-10
lines changed

sqlalchemy_utils/view.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,19 @@
77

88

99
class CreateView(DDLElement):
10-
def __init__(self, name, selectable, materialized=False):
10+
def __init__(self, name, selectable, materialized=False, replace=False):
11+
if materialized and replace:
12+
raise ValueError("Cannot use CREATE OR REPLACE with materialized views")
1113
self.name = name
1214
self.selectable = selectable
1315
self.materialized = materialized
16+
self.replace = replace
1417

1518

1619
@compiler.compiles(CreateView)
1720
def compile_create_materialized_view(element, compiler, **kw):
18-
return 'CREATE {}VIEW {} AS {}'.format(
21+
return 'CREATE {}{}VIEW {} AS {}'.format(
22+
'OR REPLACE ' if element.replace else '',
1923
'MATERIALIZED ' if element.materialized else '',
2024
compiler.dialect.identifier_preparer.quote(element.name),
2125
compiler.sql_compiler.process(element.selectable, literal_binds=True),
@@ -124,7 +128,8 @@ def create_view(
124128
name,
125129
selectable,
126130
metadata,
127-
cascade_on_drop=True
131+
cascade_on_drop=True,
132+
replace=False,
128133
):
129134
""" Create a view on a given metadata
130135
@@ -164,7 +169,11 @@ def create_view(
164169
metadata=None
165170
)
166171

167-
sa.event.listen(metadata, 'after_create', CreateView(name, selectable))
172+
sa.event.listen(
173+
metadata,
174+
'after_create',
175+
CreateView(name, selectable, replace=replace),
176+
)
168177

169178
@sa.event.listens_for(metadata, 'after_create')
170179
def create_indexes(target, connection, **kw):

tests/test_views.py

Lines changed: 63 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
create_view,
88
refresh_materialized_view
99
)
10+
from sqlalchemy_utils.view import CreateView
1011
from sqlalchemy_utils.compat import _select_args
1112

1213

@@ -121,16 +122,18 @@ def life_cycle(
121122
engine,
122123
metadata,
123124
column,
124-
cascade_on_drop
125+
cascade_on_drop,
126+
replace=False,
125127
):
126128
__table__ = create_view(
127129
name='trivial_view',
128130
selectable=sa.select(*_select_args(column)),
129131
metadata=metadata,
130-
cascade_on_drop=cascade_on_drop
132+
cascade_on_drop=cascade_on_drop,
133+
replace=replace,
131134
)
132-
__table__.create(engine)
133-
__table__.drop(engine)
135+
metadata.create_all(engine)
136+
metadata.drop_all(engine)
134137

135138

136139
class SupportsCascade(TrivialViewTestCases):
@@ -164,13 +167,67 @@ def test_life_cycle_no_cascade(
164167
self.life_cycle(engine, Base.metadata, User.id, cascade_on_drop=False)
165168

166169

170+
class SupportsReplace(TrivialViewTestCases):
171+
def test_life_cycle_replace(
172+
self,
173+
connection,
174+
engine,
175+
Base,
176+
User
177+
):
178+
self.life_cycle(
179+
engine,
180+
Base.metadata,
181+
User.id,
182+
cascade_on_drop=False,
183+
replace=True,
184+
)
185+
186+
def test_life_cycle_replace_existing(
187+
self,
188+
connection,
189+
engine,
190+
Base,
191+
User
192+
):
193+
__table__ = create_view(
194+
name='trivial_view',
195+
selectable=sa.select(*_select_args(User.id)),
196+
metadata=Base.metadata,
197+
)
198+
Base.metadata.create_all(engine)
199+
view = CreateView(
200+
name='trivial_view',
201+
selectable=sa.select(*_select_args(User.id)),
202+
replace=True,
203+
)
204+
connection.execute(view)
205+
connection.commit()
206+
Base.metadata.drop_all(engine)
207+
208+
def test_replace_materialized(
209+
self,
210+
connection,
211+
engine,
212+
Base,
213+
User
214+
):
215+
with pytest.raises(ValueError):
216+
CreateView(
217+
name='trivial_view',
218+
selectable=sa.select(*_select_args(User.id)),
219+
materialized=True,
220+
replace=True,
221+
)
222+
223+
167224
@pytest.mark.usefixtures('postgresql_dsn')
168-
class TestPostgresTrivialView(SupportsCascade, SupportsNoCascade):
225+
class TestPostgresTrivialView(SupportsCascade, SupportsNoCascade, SupportsReplace):
169226
pass
170227

171228

172229
@pytest.mark.usefixtures('mysql_dsn')
173-
class TestMySqlTrivialView(SupportsCascade, SupportsNoCascade):
230+
class TestMySqlTrivialView(SupportsCascade, SupportsNoCascade, SupportsReplace):
174231
pass
175232

176233

0 commit comments

Comments
 (0)