diff --git a/postgres/src/client.rs b/postgres/src/client.rs index 29cac840d..a74267270 100644 --- a/postgres/src/client.rs +++ b/postgres/src/client.rs @@ -46,6 +46,24 @@ impl Client { Config::new() } + /// Return the result format of client + /// + /// true indicates that the client will receive the result in binary format + /// false indicates that the client will receive the result in text format + pub fn result_format(&self) -> bool { + self.client.result_format() + } + + /// Set the format of return result. + /// + /// format + /// true: binary format + /// false: text format + /// default format is binary format(result_format = true) + pub fn set_result_format(&mut self, format: bool) { + self.client.set_result_format(format); + } + /// Executes a statement, returning the number of rows modified. /// /// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list diff --git a/postgres/src/test.rs b/postgres/src/test.rs index 0fd404574..f5857d337 100644 --- a/postgres/src/test.rs +++ b/postgres/src/test.rs @@ -508,3 +508,36 @@ fn check_send() { is_send::(); is_send::>(); } + +#[test] +fn query_text() { + let mut client = Client::connect("host=localhost port=5433 user=postgres", NoTls).unwrap(); + client.set_result_format(false); + + let rows = client.query("SELECT $1::TEXT", &[&"hello"]).unwrap(); + assert_eq!(rows.len(), 1); + assert_eq!(rows[0].get_text(0).unwrap(), "hello"); + + let rows = client.query("SELECT 2,'2022-01-01'::date", &[]).unwrap(); + assert_eq!(rows.len(), 1); + assert_eq!(rows[0].get_text(0).unwrap(), "2"); + assert_eq!(rows[0].get_text(1).unwrap(), "2022-01-01"); +} + +#[test] +fn transaction_text() { + let mut client = Client::connect("host=localhost port=5433 user=postgres", NoTls).unwrap(); + client.set_result_format(false); + + let mut transaction = client.transaction().unwrap(); + + let prepare_stmt = transaction.prepare("SELECT $1::INT8,$2::FLOAT4").unwrap(); + let portal = transaction + .bind(&prepare_stmt, &[&64_i64, &3.9999_f32]) + .unwrap(); + let rows = transaction.query_portal(&portal, 0).unwrap(); + + assert_eq!(rows.len(), 1); + assert_eq!(rows[0].get_text(0).unwrap(), "64"); + assert_eq!(rows[0].get_text(1).unwrap(), "3.9999"); +} diff --git a/tokio-postgres/src/bind.rs b/tokio-postgres/src/bind.rs index 9c5c49218..e2e1eb437 100644 --- a/tokio-postgres/src/bind.rs +++ b/tokio-postgres/src/bind.rs @@ -14,6 +14,7 @@ pub async fn bind( client: &Arc, statement: Statement, params: I, + result_format: bool, ) -> Result where P: BorrowToSql, @@ -22,7 +23,7 @@ where { let name = format!("p{}", NEXT_ID.fetch_add(1, Ordering::SeqCst)); let buf = client.with_buf(|buf| { - query::encode_bind(&statement, params, &name, buf)?; + query::encode_bind(&statement, params, &name, buf, result_format)?; frontend::sync(buf); Ok(buf.split().freeze()) })?; diff --git a/tokio-postgres/src/client.rs b/tokio-postgres/src/client.rs index ad5aa2866..26a6ba36d 100644 --- a/tokio-postgres/src/client.rs +++ b/tokio-postgres/src/client.rs @@ -170,6 +170,7 @@ pub struct Client { ssl_mode: SslMode, process_id: i32, secret_key: i32, + result_format: bool, } impl Client { @@ -190,6 +191,7 @@ impl Client { ssl_mode, process_id, secret_key, + result_format: true, } } @@ -202,6 +204,24 @@ impl Client { self.socket_config = Some(socket_config); } + /// Return the result format of client + /// + /// true indicates that the client will receive the result in binary format + /// false indicates that the client will receive the result in text format + pub fn result_format(&self) -> bool { + self.result_format + } + + /// Set the format of return result. + /// + /// format + /// true: binary format + /// false: text format + /// default format is binary format(result_format = true) + pub fn set_result_format(&mut self, format: bool) { + self.result_format = format; + } + /// Creates a new prepared statement. /// /// Prepared statements can be executed repeatedly, and may contain query parameters (indicated by `$1`, `$2`, etc), @@ -369,7 +389,7 @@ impl Client { I::IntoIter: ExactSizeIterator, { let statement = statement.__convert().into_statement(self).await?; - query::query(&self.inner, statement, params).await + query::query(&self.inner, statement, params, self.result_format).await } /// Executes a statement, returning the number of rows modified. diff --git a/tokio-postgres/src/copy_in.rs b/tokio-postgres/src/copy_in.rs index de1da933b..f9c2e5bc5 100644 --- a/tokio-postgres/src/copy_in.rs +++ b/tokio-postgres/src/copy_in.rs @@ -1,7 +1,7 @@ use crate::client::{InnerClient, Responses}; use crate::codec::FrontendMessage; use crate::connection::RequestMessages; -use crate::{query, slice_iter, Error, Statement}; +use crate::{query, slice_iter, Error, Statement, DEFAULT_RESULT_FORMAT}; use bytes::{Buf, BufMut, BytesMut}; use futures_channel::mpsc; use futures_util::{future, ready, Sink, SinkExt, Stream, StreamExt}; @@ -200,7 +200,7 @@ where { debug!("executing copy in statement {}", statement.name()); - let buf = query::encode(client, &statement, slice_iter(&[]))?; + let buf = query::encode(client, &statement, slice_iter(&[]), DEFAULT_RESULT_FORMAT)?; let (mut sender, receiver) = mpsc::channel(1); let receiver = CopyInReceiver::new(receiver); diff --git a/tokio-postgres/src/copy_out.rs b/tokio-postgres/src/copy_out.rs index 1e6949252..58124364a 100644 --- a/tokio-postgres/src/copy_out.rs +++ b/tokio-postgres/src/copy_out.rs @@ -1,7 +1,7 @@ use crate::client::{InnerClient, Responses}; use crate::codec::FrontendMessage; use crate::connection::RequestMessages; -use crate::{query, slice_iter, Error, Statement}; +use crate::{query, slice_iter, Error, Statement, DEFAULT_RESULT_FORMAT}; use bytes::Bytes; use futures_util::{ready, Stream}; use log::debug; @@ -14,7 +14,7 @@ use std::task::{Context, Poll}; pub async fn copy_out(client: &InnerClient, statement: Statement) -> Result { debug!("executing copy out statement {}", statement.name()); - let buf = query::encode(client, &statement, slice_iter(&[]))?; + let buf = query::encode(client, &statement, slice_iter(&[]), DEFAULT_RESULT_FORMAT)?; let responses = start(client, buf).await?; Ok(CopyOutStream { responses, diff --git a/tokio-postgres/src/lib.rs b/tokio-postgres/src/lib.rs index bd4d7b8ce..ca40da5a0 100644 --- a/tokio-postgres/src/lib.rs +++ b/tokio-postgres/src/lib.rs @@ -179,6 +179,9 @@ mod transaction; mod transaction_builder; pub mod types; +// Default result format : binary(true) +const DEFAULT_RESULT_FORMAT: bool = true; + /// A convenience function which parses a connection string and connects to the database. /// /// See the documentation for [`Config`] for details on the connection string format. diff --git a/tokio-postgres/src/prepare.rs b/tokio-postgres/src/prepare.rs index e3f09a7c2..2af7bb701 100644 --- a/tokio-postgres/src/prepare.rs +++ b/tokio-postgres/src/prepare.rs @@ -3,7 +3,7 @@ use crate::codec::FrontendMessage; use crate::connection::RequestMessages; use crate::error::SqlState; use crate::types::{Field, Kind, Oid, Type}; -use crate::{query, slice_iter}; +use crate::{query, slice_iter, DEFAULT_RESULT_FORMAT}; use crate::{Column, Error, Statement}; use bytes::Bytes; use fallible_iterator::FallibleIterator; @@ -137,7 +137,7 @@ async fn get_type(client: &Arc, oid: Oid) -> Result { let stmt = typeinfo_statement(client).await?; - let rows = query::query(client, stmt, slice_iter(&[&oid])).await?; + let rows = query::query(client, stmt, slice_iter(&[&oid]), DEFAULT_RESULT_FORMAT).await?; pin_mut!(rows); let row = match rows.try_next().await? { @@ -207,7 +207,7 @@ async fn typeinfo_statement(client: &Arc) -> Result, oid: Oid) -> Result, Error> { let stmt = typeinfo_enum_statement(client).await?; - query::query(client, stmt, slice_iter(&[&oid])) + query::query(client, stmt, slice_iter(&[&oid]), DEFAULT_RESULT_FORMAT) .await? .and_then(|row| async move { row.try_get(0) }) .try_collect() @@ -234,7 +234,7 @@ async fn typeinfo_enum_statement(client: &Arc) -> Result, oid: Oid) -> Result, Error> { let stmt = typeinfo_composite_statement(client).await?; - let rows = query::query(client, stmt, slice_iter(&[&oid])) + let rows = query::query(client, stmt, slice_iter(&[&oid]), DEFAULT_RESULT_FORMAT) .await? .try_collect::>() .await?; diff --git a/tokio-postgres/src/query.rs b/tokio-postgres/src/query.rs index 71db8769a..ae6aeb65b 100644 --- a/tokio-postgres/src/query.rs +++ b/tokio-postgres/src/query.rs @@ -2,7 +2,7 @@ use crate::client::{InnerClient, Responses}; use crate::codec::FrontendMessage; use crate::connection::RequestMessages; use crate::types::{BorrowToSql, IsNull}; -use crate::{Error, Portal, Row, Statement}; +use crate::{Error, Portal, Row, Statement, DEFAULT_RESULT_FORMAT}; use bytes::{Bytes, BytesMut}; use futures_util::{ready, Stream}; use log::{debug, log_enabled, Level}; @@ -31,6 +31,7 @@ pub async fn query( client: &InnerClient, statement: Statement, params: I, + result_format: bool, ) -> Result where P: BorrowToSql, @@ -44,9 +45,9 @@ where statement.name(), BorrowToSqlParamsDebug(params.as_slice()), ); - encode(client, &statement, params)? + encode(client, &statement, params, result_format)? } else { - encode(client, &statement, params)? + encode(client, &statement, params, result_format)? }; let responses = start(client, buf).await?; Ok(RowStream { @@ -93,9 +94,9 @@ where statement.name(), BorrowToSqlParamsDebug(params.as_slice()), ); - encode(client, &statement, params)? + encode(client, &statement, params, DEFAULT_RESULT_FORMAT)? } else { - encode(client, &statement, params)? + encode(client, &statement, params, DEFAULT_RESULT_FORMAT)? }; let mut responses = start(client, buf).await?; @@ -131,14 +132,19 @@ async fn start(client: &InnerClient, buf: Bytes) -> Result { Ok(responses) } -pub fn encode(client: &InnerClient, statement: &Statement, params: I) -> Result +pub fn encode( + client: &InnerClient, + statement: &Statement, + params: I, + result_format: bool, +) -> Result where P: BorrowToSql, I: IntoIterator, I::IntoIter: ExactSizeIterator, { client.with_buf(|buf| { - encode_bind(statement, params, "", buf)?; + encode_bind(statement, params, "", buf, result_format)?; frontend::execute("", 0, buf).map_err(Error::encode)?; frontend::sync(buf); Ok(buf.split().freeze()) @@ -150,6 +156,7 @@ pub fn encode_bind( params: I, portal: &str, buf: &mut BytesMut, + result_format: bool, ) -> Result<(), Error> where P: BorrowToSql, @@ -174,6 +181,7 @@ where let params = params.into_iter(); let mut error_idx = 0; + let result_format = if result_format { Some(1) } else { Some(0) }; let r = frontend::bind( portal, statement.name(), @@ -187,7 +195,7 @@ where Err(e) } }, - Some(1), + result_format, buf, ); match r { diff --git a/tokio-postgres/src/row.rs b/tokio-postgres/src/row.rs index e3ed696c1..e37a3c0fc 100644 --- a/tokio-postgres/src/row.rs +++ b/tokio-postgres/src/row.rs @@ -182,6 +182,46 @@ impl Row { FromSql::from_sql_nullable(ty, self.col_buffer(idx)).map_err(|e| Error::from_sql(e, idx)) } + /// Returns a value(text format) from the row. + /// + /// The value can be specified either by its numeric index in the row, or by its column name. + /// + /// NOTE: user should gurantee the result is text format + /// + /// # Panics + /// + /// Panics if the index is out of bounds or if the value cannot be converted to the TEXT type. + pub fn get_text(&self, idx: I) -> Option<&str> + where + I: RowIndex + fmt::Display, + { + match self.get_text_inner(&idx) { + Ok(ok) => ok, + Err(err) => panic!("error retrieving column {}: {}", idx, err), + } + } + + /// Like `Row::get_text`, but returns a `Result` rather than panicking. + pub fn try_get_text(&self, idx: I) -> Result, Error> + where + I: RowIndex + fmt::Display, + { + self.get_text_inner(&idx) + } + + fn get_text_inner(&self, idx: &I) -> Result, Error> + where + I: RowIndex + fmt::Display, + { + let idx = match idx.__idx(self.columns()) { + Some(idx) => idx, + None => return Err(Error::column(idx.to_string())), + }; + + let buf = self.ranges[idx].clone().map(|r| &self.body.buffer()[r]); + FromSql::from_sql_nullable(&Type::TEXT, buf).map_err(|e| Error::from_sql(e, idx)) + } + /// Get the raw bytes for the column at the given index. fn col_buffer(&self, idx: usize) -> Option<&[u8]> { let range = self.ranges[idx].to_owned()?; diff --git a/tokio-postgres/src/transaction.rs b/tokio-postgres/src/transaction.rs index 96a324652..6fef59af4 100644 --- a/tokio-postgres/src/transaction.rs +++ b/tokio-postgres/src/transaction.rs @@ -25,6 +25,7 @@ pub struct Transaction<'a> { client: &'a mut Client, savepoint: Option, done: bool, + result_format: bool, } /// A representation of a PostgreSQL database savepoint. @@ -57,10 +58,12 @@ impl<'a> Drop for Transaction<'a> { impl<'a> Transaction<'a> { pub(crate) fn new(client: &'a mut Client) -> Transaction<'a> { + let result_format = client.result_format(); Transaction { client, savepoint: None, done: false, + result_format, } } @@ -202,7 +205,7 @@ impl<'a> Transaction<'a> { I::IntoIter: ExactSizeIterator, { let statement = statement.__convert().into_statement(self.client).await?; - bind::bind(self.client.inner(), statement, params).await + bind::bind(self.client.inner(), statement, params, self.result_format).await } /// Continues execution of a portal, returning a stream of the resulting rows. @@ -304,6 +307,7 @@ impl<'a> Transaction<'a> { client: self.client, savepoint: Some(Savepoint { name, depth }), done: false, + result_format: self.result_format, }) }