diff --git a/tensorflow_federated/python/core/impl/executors/BUILD b/tensorflow_federated/python/core/impl/executors/BUILD index e34a59ca00..53986858e4 100644 --- a/tensorflow_federated/python/core/impl/executors/BUILD +++ b/tensorflow_federated/python/core/impl/executors/BUILD @@ -422,6 +422,7 @@ py_library( srcs = ["federating_executor.py"], srcs_version = "PY3", deps = [ + ":channel_base", ":executor_base", ":executor_utils", ":executor_value_base", @@ -439,6 +440,12 @@ py_library( ], ) +py_library( + name = "channel_base", + srcs = ["channel_base.py"], + srcs_version = "PY3", +) + py_test( name = "federating_executor_test", size = "small", diff --git a/tensorflow_federated/python/core/impl/executors/channel_base.py b/tensorflow_federated/python/core/impl/executors/channel_base.py new file mode 100644 index 0000000000..d4a3abac22 --- /dev/null +++ b/tensorflow_federated/python/core/impl/executors/channel_base.py @@ -0,0 +1,35 @@ +import abc +from dataclasses import dataclass +from typing import Tuple, Dict + +from tensorflow_federated.python.common_libs import py_typecheck +from tensorflow_federated.python.core.impl.compiler import placement_literals + +PlacementPair = Tuple[placement_literals.PlacementLiteral, + placement_literals.PlacementLiteral] + + +class Channel(metaclass=abc.ABCMeta): + + @abc.abstractmethod + async def send(self, value, sender=None, receiver=None): + pass + + @abc.abstractmethod + async def receive(self, value, sender=None, receiver=None): + pass + + @abc.abstractmethod + async def setup(self): + pass + + +@dataclass +class ChannelGrid: + channel_dict: Dict[PlacementPair, Channel] + + def __getitem__(self, placements: PlacementPair): + py_typecheck.check_type(placements, tuple) + py_typecheck.check_len(placements, 2) + sorted_placements = sorted(placements, key=lambda p: p.uri) + return self.channel_dict.get(tuple(sorted_placements)) diff --git a/tensorflow_federated/python/core/impl/executors/federating_executor.py b/tensorflow_federated/python/core/impl/executors/federating_executor.py index 8a43585def..717a4fc33a 100644 --- a/tensorflow_federated/python/core/impl/executors/federating_executor.py +++ b/tensorflow_federated/python/core/impl/executors/federating_executor.py @@ -36,6 +36,7 @@ from tensorflow_federated.python.core.impl.executors import executor_utils from tensorflow_federated.python.core.impl.executors import executor_value_base +from tensorflow_federated.python.core.impl.executors import channel_base from tf_encrypted.primitives.sodium import easy_box @@ -461,6 +462,14 @@ class TrustedAggregatorIntrinsicStrategy(IntrinsicStrategy): def __init__(self, federating_executor): super().__init__(federating_executor) + self.channel_grid = channel_base.ChannelGrid({ + (placement_literals.AGGREGATORS, placement_literals.CLIENTS): + EasyBoxChannel( + parent_executor=self, + sender_placement=placement_literals.CLIENTS, + receiver_placement=placement_literals.AGGREGATORS) + }) + @classmethod def validate_executor_placements(cls, executor_placements): singleton_placements = [ @@ -492,128 +501,29 @@ def validate_executor_placements(cls, executor_placements): 'Unsupported cardinality for placement "{}": {}.'.format( pl, pl_cardinality)) - async def _trusted_aggregator_generate_keys(self): + async def _move(self, arg, source_placement, target_placement): - @computations.tf_computation() - def generate_keys(): - pk, sk = easy_box.gen_keypair() - return pk.raw, sk.raw + val_type = arg.type_signature[0] + val = arg.internal_representation[0] + py_typecheck.check_type(val, list) + py_typecheck.check_type(val_type, computation_types.FederatedType) - fn_type = generate_keys.type_signature - fn = generate_keys._computation_proto + target_executor = self._get_child_executors(target_placement, index=0) + channel = self.channel_grid[(source_placement, target_placement)] - aggregator_ex = self._get_child_executors( - placement_literals.AGGREGATORS, index=0) + await channel.setup() - key_generator = await aggregator_ex.create_call(await - aggregator_ex.create_value( - fn, fn_type)) + val = await channel.send(value=val) + item_type = val[0].type_signature - keys = await asyncio.gather(*[ - aggregator_ex.create_selection(key_generator, i) - for i in range(len(key_generator.type_signature)) + val = await asyncio.gather(*[ + target_executor.create_value(await v.compute(), item_type) for v in val ]) - pk_fed_val, sk_fed_val = await asyncio.gather( - *[self._place(k, placement_literals.AGGREGATORS) for k in keys]) - - pk_fed_val = await self.federated_broadcast(pk_fed_val) - - return pk_fed_val, sk_fed_val - - async def _zip_val_key(self, vals, key, placement): - - if isinstance(vals, list): - val_type = computation_types.FederatedType( - vals[0].type_signature, placement, all_equal=False) - else: - val_type = computation_types.FederatedType( - vals.type_signature, placement, all_equal=False) - vals = [vals] - - vals_key = FederatingExecutorValue( - anonymous_tuple.AnonymousTuple([(None, vals), - (None, key.internal_representation)]), - computation_types.NamedTupleType((val_type, key.type_signature))) - - vals_key_zipped = await self._zip(vals_key, placement, all_equal=False) - - return vals_key_zipped.internal_representation - - async def _encrypt_client_tensors(self, arg, pk_a): - - nb_clients = len(self._get_child_executors(placement_literals.CLIENTS)) - - if nb_clients == 1: - input_tensor_type = arg.type_signature.member - else: - input_tensor_type = arg.type_signature[0].member - pk_a_tensor_type = pk_a.type_signature.member - - @computations.tf_computation(input_tensor_type, pk_a_tensor_type) - def encrypt_tensor(plaintext, pk_a): - - pk_a = easy_box.PublicKey(pk_a) - pk_c, sk_c = easy_box.gen_keypair() - nonce = easy_box.gen_nonce() - ciphertext, mac = easy_box.seal_detached(plaintext, nonce, pk_a, sk_c) - - return ciphertext.raw, mac.raw, pk_c.raw, nonce.raw - - fn_type = encrypt_tensor.type_signature - fn = encrypt_tensor._computation_proto - if nb_clients == 1: - val_type = arg.type_signature - val = arg.internal_representation - else: - val_type = arg.type_signature[0] - val = arg.internal_representation[0] - - val_key_zipped = await self._zip_val_key(val, pk_a, - placement_literals.CLIENTS) - - fed_ex = self.federating_executor - - return await fed_ex._compute_intrinsic_federated_map( - FederatingExecutorValue( - anonymous_tuple.AnonymousTuple([(None, fn), - (None, val_key_zipped)]), - computation_types.NamedTupleType((fn_type, val_type)))) - - async def _decrypt_tensors_on_aggregator(self, val, clients_dtype): - - client_output_type_signature = val[0].type_signature[0] - aggregator_public_key_type_signature = val[0].type_signature[1] - - @computations.tf_computation(client_output_type_signature, - aggregator_public_key_type_signature) - def decrypt_tensor(client_outputs, sk_a): + received_vals = await asyncio.gather( + *[channel.receive(value=v, sender=i) for (i, v) in enumerate(val)]) - ciphertext = easy_box.Ciphertext(client_outputs[0]) - mac = easy_box.Mac(client_outputs[1]) - pk_c = easy_box.PublicKey(client_outputs[2]) - nonce = easy_box.Nonce(client_outputs[3]) - sk_a = easy_box.SecretKey(sk_a) - - plaintext_recovered = easy_box.open_detached(ciphertext, mac, nonce, pk_c, - sk_a, clients_dtype) - - return plaintext_recovered - - val_type = computation_types.FederatedType( - computation_types.TensorType(clients_dtype), - placement_literals.AGGREGATORS, - all_equal=False) - - fn_type = decrypt_tensor.type_signature - fn = decrypt_tensor._computation_proto - - fed_ex = self.federating_executor - - return await fed_ex._compute_intrinsic_federated_map( - FederatingExecutorValue( - anonymous_tuple.AnonymousTuple([(None, fn), (None, val)]), - computation_types.NamedTupleType((fn_type, val_type)))) + return received_vals async def federated_value_at_server(self, arg): return await self._place(arg, placement_literals.SERVER) @@ -664,36 +574,13 @@ async def federated_reduce(self, arg): 'Expected 3 elements in the `federated_reduce()` argument tuple, ' 'found {}.'.format(len(arg.internal_representation))) - # Encrypt client tensors - pk_a, sk_a = await self._trusted_aggregator_generate_keys() - enc_clients_arg = await self._encrypt_client_tensors(arg, pk_a) - val_type = enc_clients_arg.type_signature - val = enc_clients_arg.internal_representation - # Store original client dtype before encryption for - # future decryption - orig_clients_dtype = arg.type_signature[0].member.dtype + aggr = self._get_child_executors(placement_literals.AGGREGATORS, index=0) + aggregands = await self._move(arg, placement_literals.CLIENTS, + placement_literals.AGGREGATORS) - py_typecheck.check_type(val_type, computation_types.FederatedType) - item_type = val_type.member zero_type = arg.type_signature[1] op_type = arg.type_signature[2] - py_typecheck.check_type(val, list) - aggr = self._get_child_executors(placement_literals.AGGREGATORS, index=0) - - aggregands = await asyncio.gather( - *[self._move(v, item_type, aggr) for v in val]) - - # Decrypt tensors moved to the server before applying reduce - # In the future should be decrypted by the Trusted aggregator - aggregands_decrypted = [] - for item in aggregands: - item_key_zipped = await self._zip_val_key(item, sk_a, - placement_literals.AGGREGATORS) - decrypted_tensor = await self._decrypt_tensors_on_aggregator( - item_key_zipped, orig_clients_dtype) - aggregands_decrypted.append(decrypted_tensor.internal_representation[0]) - zero = await aggr.create_value( await (await self.federating_executor.create_selection(arg, @@ -701,12 +588,12 @@ async def federated_reduce(self, arg): zero_type) op = await aggr.create_value(arg.internal_representation[2], op_type) - for item in aggregands_decrypted: + for item in aggregands: type_utils.check_equivalent_types( op_type, type_factory.reduction_op(zero_type, item.type_signature)) result = zero - for item in aggregands_decrypted: + for item in aggregands: result = await aggr.create_call( op, await aggr.create_tuple( anonymous_tuple.AnonymousTuple([(None, result), (None, item)]))) @@ -1222,3 +1109,315 @@ async def _compute_intrinsic_federated_collect(self, arg): @tracing.trace async def _compute_intrinsic_federated_secure_sum(self, arg): return await self.intrinsic_strategy.federated_secure_sum(arg) + + +class EasyBoxChannel(channel_base.Channel): + + def __init__(self, parent_executor, sender_placement, receiver_placement): + + self.parent_executor = parent_executor + self.sender_placement = sender_placement + self.receiver_placement = receiver_placement + + self.key_references = KeyStore() + self._is_channel_setup = False + + self._encrypt_tensor_fn = None + self._decrypt_tensor_fn = None + + async def setup(self): + + if not self._is_channel_setup: + await asyncio.gather(*[ + self._generate_keys(self.sender_placement), + self._generate_keys(self.receiver_placement) + ]) + await asyncio.gather(*[ + self._share_public_keys(self.sender_placement, + self.receiver_placement), + self._share_public_keys(self.receiver_placement, + self.sender_placement) + ]) + + self._is_channel_setup = True + + async def send(self, value, sender=None, receiver=None): + + return await self._encrypt_values_on_sender(value, sender, receiver) + + async def receive(self, value, receiver=None, sender=None): + + return await self._decrypt_values_on_receiver(value, sender, receiver) + + async def _generate_keys(self, key_owner): + + @computations.tf_computation() + def generate_keys(): + pk, sk = easy_box.gen_keypair() + return pk.raw, sk.raw + + fn_type = generate_keys.type_signature + fn = generate_keys._computation_proto + + executors = self.parent_executor._get_child_executors(key_owner) + + nb_executors = len(executors) + sk_vals = [] + pk_vals = [] + + for executor in executors: + key_generator = await executor.create_call(await executor.create_value( + fn, fn_type)) + + pk = await executor.create_selection(key_generator, 0) + sk = await executor.create_selection(key_generator, 1) + + pk_vals.append(pk) + sk_vals.append(sk) + + # Store list of EagerValue created by executor.create_call + # in a FederatingExecutorValue with the key onwer placement + sk_fed_vals = await self._place_keys(sk_vals, key_owner) + + self.key_references.add_keys(key_owner.name, pk_vals, sk_fed_vals) + + async def _share_public_keys(self, key_owner, send_pks_to): + + pk = self.key_references.get_public_key(key_owner.name) + + pk_fed_vals = await self._place_keys(pk, send_pks_to) + + self.key_references.update_keys(key_owner.name, pk_fed_vals) + + async def _encrypt_values_on_sender(self, val, sender=None, receiver=None): + + nb_senders = len( + self.parent_executor._get_child_executors(self.sender_placement)) + + if nb_senders == 1: + input_tensor_type = val.type_signature + self.orig_sender_tensor_dtype = input_tensor_type.dtype + else: + input_tensor_type = val[0].type_signature + self.orig_sender_tensor_dtype = input_tensor_type.dtype + + pk_receiver = self.key_references.get_public_key( + self.receiver_placement.name) + sk_sender = self.key_references.get_secret_key(self.sender_placement.name) + pk_rcv_type = pk_receiver.type_signature.member + sk_snd_type = sk_sender.type_signature.member + + if not self._encrypt_tensor_fn: + self._encrypt_tensor_fn = self._encrypt_tensor(input_tensor_type, + pk_rcv_type, sk_snd_type) + + fn_type = self._encrypt_tensor_fn.type_signature + fn = self._encrypt_tensor_fn._computation_proto + + if nb_senders == 1: + tensor_type = val.type_signature + else: + tensor_type = val[0].type_signature + + val_type = computation_types.FederatedType( + tensor_type, self.sender_placement, all_equal=False) + + val_key_zipped = await self._zip_val_key( + self.sender_placement, + val, + pk_receiver, + sk_sender, + pk_index=receiver, + sk_index=sender) + + # NOTE probably won't always be fed_ex in future design + fed_ex = self.parent_executor.federating_executor + + val_encrypted = await fed_ex._compute_intrinsic_federated_map( + FederatingExecutorValue( + anonymous_tuple.AnonymousTuple([(None, fn), + (None, val_key_zipped)]), + computation_types.NamedTupleType((fn_type, val_type)))) + + if sender != None or receiver != None: + return val_encrypted.internal_representation[0] + else: + return val_encrypted.internal_representation + + async def _decrypt_values_on_receiver(self, val, sender=None, receiver=None): + + pk_sender = self.key_references.get_public_key(self.sender_placement.name) + sk_receiver = self.key_references.get_secret_key( + self.receiver_placement.name) + + val = await self._zip_val_key( + self.receiver_placement, + val, + pk_sender, + sk_receiver, + pk_index=sender, + sk_index=receiver) + + sender_values_type = val[0].type_signature[0] + pk_snd_type = val[0].type_signature[1] + sk_snd_type = val[0].type_signature[2] + + if not self._decrypt_tensor_fn: + self._decrypt_tensor_fn = self._decrypt_tensor( + sender_values_type, pk_snd_type, sk_snd_type, + self.orig_sender_tensor_dtype) + + fn_type = self._decrypt_tensor_fn.type_signature + fn = self._decrypt_tensor_fn._computation_proto + + val_type = computation_types.FederatedType( + computation_types.TensorType(self.orig_sender_tensor_dtype), + self.receiver_placement, + all_equal=False) + + # NOTE probably won't always be fed_ex in future design + fed_ex = self.parent_executor.federating_executor + + val_decrypted = await fed_ex._compute_intrinsic_federated_map( + FederatingExecutorValue( + anonymous_tuple.AnonymousTuple([(None, fn), (None, val)]), + computation_types.NamedTupleType((fn_type, val_type)))) + + if sender != None or receiver != None: + return val_decrypted.internal_representation[0] + else: + return val_decrypted.internal_representation + + def _encrypt_tensor(self, plaintext_type, pk_rcv_type, sk_snd_type): + + @computations.tf_computation(plaintext_type, pk_rcv_type, sk_snd_type) + def encrypt_tensor(plaintext, pk_rcv, sk_snd): + + pk_rcv = easy_box.PublicKey(pk_rcv) + sk_snd = easy_box.PublicKey(sk_snd) + + nonce = easy_box.gen_nonce() + ciphertext, mac = easy_box.seal_detached(plaintext, nonce, pk_rcv, sk_snd) + + return ciphertext.raw, mac.raw, nonce.raw + + return encrypt_tensor + + def _decrypt_tensor(self, sender_values_type, pk_snd_type, sk_rcv_snd, + orig_sender_tensor_dtype): + + @computations.tf_computation(sender_values_type, pk_snd_type, sk_rcv_snd) + def decrypt_tensor(sender_values, pk_snd, sk_rcv): + + ciphertext = easy_box.Ciphertext(sender_values[0]) + mac = easy_box.Mac(sender_values[1]) + nonce = easy_box.Nonce(sender_values[2]) + sk_rcv = easy_box.SecretKey(sk_rcv) + pk_snd = easy_box.PublicKey(pk_snd) + + plaintext_recovered = easy_box.open_detached(ciphertext, mac, nonce, + pk_snd, sk_rcv, + orig_sender_tensor_dtype) + + return plaintext_recovered + + return decrypt_tensor + + async def _zip_val_key(self, + placement, + vals, + pk_key, + sk_key, + pk_index=None, + sk_index=None): + + if isinstance(vals, list): + val_type = computation_types.FederatedType( + vals[0].type_signature, placement, all_equal=False) + else: + val_type = computation_types.FederatedType( + vals.type_signature, placement, all_equal=False) + vals = [vals] + + pk_key_vals = pk_key.internal_representation + sk_key_vals = sk_key.internal_representation + + if pk_index != None: + pk_key_vals = [pk_key_vals[pk_index]] + + if sk_index != None: + sk_key_vals = [sk_key_vals[sk_index]] + + vals_key = FederatingExecutorValue( + anonymous_tuple.AnonymousTuple([(None, vals), (None, pk_key_vals), + (None, sk_key_vals)]), + computation_types.NamedTupleType( + (val_type, pk_key.type_signature, sk_key.type_signature))) + + vals_key_zipped = await self.parent_executor._zip( + vals_key, placement, all_equal=False) + + return vals_key_zipped.internal_representation + + async def _place_keys(self, keys, placement): + + py_typecheck.check_type(placement, placement_literals.PlacementLiteral) + children = self.parent_executor._get_child_executors(placement) + + # Scenario: there are as many keys as exectutors. For example + # there are 3 clients and each should have a secret key + if len(keys) == len(children): + keys_type_signature = keys[0].type_signature + return FederatingExecutorValue( + await asyncio.gather(*[ + c.create_value(await keys[i].compute(), keys_type_signature) + for (i, c) in enumerate(children) + ]), + computation_types.FederatedType( + keys_type_signature, placement, all_equal=False)) + # Scenario: there are more keys than exectutors. For example + # there are 3 clients and each have a public key. Each client wants + # to share its key to the same aggregator. + elif (len(children) == 1) & (len(children) < len(keys)): + keys_type_signature = keys[0].type_signature + child = children[0] + return FederatingExecutorValue( + await asyncio.gather(*[ + child.create_value(await k.compute(), keys_type_signature) + for k in keys + ]), + computation_types.FederatedType( + keys_type_signature, placement, all_equal=False)) + # Scenario: there are more exectutors than keys. For example + # there is an aggregator with one public key. The aggregator + # wants to share the samer public key to 3 different clients. + elif (len(keys) == 1) & (len(children) > len(keys)): + keys_type_signature = keys[0].type_signature + return FederatingExecutorValue( + await asyncio.gather(*[ + c.create_value(await keys[0].compute(), keys_type_signature) + for c in children + ]), + computation_types.FederatedType( + keys_type_signature, placement, all_equal=True)) + + +class KeyStore: + + def __init__(self): + self.key_store = {} + + def add_keys(self, key_owner, pk, sk): + self.key_store[key_owner] = {'pk': pk, 'sk': sk} + + def get_public_key(self, key_owner): + return self.key_store[key_owner]['pk'] + + def get_secret_key(self, key_owner): + return self.key_store[key_owner]['sk'] + + def update_keys(self, key_owner, pk=None, sk=None): + if pk: + self.key_store[key_owner]['pk'] = pk + if sk: + self.key_store[key_owner]['sk'] = sk diff --git a/tensorflow_federated/python/core/impl/executors/federating_executor_test.py b/tensorflow_federated/python/core/impl/executors/federating_executor_test.py index 97c50c6797..e16d2b2a92 100644 --- a/tensorflow_federated/python/core/impl/executors/federating_executor_test.py +++ b/tensorflow_federated/python/core/impl/executors/federating_executor_test.py @@ -40,6 +40,8 @@ from tensorflow_federated.python.core.impl.executors import federating_executor from tensorflow_federated.python.core.impl.executors import reference_resolving_executor +from tensorflow_federated.python.core.impl.executors import channel_base + tf.compat.v1.enable_v2_behavior() @@ -1277,11 +1279,30 @@ class EncryptionTest(parameterized.TestCase): def test_generate_aggregator_keys(self): strategy = federating_executor.TrustedAggregatorIntrinsicStrategy loop, ex = _make_test_runtime(intrinsic_strategy_fn=strategy) - generate_keys = ex.intrinsic_strategy._trusted_aggregator_generate_keys() - pk, sk = loop.run_until_complete(generate_keys) + strat_ex = ex.intrinsic_strategy + + channel_grid = channel_base.ChannelGrid({ + (placement_literals.AGGREGATORS, placement_literals.CLIENTS): + federating_executor.EasyBoxChannel( + parent_executor=strat_ex, + sender_placement=placement_literals.CLIENTS, + receiver_placement=placement_literals.AGGREGATORS) + }) + + channel = channel_grid[(placement_literals.CLIENTS, + placement_literals.AGGREGATORS)] + loop.run_until_complete(channel.setup()) + key_references = channel.key_references + + pk_c = key_references.get_public_key(placement_literals.CLIENTS.name) + sk_c = key_references.get_secret_key(placement_literals.CLIENTS.name) + pk_a = key_references.get_public_key(placement_literals.AGGREGATORS.name) + sk_a = key_references.get_secret_key(placement_literals.AGGREGATORS.name) - self.assertEqual(str(pk.type_signature), 'uint8[32]@CLIENTS') - self.assertEqual(str(sk.type_signature), 'uint8[32]@AGGREGATORS') + self.assertEqual(str(pk_c.type_signature), '{uint8[32]}@AGGREGATORS') + self.assertEqual(str(sk_c.type_signature), '{uint8[32]}@CLIENTS') + self.assertEqual(str(pk_a.type_signature), '{uint8[32]}@CLIENTS') + self.assertEqual(str(sk_a.type_signature), '{uint8[32]}@AGGREGATORS') def test_encryption_decryption(self): @@ -1289,30 +1310,28 @@ def test_encryption_decryption(self): loop, ex = _make_test_runtime(intrinsic_strategy_fn=strategy) strat_ex = ex.intrinsic_strategy - pk_a, sk_a = loop.run_until_complete( - strat_ex._trusted_aggregator_generate_keys()) + channel_grid = channel_base.ChannelGrid({ + (placement_literals.AGGREGATORS, placement_literals.CLIENTS): + federating_executor.EasyBoxChannel( + parent_executor=strat_ex, + sender_placement=placement_literals.CLIENTS, + receiver_placement=placement_literals.AGGREGATORS) + }) + + channel = channel_grid[(placement_literals.CLIENTS, + placement_literals.AGGREGATORS)] + loop.run_until_complete(channel.setup()) val = loop.run_until_complete( ex.create_value([2.0], type_factory.at_clients(tf.float32))) val_enc = loop.run_until_complete( - strat_ex._encrypt_client_tensors(val, pk_a)) - - aggr = strat_ex._get_child_executors( - placement_literals.AGGREGATORS, index=0) - - enc_val_on_aggr = loop.run_until_complete( - strat_ex._move(val_enc.internal_representation[0], - val_enc.type_signature.member, aggr)) - - val_key_zipped = loop.run_until_complete( - strat_ex._zip_val_key([enc_val_on_aggr], sk_a, - placement_literals.AGGREGATORS)) + channel.send(val.internal_representation[0])) val_dec = loop.run_until_complete( - strat_ex._decrypt_tensors_on_aggregator(val_key_zipped, tf.float32)) + channel.receive(val_enc)) - dec_tf_tensor = val_dec.internal_representation[0].internal_representation + dec_tf_tensor = val_dec[0].internal_representation self.assertEqual(dec_tf_tensor, tf.constant(2.0, dtype=tf.float32))