diff --git a/willow/src/shell/ahe.rs b/willow/src/shell/ahe.rs index 0bc212b..d8e3b27 100644 --- a/willow/src/shell/ahe.rs +++ b/willow/src/shell/ahe.rs @@ -391,26 +391,24 @@ pub struct PartialDecryptionMetadata { macro_rules! impl_proto_traits_single_poly { ($type:ty, $proto:ty) => { - impl ToProto for $type { + impl> ToProto for $type { type Proto = $proto; - type Context = ShellAhe; - fn to_proto(&self, ctx: &Self::Context) -> Result { - let moduli = ahe::get_moduli(&ctx.public_ahe_parameters); + fn to_proto(&self, ctx: Context) -> Result { + let moduli = ahe::get_moduli(&ctx.as_ref().public_ahe_parameters); let poly_proto = rns_polynomial_to_proto(&self.0, &moduli)?; Ok(proto!($proto { poly: poly_proto })) } } - impl FromProto for $type { + impl> FromProto for $type { type Proto = $proto; - type Context = ShellAhe; fn from_proto( proto: impl protobuf::AsView, - ctx: &Self::Context, + ctx: Context, ) -> Result { - let moduli = ahe::get_moduli(&ctx.public_ahe_parameters); + let moduli = ahe::get_moduli(&ctx.as_ref().public_ahe_parameters); let poly = rns_polynomial_from_proto(proto.as_view().poly(), &moduli)?; Ok(Self(poly)) } @@ -424,12 +422,11 @@ impl_proto_traits_single_poly!(PublicKey, ShellAhePublicKey); macro_rules! impl_proto_traits_vec_poly { ($type:ty, $proto:ty) => { - impl ToProto for $type { + impl> ToProto for $type { type Proto = $proto; - type Context = ShellAhe; - fn to_proto(&self, ctx: &Self::Context) -> Result { - let moduli = ahe::get_moduli(&ctx.public_ahe_parameters); + fn to_proto(&self, ctx: Context) -> Result { + let moduli = ahe::get_moduli(&ctx.as_ref().public_ahe_parameters); let mut result = proto!($proto {}); for poly in &self.0 { result.poly_mut().push(rns_polynomial_to_proto(&poly, &moduli)?); @@ -438,15 +435,14 @@ macro_rules! impl_proto_traits_vec_poly { } } - impl FromProto for $type { + impl> FromProto for $type { type Proto = $proto; - type Context = ShellAhe; fn from_proto( proto: impl protobuf::AsView, - ctx: &Self::Context, + ctx: Context, ) -> Result { - let moduli = ahe::get_moduli(&ctx.public_ahe_parameters); + let moduli = ahe::get_moduli(&ctx.as_ref().public_ahe_parameters); let polys: Result, _> = proto .as_view() .poly() @@ -463,34 +459,38 @@ impl_proto_traits_vec_poly!(PartialDecryption, ShellAhePartialDecryption); impl_proto_traits_vec_poly!(PartialDecCiphertext, ShellAhePartialDecCiphertext); impl_proto_traits_vec_poly!(RecoverCiphertext, ShellAheRecoverCiphertext); -impl ToProto for Ciphertext { +impl> ToProto for Ciphertext { type Proto = ShellAheCiphertext; - type Context = ShellAhe; - fn to_proto(&self, ctx: &Self::Context) -> Result { + fn to_proto(&self, ctx: Context) -> Result { Ok(proto!(ShellAheCiphertext { - component_b: self.component_b.to_proto(ctx)?, - component_a: self.component_a.to_proto(ctx)?, + component_b: self.component_b.to_proto(&ctx)?, + component_a: self.component_a.to_proto(&ctx)?, })) } } -impl FromProto for Ciphertext { +impl> FromProto for Ciphertext { type Proto = ShellAheCiphertext; - type Context = ShellAhe; fn from_proto( proto: impl protobuf::AsView, - ctx: &Self::Context, + ctx: Context, ) -> Result { let proto_view = proto.as_view(); Ok(Ciphertext { - component_b: RecoverCiphertext::from_proto(proto_view.component_b(), ctx)?, - component_a: PartialDecCiphertext::from_proto(proto_view.component_a(), ctx)?, + component_b: RecoverCiphertext::from_proto(proto_view.component_b(), &ctx)?, + component_a: PartialDecCiphertext::from_proto(proto_view.component_a(), &ctx)?, }) } } +impl AsRef for ShellAhe { + fn as_ref(&self) -> &ShellAhe { + self + } +} + impl AheBase for ShellAhe { type SecretKeyShare = SecretKeyShare; type PublicKeyShare = PublicKeyShare; diff --git a/willow/src/shell/kahe.rs b/willow/src/shell/kahe.rs index ad3022f..8ff84a3 100644 --- a/willow/src/shell/kahe.rs +++ b/willow/src/shell/kahe.rs @@ -110,37 +110,34 @@ pub struct SecretKey(pub RnsPolynomial); #[derive(Clone)] pub struct Ciphertext(pub RnsPolynomialVec); -impl ToProto for SecretKey { +impl> ToProto for SecretKey { type Proto = ShellKaheSecretKey; - type Context = ShellKahe; - fn to_proto(&self, ctx: &Self::Context) -> Result { - let moduli = kahe::get_moduli(&ctx.public_kahe_parameters); + fn to_proto(&self, ctx: Context) -> Result { + let moduli = kahe::get_moduli(&ctx.as_ref().public_kahe_parameters); let poly_proto = rns_polynomial_to_proto(&self.0, &moduli)?; Ok(proto!(ShellKaheSecretKey { poly: poly_proto })) } } -impl FromProto for SecretKey { +impl> FromProto for SecretKey { type Proto = ShellKaheSecretKey; - type Context = ShellKahe; fn from_proto( proto: impl protobuf::AsView, - ctx: &Self::Context, + ctx: Context, ) -> Result { - let moduli = kahe::get_moduli(&ctx.public_kahe_parameters); + let moduli = kahe::get_moduli(&ctx.as_ref().public_kahe_parameters); let poly = rns_polynomial_from_proto(proto.as_view().poly(), &moduli)?; Ok(Self(poly)) } } -impl ToProto for Ciphertext { +impl> ToProto for Ciphertext { type Proto = ShellKaheCiphertext; - type Context = ShellKahe; - fn to_proto(&self, ctx: &Self::Context) -> Result { - let moduli = kahe::get_moduli(&ctx.public_kahe_parameters); + fn to_proto(&self, ctx: Context) -> Result { + let moduli = kahe::get_moduli(&ctx.as_ref().public_kahe_parameters); let mut result = proto!(ShellKaheCiphertext {}); for poly in self.0.iter() { result.poly_mut().push(rns_polynomial_to_proto(&poly, &moduli)?); @@ -149,21 +146,26 @@ impl ToProto for Ciphertext { } } -impl FromProto for Ciphertext { +impl> FromProto for Ciphertext { type Proto = ShellKaheCiphertext; - type Context = ShellKahe; fn from_proto( proto: impl protobuf::AsView, - ctx: &Self::Context, + ctx: Context, ) -> Result { - let moduli = kahe::get_moduli(&ctx.public_kahe_parameters); + let moduli = kahe::get_moduli(&ctx.as_ref().public_kahe_parameters); let polys: Result, _> = proto.as_view().poly().iter().map(|p| rns_polynomial_from_proto(p, &moduli)).collect(); Ok(Ciphertext(rust_vec_to_rns_polynomial_vec(polys?))) } } +impl AsRef for ShellKahe { + fn as_ref(&self) -> &ShellKahe { + self + } +} + impl KaheBase for ShellKahe { type SecretKey = SecretKey; diff --git a/willow/src/traits/proto_serialization.rs b/willow/src/traits/proto_serialization.rs index a145958..a2bc00f 100644 --- a/willow/src/traits/proto_serialization.rs +++ b/willow/src/traits/proto_serialization.rs @@ -15,20 +15,18 @@ use status::StatusError; /// Trait for converting a struct to a Protobuf message. -pub trait ToProto { +pub trait ToProto { type Proto; - type Context; - fn to_proto(&self, ctx: &Self::Context) -> Result; + fn to_proto(&self, ctx: Context) -> Result; } /// Trait for converting a Protobuf message view to a struct. -pub trait FromProto: Sized { - type Proto: protobuf::AsView; - type Context; +pub trait FromProto: Sized { + type Proto; fn from_proto( proto: impl protobuf::AsView, - ctx: &Self::Context, + ctx: Context, ) -> Result; }