From eef4284d96b993bf36071239c3c1676dc22cca33 Mon Sep 17 00:00:00 2001 From: Joey de Waal Date: Thu, 9 Jan 2025 21:12:39 +0100 Subject: [PATCH 01/11] feat: implement Decode,Encode,Type for Box,Arc,Cow --- sqlx-core/src/decode.rs | 37 ++++++++++++++++ sqlx-core/src/encode.rs | 87 ++++++++++++++++++++++++++++++++++++++ sqlx-core/src/types/mod.rs | 40 ++++++++++++++++++ tests/postgres/types.rs | 42 ++++++++++++++++++ 4 files changed, 206 insertions(+) diff --git a/sqlx-core/src/decode.rs b/sqlx-core/src/decode.rs index 3249c349cc..7d36e24785 100644 --- a/sqlx-core/src/decode.rs +++ b/sqlx-core/src/decode.rs @@ -1,5 +1,8 @@ //! Provides [`Decode`] for decoding values from the database. +use std::borrow::Cow; +use std::sync::Arc; + use crate::database::Database; use crate::error::BoxDynError; @@ -77,3 +80,37 @@ where } } } + +// implement `Decode` for Arc for all SQL types +impl<'r, DB, T> Decode<'r, DB> for Arc +where + DB: Database, + T: Decode<'r, DB>, +{ + fn decode(value: ::ValueRef<'r>) -> Result { + Ok(Arc::new(T::decode(value)?)) + } +} + +// implement `Decode` for Cow for all SQL types +impl<'r, DB, T> Decode<'r, DB> for Cow<'_, T> +where + DB: Database, + T: Decode<'r, DB>, + T: ToOwned, +{ + fn decode(value: ::ValueRef<'r>) -> Result { + Ok(Cow::Owned(T::decode(value)?)) + } +} + +// implement `Decode` for Box for all SQL types +impl<'r, DB, T> Decode<'r, DB> for Box +where + DB: Database, + T: Decode<'r, DB>, +{ + fn decode(value: ::ValueRef<'r>) -> Result { + Ok(Box::new(T::decode(value)?)) + } +} diff --git a/sqlx-core/src/encode.rs b/sqlx-core/src/encode.rs index 2d28641f94..15f48a6086 100644 --- a/sqlx-core/src/encode.rs +++ b/sqlx-core/src/encode.rs @@ -1,6 +1,8 @@ //! Provides [`Encode`] for encoding values for the database. +use std::borrow::Cow; use std::mem; +use std::sync::Arc; use crate::database::Database; use crate::error::BoxDynError; @@ -129,3 +131,88 @@ macro_rules! impl_encode_for_option { } }; } + +impl<'q, T, DB: Database> Encode<'q, DB> for Arc +where + T: Encode<'q, DB>, +{ + #[inline] + fn encode(self, buf: &mut ::ArgumentBuffer<'q>) -> Result { + >::encode_by_ref(self.as_ref(), buf) + } + + #[inline] + fn encode_by_ref( + &self, + buf: &mut ::ArgumentBuffer<'q>, + ) -> Result { + <&T as Encode>::encode(self, buf) + } + + #[inline] + fn produces(&self) -> Option { + (**self).produces() + } + + #[inline] + fn size_hint(&self) -> usize { + (**self).size_hint() + } +} + +impl<'q, T, DB: Database> Encode<'q, DB> for Cow<'_, T> +where + T: Encode<'q, DB>, + T: ToOwned, +{ + #[inline] + fn encode(self, buf: &mut ::ArgumentBuffer<'q>) -> Result { + >::encode_by_ref(self.as_ref(), buf) + } + + #[inline] + fn encode_by_ref( + &self, + buf: &mut ::ArgumentBuffer<'q>, + ) -> Result { + <&T as Encode>::encode(self, buf) + } + + #[inline] + fn produces(&self) -> Option { + (**self).produces() + } + + #[inline] + fn size_hint(&self) -> usize { + (**self).size_hint() + } +} + +impl<'q, T, DB: Database> Encode<'q, DB> for Box +where + T: Encode<'q, DB>, +{ + #[inline] + fn encode(self, buf: &mut ::ArgumentBuffer<'q>) -> Result { + >::encode_by_ref(self.as_ref(), buf) + } + + #[inline] + fn encode_by_ref( + &self, + buf: &mut ::ArgumentBuffer<'q>, + ) -> Result { + <&T as Encode>::encode(self, buf) + } + + #[inline] + fn produces(&self) -> Option { + (**self).produces() + } + + #[inline] + fn size_hint(&self) -> usize { + (**self).size_hint() + } +} diff --git a/sqlx-core/src/types/mod.rs b/sqlx-core/src/types/mod.rs index b00427daae..f3e038b2af 100644 --- a/sqlx-core/src/types/mod.rs +++ b/sqlx-core/src/types/mod.rs @@ -17,6 +17,8 @@ //! To represent nullable SQL types, `Option` is supported where `T` implements `Type`. //! An `Option` represents a potentially `NULL` value from SQL. +use std::{borrow::Cow, sync::Arc}; + use crate::database::Database; use crate::type_info::TypeInfo; @@ -248,3 +250,41 @@ impl, DB: Database> Type for Option { ty.is_null() || >::compatible(ty) } } + +impl Type for Arc +where + T: Type, + T: ?Sized, +{ + fn type_info() -> DB::TypeInfo { + >::type_info() + } + + fn compatible(ty: &DB::TypeInfo) -> bool { + ty.is_null() || >::compatible(ty) + } +} + +impl Type for Cow<'_, T> +where + T: Type, + T: ToOwned, +{ + fn type_info() -> DB::TypeInfo { + >::type_info() + } + + fn compatible(ty: &DB::TypeInfo) -> bool { + ty.is_null() || >::compatible(ty) + } +} + +impl, DB: Database> Type for Box { + fn type_info() -> DB::TypeInfo { + >::type_info() + } + + fn compatible(ty: &DB::TypeInfo) -> bool { + ty.is_null() || >::compatible(ty) + } +} diff --git a/tests/postgres/types.rs b/tests/postgres/types.rs index d5d34bc1b3..dfea3e7d53 100644 --- a/tests/postgres/types.rs +++ b/tests/postgres/types.rs @@ -1,8 +1,10 @@ extern crate time_ as time; +use std::borrow::Cow; use std::net::SocketAddr; use std::ops::Bound; use std::str::FromStr; +use std::sync::Arc; use sqlx::postgres::types::{Oid, PgCiText, PgInterval, PgMoney, PgRange}; use sqlx::postgres::Postgres; @@ -736,3 +738,43 @@ CREATE TEMPORARY TABLE user_login ( Ok(()) } + +#[sqlx_macros::test] +async fn test_arc() -> anyhow::Result<()> { + let mut conn = new::().await?; + + let user_age: Arc = sqlx::query_scalar("select $1 as age ") + .bind(Arc::new(1i32)) + .fetch_one(&mut conn) + .await?; + assert!(user_age.as_ref() == &1); + Ok(()) +} + +#[sqlx_macros::test] +async fn test_cow() -> anyhow::Result<()> { + let mut conn = new::().await?; + + let age: Cow<'_, i32> = Cow::Owned(1i32); + + let user_age: Cow<'static, i32> = sqlx::query_scalar("select $1 as age ") + .bind(age) + .fetch_one(&mut conn) + .await?; + + assert!(user_age.as_ref() == &1); + Ok(()) +} + +#[sqlx_macros::test] +async fn test_box() -> anyhow::Result<()> { + let mut conn = new::().await?; + + let user_age: Box = sqlx::query_scalar("select $1 as age ") + .bind(Box::new(1)) + .fetch_one(&mut conn) + .await?; + + assert!(user_age.as_ref() == &1); + Ok(()) +} From 23529b83a4b02ec80eea4c1d8b838fda1283fedc Mon Sep 17 00:00:00 2001 From: Joey de Waal Date: Thu, 9 Jan 2025 22:19:09 +0100 Subject: [PATCH 02/11] feat implement Encode,Type for Rc --- sqlx-core/src/encode.rs | 29 +++++++++++++++++++++++++++++ sqlx-core/src/types/mod.rs | 12 +++++++++++- tests/postgres/types.rs | 20 +++++++++++++++++--- 3 files changed, 57 insertions(+), 4 deletions(-) diff --git a/sqlx-core/src/encode.rs b/sqlx-core/src/encode.rs index 15f48a6086..51a1becfc3 100644 --- a/sqlx-core/src/encode.rs +++ b/sqlx-core/src/encode.rs @@ -2,6 +2,7 @@ use std::borrow::Cow; use std::mem; +use std::rc::Rc; use std::sync::Arc; use crate::database::Database; @@ -216,3 +217,31 @@ where (**self).size_hint() } } + +impl<'q, T, DB: Database> Encode<'q, DB> for Rc +where + T: Encode<'q, DB>, +{ + #[inline] + fn encode(self, buf: &mut ::ArgumentBuffer<'q>) -> Result { + >::encode_by_ref(self.as_ref(), buf) + } + + #[inline] + fn encode_by_ref( + &self, + buf: &mut ::ArgumentBuffer<'q>, + ) -> Result { + <&T as Encode>::encode(self, buf) + } + + #[inline] + fn produces(&self) -> Option { + (**self).produces() + } + + #[inline] + fn size_hint(&self) -> usize { + (**self).size_hint() + } +} diff --git a/sqlx-core/src/types/mod.rs b/sqlx-core/src/types/mod.rs index f3e038b2af..5c47375a5d 100644 --- a/sqlx-core/src/types/mod.rs +++ b/sqlx-core/src/types/mod.rs @@ -17,7 +17,7 @@ //! To represent nullable SQL types, `Option` is supported where `T` implements `Type`. //! An `Option` represents a potentially `NULL` value from SQL. -use std::{borrow::Cow, sync::Arc}; +use std::{borrow::Cow, rc::Rc, sync::Arc}; use crate::database::Database; use crate::type_info::TypeInfo; @@ -288,3 +288,13 @@ impl, DB: Database> Type for Box { ty.is_null() || >::compatible(ty) } } + +impl, DB: Database> Type for Rc { + fn type_info() -> DB::TypeInfo { + >::type_info() + } + + fn compatible(ty: &DB::TypeInfo) -> bool { + ty.is_null() || >::compatible(ty) + } +} diff --git a/tests/postgres/types.rs b/tests/postgres/types.rs index dfea3e7d53..52a7fd31eb 100644 --- a/tests/postgres/types.rs +++ b/tests/postgres/types.rs @@ -4,6 +4,7 @@ use std::borrow::Cow; use std::net::SocketAddr; use std::ops::Bound; use std::str::FromStr; +use std::rc::Rc; use std::sync::Arc; use sqlx::postgres::types::{Oid, PgCiText, PgInterval, PgMoney, PgRange}; @@ -743,7 +744,7 @@ CREATE TEMPORARY TABLE user_login ( async fn test_arc() -> anyhow::Result<()> { let mut conn = new::().await?; - let user_age: Arc = sqlx::query_scalar("select $1 as age ") + let user_age: Arc = sqlx::query_scalar("SELECT $1 AS age ") .bind(Arc::new(1i32)) .fetch_one(&mut conn) .await?; @@ -757,7 +758,7 @@ async fn test_cow() -> anyhow::Result<()> { let age: Cow<'_, i32> = Cow::Owned(1i32); - let user_age: Cow<'static, i32> = sqlx::query_scalar("select $1 as age ") + let user_age: Cow<'static, i32> = sqlx::query_scalar("SELECT $1 AS age ") .bind(age) .fetch_one(&mut conn) .await?; @@ -770,7 +771,7 @@ async fn test_cow() -> anyhow::Result<()> { async fn test_box() -> anyhow::Result<()> { let mut conn = new::().await?; - let user_age: Box = sqlx::query_scalar("select $1 as age ") + let user_age: Box = sqlx::query_scalar("SELECT $1 AS age ") .bind(Box::new(1)) .fetch_one(&mut conn) .await?; @@ -778,3 +779,16 @@ async fn test_box() -> anyhow::Result<()> { assert!(user_age.as_ref() == &1); Ok(()) } + +#[sqlx_macros::test] +async fn test_rc() -> anyhow::Result<()> { + let mut conn = new::().await?; + + let user_age: i32 = sqlx::query_scalar("SELECT $1 AS age") + .bind(Rc::new(1i32)) + .fetch_one(&mut conn) + .await?; + + assert!(user_age == 1); + Ok(()) +} From a510645653e660e9c516e9ed5e7615561d587799 Mon Sep 17 00:00:00 2001 From: Joey de Waal Date: Sun, 2 Mar 2025 22:24:07 +0100 Subject: [PATCH 03/11] feat: implement Decode for Rc --- sqlx-core/src/decode.rs | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/sqlx-core/src/decode.rs b/sqlx-core/src/decode.rs index 7d36e24785..1d22eec5a4 100644 --- a/sqlx-core/src/decode.rs +++ b/sqlx-core/src/decode.rs @@ -1,6 +1,7 @@ //! Provides [`Decode`] for decoding values from the database. use std::borrow::Cow; +use std::rc::Rc; use std::sync::Arc; use crate::database::Database; @@ -114,3 +115,14 @@ where Ok(Box::new(T::decode(value)?)) } } + +// implement `Decode` for Rc for all SQL types +impl<'r, DB, T> Decode<'r, DB> for Rc +where + DB: Database, + T: Decode<'r, DB>, +{ + fn decode(value: ::ValueRef<'r>) -> Result { + Ok(Rc::new(T::decode(value)?)) + } +} From 39d8e533729f6e6f7f85f7ff439d0fbac0268881 Mon Sep 17 00:00:00 2001 From: Joey de Waal Date: Sun, 2 Mar 2025 22:35:20 +0100 Subject: [PATCH 04/11] chore: make tests more concise --- tests/postgres/types.rs | 58 +++++++++-------------------------------- 1 file changed, 12 insertions(+), 46 deletions(-) diff --git a/tests/postgres/types.rs b/tests/postgres/types.rs index 52a7fd31eb..c4f2f48b79 100644 --- a/tests/postgres/types.rs +++ b/tests/postgres/types.rs @@ -744,51 +744,17 @@ CREATE TEMPORARY TABLE user_login ( async fn test_arc() -> anyhow::Result<()> { let mut conn = new::().await?; - let user_age: Arc = sqlx::query_scalar("SELECT $1 AS age ") - .bind(Arc::new(1i32)) - .fetch_one(&mut conn) - .await?; - assert!(user_age.as_ref() == &1); - Ok(()) -} - -#[sqlx_macros::test] -async fn test_cow() -> anyhow::Result<()> { - let mut conn = new::().await?; - - let age: Cow<'_, i32> = Cow::Owned(1i32); - - let user_age: Cow<'static, i32> = sqlx::query_scalar("SELECT $1 AS age ") - .bind(age) - .fetch_one(&mut conn) - .await?; - - assert!(user_age.as_ref() == &1); - Ok(()) -} - -#[sqlx_macros::test] -async fn test_box() -> anyhow::Result<()> { - let mut conn = new::().await?; - - let user_age: Box = sqlx::query_scalar("SELECT $1 AS age ") - .bind(Box::new(1)) - .fetch_one(&mut conn) - .await?; - - assert!(user_age.as_ref() == &1); - Ok(()) -} - -#[sqlx_macros::test] -async fn test_rc() -> anyhow::Result<()> { - let mut conn = new::().await?; - - let user_age: i32 = sqlx::query_scalar("SELECT $1 AS age") - .bind(Rc::new(1i32)) - .fetch_one(&mut conn) - .await?; - - assert!(user_age == 1); + let user_age: (Arc, Cow<'static, i32>, Box, i32) = + sqlx::query_as("SELECT $1, $2, $3, $4") + .bind(Arc::new(1i32)) + .bind(Cow::<'_, i32>::Owned(2i32)) + .bind(Box::new(3i32)) + .bind(Rc::new(4i32)) + .fetch_one(&mut conn) + .await?; + assert!(user_age.0.as_ref() == &1); + assert!(user_age.1.as_ref() == &2); + assert!(user_age.2.as_ref() == &3); + assert!(user_age.3 == 4); Ok(()) } From abc5adb3d48534e5ba70a402652efe7f412ec39a Mon Sep 17 00:00:00 2001 From: Joey de Waal Date: Sun, 2 Mar 2025 23:23:27 +0100 Subject: [PATCH 05/11] chore: use macro's --- sqlx-core/src/encode.rs | 113 ++++++++++----------------------- sqlx-core/src/types/mod.rs | 52 ++++++--------- sqlx-postgres/src/types/str.rs | 10 --- tests/postgres/types.rs | 3 +- 4 files changed, 56 insertions(+), 122 deletions(-) diff --git a/sqlx-core/src/encode.rs b/sqlx-core/src/encode.rs index 51a1becfc3..6f3d53eee5 100644 --- a/sqlx-core/src/encode.rs +++ b/sqlx-core/src/encode.rs @@ -133,34 +133,45 @@ macro_rules! impl_encode_for_option { }; } -impl<'q, T, DB: Database> Encode<'q, DB> for Arc -where - T: Encode<'q, DB>, -{ - #[inline] - fn encode(self, buf: &mut ::ArgumentBuffer<'q>) -> Result { - >::encode_by_ref(self.as_ref(), buf) - } +macro_rules! impl_encode_for_smartpointer { + ($smart_pointer:ty) => { + impl<'q, T, DB: Database> Encode<'q, DB> for $smart_pointer + where + T: Encode<'q, DB>, + { + #[inline] + fn encode( + self, + buf: &mut ::ArgumentBuffer<'q>, + ) -> Result { + >::encode_by_ref(self.as_ref(), buf) + } - #[inline] - fn encode_by_ref( - &self, - buf: &mut ::ArgumentBuffer<'q>, - ) -> Result { - <&T as Encode>::encode(self, buf) - } + #[inline] + fn encode_by_ref( + &self, + buf: &mut ::ArgumentBuffer<'q>, + ) -> Result { + <&T as Encode>::encode(self, buf) + } - #[inline] - fn produces(&self) -> Option { - (**self).produces() - } + #[inline] + fn produces(&self) -> Option { + (**self).produces() + } - #[inline] - fn size_hint(&self) -> usize { - (**self).size_hint() - } + #[inline] + fn size_hint(&self) -> usize { + (**self).size_hint() + } + } + }; } +impl_encode_for_smartpointer!(Arc); +impl_encode_for_smartpointer!(Box); +impl_encode_for_smartpointer!(Rc); + impl<'q, T, DB: Database> Encode<'q, DB> for Cow<'_, T> where T: Encode<'q, DB>, @@ -189,59 +200,3 @@ where (**self).size_hint() } } - -impl<'q, T, DB: Database> Encode<'q, DB> for Box -where - T: Encode<'q, DB>, -{ - #[inline] - fn encode(self, buf: &mut ::ArgumentBuffer<'q>) -> Result { - >::encode_by_ref(self.as_ref(), buf) - } - - #[inline] - fn encode_by_ref( - &self, - buf: &mut ::ArgumentBuffer<'q>, - ) -> Result { - <&T as Encode>::encode(self, buf) - } - - #[inline] - fn produces(&self) -> Option { - (**self).produces() - } - - #[inline] - fn size_hint(&self) -> usize { - (**self).size_hint() - } -} - -impl<'q, T, DB: Database> Encode<'q, DB> for Rc -where - T: Encode<'q, DB>, -{ - #[inline] - fn encode(self, buf: &mut ::ArgumentBuffer<'q>) -> Result { - >::encode_by_ref(self.as_ref(), buf) - } - - #[inline] - fn encode_by_ref( - &self, - buf: &mut ::ArgumentBuffer<'q>, - ) -> Result { - <&T as Encode>::encode(self, buf) - } - - #[inline] - fn produces(&self) -> Option { - (**self).produces() - } - - #[inline] - fn size_hint(&self) -> usize { - (**self).size_hint() - } -} diff --git a/sqlx-core/src/types/mod.rs b/sqlx-core/src/types/mod.rs index 5c47375a5d..aa618c55ee 100644 --- a/sqlx-core/src/types/mod.rs +++ b/sqlx-core/src/types/mod.rs @@ -251,20 +251,28 @@ impl, DB: Database> Type for Option { } } -impl Type for Arc -where - T: Type, - T: ?Sized, -{ - fn type_info() -> DB::TypeInfo { - >::type_info() - } +macro_rules! impl_type_for_smartpointer { + ($smart_pointer:ty) => { + impl Type for $smart_pointer + where + T: Type, + T: ?Sized, + { + fn type_info() -> DB::TypeInfo { + >::type_info() + } - fn compatible(ty: &DB::TypeInfo) -> bool { - ty.is_null() || >::compatible(ty) - } + fn compatible(ty: &DB::TypeInfo) -> bool { + >::compatible(ty) + } + } + }; } +impl_type_for_smartpointer!(Arc); +impl_type_for_smartpointer!(Box); +impl_type_for_smartpointer!(Rc); + impl Type for Cow<'_, T> where T: Type, @@ -275,26 +283,6 @@ where } fn compatible(ty: &DB::TypeInfo) -> bool { - ty.is_null() || >::compatible(ty) - } -} - -impl, DB: Database> Type for Box { - fn type_info() -> DB::TypeInfo { - >::type_info() - } - - fn compatible(ty: &DB::TypeInfo) -> bool { - ty.is_null() || >::compatible(ty) - } -} - -impl, DB: Database> Type for Rc { - fn type_info() -> DB::TypeInfo { - >::type_info() - } - - fn compatible(ty: &DB::TypeInfo) -> bool { - ty.is_null() || >::compatible(ty) + >::compatible(ty) } } diff --git a/sqlx-postgres/src/types/str.rs b/sqlx-postgres/src/types/str.rs index ca7e20a558..c4d309333a 100644 --- a/sqlx-postgres/src/types/str.rs +++ b/sqlx-postgres/src/types/str.rs @@ -34,16 +34,6 @@ impl Type for Cow<'_, str> { } } -impl Type for Box { - fn type_info() -> PgTypeInfo { - <&str as Type>::type_info() - } - - fn compatible(ty: &PgTypeInfo) -> bool { - <&str as Type>::compatible(ty) - } -} - impl Type for String { fn type_info() -> PgTypeInfo { <&str as Type>::type_info() diff --git a/tests/postgres/types.rs b/tests/postgres/types.rs index c4f2f48b79..51c5262daf 100644 --- a/tests/postgres/types.rs +++ b/tests/postgres/types.rs @@ -747,11 +747,12 @@ async fn test_arc() -> anyhow::Result<()> { let user_age: (Arc, Cow<'static, i32>, Box, i32) = sqlx::query_as("SELECT $1, $2, $3, $4") .bind(Arc::new(1i32)) - .bind(Cow::<'_, i32>::Owned(2i32)) + .bind(Cow::<'_, i32>::Borrowed(&2i32)) .bind(Box::new(3i32)) .bind(Rc::new(4i32)) .fetch_one(&mut conn) .await?; + assert!(user_age.0.as_ref() == &1); assert!(user_age.1.as_ref() == &2); assert!(user_age.2.as_ref() == &3); From 1d1a54692b702c53d5d612d8c8222ae5fbfb80ae Mon Sep 17 00:00:00 2001 From: Joey de Waal Date: Sun, 2 Mar 2025 23:29:58 +0100 Subject: [PATCH 06/11] chore: remove conflicting impls --- sqlx-core/src/types/mod.rs | 1 + sqlx-mysql/src/types/bytes.rs | 10 ---------- sqlx-mysql/src/types/str.rs | 20 -------------------- sqlx-postgres/src/types/str.rs | 10 ---------- sqlx-sqlite/src/types/bytes.rs | 10 ---------- sqlx-sqlite/src/types/str.rs | 16 ---------------- 6 files changed, 1 insertion(+), 66 deletions(-) diff --git a/sqlx-core/src/types/mod.rs b/sqlx-core/src/types/mod.rs index aa618c55ee..3613e4b6b6 100644 --- a/sqlx-core/src/types/mod.rs +++ b/sqlx-core/src/types/mod.rs @@ -277,6 +277,7 @@ impl Type for Cow<'_, T> where T: Type, T: ToOwned, + T: ?Sized, { fn type_info() -> DB::TypeInfo { >::type_info() diff --git a/sqlx-mysql/src/types/bytes.rs b/sqlx-mysql/src/types/bytes.rs index ade079ad4e..14cf16865a 100644 --- a/sqlx-mysql/src/types/bytes.rs +++ b/sqlx-mysql/src/types/bytes.rs @@ -40,16 +40,6 @@ impl<'r> Decode<'r, MySql> for &'r [u8] { } } -impl Type for Box<[u8]> { - fn type_info() -> MySqlTypeInfo { - <&[u8] as Type>::type_info() - } - - fn compatible(ty: &MySqlTypeInfo) -> bool { - <&[u8] as Type>::compatible(ty) - } -} - impl Encode<'_, MySql> for Box<[u8]> { fn encode_by_ref(&self, buf: &mut Vec) -> Result { <&[u8] as Encode>::encode(self.as_ref(), buf) diff --git a/sqlx-mysql/src/types/str.rs b/sqlx-mysql/src/types/str.rs index 8233e90893..b13d2d3a82 100644 --- a/sqlx-mysql/src/types/str.rs +++ b/sqlx-mysql/src/types/str.rs @@ -46,16 +46,6 @@ impl<'r> Decode<'r, MySql> for &'r str { } } -impl Type for Box { - fn type_info() -> MySqlTypeInfo { - <&str as Type>::type_info() - } - - fn compatible(ty: &MySqlTypeInfo) -> bool { - <&str as Type>::compatible(ty) - } -} - impl Encode<'_, MySql> for Box { fn encode_by_ref(&self, buf: &mut Vec) -> Result { <&str as Encode>::encode(&**self, buf) @@ -90,16 +80,6 @@ impl Decode<'_, MySql> for String { } } -impl Type for Cow<'_, str> { - fn type_info() -> MySqlTypeInfo { - <&str as Type>::type_info() - } - - fn compatible(ty: &MySqlTypeInfo) -> bool { - <&str as Type>::compatible(ty) - } -} - impl Encode<'_, MySql> for Cow<'_, str> { fn encode_by_ref(&self, buf: &mut Vec) -> Result { match self { diff --git a/sqlx-postgres/src/types/str.rs b/sqlx-postgres/src/types/str.rs index c4d309333a..182a295738 100644 --- a/sqlx-postgres/src/types/str.rs +++ b/sqlx-postgres/src/types/str.rs @@ -24,16 +24,6 @@ impl Type for str { } } -impl Type for Cow<'_, str> { - fn type_info() -> PgTypeInfo { - <&str as Type>::type_info() - } - - fn compatible(ty: &PgTypeInfo) -> bool { - <&str as Type>::compatible(ty) - } -} - impl Type for String { fn type_info() -> PgTypeInfo { <&str as Type>::type_info() diff --git a/sqlx-sqlite/src/types/bytes.rs b/sqlx-sqlite/src/types/bytes.rs index f854b911c5..0ab51aff2c 100644 --- a/sqlx-sqlite/src/types/bytes.rs +++ b/sqlx-sqlite/src/types/bytes.rs @@ -34,16 +34,6 @@ impl<'r> Decode<'r, Sqlite> for &'r [u8] { } } -impl Type for Box<[u8]> { - fn type_info() -> SqliteTypeInfo { - <&[u8] as Type>::type_info() - } - - fn compatible(ty: &SqliteTypeInfo) -> bool { - <&[u8] as Type>::compatible(ty) - } -} - impl Encode<'_, Sqlite> for Box<[u8]> { fn encode(self, args: &mut Vec>) -> Result { args.push(SqliteArgumentValue::Blob(Cow::Owned(self.into_vec()))); diff --git a/sqlx-sqlite/src/types/str.rs b/sqlx-sqlite/src/types/str.rs index bfaffae78e..1b39343f61 100644 --- a/sqlx-sqlite/src/types/str.rs +++ b/sqlx-sqlite/src/types/str.rs @@ -30,12 +30,6 @@ impl<'r> Decode<'r, Sqlite> for &'r str { } } -impl Type for Box { - fn type_info() -> SqliteTypeInfo { - <&str as Type>::type_info() - } -} - impl Encode<'_, Sqlite> for Box { fn encode(self, args: &mut Vec>) -> Result { args.push(SqliteArgumentValue::Text(Cow::Owned(self.into_string()))); @@ -90,16 +84,6 @@ impl<'r> Decode<'r, Sqlite> for String { } } -impl Type for Cow<'_, str> { - fn type_info() -> SqliteTypeInfo { - <&str as Type>::type_info() - } - - fn compatible(ty: &SqliteTypeInfo) -> bool { - <&str as Type>::compatible(ty) - } -} - impl<'q> Encode<'q, Sqlite> for Cow<'q, str> { fn encode(self, args: &mut Vec>) -> Result { args.push(SqliteArgumentValue::Text(self)); From feb6a8a48206bee83b409ce9e7704567ca9c48c4 Mon Sep 17 00:00:00 2001 From: Joey de Waal Date: Mon, 3 Mar 2025 10:43:55 +0100 Subject: [PATCH 07/11] chore: more macro's --- sqlx-core/src/decode.rs | 47 ++++++++++++++--------------------------- 1 file changed, 16 insertions(+), 31 deletions(-) diff --git a/sqlx-core/src/decode.rs b/sqlx-core/src/decode.rs index 1d22eec5a4..d0d6a1ee17 100644 --- a/sqlx-core/src/decode.rs +++ b/sqlx-core/src/decode.rs @@ -82,17 +82,24 @@ where } } -// implement `Decode` for Arc for all SQL types -impl<'r, DB, T> Decode<'r, DB> for Arc -where - DB: Database, - T: Decode<'r, DB>, -{ - fn decode(value: ::ValueRef<'r>) -> Result { - Ok(Arc::new(T::decode(value)?)) - } +macro_rules! impl_decode_for_smartpointer { + ($smart_pointer:ty) => { + impl<'r, DB, T> Decode<'r, DB> for $smart_pointer + where + DB: Database, + T: Decode<'r, DB>, + { + fn decode(value: ::ValueRef<'r>) -> Result { + Ok(Self::new(T::decode(value)?)) + } + } + }; } +impl_decode_for_smartpointer!(Arc); +impl_decode_for_smartpointer!(Box); +impl_decode_for_smartpointer!(Rc); + // implement `Decode` for Cow for all SQL types impl<'r, DB, T> Decode<'r, DB> for Cow<'_, T> where @@ -104,25 +111,3 @@ where Ok(Cow::Owned(T::decode(value)?)) } } - -// implement `Decode` for Box for all SQL types -impl<'r, DB, T> Decode<'r, DB> for Box -where - DB: Database, - T: Decode<'r, DB>, -{ - fn decode(value: ::ValueRef<'r>) -> Result { - Ok(Box::new(T::decode(value)?)) - } -} - -// implement `Decode` for Rc for all SQL types -impl<'r, DB, T> Decode<'r, DB> for Rc -where - DB: Database, - T: Decode<'r, DB>, -{ - fn decode(value: ::ValueRef<'r>) -> Result { - Ok(Rc::new(T::decode(value)?)) - } -} From 6669bd3adf0529dee19b598758cd10a347caaef8 Mon Sep 17 00:00:00 2001 From: Joey de Waal Date: Wed, 12 Mar 2025 09:53:41 +0100 Subject: [PATCH 08/11] Relax Sized bound for Decode, Encode --- sqlx-core/src/decode.rs | 61 +++++++++++++++++++++++++++----- sqlx-core/src/encode.rs | 10 +++--- sqlx-core/src/types/mod.rs | 2 +- sqlx-mysql/src/types/bytes.rs | 6 ---- sqlx-mysql/src/types/str.rs | 15 +++----- sqlx-postgres/src/types/bytes.rs | 9 ----- sqlx-postgres/src/types/str.rs | 30 +++++++--------- sqlx-sqlite/src/types/bytes.rs | 6 ---- sqlx-sqlite/src/types/str.rs | 23 +++++++----- tests/mysql/types.rs | 51 +++++++++++++++++++++++++- tests/postgres/types.rs | 29 ++++++++++++++- tests/sqlite/types.rs | 50 ++++++++++++++++++++++++++ 12 files changed, 218 insertions(+), 74 deletions(-) diff --git a/sqlx-core/src/decode.rs b/sqlx-core/src/decode.rs index d0d6a1ee17..081228e958 100644 --- a/sqlx-core/src/decode.rs +++ b/sqlx-core/src/decode.rs @@ -83,8 +83,8 @@ where } macro_rules! impl_decode_for_smartpointer { - ($smart_pointer:ty) => { - impl<'r, DB, T> Decode<'r, DB> for $smart_pointer + ($smart_pointer:tt) => { + impl<'r, DB, T> Decode<'r, DB> for $smart_pointer where DB: Database, T: Decode<'r, DB>, @@ -93,21 +93,66 @@ macro_rules! impl_decode_for_smartpointer { Ok(Self::new(T::decode(value)?)) } } + + impl<'r, DB> Decode<'r, DB> for $smart_pointer + where + DB: Database, + &'r str: Decode<'r, DB>, + { + fn decode(value: ::ValueRef<'r>) -> Result { + let ref_str = <&str as Decode>::decode(value)?; + Ok(ref_str.into()) + } + } + + impl<'r, DB> Decode<'r, DB> for $smart_pointer<[u8]> + where + DB: Database, + &'r [u8]: Decode<'r, DB>, + { + fn decode(value: ::ValueRef<'r>) -> Result { + let ref_str = <&[u8] as Decode>::decode(value)?; + Ok(ref_str.into()) + } + } }; } -impl_decode_for_smartpointer!(Arc); -impl_decode_for_smartpointer!(Box); -impl_decode_for_smartpointer!(Rc); +impl_decode_for_smartpointer!(Arc); +impl_decode_for_smartpointer!(Box); +impl_decode_for_smartpointer!(Rc); // implement `Decode` for Cow for all SQL types impl<'r, DB, T> Decode<'r, DB> for Cow<'_, T> where DB: Database, - T: Decode<'r, DB>, - T: ToOwned, + T: ToOwned, + ::Owned: Decode<'r, DB>, +{ + fn decode(value: ::ValueRef<'r>) -> Result { + let owned = <::Owned as Decode>::decode(value)?; + Ok(Cow::Owned(owned)) + } +} + +impl<'r, DB> Decode<'r, DB> for Cow<'r, str> +where + DB: Database, + &'r str: Decode<'r, DB>, +{ + fn decode(value: ::ValueRef<'r>) -> Result { + let borrowed = <&str as Decode>::decode(value)?; + Ok(Cow::Borrowed(borrowed)) + } +} + +impl<'r, DB> Decode<'r, DB> for Cow<'r, [u8]> +where + DB: Database, + &'r [u8]: Decode<'r, DB>, { fn decode(value: ::ValueRef<'r>) -> Result { - Ok(Cow::Owned(T::decode(value)?)) + let borrowed = <&[u8] as Decode>::decode(value)?; + Ok(Cow::Borrowed(borrowed)) } } diff --git a/sqlx-core/src/encode.rs b/sqlx-core/src/encode.rs index 6f3d53eee5..3273a20e62 100644 --- a/sqlx-core/src/encode.rs +++ b/sqlx-core/src/encode.rs @@ -175,11 +175,11 @@ impl_encode_for_smartpointer!(Rc); impl<'q, T, DB: Database> Encode<'q, DB> for Cow<'_, T> where T: Encode<'q, DB>, - T: ToOwned, + T: ToOwned, { #[inline] fn encode(self, buf: &mut ::ArgumentBuffer<'q>) -> Result { - >::encode_by_ref(self.as_ref(), buf) + <&T as Encode>::encode_by_ref(&self.as_ref(), buf) } #[inline] @@ -187,16 +187,16 @@ where &self, buf: &mut ::ArgumentBuffer<'q>, ) -> Result { - <&T as Encode>::encode(self, buf) + <&T as Encode>::encode_by_ref(&self.as_ref(), buf) } #[inline] fn produces(&self) -> Option { - (**self).produces() + <&T as Encode>::produces(&self.as_ref()) } #[inline] fn size_hint(&self) -> usize { - (**self).size_hint() + <&T as Encode>::size_hint(&self.as_ref()) } } diff --git a/sqlx-core/src/types/mod.rs b/sqlx-core/src/types/mod.rs index 3613e4b6b6..27dcf7c44b 100644 --- a/sqlx-core/src/types/mod.rs +++ b/sqlx-core/src/types/mod.rs @@ -276,7 +276,7 @@ impl_type_for_smartpointer!(Rc); impl Type for Cow<'_, T> where T: Type, - T: ToOwned, + T: ToOwned, T: ?Sized, { fn type_info() -> DB::TypeInfo { diff --git a/sqlx-mysql/src/types/bytes.rs b/sqlx-mysql/src/types/bytes.rs index 14cf16865a..3df0044b69 100644 --- a/sqlx-mysql/src/types/bytes.rs +++ b/sqlx-mysql/src/types/bytes.rs @@ -46,12 +46,6 @@ impl Encode<'_, MySql> for Box<[u8]> { } } -impl<'r> Decode<'r, MySql> for Box<[u8]> { - fn decode(value: MySqlValueRef<'r>) -> Result { - <&[u8] as Decode>::decode(value).map(Box::from) - } -} - impl Type for Vec { fn type_info() -> MySqlTypeInfo { <[u8] as Type>::type_info() diff --git a/sqlx-mysql/src/types/str.rs b/sqlx-mysql/src/types/str.rs index b13d2d3a82..b815ff79cf 100644 --- a/sqlx-mysql/src/types/str.rs +++ b/sqlx-mysql/src/types/str.rs @@ -1,3 +1,5 @@ +use std::borrow::Cow; + use crate::decode::Decode; use crate::encode::{Encode, IsNull}; use crate::error::BoxDynError; @@ -5,7 +7,6 @@ use crate::io::MySqlBufMutExt; use crate::protocol::text::{ColumnFlags, ColumnType}; use crate::types::Type; use crate::{MySql, MySqlTypeInfo, MySqlValueRef}; -use std::borrow::Cow; impl Type for str { fn type_info() -> MySqlTypeInfo { @@ -52,12 +53,6 @@ impl Encode<'_, MySql> for Box { } } -impl<'r> Decode<'r, MySql> for Box { - fn decode(value: MySqlValueRef<'r>) -> Result { - <&str as Decode>::decode(value).map(Box::from) - } -} - impl Type for String { fn type_info() -> MySqlTypeInfo { >::type_info() @@ -89,8 +84,8 @@ impl Encode<'_, MySql> for Cow<'_, str> { } } -impl<'r> Decode<'r, MySql> for Cow<'r, str> { - fn decode(value: MySqlValueRef<'r>) -> Result { - value.as_str().map(Cow::Borrowed) +impl Encode<'_, MySql> for Cow<'_, [u8]> { + fn encode_by_ref(&self, buf: &mut Vec) -> Result { + <&[u8] as Encode>::encode(self.as_ref(), buf) } } diff --git a/sqlx-postgres/src/types/bytes.rs b/sqlx-postgres/src/types/bytes.rs index 45968837af..32ae7aeefb 100644 --- a/sqlx-postgres/src/types/bytes.rs +++ b/sqlx-postgres/src/types/bytes.rs @@ -80,15 +80,6 @@ fn text_hex_decode_input(value: PgValueRef<'_>) -> Result<&[u8], BoxDynError> { .map_err(Into::into) } -impl Decode<'_, Postgres> for Box<[u8]> { - fn decode(value: PgValueRef<'_>) -> Result { - Ok(match value.format() { - PgValueFormat::Binary => Box::from(value.as_bytes()?), - PgValueFormat::Text => Box::from(hex::decode(text_hex_decode_input(value)?)?), - }) - } -} - impl Decode<'_, Postgres> for Vec { fn decode(value: PgValueRef<'_>) -> Result { Ok(match value.format() { diff --git a/sqlx-postgres/src/types/str.rs b/sqlx-postgres/src/types/str.rs index 182a295738..8c1214d161 100644 --- a/sqlx-postgres/src/types/str.rs +++ b/sqlx-postgres/src/types/str.rs @@ -82,15 +82,6 @@ impl Encode<'_, Postgres> for &'_ str { } } -impl Encode<'_, Postgres> for Cow<'_, str> { - fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { - match self { - Cow::Borrowed(str) => <&str as Encode>::encode(*str, buf), - Cow::Owned(str) => <&str as Encode>::encode(&**str, buf), - } - } -} - impl Encode<'_, Postgres> for Box { fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { <&str as Encode>::encode(&**self, buf) @@ -109,20 +100,23 @@ impl<'r> Decode<'r, Postgres> for &'r str { } } -impl<'r> Decode<'r, Postgres> for Cow<'r, str> { - fn decode(value: PgValueRef<'r>) -> Result { - Ok(Cow::Borrowed(value.as_str()?)) +impl Decode<'_, Postgres> for String { + fn decode(value: PgValueRef<'_>) -> Result { + Ok(value.as_str()?.to_owned()) } } -impl<'r> Decode<'r, Postgres> for Box { - fn decode(value: PgValueRef<'r>) -> Result { - Ok(Box::from(value.as_str()?)) +impl Encode<'_, Postgres> for Cow<'_, str> { + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { + match self { + Cow::Borrowed(str) => <&str as Encode>::encode(*str, buf), + Cow::Owned(str) => <&str as Encode>::encode(&**str, buf), + } } } -impl Decode<'_, Postgres> for String { - fn decode(value: PgValueRef<'_>) -> Result { - Ok(value.as_str()?.to_owned()) +impl Encode<'_, Postgres> for Cow<'_, [u8]> { + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { + <&[u8] as Encode>::encode(self.as_ref(), buf) } } diff --git a/sqlx-sqlite/src/types/bytes.rs b/sqlx-sqlite/src/types/bytes.rs index 0ab51aff2c..ddba252ede 100644 --- a/sqlx-sqlite/src/types/bytes.rs +++ b/sqlx-sqlite/src/types/bytes.rs @@ -53,12 +53,6 @@ impl Encode<'_, Sqlite> for Box<[u8]> { } } -impl Decode<'_, Sqlite> for Box<[u8]> { - fn decode(value: SqliteValueRef<'_>) -> Result { - Ok(Box::from(value.blob())) - } -} - impl Type for Vec { fn type_info() -> SqliteTypeInfo { <&[u8] as Type>::type_info() diff --git a/sqlx-sqlite/src/types/str.rs b/sqlx-sqlite/src/types/str.rs index 1b39343f61..c54ec035b7 100644 --- a/sqlx-sqlite/src/types/str.rs +++ b/sqlx-sqlite/src/types/str.rs @@ -49,12 +49,6 @@ impl Encode<'_, Sqlite> for Box { } } -impl Decode<'_, Sqlite> for Box { - fn decode(value: SqliteValueRef<'_>) -> Result { - value.text().map(Box::from) - } -} - impl Type for String { fn type_info() -> SqliteTypeInfo { <&str as Type>::type_info() @@ -101,8 +95,19 @@ impl<'q> Encode<'q, Sqlite> for Cow<'q, str> { } } -impl<'r> Decode<'r, Sqlite> for Cow<'r, str> { - fn decode(value: SqliteValueRef<'r>) -> Result { - value.text().map(Cow::Borrowed) +impl<'q> Encode<'q, Sqlite> for Cow<'q, [u8]> { + fn encode(self, args: &mut Vec>) -> Result { + args.push(SqliteArgumentValue::Blob(self)); + + Ok(IsNull::No) + } + + fn encode_by_ref( + &self, + args: &mut Vec>, + ) -> Result { + args.push(SqliteArgumentValue::Blob(self.clone())); + + Ok(IsNull::No) } } diff --git a/tests/mysql/types.rs b/tests/mysql/types.rs index e837a53f75..3a2ecd962c 100644 --- a/tests/mysql/types.rs +++ b/tests/mysql/types.rs @@ -1,11 +1,14 @@ extern crate time_ as time; +use std::borrow::Cow; use std::net::SocketAddr; +use std::rc::Rc; #[cfg(feature = "rust_decimal")] use std::str::FromStr; +use std::sync::Arc; use sqlx::mysql::MySql; -use sqlx::{Executor, Row}; +use sqlx::{Executor, FromRow, Row}; use sqlx::types::Text; @@ -384,3 +387,49 @@ CREATE TEMPORARY TABLE user_login ( Ok(()) } + +#[sqlx_macros::test] +async fn test_smartpointers() -> anyhow::Result<()> { + let mut conn = new::().await?; + + let user_age: (Arc, Cow<'static, i32>, Box, i32) = + sqlx::query_as("SELECT ?, ?, ?, ?") + .bind(Arc::new(1i32)) + .bind(Cow::<'_, i32>::Borrowed(&2i32)) + .bind(Box::new(3i32)) + .bind(Rc::new(4i32)) + .fetch_one(&mut conn) + .await?; + + assert!(user_age.0.as_ref() == &1); + assert!(user_age.1.as_ref() == &2); + assert!(user_age.2.as_ref() == &3); + assert!(user_age.3 == 4); + Ok(()) +} + +#[sqlx_macros::test] +async fn test_str_slice() -> anyhow::Result<()> { + let mut conn = new::().await?; + + let box_str: Box = "John".into(); + let box_slice: Box<[u8]> = [1, 2, 3, 4].into(); + let cow_str: Cow<'static, str> = "Phil".into(); + let cow_slice: Cow<'static, [u8]> = Cow::Borrowed(&[1, 2, 3, 4]); + + let row = sqlx::query("SELECT ?, ?, ?, ?") + .bind(&box_str) + .bind(&box_slice) + .bind(&cow_str) + .bind(&cow_slice) + .fetch_one(&mut conn) + .await?; + + let data: (Box, Box<[u8]>, Cow<'_, str>, Cow<'_, [u8]>) = FromRow::from_row(&row)?; + + assert!(data.0 == box_str); + assert!(data.1 == box_slice); + assert!(data.2 == cow_str); + assert!(data.3 == cow_slice); + Ok(()) +} diff --git a/tests/postgres/types.rs b/tests/postgres/types.rs index 51c5262daf..6a7dc79c3d 100644 --- a/tests/postgres/types.rs +++ b/tests/postgres/types.rs @@ -9,6 +9,7 @@ use std::sync::Arc; use sqlx::postgres::types::{Oid, PgCiText, PgInterval, PgMoney, PgRange}; use sqlx::postgres::Postgres; +use sqlx::FromRow; use sqlx_test::{new, test_decode_type, test_prepared_type, test_type}; use sqlx_core::executor::Executor; @@ -741,7 +742,7 @@ CREATE TEMPORARY TABLE user_login ( } #[sqlx_macros::test] -async fn test_arc() -> anyhow::Result<()> { +async fn test_smartpointers() -> anyhow::Result<()> { let mut conn = new::().await?; let user_age: (Arc, Cow<'static, i32>, Box, i32) = @@ -759,3 +760,29 @@ async fn test_arc() -> anyhow::Result<()> { assert!(user_age.3 == 4); Ok(()) } + +#[sqlx_macros::test] +async fn test_str_slice() -> anyhow::Result<()> { + let mut conn = new::().await?; + + let box_str: Box = "John".into(); + let box_slice: Box<[u8]> = [1, 2, 3, 4].into(); + let cow_str: Cow<'static, str> = "Phil".into(); + let cow_slice: Cow<'static, [u8]> = Cow::Borrowed(&[1, 2, 3, 4]); + + let row = sqlx::query("SELECT $1, $2, $3, $4") + .bind(&box_str) + .bind(&box_slice) + .bind(&cow_str) + .bind(&cow_slice) + .fetch_one(&mut conn) + .await?; + + let data: (Box, Box<[u8]>, Cow<'_, str>, Cow<'_, [u8]>) = FromRow::from_row(&row)?; + + assert!(data.0 == box_str); + assert!(data.1 == box_slice); + assert!(data.2 == cow_str); + assert!(data.3 == cow_slice); + Ok(()) +} diff --git a/tests/sqlite/types.rs b/tests/sqlite/types.rs index 2497e406cc..57cc4cd972 100644 --- a/tests/sqlite/types.rs +++ b/tests/sqlite/types.rs @@ -1,12 +1,16 @@ extern crate time_ as time; use sqlx::sqlite::{Sqlite, SqliteRow}; +use sqlx::FromRow; use sqlx_core::executor::Executor; use sqlx_core::row::Row; use sqlx_core::types::Text; use sqlx_test::new; use sqlx_test::test_type; +use std::borrow::Cow; use std::net::SocketAddr; +use std::rc::Rc; +use std::sync::Arc; test_type!(null>(Sqlite, "NULL" == None:: @@ -250,3 +254,49 @@ CREATE TEMPORARY TABLE user_login ( Ok(()) } + +#[sqlx_macros::test] +async fn test_smartpointers() -> anyhow::Result<()> { + let mut conn = new::().await?; + + let user_age: (Arc, Cow<'static, i32>, Box, i32) = + sqlx::query_as("SELECT $1, $2, $3, $4") + .bind(Arc::new(1i32)) + .bind(Cow::<'_, i32>::Borrowed(&2i32)) + .bind(Box::new(3i32)) + .bind(Rc::new(4i32)) + .fetch_one(&mut conn) + .await?; + + assert!(user_age.0.as_ref() == &1); + assert!(user_age.1.as_ref() == &2); + assert!(user_age.2.as_ref() == &3); + assert!(user_age.3 == 4); + Ok(()) +} + +#[sqlx_macros::test] +async fn test_str_slice() -> anyhow::Result<()> { + let mut conn = new::().await?; + + let box_str: Box = "John".into(); + let box_slice: Box<[u8]> = [1, 2, 3, 4].into(); + let cow_str: Cow<'static, str> = "Phil".into(); + let cow_slice: Cow<'static, [u8]> = Cow::Borrowed(&[1, 2, 3, 4]); + + let row = sqlx::query("SELECT $1, $2, $3, $4") + .bind(&box_str) + .bind(&box_slice) + .bind(&cow_str) + .bind(&cow_slice) + .fetch_one(&mut conn) + .await?; + + let data: (Box, Box<[u8]>, Cow<'_, str>, Cow<'_, [u8]>) = FromRow::from_row(&row)?; + + assert!(data.0 == box_str); + assert!(data.1 == box_slice); + assert!(data.2 == cow_str); + assert!(data.3 == cow_slice); + Ok(()) +} From c8ae9c8f097d5798ada345edda76bbaa38896d12 Mon Sep 17 00:00:00 2001 From: Joey de Waal Date: Wed, 19 Mar 2025 13:25:28 +0100 Subject: [PATCH 09/11] update unit tests --- tests/mysql/types.rs | 59 +++++++++-------------------------------- tests/postgres/types.rs | 58 ++++++++-------------------------------- tests/sqlite/types.rs | 57 ++++++++------------------------------- 3 files changed, 34 insertions(+), 140 deletions(-) diff --git a/tests/mysql/types.rs b/tests/mysql/types.rs index 3a2ecd962c..f323eb7cbc 100644 --- a/tests/mysql/types.rs +++ b/tests/mysql/types.rs @@ -15,7 +15,7 @@ use sqlx::types::Text; use sqlx::mysql::types::MySqlTime; use sqlx_mysql::types::MySqlTimeSign; -use sqlx_test::{new, test_type}; +use sqlx_test::{new, test_prepared_type, test_type}; test_type!(bool(MySql, "false" == false, "true" == true)); @@ -303,6 +303,17 @@ mod json_tests { )); } +test_type!(test_arc>(MySql, "1" == Arc::new(1i32))); +test_type!(test_cow>(MySql, "1" == Cow::::Owned(1i32))); +test_type!(test_box>(MySql, "1" == Box::new(1i32))); +test_type!(test_rc>(MySql, "1" == Rc::new(1i32))); + +test_type!(test_box_str>(MySql, "'John'" == Box::::from("John"))); +test_type!(test_cow_str>(MySql, "'Phil'" == Cow::<'static, str>::from("Phil"))); + +test_prepared_type!(test_box_slice>(MySql, "X'01020304'" == Box::<[u8]>::from([1,2,3,4]))); +test_prepared_type!(test_cow_slice>(MySql, "X'01020304'" == Cow::<'static, [u8]>::from(&[1,2,3,4]))); + #[sqlx_macros::test] async fn test_bits() -> anyhow::Result<()> { let mut conn = new::().await?; @@ -387,49 +398,3 @@ CREATE TEMPORARY TABLE user_login ( Ok(()) } - -#[sqlx_macros::test] -async fn test_smartpointers() -> anyhow::Result<()> { - let mut conn = new::().await?; - - let user_age: (Arc, Cow<'static, i32>, Box, i32) = - sqlx::query_as("SELECT ?, ?, ?, ?") - .bind(Arc::new(1i32)) - .bind(Cow::<'_, i32>::Borrowed(&2i32)) - .bind(Box::new(3i32)) - .bind(Rc::new(4i32)) - .fetch_one(&mut conn) - .await?; - - assert!(user_age.0.as_ref() == &1); - assert!(user_age.1.as_ref() == &2); - assert!(user_age.2.as_ref() == &3); - assert!(user_age.3 == 4); - Ok(()) -} - -#[sqlx_macros::test] -async fn test_str_slice() -> anyhow::Result<()> { - let mut conn = new::().await?; - - let box_str: Box = "John".into(); - let box_slice: Box<[u8]> = [1, 2, 3, 4].into(); - let cow_str: Cow<'static, str> = "Phil".into(); - let cow_slice: Cow<'static, [u8]> = Cow::Borrowed(&[1, 2, 3, 4]); - - let row = sqlx::query("SELECT ?, ?, ?, ?") - .bind(&box_str) - .bind(&box_slice) - .bind(&cow_str) - .bind(&cow_slice) - .fetch_one(&mut conn) - .await?; - - let data: (Box, Box<[u8]>, Cow<'_, str>, Cow<'_, [u8]>) = FromRow::from_row(&row)?; - - assert!(data.0 == box_str); - assert!(data.1 == box_slice); - assert!(data.2 == cow_str); - assert!(data.3 == cow_slice); - Ok(()) -} diff --git a/tests/postgres/types.rs b/tests/postgres/types.rs index 6a7dc79c3d..d084e021c9 100644 --- a/tests/postgres/types.rs +++ b/tests/postgres/types.rs @@ -9,7 +9,6 @@ use std::sync::Arc; use sqlx::postgres::types::{Oid, PgCiText, PgInterval, PgMoney, PgRange}; use sqlx::postgres::Postgres; -use sqlx::FromRow; use sqlx_test::{new, test_decode_type, test_prepared_type, test_type}; use sqlx_core::executor::Executor; @@ -698,6 +697,17 @@ test_type!(ltree_vec>(Postgres, ] )); +test_type!(test_arc>(Postgres, "1::INT4" == Arc::new(1i32))); +test_type!(test_cow>(Postgres, "1::INT4" == Cow::::Owned(1i32))); +test_type!(test_box>(Postgres, "1::INT4" == Box::new(1i32))); +test_type!(test_rc>(Postgres, "1::INT4" == Rc::new(1i32))); + +test_type!(test_box_str>(Postgres, "'John'::TEXT" == Box::::from("John"))); +test_type!(test_cow_str>(Postgres, "'Phil'::TEXT" == Cow::<'static, str>::from("Phil"))); + +test_prepared_type!(test_box_slice>(Postgres, "'\\x01020304'::BYTEA" == Box::<[u8]>::from([1,2,3,4]))); +test_prepared_type!(test_cow_slice>(Postgres, "'\\x01020304'::BYTEA" == Cow::<'static, [u8]>::from(&[1,2,3,4]))); + #[sqlx_macros::test] async fn test_text_adapter() -> anyhow::Result<()> { #[derive(sqlx::FromRow, Debug, PartialEq, Eq)] @@ -740,49 +750,3 @@ CREATE TEMPORARY TABLE user_login ( Ok(()) } - -#[sqlx_macros::test] -async fn test_smartpointers() -> anyhow::Result<()> { - let mut conn = new::().await?; - - let user_age: (Arc, Cow<'static, i32>, Box, i32) = - sqlx::query_as("SELECT $1, $2, $3, $4") - .bind(Arc::new(1i32)) - .bind(Cow::<'_, i32>::Borrowed(&2i32)) - .bind(Box::new(3i32)) - .bind(Rc::new(4i32)) - .fetch_one(&mut conn) - .await?; - - assert!(user_age.0.as_ref() == &1); - assert!(user_age.1.as_ref() == &2); - assert!(user_age.2.as_ref() == &3); - assert!(user_age.3 == 4); - Ok(()) -} - -#[sqlx_macros::test] -async fn test_str_slice() -> anyhow::Result<()> { - let mut conn = new::().await?; - - let box_str: Box = "John".into(); - let box_slice: Box<[u8]> = [1, 2, 3, 4].into(); - let cow_str: Cow<'static, str> = "Phil".into(); - let cow_slice: Cow<'static, [u8]> = Cow::Borrowed(&[1, 2, 3, 4]); - - let row = sqlx::query("SELECT $1, $2, $3, $4") - .bind(&box_str) - .bind(&box_slice) - .bind(&cow_str) - .bind(&cow_slice) - .fetch_one(&mut conn) - .await?; - - let data: (Box, Box<[u8]>, Cow<'_, str>, Cow<'_, [u8]>) = FromRow::from_row(&row)?; - - assert!(data.0 == box_str); - assert!(data.1 == box_slice); - assert!(data.2 == cow_str); - assert!(data.3 == cow_slice); - Ok(()) -} diff --git a/tests/sqlite/types.rs b/tests/sqlite/types.rs index 57cc4cd972..24232a03b8 100644 --- a/tests/sqlite/types.rs +++ b/tests/sqlite/types.rs @@ -212,6 +212,17 @@ test_type!(uuid_simple(Sqlite, == sqlx::types::Uuid::parse_str("00000000000000000000000000000000").unwrap().simple() )); +test_type!(test_arc>(Sqlite, "1" == Arc::new(1i32))); +test_type!(test_cow>(Sqlite, "1" == Cow::::Owned(1i32))); +test_type!(test_box>(Sqlite, "1" == Box::new(1i32))); +test_type!(test_rc>(Sqlite, "1" == Rc::new(1i32))); + +test_type!(test_box_str>(Sqlite, "'John'" == Box::::from("John"))); +test_type!(test_cow_str>(Sqlite, "'Phil'" == Cow::<'static, str>::from("Phil"))); + +test_type!(test_box_slice>(Sqlite, "X'01020304'" == Box::<[u8]>::from([1,2,3,4]))); +test_type!(test_cow_slice>(Sqlite, "X'01020304'" == Cow::<'static, [u8]>::from(&[1,2,3,4]))); + #[sqlx_macros::test] async fn test_text_adapter() -> anyhow::Result<()> { #[derive(sqlx::FromRow, Debug, PartialEq, Eq)] @@ -254,49 +265,3 @@ CREATE TEMPORARY TABLE user_login ( Ok(()) } - -#[sqlx_macros::test] -async fn test_smartpointers() -> anyhow::Result<()> { - let mut conn = new::().await?; - - let user_age: (Arc, Cow<'static, i32>, Box, i32) = - sqlx::query_as("SELECT $1, $2, $3, $4") - .bind(Arc::new(1i32)) - .bind(Cow::<'_, i32>::Borrowed(&2i32)) - .bind(Box::new(3i32)) - .bind(Rc::new(4i32)) - .fetch_one(&mut conn) - .await?; - - assert!(user_age.0.as_ref() == &1); - assert!(user_age.1.as_ref() == &2); - assert!(user_age.2.as_ref() == &3); - assert!(user_age.3 == 4); - Ok(()) -} - -#[sqlx_macros::test] -async fn test_str_slice() -> anyhow::Result<()> { - let mut conn = new::().await?; - - let box_str: Box = "John".into(); - let box_slice: Box<[u8]> = [1, 2, 3, 4].into(); - let cow_str: Cow<'static, str> = "Phil".into(); - let cow_slice: Cow<'static, [u8]> = Cow::Borrowed(&[1, 2, 3, 4]); - - let row = sqlx::query("SELECT $1, $2, $3, $4") - .bind(&box_str) - .bind(&box_slice) - .bind(&cow_str) - .bind(&cow_slice) - .fetch_one(&mut conn) - .await?; - - let data: (Box, Box<[u8]>, Cow<'_, str>, Cow<'_, [u8]>) = FromRow::from_row(&row)?; - - assert!(data.0 == box_str); - assert!(data.1 == box_slice); - assert!(data.2 == cow_str); - assert!(data.3 == cow_slice); - Ok(()) -} From 86d11ef4b7b6ff2799039c4867c9cc51535e8742 Mon Sep 17 00:00:00 2001 From: Joey de Waal Date: Sat, 5 Apr 2025 22:27:40 +0200 Subject: [PATCH 10/11] fixes after review --- sqlx-core/src/decode.rs | 23 ++++++++++------------- sqlx-core/src/encode.rs | 2 +- sqlx-core/src/types/mod.rs | 8 +++----- sqlx-mysql/src/types/bytes.rs | 8 ++++++++ sqlx-mysql/src/types/str.rs | 6 ------ sqlx-postgres/src/types/bytes.rs | 8 ++++++++ sqlx-postgres/src/types/str.rs | 24 +++++++++--------------- sqlx-sqlite/src/types/bytes.rs | 17 +++++++++++++++++ sqlx-sqlite/src/types/str.rs | 17 ----------------- tests/postgres/types.rs | 2 +- tests/sqlite/types.rs | 22 +++++++++++++++++++++- 11 files changed, 78 insertions(+), 59 deletions(-) diff --git a/sqlx-core/src/decode.rs b/sqlx-core/src/decode.rs index 081228e958..bbcf315a73 100644 --- a/sqlx-core/src/decode.rs +++ b/sqlx-core/src/decode.rs @@ -108,10 +108,10 @@ macro_rules! impl_decode_for_smartpointer { impl<'r, DB> Decode<'r, DB> for $smart_pointer<[u8]> where DB: Database, - &'r [u8]: Decode<'r, DB>, + Vec: Decode<'r, DB>, { fn decode(value: ::ValueRef<'r>) -> Result { - let ref_str = <&[u8] as Decode>::decode(value)?; + let ref_str = as Decode>::decode(value)?; Ok(ref_str.into()) } } @@ -123,36 +123,33 @@ impl_decode_for_smartpointer!(Box); impl_decode_for_smartpointer!(Rc); // implement `Decode` for Cow for all SQL types -impl<'r, DB, T> Decode<'r, DB> for Cow<'_, T> +impl<'r, 'a, DB, T> Decode<'r, DB> for Cow<'a, T> where DB: Database, T: ToOwned, ::Owned: Decode<'r, DB>, { fn decode(value: ::ValueRef<'r>) -> Result { - let owned = <::Owned as Decode>::decode(value)?; - Ok(Cow::Owned(owned)) + <::Owned as Decode>::decode(value).map(Cow::Owned) } } -impl<'r, DB> Decode<'r, DB> for Cow<'r, str> +impl<'r, 'a, DB> Decode<'r, DB> for Cow<'a, str> where DB: Database, - &'r str: Decode<'r, DB>, + String: Decode<'r, DB>, { fn decode(value: ::ValueRef<'r>) -> Result { - let borrowed = <&str as Decode>::decode(value)?; - Ok(Cow::Borrowed(borrowed)) + >::decode(value).map(Cow::Owned) } } -impl<'r, DB> Decode<'r, DB> for Cow<'r, [u8]> +impl<'r, 'a, DB> Decode<'r, DB> for Cow<'a, [u8]> where DB: Database, - &'r [u8]: Decode<'r, DB>, + Vec: Decode<'r, DB>, { fn decode(value: ::ValueRef<'r>) -> Result { - let borrowed = <&[u8] as Decode>::decode(value)?; - Ok(Cow::Borrowed(borrowed)) + as Decode>::decode(value).map(Cow::Owned) } } diff --git a/sqlx-core/src/encode.rs b/sqlx-core/src/encode.rs index 3273a20e62..1a149b6e87 100644 --- a/sqlx-core/src/encode.rs +++ b/sqlx-core/src/encode.rs @@ -172,7 +172,7 @@ impl_encode_for_smartpointer!(Arc); impl_encode_for_smartpointer!(Box); impl_encode_for_smartpointer!(Rc); -impl<'q, T, DB: Database> Encode<'q, DB> for Cow<'_, T> +impl<'q, T, DB: Database> Encode<'q, DB> for Cow<'q, T> where T: Encode<'q, DB>, T: ToOwned, diff --git a/sqlx-core/src/types/mod.rs b/sqlx-core/src/types/mod.rs index 27dcf7c44b..f6ac3ac74d 100644 --- a/sqlx-core/src/types/mod.rs +++ b/sqlx-core/src/types/mod.rs @@ -255,8 +255,7 @@ macro_rules! impl_type_for_smartpointer { ($smart_pointer:ty) => { impl Type for $smart_pointer where - T: Type, - T: ?Sized, + T: Type + ?Sized, { fn type_info() -> DB::TypeInfo { >::type_info() @@ -275,9 +274,8 @@ impl_type_for_smartpointer!(Rc); impl Type for Cow<'_, T> where - T: Type, - T: ToOwned, - T: ?Sized, + // `ToOwned` is required here to satisfy `Cow` + T: Type + ToOwned + ?Sized, { fn type_info() -> DB::TypeInfo { >::type_info() diff --git a/sqlx-mysql/src/types/bytes.rs b/sqlx-mysql/src/types/bytes.rs index 3df0044b69..023a8ee87a 100644 --- a/sqlx-mysql/src/types/bytes.rs +++ b/sqlx-mysql/src/types/bytes.rs @@ -1,3 +1,5 @@ +use std::borrow::Cow; + use crate::decode::Decode; use crate::encode::{Encode, IsNull}; use crate::error::BoxDynError; @@ -67,3 +69,9 @@ impl Decode<'_, MySql> for Vec { <&[u8] as Decode>::decode(value).map(ToOwned::to_owned) } } + +impl Encode<'_, MySql> for Cow<'_, [u8]> { + fn encode_by_ref(&self, buf: &mut Vec) -> Result { + <&[u8] as Encode>::encode(self.as_ref(), buf) + } +} diff --git a/sqlx-mysql/src/types/str.rs b/sqlx-mysql/src/types/str.rs index b815ff79cf..4e2730577a 100644 --- a/sqlx-mysql/src/types/str.rs +++ b/sqlx-mysql/src/types/str.rs @@ -83,9 +83,3 @@ impl Encode<'_, MySql> for Cow<'_, str> { } } } - -impl Encode<'_, MySql> for Cow<'_, [u8]> { - fn encode_by_ref(&self, buf: &mut Vec) -> Result { - <&[u8] as Encode>::encode(self.as_ref(), buf) - } -} diff --git a/sqlx-postgres/src/types/bytes.rs b/sqlx-postgres/src/types/bytes.rs index 32ae7aeefb..17b7ce9a3f 100644 --- a/sqlx-postgres/src/types/bytes.rs +++ b/sqlx-postgres/src/types/bytes.rs @@ -1,3 +1,5 @@ +use std::borrow::Cow; + use crate::decode::Decode; use crate::encode::{Encode, IsNull}; use crate::error::BoxDynError; @@ -101,3 +103,9 @@ impl Decode<'_, Postgres> for [u8; N] { Ok(bytes) } } + +impl Encode<'_, Postgres> for Cow<'_, [u8]> { + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { + <&[u8] as Encode>::encode(self.as_ref(), buf) + } +} diff --git a/sqlx-postgres/src/types/str.rs b/sqlx-postgres/src/types/str.rs index 8c1214d161..8b9c33ef47 100644 --- a/sqlx-postgres/src/types/str.rs +++ b/sqlx-postgres/src/types/str.rs @@ -82,6 +82,15 @@ impl Encode<'_, Postgres> for &'_ str { } } +impl Encode<'_, Postgres> for Cow<'_, str> { + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { + match self { + Cow::Borrowed(str) => <&str as Encode>::encode(*str, buf), + Cow::Owned(str) => <&str as Encode>::encode(&**str, buf), + } + } +} + impl Encode<'_, Postgres> for Box { fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { <&str as Encode>::encode(&**self, buf) @@ -105,18 +114,3 @@ impl Decode<'_, Postgres> for String { Ok(value.as_str()?.to_owned()) } } - -impl Encode<'_, Postgres> for Cow<'_, str> { - fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { - match self { - Cow::Borrowed(str) => <&str as Encode>::encode(*str, buf), - Cow::Owned(str) => <&str as Encode>::encode(&**str, buf), - } - } -} - -impl Encode<'_, Postgres> for Cow<'_, [u8]> { - fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result { - <&[u8] as Encode>::encode(self.as_ref(), buf) - } -} diff --git a/sqlx-sqlite/src/types/bytes.rs b/sqlx-sqlite/src/types/bytes.rs index ddba252ede..48dffe0ae5 100644 --- a/sqlx-sqlite/src/types/bytes.rs +++ b/sqlx-sqlite/src/types/bytes.rs @@ -85,3 +85,20 @@ impl<'r> Decode<'r, Sqlite> for Vec { Ok(value.blob().to_owned()) } } + +impl<'q> Encode<'q, Sqlite> for Cow<'q, [u8]> { + fn encode(self, args: &mut Vec>) -> Result { + args.push(SqliteArgumentValue::Blob(self)); + + Ok(IsNull::No) + } + + fn encode_by_ref( + &self, + args: &mut Vec>, + ) -> Result { + args.push(SqliteArgumentValue::Blob(self.clone())); + + Ok(IsNull::No) + } +} diff --git a/sqlx-sqlite/src/types/str.rs b/sqlx-sqlite/src/types/str.rs index c54ec035b7..6c51fa1aff 100644 --- a/sqlx-sqlite/src/types/str.rs +++ b/sqlx-sqlite/src/types/str.rs @@ -94,20 +94,3 @@ impl<'q> Encode<'q, Sqlite> for Cow<'q, str> { Ok(IsNull::No) } } - -impl<'q> Encode<'q, Sqlite> for Cow<'q, [u8]> { - fn encode(self, args: &mut Vec>) -> Result { - args.push(SqliteArgumentValue::Blob(self)); - - Ok(IsNull::No) - } - - fn encode_by_ref( - &self, - args: &mut Vec>, - ) -> Result { - args.push(SqliteArgumentValue::Blob(self.clone())); - - Ok(IsNull::No) - } -} diff --git a/tests/postgres/types.rs b/tests/postgres/types.rs index d084e021c9..569fe585e1 100644 --- a/tests/postgres/types.rs +++ b/tests/postgres/types.rs @@ -3,8 +3,8 @@ extern crate time_ as time; use std::borrow::Cow; use std::net::SocketAddr; use std::ops::Bound; -use std::str::FromStr; use std::rc::Rc; +use std::str::FromStr; use std::sync::Arc; use sqlx::postgres::types::{Oid, PgCiText, PgInterval, PgMoney, PgRange}; diff --git a/tests/sqlite/types.rs b/tests/sqlite/types.rs index 24232a03b8..ddc2969705 100644 --- a/tests/sqlite/types.rs +++ b/tests/sqlite/types.rs @@ -1,7 +1,7 @@ extern crate time_ as time; use sqlx::sqlite::{Sqlite, SqliteRow}; -use sqlx::FromRow; +use sqlx::{FromRow, Type}; use sqlx_core::executor::Executor; use sqlx_core::row::Row; use sqlx_core::types::Text; @@ -265,3 +265,23 @@ CREATE TEMPORARY TABLE user_login ( Ok(()) } + +#[sqlx_macros::test] +async fn it_binds_with_borrowed_data() -> anyhow::Result<()> { + #[derive(Debug, Type, Clone)] + #[sqlx(rename_all = "lowercase")] + enum Status { + New, + Open, + Closed, + } + + let owned = Status::New; + + let mut conn = new::().await?; + sqlx::query("select ?") + .bind(Cow::Borrowed(&owned)) + .fetch_one(&mut conn) + .await?; + Ok(()) +} From bf098815c4ffc20fbf02e4282c673bfaf57efcfd Mon Sep 17 00:00:00 2001 From: Joey de Waal Date: Sat, 5 Apr 2025 23:34:35 +0200 Subject: [PATCH 11/11] add comment in `Decode` impl --- sqlx-core/src/decode.rs | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/sqlx-core/src/decode.rs b/sqlx-core/src/decode.rs index bbcf315a73..c3abe6bf87 100644 --- a/sqlx-core/src/decode.rs +++ b/sqlx-core/src/decode.rs @@ -111,8 +111,10 @@ macro_rules! impl_decode_for_smartpointer { Vec: Decode<'r, DB>, { fn decode(value: ::ValueRef<'r>) -> Result { - let ref_str = as Decode>::decode(value)?; - Ok(ref_str.into()) + // The `Postgres` implementation requires this to be decoded as an owned value because + // bytes can be sent in text format. + let bytes = as Decode>::decode(value)?; + Ok(bytes.into()) } } };