diff --git a/docs/src/concepts/factories.md b/docs/src/concepts/factories.md index e081327..4fffcda 100644 --- a/docs/src/concepts/factories.md +++ b/docs/src/concepts/factories.md @@ -43,6 +43,36 @@ assert_eq!(product.name, "Anvil 3000"); Fields you don't set are filled with generated values automatically. +## Building Without Persisting + +Use `make()` instead of `create()` to build a model instance +in memory without hitting the database: + +```rust +# extern crate fabrique; +# extern crate sqlx; +# extern crate uuid; +# use fabrique::prelude::*; +# use uuid::Uuid; +# +# #[derive(Model, Factory)] +# pub struct Product { +# id: Uuid, +# name: String, +# price_cents: i32, +# } +# fn main() { +let product: Product = Product::factory::() + .name("Anvil 3000".to_string()) + .make(); + +assert_eq!(product.name, "Anvil 3000"); +# } +``` + +This is useful for preparing data for bulk operations like +[`upsert()`](../cookbook/bulk-update-and-upsert.md#bulk-upsert-a-collection). + ## Random Value Generation By default, factories generate random values for all fields using diff --git a/docs/src/concepts/models.md b/docs/src/concepts/models.md index 2f928b6..10953af 100644 --- a/docs/src/concepts/models.md +++ b/docs/src/concepts/models.md @@ -255,8 +255,12 @@ let anvil: Product = Product::find(&pool, anvil.id).await?; // Upsert (insert or update on PK conflict) let anvil: Product = anvil.save(&pool).await?; +// Bulk upsert a collection +let products = vec![anvil]; +products.upsert(&pool, Product::ID).await?; + // Delete by primary key, no instance needed -Product::destroy(&pool, anvil.id).await?; +Product::destroy(&pool, Uuid::new_v4()).await?; # Ok(()) # } ``` diff --git a/docs/src/cookbook/bulk-update-and-upsert.md b/docs/src/cookbook/bulk-update-and-upsert.md index 594d216..661b7ad 100644 --- a/docs/src/cookbook/bulk-update-and-upsert.md +++ b/docs/src/cookbook/bulk-update-and-upsert.md @@ -189,6 +189,73 @@ let saved: Product = Product::insert() values from the INSERT — the SQL equivalent of `SET col = EXCLUDED.col` for each column. +## Bulk Upsert a Collection + +When you have a `Vec` of models to upsert — e.g. syncing an +external data source — use `.upsert()` to insert or update +them all in a single statement: + +```rust +# extern crate fabrique; +# extern crate sqlx; +# extern crate tokio; +# extern crate uuid; +# use fabrique::prelude::*; +# use uuid::Uuid; +# +# #[derive(Clone, Debug, Factory, Model)] +# pub struct Product { +# pub id: Uuid, +# pub name: String, +# pub price_cents: i32, +# pub in_stock: bool, +# } +# +# #[fabrique::doctest] +# async fn main(pool: Pool) -> Result<(), fabrique::Error> { +# let products = vec![ +# Product::factory().name("Anvil 3000".to_string()).make(), +# Product::factory().name("Rocket Skates".to_string()).make(), +# ]; +// Insert all products, or update on ID conflict +products.upsert(&pool, Product::ID).await?; +# Ok(()) +# } +``` + +The second argument specifies the conflict target — the column +(or columns) that identify a unique row. All other columns are +updated when a conflict is detected. + +For a composite unique key, pass a tuple: + +```rust,no_run +# extern crate fabrique; +# extern crate sqlx; +# extern crate uuid; +# use fabrique::prelude::*; +# use uuid::Uuid; +# #[derive(Clone, Debug, Factory, Model)] +# pub struct Product { +# pub id: Uuid, +# pub name: String, +# pub price_cents: i32, +# pub in_stock: bool, +# } +# async fn example( +# products: Vec, +# pool: Pool, +# ) -> Result<(), fabrique::Error> { +products + .upsert(&pool, (Product::NAME, Product::PRICE_CENTS)) + .await?; +# Ok(()) +# } +``` + +> **Note:** The conflict target columns must have a UNIQUE +> constraint or be the primary key in your database schema. + If you only want to skip duplicates without updating, use `.do_nothing()`: diff --git a/fabrique-core/src/lib.rs b/fabrique-core/src/lib.rs index 32e9b1e..f27fc46 100644 --- a/fabrique-core/src/lib.rs +++ b/fabrique-core/src/lib.rs @@ -12,6 +12,7 @@ pub mod factory; pub mod model; pub mod relation; pub mod sql; +pub mod upsert; // Re-export for use in generated code pub use database::Nil; @@ -24,3 +25,5 @@ pub use factory::SetForeignKey; pub use relation::Alias; pub use relation::BelongsTo; pub use relation::Joinable; +pub use upsert::UniqueBy; +pub use upsert::Upsert; diff --git a/fabrique-core/src/model.rs b/fabrique-core/src/model.rs index bf6be25..79c1b26 100644 --- a/fabrique-core/src/model.rs +++ b/fabrique-core/src/model.rs @@ -116,6 +116,10 @@ pub trait Persist: Model { ) -> impl Future> + Send + 'e where A: sqlx::Acquire<'e, Database = DB> + Send + 'e; + + /// Pushes all field values as bind parameters into a separated + /// query builder row. + fn push_bind_values(self, separated: sqlx::query_builder::Separated<'_, DB, &'static str>); } /// Delete operations diff --git a/fabrique-core/src/upsert.rs b/fabrique-core/src/upsert.rs new file mode 100644 index 0000000..f3d9dc2 --- /dev/null +++ b/fabrique-core/src/upsert.rs @@ -0,0 +1,84 @@ +use crate::{ + database::Column, + dialect::Dialect, + model::{Model, Persist}, +}; + +pub trait UniqueBy { + fn column_names() -> &'static [&'static str]; +} + +macro_rules! impl_unique_by { + ($($C:ident),+) => { + impl UniqueBy for ($($C,)+) + where + $($C: Column,)+ + { + fn column_names() -> &'static [&'static str] { + &[$($C::NAME),+] + } + } + }; +} +impl_unique_by!(C0); +impl_unique_by!(C0, C1); +impl_unique_by!(C0, C1, C2); +impl_unique_by!(C0, C1, C2, C3); + +pub trait Upsert { + type Model: Model; + + fn upsert<'e, A, U>( + self, + executor: A, + unique_by: U, + ) -> impl Future> + where + A: sqlx::Acquire<'e, Database = DB> + Send + 'e, + U: UniqueBy; +} + +impl Upsert for Vec +where + M: Persist + Send, + DB: Dialect, + ::Arguments: sqlx::IntoArguments, + for<'c> &'c mut ::Connection: sqlx::Executor<'c, Database = DB>, +{ + type Model = M; + + async fn upsert<'e, A, U>(self, executor: A, _unique_by: U) -> Result<(), crate::Error> + where + A: sqlx::Acquire<'e, Database = DB> + Send + 'e, + U: UniqueBy, + { + if self.is_empty() { + return Ok(()); + } + + let mut conn = executor.acquire().await.map_err(crate::Error::from)?; + + let mut qb = sqlx::QueryBuilder::new("INSERT INTO "); + qb.push(M::table_name()); + qb.push(" ("); + qb.push(M::columns().join(", ")); + qb.push(") "); + + qb.push_values(self, |separated, model| { + model.push_bind_values(separated); + }); + + let unique_cols = U::column_names(); + qb.push(DB::on_conflict_sql(unique_cols)); + + let update_cols: Vec<&str> = M::columns() + .iter() + .copied() + .filter(|c| !unique_cols.contains(c)) + .collect(); + qb.push(DB::do_update_sql(&update_cols)); + + qb.build().execute(&mut *conn).await?; + Ok(()) + } +} diff --git a/fabrique-derive/src/codegen/columns.rs b/fabrique-derive/src/codegen/columns.rs index 36d5f51..e35104b 100644 --- a/fabrique-derive/src/codegen/columns.rs +++ b/fabrique-derive/src/codegen/columns.rs @@ -70,6 +70,12 @@ impl<'a> ColumnsCodegen<'a> { #into_db_body } } + + impl ::fabrique::UniqueBy<#base_struct_ident> for #type_name { + fn column_names() -> &'static [&'static str] { + &[#column_name] + } + } } }); @@ -127,6 +133,13 @@ mod tests { value } } + + impl ::fabrique::UniqueBy for AnvilIdColumn { + fn column_names() -> &'static [&'static str] { + &["id"] + } + } + impl ::fabrique::Column for AnvilNameColumn { type Type = String; type DbType = String; @@ -139,6 +152,12 @@ mod tests { } } + impl ::fabrique::UniqueBy for AnvilNameColumn { + fn column_names() -> &'static [&'static str] { + &["name"] + } + } + impl Anvil { pub const ID: AnvilIdColumn = AnvilIdColumn; pub const NAME: AnvilNameColumn = AnvilNameColumn; @@ -182,6 +201,13 @@ mod tests { value } } + + impl ::fabrique::UniqueBy for AccountIdColumn { + fn column_names() -> &'static [&'static str] { + &["id"] + } + } + impl ::fabrique::Column for AccountStatusColumn { type Type = Status; type DbType = String; @@ -194,6 +220,12 @@ mod tests { } } + impl ::fabrique::UniqueBy for AccountStatusColumn { + fn column_names() -> &'static [&'static str] { + &["status"] + } + } + impl Account { pub const ID: AccountIdColumn = AccountIdColumn; pub const STATUS: AccountStatusColumn = AccountStatusColumn; diff --git a/fabrique-derive/src/codegen/factory.rs b/fabrique-derive/src/codegen/factory.rs index 4ca3906..6b7736b 100644 --- a/fabrique-derive/src/codegen/factory.rs +++ b/fabrique-derive/src/codegen/factory.rs @@ -42,6 +42,7 @@ impl<'a> FactoryCodegen<'a> { let factory_ident = &self.ident; let factory_fields = self.generate_factory_fields(); let factory_method_new = self.generate_factory_method_new(); + let factory_method_make = self.generate_factory_method_make(); let factory_method_fields = self.generate_factory_method_fields(); let factory_methods_for_relation = self.generate_factory_methods_for_relation(); let factory_relation_fields = self.generate_factory_relation_fields(); @@ -87,6 +88,8 @@ impl<'a> FactoryCodegen<'a> { impl #factory_ident { #factory_method_new + #factory_method_make + #(#factory_method_fields)* #(#factory_methods_for_relation)* @@ -369,6 +372,47 @@ impl<'a> FactoryCodegen<'a> { } } + /// Generates the `make()` method that builds a model instance + /// without persisting it. + fn generate_factory_method_make(&self) -> TokenStream { + let struct_ident = &self.analysis.ident; + + let has_custom_faker = self + .analysis + .column_fields + .iter() + .any(|f| f.faker.is_some()); + + let fake_import = if has_custom_faker { + quote! { use ::fabrique::fake::Fake; } + } else { + quote! {} + }; + + let column_fields = self.analysis.column_fields.iter().map(|field| { + let name = &field.ident; + let ty = &field.ty; + + match &field.faker { + Some(faker_expr) => quote! { + #name: self.#name.unwrap_or_else(|| #faker_expr.fake()) + }, + None => quote! { + #name: self.#name.unwrap_or_else(::fabrique::seeded_value::<#ty>) + }, + } + }); + + quote! { + pub fn make(self) -> #struct_ident { + #fake_import + #struct_ident { + #(#column_fields,)* + } + } + } + } + /// Generates setter methods for each field in the factory struct. /// /// Each setter method takes a value and stores it in the factory's optional @@ -661,6 +705,15 @@ mod tests { } } + pub fn make(self) -> Anvil { + Anvil { + id: self.id.unwrap_or_else(::fabrique::seeded_value::), + hammer_id: self.hammer_id.unwrap_or_else(::fabrique::seeded_value::), + hardness: self.hardness.unwrap_or_else(::fabrique::seeded_value::), + weight: self.weight.unwrap_or_else(::fabrique::seeded_value::), + } + } + pub fn id(mut self, id: u32) -> Self { self.id = Some(id); self @@ -1061,4 +1114,38 @@ mod tests { generated ); } + + #[test] + fn test_generate_factory_method_make_with_custom_faker() { + // Arrange + let input = parse_quote! { + struct User { + id: u32, + #[fabrique(faker = "Name()")] + name: String, + } + }; + let analysis = Analysis::from(&input).unwrap(); + let factory = FactoryCodegen::new(&analysis); + + // Act + let generated = factory.generate_factory_method_make().to_string(); + + // Assert + assert!( + generated.contains("use :: fabrique :: fake :: Fake"), + "Should import Fake trait when custom faker is used. Generated: {}", + generated + ); + assert!( + generated.contains("Name () . fake ()"), + "Should use custom faker expression. Generated: {}", + generated + ); + assert!( + generated.contains("seeded_value :: < u32 >"), + "Fields without faker should use seeded_value. Generated: {}", + generated + ); + } } diff --git a/fabrique-derive/src/codegen/persist.rs b/fabrique-derive/src/codegen/persist.rs index e4fb72e..4f7c161 100644 --- a/fabrique-derive/src/codegen/persist.rs +++ b/fabrique-derive/src/codegen/persist.rs @@ -18,6 +18,7 @@ impl<'a> PersistCodegen<'a> { let base_struct_ident = &self.analysis.ident; let fn_create = self.generate_fn_create(); let fn_save = self.generate_fn_save(); + let fn_push_bind_values = self.generate_fn_push_bind_values(); // Per-column-field Encode/Type bounds (using `as` type when present) let field_bounds = self.analysis.column_fields.iter().map(|f| { @@ -38,6 +39,8 @@ impl<'a> PersistCodegen<'a> { #fn_create #fn_save + + #fn_push_bind_values } } } @@ -70,6 +73,29 @@ impl<'a> PersistCodegen<'a> { } } + /// Generates the `push_bind_values()` method. + fn generate_fn_push_bind_values(&self) -> TokenStream { + let base_struct_ident = &self.analysis.ident; + let bind_calls = self.analysis.column_fields.iter().map(|field| { + let ident = &field.ident; + let column_type = &field.column_type; + quote! { + separated.push_bind( + <#column_type as ::fabrique::Column<#base_struct_ident>>::into_db(self.#ident) + ); + } + }); + + quote! { + fn push_bind_values( + self, + mut separated: ::sqlx::query_builder::Separated<'_, DB, &'static str>, + ) { + #(#bind_calls)* + } + } + } + /// Generates the `save()` method using execute + find (universal). fn generate_fn_save(&self) -> TokenStream { let set_calls = self.analysis.column_fields.iter().map(|field| { @@ -162,6 +188,15 @@ mod tests { .map_err(Into::into) } } + + fn push_bind_values( + self, + mut separated: ::sqlx::query_builder::Separated<'_, DB, &'static str>, + ) { + separated.push_bind( + >::into_db(self.id) + ); + } } } .to_string() diff --git a/fabrique/src/lib.rs b/fabrique/src/lib.rs index f1ac3ec..e247bb0 100644 --- a/fabrique/src/lib.rs +++ b/fabrique/src/lib.rs @@ -88,6 +88,7 @@ pub use error::*; pub use factory::*; pub use model::*; pub use relation::*; +pub use upsert::*; #[cfg(feature = "testing")] pub use fake; @@ -117,6 +118,7 @@ pub mod model; pub mod prelude; pub mod relation; pub mod sql; +pub mod upsert; #[cfg(feature = "testing")] #[doc(hidden)] pub use fabrique_core::__private; diff --git a/fabrique/src/prelude.rs b/fabrique/src/prelude.rs index 5c390d8..916ea36 100644 --- a/fabrique/src/prelude.rs +++ b/fabrique/src/prelude.rs @@ -5,3 +5,4 @@ pub use crate::error::*; pub use crate::factory::*; pub use crate::model::*; pub use crate::relation::*; +pub use crate::upsert::*; diff --git a/fabrique/src/upsert.rs b/fabrique/src/upsert.rs new file mode 100644 index 0000000..ff8e1f8 --- /dev/null +++ b/fabrique/src/upsert.rs @@ -0,0 +1 @@ +pub use fabrique_core::upsert::*; diff --git a/fabrique/tests/upsert.rs b/fabrique/tests/upsert.rs new file mode 100644 index 0000000..20ca8f5 --- /dev/null +++ b/fabrique/tests/upsert.rs @@ -0,0 +1,120 @@ +use fabrique::prelude::*; +use uuid::Uuid; + +#[derive(Clone, Debug, Default, Factory, PartialEq, Model)] +#[allow(dead_code)] +pub struct Product { + pub id: Uuid, + pub name: String, + pub price_cents: i32, + pub in_stock: bool, +} + +#[fabrique::test] +async fn test_upsert_inserts_new_records(pool: Pool) { + let products = vec![ + Product::factory::() + .name("Anvil 3000".to_owned()) + .make(), + Product::factory::() + .name("Rocket Skates".to_owned()) + .make(), + ]; + + products.upsert(&pool, Product::ID).await.unwrap(); + + let all = Product::all(&pool).await.unwrap(); + assert_eq!(all.len(), 2); +} + +#[fabrique::test] +async fn test_upsert_updates_existing_records(pool: Pool) { + let product = Product::factory() + .name("Anvil 3000".to_owned()) + .price_cents(9999) + .create(&pool) + .await + .unwrap(); + + let updated = vec![Product { + price_cents: 4999, + ..product + }]; + + updated.upsert(&pool, Product::ID).await.unwrap(); + + let all = Product::all(&pool).await.unwrap(); + assert_eq!(all.len(), 1); + assert_eq!(all[0].price_cents, 4999); +} + +#[fabrique::test] +async fn test_upsert_with_composite_unique_by(pool: Pool) { + let products = vec![ + Product::factory::() + .name("Anvil 3000".to_owned()) + .price_cents(9999) + .make(), + Product::factory::() + .name("Anvil 3000".to_owned()) + .price_cents(4999) + .make(), + ]; + + products + .upsert(&pool, (Product::NAME, Product::PRICE_CENTS)) + .await + .unwrap(); + + let all = Product::all(&pool).await.unwrap(); + assert_eq!(all.len(), 2); +} + +#[fabrique::test] +async fn test_upsert_empty_vec_is_noop(pool: Pool) { + let products: Vec = vec![]; + + products.upsert(&pool, Product::ID).await.unwrap(); + + let all = Product::all(&pool).await.unwrap(); + assert_eq!(all.len(), 0); +} + +#[fabrique::test] +async fn test_upsert_with_tuples(pool: Pool) { + let products = vec![ + Product::factory::() + .name("Anvil 3000".to_owned()) + .price_cents(100) + .make(), + Product::factory::() + .name("Rocket Skates".to_owned()) + .price_cents(200) + .make(), + ]; + + // Arity 1: tuple with single column + products + .clone() + .upsert(&pool, (Product::ID,)) + .await + .unwrap(); + + let all = Product::all(&pool).await.unwrap(); + assert_eq!(all.len(), 2); + + // Arity 2: tuple with composite unique (name, price_cents) + let updated = vec![Product { + in_stock: false, + ..products[0].clone() + }]; + updated + .upsert(&pool, (Product::NAME, Product::PRICE_CENTS)) + .await + .unwrap(); + + let all = Product::all(&pool).await.unwrap(); + assert_eq!(all.len(), 2); + let anvil = all.iter().find(|p| p.name == "Anvil 3000").unwrap(); + assert!(!anvil.in_stock); +} diff --git a/migrations/mysql/00001_initial_schema.sql b/migrations/mysql/00001_initial_schema.sql index 52b4233..1d94587 100644 --- a/migrations/mysql/00001_initial_schema.sql +++ b/migrations/mysql/00001_initial_schema.sql @@ -20,7 +20,8 @@ CREATE TABLE products ( name VARCHAR(255) NOT NULL, price_cents INTEGER NOT NULL, in_stock BOOLEAN NOT NULL DEFAULT true, - deleted_at DATETIME + deleted_at DATETIME, + UNIQUE (name, price_cents) ); CREATE TABLE orders ( diff --git a/migrations/postgres/00001_initial_schema.sql b/migrations/postgres/00001_initial_schema.sql index 280563f..a3b53b6 100644 --- a/migrations/postgres/00001_initial_schema.sql +++ b/migrations/postgres/00001_initial_schema.sql @@ -19,7 +19,8 @@ CREATE TABLE products ( name VARCHAR(255) NOT NULL, price_cents INTEGER NOT NULL, in_stock BOOLEAN NOT NULL DEFAULT true, - deleted_at TIMESTAMPTZ + deleted_at TIMESTAMPTZ, + UNIQUE (name, price_cents) ); CREATE TABLE orders ( diff --git a/migrations/sqlite/00001_initial_schema.sql b/migrations/sqlite/00001_initial_schema.sql index 3a6d749..36c4857 100644 --- a/migrations/sqlite/00001_initial_schema.sql +++ b/migrations/sqlite/00001_initial_schema.sql @@ -21,7 +21,8 @@ CREATE TABLE products ( name TEXT NOT NULL, price_cents INTEGER NOT NULL, in_stock BOOLEAN NOT NULL DEFAULT 1, - deleted_at TEXT + deleted_at TEXT, + UNIQUE (name, price_cents) ); CREATE TABLE orders (