From 258fe68f193b7951e20f244ecbbf664d7629f0eb Mon Sep 17 00:00:00 2001 From: Vinicius Hirschle Date: Sat, 29 Apr 2023 21:52:01 -0300 Subject: [PATCH] feat(derive): add `#[postgres(allow_mismatch)]` --- .../compile-fail/invalid-allow-mismatch.rs | 31 ++++++++ .../invalid-allow-mismatch.stderr | 43 +++++++++++ postgres-derive-test/src/enums.rs | 72 ++++++++++++++++++- postgres-derive/src/accepts.rs | 42 ++++++----- postgres-derive/src/fromsql.rs | 22 +++++- postgres-derive/src/overrides.rs | 22 +++++- postgres-derive/src/tosql.rs | 22 +++++- postgres-types/src/lib.rs | 23 +++++- 8 files changed, 250 insertions(+), 27 deletions(-) create mode 100644 postgres-derive-test/src/compile-fail/invalid-allow-mismatch.rs create mode 100644 postgres-derive-test/src/compile-fail/invalid-allow-mismatch.stderr diff --git a/postgres-derive-test/src/compile-fail/invalid-allow-mismatch.rs b/postgres-derive-test/src/compile-fail/invalid-allow-mismatch.rs new file mode 100644 index 000000000..52d0ba8f6 --- /dev/null +++ b/postgres-derive-test/src/compile-fail/invalid-allow-mismatch.rs @@ -0,0 +1,31 @@ +use postgres_types::{FromSql, ToSql}; + +#[derive(ToSql, Debug)] +#[postgres(allow_mismatch)] +struct ToSqlAllowMismatchStruct { + a: i32, +} + +#[derive(FromSql, Debug)] +#[postgres(allow_mismatch)] +struct FromSqlAllowMismatchStruct { + a: i32, +} + +#[derive(ToSql, Debug)] +#[postgres(allow_mismatch)] +struct ToSqlAllowMismatchTupleStruct(i32, i32); + +#[derive(FromSql, Debug)] +#[postgres(allow_mismatch)] +struct FromSqlAllowMismatchTupleStruct(i32, i32); + +#[derive(FromSql, Debug)] +#[postgres(transparent, allow_mismatch)] +struct TransparentFromSqlAllowMismatchStruct(i32); + +#[derive(FromSql, Debug)] +#[postgres(allow_mismatch, transparent)] +struct AllowMismatchFromSqlTransparentStruct(i32); + +fn main() {} diff --git a/postgres-derive-test/src/compile-fail/invalid-allow-mismatch.stderr b/postgres-derive-test/src/compile-fail/invalid-allow-mismatch.stderr new file mode 100644 index 000000000..a8e573248 --- /dev/null +++ b/postgres-derive-test/src/compile-fail/invalid-allow-mismatch.stderr @@ -0,0 +1,43 @@ +error: #[postgres(allow_mismatch)] may only be applied to enums + --> src/compile-fail/invalid-allow-mismatch.rs:4:1 + | +4 | / #[postgres(allow_mismatch)] +5 | | struct ToSqlAllowMismatchStruct { +6 | | a: i32, +7 | | } + | |_^ + +error: #[postgres(allow_mismatch)] may only be applied to enums + --> src/compile-fail/invalid-allow-mismatch.rs:10:1 + | +10 | / #[postgres(allow_mismatch)] +11 | | struct FromSqlAllowMismatchStruct { +12 | | a: i32, +13 | | } + | |_^ + +error: #[postgres(allow_mismatch)] may only be applied to enums + --> src/compile-fail/invalid-allow-mismatch.rs:16:1 + | +16 | / #[postgres(allow_mismatch)] +17 | | struct ToSqlAllowMismatchTupleStruct(i32, i32); + | |_______________________________________________^ + +error: #[postgres(allow_mismatch)] may only be applied to enums + --> src/compile-fail/invalid-allow-mismatch.rs:20:1 + | +20 | / #[postgres(allow_mismatch)] +21 | | struct FromSqlAllowMismatchTupleStruct(i32, i32); + | |_________________________________________________^ + +error: #[postgres(transparent)] is not allowed with #[postgres(allow_mismatch)] + --> src/compile-fail/invalid-allow-mismatch.rs:24:25 + | +24 | #[postgres(transparent, allow_mismatch)] + | ^^^^^^^^^^^^^^ + +error: #[postgres(allow_mismatch)] is not allowed with #[postgres(transparent)] + --> src/compile-fail/invalid-allow-mismatch.rs:28:28 + | +28 | #[postgres(allow_mismatch, transparent)] + | ^^^^^^^^^^^ diff --git a/postgres-derive-test/src/enums.rs b/postgres-derive-test/src/enums.rs index 36d428437..f3e6c488c 100644 --- a/postgres-derive-test/src/enums.rs +++ b/postgres-derive-test/src/enums.rs @@ -1,5 +1,5 @@ use crate::test_type; -use postgres::{Client, NoTls}; +use postgres::{error::DbError, Client, NoTls}; use postgres_types::{FromSql, ToSql, WrongType}; use std::error::Error; @@ -131,3 +131,73 @@ fn missing_variant() { let err = conn.execute("SELECT $1::foo", &[&Foo::Bar]).unwrap_err(); assert!(err.source().unwrap().is::()); } + +#[test] +fn allow_mismatch_enums() { + #[derive(Debug, ToSql, FromSql, PartialEq)] + #[postgres(allow_mismatch)] + enum Foo { + Bar, + } + + let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap(); + conn.execute("CREATE TYPE pg_temp.\"Foo\" AS ENUM ('Bar', 'Baz')", &[]) + .unwrap(); + + let row = conn.query_one("SELECT $1::\"Foo\"", &[&Foo::Bar]).unwrap(); + assert_eq!(row.get::<_, Foo>(0), Foo::Bar); +} + +#[test] +fn missing_enum_variant() { + #[derive(Debug, ToSql, FromSql, PartialEq)] + #[postgres(allow_mismatch)] + enum Foo { + Bar, + Buz, + } + + let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap(); + conn.execute("CREATE TYPE pg_temp.\"Foo\" AS ENUM ('Bar', 'Baz')", &[]) + .unwrap(); + + let err = conn + .query_one("SELECT $1::\"Foo\"", &[&Foo::Buz]) + .unwrap_err(); + assert!(err.source().unwrap().is::()); +} + +#[test] +fn allow_mismatch_and_renaming() { + #[derive(Debug, ToSql, FromSql, PartialEq)] + #[postgres(name = "foo", allow_mismatch)] + enum Foo { + #[postgres(name = "bar")] + Bar, + #[postgres(name = "buz")] + Buz, + } + + let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap(); + conn.execute("CREATE TYPE pg_temp.foo AS ENUM ('bar', 'baz', 'buz')", &[]) + .unwrap(); + + let row = conn.query_one("SELECT $1::foo", &[&Foo::Buz]).unwrap(); + assert_eq!(row.get::<_, Foo>(0), Foo::Buz); +} + +#[test] +fn wrong_name_and_allow_mismatch() { + #[derive(Debug, ToSql, FromSql, PartialEq)] + #[postgres(allow_mismatch)] + enum Foo { + Bar, + } + + let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap(); + conn.execute("CREATE TYPE pg_temp.foo AS ENUM ('Bar', 'Baz')", &[]) + .unwrap(); + + let err = conn.query_one("SELECT $1::foo", &[&Foo::Bar]).unwrap_err(); + assert!(err.source().unwrap().is::()); +} diff --git a/postgres-derive/src/accepts.rs b/postgres-derive/src/accepts.rs index 63473863a..a68538dcc 100644 --- a/postgres-derive/src/accepts.rs +++ b/postgres-derive/src/accepts.rs @@ -31,31 +31,37 @@ pub fn domain_body(name: &str, field: &syn::Field) -> TokenStream { } } -pub fn enum_body(name: &str, variants: &[Variant]) -> TokenStream { +pub fn enum_body(name: &str, variants: &[Variant], allow_mismatch: bool) -> TokenStream { let num_variants = variants.len(); let variant_names = variants.iter().map(|v| &v.name); - quote! { - if type_.name() != #name { - return false; + if allow_mismatch { + quote! { + type_.name() == #name } + } else { + quote! { + if type_.name() != #name { + return false; + } - match *type_.kind() { - ::postgres_types::Kind::Enum(ref variants) => { - if variants.len() != #num_variants { - return false; - } - - variants.iter().all(|v| { - match &**v { - #( - #variant_names => true, - )* - _ => false, + match *type_.kind() { + ::postgres_types::Kind::Enum(ref variants) => { + if variants.len() != #num_variants { + return false; } - }) + + variants.iter().all(|v| { + match &**v { + #( + #variant_names => true, + )* + _ => false, + } + }) + } + _ => false, } - _ => false, } } } diff --git a/postgres-derive/src/fromsql.rs b/postgres-derive/src/fromsql.rs index a9150411a..d3ac47f4f 100644 --- a/postgres-derive/src/fromsql.rs +++ b/postgres-derive/src/fromsql.rs @@ -48,6 +48,26 @@ pub fn expand_derive_fromsql(input: DeriveInput) -> Result { )) } } + } else if overrides.allow_mismatch { + match input.data { + Data::Enum(ref data) => { + let variants = data + .variants + .iter() + .map(|variant| Variant::parse(variant, overrides.rename_all)) + .collect::, _>>()?; + ( + accepts::enum_body(&name, &variants, overrides.allow_mismatch), + enum_body(&input.ident, &variants), + ) + } + _ => { + return Err(Error::new_spanned( + input, + "#[postgres(allow_mismatch)] may only be applied to enums", + )); + } + } } else { match input.data { Data::Enum(ref data) => { @@ -57,7 +77,7 @@ pub fn expand_derive_fromsql(input: DeriveInput) -> Result { .map(|variant| Variant::parse(variant, overrides.rename_all)) .collect::, _>>()?; ( - accepts::enum_body(&name, &variants), + accepts::enum_body(&name, &variants, overrides.allow_mismatch), enum_body(&input.ident, &variants), ) } diff --git a/postgres-derive/src/overrides.rs b/postgres-derive/src/overrides.rs index 99faeebb7..d50550bee 100644 --- a/postgres-derive/src/overrides.rs +++ b/postgres-derive/src/overrides.rs @@ -7,6 +7,7 @@ pub struct Overrides { pub name: Option, pub rename_all: Option, pub transparent: bool, + pub allow_mismatch: bool, } impl Overrides { @@ -15,6 +16,7 @@ impl Overrides { name: None, rename_all: None, transparent: false, + allow_mismatch: false, }; for attr in attrs { @@ -74,11 +76,25 @@ impl Overrides { } } Meta::Path(path) => { - if !path.is_ident("transparent") { + if path.is_ident("transparent") { + if overrides.allow_mismatch { + return Err(Error::new_spanned( + path, + "#[postgres(allow_mismatch)] is not allowed with #[postgres(transparent)]", + )); + } + overrides.transparent = true; + } else if path.is_ident("allow_mismatch") { + if overrides.transparent { + return Err(Error::new_spanned( + path, + "#[postgres(transparent)] is not allowed with #[postgres(allow_mismatch)]", + )); + } + overrides.allow_mismatch = true; + } else { return Err(Error::new_spanned(path, "unknown override")); } - - overrides.transparent = true; } bad => return Err(Error::new_spanned(bad, "unknown attribute")), } diff --git a/postgres-derive/src/tosql.rs b/postgres-derive/src/tosql.rs index ec7602312..81d4834bf 100644 --- a/postgres-derive/src/tosql.rs +++ b/postgres-derive/src/tosql.rs @@ -44,6 +44,26 @@ pub fn expand_derive_tosql(input: DeriveInput) -> Result { )); } } + } else if overrides.allow_mismatch { + match input.data { + Data::Enum(ref data) => { + let variants = data + .variants + .iter() + .map(|variant| Variant::parse(variant, overrides.rename_all)) + .collect::, _>>()?; + ( + accepts::enum_body(&name, &variants, overrides.allow_mismatch), + enum_body(&input.ident, &variants), + ) + } + _ => { + return Err(Error::new_spanned( + input, + "#[postgres(allow_mismatch)] may only be applied to enums", + )); + } + } } else { match input.data { Data::Enum(ref data) => { @@ -53,7 +73,7 @@ pub fn expand_derive_tosql(input: DeriveInput) -> Result { .map(|variant| Variant::parse(variant, overrides.rename_all)) .collect::, _>>()?; ( - accepts::enum_body(&name, &variants), + accepts::enum_body(&name, &variants, overrides.allow_mismatch), enum_body(&input.ident, &variants), ) } diff --git a/postgres-types/src/lib.rs b/postgres-types/src/lib.rs index edd723977..cb82e2f93 100644 --- a/postgres-types/src/lib.rs +++ b/postgres-types/src/lib.rs @@ -138,7 +138,6 @@ //! #[derive(Debug, ToSql, FromSql)] //! #[postgres(name = "mood", rename_all = "snake_case")] //! enum Mood { -//! VerySad, // very_sad //! #[postgres(name = "ok")] //! Ok, // ok //! VeryHappy, // very_happy @@ -155,10 +154,28 @@ //! - `"kebab-case"` //! - `"SCREAMING-KEBAB-CASE"` //! - `"Train-Case"` - +//! +//! ## Allowing Enum Mismatches +//! +//! By default the generated implementation of [`ToSql`] & [`FromSql`] for enums will require an exact match of the enum +//! variants between the Rust and Postgres types. +//! To allow mismatches, the `#[postgres(allow_mismatch)]` attribute can be used on the enum definition: +//! +//! ```sql +//! CREATE TYPE mood AS ENUM ( +//! 'Sad', +//! 'Ok', +//! 'Happy' +//! ); +//! ``` +//! #[postgres(allow_mismatch)] +//! enum Mood { +//! Happy, +//! Meh, +//! } +//! ``` #![doc(html_root_url = "https://docs.rs/postgres-types/0.2")] #![warn(clippy::all, rust_2018_idioms, missing_docs)] - use fallible_iterator::FallibleIterator; use postgres_protocol::types::{self, ArrayDimension}; use std::any::type_name;