diff --git a/tests/test_grant.py b/tests/test_grant.py index 07b20f2..e11d3f8 100644 --- a/tests/test_grant.py +++ b/tests/test_grant.py @@ -237,3 +237,29 @@ def test_grant_on_dynamic_tables(): assert future_grant._data.in_name == "SOMEDB.SOMESCHEMA" assert future_grant._data.in_type == ResourceType.SCHEMA assert future_grant._data.on_type == ResourceType.DYNAMIC_TABLE + + +def test_grant_database_role_to_database_role(): + database = res.Database(name="somedb") + parent = res.DatabaseRole(name="parent", database=database) + child = res.DatabaseRole(name="child", database=database) + grant = res.RoleGrant(role=child, to_role=parent) + assert grant.role.name == "child" + assert grant.to.name == "parent" + + +def test_grant_database_role_to_account_role(): + database = res.Database(name="somedb") + parent = res.Role(name="parent") + child = res.DatabaseRole(name="child", database=database) + grant = res.RoleGrant(role=child, to_role=parent) + assert grant.role.name == "child" + assert grant.to.name == "parent" + + +def test_grant_database_role_to_system_role(): + database = res.Database(name="somedb") + child = res.DatabaseRole(name="child", database=database) + grant = res.RoleGrant(role=child, to_role="SYSADMIN") + assert grant.role.name == "child" + assert grant.to.name == "SYSADMIN" diff --git a/titan/resources/grant.py b/titan/resources/grant.py index a7327f9..e8fd9d3 100644 --- a/titan/resources/grant.py +++ b/titan/resources/grant.py @@ -6,10 +6,11 @@ from ..enums import ParseableEnum, ResourceType from ..identifiers import FQN, parse_FQN, resource_label_for_type, resource_type_for_label -from ..parse import parse_grant, format_collection_string +from ..parse import format_collection_string, parse_grant from ..privs import all_privs_for_resource_type from ..props import FlagProp, IdentifierProp, Props from ..resource_name import ResourceName +from ..role_ref import RoleRef from ..scope import AccountScope from .resource import NamedResource, Resource, ResourcePointer, ResourceSpec from .role import Role @@ -23,7 +24,7 @@ class _Grant(ResourceSpec): priv: str on: str on_type: ResourceType - to: Role + to: RoleRef grant_option: bool = False owner: Role = field(default=None, metadata={"fetchable": False}) _privs: list[str] = field(default_factory=list, metadata={"triggers_create": True}) @@ -242,7 +243,7 @@ class _FutureGrant(ResourceSpec): on_type: ResourceType in_type: ResourceType in_name: ResourceName - to: Role + to: RoleRef grant_option: bool = False def __post_init__(self): @@ -425,7 +426,7 @@ class _GrantOnAll(ResourceSpec): on_type: ResourceType in_type: ResourceType in_name: ResourceName - to: Role + to: RoleRef grant_option: bool = False def __post_init__(self): @@ -587,8 +588,8 @@ def grant_on_all_fqn(data: _GrantOnAll): @dataclass(unsafe_hash=True) class _RoleGrant(ResourceSpec): - role: Role - to_role: Role = None + role: RoleRef + to_role: RoleRef = None to_user: User = None def __post_init__(self): diff --git a/titan/resources/resource.py b/titan/resources/resource.py index 3bd3998..0005ce3 100644 --- a/titan/resources/resource.py +++ b/titan/resources/resource.py @@ -224,7 +224,7 @@ def __post_init__(self): setattr(self, f.name, new_value) except TypeError as err: human_readable_classname = self.__class__.__name__[1:] - if issubclass(f.type, Enum): + if isclass(f.type) and issubclass(f.type, Enum): raise TypeError( f"Expected {human_readable_classname}.{f.name} to be one of ({', '.join(f.type.__members__.keys())}), got {repr(field_value)} instead" ) from err @@ -699,13 +699,15 @@ def convert_role_ref(role_ref: RoleRef) -> Resource: ResourceType.ROLE, ): return role_ref - elif isinstance(role_ref, str): + elif isinstance(role_ref, str) or isinstance(role_ref, ResourceName): return ResourcePointer(name=role_ref, resource_type=infer_role_type_from_name(role_ref)) else: raise TypeError -def infer_role_type_from_name(name: str) -> ResourceType: +def infer_role_type_from_name(name: Union[str, ResourceName]) -> ResourceType: + if isinstance(name, ResourceName): + name = str(name) if name == "": return ResourceType.ROLE identifier = parse_identifier(name, is_db_scoped=True)