33
33
eps = rdp_accountant.get_privacy_spent(orders, rdp, target_delta)[0]
34
34
"""
35
35
36
- import attr
36
+ from typing import Any , NamedTuple
37
+
37
38
import dp_accounting
38
39
import tensorflow as tf
40
+
39
41
from tensorflow_privacy .privacy .dp_query import dp_query
40
42
from tensorflow_privacy .privacy .dp_query import tree_aggregation
41
43
@@ -84,8 +86,7 @@ class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery):
84
86
O(clip_norm*log(T)/eps) to guarantee eps-DP.
85
87
"""
86
88
87
- @attr .s (frozen = True )
88
- class GlobalState (object ):
89
+ class GlobalState (NamedTuple ):
89
90
"""Class defining global state for Tree sum queries.
90
91
91
92
Attributes:
@@ -94,9 +95,9 @@ class GlobalState(object):
94
95
clip_value: The clipping value to be passed to clip_fn.
95
96
samples_cumulative_sum: Noiseless cumulative sum of samples over time.
96
97
"""
97
- tree_state = attr . ib ()
98
- clip_value = attr . ib ()
99
- samples_cumulative_sum = attr . ib ()
98
+ tree_state : Any
99
+ clip_value : Any
100
+ samples_cumulative_sum : Any
100
101
101
102
def __init__ (self ,
102
103
record_specs ,
@@ -182,10 +183,11 @@ def get_noised_result(self, sample_state, global_state):
182
183
global_state .tree_state )
183
184
noised_cumulative_sum = tf .nest .map_structure (tf .add , new_cumulative_sum ,
184
185
cumulative_sum_noise )
185
- new_global_state = attr .evolve (
186
- global_state ,
186
+ new_global_state = TreeCumulativeSumQuery .GlobalState (
187
+ tree_state = new_tree_state ,
188
+ clip_value = global_state .clip_value ,
187
189
samples_cumulative_sum = new_cumulative_sum ,
188
- tree_state = new_tree_state )
190
+ )
189
191
event = dp_accounting .UnsupportedDpEvent ()
190
192
return noised_cumulative_sum , new_global_state , event
191
193
@@ -206,10 +208,11 @@ def reset_state(self, noised_results, global_state):
206
208
state for the next cumulative sum.
207
209
"""
208
210
new_tree_state = self ._tree_aggregator .reset_state (global_state .tree_state )
209
- return attr .evolve (
210
- global_state ,
211
+ return TreeCumulativeSumQuery .GlobalState (
212
+ tree_state = new_tree_state ,
213
+ clip_value = global_state .clip_value ,
211
214
samples_cumulative_sum = noised_results ,
212
- tree_state = new_tree_state )
215
+ )
213
216
214
217
@classmethod
215
218
def build_l2_gaussian_query (cls ,
@@ -312,8 +315,7 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
312
315
O(clip_norm*log(T)/eps) to guarantee eps-DP.
313
316
"""
314
317
315
- @attr .s (frozen = True )
316
- class GlobalState (object ):
318
+ class GlobalState (NamedTuple ):
317
319
"""Class defining global state for Tree sum queries.
318
320
319
321
Attributes:
@@ -323,9 +325,9 @@ class GlobalState(object):
323
325
previous_tree_noise: Cumulative noise by tree aggregation from the
324
326
previous time the query is called on a sample.
325
327
"""
326
- tree_state = attr . ib ()
327
- clip_value = attr . ib ()
328
- previous_tree_noise = attr . ib ()
328
+ tree_state : Any
329
+ clip_value : Any
330
+ previous_tree_noise : Any
329
331
330
332
def __init__ (self ,
331
333
record_specs ,
@@ -426,8 +428,11 @@ def get_noised_result(self, sample_state, global_state):
426
428
noised_sample = tf .nest .map_structure (lambda a , b , c : a + b - c ,
427
429
sample_state , tree_noise ,
428
430
global_state .previous_tree_noise )
429
- new_global_state = attr .evolve (
430
- global_state , previous_tree_noise = tree_noise , tree_state = new_tree_state )
431
+ new_global_state = TreeResidualSumQuery .GlobalState (
432
+ tree_state = new_tree_state ,
433
+ clip_value = global_state .clip_value ,
434
+ previous_tree_noise = tree_noise ,
435
+ )
431
436
event = dp_accounting .UnsupportedDpEvent ()
432
437
return noised_sample , new_global_state , event
433
438
@@ -448,21 +453,28 @@ def reset_state(self, noised_results, global_state):
448
453
"""
449
454
del noised_results
450
455
new_tree_state = self ._tree_aggregator .reset_state (global_state .tree_state )
451
- return attr .evolve (
452
- global_state ,
456
+ return TreeResidualSumQuery .GlobalState (
457
+ tree_state = new_tree_state ,
458
+ clip_value = global_state .clip_value ,
453
459
previous_tree_noise = self ._zero_initial_noise (),
454
- tree_state = new_tree_state )
460
+ )
455
461
456
462
def reset_l2_clip_gaussian_noise (self , global_state , clip_norm , stddev ):
457
463
noise_generator_state = global_state .tree_state .value_generator_state
458
464
assert isinstance (self ._tree_aggregator .value_generator ,
459
465
tree_aggregation .GaussianNoiseGenerator )
460
466
noise_generator_state = self ._tree_aggregator .value_generator .make_state (
461
467
noise_generator_state .seeds , stddev )
462
- new_tree_state = attr .evolve (
463
- global_state .tree_state , value_generator_state = noise_generator_state )
464
- return attr .evolve (
465
- global_state , clip_value = clip_norm , tree_state = new_tree_state )
468
+ new_tree_state = tree_aggregation .TreeState (
469
+ level_buffer = global_state .tree_state .level_buffer ,
470
+ level_buffer_idx = global_state .tree_state .level_buffer_idx ,
471
+ value_generator_state = noise_generator_state ,
472
+ )
473
+ return TreeResidualSumQuery .GlobalState (
474
+ tree_state = new_tree_state ,
475
+ clip_value = clip_norm ,
476
+ previous_tree_noise = global_state .previous_tree_noise ,
477
+ )
466
478
467
479
@classmethod
468
480
def build_l2_gaussian_query (cls ,
0 commit comments