From 17a32d4fd1e712c9080b4d9cc298a319cddb2b8f Mon Sep 17 00:00:00 2001 From: Paul Draper Date: Sun, 31 Jan 2021 20:45:56 -0700 Subject: [PATCH] Implement comments --- schemainspect/misc.py | 4 +- schemainspect/pg/obj.py | 62 +++++++++++++++++++++++++++++++ schemainspect/pg/sql/comments.sql | 54 +++++++++++++++++++++++++++ tests/test_all.py | 13 +++++++ 4 files changed, 132 insertions(+), 1 deletion(-) create mode 100644 schemainspect/pg/sql/comments.sql diff --git a/schemainspect/misc.py b/schemainspect/misc.py index c4549cf..51ee1b8 100644 --- a/schemainspect/misc.py +++ b/schemainspect/misc.py @@ -31,8 +31,10 @@ def __ne__(self, other): return not self == other -def quoted_identifier(identifier, schema=None, identity_arguments=None): +def quoted_identifier(identifier, schema=None, identity_arguments=None, table=None): s = '"{}"'.format(identifier.replace('"', '""')) + if table: + s = '"{}".{}'.format(table.replace('"', '""'), s) if schema: s = '"{}".{}'.format(schema.replace('"', '""'), s) if identity_arguments is not None: diff --git a/schemainspect/pg/obj.py b/schemainspect/pg/obj.py index 4f9ad46..1be7b04 100644 --- a/schemainspect/pg/obj.py +++ b/schemainspect/pg/obj.py @@ -39,6 +39,7 @@ COLLATIONS_QUERY = resource_text("sql/collations.sql") COLLATIONS_QUERY_9 = resource_text("sql/collations9.sql") RLSPOLICIES_QUERY = resource_text("sql/rlspolicies.sql") +COMMENTS_QUERY = resource_text("sql/comments.sql") class InspectedSelectable(BaseInspectedSelectable): @@ -943,6 +944,49 @@ def key(self): return self.object_type, self.quoted_full_name, self.target_user, self.privilege +class InspectedComment(Inspected): + def __init__(self, object_type, schema, table, name, args, comment): + self.object_type = object_type + self.schema = schema + self.table = table + self.name = name + self.args = args + self.comment = comment + + @property + def _identifier(self): + return quoted_identifier( + self.name, + schema=self.schema, + table=self.table, + identity_arguments=self.args, + ) + + @property + def drop_statement(self): + return "comment on {} {} is null;".format(self.object_type, self._identifier) + + @property + def create_statement(self): + return "comment on {} {} is '{}';".format( + self.object_type, self._identifier, self.comment + ) + + @property + def key(self): + return "{} {}".format(self.object_type, self._identifier) + + def __eq__(self, other): + return ( + self.object_type == other.object_type + and self.schema == other.schema + and self.table == other.table + and self.name == other.name + and self.args == other.args + and self.comment == other.comment + ) + + RLS_POLICY_CREATE = """create policy {name} on {table_name} as {permissiveness} @@ -1069,6 +1113,7 @@ def processed(q): self.SCHEMAS_QUERY = processed(SCHEMAS_QUERY) self.PRIVILEGES_QUERY = processed(PRIVILEGES_QUERY) self.TRIGGERS_QUERY = processed(TRIGGERS_QUERY) + self.COMMENTS_QUERY = processed(COMMENTS_QUERY) super(PostgreSQL, self).__init__(c, include_internal) @@ -1086,6 +1131,7 @@ def load_all(self): self.load_rlspolicies() self.load_types() self.load_domains() + self.load_comments() self.load_deps() self.load_deps_all() @@ -1582,6 +1628,21 @@ def col(defn): ] # type: list[InspectedType] self.domains = od((t.signature, t) for t in domains) + def load_comments(self): + q = self.c.execute(self.COMMENTS_QUERY) + comments = [ + InspectedComment( + i.object_type, + i.schema, + i.table, + i.name, + i.args, + i.comment, + ) + for i in q + ] # type: list[InspectedComment] + self.comments = od((t.key, t) for t in comments) + def filter_schema(self, schema=None, exclude_schema=None): if schema and exclude_schema: raise ValueError("Can only have schema or exclude schema, not both") @@ -1654,4 +1715,5 @@ def __eq__(self, other): and self.triggers == other.triggers and self.collations == other.collations and self.rlspolicies == other.rlspolicies + and self.comments == other.comments ) diff --git a/schemainspect/pg/sql/comments.sql b/schemainspect/pg/sql/comments.sql new file mode 100644 index 0000000..c622c02 --- /dev/null +++ b/schemainspect/pg/sql/comments.sql @@ -0,0 +1,54 @@ +select + 'function' object_type, + n.nspname "schema", + NULL "table", + p.proname "name", + pg_catalog.pg_get_function_identity_arguments(p.oid) args, + pg_catalog.obj_description(p.oid, 'pg_proc') "comment" +from + pg_catalog.pg_proc p + join pg_catalog.pg_namespace n on n.oid = p.pronamespace +where + n.nspname <> 'pg_catalog' + and n.nspname <> 'information_schema' + and pg_catalog.obj_description(p.oid, 'pg_proc') is not null +union all +select + case c.relkind + when 'I' then 'index' + when 'c' then 'type' + when 'i' then 'index' + when 'm' then 'materialized view' + when 'p' then 'table' + when 'r' then 'table' + when 's' then 'sequence' + when 'v' then 'view' + end, + n.nspname, + NULL, + c.relname, + NULL, + pg_catalog.obj_description(c.oid, 'pg_class') +from + pg_catalog.pg_class c + join pg_catalog.pg_namespace n on n.oid = c.relnamespace +where + n.nspname <> 'pg_catalog' + and n.nspname <> 'information_schema' + and pg_catalog.obj_description(c.oid, 'pg_class') is not null +union all +select + 'column', + n.nspname, + c.relname, + a.attname, + NULL, + pg_catalog.col_description(c.oid, a.attnum) +from + pg_catalog.pg_attribute a + join pg_catalog.pg_class c on c.oid = a.attrelid + join pg_catalog.pg_namespace n on n.oid = c.relnamespace +where + n.nspname <> 'pg_catalog' + and n.nspname <> 'information_schema' + and pg_catalog.col_description(c.oid, a.attnum) is not null; diff --git a/tests/test_all.py b/tests/test_all.py index 88b8f18..a06102c 100644 --- a/tests/test_all.py +++ b/tests/test_all.py @@ -485,6 +485,19 @@ def asserts_pg(i, has_timescale=False): with raises(ValueError): tid.change_string_to_enum_statement("t") + # comments + assert len(i.comments) == 2 + assert ( + i.comments[ + 'function "public"."films_f"(d date, def_t text, def_d date)' + ].create_statement + == 'comment on function "public"."films_f"(d date, def_t text, def_d date) is \'films_f comment\';' + ) + assert ( + i.comments['table "public"."emptytable"'].create_statement + == 'comment on table "public"."emptytable" is \'emptytable comment\';' + ) + def test_weird_names(db): with S(db) as s: