Skip to content

Commit 89bca60

Browse files
Aleksandr Sterkhovmikicz
authored andcommitted
Updated view creation to not assume default DB schema is "public"
1 parent c809ae1 commit 89bca60

File tree

2 files changed

+63
-13
lines changed

2 files changed

+63
-13
lines changed

django_pgviews/view.py

Lines changed: 38 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def _schema_and_name(connection, view_name):
7171
try:
7272
schema_name = connection.schema_name
7373
except AttributeError:
74-
schema_name = "public"
74+
schema_name = None
7575

7676
return schema_name, view_name
7777

@@ -119,6 +119,26 @@ def _create_index_sql(self, *args, **kwargs):
119119
return statement
120120

121121

122+
def _make_where(**kwargs):
123+
where_fragments = []
124+
params = []
125+
126+
for key, value in kwargs.items():
127+
if value is None:
128+
# skip key if value is not specified
129+
continue
130+
131+
if isinstance(value, (list, tuple)):
132+
in_fragment = ", ".join("%s" for _ in range(len(value)))
133+
where_fragments.append(f"{key} IN ({in_fragment})")
134+
params.extend(list(value))
135+
else:
136+
where_fragments.append(f"{key} = %s")
137+
params.append(value)
138+
where_fragment = " AND ".join(where_fragments)
139+
return where_fragment, params
140+
141+
122142
def _ensure_indexes(connection, cursor, view_cls, schema_name_log):
123143
"""
124144
This function gets called when a materialized view is deemed not needing a re-create. That is however only a part
@@ -131,7 +151,8 @@ def _ensure_indexes(connection, cursor, view_cls, schema_name_log):
131151
indexes = view_cls._meta.indexes
132152
vschema, vname = _schema_and_name(connection, view_name)
133153

134-
cursor.execute("SELECT indexname FROM pg_indexes WHERE tablename = %s AND schemaname = %s", [vname, vschema])
154+
where_fragment, params = _make_where(schemaname=vschema, tablename=vname)
155+
cursor.execute(f"SELECT indexname FROM pg_indexes WHERE {where_fragment}", params)
135156

136157
existing_indexes = {x[0] for x in cursor.fetchall()}
137158
required_indexes = {x.name for x in indexes}
@@ -143,7 +164,11 @@ def _ensure_indexes(connection, cursor, view_cls, schema_name_log):
143164
concurrent_index_name = None
144165

145166
for index_name in existing_indexes - required_indexes:
146-
cursor.execute(f"DROP INDEX {vschema}.{index_name}")
167+
if vschema:
168+
full_index_name = f"{vschema}.{index_name}"
169+
else:
170+
full_index_name = index_name
171+
cursor.execute(f"DROP INDEX {full_index_name}")
147172
logger.info("pgview dropped index %s on view %s (%s)", index_name, view_name, schema_name_log)
148173

149174
schema_editor: DatabaseSchemaEditor = CustomSchemaEditor(connection)
@@ -188,10 +213,12 @@ def create_materialized_view(connection, view_cls, check_sql_changed=False):
188213
cursor_wrapper = connection.cursor()
189214
cursor = cursor_wrapper.cursor
190215

216+
where_fragment, params = _make_where(schemaname=vschema, matviewname=vname)
217+
191218
try:
192219
cursor.execute(
193-
"SELECT COUNT(*) FROM pg_matviews WHERE schemaname = %s and matviewname = %s;",
194-
[vschema, vname],
220+
f"SELECT COUNT(*) FROM pg_matviews WHERE {where_fragment};",
221+
params,
195222
)
196223
view_exists = cursor.fetchone()[0] > 0
197224

@@ -206,9 +233,10 @@ def create_materialized_view(connection, view_cls, check_sql_changed=False):
206233
_drop_mat_view(cursor, temp_viewname)
207234
_create_mat_view(cursor, temp_viewname, query, view_query.params, with_data=False)
208235

236+
definitions_where, definitions_params = _make_where(schemaname=vschema, matviewname=[vname, temp_vname])
209237
cursor.execute(
210-
"SELECT definition FROM pg_matviews WHERE schemaname = %s and matviewname IN (%s, %s);",
211-
[vschema, vname, temp_vname],
238+
f"SELECT definition FROM pg_matviews WHERE {definitions_where};",
239+
definitions_params,
212240
)
213241
definitions = cursor.fetchall()
214242

@@ -265,9 +293,10 @@ def create_view(connection, view_name, view_query: ViewSQL, update=True, force=F
265293
try:
266294
force_required = False
267295
# Determine if view already exists.
296+
view_exists_where, view_exists_params = _make_where(table_schema=vschema, table_name=vname)
268297
cursor.execute(
269-
"SELECT COUNT(*) FROM information_schema.views WHERE table_schema = %s and table_name = %s;",
270-
[vschema, vname],
298+
f"SELECT COUNT(*) FROM information_schema.views WHERE {view_exists_where};",
299+
view_exists_params,
271300
)
272301
view_exists = cursor.fetchone()[0] > 0
273302
if view_exists and not update:

tests/test_project/viewtest/tests.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from django.utils import timezone
1515

1616
from django_pgviews.signals import all_views_synced, view_synced
17-
from django_pgviews.view import _schema_and_name
17+
from django_pgviews.view import _make_where, _schema_and_name
1818

1919
from . import models
2020
from .models import LatestSuperusers
@@ -27,8 +27,8 @@
2727

2828
def get_list_of_indexes(cursor, cls):
2929
schema, table = _schema_and_name(cursor.connection, cls._meta.db_table)
30-
31-
cursor.execute("SELECT indexname FROM pg_indexes WHERE tablename = %s AND schemaname = %s", [table, schema])
30+
where_fragment, params = _make_where(tablename=table, schemaname=schema)
31+
cursor.execute(f"SELECT indexname FROM pg_indexes WHERE {where_fragment}", params)
3232
return {x[0] for x in cursor.fetchall()}
3333

3434

@@ -150,7 +150,6 @@ def test_refresh_missing(self):
150150
def test_materialized_view_indexes(self):
151151
with connection.cursor() as cursor:
152152
orig_indexes = get_list_of_indexes(cursor, models.MaterializedRelatedViewWithIndex)
153-
154153
self.assertIn("viewtest_materializedrelatedviewwithindex_id_index", orig_indexes)
155154
self.assertEqual(len(orig_indexes), 2)
156155

@@ -417,3 +416,25 @@ def test_sync_depending_materialized_views(self):
417416

418417
with self.assertRaises(DatabaseError):
419418
cur.execute("""SELECT name from viewtest_dependantmaterializedview;""")
419+
420+
421+
class MakeWhereTestCase(TestCase):
422+
def test_with_schema(self):
423+
where_fragment, params = _make_where(schemaname="test_schema", tablename="test_tablename")
424+
self.assertEqual(where_fragment, "schemaname = %s AND tablename = %s")
425+
self.assertEqual(params, ["test_schema", "test_tablename"])
426+
427+
def test_no_schema(self):
428+
where_fragment, params = _make_where(schemaname=None, tablename="test_tablename")
429+
self.assertEqual(where_fragment, "tablename = %s")
430+
self.assertEqual(params, ["test_tablename"])
431+
432+
def test_with_schema_list(self):
433+
where_fragment, params = _make_where(schemaname="test_schema", tablename=["test_tablename1", "test_tablename2"])
434+
self.assertEqual(where_fragment, "schemaname = %s AND tablename IN (%s, %s)")
435+
self.assertEqual(params, ["test_schema", "test_tablename1", "test_tablename2"])
436+
437+
def test_no_schema_list(self):
438+
where_fragment, params = _make_where(schemaname=None, tablename=["test_tablename1", "test_tablename2"])
439+
self.assertEqual(where_fragment, "tablename IN (%s, %s)")
440+
self.assertEqual(params, ["test_tablename1", "test_tablename2"])

0 commit comments

Comments
 (0)