diff --git a/lang-v2/derive/src/parse.rs b/lang-v2/derive/src/parse.rs index a87730d735..345e4cce33 100644 --- a/lang-v2/derive/src/parse.rs +++ b/lang-v2/derive/src/parse.rs @@ -364,11 +364,15 @@ pub fn parse_account_attrs(attrs: &[Attribute]) -> syn::Result { // supplied bump could be non-canonical and either create the wrong // PDA or fail under the runtime's curve check, so we don't allow the // combination at all. - if (result.is_init || result.is_init_if_needed) && matches!(result.bump, Some(Some(_))) { + // + // `init_if_needed` is different: the create branch still uses the + // canonical bump, while the existing-account branch can verify an + // explicit stored bump after loading the account. + if result.is_init && matches!(result.bump, Some(Some(_))) { if let Some(Some(ref bump_expr)) = result.bump { return Err(syn::Error::new( syn::spanned::Spanned::span(bump_expr), - "`bump = ` is not allowed with `init` / `init_if_needed`: account creation \ + "`bump = ` is not allowed with `init`: account creation \ must use the canonical bump (write `bump` without a value)", )); } @@ -1252,6 +1256,12 @@ pub fn parse_field( }; let has_bump = attrs.seeds.is_some(); + let init_if_needed_existed = attrs.is_init_if_needed.then(|| { + Ident::new( + &format!("__anchor_{}_existed", field_name), + proc_macro2::Span::call_site(), + ) + }); // --- Load --- if is_nested_type(field_ty) { @@ -1391,7 +1401,18 @@ pub fn parse_field( )?) } }; + let init_if_needed_existed_binding = init_if_needed_existed.as_ref().map(|existed| { + quote! { + let #existed = { + let __target = __views[#offset_expr]; + !anchor_lang_v2::address_eq(__target.address(), __program_id) + && __target.data_len() > 0 + && !__target.owned_by(&anchor_lang_v2::programs::System::id()) + }; + } + }); let load = quote! { + #init_if_needed_existed_binding let mut #field_name: #field_ty = { let __target = __views[#offset_expr]; if anchor_lang_v2::address_eq(__target.address(), __program_id) { @@ -1430,10 +1451,16 @@ pub fn parse_field( }; let init_body_with_constraints = wrap_init_body_with_constraints(field_ty, &attrs, &init_body); + let existed = init_if_needed_existed.as_ref().unwrap(); deferred_load = Some(quote! { + let #existed = { + let __target = __views[#offset_expr]; + __target.data_len() > 0 + && !__target.owned_by(&anchor_lang_v2::programs::System::id()) + }; let mut #field_name: #field_ty = { let __target = __views[#offset_expr]; - if __target.data_len() > 0 && !__target.owned_by(&anchor_lang_v2::programs::System::id()) { + if #existed { // SAFETY: the bitvec duplicate-account check below ensures // no other mutable reference to this account's data exists. unsafe { <#field_ty as anchor_lang_v2::AnchorAccount>::load_mut(__target, __program_id)? } @@ -1527,7 +1554,7 @@ pub fn parse_field( if let Expr::Array(arr) = seeds_expr { // Array-literal seeds: `seeds = [b"vault", user.address().as_ref()]` let seed_elems: Vec<&Expr> = arr.elems.iter().collect(); - if let Some(Some(ref bump_expr)) = attrs.bump { + let seed_constraint = if let Some(Some(ref bump_expr)) = attrs.bump { let bump_assign = if is_optional { quote! { Some(__bump_val) } } else { @@ -1535,7 +1562,7 @@ pub fn parse_field( }; let (seed_bindings, seed_refs) = materialize_seed_refs(&seed_elems, field_names); - constraints.push(quote! { + quote! { { #(#seed_bindings)* let __bump_val: u8 = #bump_expr; @@ -1546,10 +1573,10 @@ pub fn parse_field( )?; __bumps.#field_name = #bump_assign; } - }); + } } else { let target_addr_ref = quote! { #field_name.account().address() }; - constraints.push(emit_seeds_check( + emit_seeds_check( &seed_elems, field_names, &pda_program, @@ -1559,8 +1586,17 @@ pub fn parse_field( false, using_our_program_id, is_optional, - )); - } + ) + }; + constraints.push(if let Some(existed) = init_if_needed_existed.as_ref() { + quote! { + if #existed { + #seed_constraint + } + } + } else { + seed_constraint + }); } else { // Opaque expression: `seeds = Counter::seeds()` etc. let bump_assign = if is_optional { @@ -1568,9 +1604,9 @@ pub fn parse_field( } else { quote! { __bump } }; - if let Some(Some(ref bump_expr)) = attrs.bump { + let seed_constraint = if let Some(Some(ref bump_expr)) = attrs.bump { // Explicit bump + expression seeds: verify with appended bump - constraints.push(quote! { + quote! { { let __seed_val = #seeds_expr; let __seed_ref: &[&[u8]] = __seed_val.as_ref(); @@ -1590,7 +1626,7 @@ pub fn parse_field( )?; __bumps.#field_name = #bump_assign; } - }); + } } else { // Bare bump: use find_and_verify with skip_curve // when the account type guarantees non-zero data. @@ -1598,7 +1634,7 @@ pub fn parse_field( <#field_ty as anchor_lang_v2::AnchorAccount>::MIN_DATA_LEN > 0 }; let target_addr = quote! { #field_name.account().address() }; - constraints.push(quote! { + quote! { { let __seed_val = #seeds_expr; let __seed_ref: &[&[u8]] = __seed_val.as_ref(); @@ -1613,8 +1649,17 @@ pub fn parse_field( }; __bumps.#field_name = #bump_assign; } - }); - } + } + }; + constraints.push(if let Some(existed) = init_if_needed_existed.as_ref() { + quote! { + if #existed { + #seed_constraint + } + } + } else { + seed_constraint + }); } } } @@ -2142,19 +2187,13 @@ mod tests { } #[test] - fn init_if_needed_with_explicit_bump_is_rejected() { + fn init_if_needed_with_explicit_bump_is_accepted() { let attrs: Vec = vec![syn::parse_quote!( #[account(init_if_needed, payer = payer, space = 8, seeds = [b"x"], bump = 0)] )]; - let err = match parse_account_attrs(&attrs) { - Ok(_) => panic!("init_if_needed + bump= must be rejected"), - Err(err) => err, - }; - assert!( - err.to_string() - .contains("`bump = ` is not allowed with `init`"), - "unexpected error: {err}" - ); + let parsed = parse_account_attrs(&attrs).expect("init_if_needed + bump="); + assert!(parsed.is_init_if_needed); + assert!(matches!(parsed.bump, Some(Some(_)))); } #[test] diff --git a/tests-v2/programs/constraints/src/lib.rs b/tests-v2/programs/constraints/src/lib.rs index f14f827771..b2c4c53b8a 100644 --- a/tests-v2/programs/constraints/src/lib.rs +++ b/tests-v2/programs/constraints/src/lib.rs @@ -58,6 +58,13 @@ pub struct Data { pub value: u64, } +#[account] +pub struct DataWithBump { + pub value: u64, + pub bump: u8, + pub _padding: [u8; 7], +} + // -- Handlers ---------------------------------------------------------------- #[program] @@ -183,6 +190,17 @@ pub mod constraints { pub fn check_address_into_ref(_ctx: &mut Context) -> Result<()> { Ok(()) } + + /// First call creates the PDA with the canonical bump; subsequent + /// calls verify against the bump stored in the account. + #[discrim = 19] + pub fn do_init_if_needed_explicit_bump( + ctx: &mut Context, + ) -> Result<()> { + ctx.accounts.data.bump = ctx.bumps.data; + ctx.accounts.data.value = ctx.accounts.data.value.wrapping_add(1); + Ok(()) + } } // -- Accounts structs -------------------------------------------------------- @@ -311,6 +329,20 @@ pub struct DoInitIfNeeded { pub system_program: Program, } +#[derive(Accounts)] +pub struct DoInitIfNeededExplicitBump { + #[account(mut)] + pub payer: Signer, + #[account( + init_if_needed, + payer = payer, + seeds = [b"maybe-explicit"], + bump = data.bump, + )] + pub data: Account, + pub system_program: Program, +} + // 13. zeroed #[derive(Accounts)] pub struct CheckZeroed { diff --git a/tests-v2/tests/constraints.rs b/tests-v2/tests/constraints.rs index 7e85ef7e1a..a209eea75e 100644 --- a/tests-v2/tests/constraints.rs +++ b/tests-v2/tests/constraints.rs @@ -64,6 +64,10 @@ fn maybe_pda() -> Pubkey { Pubkey::find_program_address(&[b"maybe"], &program_id()).0 } +fn maybe_explicit_bump_pda() -> Pubkey { + Pubkey::find_program_address(&[b"maybe-explicit"], &program_id()).0 +} + fn other_pda() -> Pubkey { Pubkey::find_program_address(&[b"other"], &other_program()).0 } @@ -471,6 +475,18 @@ fn read_value(svm: &LiteSVM, pda: &Pubkey) -> Option { Some(u64::from_le_bytes(account.data[40..48].try_into().unwrap())) } +/// `DataWithBump` layout: disc(8) + value(u64) + bump(u8). +fn read_value_and_bump(svm: &LiteSVM, pda: &Pubkey) -> Option<(u64, u8)> { + let account = svm.get_account(pda)?; + if account.data.len() < 17 { + return None; + } + Some(( + u64::from_le_bytes(account.data[8..16].try_into().unwrap()), + account.data[16], + )) +} + #[test] fn close_transfers_lamports_and_zeros_account() { let (mut svm, payer, authority) = setup(); @@ -625,6 +641,88 @@ fn init_if_needed_creates_then_reuses() { ); } +#[test] +fn init_if_needed_allows_explicit_bump_on_reuse() { + let (mut svm, payer, _) = setup(); + let pda = maybe_explicit_bump_pda(); + let canonical_bump = Pubkey::find_program_address(&[b"maybe-explicit"], &program_id()).1; + + call( + &mut svm, + &payer, + 19, + vec![ + AccountMeta::new(payer.pubkey(), true), + AccountMeta::new(pda, false), + AccountMeta::new_readonly(solana_sdk_ids::system_program::ID, false), + ], + &[], + ) + .expect("first init_if_needed explicit-bump call"); + assert_eq!( + read_value_and_bump(&svm, &pda), + Some((1, canonical_bump)), + "create branch should store the canonical bump for later reuse", + ); + + svm.expire_blockhash(); + + call( + &mut svm, + &payer, + 19, + vec![ + AccountMeta::new(payer.pubkey(), true), + AccountMeta::new(pda, false), + AccountMeta::new_readonly(solana_sdk_ids::system_program::ID, false), + ], + &[], + ) + .expect("second init_if_needed explicit-bump call"); + assert_eq!( + read_value_and_bump(&svm, &pda), + Some((2, canonical_bump)), + "existing branch should verify against the stored bump and reuse account data", + ); +} + +#[test] +fn init_if_needed_rejects_wrong_explicit_bump_on_reuse() { + let (mut svm, payer, _) = setup(); + let pda = maybe_explicit_bump_pda(); + + call( + &mut svm, + &payer, + 19, + vec![ + AccountMeta::new(payer.pubkey(), true), + AccountMeta::new(pda, false), + AccountMeta::new_readonly(solana_sdk_ids::system_program::ID, false), + ], + &[], + ) + .expect("first init_if_needed explicit-bump call"); + + let mut account = svm.get_account(&pda).expect("account exists"); + account.data[16] = account.data[16].wrapping_add(1); + svm.set_account(pda, account).expect("corrupt stored bump"); + svm.expire_blockhash(); + + let result = call_raw( + &mut svm, + &payer, + 19, + vec![ + AccountMeta::new(payer.pubkey(), true), + AccountMeta::new(pda, false), + AccountMeta::new_readonly(solana_sdk_ids::system_program::ID, false), + ], + &[], + ); + assert_err_contains(&result, "InvalidSeeds"); +} + // ---- 13. zeroed ----------------------------------------------------------- #[test]