Skip to content

Adds support for zero-copy FromSql derive #1070

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions postgres-derive-test/src/composites.rs
Original file line number Diff line number Diff line change
Expand Up @@ -346,3 +346,32 @@ fn generics() {
},
);
}

#[test]
fn struct_with_borrowed_fields() {
#[derive(FromSql, ToSql, Debug, PartialEq)]
#[postgres(name = "item")]
struct Item<'a, 'b: 'a> {
name: &'a str,
data: &'b [u8],
}

let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap();
conn.batch_execute(
"CREATE TYPE pg_temp.item AS (
name TEXT,
data BYTEA
);",
)
.unwrap();

let item = Item {
name: "foobar",
data: b"12345",
};

let row = conn.query_one("SELECT $1::item", &[&item]).unwrap();
let result: Item<'_, '_> = row.get(0);
assert_eq!(item.name, result.name);
assert_eq!(item.data, result.data);
}
35 changes: 35 additions & 0 deletions postgres-derive-test/src/transparent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,38 @@ fn round_trip() {
UserId(123)
);
}

#[test]
fn struct_with_reference() {
#[derive(FromSql, ToSql, Debug, PartialEq)]
#[postgres(transparent)]
struct UserName<'a>(&'a str);

let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap();

let user_name = "tester";
let row = conn
.query_one("SELECT $1", &[&UserName(user_name)])
.unwrap();
let result: UserName<'_> = row.get(0);
assert_eq!(user_name, result.0);
}

#[test]
fn nested_struct_with_reference() {
#[derive(FromSql, ToSql, Debug, PartialEq)]
#[postgres(transparent)]
struct Inner<'a>(&'a str);

#[derive(FromSql, ToSql, Debug, PartialEq)]
#[postgres(transparent)]
struct UserName<'a>(#[postgres(borrow)] Inner<'a>);

let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap();

let user_name = "tester";
let inner = Inner(user_name);
let row = conn.query_one("SELECT $1", &[&UserName(inner)]).unwrap();
let result: UserName<'_> = row.get(0);
assert_eq!(user_name, result.0 .0);
}
4 changes: 2 additions & 2 deletions postgres-derive/src/accepts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use quote::quote;
use std::iter;
use syn::Ident;

use crate::composites::Field;
use crate::composites::NamedField;
use crate::enums::Variant;

pub fn transparent_body(field: &syn::Field) -> TokenStream {
Expand Down Expand Up @@ -66,7 +66,7 @@ pub fn enum_body(name: &str, variants: &[Variant], allow_mismatch: bool) -> Toke
}
}

pub fn composite_body(name: &str, trait_: &str, fields: &[Field]) -> TokenStream {
pub fn composite_body(name: &str, trait_: &str, fields: &[NamedField]) -> TokenStream {
let num_fields = fields.len();
let trait_ = Ident::new(trait_, Span::call_site());
let traits = iter::repeat(&trait_);
Expand Down
22 changes: 14 additions & 8 deletions postgres-derive/src/composites.rs
Original file line number Diff line number Diff line change
@@ -1,23 +1,28 @@
use proc_macro2::Span;
use std::collections::HashSet;
use syn::{
punctuated::Punctuated, Error, GenericParam, Generics, Ident, Path, PathSegment, Type,
TypeParamBound,
Error, GenericParam, Generics,
Ident, Lifetime, Path, PathSegment, punctuated::Punctuated, Type, TypeParamBound,
};
use lifetimes::extract_borrowed_lifetimes;

use crate::{case::RenameRule, overrides::Overrides};
use crate::{case::RenameRule, lifetimes, overrides::Overrides};

pub struct Field {
pub struct NamedField {
pub name: String,
pub ident: Ident,
pub type_: Type,
pub borrowed_lifetimes: HashSet<Lifetime>,
}

impl Field {
pub fn parse(raw: &syn::Field, rename_all: Option<RenameRule>) -> Result<Field, Error> {
impl NamedField {
pub fn parse(raw: &syn::Field, rename_all: Option<RenameRule>) -> Result<NamedField, Error> {
let overrides = Overrides::extract(&raw.attrs, false)?;
let ident = raw.ident.as_ref().unwrap().clone();

// field level name override takes precendence over container level rename_all override
let borrowed_lifetimes = extract_borrowed_lifetimes(raw, &overrides);

// field level name override takes precedence over container level rename_all override
let name = match overrides.name {
Some(n) => n,
None => {
Expand All @@ -31,10 +36,11 @@ impl Field {
}
};

Ok(Field {
Ok(NamedField {
name,
ident,
type_: raw.ty.clone(),
borrowed_lifetimes,
})
}
}
Expand Down
47 changes: 30 additions & 17 deletions postgres-derive/src/fromsql.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,21 @@
use std::collections::{BTreeSet, HashSet};
use proc_macro2::{Span, TokenStream};
use quote::{format_ident, quote};
use std::iter;
use std::iter::FromIterator;
use syn::{
punctuated::Punctuated, token, AngleBracketedGenericArguments, Data, DataStruct, DeriveInput,
Error, Fields, GenericArgument, GenericParam, Generics, Ident, Lifetime, PathArguments,
PathSegment,
AngleBracketedGenericArguments, Data, DataStruct, DeriveInput, Error, Fields,
GenericArgument, GenericParam, Generics, Ident, Lifetime, PathArguments, PathSegment, punctuated::Punctuated,
token,
};
use syn::{LifetimeParam, TraitBound, TraitBoundModifier, TypeParamBound};

use crate::accepts;
use crate::composites::Field;
use crate::composites::{append_generic_bound, new_derive_path};
use crate::composites::NamedField;
use crate::enums::Variant;
use crate::overrides::Overrides;
use crate::transparent::UnnamedField;

pub fn expand_derive_fromsql(input: DeriveInput) -> Result<TokenStream, Error> {
let overrides = Overrides::extract(&input.attrs, true)?;
Expand All @@ -29,16 +32,18 @@ pub fn expand_derive_fromsql(input: DeriveInput) -> Result<TokenStream, Error> {
.clone()
.unwrap_or_else(|| input.ident.to_string());

let (accepts_body, to_sql_body) = if overrides.transparent {
let (accepts_body, to_sql_body, borrowed_lifetimes) = if overrides.transparent {
match input.data {
Data::Struct(DataStruct {
fields: Fields::Unnamed(ref fields),
..
}) if fields.unnamed.len() == 1 => {
let field = fields.unnamed.first().unwrap();
let parsed_field = UnnamedField::parse(field)?;
(
accepts::transparent_body(field),
transparent_body(&input.ident, field),
parsed_field.borrowed_lifetimes,
)
}
_ => {
Expand All @@ -59,6 +64,7 @@ pub fn expand_derive_fromsql(input: DeriveInput) -> Result<TokenStream, Error> {
(
accepts::enum_body(&name, &variants, overrides.allow_mismatch),
enum_body(&input.ident, &variants),
HashSet::new(),
)
}
_ => {
Expand All @@ -79,16 +85,19 @@ pub fn expand_derive_fromsql(input: DeriveInput) -> Result<TokenStream, Error> {
(
accepts::enum_body(&name, &variants, overrides.allow_mismatch),
enum_body(&input.ident, &variants),
HashSet::new(),
)
}
Data::Struct(DataStruct {
fields: Fields::Unnamed(ref fields),
..
}) if fields.unnamed.len() == 1 => {
let field = fields.unnamed.first().unwrap();
let parsed_field = UnnamedField::parse(field)?;
(
domain_accepts_body(&name, field),
domain_body(&input.ident, field),
parsed_field.borrowed_lifetimes,
)
}
Data::Struct(DataStruct {
Expand All @@ -98,11 +107,16 @@ pub fn expand_derive_fromsql(input: DeriveInput) -> Result<TokenStream, Error> {
let fields = fields
.named
.iter()
.map(|field| Field::parse(field, overrides.rename_all))
.map(|field| NamedField::parse(field, overrides.rename_all))
.collect::<Result<Vec<_>, _>>()?;
let borrowed_lifetimes: HashSet<_> = fields
.iter()
.flat_map(|f| f.borrowed_lifetimes.to_owned())
.collect();
(
accepts::composite_body(&name, "FromSql", &fields),
composite_body(&input.ident, &fields),
borrowed_lifetimes
)
}
_ => {
Expand All @@ -115,7 +129,7 @@ pub fn expand_derive_fromsql(input: DeriveInput) -> Result<TokenStream, Error> {
};

let ident = &input.ident;
let (generics, lifetime) = build_generics(&input.generics);
let (generics, lifetime) = build_generics(&input.generics, borrowed_lifetimes);
let (impl_generics, _, _) = generics.split_for_impl();
let (_, ty_generics, where_clause) = input.generics.split_for_impl();
let out = quote! {
Expand Down Expand Up @@ -183,7 +197,7 @@ fn domain_body(ident: &Ident, field: &syn::Field) -> TokenStream {
}
}

fn composite_body(ident: &Ident, fields: &[Field]) -> TokenStream {
fn composite_body(ident: &Ident, fields: &[NamedField]) -> TokenStream {
let temp_vars = &fields
.iter()
.map(|f| format_ident!("__{}", f.ident))
Expand Down Expand Up @@ -233,16 +247,15 @@ fn composite_body(ident: &Ident, fields: &[Field]) -> TokenStream {
}
}

fn build_generics(source: &Generics) -> (Generics, Lifetime) {
// don't worry about lifetime name collisions, it doesn't make sense to derive FromSql on a struct with a lifetime
let lifetime = Lifetime::new("'a", Span::call_site());

fn build_generics(source: &Generics, borrowed_lifetimes: HashSet<Lifetime>) -> (Generics, Lifetime) {
// This is the same parent lifetime name serde uses
let lifetime = Lifetime::new("'de", Span::call_site());
// Sort lifetimes for deterministic code-gen
let sorted_lifetimes = BTreeSet::from_iter(borrowed_lifetimes);
let mut lifetime_param = LifetimeParam::new(lifetime.to_owned());
lifetime_param.bounds.extend(sorted_lifetimes);
let mut out = append_generic_bound(source.to_owned(), &new_fromsql_bound(&lifetime));
out.params.insert(
0,
GenericParam::Lifetime(LifetimeParam::new(lifetime.to_owned())),
);

out.params.insert(0, GenericParam::Lifetime(lifetime_param));
(out, lifetime)
}

Expand Down
2 changes: 2 additions & 0 deletions postgres-derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ mod enums;
mod fromsql;
mod overrides;
mod tosql;
mod transparent;
mod lifetimes;

#[proc_macro_derive(ToSql, attributes(postgres))]
pub fn derive_tosql(input: TokenStream) -> TokenStream {
Expand Down
36 changes: 36 additions & 0 deletions postgres-derive/src/lifetimes.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
use std::collections::HashSet;
use syn::{AngleBracketedGenericArguments, GenericArgument, Lifetime, PathArguments, Type};
use crate::overrides::Overrides;

pub(crate) fn extract_borrowed_lifetimes(
raw: &syn::Field,
overrides: &Overrides,
) -> HashSet<Lifetime> {
let mut borrowed_lifetimes = HashSet::new();

// If the field is a reference, it's lifetime should be implicitly borrowed. Serde does
// the same thing
if let Type::Reference(ref_type) = &raw.ty {
borrowed_lifetimes.insert(ref_type.lifetime.to_owned().unwrap());
}

// Borrow all generic lifetimes of fields marked with #[postgres(borrow)]
if overrides.borrows {
if let Type::Path(type_path) = &raw.ty {
for segment in &type_path.path.segments {
if let PathArguments::AngleBracketed(AngleBracketedGenericArguments {
args, ..
}) = &segment.arguments
{
let lifetimes = args.iter().filter_map(|a| match a {
GenericArgument::Lifetime(lifetime) => Some(lifetime.to_owned()),
_ => None,
});
borrowed_lifetimes.extend(lifetimes);
}
}
}
}

borrowed_lifetimes
}
10 changes: 10 additions & 0 deletions postgres-derive/src/overrides.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ pub struct Overrides {
pub rename_all: Option<RenameRule>,
pub transparent: bool,
pub allow_mismatch: bool,
pub borrows: bool,
}

impl Overrides {
Expand All @@ -17,6 +18,7 @@ impl Overrides {
rename_all: None,
transparent: false,
allow_mismatch: false,
borrows: false,
};

for attr in attrs {
Expand Down Expand Up @@ -92,6 +94,14 @@ impl Overrides {
));
}
overrides.allow_mismatch = true;
} else if path.is_ident("borrow") {
if container_attr {
return Err(Error::new_spanned(
path,
"#[postgres(borrow)] is a field attribute",
));
}
overrides.borrows = true;
} else {
return Err(Error::new_spanned(path, "unknown override"));
}
Expand Down
6 changes: 3 additions & 3 deletions postgres-derive/src/tosql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use syn::{
};

use crate::accepts;
use crate::composites::Field;
use crate::composites::NamedField;
use crate::composites::{append_generic_bound, new_derive_path};
use crate::enums::Variant;
use crate::overrides::Overrides;
Expand Down Expand Up @@ -92,7 +92,7 @@ pub fn expand_derive_tosql(input: DeriveInput) -> Result<TokenStream, Error> {
let fields = fields
.named
.iter()
.map(|field| Field::parse(field, overrides.rename_all))
.map(|field| NamedField::parse(field, overrides.rename_all))
.collect::<Result<Vec<_>, _>>()?;
(
accepts::composite_body(&name, "ToSql", &fields),
Expand Down Expand Up @@ -168,7 +168,7 @@ fn domain_body() -> TokenStream {
}
}

fn composite_body(fields: &[Field]) -> TokenStream {
fn composite_body(fields: &[NamedField]) -> TokenStream {
let field_names = fields.iter().map(|f| &f.name);
let field_idents = fields.iter().map(|f| &f.ident);

Expand Down
17 changes: 17 additions & 0 deletions postgres-derive/src/transparent.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
use syn::{Error, Lifetime};
use std::collections::HashSet;
use lifetimes::extract_borrowed_lifetimes;
use crate::lifetimes;
use crate::overrides::Overrides;

pub struct UnnamedField {
pub borrowed_lifetimes: HashSet<Lifetime>,
}

impl UnnamedField {
pub fn parse(raw: &syn::Field) -> Result<UnnamedField, Error> {
let overrides = Overrides::extract(&raw.attrs, false)?;
let borrowed_lifetimes = extract_borrowed_lifetimes(raw, &overrides);
Ok(UnnamedField { borrowed_lifetimes })
}
}