Skip to content

Commit 3e68498

Browse files
Automated rollback of changelist 520433075
PiperOrigin-RevId: 522139828
1 parent e362f51 commit 3e68498

File tree

3 files changed

+48
-39
lines changed

3 files changed

+48
-39
lines changed

tensorflow_privacy/privacy/dp_query/tree_aggregation.py

+6-7
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,11 @@
2121
as helper functions for `tree_aggregation_query`. This module and helper
2222
functions are publicly accessible.
2323
"""
24+
2425
import abc
2526
import collections
26-
from typing import Any, Callable, Collection, Optional, Tuple, Union
27+
from typing import Any, Callable, Collection, NamedTuple, Optional, Tuple, Union
2728

28-
import attr
2929
import tensorflow as tf
3030

3131
# TODO(b/192464750): find a proper place for the helper functions, privatize
@@ -170,8 +170,7 @@ def next(self, state):
170170
return self.value_fn(), state
171171

172172

173-
@attr.s(eq=False, frozen=True, slots=True)
174-
class TreeState(object):
173+
class TreeState(NamedTuple):
175174
"""Class defining state of the tree.
176175
177176
Attributes:
@@ -183,9 +182,9 @@ class TreeState(object):
183182
for the most recent leaf node.
184183
value_generator_state: State of a stateful `ValueGenerator` for tree node.
185184
"""
186-
level_buffer = attr.ib(type=tf.Tensor)
187-
level_buffer_idx = attr.ib(type=tf.Tensor)
188-
value_generator_state = attr.ib(type=Any)
185+
level_buffer: tf.Tensor
186+
level_buffer_idx: tf.Tensor
187+
value_generator_state: Any
189188

190189

191190
# TODO(b/192464750): move `get_step_idx` to be a property of `TreeState`.

tensorflow_privacy/privacy/dp_query/tree_aggregation_query.py

+38-26
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,11 @@
3333
eps = rdp_accountant.get_privacy_spent(orders, rdp, target_delta)[0]
3434
"""
3535

36-
import attr
36+
from typing import Any, NamedTuple
37+
3738
import dp_accounting
3839
import tensorflow as tf
40+
3941
from tensorflow_privacy.privacy.dp_query import dp_query
4042
from tensorflow_privacy.privacy.dp_query import tree_aggregation
4143

@@ -84,8 +86,7 @@ class TreeCumulativeSumQuery(dp_query.SumAggregationDPQuery):
8486
O(clip_norm*log(T)/eps) to guarantee eps-DP.
8587
"""
8688

87-
@attr.s(frozen=True)
88-
class GlobalState(object):
89+
class GlobalState(NamedTuple):
8990
"""Class defining global state for Tree sum queries.
9091
9192
Attributes:
@@ -94,9 +95,9 @@ class GlobalState(object):
9495
clip_value: The clipping value to be passed to clip_fn.
9596
samples_cumulative_sum: Noiseless cumulative sum of samples over time.
9697
"""
97-
tree_state = attr.ib()
98-
clip_value = attr.ib()
99-
samples_cumulative_sum = attr.ib()
98+
tree_state: Any
99+
clip_value: Any
100+
samples_cumulative_sum: Any
100101

101102
def __init__(self,
102103
record_specs,
@@ -182,10 +183,11 @@ def get_noised_result(self, sample_state, global_state):
182183
global_state.tree_state)
183184
noised_cumulative_sum = tf.nest.map_structure(tf.add, new_cumulative_sum,
184185
cumulative_sum_noise)
185-
new_global_state = attr.evolve(
186-
global_state,
186+
new_global_state = TreeCumulativeSumQuery.GlobalState(
187+
tree_state=new_tree_state,
188+
clip_value=global_state.clip_value,
187189
samples_cumulative_sum=new_cumulative_sum,
188-
tree_state=new_tree_state)
190+
)
189191
event = dp_accounting.UnsupportedDpEvent()
190192
return noised_cumulative_sum, new_global_state, event
191193

@@ -206,10 +208,11 @@ def reset_state(self, noised_results, global_state):
206208
state for the next cumulative sum.
207209
"""
208210
new_tree_state = self._tree_aggregator.reset_state(global_state.tree_state)
209-
return attr.evolve(
210-
global_state,
211+
return TreeCumulativeSumQuery.GlobalState(
212+
tree_state=new_tree_state,
213+
clip_value=global_state.clip_value,
211214
samples_cumulative_sum=noised_results,
212-
tree_state=new_tree_state)
215+
)
213216

214217
@classmethod
215218
def build_l2_gaussian_query(cls,
@@ -312,8 +315,7 @@ class TreeResidualSumQuery(dp_query.SumAggregationDPQuery):
312315
O(clip_norm*log(T)/eps) to guarantee eps-DP.
313316
"""
314317

315-
@attr.s(frozen=True)
316-
class GlobalState(object):
318+
class GlobalState(NamedTuple):
317319
"""Class defining global state for Tree sum queries.
318320
319321
Attributes:
@@ -323,9 +325,9 @@ class GlobalState(object):
323325
previous_tree_noise: Cumulative noise by tree aggregation from the
324326
previous time the query is called on a sample.
325327
"""
326-
tree_state = attr.ib()
327-
clip_value = attr.ib()
328-
previous_tree_noise = attr.ib()
328+
tree_state: Any
329+
clip_value: Any
330+
previous_tree_noise: Any
329331

330332
def __init__(self,
331333
record_specs,
@@ -426,8 +428,11 @@ def get_noised_result(self, sample_state, global_state):
426428
noised_sample = tf.nest.map_structure(lambda a, b, c: a + b - c,
427429
sample_state, tree_noise,
428430
global_state.previous_tree_noise)
429-
new_global_state = attr.evolve(
430-
global_state, previous_tree_noise=tree_noise, tree_state=new_tree_state)
431+
new_global_state = TreeResidualSumQuery.GlobalState(
432+
tree_state=new_tree_state,
433+
clip_value=global_state.clip_value,
434+
previous_tree_noise=tree_noise,
435+
)
431436
event = dp_accounting.UnsupportedDpEvent()
432437
return noised_sample, new_global_state, event
433438

@@ -448,21 +453,28 @@ def reset_state(self, noised_results, global_state):
448453
"""
449454
del noised_results
450455
new_tree_state = self._tree_aggregator.reset_state(global_state.tree_state)
451-
return attr.evolve(
452-
global_state,
456+
return TreeResidualSumQuery.GlobalState(
457+
tree_state=new_tree_state,
458+
clip_value=global_state.clip_value,
453459
previous_tree_noise=self._zero_initial_noise(),
454-
tree_state=new_tree_state)
460+
)
455461

456462
def reset_l2_clip_gaussian_noise(self, global_state, clip_norm, stddev):
457463
noise_generator_state = global_state.tree_state.value_generator_state
458464
assert isinstance(self._tree_aggregator.value_generator,
459465
tree_aggregation.GaussianNoiseGenerator)
460466
noise_generator_state = self._tree_aggregator.value_generator.make_state(
461467
noise_generator_state.seeds, stddev)
462-
new_tree_state = attr.evolve(
463-
global_state.tree_state, value_generator_state=noise_generator_state)
464-
return attr.evolve(
465-
global_state, clip_value=clip_norm, tree_state=new_tree_state)
468+
new_tree_state = tree_aggregation.TreeState(
469+
level_buffer=global_state.tree_state.level_buffer,
470+
level_buffer_idx=global_state.tree_state.level_buffer_idx,
471+
value_generator_state=noise_generator_state,
472+
)
473+
return TreeResidualSumQuery.GlobalState(
474+
tree_state=new_tree_state,
475+
clip_value=clip_norm,
476+
previous_tree_noise=global_state.previous_tree_noise,
477+
)
466478

467479
@classmethod
468480
def build_l2_gaussian_query(cls,

tensorflow_privacy/privacy/dp_query/tree_range_query.py

+4-6
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,8 @@
1818

1919
import distutils
2020
import math
21-
from typing import Optional
21+
from typing import Any, NamedTuple, Optional
2222

23-
import attr
2423
import dp_accounting
2524
import tensorflow as tf
2625
from tensorflow_privacy.privacy.dp_query import distributed_discrete_gaussian_query
@@ -102,17 +101,16 @@ class TreeRangeSumQuery(dp_query.SumAggregationDPQuery):
102101
Improves efficiency and reduces noise scale.
103102
"""
104103

105-
@attr.s(frozen=True)
106-
class GlobalState(object):
104+
class GlobalState(NamedTuple):
107105
"""Class defining global state for TreeRangeSumQuery.
108106
109107
Attributes:
110108
arity: The branching factor of the tree (i.e. the number of children each
111109
internal node has).
112110
inner_query_state: The global state of the inner query.
113111
"""
114-
arity = attr.ib()
115-
inner_query_state = attr.ib()
112+
arity: Any
113+
inner_query_state: Any
116114

117115
def __init__(self,
118116
inner_query: dp_query.SumAggregationDPQuery,

0 commit comments

Comments
 (0)