diff --git a/builder/delete.go b/builder/delete.go index 80ab943..17366cd 100644 --- a/builder/delete.go +++ b/builder/delete.go @@ -15,6 +15,7 @@ type DeleteBuilder struct { limit string offset string suffixes []Sqlizer + returning []string } func (b *DeleteBuilder) ToSql() (sqlStr string, args []any, err error) { @@ -60,6 +61,11 @@ func (b *DeleteBuilder) ToSql() (sqlStr string, args []any, err error) { sql.WriteString(b.offset) } + if len(b.returning) > 0 { + sql.WriteString(" RETURNING ") + sql.WriteString(strings.Join(b.returning, ",")) + } + if len(b.suffixes) > 0 { sql.WriteString(" ") args, err = appendToSql(b.suffixes, sql, " ", args) @@ -125,3 +131,9 @@ func (b *DeleteBuilder) SuffixExpr(expr Sqlizer) *DeleteBuilder { b.suffixes = append(b.suffixes, expr) return b } + +// Returning adds a RETURNING clause to the query. +func (b *DeleteBuilder) Returning(columns ...string) *DeleteBuilder { + b.returning = append(b.returning, columns...) + return b +} diff --git a/builder/expr.go b/builder/expr.go index a21d7db..eef7d5d 100644 --- a/builder/expr.go +++ b/builder/expr.go @@ -2,11 +2,15 @@ package builder import ( "bytes" + "cmp" "database/sql/driver" "fmt" "reflect" + "slices" "sort" "strings" + + "github.com/samber/lo" ) const ( @@ -109,7 +113,7 @@ func (ce concatExpr) ToSql() (sql string, args []any, err error) { // // name_expr := Expr("CONCAT(?, ' ', ?)", firstName, lastName) // ConcatExpr("COALESCE(full_name,", name_expr, ")") -func ConcatExpr(parts ...any) concatExpr { +func ConcatExpr(parts ...any) Sqlizer { return concatExpr(parts) } @@ -124,7 +128,7 @@ type aliasExpr struct { // Ex: // // .AddColumn(Alias(caseStmt, "case_column")) -func Alias(expr Sqlizer, alias string) aliasExpr { +func Alias(expr Sqlizer, alias string) Sqlizer { return aliasExpr{expr, alias} } @@ -147,7 +151,7 @@ func (eq Eq) toSQL(useNotOpr bool) (sql string, args []any, err error) { } var ( - exprs []string + exprs []Sqlizer equalOpr = "=" nullOpr = "IS" ) @@ -156,10 +160,10 @@ func (eq Eq) toSQL(useNotOpr bool) (sql string, args []any, err error) { equalOpr = "<>" nullOpr = "IS NOT" } - - sortedKeys := getSortedKeys(eq) - for _, key := range sortedKeys { - var expr string + keys := lo.Keys(eq) + slices.Sort(keys) + for _, key := range keys { + var e Sqlizer val := eq[key] switch v := val.(type) { @@ -175,22 +179,38 @@ func (eq Eq) toSQL(useNotOpr bool) (sql string, args []any, err error) { val = nil } else { val = r.Elem().Interface() + r = reflect.ValueOf(val) } } if val == nil { - expr = fmt.Sprintf("%s %s NULL", key, nullOpr) + e = Expr(fmt.Sprintf("%s %s NULL", key, nullOpr)) } else { - if isListType(val) { - err = fmt.Errorf("cannot use array or slice with Eq operators") - return + if r.Kind() == reflect.Slice || r.Kind() == reflect.Array { + if _, ok := val.([]byte); !ok { + err = fmt.Errorf("cannot use array or slice with Eq operators") + return + } } - expr = fmt.Sprintf("%s %s ?", key, equalOpr) - args = append(args, val) + e = Expr(fmt.Sprintf("%s %s ?", key, equalOpr), val) + } + exprs = append(exprs, e) + } + + var sqlParts []string + for _, sqlizer := range exprs { + partSQL, partArgs, err := sqlizer.ToSql() + if err != nil { + return "", nil, err + } + if partSQL != "" { + sqlParts = append(sqlParts, partSQL) + args = append(args, partArgs...) } - exprs = append(exprs, expr) } - sql = strings.Join(exprs, " AND ") + if len(sqlParts) > 0 { + sql = strings.Join(sqlParts, " AND ") + } return } @@ -208,228 +228,92 @@ func (neq NotEq) ToSql() (sql string, args []any, err error) { return Eq(neq).toSQL(true) } -// In is syntactic sugar for IN conditions -type In struct { - Col string - Val any -} - -func (in In) toSql(useNotOpr bool) (sql string, args []any, err error) { - var ( - inOpr = "IN" - inEmptyExpr = sqlFalse - ) - - if useNotOpr { - inOpr = "NOT IN" - inEmptyExpr = sqlTrue - } - - val := in.Val - - switch v := val.(type) { - case driver.Valuer: - if val, err = v.Value(); err != nil { - return - } - } - - r := reflect.ValueOf(val) - if r.Kind() == reflect.Ptr { - if r.IsNil() { - val = nil - } else { - val = r.Elem().Interface() - } - } - - if val == nil { - err = fmt.Errorf("cannot use null with in operators") - return - } - - if isListType(val) { - valVal := reflect.ValueOf(val) - if valVal.Len() == 0 { - sql = inEmptyExpr - if args == nil { - args = []any{} - } - } else { - for i := 0; i < valVal.Len(); i++ { - args = append(args, valVal.Index(i).Interface()) - } - sql = fmt.Sprintf("%s %s (%s)", in.Col, inOpr, Placeholders(valVal.Len())) - } - } else { - sql = fmt.Sprintf("%s %s (?)", in.Col, inOpr) - args = append(args, val) +func In[T any](field string, val []T) Sqlizer { + if len(val) == 0 { + return expr{sql: sqlFalse, args: []any{}} } - return -} - -func (in In) ToSql() (sql string, args []any, err error) { - return in.toSql(false) + s := lo.Map(val, func(item T, index int) any { + return any(item) + }) + return Expr(fmt.Sprintf("%s IN (%s)", field, Placeholders(len(val))), s...) } -// NotIn is syntactic sugar for NOT IN conditions -type NotIn In - -func (ni NotIn) ToSql() (sql string, args []any, err error) { - return In(ni).toSql(true) +func NotIn[T any](field string, val []T) Sqlizer { + if len(val) == 0 { + return expr{sql: sqlTrue, args: []any{}} + } + s := lo.Map(val, func(item T, index int) any { + return any(item) + }) + return Expr(fmt.Sprintf("%s NOT IN (%s)", field, Placeholders(len(val))), s...) } // Like is syntactic sugar for use with LIKE conditions. // Ex: // -// .Where(Like{"name": "%irrel"}) -type Like map[string]any - -func (lk Like) toSql(opr string) (sql string, args []any, err error) { - var exprs []string - for key, val := range lk { - expr := "" - - switch v := val.(type) { - case driver.Valuer: - if val, err = v.Value(); err != nil { - return - } - } - - if val == nil { - err = fmt.Errorf("cannot use null with like operators") - return - } else { - if isListType(val) { - err = fmt.Errorf("cannot use array or slice with like operators") - return - } else { - expr = fmt.Sprintf("%s %s ?", key, opr) - args = append(args, val) - } - } - exprs = append(exprs, expr) - } - sql = strings.Join(exprs, " AND ") - return -} - -func (lk Like) ToSql() (sql string, args []any, err error) { - return lk.toSql("LIKE") +// .Where(Like("name", "%irrel")) +func Like(field, value string) Sqlizer { + return Expr(fmt.Sprintf("%s LIKE ?", field), value) } // NotLike is syntactic sugar for use with LIKE conditions. // Ex: // -// .Where(NotLike{"name": "%irrel"}) -type NotLike Like - -func (nlk NotLike) ToSql() (sql string, args []any, err error) { - return Like(nlk).toSql("NOT LIKE") +// .Where(NotLike("name": "%irrel")) +func NotLike(field, value string) Sqlizer { + return Expr(fmt.Sprintf("%s NOT LIKE ?", field), value) } // ILike is syntactic sugar for use with ILIKE conditions. // Ex: // -// .Where(ILike{"name": "sq%"}) -type ILike Like - -func (ilk ILike) ToSql() (sql string, args []any, err error) { - return Like(ilk).toSql("ILIKE") +// .Where(ILike("name", "sq%")) +func ILike(field, value string) Sqlizer { + return Expr(fmt.Sprintf("%s ILIKE ?", field), value) } // NotILike is syntactic sugar for use with ILIKE conditions. // Ex: // -// .Where(NotILike{"name": "sq%"}) -type NotILike Like - -func (nilk NotILike) ToSql() (sql string, args []any, err error) { - return Like(nilk).toSql("NOT ILIKE") +// .Where(NotILike("name", "sq%")) +func NotILike(field, value string) Sqlizer { + return Expr(fmt.Sprintf("%s NOT ILIKE ?", field), value) } // Lt is syntactic sugar for use with Where/Having/Set methods. // Ex: // -// .Where(Lt{"id": 1}) -type Lt map[string]any - -func (lt Lt) toSql(opposite, orEq bool) (sql string, args []any, err error) { - var ( - exprs []string - opr = "<" - ) - - if opposite { - opr = ">" - } - - if orEq { - opr = fmt.Sprintf("%s%s", opr, "=") - } - - sortedKeys := getSortedKeys(lt) - for _, key := range sortedKeys { - var expr string - val := lt[key] - - switch v := val.(type) { - case driver.Valuer: - if val, err = v.Value(); err != nil { - return - } - } - - if val == nil { - err = fmt.Errorf("cannot use null with less than or greater than operators") - return - } - if isListType(val) { - err = fmt.Errorf("cannot use array or slice with less than or greater than operators") - return - } - expr = fmt.Sprintf("%s %s ?", key, opr) - args = append(args, val) - - exprs = append(exprs, expr) - } - sql = strings.Join(exprs, " AND ") - return -} - -func (lt Lt) ToSql() (sql string, args []any, err error) { - return lt.toSql(false, false) +// .Where(Lt("id", 1)) +func Lt[T cmp.Ordered](field string, value T) Sqlizer { + return Expr(fmt.Sprintf("%s < ?", field), value) } // Lte is syntactic sugar for use with Where/Having/Set methods. // Ex: // -// .Where(Lte{"id": 1}) == "id <= 1" -type Lte Lt - -func (ltOrEq Lte) ToSql() (sql string, args []any, err error) { - return Lt(ltOrEq).toSql(false, true) +// .Where(Lte("id", 1)) == "id <= 1" +func Lte[T cmp.Ordered](field string, value T) Sqlizer { + return Expr(fmt.Sprintf("%s <= ?", field), value) } // Gt is syntactic sugar for use with Where/Having/Set methods. // Ex: // -// .Where(Gt{"id": 1}) == "id > 1" -type Gt Lt - -func (gt Gt) ToSql() (sql string, args []any, err error) { - return Lt(gt).toSql(true, false) +// .Where(Gt("id", 1)) == "id > 1" +func Gt[T cmp.Ordered](field string, value T) Sqlizer { + return Expr(fmt.Sprintf("%s > ?", field), value) } // Gte is syntactic sugar for use with Where/Having/Set methods. // Ex: // -// .Where(Gte{"id": 1}) == "id >= 1" -type Gte Lt +// .Where(Gte("id", 1)) == "id >= 1" +func Gte[T cmp.Ordered](field string, value T) Sqlizer { + return Expr(fmt.Sprintf("%s >= ?", field), value) +} -func (gtOrEq Gte) ToSql() (sql string, args []any, err error) { - return Lt(gtOrEq).toSql(true, true) +func Between[T cmp.Ordered](field string, start, end T) Sqlizer { + return Expr(fmt.Sprintf("%s BETWEEN ? AND ?", field), start, end) } type conj []Sqlizer @@ -477,11 +361,3 @@ func getSortedKeys(exp map[string]any) []string { sort.Strings(sortedKeys) return sortedKeys } - -func isListType(val any) bool { - if driver.IsValue(val) { - return false - } - valVal := reflect.ValueOf(val) - return valVal.Kind() == reflect.Array || valVal.Kind() == reflect.Slice -} diff --git a/builder/expr_test.go b/builder/expr_test.go index 8c3425e..562dc99 100644 --- a/builder/expr_test.go +++ b/builder/expr_test.go @@ -88,7 +88,7 @@ func TestNotEqToSql(t *testing.T) { } func TestInToSql(t *testing.T) { - b := In{"id", []int{1, 2, 3}} + b := In("id", []int{1, 2, 3}) sql, args, err := b.ToSql() assert.NoError(t, err) @@ -100,7 +100,7 @@ func TestInToSql(t *testing.T) { } func TestNotInToSql(t *testing.T) { - b := NotIn{"id", []int{1, 2, 3}} + b := NotIn("id", []int{1, 2, 3}) sql, args, err := b.ToSql() assert.NoError(t, err) @@ -112,7 +112,7 @@ func TestNotInToSql(t *testing.T) { } func TestInEmptyToSql(t *testing.T) { - b := In{"id", []int{}} + b := In("id", []int{}) sql, args, err := b.ToSql() assert.NoError(t, err) @@ -124,7 +124,7 @@ func TestInEmptyToSql(t *testing.T) { } func TestNotInEmptyToSql(t *testing.T) { - b := NotIn{"id", []int{}} + b := NotIn("id", []int{}) sql, args, err := b.ToSql() assert.NoError(t, err) @@ -148,7 +148,7 @@ func TestEqBytesToSql(t *testing.T) { } func TestLtToSql(t *testing.T) { - b := Lt{"id": 1} + b := Lt("id", 1) sql, args, err := b.ToSql() assert.NoError(t, err) @@ -160,7 +160,7 @@ func TestLtToSql(t *testing.T) { } func TestLtOrEqToSql(t *testing.T) { - b := Lte{"id": 1} + b := Lte("id", 1) sql, args, err := b.ToSql() assert.NoError(t, err) @@ -172,7 +172,7 @@ func TestLtOrEqToSql(t *testing.T) { } func TestGtToSql(t *testing.T) { - b := Gt{"id": 1} + b := Gt("id", 1) sql, args, err := b.ToSql() assert.NoError(t, err) @@ -184,7 +184,7 @@ func TestGtToSql(t *testing.T) { } func TestGtOrEqToSql(t *testing.T) { - b := Gte{"id": 1} + b := Gte("id", 1) sql, args, err := b.ToSql() assert.NoError(t, err) @@ -362,7 +362,7 @@ func TestEmptyOrToSql(t *testing.T) { } func TestLikeToSql(t *testing.T) { - b := Like{"name": "%irrel"} + b := Like("name", "%irrel") sql, args, err := b.ToSql() assert.NoError(t, err) @@ -374,7 +374,7 @@ func TestLikeToSql(t *testing.T) { } func TestNotLikeToSql(t *testing.T) { - b := NotLike{"name": "%irrel"} + b := NotLike("name", "%irrel") sql, args, err := b.ToSql() assert.NoError(t, err) @@ -386,7 +386,7 @@ func TestNotLikeToSql(t *testing.T) { } func TestILikeToSql(t *testing.T) { - b := ILike{"name": "sq%"} + b := ILike("name", "sq%") sql, args, err := b.ToSql() assert.NoError(t, err) @@ -398,7 +398,7 @@ func TestILikeToSql(t *testing.T) { } func TestNotILikeToSql(t *testing.T) { - b := NotILike{"name": "sq%"} + b := NotILike("name", "sq%") sql, args, err := b.ToSql() assert.NoError(t, err) @@ -422,11 +422,11 @@ func TestSqlEqOrder(t *testing.T) { } func TestSqlLtOrder(t *testing.T) { - b := Lt{"a": 1, "b": 2, "c": 3} + b := And{Lt("a", 1), Lt("b", 2), Lt("c", 3)} sql, args, err := b.ToSql() assert.NoError(t, err) - expectedSql := "a < ? AND b < ? AND c < ?" + expectedSql := "(a < ? AND b < ? AND c < ?)" assert.Equal(t, expectedSql, sql) expectedArgs := []any{1, 2, 3} @@ -481,6 +481,42 @@ func TestExprRecursion(t *testing.T) { } } +func TestBetweenIntToSql(t *testing.T) { + b := Between("age", 18, 65) + sql, args, err := b.ToSql() + assert.NoError(t, err) + + expectedSql := "age BETWEEN ? AND ?" + assert.Equal(t, expectedSql, sql) + + expectedArgs := []any{18, 65} + assert.Equal(t, expectedArgs, args) +} + +func TestBetweenFloat64ToSql(t *testing.T) { + b := Between("price", 10.5, 99.99) + sql, args, err := b.ToSql() + assert.NoError(t, err) + + expectedSql := "price BETWEEN ? AND ?" + assert.Equal(t, expectedSql, sql) + + expectedArgs := []any{10.5, 99.99} + assert.Equal(t, expectedArgs, args) +} + +func TestBetweenStringToSql(t *testing.T) { + b := Between("name", "A", "Z") + sql, args, err := b.ToSql() + assert.NoError(t, err) + + expectedSql := "name BETWEEN ? AND ?" + assert.Equal(t, expectedSql, sql) + + expectedArgs := []any{"A", "Z"} + assert.Equal(t, expectedArgs, args) +} + func ExampleEq() { Select("id", "created", "first_name").From("users").Where(Eq{ "company": 20, diff --git a/builder/insert.go b/builder/insert.go index 2d01c4d..fa8b6a7 100644 --- a/builder/insert.go +++ b/builder/insert.go @@ -18,6 +18,7 @@ type InsertBuilder struct { values [][]any suffixes []Sqlizer selectBuilder *SelectBuilder + returning []string } func (b *InsertBuilder) ToSql() (sqlStr string, args []any, err error) { @@ -72,6 +73,11 @@ func (b *InsertBuilder) ToSql() (sqlStr string, args []any, err error) { return } + if len(b.returning) > 0 { + sql.WriteString(" RETURNING ") + sql.WriteString(strings.Join(b.returning, ",")) + } + if len(b.suffixes) > 0 { sql.WriteString(" ") args, err = appendToSql(b.suffixes, sql, " ", args) @@ -88,8 +94,7 @@ func (b *InsertBuilder) appendValuesToSQL(w io.Writer, args []any) ([]any, error if len(b.values) == 0 { return args, errors.New("values for insert statements are not set") } - - io.WriteString(w, "VALUES ") + _, _ = io.WriteString(w, "VALUES ") valuesStrings := make([]string, len(b.values)) for r, row := range b.values { @@ -109,8 +114,7 @@ func (b *InsertBuilder) appendValuesToSQL(w io.Writer, args []any) ([]any, error } valuesStrings[r] = fmt.Sprintf("(%s)", strings.Join(valueStrings, ",")) } - - io.WriteString(w, strings.Join(valuesStrings, ",")) + _, _ = io.WriteString(w, strings.Join(valuesStrings, ",")) return args, nil } @@ -207,3 +211,9 @@ func (b *InsertBuilder) StatementKeyword(keyword string) *InsertBuilder { b.statementKeyword = keyword return b } + +// Returning adds a RETURNING clause to the query. +func (b *InsertBuilder) Returning(columns ...string) *InsertBuilder { + b.returning = append(b.returning, columns...) + return b +} diff --git a/builder/returning_test.go b/builder/returning_test.go new file mode 100644 index 0000000..7c4cb6f --- /dev/null +++ b/builder/returning_test.go @@ -0,0 +1,70 @@ +package builder + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestInsertBuilderReturning(t *testing.T) { + b := Insert("users"). + Columns("name", "email"). + Values("John", "john@example.com"). + Returning("id") + + sql, args, err := b.ToSql() + assert.NoError(t, err) + + expectedSQL := "INSERT INTO users (name,email) VALUES (?,?) RETURNING id" + assert.Equal(t, expectedSQL, sql) + + expectedArgs := []any{"John", "john@example.com"} + assert.Equal(t, expectedArgs, args) +} + +func TestInsertBuilderReturningMultiple(t *testing.T) { + b := Insert("users"). + Columns("name", "email"). + Values("John", "john@example.com"). + Returning("id", "created_at") + + sql, args, err := b.ToSql() + assert.NoError(t, err) + + expectedSQL := "INSERT INTO users (name,email) VALUES (?,?) RETURNING id,created_at" + assert.Equal(t, expectedSQL, sql) + + expectedArgs := []any{"John", "john@example.com"} + assert.Equal(t, expectedArgs, args) +} + +func TestUpdateBuilderReturning(t *testing.T) { + b := Update("users"). + Set("name", "Jane"). + Where(Eq{"id": 1}). + Returning("updated_at") + + sql, args, err := b.ToSql() + assert.NoError(t, err) + + expectedSQL := "UPDATE users SET name = ? WHERE id = ? RETURNING updated_at" + assert.Equal(t, expectedSQL, sql) + + expectedArgs := []any{"Jane", 1} + assert.Equal(t, expectedArgs, args) +} + +func TestDeleteBuilderReturning(t *testing.T) { + b := Delete("users"). + Where(Eq{"id": 1}). + Returning("id", "name") + + sql, args, err := b.ToSql() + assert.NoError(t, err) + + expectedSQL := "DELETE FROM users WHERE id = ? RETURNING id,name" + assert.Equal(t, expectedSQL, sql) + + expectedArgs := []any{1} + assert.Equal(t, expectedArgs, args) +} diff --git a/builder/select_test.go b/builder/select_test.go index b5eee6e..de70350 100644 --- a/builder/select_test.go +++ b/builder/select_test.go @@ -18,7 +18,7 @@ func TestSelectBuilderToSql(t *testing.T) { AddColumn("c"). AddColumn("IF(d IN ("+Placeholders(3)+"), 1, 0) as stat_column", 1, 2, 3). AddColumn(Expr("a > ?", 100)). - AddColumn(Alias(In{"b", []int{101, 102, 103}}, "b_alias")). + AddColumn(Alias(In("b", []int{101, 102, 103}), "b_alias")). AddColumn(Alias(subQ, "subq")). From("e"). JoinClause("CROSS JOIN j1"). @@ -30,7 +30,7 @@ func TestSelectBuilderToSql(t *testing.T) { Where("f = ?", 4). Where(Eq{"g": 5}). Where(map[string]any{"h": 6}). - Where(In{"i", []int{7, 8, 9}}). + Where(In("i", []int{7, 8, 9})). Where(Or{Expr("j = ?", 10), And{Eq{"k": 11}, Expr("true")}}). GroupBy("l"). Having("m = n"). @@ -182,10 +182,10 @@ func TestSelectBuilderFromSelect(t *testing.T) { func TestSelectBuilderFromSelectNestedDollarPlaceholders(t *testing.T) { subQ := Select("c"). From("t"). - Where(Gt{"c": 1}) + Where(Gt("c", 1)) b := Select("c"). FromSelect(subQ, "subq"). - Where(Lt{"c": 2}) + Where(Lt("c", 2)) sql, args, err := b.ToSql() assert.NoError(t, err) sql, err = Dollar.ReplacePlaceholders(sql) @@ -390,14 +390,10 @@ func ExampleSelectBuilder_Where_helpers() { "company": companyId, }) - Select("id", "created", "first_name").From("users").Where(Gte{ - "created": time.Now().AddDate(0, 0, -7), - }) + Select("id", "created", "first_name").From("users").Where(Expr("created >= ?", time.Now().AddDate(0, 0, -7))) Select("id", "created", "first_name").From("users").Where(And{ - Gte{ - "created": time.Now().AddDate(0, 0, -7), - }, + Expr("created >= ?", time.Now().AddDate(0, 0, -7)), Eq{ "company": companyId, }, @@ -412,9 +408,7 @@ func ExampleSelectBuilder_Where_multiple() { Select("id", "created", "first_name"). From("users"). Where("company = ?", companyId). - Where(Gte{ - "created": time.Now().AddDate(0, 0, -7), - }) + Where(Expr("created >= ?", time.Now().AddDate(0, 0, -7))) } func ExampleSelectBuilder_FromSelect() { diff --git a/builder/sqlizer_test.go b/builder/sqlizer_test.go index 66581b1..8c05e47 100644 --- a/builder/sqlizer_test.go +++ b/builder/sqlizer_test.go @@ -56,6 +56,6 @@ func TestDebugSqlizerErrors(t *testing.T) { errorMsg = DebugSqlizer(Expr("x = ? AND y = ?", 1)) // Too many placeholders assert.True(t, strings.HasPrefix(errorMsg, "[DebugSqlizer error: ")) - errorMsg = DebugSqlizer(Lt{"x": nil}) // Cannot use nil values with Lt + errorMsg = DebugSqlizer(Eq{"x": []int{1, 2}}) // Cannot use array with Eq assert.True(t, strings.HasPrefix(errorMsg, "[ToSql error: ")) } diff --git a/builder/try/try.go b/builder/try/try.go index 3367fb1..b0782f6 100644 --- a/builder/try/try.go +++ b/builder/try/try.go @@ -8,95 +8,90 @@ import ( "github.com/yvvlee/lorm/builder" ) -type Ordered interface { +type ordered interface { cmp.Ordered | ~bool } -// Equal 如果value不为空,则添加 dbField = value 的条件 -func Equal[T Ordered](dbField string, value *T) builder.Sqlizer { +// Equal adds condition dbField = value if value is not nil +func Equal[T ordered](dbField string, value *T) builder.Sqlizer { if value == nil { return nil } return builder.Eq{dbField: *value} } -// NotEqual 如果value不为空,则添加 dbField != value 的条件 -func NotEqual[T Ordered](dbField string, value *T) builder.Sqlizer { +// NotEqual adds condition dbField != value if value is not nil +func NotEqual[T ordered](dbField string, value *T) builder.Sqlizer { if value == nil { return nil } return builder.NotEq{dbField: *value} } -// Gt 如果value不为空,则添加 dbField > value 的条件 -func Gt[T Ordered](dbField string, value *T) builder.Sqlizer { +// Gt adds condition dbField > value if value is not nil +func Gt[T cmp.Ordered](dbField string, value *T) builder.Sqlizer { if value == nil { return nil } - return builder.Gt{dbField: *value} + return builder.Gt(dbField, *value) } -// Gte 如果value不为空,则添加 dbField >= value 的条件 -func Gte[T Ordered](dbField string, value *T) builder.Sqlizer { +// Gte adds condition dbField >= value if value is not nil +func Gte[T cmp.Ordered](dbField string, value *T) builder.Sqlizer { if value == nil { return nil } - return builder.Gte{dbField: *value} + return builder.Gte(dbField, *value) } -// Lt 如果value不为空,则添加 dbField < value 的条件 -func Lt[T Ordered](dbField string, value *T) builder.Sqlizer { +// Lt adds condition dbField < value if value is not nil +func Lt[T cmp.Ordered](dbField string, value *T) builder.Sqlizer { if value == nil { return nil } - return builder.Lt{dbField: *value} + return builder.Lt(dbField, *value) } -// Lte 如果value不为空,则添加 dbField <= value 的条件 -func Lte[T Ordered](dbField string, value *T) builder.Sqlizer { +// Lte adds condition dbField <= value if value is not nil +func Lte[T cmp.Ordered](dbField string, value *T) builder.Sqlizer { if value == nil { return nil } - return builder.Lte{dbField: *value} + return builder.Lte(dbField, *value) } -// Like 如果value不为空,则添加 dbField like "%${value}%" 的条件 +// Like adds condition dbField like "%${value}%" if value is not empty func Like(dbField, value string) builder.Sqlizer { if v := strings.TrimSpace(value); v != "" { - return builder.Like{dbField: v} + return builder.Like(dbField, v) } return nil } -// Likes 如果values不为空,则添加 dbField like "%${value1}%" OR dbField like "%${value2}%" 的条件 +// Likes adds conditions dbField like "%${value1}%" OR dbField like "%${value2}%" if values are not empty func Likes(dbField string, values []string) builder.Sqlizer { if len(values) == 0 { return nil } var c []builder.Sqlizer for _, v := range values { - c = append(c, builder.Like{dbField: v}) + c = append(c, builder.Like(dbField, v)) } return builder.Or(c) } -// Range 如果min不为空,则添加 dbField >= min 的条件;如果max不为空,则添加 dbField <= max 的条件 -func Range[T Ordered](dbField string, min, max *T) builder.Sqlizer { +// Range adds condition dbField >= min if min is not nil, and dbField <= max if max is not nil +func Range[T cmp.Ordered](dbField string, min, max *T) builder.Sqlizer { if min == nil { if max == nil { return nil - } else { - return builder.Lte{dbField: *max} } + return builder.Lte(dbField, *max) } else { if max == nil { - return builder.Gte{dbField: *min} - } else { - return builder.And{ - builder.Gte{dbField: *min}, - builder.Lte{dbField: *max}, - } + return builder.Gte(dbField, *min) } + return builder.Between(dbField, *min, *max) } } @@ -104,53 +99,53 @@ func timeToString(t *time.Time) string { if t == nil { return "" } - return t.Format("2006-01-02 15:04:05") + return t.Format(time.DateTime) } -// TimeRange 如果start不为空,则添加 dbField >= min 的条件;如果end不为空,则添加 dbField < max 的条件 +// TimeRange adds condition dbField >= start if start is not zero, and dbField < end if end is not zero func TimeRange(dbField string, start, end *time.Time) builder.Sqlizer { if start == nil || start.IsZero() { if end == nil || end.IsZero() { return nil } else { - return builder.Lt{dbField: timeToString(end)} + return builder.Lt(dbField, timeToString(end)) } } else { if end == nil || end.IsZero() { - return builder.Gte{dbField: timeToString(start)} + return builder.Gte(dbField, timeToString(start)) } else { return builder.And{ - builder.Gte{dbField: timeToString(start)}, - builder.Lt{dbField: timeToString(end)}, + builder.Gte(dbField, timeToString(start)), + builder.Lt(dbField, timeToString(end)), } } } } -// MultiLike 如果value不为空,则添加 dbField1 like "%${value}%" OR dbField2 like "%${value}%" 的条件 +// MultiLike adds conditions dbField1 like "%${value}%" OR dbField2 like "%${value}%" if value is not empty func MultiLike(dbFields []string, value string) builder.Sqlizer { if v := strings.TrimSpace(value); v != "" { var conds []builder.Sqlizer for _, field := range dbFields { - conds = append(conds, builder.Like{field: v}) + conds = append(conds, builder.Like(field, v)) } return builder.Or(conds) } return nil } -// In 如果values不为空,则添加 dbField IN (values) 的条件 +// In adds condition dbField IN (values) if values are not empty func In[T any](dbField string, values *[]T) builder.Sqlizer { if values == nil || len(*values) == 0 { return nil } - return builder.In{Col: dbField, Val: *values} + return builder.In(dbField, *values) } -// NotIn 如果values不为空,则添加 dbField NOT IN (values) 的条件 +// NotIn adds condition dbField NOT IN (values) if values are not empty func NotIn[T any](dbField string, values *[]T) builder.Sqlizer { if values == nil || len(*values) == 0 { return nil } - return builder.NotIn{Col: dbField, Val: *values} + return builder.NotIn(dbField, *values) } diff --git a/builder/try/try_test.go b/builder/try/try_test.go index d1af9f6..5021295 100644 --- a/builder/try/try_test.go +++ b/builder/try/try_test.go @@ -52,7 +52,7 @@ func TestRange(t *testing.T) { c := Range("age", &min, &max) sql, args, err := c.ToSql() assert.NoError(t, err) - assert.Equal(t, "(age >= ? AND age <= ?)", sql) + assert.Equal(t, "age BETWEEN ? AND ?", sql) assert.Equal(t, []any{10, 20}, args) c = Range("age", &min, nil) diff --git a/builder/update.go b/builder/update.go index 5ab17e4..bd76666 100644 --- a/builder/update.go +++ b/builder/update.go @@ -18,6 +18,7 @@ type UpdateBuilder struct { limit string offset string suffixes []Sqlizer + returning []string } type setClause struct { @@ -103,6 +104,11 @@ func (b *UpdateBuilder) ToSql() (sqlStr string, args []any, err error) { sql.WriteString(b.offset) } + if len(b.returning) > 0 { + sql.WriteString(" RETURNING ") + sql.WriteString(strings.Join(b.returning, ",")) + } + if len(b.suffixes) > 0 { sql.WriteString(" ") args, err = appendToSql(b.suffixes, sql, " ", args) @@ -203,3 +209,9 @@ func (b *UpdateBuilder) SuffixExpr(expr Sqlizer) *UpdateBuilder { b.suffixes = append(b.suffixes, expr) return b } + +// Returning adds a RETURNING clause to the query. +func (b *UpdateBuilder) Returning(columns ...string) *UpdateBuilder { + b.returning = append(b.returning, columns...) + return b +} diff --git a/insert.go b/insert.go index 57ca1e0..864ce13 100644 --- a/insert.go +++ b/insert.go @@ -4,49 +4,56 @@ import ( "context" "database/sql" "slices" - "time" "github.com/samber/lo" - "github.com/spf13/cast" "github.com/yvvlee/lorm/builder" ) func Insert[T Table](ctx context.Context, engine *Engine, table T) (rowsAffected int64, err error) { - result, err := inserts(ctx, engine, []T{table}) - if err != nil { - return 0, err - } - rowsAffected, err = result.RowsAffected() - if err != nil { - return - } - return rowsAffected, fillModelID(table, result) + return InsertAll(ctx, engine, []T{table}) } func InsertAll[T Table](ctx context.Context, engine *Engine, models []T) (rowsAffected int64, err error) { if len(models) == 0 { return } - if len(models) == 1 { - return Insert(ctx, engine, models[0]) + table := models[0] + descriptor := table.LormModelDescriptor() + primaryKeys := descriptor.FlagFields(FlagPrimaryKey) + // Check if we can use RETURNING or LastInsertId + var useReturning bool + var pkColumn string + if len(primaryKeys) == 1 { + flagAutoIncrementFields := descriptor.FlagFields(FlagAutoIncrement) + if slices.Contains(flagAutoIncrementFields, primaryKeys[0]) { + pkColumn = primaryKeys[0] + useReturning = engine.SupportsReturning() + } + } + if useReturning { + return insertsWithReturning(ctx, engine, models, pkColumn) } - - result, err := inserts(ctx, engine, models) + result, err := inserts(ctx, engine, models, pkColumn) if err != nil { return 0, err } - return result.RowsAffected() + rowsAffected, err = result.RowsAffected() + if err != nil { + return + } + if len(models) > 1 { + return rowsAffected, nil + } + return rowsAffected, fillModelID(table, result) } -func inserts[T Table](ctx context.Context, engine *Engine, models []T) (sql.Result, error) { - table := models[0].TableName() - insertBuilder := builder.Insert(table) - fields, values := ModelsToInsertData(models) +func inserts[T Table](ctx context.Context, engine *Engine, models []T, pkColumn string) (sql.Result, error) { + fields, values := ModelsToInsertData(models, pkColumn) escaper := engine.Escaper() - insertBuilder.Into(escaper.Escape(table)) - insertBuilder.Columns(lo.Map(fields, func(field string, _ int) string { + columns := lo.Map(fields, func(field string, _ int) string { return escaper.Escape(field) - })...) + }) + insertBuilder := builder.Insert(models[0].TableName()).Columns(columns...) for _, value := range values { insertBuilder.Values(value...) } @@ -57,77 +64,33 @@ func inserts[T Table](ctx context.Context, engine *Engine, models []T) (sql.Resu return engine.Exec(ctx, query, args...) } -func fillCurrentTime(value any, now time.Time) { - switch v := value.(type) { - case *time.Time: - if v.IsZero() { - *v = now - } - case *int64: - if *v == 0 { - *v = now.Unix() - } - case *uint64: - if *v == 0 { - *v = uint64(now.Unix()) - } - case *int32: - if *v == 0 { - *v = int32(now.Unix()) - } - case *uint32: - if *v == 0 { - *v = uint32(now.Unix()) - } - case *int: - if *v == 0 { - *v = int(now.Unix()) - } - case *string: - if *v == "" { - *v = now.Format(time.DateTime) - } - } -} +func insertsWithReturning[T Table](ctx context.Context, engine *Engine, models []T, pkColumn string) (rowsAffected int64, err error) { + fields, values := ModelsToInsertData(models, pkColumn) + escaper := engine.Escaper() + columns := lo.Map(fields, func(field string, _ int) string { + return escaper.Escape(field) + }) -func fillModelID(table Table, result sql.Result) error { - descriptor := table.LormModelDescriptor() - primaryKeys := descriptor.FlagFields(FlagPrimaryKey) - if len(primaryKeys) != 1 { - return nil + insertBuilder := builder.Insert(models[0].TableName()). + Columns(columns...). + Returning(escaper.Escape(pkColumn)) + for _, value := range values { + insertBuilder.Values(value...) } - flagAutoIncrementFields := descriptor.FlagFields(FlagAutoIncrement) - if !slices.Contains(flagAutoIncrementFields, primaryKeys[0]) { - return nil + + query, args, err := insertBuilder.ToSql() + if err != nil { + return 0, err } - primaryPointer := table.LormFieldMap()[primaryKeys[0]] - if cast.ToUint64(primaryPointer) == 0 { - lastInsertId, err := result.LastInsertId() - if err != nil { - return err - } - switch id := primaryPointer.(type) { - case *uint64: - *id = cast.ToUint64(lastInsertId) - case *int64: - *id = cast.ToInt64(lastInsertId) - case *uint32: - *id = cast.ToUint32(lastInsertId) - case *int32: - *id = cast.ToInt32(lastInsertId) - case *uint16: - *id = cast.ToUint16(lastInsertId) - case *int16: - *id = cast.ToInt16(lastInsertId) - case *uint8: - *id = cast.ToUint8(lastInsertId) - case *int8: - *id = cast.ToInt8(lastInsertId) - case *uint: - *id = cast.ToUint(lastInsertId) - case *int: - *id = cast.ToInt(lastInsertId) - } + + // Execute query and scan the returned ID + primaryPointers := lo.Map(models, func(item T, _ int) any { + return item.LormFieldMap()[pkColumn] + }) + err = engine.Query(ctx, NewColsScanner(&primaryPointers), query, args...) + + if err != nil { + return 0, err } - return nil + return int64(len(models)), nil } diff --git a/insert_single_test.go b/insert_test.go similarity index 75% rename from insert_single_test.go rename to insert_test.go index ee861a3..3481ed4 100644 --- a/insert_single_test.go +++ b/insert_test.go @@ -21,3 +21,10 @@ func TestInsertSingle(t *testing.T) { assert.EqualValues(t, 1, rows) assert.True(t, m.ID > 0) } + +func TestInsertAllEmpty(t *testing.T) { + var models []*Test + rows, err := InsertAll(context.TODO(), &Engine{config: &Config{}}, models) + assert.NoError(t, err) + assert.EqualValues(t, 0, rows) +} diff --git a/lorm.go b/lorm.go index 860de11..2a96288 100644 --- a/lorm.go +++ b/lorm.go @@ -65,6 +65,30 @@ func (e *Engine) Escaper() names.Escaper { return e.config.escaper } +func (e *Engine) DriverName() string { + return e.config.driverName +} + +// SupportsReturning returns true if the database driver supports RETURNING clause +func (e *Engine) SupportsReturning() bool { + switch e.config.driverName { + case "postgres", "pgx", "pq-timeouts", "cloudsqlpostgres", "nrpostgres", "cockroach": + return true + default: + return false + } +} + +// SupportsLastInsertId returns true if the database driver supports LastInsertId +func (e *Engine) SupportsLastInsertId() bool { + switch e.config.driverName { + case "postgres", "pgx", "pq-timeouts", "cloudsqlpostgres", "nrpostgres", "cockroach": + return false + default: + return true + } +} + func (e *Engine) session(ctx context.Context) *session { if s, ok := ctx.Value(e).(*session); ok { return s diff --git a/lorm_test.go b/lorm_test.go index 17d2b90..99842b0 100644 --- a/lorm_test.go +++ b/lorm_test.go @@ -259,6 +259,7 @@ func testEngine(t *testing.T, engine *Engine) { From("test"). Columns("id"). Where("id < ?", 3). + OrderBy("id"). Find(ctx) assert.Nil(t, err) assert.Equal(t, ids, []uint64{1, 2}) diff --git a/model.go b/model.go index 46eef28..133149e 100644 --- a/model.go +++ b/model.go @@ -41,17 +41,22 @@ type UnimplementedTable struct{} func (u UnimplementedTable) mustEmbedUnimplementedModel() {} func (u UnimplementedTable) mustEmbedUnimplementedTable() {} -func ModelToInsertData[T Model](model T) (columns []string, values []any) { - fields, v := ModelsToInsertData([]T{model}) +func ModelToInsertData[T Model](model T, ignoreFields ...string) (columns []string, values []any) { + fields, v := ModelsToInsertData([]T{model}, ignoreFields...) return fields, v[0] } -func ModelsToInsertData[T Model](models []T) (columns []string, values [][]any) { +func ModelsToInsertData[T Model](models []T, ignoreFields ...string) (columns []string, values [][]any) { if len(models) == 0 { return } descriptor := models[0].LormModelDescriptor() columns = descriptor.AllFields() + if len(ignoreFields) > 0 { + columns = lo.Filter(columns, func(item string, _ int) bool { + return !slices.Contains(ignoreFields, item) + }) + } createdFields := descriptor.FlagFields(FlagCreated) updatedFields := descriptor.FlagFields(FlagUpdated) jsonFields := descriptor.FlagFields(FlagJson) diff --git a/scanner.go b/scanner.go index fecbff2..d2f82fb 100644 --- a/scanner.go +++ b/scanner.go @@ -74,6 +74,17 @@ func (m *ColsScanner[T]) Scan(rows *sql.Rows) error { if len(columns) != 1 { return fmt.Errorf("expected exactly one column, got %d", len(columns)) } + if len(*m.v) > 0 { + i := 0 + for rows.Next() && i < len(*m.v) { + item := (*m.v)[i] + if err = rows.Scan(item); err != nil { + return err + } + i++ + } + return nil + } var v []T for rows.Next() { var item T @@ -140,7 +151,7 @@ func (m *ColScanner[T]) Scan(row *sql.Rows) error { return scanRow(row, m.v) } -func scanRow(rows *sql.Rows, dest ...interface{}) error { +func scanRow(rows *sql.Rows, dest ...any) error { for _, dp := range dest { if _, ok := dp.(*sql.RawBytes); ok { return errors.New("sql: RawBytes isn't allowed on Row.Scan") diff --git a/tools.go b/tools.go new file mode 100644 index 0000000..3236ecc --- /dev/null +++ b/tools.go @@ -0,0 +1,84 @@ +package lorm + +import ( + "database/sql" + "slices" + "time" + + "github.com/spf13/cast" +) + +func fillModelID(table Table, result sql.Result) error { + descriptor := table.LormModelDescriptor() + primaryKeys := descriptor.FlagFields(FlagPrimaryKey) + if len(primaryKeys) != 1 { + return nil + } + flagAutoIncrementFields := descriptor.FlagFields(FlagAutoIncrement) + if !slices.Contains(flagAutoIncrementFields, primaryKeys[0]) { + return nil + } + primaryPointer := table.LormFieldMap()[primaryKeys[0]] + if cast.ToUint64(primaryPointer) == 0 { + lastInsertId, err := result.LastInsertId() + if err != nil { + return err + } + switch id := primaryPointer.(type) { + case *uint64: + *id = cast.ToUint64(lastInsertId) + case *int64: + *id = cast.ToInt64(lastInsertId) + case *uint32: + *id = cast.ToUint32(lastInsertId) + case *int32: + *id = cast.ToInt32(lastInsertId) + case *uint16: + *id = cast.ToUint16(lastInsertId) + case *int16: + *id = cast.ToInt16(lastInsertId) + case *uint8: + *id = cast.ToUint8(lastInsertId) + case *int8: + *id = cast.ToInt8(lastInsertId) + case *uint: + *id = cast.ToUint(lastInsertId) + case *int: + *id = cast.ToInt(lastInsertId) + } + } + return nil +} + +func fillCurrentTime(value any, now time.Time) { + switch v := value.(type) { + case *time.Time: + if v.IsZero() { + *v = now + } + case *int64: + if *v == 0 { + *v = now.Unix() + } + case *uint64: + if *v == 0 { + *v = uint64(now.Unix()) + } + case *int32: + if *v == 0 { + *v = int32(now.Unix()) + } + case *uint32: + if *v == 0 { + *v = uint32(now.Unix()) + } + case *int: + if *v == 0 { + *v = int(now.Unix()) + } + case *string: + if *v == "" { + *v = now.Format(time.DateTime) + } + } +} diff --git a/insert_extra_test.go b/tools_test.go similarity index 96% rename from insert_extra_test.go rename to tools_test.go index 44cc891..f7dc995 100644 --- a/insert_extra_test.go +++ b/tools_test.go @@ -1,7 +1,6 @@ package lorm import ( - "context" "testing" "time" @@ -47,13 +46,6 @@ func TestFillCurrentTimeAllBranches(t *testing.T) { } } -func TestInsertAllEmpty(t *testing.T) { - var models []*Test - rows, err := InsertAll(context.TODO(), &Engine{config: &Config{}}, models) - assert.NoError(t, err) - assert.EqualValues(t, 0, rows) -} - type _pkInt64 struct { UnimplementedTable ID int64