1
+ import json
1
2
import logging
2
3
import time
3
4
import threading
5
+ import uuid
4
6
5
7
import casbin
6
8
import pika
@@ -18,12 +20,14 @@ def __init__(
18
20
username = "guest" ,
19
21
password = "guest" ,
20
22
key = "casbin-policy-updated" ,
21
- ** kwargs
23
+ local_id = None ,
24
+ ** kwargs ,
22
25
):
23
26
self .connection = None
24
27
self .pub_channel = None
25
28
self .key = key
26
29
self .callback = None
30
+ self .local_id = local_id if local_id is not None else str (uuid .uuid4 ())
27
31
self .mutex = threading .Lock ()
28
32
self .subscribe_event = threading .Event ()
29
33
self .subscribe_thread = threading .Thread (target = self .start_watch , daemon = True )
@@ -33,7 +37,7 @@ def __init__(
33
37
port = port ,
34
38
virtual_host = virtual_host ,
35
39
credentials = credentials ,
36
- ** kwargs
40
+ ** kwargs ,
37
41
)
38
42
39
43
def create_client (self ):
@@ -62,8 +66,9 @@ def update(self):
62
66
"""
63
67
update the policy
64
68
"""
69
+ msg = MSG ("Update" , self .local_id , "" , "" , "" )
65
70
self .pub_channel .basic_publish (
66
- exchange = self .key , routing_key = "" , body = str ( time . time () )
71
+ exchange = self .key , routing_key = "" , body = msg . marshal_binary ( )
67
72
)
68
73
return True
69
74
@@ -75,9 +80,14 @@ def update_for_add_policy(self, section, ptype, *params):
75
80
:param params: other params
76
81
:return: True if updated
77
82
"""
78
- message = "Update for add policy: " + section + " " + ptype + " " + str (params )
79
- LOGGER .info (message )
80
- return self .update ()
83
+
84
+ def func ():
85
+ msg = MSG ("UpdateForAddPolicy" , self .local_id , section , ptype , params )
86
+ return self .pub_channel .basic_publish (
87
+ exchange = self .key , routing_key = "" , body = msg .marshal_binary ()
88
+ )
89
+
90
+ return self .log_record (func )
81
91
82
92
def update_for_remove_policy (self , section , ptype , * params ):
83
93
"""
@@ -87,11 +97,14 @@ def update_for_remove_policy(self, section, ptype, *params):
87
97
:param params: other params
88
98
:return: True if updated
89
99
"""
90
- message = (
91
- "Update for remove policy: " + section + " " + ptype + " " + str (params )
92
- )
93
- LOGGER .info (message )
94
- return self .update ()
100
+
101
+ def func ():
102
+ msg = MSG ("UpdateForRemovePolicy" , self .local_id , section , ptype , params )
103
+ return self .pub_channel .basic_publish (
104
+ exchange = self .key , routing_key = "" , body = msg .marshal_binary ()
105
+ )
106
+
107
+ return self .log_record (func )
95
108
96
109
def update_for_remove_filtered_policy (self , section , ptype , field_index , * params ):
97
110
"""
@@ -102,28 +115,41 @@ def update_for_remove_filtered_policy(self, section, ptype, field_index, *params
102
115
:param params: other params
103
116
:return:
104
117
"""
105
- message = (
106
- "Update for remove filtered policy: "
107
- + section
108
- + " "
109
- + ptype
110
- + " "
111
- + str (field_index )
112
- + " "
113
- + str (params )
114
- )
115
- LOGGER .info (message )
116
- return self .update ()
118
+
119
+ def func ():
120
+ msg = MSG (
121
+ "UpdateForRemoveFilteredPolicy" ,
122
+ self .local_id ,
123
+ section ,
124
+ ptype ,
125
+ f"{ field_index } { ' ' .join (params )} " ,
126
+ )
127
+ return self .pub_channel .basic_publish (
128
+ exchange = self .key , routing_key = "" , body = msg .marshal_binary ()
129
+ )
130
+
131
+ return self .log_record (func )
117
132
118
133
def update_for_save_policy (self , model : casbin .Model ):
119
134
"""
120
135
update for save policy
121
136
:param model: casbin model
122
137
:return:
123
138
"""
124
- message = "Update for save policy: " + model .to_text ()
125
- LOGGER .info (message )
126
- return self .update ()
139
+
140
+ def func ():
141
+ msg = MSG (
142
+ "UpdateForSavePolicy" ,
143
+ self .local_id ,
144
+ "" ,
145
+ "" ,
146
+ model .to_text (),
147
+ )
148
+ return self .pub_channel .basic_publish (
149
+ exchange = self .key , routing_key = "" , body = msg .marshal_binary ()
150
+ )
151
+
152
+ return self .log_record (func )
127
153
128
154
def update_for_add_policies (self , section , ptype , * params ):
129
155
"""
@@ -133,11 +159,14 @@ def update_for_add_policies(self, section, ptype, *params):
133
159
:param params: other params
134
160
:return:
135
161
"""
136
- message = (
137
- "Update for add policies: " + section + " " + ptype + " " + str (params )
138
- )
139
- LOGGER .info (message )
140
- return self .update ()
162
+
163
+ def func ():
164
+ msg = MSG ("UpdateForAddPolicies" , self .local_id , section , ptype , params )
165
+ return self .pub_channel .basic_publish (
166
+ exchange = self .key , routing_key = "" , body = msg .marshal_binary ()
167
+ )
168
+
169
+ return self .log_record (func )
141
170
142
171
def update_for_remove_policies (self , section , ptype , * params ):
143
172
"""
@@ -147,11 +176,23 @@ def update_for_remove_policies(self, section, ptype, *params):
147
176
:param params: other params
148
177
:return:
149
178
"""
150
- message = (
151
- "Update for remove policies: " + section + " " + ptype + " " + str (params )
152
- )
153
- LOGGER .info (message )
154
- return self .update ()
179
+
180
+ def func ():
181
+ msg = MSG ("UpdateForRemovePolicies" , self .local_id , section , ptype , params )
182
+ return self .pub_channel .basic_publish (
183
+ exchange = self .key , routing_key = "" , body = msg .marshal_binary ()
184
+ )
185
+
186
+ return self .log_record (func )
187
+
188
+ @staticmethod
189
+ def log_record (f : callable ):
190
+ try :
191
+ result = f ()
192
+ except Exception as e :
193
+ print (f"Casbin Redis Watcher error: { e } " )
194
+ else :
195
+ return result
155
196
156
197
def start_watch (self ):
157
198
"""
@@ -204,14 +245,31 @@ def _watch_callback(ch, method, properties, body):
204
245
continue
205
246
206
247
248
+ class MSG :
249
+ def __init__ (self , method = "" , id = "" , sec = "" , ptype = "" , * params ):
250
+ self .method : str = method
251
+ self .id : str = id
252
+ self .sec : str = sec
253
+ self .ptype : str = ptype
254
+ self .params = params
255
+
256
+ def marshal_binary (self ):
257
+ return json .dumps (self .__dict__ )
258
+
259
+ @staticmethod
260
+ def unmarshal_binary (data : bytes ):
261
+ loaded = json .loads (data )
262
+ return MSG (** loaded )
263
+
264
+
207
265
def new_watcher (
208
266
host = "localhost" ,
209
267
port = 5672 ,
210
268
virtual_host = "/" ,
211
269
username = "guest" ,
212
270
password = "guest" ,
213
271
key = "casbin-policy-updated" ,
214
- ** kwargs
272
+ ** kwargs ,
215
273
):
216
274
"""
217
275
creates a new watcher
@@ -230,7 +288,7 @@ def new_watcher(
230
288
username = username ,
231
289
password = password ,
232
290
key = key ,
233
- ** kwargs
291
+ ** kwargs ,
234
292
)
235
293
rabbit .subscribe_thread .start ()
236
294
rabbit .subscribe_event .wait (timeout = 5 )
0 commit comments