Skip to content

Commit b298e21

Browse files
authored
table: fix the issue that the default value for BIT column is wrong (pingcap#57303) (pingcap#57356)
close pingcap#57301, close pingcap#57312
1 parent 2061937 commit b298e21

File tree

8 files changed

+92
-49
lines changed

8 files changed

+92
-49
lines changed

pkg/ddl/ddl_api.go

+24-7
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ import (
6464
"github.com/pingcap/tidb/pkg/util/dbterror"
6565
"github.com/pingcap/tidb/pkg/util/domainutil"
6666
"github.com/pingcap/tidb/pkg/util/hack"
67+
"github.com/pingcap/tidb/pkg/util/intest"
6768
"github.com/pingcap/tidb/pkg/util/logutil"
6869
"github.com/pingcap/tidb/pkg/util/mathutil"
6970
"github.com/pingcap/tidb/pkg/util/memory"
@@ -1031,6 +1032,21 @@ func checkColumnDefaultValue(ctx sessionctx.Context, col *table.Column, value in
10311032
}
10321033
}
10331034
}
1035+
if value != nil && col.GetType() == mysql.TypeBit {
1036+
v, ok := value.(string)
1037+
if !ok {
1038+
return hasDefaultValue, value, types.ErrInvalidDefault.GenWithStackByArgs(col.Name.O)
1039+
}
1040+
1041+
uintVal, err := types.BinaryLiteral(v).ToInt(ctx.GetSessionVars().StmtCtx)
1042+
if err != nil {
1043+
return hasDefaultValue, value, types.ErrInvalidDefault.GenWithStackByArgs(col.Name.O)
1044+
}
1045+
intest.Assert(col.GetFlen() > 0 && col.GetFlen() <= 64)
1046+
if col.GetFlen() < 64 && uintVal >= 1<<(uint64(col.GetFlen())) {
1047+
return hasDefaultValue, value, types.ErrInvalidDefault.GenWithStackByArgs(col.Name.O)
1048+
}
1049+
}
10341050
return hasDefaultValue, value, nil
10351051
}
10361052

@@ -5283,13 +5299,14 @@ func SetDefaultValue(ctx sessionctx.Context, col *table.Column, option *ast.Colu
52835299
}
52845300
col.DefaultIsExpr = isSeqExpr
52855301
}
5286-
5287-
if hasDefaultValue, value, err = checkColumnDefaultValue(ctx, col, value); err != nil {
5288-
return hasDefaultValue, errors.Trace(err)
5289-
}
5290-
value, err = convertTimestampDefaultValToUTC(ctx, value, col)
5291-
if err != nil {
5292-
return hasDefaultValue, errors.Trace(err)
5302+
if !col.DefaultIsExpr {
5303+
if hasDefaultValue, value, err = checkColumnDefaultValue(ctx, col, value); err != nil {
5304+
return hasDefaultValue, errors.Trace(err)
5305+
}
5306+
value, err = convertTimestampDefaultValToUTC(ctx, value, col)
5307+
if err != nil {
5308+
return hasDefaultValue, errors.Trace(err)
5309+
}
52935310
}
52945311
err = setDefaultValueWithBinaryPadding(col, value)
52955312
if err != nil {

pkg/executor/test/writetest/write_test.go

+3-3
Original file line numberDiff line numberDiff line change
@@ -1301,7 +1301,7 @@ func TestIssue18681(t *testing.T) {
13011301
tk := testkit.NewTestKit(t, store)
13021302
tk.MustExec("use test")
13031303
createSQL := `drop table if exists load_data_test;
1304-
create table load_data_test (a bit(1),b bit(1),c bit(1),d bit(1));`
1304+
create table load_data_test (a bit(1),b bit(1),c bit(1),d bit(1),e bit(32),f bit(1));`
13051305
tk.MustExec(createSQL)
13061306
tk.MustExec("load data local infile '/tmp/nonexistence.csv' ignore into table load_data_test")
13071307
ctx := tk.Session().(sessionctx.Context)
@@ -1311,7 +1311,7 @@ func TestIssue18681(t *testing.T) {
13111311
require.NotNil(t, ld)
13121312

13131313
deleteSQL := "delete from load_data_test"
1314-
selectSQL := "select bin(a), bin(b), bin(c), bin(d) from load_data_test;"
1314+
selectSQL := "select bin(a), bin(b), bin(c), bin(d), bin(e), bin(f) from load_data_test;"
13151315
ctx.GetSessionVars().StmtCtx.DupKeyAsWarning = true
13161316
ctx.GetSessionVars().StmtCtx.BadNullAsWarning = true
13171317

@@ -1322,7 +1322,7 @@ func TestIssue18681(t *testing.T) {
13221322
}()
13231323
sc.IgnoreTruncate.Store(false)
13241324
tests := []testCase{
1325-
{[]byte("true\tfalse\t0\t1\n"), []string{"1|0|0|1"}, "Records: 1 Deleted: 0 Skipped: 0 Warnings: 0"},
1325+
{[]byte("true\tfalse\t0\t1\tb'1'\tb'1'\n"), []string{"1|1|1|1|1100010001001110011000100100111|1"}, "Records: 1 Deleted: 0 Skipped: 0 Warnings: 5"},
13261326
}
13271327
checkCases(tests, ld, t, tk, ctx, selectSQL, deleteSQL)
13281328
require.Equal(t, uint16(0), sc.WarningCount())

pkg/types/datum.go

+1-26
Original file line numberDiff line numberDiff line change
@@ -1576,38 +1576,13 @@ func (d *Datum) ConvertToMysqlYear(sc *stmtctx.StatementContext, target *FieldTy
15761576
return ret, errors.Trace(err)
15771577
}
15781578

1579-
func (d *Datum) convertStringToMysqlBit(sc *stmtctx.StatementContext) (uint64, error) {
1580-
bitStr, err := ParseBitStr(BinaryLiteral(d.b).ToString())
1581-
if err != nil {
1582-
// It cannot be converted to bit type, so we need to convert it to int type.
1583-
return BinaryLiteral(d.b).ToInt(sc)
1584-
}
1585-
return bitStr.ToInt(sc)
1586-
}
1587-
15881579
func (d *Datum) convertToMysqlBit(sc *stmtctx.StatementContext, target *FieldType) (Datum, error) {
15891580
var ret Datum
15901581
var uintValue uint64
15911582
var err error
15921583
switch d.k {
1593-
case KindBytes:
1584+
case KindString, KindBytes:
15941585
uintValue, err = BinaryLiteral(d.b).ToInt(sc)
1595-
case KindString:
1596-
// For single bit value, we take string like "true", "1" as 1, and "false", "0" as 0,
1597-
// this behavior is not documented in MySQL, but it behaves so, for more information, see issue #18681
1598-
s := BinaryLiteral(d.b).ToString()
1599-
if target.GetFlen() == 1 {
1600-
switch strings.ToLower(s) {
1601-
case "true", "1":
1602-
uintValue = 1
1603-
case "false", "0":
1604-
uintValue = 0
1605-
default:
1606-
uintValue, err = d.convertStringToMysqlBit(sc)
1607-
}
1608-
} else {
1609-
uintValue, err = d.convertStringToMysqlBit(sc)
1610-
}
16111586
case KindInt64:
16121587
// if input kind is int64 (signed), when trans to bit, we need to treat it as unsigned
16131588
d.k = KindUint64

pkg/types/datum_test.go

+26-13
Original file line numberDiff line numberDiff line change
@@ -527,24 +527,37 @@ func prepareCompareDatums() ([]Datum, []Datum) {
527527

528528
func TestStringToMysqlBit(t *testing.T) {
529529
tests := []struct {
530-
a Datum
531-
out []byte
530+
a Datum
531+
out []byte
532+
flen int
533+
truncated bool
532534
}{
533-
{NewStringDatum("true"), []byte{1}},
534-
{NewStringDatum("false"), []byte{0}},
535-
{NewStringDatum("1"), []byte{1}},
536-
{NewStringDatum("0"), []byte{0}},
537-
{NewStringDatum("b'1'"), []byte{1}},
538-
{NewStringDatum("b'0'"), []byte{0}},
535+
{NewStringDatum("true"), []byte{1}, 1, true},
536+
{NewStringDatum("true"), []byte{0x74, 0x72, 0x75, 0x65}, 32, false},
537+
{NewStringDatum("false"), []byte{0x1}, 1, true},
538+
{NewStringDatum("false"), []byte{0x66, 0x61, 0x6c, 0x73, 0x65}, 40, false},
539+
{NewStringDatum("1"), []byte{1}, 1, true},
540+
{NewStringDatum("1"), []byte{0x31}, 8, false},
541+
{NewStringDatum("0"), []byte{1}, 1, true},
542+
{NewStringDatum("0"), []byte{0x30}, 8, false},
543+
{NewStringDatum("b'1'"), []byte{0x62, 0x27, 0x31, 0x27}, 32, false},
544+
{NewStringDatum("b'0'"), []byte{0x62, 0x27, 0x30, 0x27}, 32, false},
539545
}
540546
sc := stmtctx.NewStmtCtx()
541547
sc.IgnoreTruncate.Store(true)
542-
tp := NewFieldType(mysql.TypeBit)
543-
tp.SetFlen(1)
544548
for _, tt := range tests {
545-
bin, err := tt.a.convertToMysqlBit(nil, tp)
546-
require.NoError(t, err)
547-
require.Equal(t, tt.out, bin.b)
549+
t.Run(fmt.Sprintf("%s %d %t", tt.a.GetString(), tt.flen, tt.truncated), func(t *testing.T) {
550+
tp := NewFieldType(mysql.TypeBit)
551+
tp.SetFlen(tt.flen)
552+
553+
bin, err := tt.a.convertToMysqlBit(sc, tp)
554+
if tt.truncated {
555+
require.Contains(t, err.Error(), "Data Too Long")
556+
} else {
557+
require.NoError(t, err)
558+
}
559+
require.Equal(t, tt.out, bin.b)
560+
})
548561
}
549562
}
550563

tests/integrationtest/r/ddl/column.result

+12
Original file line numberDiff line numberDiff line change
@@ -65,3 +65,15 @@ t CREATE TABLE `t` (
6565
`a` decimal(10,0) DEFAULT NULL,
6666
`b` decimal(10,0) DEFAULT NULL
6767
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin
68+
drop table if exists t;
69+
create table t(a bit(2) default b'111');
70+
Error 1067 (42000): Invalid default value for 'a'
71+
create table t(a bit(65) default b'111');
72+
Error 1439 (42000): Display width out of range for column 'a' (max = 64)
73+
create table t(a bit(64) default b'1111111111111111111111111111111111111111111111111111111111111111');
74+
drop table t;
75+
create table t(a bit(3) default b'111');
76+
drop table t;
77+
create table t(a bit(3) default b'000111');
78+
drop table t;
79+
create table t(a bit(32) default b'1111111111111111111111111111111');

tests/integrationtest/r/table/tables.result

+6
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,9 @@ select count(distinct(_tidb_rowid>>48)) from shard_t;
66
count(distinct(_tidb_rowid>>48))
77
4
88
set @@tidb_shard_allocate_step=default;
9+
drop table if exists t;
10+
create table t(a bit(32) default b'1100010001001110011000100100111');
11+
insert into t values ();
12+
select hex(a) from t;
13+
hex(a)
14+
62273127

tests/integrationtest/t/ddl/column.test

+14
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,17 @@ show create table t2;
2222
drop table if exists t;
2323
create table t(a decimal(0,0), b decimal(0));
2424
show create table t;
25+
26+
# TestTooLongDefaultValueForBit
27+
drop table if exists t;
28+
-- error 1067
29+
create table t(a bit(2) default b'111');
30+
-- error 1439
31+
create table t(a bit(65) default b'111');
32+
create table t(a bit(64) default b'1111111111111111111111111111111111111111111111111111111111111111');
33+
drop table t;
34+
create table t(a bit(3) default b'111');
35+
drop table t;
36+
create table t(a bit(3) default b'000111');
37+
drop table t;
38+
create table t(a bit(32) default b'1111111111111111111111111111111');

tests/integrationtest/t/table/tables.test

+6
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,9 @@ set @@tidb_shard_allocate_step=3;
55
insert into shard_t values (1), (2), (3), (4), (5), (6), (7), (8), (9), (10), (11);
66
select count(distinct(_tidb_rowid>>48)) from shard_t;
77
set @@tidb_shard_allocate_step=default;
8+
9+
# TestInsertBitDefaultValue
10+
drop table if exists t;
11+
create table t(a bit(32) default b'1100010001001110011000100100111');
12+
insert into t values ();
13+
select hex(a) from t;

0 commit comments

Comments
 (0)