Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Decode, Encode and Type for Box, Arc, Cow and Rc #3674

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
78 changes: 78 additions & 0 deletions sqlx-core/src/decode.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
//! Provides [`Decode`] for decoding values from the database.

use std::borrow::Cow;
use std::rc::Rc;
use std::sync::Arc;

use crate::database::Database;
use crate::error::BoxDynError;

Expand Down Expand Up @@ -77,3 +81,77 @@ where
}
}
}

macro_rules! impl_decode_for_smartpointer {
($smart_pointer:tt) => {
impl<'r, DB, T> Decode<'r, DB> for $smart_pointer<T>
where
DB: Database,
T: Decode<'r, DB>,
{
fn decode(value: <DB as Database>::ValueRef<'r>) -> Result<Self, BoxDynError> {
Ok(Self::new(T::decode(value)?))
}
}

impl<'r, DB> Decode<'r, DB> for $smart_pointer<str>
where
DB: Database,
&'r str: Decode<'r, DB>,
{
fn decode(value: <DB as Database>::ValueRef<'r>) -> Result<Self, BoxDynError> {
let ref_str = <&str as Decode<DB>>::decode(value)?;
Ok(ref_str.into())
}
}

impl<'r, DB> Decode<'r, DB> for $smart_pointer<[u8]>
where
DB: Database,
Vec<u8>: Decode<'r, DB>,
{
fn decode(value: <DB as Database>::ValueRef<'r>) -> Result<Self, BoxDynError> {
// The `Postgres` implementation requires this to be decoded as an owned value because
// bytes can be sent in text format.
let bytes = <Vec<u8> as Decode<DB>>::decode(value)?;
Ok(bytes.into())
}
}
};
}

impl_decode_for_smartpointer!(Arc);
impl_decode_for_smartpointer!(Box);
impl_decode_for_smartpointer!(Rc);

// implement `Decode` for Cow<T> for all SQL types
impl<'r, 'a, DB, T> Decode<'r, DB> for Cow<'a, T>
where
DB: Database,
T: ToOwned,
<T as ToOwned>::Owned: Decode<'r, DB>,
{
fn decode(value: <DB as Database>::ValueRef<'r>) -> Result<Self, BoxDynError> {
<<T as ToOwned>::Owned as Decode<DB>>::decode(value).map(Cow::Owned)
}
}

impl<'r, 'a, DB> Decode<'r, DB> for Cow<'a, str>
where
DB: Database,
String: Decode<'r, DB>,
{
fn decode(value: <DB as Database>::ValueRef<'r>) -> Result<Self, BoxDynError> {
<String as Decode<DB>>::decode(value).map(Cow::Owned)
}
}

impl<'r, 'a, DB> Decode<'r, DB> for Cow<'a, [u8]>
where
DB: Database,
Vec<u8>: Decode<'r, DB>,
{
fn decode(value: <DB as Database>::ValueRef<'r>) -> Result<Self, BoxDynError> {
<Vec<u8> as Decode<DB>>::decode(value).map(Cow::Owned)
}
}
71 changes: 71 additions & 0 deletions sqlx-core/src/encode.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
//! Provides [`Encode`] for encoding values for the database.

use std::borrow::Cow;
use std::mem;
use std::rc::Rc;
use std::sync::Arc;

use crate::database::Database;
use crate::error::BoxDynError;
Expand Down Expand Up @@ -129,3 +132,71 @@ macro_rules! impl_encode_for_option {
}
};
}

macro_rules! impl_encode_for_smartpointer {
($smart_pointer:ty) => {
impl<'q, T, DB: Database> Encode<'q, DB> for $smart_pointer
where
T: Encode<'q, DB>,
{
#[inline]
fn encode(
self,
buf: &mut <DB as Database>::ArgumentBuffer<'q>,
) -> Result<IsNull, BoxDynError> {
<T as Encode<DB>>::encode_by_ref(self.as_ref(), buf)
}

#[inline]
fn encode_by_ref(
&self,
buf: &mut <DB as Database>::ArgumentBuffer<'q>,
) -> Result<IsNull, BoxDynError> {
<&T as Encode<DB>>::encode(self, buf)
}

#[inline]
fn produces(&self) -> Option<DB::TypeInfo> {
(**self).produces()
}

#[inline]
fn size_hint(&self) -> usize {
(**self).size_hint()
}
}
};
}

impl_encode_for_smartpointer!(Arc<T>);
impl_encode_for_smartpointer!(Box<T>);
impl_encode_for_smartpointer!(Rc<T>);

impl<'q, T, DB: Database> Encode<'q, DB> for Cow<'q, T>
where
T: Encode<'q, DB>,
T: ToOwned,
{
#[inline]
fn encode(self, buf: &mut <DB as Database>::ArgumentBuffer<'q>) -> Result<IsNull, BoxDynError> {
<&T as Encode<DB>>::encode_by_ref(&self.as_ref(), buf)
}

#[inline]
fn encode_by_ref(
&self,
buf: &mut <DB as Database>::ArgumentBuffer<'q>,
) -> Result<IsNull, BoxDynError> {
<&T as Encode<DB>>::encode_by_ref(&self.as_ref(), buf)
}

#[inline]
fn produces(&self) -> Option<DB::TypeInfo> {
<&T as Encode<DB>>::produces(&self.as_ref())
}

#[inline]
fn size_hint(&self) -> usize {
<&T as Encode<DB>>::size_hint(&self.as_ref())
}
}
37 changes: 37 additions & 0 deletions sqlx-core/src/types/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
//! To represent nullable SQL types, `Option<T>` is supported where `T` implements `Type`.
//! An `Option<T>` represents a potentially `NULL` value from SQL.

use std::{borrow::Cow, rc::Rc, sync::Arc};

use crate::database::Database;
use crate::type_info::TypeInfo;

Expand Down Expand Up @@ -248,3 +250,38 @@ impl<T: Type<DB>, DB: Database> Type<DB> for Option<T> {
ty.is_null() || <T as Type<DB>>::compatible(ty)
}
}

macro_rules! impl_type_for_smartpointer {
($smart_pointer:ty) => {
impl<T, DB: Database> Type<DB> for $smart_pointer
where
T: Type<DB> + ?Sized,
{
fn type_info() -> DB::TypeInfo {
<T as Type<DB>>::type_info()
}

fn compatible(ty: &DB::TypeInfo) -> bool {
<T as Type<DB>>::compatible(ty)
}
}
};
}

impl_type_for_smartpointer!(Arc<T>);
impl_type_for_smartpointer!(Box<T>);
impl_type_for_smartpointer!(Rc<T>);

impl<T, DB: Database> Type<DB> for Cow<'_, T>
where
// `ToOwned` is required here to satisfy `Cow`
T: Type<DB> + ToOwned + ?Sized,
{
fn type_info() -> DB::TypeInfo {
<T as Type<DB>>::type_info()
}

fn compatible(ty: &DB::TypeInfo) -> bool {
<T as Type<DB>>::compatible(ty)
}
}
24 changes: 8 additions & 16 deletions sqlx-mysql/src/types/bytes.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::borrow::Cow;

use crate::decode::Decode;
use crate::encode::{Encode, IsNull};
use crate::error::BoxDynError;
Expand Down Expand Up @@ -40,28 +42,12 @@ impl<'r> Decode<'r, MySql> for &'r [u8] {
}
}

impl Type<MySql> for Box<[u8]> {
fn type_info() -> MySqlTypeInfo {
<&[u8] as Type<MySql>>::type_info()
}

fn compatible(ty: &MySqlTypeInfo) -> bool {
<&[u8] as Type<MySql>>::compatible(ty)
}
}

impl Encode<'_, MySql> for Box<[u8]> {
fn encode_by_ref(&self, buf: &mut Vec<u8>) -> Result<IsNull, BoxDynError> {
<&[u8] as Encode<MySql>>::encode(self.as_ref(), buf)
}
}

impl<'r> Decode<'r, MySql> for Box<[u8]> {
fn decode(value: MySqlValueRef<'r>) -> Result<Self, BoxDynError> {
<&[u8] as Decode<MySql>>::decode(value).map(Box::from)
}
}

impl Type<MySql> for Vec<u8> {
fn type_info() -> MySqlTypeInfo {
<[u8] as Type<MySql>>::type_info()
Expand All @@ -83,3 +69,9 @@ impl Decode<'_, MySql> for Vec<u8> {
<&[u8] as Decode<MySql>>::decode(value).map(ToOwned::to_owned)
}
}

impl Encode<'_, MySql> for Cow<'_, [u8]> {
fn encode_by_ref(&self, buf: &mut Vec<u8>) -> Result<IsNull, BoxDynError> {
<&[u8] as Encode<MySql>>::encode(self.as_ref(), buf)
}
}
35 changes: 2 additions & 33 deletions sqlx-mysql/src/types/str.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
use std::borrow::Cow;

use crate::decode::Decode;
use crate::encode::{Encode, IsNull};
use crate::error::BoxDynError;
use crate::io::MySqlBufMutExt;
use crate::protocol::text::{ColumnFlags, ColumnType};
use crate::types::Type;
use crate::{MySql, MySqlTypeInfo, MySqlValueRef};
use std::borrow::Cow;

impl Type<MySql> for str {
fn type_info() -> MySqlTypeInfo {
Expand Down Expand Up @@ -46,28 +47,12 @@ impl<'r> Decode<'r, MySql> for &'r str {
}
}

impl Type<MySql> for Box<str> {
fn type_info() -> MySqlTypeInfo {
<&str as Type<MySql>>::type_info()
}

fn compatible(ty: &MySqlTypeInfo) -> bool {
<&str as Type<MySql>>::compatible(ty)
}
}

impl Encode<'_, MySql> for Box<str> {
fn encode_by_ref(&self, buf: &mut Vec<u8>) -> Result<IsNull, BoxDynError> {
<&str as Encode<MySql>>::encode(&**self, buf)
}
}

impl<'r> Decode<'r, MySql> for Box<str> {
fn decode(value: MySqlValueRef<'r>) -> Result<Self, BoxDynError> {
<&str as Decode<MySql>>::decode(value).map(Box::from)
}
}

impl Type<MySql> for String {
fn type_info() -> MySqlTypeInfo {
<str as Type<MySql>>::type_info()
Expand All @@ -90,16 +75,6 @@ impl Decode<'_, MySql> for String {
}
}

impl Type<MySql> for Cow<'_, str> {
fn type_info() -> MySqlTypeInfo {
<&str as Type<MySql>>::type_info()
}

fn compatible(ty: &MySqlTypeInfo) -> bool {
<&str as Type<MySql>>::compatible(ty)
}
}

impl Encode<'_, MySql> for Cow<'_, str> {
fn encode_by_ref(&self, buf: &mut Vec<u8>) -> Result<IsNull, BoxDynError> {
match self {
Expand All @@ -108,9 +83,3 @@ impl Encode<'_, MySql> for Cow<'_, str> {
}
}
}

impl<'r> Decode<'r, MySql> for Cow<'r, str> {
fn decode(value: MySqlValueRef<'r>) -> Result<Self, BoxDynError> {
value.as_str().map(Cow::Borrowed)
}
}
17 changes: 8 additions & 9 deletions sqlx-postgres/src/types/bytes.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::borrow::Cow;

use crate::decode::Decode;
use crate::encode::{Encode, IsNull};
use crate::error::BoxDynError;
Expand Down Expand Up @@ -80,15 +82,6 @@ fn text_hex_decode_input(value: PgValueRef<'_>) -> Result<&[u8], BoxDynError> {
.map_err(Into::into)
}

impl Decode<'_, Postgres> for Box<[u8]> {
fn decode(value: PgValueRef<'_>) -> Result<Self, BoxDynError> {
Ok(match value.format() {
PgValueFormat::Binary => Box::from(value.as_bytes()?),
PgValueFormat::Text => Box::from(hex::decode(text_hex_decode_input(value)?)?),
})
}
}

impl Decode<'_, Postgres> for Vec<u8> {
fn decode(value: PgValueRef<'_>) -> Result<Self, BoxDynError> {
Ok(match value.format() {
Expand All @@ -110,3 +103,9 @@ impl<const N: usize> Decode<'_, Postgres> for [u8; N] {
Ok(bytes)
}
}

impl Encode<'_, Postgres> for Cow<'_, [u8]> {
fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result<IsNull, BoxDynError> {
<&[u8] as Encode<Postgres>>::encode(self.as_ref(), buf)
}
}
Loading