diff --git a/ts/packages/anchor/src/coder/borsh/instruction.ts b/ts/packages/anchor/src/coder/borsh/instruction.ts index 8b99a0f830..27e31b63e0 100644 --- a/ts/packages/anchor/src/coder/borsh/instruction.ts +++ b/ts/packages/anchor/src/coder/borsh/instruction.ts @@ -25,7 +25,7 @@ export class BorshInstructionCoder implements InstructionCoder { // Instruction args layout. Maps namespaced method private ixLayouts: Map< string, - { discriminator: IdlDiscriminator; layout: Layout } + { discriminator: IdlDiscriminator; layout: Layout; args: IdlField[] } >; public constructor(private idl: Idl) { @@ -35,7 +35,10 @@ export class BorshInstructionCoder implements InstructionCoder { IdlCoder.fieldLayout(arg, idl.types) ); const layout = borsh.struct(fieldLayouts, name); - return [name, { discriminator: ix.discriminator, layout }] as const; + return [ + name, + { discriminator: ix.discriminator, layout, args: ix.args }, + ] as const; }); this.ixLayouts = new Map(ixLayouts); } @@ -50,12 +53,186 @@ export class BorshInstructionCoder implements InstructionCoder { throw new Error(`Unknown method: ${ixName}`); } - const len = encoder.layout.encode(ix, buffer); + const ixWithDefinedOptions = BorshInstructionCoder.convertUndefinedOptions( + encoder.args, + ix, + this.idl.types + ); + const len = encoder.layout.encode(ixWithDefinedOptions, buffer); const data = buffer.slice(0, len); return Buffer.concat([Buffer.from(encoder.discriminator), data]); } + private static convertUndefinedOptions( + args: IdlField[], + ix: any, + types: IdlTypeDef[] = [] + ): any { + const converted = { ...ix }; + args.forEach((arg) => { + converted[arg.name] = BorshInstructionCoder.convertUndefinedOption( + arg.type, + ix[arg.name], + types + ); + }); + return converted; + } + + private static convertUndefinedOption( + idlType: IdlType, + value: any, + types: IdlTypeDef[] + ): any { + if (typeof idlType === "string") { + return value; + } + + if ("option" in idlType) { + if (value === undefined) { + return null; + } + if (value === null) { + return null; + } + return BorshInstructionCoder.convertUndefinedOption( + idlType.option, + value, + types + ); + } + + if ("vec" in idlType) { + return Array.isArray(value) + ? value.map((item) => + BorshInstructionCoder.convertUndefinedOption( + idlType.vec, + item, + types + ) + ) + : value; + } + + if ("array" in idlType) { + return Array.isArray(value) + ? value.map((item) => + BorshInstructionCoder.convertUndefinedOption( + idlType.array[0], + item, + types + ) + ) + : value; + } + + if ("defined" in idlType) { + const typeDef = types.find((t) => t.name === idlType.defined.name); + if (!typeDef) { + return value; + } + return BorshInstructionCoder.convertUndefinedOptionDefined( + typeDef, + value, + types + ); + } + + return value; + } + + private static convertUndefinedOptionDefined( + typeDef: IdlTypeDef, + value: any, + types: IdlTypeDef[] + ): any { + if (value === null || value === undefined || typeof value !== "object") { + return value; + } + + switch (typeDef.type.kind) { + case "struct": { + return handleDefinedFields( + typeDef.type.fields, + () => value, + (fields) => { + const converted = { ...value }; + fields.forEach((field) => { + converted[field.name] = + BorshInstructionCoder.convertUndefinedOption( + field.type, + value[field.name], + types + ); + }); + return converted; + }, + (fields) => { + const converted = Array.isArray(value) ? [...value] : { ...value }; + fields.forEach((field, index) => { + converted[index] = BorshInstructionCoder.convertUndefinedOption( + field, + value[index], + types + ); + }); + return converted; + } + ); + } + case "enum": { + const variantName = Object.keys(value)[0]; + const variant = typeDef.type.variants.find( + (v) => v.name === variantName + ); + if (!variant) { + return value; + } + + return { + ...value, + [variantName]: handleDefinedFields( + variant.fields, + () => value[variantName], + (fields) => { + const converted = { ...value[variantName] }; + fields.forEach((field) => { + converted[field.name] = + BorshInstructionCoder.convertUndefinedOption( + field.type, + value[variantName][field.name], + types + ); + }); + return converted; + }, + (fields) => { + const converted = Array.isArray(value[variantName]) + ? [...value[variantName]] + : { ...value[variantName] }; + fields.forEach((field, index) => { + converted[index] = BorshInstructionCoder.convertUndefinedOption( + field, + value[variantName][index], + types + ); + }); + return converted; + } + ), + }; + } + case "type": { + return BorshInstructionCoder.convertUndefinedOption( + typeDef.type.alias, + value, + types + ); + } + } + } + /** * Decodes a program instruction. */ diff --git a/ts/packages/anchor/tests/coder-instructions.spec.ts b/ts/packages/anchor/tests/coder-instructions.spec.ts index a87519c3c5..61bd1c28fe 100644 --- a/ts/packages/anchor/tests/coder-instructions.spec.ts +++ b/ts/packages/anchor/tests/coder-instructions.spec.ts @@ -1,4 +1,5 @@ import * as assert from "assert"; +import BN from "bn.js"; import { BorshCoder } from "../src"; import { Idl, IdlType } from "../src/idl"; import { toInstruction } from "../src/program/common"; @@ -53,4 +54,68 @@ describe("coder.instructions", () => { assert.deepStrictEqual(decoded?.data[idlIx.args[0].name], expected); }); + + test("Encodes undefined option instruction arguments as null", () => { + const idl: Idl = { + address: "Test111111111111111111111111111111111111111", + metadata: { + name: "test", + version: "0.0.0", + spec: "0.1.0", + }, + instructions: [ + { + name: "initialize", + discriminator: [0, 1, 2, 3, 4, 5, 6, 7], + accounts: [], + args: [ + { + name: "required", + type: "bool", + }, + { + name: "someArg", + type: { + option: "u64", + }, + }, + { + name: "someArg2", + type: { + option: "u64", + }, + }, + ], + }, + ], + types: [], + }; + + const idlIx = idl.instructions[0]; + const coder = new BorshCoder(idl); + const ix = toInstruction(idlIx, true, undefined, new BN(0x3030)); + + const encoded = coder.instruction.encode(idlIx.name, ix); + const decoded = coder.instruction.decode(encoded); + + assert.deepStrictEqual( + [...encoded], + [ + ...idlIx.discriminator, + 1, // required bool + 0, // someArg: None + 1, // someArg2: Some + 0x30, + 0x30, + 0, + 0, + 0, + 0, + 0, + 0, + ] + ); + assert.strictEqual(decoded?.data["someArg"], null); + assert.ok(decoded?.data["someArg2"].eq(new BN(0x3030))); + }); });