diff --git a/CHANGELOG.md b/CHANGELOG.md index ce4899139..dbbd9ca92 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - program: prevent user from enabling HLM when they are failing maintenance margin check [#2116](https://github.com/drift-labs/protocol-v2/pull/2116) - program: fix bug where users are stuck in liquidation status after completed liquidation [#2122](https://github.com/drift-labs/protocol-v2/pull/2122) - program: skip isolated positions when checking for cross margin bankruptcy [#2123](https://github.com/drift-labs/protocol-v2/pull/2123) +- program: prevent OOM when close account ix included for swap [#2148](https://github.com/drift-labs/protocol-v2/pull/2148) ### Breaking diff --git a/programs/drift/src/instructions/user.rs b/programs/drift/src/instructions/user.rs index a90a36ce1..8a5425ff5 100644 --- a/programs/drift/src/instructions/user.rs +++ b/programs/drift/src/instructions/user.rs @@ -120,6 +120,8 @@ use crate::{controller, math}; use crate::{load_mut, ExchangeStatus}; use anchor_lang::solana_program::sysvar::instructions; use borsh::{BorshDeserialize, BorshSerialize}; +use solana_program::pubkey::PUBKEY_BYTES; +use solana_program::serialize_utils; use solana_program::sysvar::instructions::ID as IX_ID; use super::optional_accounts::get_high_leverage_mode_config; @@ -3754,10 +3756,144 @@ pub fn handle_enable_user_high_leverage_mode<'c: 'info, 'info>( Ok(()) } +// We intentionally parse the instructions sysvar with zero-copy byte slices here +// to avoid heap growth/OOM risk from repeatedly deserializing full `Instruction`s, +// especially when post-swap CloseAccount instructions are included. +const INSTRUCTION_ACCOUNT_META_SIZE: usize = 1 + PUBKEY_BYTES; +const INSTRUCTION_ACCOUNT_META_IS_WRITABLE_BIT: u8 = 1 << 1; + +struct InstructionSysvarView<'a> { + program_id: Pubkey, + account_meta_bytes: &'a [u8], + account_metas_len: usize, + data: &'a [u8], +} + +impl<'a> InstructionSysvarView<'a> { + fn accounts_len(&self) -> usize { + self.account_metas_len + } + + fn account_meta_bytes_at(&self, index: usize) -> std::result::Result<&[u8], ProgramError> { + if index >= self.account_metas_len { + return Err(ProgramError::InvalidInstructionData); + } + + let start = index + .checked_mul(INSTRUCTION_ACCOUNT_META_SIZE) + .ok_or(ProgramError::InvalidInstructionData)?; + let end = start + .checked_add(INSTRUCTION_ACCOUNT_META_SIZE) + .ok_or(ProgramError::InvalidInstructionData)?; + + self.account_meta_bytes + .get(start..end) + .ok_or(ProgramError::InvalidInstructionData) + } + + fn account_pubkey_bytes_at(&self, index: usize) -> std::result::Result<&[u8], ProgramError> { + let account_meta = self.account_meta_bytes_at(index)?; + Ok(&account_meta[1..]) + } + + fn account_meta_bytes_iter(&self) -> impl Iterator { + self.account_meta_bytes + .chunks_exact(INSTRUCTION_ACCOUNT_META_SIZE) + } + + fn account_pubkey_equals( + &self, + index: usize, + key: &Pubkey, + ) -> std::result::Result { + Ok(self.account_pubkey_bytes_at(index)? == key.as_ref()) + } +} + +fn read_u16_le( + instruction_sysvar_data: &[u8], + offset: &mut usize, +) -> std::result::Result { + serialize_utils::read_u16(offset, instruction_sysvar_data) + .map_err(|_| ProgramError::InvalidInstructionData) +} + +fn read_pubkey( + instruction_sysvar_data: &[u8], + offset: &mut usize, +) -> std::result::Result { + serialize_utils::read_pubkey(offset, instruction_sysvar_data) + .map_err(|_| ProgramError::InvalidInstructionData) +} + +fn read_slice<'a>( + instruction_sysvar_data: &'a [u8], + offset: &mut usize, + len: usize, +) -> std::result::Result<&'a [u8], ProgramError> { + let end = offset + .checked_add(len) + .ok_or(ProgramError::InvalidInstructionData)?; + let slice = instruction_sysvar_data + .get(*offset..end) + .ok_or(ProgramError::InvalidInstructionData)?; + *offset = end; + Ok(slice) +} + +fn load_instruction_sysvar_view_at<'a>( + index: usize, + instruction_sysvar_data: &'a [u8], +) -> std::result::Result, ProgramError> { + let mut offset = 0; + let num_instructions = read_u16_le(instruction_sysvar_data, &mut offset)? as usize; + if index >= num_instructions { + return Err(ProgramError::InvalidArgument); + } + + offset = 2usize + .checked_add( + index + .checked_mul(2) + .ok_or(ProgramError::InvalidInstructionData)?, + ) + .ok_or(ProgramError::InvalidInstructionData)?; + + let ix_offset = read_u16_le(instruction_sysvar_data, &mut offset)? as usize; + + let mut ix_read_offset = ix_offset; + let account_metas_len = read_u16_le(instruction_sysvar_data, &mut ix_read_offset)? as usize; + + let account_meta_bytes_len = account_metas_len + .checked_mul(INSTRUCTION_ACCOUNT_META_SIZE) + .ok_or(ProgramError::InvalidInstructionData)?; + let account_meta_bytes = read_slice( + instruction_sysvar_data, + &mut ix_read_offset, + account_meta_bytes_len, + )?; + + let program_id = read_pubkey(instruction_sysvar_data, &mut ix_read_offset)?; + + let instruction_data_len = read_u16_le(instruction_sysvar_data, &mut ix_read_offset)? as usize; + let data = read_slice( + instruction_sysvar_data, + &mut ix_read_offset, + instruction_data_len, + )?; + + Ok(InstructionSysvarView { + program_id, + account_meta_bytes, + account_metas_len, + data, + }) +} + /// Checks if an instruction is a SPL Token CloseAccount targeting /// one of the swap's token accounts. fn is_token_close_account_for_swap_ix( - ix: &solana_program::instruction::Instruction, + ix: &InstructionSysvarView, in_token_account: &Pubkey, out_token_account: &Pubkey, ) -> bool { @@ -3774,12 +3910,22 @@ fn is_token_close_account_for_swap_ix( } // The first account in CloseAccount is the account being closed - if ix.accounts.is_empty() { + if ix.accounts_len() == 0 { return false; } - let account_to_close = &ix.accounts[0].pubkey; - account_to_close == in_token_account || account_to_close == out_token_account + let first_account_meta = match ix.account_meta_bytes_at(0) { + Ok(account_meta) => account_meta, + Err(_) => return false, + }; + + let account_to_close = &first_account_meta[1..]; + let is_in_token_account = account_to_close == in_token_account.as_ref(); + if is_in_token_account { + return true; + } + + account_to_close == out_token_account.as_ref() } #[access_control( @@ -3919,27 +4065,35 @@ pub fn handle_begin_swap<'c: 'info, 'info>( )?; let ixs = ctx.accounts.instructions.as_ref(); + validate!( + instructions::check_id(ixs.key), + ErrorCode::InvalidSwap, + "invalid instructions sysvar account" + )?; + let current_index = instructions::load_current_index_checked(ixs)? as usize; + let instruction_sysvar_data = ixs.try_borrow_data()?; - let current_ix = instructions::load_instruction_at_checked(current_index, ixs)?; + let current_ix = load_instruction_sysvar_view_at(current_index, &instruction_sysvar_data)?; validate!( current_ix.program_id == *ctx.program_id, ErrorCode::InvalidSwap, "SwapBegin must be a top-level instruction (cant be cpi)" )?; + let drift_program_id = crate::id(); // The only other drift program allowed is SwapEnd let mut index = current_index + 1; let mut found_end = false; loop { - let ix = match instructions::load_instruction_at_checked(index, ixs) { + let ix = match load_instruction_sysvar_view_at(index, &instruction_sysvar_data) { Ok(ix) => ix, Err(ProgramError::InvalidArgument) => break, Err(e) => return Err(e.into()), }; // Check that the drift program key is not used - if ix.program_id == crate::id() { + if ix.program_id == drift_program_id { // must be the last ix -- this could possibly be relaxed validate!( !found_end, @@ -3951,85 +4105,102 @@ pub fn handle_begin_swap<'c: 'info, 'info>( // must be the SwapEnd instruction let discriminator = crate::instruction::EndSwap::discriminator(); validate!( - ix.data[0..8] == discriminator, + ix.data.len() >= 8 && ix.data[0..8] == discriminator, ErrorCode::InvalidSwap, "last drift ix must be end of swap" )?; validate!( - ctx.accounts.user.key() == ix.accounts[1].pubkey, + ix.accounts_len() >= 11, + ErrorCode::InvalidSwap, + "SwapEnd instruction has insufficient accounts" + )?; + + validate!( + ix.account_pubkey_equals(1, &ctx.accounts.user.key())?, ErrorCode::InvalidSwap, "the user passed to SwapBegin and End must match" )?; validate!( - ctx.accounts.authority.key() == ix.accounts[3].pubkey, + ix.account_pubkey_equals(3, &ctx.accounts.authority.key())?, ErrorCode::InvalidSwap, "the authority passed to SwapBegin and End must match" )?; validate!( - ctx.accounts.out_spot_market_vault.key() == ix.accounts[4].pubkey, + ix.account_pubkey_equals(4, &ctx.accounts.out_spot_market_vault.key())?, ErrorCode::InvalidSwap, "the out_spot_market_vault passed to SwapBegin and End must match" )?; validate!( - ctx.accounts.in_spot_market_vault.key() == ix.accounts[5].pubkey, + ix.account_pubkey_equals(5, &ctx.accounts.in_spot_market_vault.key())?, ErrorCode::InvalidSwap, "the in_spot_market_vault passed to SwapBegin and End must match" )?; validate!( - ctx.accounts.out_token_account.key() == ix.accounts[6].pubkey, + ix.account_pubkey_equals(6, &ctx.accounts.out_token_account.key())?, ErrorCode::InvalidSwap, "the out_token_account passed to SwapBegin and End must match" )?; validate!( - ctx.accounts.in_token_account.key() == ix.accounts[7].pubkey, + ix.account_pubkey_equals(7, &ctx.accounts.in_token_account.key())?, ErrorCode::InvalidSwap, "the in_token_account passed to SwapBegin and End must match" )?; validate!( - ctx.remaining_accounts.len() == ix.accounts.len() - 11, + ctx.remaining_accounts.len() == ix.accounts_len() - 11, ErrorCode::InvalidSwap, "begin and end ix must have the same number of accounts" )?; - for i in 11..ix.accounts.len() { - validate!( - *ctx.remaining_accounts[i - 11].key == ix.accounts[i].pubkey, - ErrorCode::InvalidSwap, - "begin and end ix must have the same accounts. {}th account mismatch. begin: {}, end: {}", - i, - ctx.remaining_accounts[i - 11].key, - ix.accounts[i].pubkey - )?; - } - } else { - if found_end { - if ix.program_id == lighthouse::ID { - continue; - } - - // Allow closing the swap's token accounts after end_swap - if is_token_close_account_for_swap_ix( - &ix, - &ctx.accounts.in_token_account.key(), - &ctx.accounts.out_token_account.key(), - ) { - continue; - } - - for meta in ix.accounts.iter() { + let start_offset = 11 * INSTRUCTION_ACCOUNT_META_SIZE; + let end_remaining_accounts_meta_bytes = ix + .account_meta_bytes + .get(start_offset..) + .ok_or(ProgramError::InvalidInstructionData)?; + + for (i, (account_meta_bytes, begin_remaining_account)) in + end_remaining_accounts_meta_bytes + .chunks_exact(INSTRUCTION_ACCOUNT_META_SIZE) + .zip(ctx.remaining_accounts.iter()) + .enumerate() + { + if &account_meta_bytes[1..] != begin_remaining_account.key.as_ref() { + let mut end_account_bytes = [0_u8; 32]; + end_account_bytes.copy_from_slice(&account_meta_bytes[1..]); validate!( - meta.is_writable == false, + false, ErrorCode::InvalidSwap, - "instructions after swap end must not have writable accounts" + "begin and end ix must have the same accounts. {}th account mismatch. begin: {}, end: {}", + i + 11, + begin_remaining_account.key, + Pubkey::new_from_array(end_account_bytes) )?; } + } + } else { + if found_end { + let is_allowed_post_end_ix = ix.program_id == lighthouse::ID + || is_token_close_account_for_swap_ix( + &ix, + &ctx.accounts.in_token_account.key(), + &ctx.accounts.out_token_account.key(), + ); + + if !is_allowed_post_end_ix { + for account_meta_bytes in ix.account_meta_bytes_iter() { + validate!( + account_meta_bytes[0] & INSTRUCTION_ACCOUNT_META_IS_WRITABLE_BIT == 0, + ErrorCode::InvalidSwap, + "instructions after swap end must not have writable accounts" + )?; + } + } } else { let mut whitelisted_programs = WHITELISTED_SWAP_PROGRAMS.to_vec(); if !delegate_is_signer { @@ -4044,9 +4215,9 @@ pub fn handle_begin_swap<'c: 'info, 'info>( "only allowed to pass in ixs to ATA, openbook, Jupiter v3/v4/v6, dflow, or titan programs" )?; - for meta in ix.accounts.iter() { + for account_meta_bytes in ix.account_meta_bytes_iter() { validate!( - meta.pubkey != crate::id(), + &account_meta_bytes[1..] != drift_program_id.as_ref(), ErrorCode::InvalidSwap, "instructions between begin and end must not be drift instructions" )?; diff --git a/tests/spotSwap.ts b/tests/spotSwap.ts index d0c9683a4..66278a95c 100644 --- a/tests/spotSwap.ts +++ b/tests/spotSwap.ts @@ -40,6 +40,8 @@ import { import { NATIVE_MINT, TOKEN_PROGRAM_ID, + getAssociatedTokenAddressSync, + createAssociatedTokenAccountIdempotentInstruction, createCloseAccountInstruction, createTransferInstruction, } from '@solana/spl-token'; @@ -746,20 +748,35 @@ describe('spot swap', () => { assert(failed); }); - it.skip('swap and close token account after end_swap', async () => { + it('swap and close token account after end_swap', async () => { // takerUSDC has 0 balance - it can be closed after endSwap const amountIn = new BN(100).mul(QUOTE_PRECISION); + + const takerUsdcAta = getAssociatedTokenAddressSync( + usdcMint.publicKey, + takerDriftClient.wallet.publicKey + ); + const { beginSwapIx, endSwapIx } = await takerDriftClient.getSwapIx({ amountIn, inMarketIndex: 0, outMarketIndex: 1, - inTokenAccount: takerUSDC, + inTokenAccount: takerUsdcAta, outTokenAccount: takerWSOL, }); + // Include explicit open ix for ATA that gets closed after endSwap. + const openIx = createAssociatedTokenAccountIdempotentInstruction( + takerDriftClient.wallet.publicKey, + takerUsdcAta, + takerDriftClient.wallet.publicKey, + usdcMint.publicKey, + TOKEN_PROGRAM_ID + ); + // Simulate swap: send all USDC to maker const transferIn = createTransferInstruction( - takerUSDC, + takerUsdcAta, makerUSDC.publicKey, takerDriftClient.wallet.publicKey, amountIn.toNumber() @@ -773,9 +790,9 @@ describe('spot swap', () => { LAMPORTS_PER_SOL ); - // Close takerUSDC after endSwap (balance will be 0) + // Close takerUsdcAta after endSwap (balance will be 0) const closeIx = createCloseAccountInstruction( - takerUSDC, + takerUsdcAta, takerDriftClient.wallet.publicKey, takerDriftClient.wallet.publicKey, undefined, @@ -783,6 +800,7 @@ describe('spot swap', () => { ); const tx = new Transaction() + .add(openIx) .add(beginSwapIx) .add(transferIn) .add(transferOut) @@ -798,9 +816,9 @@ describe('spot swap', () => { // Verify the token account is actually closed const accountInfo = await bankrunContextWrapper.connection.getAccountInfo( - takerUSDC + takerUsdcAta ); - assert(accountInfo === null, 'takerUSDC should be closed'); + assert(accountInfo === null, 'takerUsdcAta should be closed'); }); it('donate to revenue pool for a great feature!', async () => {