Skip to content

Commit

Permalink
feat(derive): add #[postgres(allow_mismatch)]
Browse files Browse the repository at this point in the history
  • Loading branch information
viniciusth committed Jun 19, 2023
1 parent 790af54 commit 258fe68
Show file tree
Hide file tree
Showing 8 changed files with 250 additions and 27 deletions.
31 changes: 31 additions & 0 deletions postgres-derive-test/src/compile-fail/invalid-allow-mismatch.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
use postgres_types::{FromSql, ToSql};

#[derive(ToSql, Debug)]
#[postgres(allow_mismatch)]
struct ToSqlAllowMismatchStruct {
a: i32,
}

#[derive(FromSql, Debug)]
#[postgres(allow_mismatch)]
struct FromSqlAllowMismatchStruct {
a: i32,
}

#[derive(ToSql, Debug)]
#[postgres(allow_mismatch)]
struct ToSqlAllowMismatchTupleStruct(i32, i32);

#[derive(FromSql, Debug)]
#[postgres(allow_mismatch)]
struct FromSqlAllowMismatchTupleStruct(i32, i32);

#[derive(FromSql, Debug)]
#[postgres(transparent, allow_mismatch)]
struct TransparentFromSqlAllowMismatchStruct(i32);

#[derive(FromSql, Debug)]
#[postgres(allow_mismatch, transparent)]
struct AllowMismatchFromSqlTransparentStruct(i32);

fn main() {}
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
error: #[postgres(allow_mismatch)] may only be applied to enums
--> src/compile-fail/invalid-allow-mismatch.rs:4:1
|
4 | / #[postgres(allow_mismatch)]
5 | | struct ToSqlAllowMismatchStruct {
6 | | a: i32,
7 | | }
| |_^

error: #[postgres(allow_mismatch)] may only be applied to enums
--> src/compile-fail/invalid-allow-mismatch.rs:10:1
|
10 | / #[postgres(allow_mismatch)]
11 | | struct FromSqlAllowMismatchStruct {
12 | | a: i32,
13 | | }
| |_^

error: #[postgres(allow_mismatch)] may only be applied to enums
--> src/compile-fail/invalid-allow-mismatch.rs:16:1
|
16 | / #[postgres(allow_mismatch)]
17 | | struct ToSqlAllowMismatchTupleStruct(i32, i32);
| |_______________________________________________^

error: #[postgres(allow_mismatch)] may only be applied to enums
--> src/compile-fail/invalid-allow-mismatch.rs:20:1
|
20 | / #[postgres(allow_mismatch)]
21 | | struct FromSqlAllowMismatchTupleStruct(i32, i32);
| |_________________________________________________^

error: #[postgres(transparent)] is not allowed with #[postgres(allow_mismatch)]
--> src/compile-fail/invalid-allow-mismatch.rs:24:25
|
24 | #[postgres(transparent, allow_mismatch)]
| ^^^^^^^^^^^^^^

error: #[postgres(allow_mismatch)] is not allowed with #[postgres(transparent)]
--> src/compile-fail/invalid-allow-mismatch.rs:28:28
|
28 | #[postgres(allow_mismatch, transparent)]
| ^^^^^^^^^^^
72 changes: 71 additions & 1 deletion postgres-derive-test/src/enums.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::test_type;
use postgres::{Client, NoTls};
use postgres::{error::DbError, Client, NoTls};
use postgres_types::{FromSql, ToSql, WrongType};
use std::error::Error;

Expand Down Expand Up @@ -131,3 +131,73 @@ fn missing_variant() {
let err = conn.execute("SELECT $1::foo", &[&Foo::Bar]).unwrap_err();
assert!(err.source().unwrap().is::<WrongType>());
}

#[test]
fn allow_mismatch_enums() {
#[derive(Debug, ToSql, FromSql, PartialEq)]
#[postgres(allow_mismatch)]
enum Foo {
Bar,
}

let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap();
conn.execute("CREATE TYPE pg_temp.\"Foo\" AS ENUM ('Bar', 'Baz')", &[])
.unwrap();

let row = conn.query_one("SELECT $1::\"Foo\"", &[&Foo::Bar]).unwrap();
assert_eq!(row.get::<_, Foo>(0), Foo::Bar);
}

#[test]
fn missing_enum_variant() {
#[derive(Debug, ToSql, FromSql, PartialEq)]
#[postgres(allow_mismatch)]
enum Foo {
Bar,
Buz,
}

let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap();
conn.execute("CREATE TYPE pg_temp.\"Foo\" AS ENUM ('Bar', 'Baz')", &[])
.unwrap();

let err = conn
.query_one("SELECT $1::\"Foo\"", &[&Foo::Buz])
.unwrap_err();
assert!(err.source().unwrap().is::<DbError>());
}

#[test]
fn allow_mismatch_and_renaming() {
#[derive(Debug, ToSql, FromSql, PartialEq)]
#[postgres(name = "foo", allow_mismatch)]
enum Foo {
#[postgres(name = "bar")]
Bar,
#[postgres(name = "buz")]
Buz,
}

let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap();
conn.execute("CREATE TYPE pg_temp.foo AS ENUM ('bar', 'baz', 'buz')", &[])
.unwrap();

let row = conn.query_one("SELECT $1::foo", &[&Foo::Buz]).unwrap();
assert_eq!(row.get::<_, Foo>(0), Foo::Buz);
}

#[test]
fn wrong_name_and_allow_mismatch() {
#[derive(Debug, ToSql, FromSql, PartialEq)]
#[postgres(allow_mismatch)]
enum Foo {
Bar,
}

let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap();
conn.execute("CREATE TYPE pg_temp.foo AS ENUM ('Bar', 'Baz')", &[])
.unwrap();

let err = conn.query_one("SELECT $1::foo", &[&Foo::Bar]).unwrap_err();
assert!(err.source().unwrap().is::<WrongType>());
}
42 changes: 24 additions & 18 deletions postgres-derive/src/accepts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,31 +31,37 @@ pub fn domain_body(name: &str, field: &syn::Field) -> TokenStream {
}
}

pub fn enum_body(name: &str, variants: &[Variant]) -> TokenStream {
pub fn enum_body(name: &str, variants: &[Variant], allow_mismatch: bool) -> TokenStream {
let num_variants = variants.len();
let variant_names = variants.iter().map(|v| &v.name);

quote! {
if type_.name() != #name {
return false;
if allow_mismatch {
quote! {
type_.name() == #name
}
} else {
quote! {
if type_.name() != #name {
return false;
}

match *type_.kind() {
::postgres_types::Kind::Enum(ref variants) => {
if variants.len() != #num_variants {
return false;
}

variants.iter().all(|v| {
match &**v {
#(
#variant_names => true,
)*
_ => false,
match *type_.kind() {
::postgres_types::Kind::Enum(ref variants) => {
if variants.len() != #num_variants {
return false;
}
})

variants.iter().all(|v| {
match &**v {
#(
#variant_names => true,
)*
_ => false,
}
})
}
_ => false,
}
_ => false,
}
}
}
Expand Down
22 changes: 21 additions & 1 deletion postgres-derive/src/fromsql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,26 @@ pub fn expand_derive_fromsql(input: DeriveInput) -> Result<TokenStream, Error> {
))
}
}
} else if overrides.allow_mismatch {
match input.data {
Data::Enum(ref data) => {
let variants = data
.variants
.iter()
.map(|variant| Variant::parse(variant, overrides.rename_all))
.collect::<Result<Vec<_>, _>>()?;
(
accepts::enum_body(&name, &variants, overrides.allow_mismatch),
enum_body(&input.ident, &variants),
)
}
_ => {
return Err(Error::new_spanned(
input,
"#[postgres(allow_mismatch)] may only be applied to enums",
));
}
}
} else {
match input.data {
Data::Enum(ref data) => {
Expand All @@ -57,7 +77,7 @@ pub fn expand_derive_fromsql(input: DeriveInput) -> Result<TokenStream, Error> {
.map(|variant| Variant::parse(variant, overrides.rename_all))
.collect::<Result<Vec<_>, _>>()?;
(
accepts::enum_body(&name, &variants),
accepts::enum_body(&name, &variants, overrides.allow_mismatch),
enum_body(&input.ident, &variants),
)
}
Expand Down
22 changes: 19 additions & 3 deletions postgres-derive/src/overrides.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ pub struct Overrides {
pub name: Option<String>,
pub rename_all: Option<RenameRule>,
pub transparent: bool,
pub allow_mismatch: bool,
}

impl Overrides {
Expand All @@ -15,6 +16,7 @@ impl Overrides {
name: None,
rename_all: None,
transparent: false,
allow_mismatch: false,
};

for attr in attrs {
Expand Down Expand Up @@ -74,11 +76,25 @@ impl Overrides {
}
}
Meta::Path(path) => {
if !path.is_ident("transparent") {
if path.is_ident("transparent") {
if overrides.allow_mismatch {
return Err(Error::new_spanned(
path,
"#[postgres(allow_mismatch)] is not allowed with #[postgres(transparent)]",
));
}
overrides.transparent = true;
} else if path.is_ident("allow_mismatch") {
if overrides.transparent {
return Err(Error::new_spanned(
path,
"#[postgres(transparent)] is not allowed with #[postgres(allow_mismatch)]",
));
}
overrides.allow_mismatch = true;
} else {
return Err(Error::new_spanned(path, "unknown override"));
}

overrides.transparent = true;
}
bad => return Err(Error::new_spanned(bad, "unknown attribute")),
}
Expand Down
22 changes: 21 additions & 1 deletion postgres-derive/src/tosql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,26 @@ pub fn expand_derive_tosql(input: DeriveInput) -> Result<TokenStream, Error> {
));
}
}
} else if overrides.allow_mismatch {
match input.data {
Data::Enum(ref data) => {
let variants = data
.variants
.iter()
.map(|variant| Variant::parse(variant, overrides.rename_all))
.collect::<Result<Vec<_>, _>>()?;
(
accepts::enum_body(&name, &variants, overrides.allow_mismatch),
enum_body(&input.ident, &variants),
)
}
_ => {
return Err(Error::new_spanned(
input,
"#[postgres(allow_mismatch)] may only be applied to enums",
));
}
}
} else {
match input.data {
Data::Enum(ref data) => {
Expand All @@ -53,7 +73,7 @@ pub fn expand_derive_tosql(input: DeriveInput) -> Result<TokenStream, Error> {
.map(|variant| Variant::parse(variant, overrides.rename_all))
.collect::<Result<Vec<_>, _>>()?;
(
accepts::enum_body(&name, &variants),
accepts::enum_body(&name, &variants, overrides.allow_mismatch),
enum_body(&input.ident, &variants),
)
}
Expand Down
23 changes: 20 additions & 3 deletions postgres-types/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,6 @@
//! #[derive(Debug, ToSql, FromSql)]
//! #[postgres(name = "mood", rename_all = "snake_case")]
//! enum Mood {
//! VerySad, // very_sad
//! #[postgres(name = "ok")]
//! Ok, // ok
//! VeryHappy, // very_happy
Expand All @@ -155,10 +154,28 @@
//! - `"kebab-case"`
//! - `"SCREAMING-KEBAB-CASE"`
//! - `"Train-Case"`
//!
//! ## Allowing Enum Mismatches
//!
//! By default the generated implementation of [`ToSql`] & [`FromSql`] for enums will require an exact match of the enum
//! variants between the Rust and Postgres types.
//! To allow mismatches, the `#[postgres(allow_mismatch)]` attribute can be used on the enum definition:
//!
//! ```sql
//! CREATE TYPE mood AS ENUM (
//! 'Sad',
//! 'Ok',
//! 'Happy'
//! );
//! ```
//! #[postgres(allow_mismatch)]
//! enum Mood {
//! Happy,
//! Meh,
//! }
//! ```
#![doc(html_root_url = "https://docs.rs/postgres-types/0.2")]
#![warn(clippy::all, rust_2018_idioms, missing_docs)]

use fallible_iterator::FallibleIterator;
use postgres_protocol::types::{self, ArrayDimension};
use std::any::type_name;
Expand Down

0 comments on commit 258fe68

Please sign in to comment.