diff --git a/dialects_test.go b/dialects_test.go index eafe357..9b0d481 100644 --- a/dialects_test.go +++ b/dialects_test.go @@ -619,6 +619,89 @@ func DoTestFind(t *testing.T, info dialectInfo) { } } +func TestFindOne(t *testing.T) { + for _, info := range toRun { + DoTestFindOne(t, info) + } +} + +func DoTestFindOne(t *testing.T, info dialectInfo) { + t.Logf("Dialect %T\n", info.dialect) + hd := info.setupDbFunc(t) + + type findOneModel struct { + Id Id + A string + B int + } + + model1 := findOneModel{ + A: "string!", + B: 2, + } + + hd.DropTable(&model1) + + tx := hd.Begin() + tx.CreateTable(&model1) + err := tx.Commit() + if err != nil { + t.Fatal("error not nil", err) + } + + // Test with no data (should be empty) + + // Test with struct + var out findOneModel + err = hd.Where("a", "=", "string!").And("b", "=", 2).FindOne(&out) + if err != nil { + t.Fatal("error not nil", err) + } + if out.Id != 0 || out.A != "" || out.B != 0 { + t.Fatal("struct should be zeroed", out) + } + + // Test with pointer to struct + var out2 *findOneModel + err = hd.Where("a", "=", "string!").And("b", "=", 2).FindOne(&out) + if err != nil { + t.Fatal("error not nil", err) + } + if out2 != nil { + t.Fatal("ptr should be nil", out) + } + + // Insert a row, test again: + id, err := hd.Save(&model1) + if err != nil { + t.Fatal("error not nil", err) + } + if id != 1 { + t.Fatal("wrong id", id) + } + + // Test with struct + err = hd.Where("a", "=", "string!").And("b", "=", 2).FindOne(&out) + if err != nil { + t.Fatal("error not nil", err) + } + if out.Id != 1 || out.A != "string!" || out.B != 2 { + t.Fatal("struct is not expected", out) + } + + // Test with pointer to struct + err = hd.Where("a", "=", "string!").And("b", "=", 2).FindOne(&out2) + if err != nil { + t.Fatal("error not nil", err) + } + if out2 == nil { + t.Fatal("ptr should not be nil", out2) + } + if out2.Id != 1 || out2.A != "string!" || out2.B != 2 { + t.Fatal("struct is not expected", out) + } +} + func TestCreateTable(t *testing.T) { for _, info := range toRun { DoTestCreateTable(t, info) diff --git a/hood.go b/hood.go index 149436d..963f414 100644 --- a/hood.go +++ b/hood.go @@ -692,6 +692,92 @@ func (hood *Hood) Find(out interface{}) error { query, args := hood.Dialect.QuerySql(hood) return hood.FindSql(out, query, args...) } +func (hood *Hood) FindOne(out interface{}) error { + if hood.selectTable == "" { + hood.Select(out) + } + hood.Limit(1) + query, args := hood.Dialect.QuerySql(hood) + return hood.FindOneSql(out, query, args...) +} +func (hood *Hood) FindOneSql(out interface{}, query string, args ...interface{}) error { + hood.mutex.Lock() + defer hood.mutex.Unlock() + defer hood.Reset() + + // Validate arg: it can either be a pointer to a struct, or a pointer to a pointer to a struct + // Must be pointer + if x := reflect.TypeOf(out).Kind(); x != reflect.Ptr { + panic("argument must be a pointer") + } + + outValue := reflect.Indirect(reflect.ValueOf(out)) + outKind := outValue.Kind() + var elementType reflect.Type + if outKind == reflect.Ptr { + // Must be a pointer to a pointer to a struct... + elementType = outValue.Type().Elem() + if elementType.Kind() != reflect.Struct { + panic("argument must be a pointer to a pointer to a struct") + } + + } else if outKind == reflect.Struct { + // Or just a pointer to a struct ... + elementType = outValue.Type() + + } else { + panic("argument must be a pointer to a struct, or a pointer to a pointer to a struct") + } + + hood.logSql(query, args...) + stmt, err := hood.qo.Prepare(query) + if err != nil { + return hood.updateTxError(err) + } + defer stmt.Close() + rows, err := stmt.Query(args...) + if err != nil { + return hood.updateTxError(err) + } + defer rows.Close() + cols, err := rows.Columns() + if err != nil { + return hood.updateTxError(err) + } + for rows.Next() { + containers := make([]interface{}, 0, len(cols)) + for i := 0; i < cap(containers); i++ { + var v interface{} + containers = append(containers, &v) + } + err := rows.Scan(containers...) + if err != nil { + return err + } + // create a new row and fill + rowValue := reflect.New(elementType) + for i, v := range containers { + key := cols[i] + value := reflect.Indirect(reflect.ValueOf(v)) + name := snakeToUpperCamel(key) + field := rowValue.Elem().FieldByName(name) + if field.IsValid() { + err = hood.Dialect.SetModelValue(value, field) + if err != nil { + return err + } + } + } + + if outKind == reflect.Struct { + outValue.Set(rowValue.Elem()) + } else { + reflect.ValueOf(out).Elem().Set(rowValue) + } + break + } + return nil +} // FindSql performs a find using the specified custom sql query and arguments and // writes the results to the specified out interface{}.