Skip to content

Commit bae606f

Browse files
committed
added more tests and refactored asSQLNull
1 parent 9fa14d0 commit bae606f

File tree

2 files changed

+55
-35
lines changed

2 files changed

+55
-35
lines changed

internal/bind/params.go

Lines changed: 15 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -51,52 +51,32 @@ func asUUID(v any) (value.Value, bool) {
5151
func asSQLNull(v any) (value.Value, bool) {
5252
switch x := v.(type) {
5353
case sql.NullBool:
54-
if x.Valid {
55-
return value.OptionalValue(value.BoolValue(x.Bool)), true
56-
}
57-
58-
return value.NullValue(types.Bool), true
54+
return wrapWithNulls(x.Valid, value.BoolValue(x.Bool), types.Bool), true
5955
case sql.NullFloat64:
60-
if x.Valid {
61-
return value.OptionalValue(value.DoubleValue(x.Float64)), true
62-
}
63-
64-
return value.NullValue(types.Double), true
56+
return wrapWithNulls(x.Valid, value.DoubleValue(x.Float64), types.Double), true
6557
case sql.NullInt16:
66-
if x.Valid {
67-
return value.OptionalValue(value.Int16Value(x.Int16)), true
68-
}
69-
70-
return value.NullValue(types.Int16), true
58+
return wrapWithNulls(x.Valid, value.Int16Value(x.Int16), types.Int16), true
7159
case sql.NullInt32:
72-
if x.Valid {
73-
return value.OptionalValue(value.Int32Value(x.Int32)), true
74-
}
75-
76-
return value.NullValue(types.Int32), true
60+
return wrapWithNulls(x.Valid, value.Int32Value(x.Int32), types.Int32), true
7761
case sql.NullInt64:
78-
if x.Valid {
79-
return value.OptionalValue(value.Int64Value(x.Int64)), true
80-
}
81-
82-
return value.NullValue(types.Int64), true
62+
return wrapWithNulls(x.Valid, value.Int64Value(x.Int64), types.Int64), true
8363
case sql.NullString:
84-
if x.Valid {
85-
return value.OptionalValue(value.TextValue(x.String)), true
86-
}
87-
88-
return value.NullValue(types.Text), true
64+
return wrapWithNulls(x.Valid, value.TextValue(x.String), types.Text), true
8965
case sql.NullTime:
90-
if x.Valid {
91-
return value.OptionalValue(value.TimestampValueFromTime(x.Time)), true
92-
}
93-
94-
return value.NullValue(types.Timestamp), true
66+
return wrapWithNulls(x.Valid, value.TimestampValueFromTime(x.Time), types.Timestamp), true
9567
}
9668

9769
return asSQLNullGeneric(v)
9870
}
9971

72+
func wrapWithNulls(valid bool, val value.Value, t types.Type) value.Value {
73+
if valid {
74+
return value.OptionalValue(val)
75+
}
76+
77+
return value.NullValue(t)
78+
}
79+
10080
func asSQLNullGeneric(v any) (value.Value, bool) {
10181
if v == nil {
10282
return nil, false

internal/bind/params_test.go

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1095,6 +1095,46 @@ func TestSQLNullTypes(t *testing.T) {
10951095
src: sql.Null[int64]{V: 0, Valid: false},
10961096
expected: value.NullValue(types.Int64),
10971097
},
1098+
{
1099+
name: "sql.Null[bool] valid",
1100+
src: sql.Null[bool]{V: true, Valid: true},
1101+
expected: value.OptionalValue(value.BoolValue(true)),
1102+
},
1103+
{
1104+
name: "sql.Null[bool] invalid",
1105+
src: sql.Null[bool]{V: false, Valid: false},
1106+
expected: value.NullValue(types.Bool),
1107+
},
1108+
{
1109+
name: "sql.Null[float64] valid",
1110+
src: sql.Null[float64]{V: 3.14, Valid: true},
1111+
expected: value.OptionalValue(value.DoubleValue(3.14)),
1112+
},
1113+
{
1114+
name: "sql.Null[float64] invalid",
1115+
src: sql.Null[float64]{V: 0, Valid: false},
1116+
expected: value.NullValue(types.Double),
1117+
},
1118+
{
1119+
name: "sql.Null[time.Time] valid",
1120+
src: sql.Null[time.Time]{V: time.Date(2024, 2, 3, 4, 5, 6, 7, time.UTC), Valid: true},
1121+
expected: value.OptionalValue(value.TimestampValueFromTime(time.Date(2024, 2, 3, 4, 5, 6, 7, time.UTC))),
1122+
},
1123+
{
1124+
name: "sql.Null[time.Time] invalid",
1125+
src: sql.Null[time.Time]{V: time.Time{}, Valid: false},
1126+
expected: value.NullValue(types.Timestamp),
1127+
},
1128+
{
1129+
name: "sql.Null[[]byte] valid",
1130+
src: sql.Null[[]byte]{V: []byte("abc"), Valid: true},
1131+
expected: value.OptionalValue(value.BytesValue([]byte("abc"))),
1132+
},
1133+
{
1134+
name: "sql.Null[[]byte] invalid",
1135+
src: sql.Null[[]byte]{V: nil, Valid: false},
1136+
expected: value.NullValue(types.Bytes),
1137+
},
10981138
}
10991139

11001140
for _, tt := range tests {

0 commit comments

Comments
 (0)