diff --git a/x/tron/types/address.go b/x/tron/types/address.go index b8dfcfb0..1b9914a9 100644 --- a/x/tron/types/address.go +++ b/x/tron/types/address.go @@ -49,12 +49,18 @@ func ValidateTronAddress(address string) error { return errors.New("empty") } if len(address) != tronaddress.AddressLengthBase58 { - return errors.New("wrong length") + return fmt.Errorf("invalid address length: expected %d chars, got %d", tronaddress.AddressLengthBase58, len(address)) } tronAddr, err := common.DecodeCheck(address) if err != nil { return errors.New("doesn't pass format validation") } + if len(tronAddr) != tronaddress.AddressLength { + return fmt.Errorf("invalid address length: expected decoded %d bytes, got %d", tronaddress.AddressLength, len(tronAddr)) + } + if tronAddr[0] != tronaddress.TronBytePrefix { + return errors.New("invalid tron prefix") + } expectAddress := common.EncodeCheck(tronAddr) if expectAddress != address { return fmt.Errorf("mismatch expected: %s, got: %s", expectAddress, address) diff --git a/x/tron/types/address_test.go b/x/tron/types/address_test.go index d6c9cd01..394115c6 100644 --- a/x/tron/types/address_test.go +++ b/x/tron/types/address_test.go @@ -1,14 +1,21 @@ package types_test import ( + "bytes" "testing" + tronaddress "github.com/fbsobreira/gotron-sdk/pkg/address" + troncommon "github.com/fbsobreira/gotron-sdk/pkg/common" "github.com/stretchr/testify/require" "github.com/openmetaearth/me-hub/x/tron/types" ) func TestValidateTronAddress(t *testing.T) { + nonTronPrefix := byte(0x00) + nonTronPayload := append([]byte{nonTronPrefix}, bytes.Repeat([]byte{0x11}, tronaddress.AddressLength-1)...) + shortTronPayload := append([]byte{tronaddress.TronBytePrefix}, bytes.Repeat([]byte{0x11}, tronaddress.AddressLength-2)...) + testCases := []struct { testName string value string @@ -25,13 +32,13 @@ func TestValidateTronAddress(t *testing.T) { testName: "address length not match", value: "abcdddddd", expectPass: false, - errStr: "wrong length", + errStr: "invalid address length", }, { testName: "address length great than tron address", value: "TR7NHqjeKQxGTCi8q8ZY4pL8otSzgjLj6t6666", expectPass: false, - errStr: "wrong length", + errStr: "invalid address length", }, { testName: "lowercase address", @@ -45,6 +52,18 @@ func TestValidateTronAddress(t *testing.T) { expectPass: false, errStr: "doesn't pass format validation", }, + { + testName: "base58check address with non-tron prefix", + value: troncommon.EncodeCheck(nonTronPayload), + expectPass: false, + errStr: "invalid tron prefix", + }, + { + testName: "base58check address with invalid decoded length", + value: troncommon.EncodeCheck(shortTronPayload), + expectPass: false, + errStr: "invalid address length", + }, { testName: "normal address", value: "TR7NHqjeKQxGTCi8q8ZY4pL8otSzgjLj6t", @@ -59,7 +78,7 @@ func TestValidateTronAddress(t *testing.T) { require.NoError(t, err) return } - require.EqualValues(t, testCase.errStr, err.Error(), testCase.value) + require.Contains(t, err.Error(), testCase.errStr, testCase.value) }) } }