diff --git a/base.go b/base.go index 2bfe6f8..9d1d621 100644 --- a/base.go +++ b/base.go @@ -1,6 +1,7 @@ package hood import ( + "database/sql" "fmt" "reflect" "strings" @@ -67,6 +68,21 @@ func (d *base) SetModelValue(driverValue, fieldValue reflect.Value) error { } else { panic(fmt.Sprintf("cannot set created value %T", driverValue.Elem().Interface())) } + } else if fieldType == reflect.TypeOf(sql.NullBool{}) { + if b, ok := driverValue.Elem().Interface().(bool); ok { + fieldValue.Set(reflect.ValueOf(sql.NullBool{b, true})) + } + } else if fieldType == reflect.TypeOf(sql.NullFloat64{}) { + if f, ok := driverValue.Elem().Interface().(float64); ok { + fieldValue.Set(reflect.ValueOf(sql.NullFloat64{f, true})) + } + } else if fieldType == reflect.TypeOf(sql.NullInt64{}) { + if i, ok := driverValue.Elem().Interface().(int64); ok { + fieldValue.Set(reflect.ValueOf(sql.NullInt64{i, true})) + } + } else if fieldType == reflect.TypeOf(sql.NullString{}) { + str := string(driverValue.Elem().Bytes()) + fieldValue.Set(reflect.ValueOf(sql.NullString{str, true})) } } return nil diff --git a/mysql.go b/mysql.go index 64e7514..07baddb 100644 --- a/mysql.go +++ b/mysql.go @@ -1,6 +1,7 @@ package hood import ( + "database/sql" "fmt" "reflect" "time" @@ -38,20 +39,20 @@ func (d *mysql) SqlType(f interface{}, size int) string { return "bigint" case time.Time, Created, Updated: return "timestamp" - case bool: + case bool, sql.NullBool: return "boolean" case int, int8, int16, int32, uint, uint8, uint16, uint32: return "int" - case int64, uint64: + case int64, uint64, sql.NullInt64: return "bigint" - case float32, float64: + case float32, float64, sql.NullFloat64: return "double" case []byte: if size > 0 && size < 65532 { return fmt.Sprintf("varbinary(%d)", size) } return "longblob" - case string: + case string, sql.NullString: if size > 0 && size < 65532 { return fmt.Sprintf("varchar(%d)", size) } diff --git a/postgres.go b/postgres.go index 5f171c4..4c27fa2 100644 --- a/postgres.go +++ b/postgres.go @@ -1,6 +1,7 @@ package hood import ( + "database/sql" "fmt" _ "github.com/lib/pq" "strings" @@ -27,17 +28,17 @@ func (d *postgres) SqlType(f interface{}, size int) string { return "bigserial" case time.Time, Created, Updated: return "timestamp with time zone" - case bool: + case bool, sql.NullBool: return "boolean" case int, int8, int16, int32, uint, uint8, uint16, uint32: return "integer" - case int64, uint64: + case int64, uint64, sql.NullInt64: return "bigint" - case float32, float64: + case float32, float64, sql.NullFloat64: return "double precision" case []byte: return "bytea" - case string: + case string, sql.NullString: if size > 0 && size < 65532 { return fmt.Sprintf("varchar(%d)", size) }