Skip to content

Commit 6c868e1

Browse files
authored
feat: add initial code (#2)
* Ignore .idea/ * feat: add async PostgreSQL watcher (#1) * test: add tests for AsyncPostgresWatcher * docs: add examples * docs: update README.md * feat: add Apache header * feat: add requirements.txt * docs: update README.md
1 parent 41c21f1 commit 6c868e1

File tree

8 files changed

+244
-2
lines changed

8 files changed

+244
-2
lines changed

.gitignore

+1-1
Original file line numberDiff line numberDiff line change
@@ -157,4 +157,4 @@ cython_debug/
157157
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158158
# and can be added to the global gitignore or merged into this file. For a more nuclear
159159
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
160-
#.idea/
160+
.idea/

README.md

+53-1
Original file line numberDiff line numberDiff line change
@@ -1 +1,53 @@
1-
# async-postgres-watcher
1+
# async-postgres-watcher
2+
3+
[![Discord](https://img.shields.io/discord/1022748306096537660?logo=discord&label=discord&color=5865F2)](https://discord.gg/S5UjpzGZjN)
4+
5+
Async casbin role watcher to be used for monitoring updates to casbin policies
6+
7+
## Basic Usage Example
8+
9+
### With Flask-authz
10+
11+
```python
12+
from flask_authz import CasbinEnforcer
13+
from async_postgres_watcher import AsyncPostgresWatcher
14+
from flask import Flask
15+
from casbin.persist.adapters import FileAdapter
16+
17+
casbin_enforcer = CasbinEnforcer(app, adapter)
18+
19+
watcher = AsyncPostgresWatcher(host=HOST, port=PORT, user=USER, password=PASSWORD, dbname=DBNAME)
20+
watcher.set_update_callback(casbin_enforcer.e.load_policy)
21+
22+
casbin_enforcer.set_watcher(watcher)
23+
```
24+
25+
## Basic Usage Example With SSL Enabled
26+
27+
See [asyncpg documentation](https://magicstack.github.io/asyncpg/current/api/index.html#connection) for full details of SSL parameters.
28+
29+
### With Flask-authz
30+
31+
```python
32+
from flask_authz import CasbinEnforcer
33+
from async_postgres_watcher import AsyncPostgresWatcher
34+
from flask import Flask
35+
from casbin.persist.adapters import FileAdapter
36+
37+
casbin_enforcer = CasbinEnforcer(app, adapter)
38+
39+
# If check_hostname is True, the SSL context is created with sslmode=verify-full.
40+
# If check_hostname is False, the SSL context is created with sslmode=verify-ca.
41+
watcher = AsyncPostgresWatcher(host=HOST, port=PORT, user=USER, password=PASSWORD, dbname=DBNAME, sslrootcert=SSLROOTCERT, check_hostname = True, sslcert=SSLCERT, sslkey=SSLKEY)
42+
43+
watcher.set_update_callback(casbin_enforcer.e.load_policy)
44+
casbin_enforcer.set_watcher(watcher)
45+
```
46+
47+
## Getting Help
48+
49+
- [PyCasbin](https://github.com/casbin/pycasbin)
50+
51+
## License
52+
53+
This project is under Apache 2.0 License. See the [LICENSE](LICENSE) file for the full license text.

async_postgres_watcher/__init__.py

+15
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Copyright 2024 The casbin Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from .watcher import AsyncPostgresWatcher

async_postgres_watcher/watcher.py

+100
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
# Copyright 2024 The casbin Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Optional, Callable
16+
import asyncio
17+
import asyncpg
18+
import ssl
19+
import time
20+
21+
POSTGRESQL_CHANNEL_NAME = "casbin_role_watcher"
22+
23+
24+
class AsyncPostgresWatcher:
25+
def __init__(
26+
self,
27+
host: str,
28+
user: str,
29+
password: str,
30+
port: Optional[int] = 5432,
31+
dbname: Optional[str] = "postgres",
32+
channel_name: Optional[str] = POSTGRESQL_CHANNEL_NAME,
33+
sslrootcert: Optional[str] = None,
34+
# If True, equivalent to sslmode=verify-full, if False: sslmode=verify-ca.
35+
check_hostname: Optional[bool] = True,
36+
sslcert: Optional[str] = None,
37+
sslkey: Optional[str] = None
38+
):
39+
self.loop = asyncio.get_event_loop()
40+
self.running = True
41+
self.callback = None
42+
self.host = host
43+
self.port = port
44+
self.user = user
45+
self.password = password
46+
self.dbname = dbname
47+
self.channel_name = channel_name
48+
49+
if sslrootcert is not None and sslcert is not None and sslkey is not None:
50+
self.sslctx = ssl.create_default_context(
51+
ssl.Purpose.SERVER_AUTH,
52+
cafile=sslrootcert
53+
)
54+
self.sslctx.check_hostname = check_hostname
55+
self.sslctx.load_cert_chain(sslcert, keyfile=sslkey)
56+
else:
57+
self.sslctx = False
58+
59+
self.loop.create_task(self.subscriber())
60+
61+
async def notify(self, pid, channel, payload):
62+
print(f"Notify: {payload}")
63+
if self.callback is not None:
64+
if asyncio.iscoroutinefunction(self.callback):
65+
await self.callback(payload)
66+
else:
67+
self.callback(payload)
68+
69+
async def subscriber(self):
70+
conn = await asyncpg.connect(
71+
host=self.host,
72+
port=self.port,
73+
user=self.user,
74+
password=self.password,
75+
database=self.dbname,
76+
ssl=self.sslctx
77+
)
78+
await conn.add_listener(self.channel_name, self.notify)
79+
while self.running:
80+
await asyncio.sleep(1) # keep the coroutine alive
81+
82+
async def set_update_callback(self, fn_name: Callable):
83+
print("runtime is set update callback", fn_name)
84+
self.callback = fn_name
85+
86+
async def update(self):
87+
conn = await asyncpg.connect(
88+
host=self.host,
89+
port=self.port,
90+
user=self.user,
91+
password=self.password,
92+
database=self.dbname,
93+
ssl=self.sslctx
94+
)
95+
async with conn.transaction():
96+
await conn.execute(
97+
f"NOTIFY {self.channel_name},'casbin policy update at {time.time()}'"
98+
)
99+
await conn.close()
100+
return True

examples/rbac_model.conf

+14
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
[request_definition]
2+
r = sub, obj, act
3+
4+
[policy_definition]
5+
p = sub, obj, act
6+
7+
[role_definition]
8+
g = _, _
9+
10+
[policy_effect]
11+
e = some(where (p.eft == allow))
12+
13+
[matchers]
14+
m = (p.sub == "*" || g(r.sub, p.sub)) && r.obj == p.obj && (p.act == "*" || r.act == p.act)

examples/rbac_policy.csv

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
p,alice,data1,read
2+
p,bob,data2,write
3+
p,data2_admin,data2,read
4+
p,data2_admin,data2,write
5+
g,alice,data2_admin

requirements.txt

152 Bytes
Binary file not shown.

test/test_async_postgres_watcher.py

+56
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# Copyright 2024 The casbin Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
17+
from async_postgres_watcher import AsyncPostgresWatcher
18+
19+
# Please set up yourself config
20+
HOST = "127.0.0.1"
21+
PORT = 5432
22+
USER = "postgres"
23+
PASSWORD = "123456"
24+
DBNAME = "postgres"
25+
26+
27+
class TestConfig(unittest.IsolatedAsyncioTestCase):
28+
29+
async def asyncSetUp(self):
30+
self.pg_watcher = AsyncPostgresWatcher(
31+
host=HOST,
32+
port=PORT,
33+
user=USER,
34+
password=PASSWORD,
35+
dbname=DBNAME
36+
)
37+
38+
async def test_update_pg_watcher(self):
39+
assert await self.pg_watcher.update() is True
40+
41+
def test_default_update_callback(self):
42+
assert self.pg_watcher.callback is None
43+
44+
async def test_add_update_callback(self):
45+
def _test_callback():
46+
pass
47+
48+
await self.pg_watcher.set_update_callback(_test_callback)
49+
assert self.pg_watcher.callback == _test_callback
50+
51+
async def asyncTearDown(self):
52+
self.pg_watcher.running = False
53+
54+
55+
if __name__ == "__main__":
56+
unittest.main()

0 commit comments

Comments
 (0)