Skip to content

Commit

Permalink
Allow enableHigherPrecision to be used in arrow batches (#1080)
Browse files Browse the repository at this point in the history
  • Loading branch information
Yifeng-Sigma authored Mar 27, 2024
1 parent bd8b73b commit 2141603
Show file tree
Hide file tree
Showing 4 changed files with 169 additions and 7 deletions.
17 changes: 12 additions & 5 deletions converter.go
Original file line number Diff line number Diff line change
Expand Up @@ -1007,8 +1007,9 @@ func getArrowBatchesTimestampOption(ctx context.Context) snowflakeArrowBatchesTi

func arrowToRecord(ctx context.Context, record arrow.Record, pool memory.Allocator, rowType []execResponseRowType, loc *time.Location) (arrow.Record, error) {
arrowBatchesTimestampOption := getArrowBatchesTimestampOption(ctx)
higherPrecisionEnabled := higherPrecisionEnabled(ctx)

s, err := recordToSchema(record.Schema(), rowType, loc, arrowBatchesTimestampOption)
s, err := recordToSchema(record.Schema(), rowType, loc, arrowBatchesTimestampOption, higherPrecisionEnabled)
if err != nil {
return nil, err
}
Expand All @@ -1026,7 +1027,9 @@ func arrowToRecord(ctx context.Context, record arrow.Record, pool memory.Allocat
switch snowflakeType {
case fixedType:
var toType arrow.DataType
if col.DataType().ID() == arrow.DECIMAL || col.DataType().ID() == arrow.DECIMAL256 {
if higherPrecisionEnabled {
// do nothing - return decimal as is
} else if col.DataType().ID() == arrow.DECIMAL || col.DataType().ID() == arrow.DECIMAL256 {
if srcColumnMeta.Scale == 0 {
toType = arrow.PrimitiveTypes.Int64
} else {
Expand Down Expand Up @@ -1151,7 +1154,7 @@ func arrowToRecord(ctx context.Context, record arrow.Record, pool memory.Allocat
return array.NewRecord(s, cols, numRows), nil
}

func recordToSchema(sc *arrow.Schema, rowType []execResponseRowType, loc *time.Location, timestampOption snowflakeArrowBatchesTimestampOption) (*arrow.Schema, error) {
func recordToSchema(sc *arrow.Schema, rowType []execResponseRowType, loc *time.Location, timestampOption snowflakeArrowBatchesTimestampOption, withHigherPrecision bool) (*arrow.Schema, error) {
var fields []arrow.Field
for i := 0; i < len(sc.Fields()); i++ {
f := sc.Field(i)
Expand All @@ -1163,13 +1166,17 @@ func recordToSchema(sc *arrow.Schema, rowType []execResponseRowType, loc *time.L
case fixedType:
switch f.Type.ID() {
case arrow.DECIMAL:
if srcColumnMeta.Scale == 0 {
if withHigherPrecision {
converted = false
} else if srcColumnMeta.Scale == 0 {
t = &arrow.Int64Type{}
} else {
t = &arrow.Float64Type{}
}
default:
if srcColumnMeta.Scale != 0 {
if withHigherPrecision {
converted = false
} else if srcColumnMeta.Scale != 0 {
t = &arrow.Float64Type{}
} else {
converted = false
Expand Down
152 changes: 150 additions & 2 deletions converter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -918,6 +918,7 @@ func TestArrowToRecord(t *testing.T) {
error string
arrowBatchesTimestampOption snowflakeArrowBatchesTimestampOption
enableArrowBatchesUtf8Validation bool
withHigherPrecision bool
nrows int
builder array.Builder
append func(b array.Builder, vs interface{})
Expand All @@ -934,7 +935,7 @@ func TestArrowToRecord(t *testing.T) {
},
{
logical: "fixed",
physical: "number(38,0)",
physical: "int64",
sc: arrow.NewSchema([]arrow.Field{{Type: &arrow.Decimal128Type{Precision: 38, Scale: 0}}}, nil),
values: []string{"10000000000000000000000000000000000000", "-12345678901234567890123456789012345678"},
nrows: 2,
Expand Down Expand Up @@ -963,9 +964,40 @@ func TestArrowToRecord(t *testing.T) {
return -1
},
},
{
logical: "fixed",
physical: "number(38,0)",
sc: arrow.NewSchema([]arrow.Field{{Type: &arrow.Decimal128Type{Precision: 38, Scale: 0}}}, nil),
values: []string{"10000000000000000000000000000000000000", "-12345678901234567890123456789012345678"},
withHigherPrecision: true,
nrows: 2,
builder: array.NewDecimal128Builder(pool, &arrow.Decimal128Type{Precision: 38, Scale: 0}),
append: func(b array.Builder, vs interface{}) {
for _, s := range vs.([]string) {
num, ok := stringIntToDecimal(s)
if !ok {
t.Fatalf("failed to convert to Int64")
}
b.(*array.Decimal128Builder).Append(num)
}
},
compare: func(src interface{}, expected interface{}, convertedRec arrow.Record) int {
srcvs := src.([]string)
for i, dec := range convertedRec.Column(0).(*array.Decimal128).Values() {
srcDec, ok := stringIntToDecimal(srcvs[i])
if !ok {
return i
}
if srcDec != dec {
return i
}
}
return -1
},
},
{
logical: "fixed",
physical: "number(38,37)",
physical: "float64",
rowType: execResponseRowType{Scale: 37},
sc: arrow.NewSchema([]arrow.Field{{Type: &arrow.Decimal128Type{Precision: 38, Scale: 37}}}, nil),
values: []string{"1.2345678901234567890123456789012345678", "-9.999999999999999"},
Expand Down Expand Up @@ -995,6 +1027,38 @@ func TestArrowToRecord(t *testing.T) {
return -1
},
},
{
logical: "fixed",
physical: "number(38,37)",
rowType: execResponseRowType{Scale: 37},
sc: arrow.NewSchema([]arrow.Field{{Type: &arrow.Decimal128Type{Precision: 38, Scale: 37}}}, nil),
values: []string{"1.2345678901234567890123456789012345678", "-9.999999999999999"},
withHigherPrecision: true,
nrows: 2,
builder: array.NewDecimal128Builder(pool, &arrow.Decimal128Type{Precision: 38, Scale: 37}),
append: func(b array.Builder, vs interface{}) {
for _, s := range vs.([]string) {
num, err := decimal128.FromString(s, 38, 37)
if err != nil {
t.Fatalf("failed to convert to decimal: %s", err)
}
b.(*array.Decimal128Builder).Append(num)
}
},
compare: func(src interface{}, expected interface{}, convertedRec arrow.Record) int {
srcvs := src.([]string)
for i, dec := range convertedRec.Column(0).(*array.Decimal128).Values() {
srcDec, err := decimal128.FromString(srcvs[i], 38, 37)
if err != nil {
return i
}
if srcDec != dec {
return i
}
}
return -1
},
},
{
logical: "fixed",
physical: "int8",
Expand Down Expand Up @@ -1051,6 +1115,26 @@ func TestArrowToRecord(t *testing.T) {
return -1
},
},
{
logical: "fixed",
physical: "int8",
rowType: execResponseRowType{Scale: 1},
sc: arrow.NewSchema([]arrow.Field{{Type: &arrow.Int8Type{}}}, nil),
values: []int8{10, 16},
withHigherPrecision: true,
nrows: 2,
builder: array.NewInt8Builder(pool),
append: func(b array.Builder, vs interface{}) { b.(*array.Int8Builder).AppendValues(vs.([]int8), valids) },
compare: func(src interface{}, expected interface{}, convertedRec arrow.Record) int {
srcvs := src.([]int8)
for i, f := range convertedRec.Column(0).(*array.Int8).Int8Values() {
if srcvs[i] != f {
return i
}
}
return -1
},
},
{
logical: "fixed",
physical: "float16",
Expand All @@ -1071,6 +1155,26 @@ func TestArrowToRecord(t *testing.T) {
return -1
},
},
{
logical: "fixed",
physical: "int16",
rowType: execResponseRowType{Scale: 1},
sc: arrow.NewSchema([]arrow.Field{{Type: &arrow.Int16Type{}}}, nil),
values: []int16{20, 26},
withHigherPrecision: true,
nrows: 2,
builder: array.NewInt16Builder(pool),
append: func(b array.Builder, vs interface{}) { b.(*array.Int16Builder).AppendValues(vs.([]int16), valids) },
compare: func(src interface{}, expected interface{}, convertedRec arrow.Record) int {
srcvs := src.([]int16)
for i, f := range convertedRec.Column(0).(*array.Int16).Int16Values() {
if srcvs[i] != f {
return i
}
}
return -1
},
},
{
logical: "fixed",
physical: "float32",
Expand All @@ -1091,6 +1195,26 @@ func TestArrowToRecord(t *testing.T) {
return -1
},
},
{
logical: "fixed",
physical: "int32",
rowType: execResponseRowType{Scale: 2},
sc: arrow.NewSchema([]arrow.Field{{Type: &arrow.Int32Type{}}}, nil),
values: []int32{200, 265},
withHigherPrecision: true,
nrows: 2,
builder: array.NewInt32Builder(pool),
append: func(b array.Builder, vs interface{}) { b.(*array.Int32Builder).AppendValues(vs.([]int32), valids) },
compare: func(src interface{}, expected interface{}, convertedRec arrow.Record) int {
srcvs := src.([]int32)
for i, f := range convertedRec.Column(0).(*array.Int32).Int32Values() {
if srcvs[i] != f {
return i
}
}
return -1
},
},
{
logical: "fixed",
physical: "float64",
Expand All @@ -1111,6 +1235,26 @@ func TestArrowToRecord(t *testing.T) {
return -1
},
},
{
logical: "fixed",
physical: "int64",
rowType: execResponseRowType{Scale: 5},
sc: arrow.NewSchema([]arrow.Field{{Type: &arrow.Int64Type{}}}, nil),
values: []int64{12345, 234567},
withHigherPrecision: true,
nrows: 2,
builder: array.NewInt64Builder(pool),
append: func(b array.Builder, vs interface{}) { b.(*array.Int64Builder).AppendValues(vs.([]int64), valids) },
compare: func(src interface{}, expected interface{}, convertedRec arrow.Record) int {
srcvs := src.([]int64)
for i, f := range convertedRec.Column(0).(*array.Int64).Int64Values() {
if srcvs[i] != f {
return i
}
}
return -1
},
},
{
logical: "boolean",
sc: arrow.NewSchema([]arrow.Field{{Type: &arrow.BooleanType{}}}, nil),
Expand Down Expand Up @@ -1880,6 +2024,10 @@ func TestArrowToRecord(t *testing.T) {
ctx = WithArrowBatchesUtf8Validation(ctx)
}

if tc.withHigherPrecision {
ctx = WithHigherPrecision(ctx)
}

transformedRec, err := arrowToRecord(ctx, rawRec, pool, []execResponseRowType{meta}, localTime.Location())
if err != nil {
if tc.error == "" || !strings.Contains(err.Error(), tc.error) {
Expand Down
6 changes: 6 additions & 0 deletions doc.go
Original file line number Diff line number Diff line change
Expand Up @@ -505,6 +505,12 @@ To address this issue and prevent potential downstream disruptions, the context
When enabled, this feature iterates through all values in string columns, identifying and replacing any invalid characters with `�`.
This ensures that Arrow records conform to the UTF-8 standards, preventing validation failures in downstream services like the Rust Arrow library that impose strict validation checks.
### WithHigherPrecision in Arrow batches
To preserve BigDecimal values within Arrow batches, use `WithHigherPrecision`.
This offers two main benefits: it helps avoid precision loss and defers the conversion to upstream services.
Alternatively, without this setting, all non-zero scale numbers will be converted to float64, potentially resulting in loss of precision.
Zero-scale numbers (DECIMAL256, DECIMAL128) will be converted to int64, which could lead to overflow.
# Binding Parameters
Binding allows a SQL statement to use a value that is stored in a Golang variable.
Expand Down
1 change: 1 addition & 0 deletions util.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ func WithDescribeOnly(ctx context.Context) context.Context {
// WithHigherPrecision returns a context that enables higher precision by
// returning a *big.Int or *big.Float variable when querying rows for column
// types with numbers that don't fit into its native Golang counterpart
// When used in combination with WithArrowBatches, original BigDecimal in arrow batches will be preserved.
func WithHigherPrecision(ctx context.Context) context.Context {
return context.WithValue(ctx, enableHigherPrecision, true)
}
Expand Down

0 comments on commit 2141603

Please sign in to comment.