Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,26 @@ import (
"gorm.io/gen/internal/utils/pools"
)

// Page pagination info
type Page struct {
Page int // current page
Limit int // limit size
}

func (p Page) GetLimit() int {
if p.Limit < 1 {
return 10
}
return p.Limit
}

func (p Page) GetOffset() int {
if p.Page <= 1 {
return 0
}
return (p.Page - 1) * p.GetLimit()
}

// T generic type
type T interface{}

Expand Down
26 changes: 26 additions & 0 deletions helper/clause.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,32 @@ func setValue(value string) string {
return strings.Trim(value, ", ")
}

// JoinRecordBuilder join records builder
func JoinRecordBuilder(src *strings.Builder, selectValue, suffix strings.Builder) {
value1 := trimAll(selectValue.String())
if value1 != "" {
src.WriteString("SELECT ")
src.WriteString(value1)
src.WriteString(" ")
}
value2 := trimAll(suffix.String())
if value2 != "" {
src.WriteString(strings.Trim(value2, " ;"))
src.WriteString(" ")
src.WriteString("LIMIT ? OFFSET ?; ")
}
}

// JoinCountBuilder join count builder
func JoinCountBuilder(src *strings.Builder, suffix strings.Builder) {
value := trimAll(suffix.String())
if value != "" {
src.WriteString("SELECT COUNT(*) ")
src.WriteString(strings.Trim(value, " ;"))
src.WriteString("; ")
}
}

// JoinWhereBuilder join where builder
func JoinWhereBuilder(src *strings.Builder, whereValue strings.Builder) {
value := trimAll(whereValue.String())
Expand Down
59 changes: 56 additions & 3 deletions internal/generate/clause.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,62 @@ func (s SQLClause) Create() string {
return fmt.Sprintf("%s.WriteString(%s)", s.VarName, s.String())
}

// Finish finish clause
func (s SQLClause) Finish() string {
return fmt.Sprintf("%s.WriteString(%s)", s.VarName, s.String())
// Finishes finish clause
func (s SQLClause) Finishes(conds ...bool) []string {
var lines []string
if s.VarName == "generateSQL" && conds != nil && len(conds) > 0 && conds[0] {
if strings.Trim(s.String(), " ;\"") != "" {
lines = append(lines, fmt.Sprintf("recordSQL.WriteString(%s)", s.String()))
lines = append(lines, fmt.Sprintf("countSQL.WriteString(%s)", s.String()))
}
} else {
lines = append(lines, fmt.Sprintf("%s.WriteString(%s)", s.VarName, s.String()))
}
return lines
}

// SelectClause select clause
type SelectClause struct {
clause
Value []Clause
}

// String string clause
func (s SelectClause) String() string {
return fmt.Sprintf("helper.SelectTrim(%s.String())", s.VarName)
}

// Create create clause
func (s SelectClause) Create() string {
return ""
}

// Finishes finish clause
func (s SelectClause) Finishes() []string {
return nil
}

// OrderByClause order by clause
type OrderByClause struct {
clause
Value []Clause
}

// String string clause
func (s OrderByClause) String() string {
return fmt.Sprintf("helper.OrderByTrim(%s.String())", s.VarName)
}

// Create create clause
func (s OrderByClause) Create() string {
return "helper.JoinCountBuilder(&countSQL, generateSQL)"
}

// Finishes finish clause
func (s OrderByClause) Finishes() []string {
return []string{
"helper.JoinRecordBuilder(&recordSQL, selectSQL, generateSQL)",
}
}

// IfClause if clause
Expand Down
8 changes: 8 additions & 0 deletions internal/generate/export.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,14 @@ func BuildDIYMethod(f *parser.InterfaceSet, s *QueryStructMeta, data []*Interfac
err = fmt.Errorf("sql [%s] build err:%w", t.SQLString, err)
return
}
if !t.NeedPaginate && t.Section.ClauseTotal[model.SELECT] > 0 {
err = fmt.Errorf("sql [%s] check err:select block can only be used if the page parameter exists", t.SQLString)
return
}
if !t.NeedCount && t.Section.ClauseTotal[model.ORDERBY] > 0 {
err = fmt.Errorf("sql [%s] check err:order by block can only be used if the count result exists", t.SQLString)
return
}
checkResults = append(checkResults, t)
}
}
Expand Down
15 changes: 13 additions & 2 deletions internal/generate/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ type InterfaceMethod struct { // feature will replace InterfaceMethod to parser.
InterfaceName string // origin interface name
Package string // interface package name
HasForParams bool //
NeedPaginate bool // need paginate or not
NeedCount bool // need count or not
}

// FuncSign function signature
Expand Down Expand Up @@ -127,7 +129,7 @@ func (m *InterfaceMethod) IsRepeatFromSameInterface(newMethod *InterfaceMethod)
return m.MethodName == newMethod.MethodName && m.InterfaceName == newMethod.InterfaceName && m.TargetStruct == newMethod.TargetStruct
}

//GetParamInTmpl return param list
// GetParamInTmpl return param list
func (m *InterfaceMethod) GetParamInTmpl() string {
return paramToString(m.Params)
}
Expand Down Expand Up @@ -193,14 +195,17 @@ func (m *InterfaceMethod) checkParams(params []parser.Param) (err error) {
case param.IsGenT():
param.Type = m.OriginStruct.Type
param.Package = m.OriginStruct.Package
case param.IsGenPage():
param.SetName("page")
m.NeedPaginate = true // need paginate
}
paramList[i] = param
}
m.Params = paramList
return
}

// checkResult check all parameters and replace gen.T by target structure. Parameters must be one of int/string/struct/map
// checkResult check all parameters and replace gen.T by target structure. Parameters must be one of int/int64/string/struct/map
func (m *InterfaceMethod) checkResult(result []parser.Param) (err error) {
resList := make([]parser.Param, len(result))
var hasError bool
Expand All @@ -215,6 +220,12 @@ func (m *InterfaceMethod) checkResult(result []parser.Param) (err error) {
switch {
case param.InMainPkg():
return fmt.Errorf("query method cannot return struct of main package in [%s.%s]", m.InterfaceName, m.MethodName)
case m.NeedPaginate && !m.ResultData.IsNull() && param.IsCount():
if m.NeedCount {
return fmt.Errorf("query method cannot return more than 1 count value in [%s.%s]", m.InterfaceName, m.MethodName)
}
param.SetName("count")
m.NeedCount = true
case param.IsError():
if hasError {
return fmt.Errorf("query method cannot return more than 1 error value in [%s.%s]", m.InterfaceName, m.MethodName)
Expand Down
112 changes: 105 additions & 7 deletions internal/generate/section.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ func (s *Section) current() section {
return s.members[s.currentIndex]
}

func (s *Section) appendTmpl(value string) {
s.Tmpls = append(s.Tmpls, value)
func (s *Section) appendTmpl(value ...string) {
s.Tmpls = append(s.Tmpls, value...)
}

func (s *Section) hasSameName(value string) bool {
Expand All @@ -76,13 +76,29 @@ func (s *Section) BuildSQL() ([]Clause, error) {
}
name := "generateSQL"
res := make([]Clause, 0, len(s.members))
ordWrite := false
for {
c := s.current()
switch c.Type {
case model.SQL, model.DATA, model.VARIABLE:
sqlClause := s.parseSQL(name)
res = append(res, sqlClause)
s.appendTmpl(sqlClause.Finish())
s.appendTmpl(sqlClause.Finishes(ordWrite)...)
case model.SELECT:
selectClause, err := s.parseSelect()
if err != nil {
return nil, err
}
res = append(res, selectClause)
s.appendTmpl(selectClause.Finishes()...)
case model.ORDERBY:
ordWrite = true
orderByClause, err := s.parseOrderBy()
if err != nil {
return nil, err
}
res = append(res, orderByClause)
s.appendTmpl(orderByClause.Finishes()...)
case model.IF:
ifClause, err := s.parseIF(name)
if err != nil {
Expand Down Expand Up @@ -131,6 +147,78 @@ func (s *Section) BuildSQL() ([]Clause, error) {
return res, nil
}

// parseSelect parse select clause
func (s *Section) parseSelect() (res SelectClause, err error) {
c := s.current()
s.current()
res.VarName = s.GetName(c.Type)
s.appendTmpl(res.Create())
res.Type = c.Type

if !s.HasMore() {
return
}
c = s.next()
for {
switch c.Type {
case model.SQL:
sqlClause := s.parseSQL(res.VarName)
res.Value = append(res.Value, sqlClause)
s.appendTmpl(sqlClause.Finishes()...)
case model.END:
return
default:
err = fmt.Errorf("unknow clause : %s", c.Value)
return
}
if !s.HasMore() {
break
}
c = s.next()
}
if c.isEnd() {
return
}
err = fmt.Errorf("incomplete SQL,select not end")
return
}

// parseOrderBy parse order by clause
func (s *Section) parseOrderBy() (res OrderByClause, err error) {
c := s.current()
s.current()
res.VarName = s.GetName(c.Type)
s.appendTmpl(res.Create())
res.Type = c.Type

if !s.HasMore() {
return
}
c = s.next()
for {
switch c.Type {
case model.SQL:
sqlClause := s.parseSQL(res.VarName)
res.Value = append(res.Value, sqlClause)
s.appendTmpl(sqlClause.Finishes()...)
case model.END:
return
default:
err = fmt.Errorf("unknow clause : %s", c.Value)
return
}
if !s.HasMore() {
break
}
c = s.next()
}
if c.isEnd() {
return
}
err = fmt.Errorf("incomplete SQL,order by not end")
return
}

// parseIF parse if clause
func (s *Section) parseIF(name string) (res IfClause, err error) {
c := s.current()
Expand All @@ -146,7 +234,7 @@ func (s *Section) parseIF(name string) (res IfClause, err error) {
case model.SQL, model.DATA, model.VARIABLE:
sqlClause := s.parseSQL(name)
res.Value = append(res.Value, sqlClause)
s.appendTmpl(sqlClause.Finish())
s.appendTmpl(sqlClause.Finishes()...)
case model.IF:
var ifClause IfClause
ifClause, err = s.parseIF(name)
Expand Down Expand Up @@ -301,7 +389,7 @@ func (s *Section) parseWhere() (res WhereClause, err error) {
case model.SQL, model.DATA, model.VARIABLE:
sqlClause := s.parseSQL(res.VarName)
res.Value = append(res.Value, sqlClause)
s.appendTmpl(sqlClause.Finish())
s.appendTmpl(sqlClause.Finishes()...)
case model.IF:
var ifClause IfClause
ifClause, err = s.parseIF(res.VarName)
Expand Down Expand Up @@ -368,7 +456,7 @@ func (s *Section) parseSet() (res SetClause, err error) {
case model.SQL, model.DATA, model.VARIABLE:
sqlClause := s.parseSQL(res.VarName)
res.Value = append(res.Value, sqlClause)
s.appendTmpl(sqlClause.Finish())
s.appendTmpl(sqlClause.Finishes()...)
case model.IF:
var ifClause IfClause
ifClause, err = s.parseIF(res.VarName)
Expand Down Expand Up @@ -434,7 +522,7 @@ func (s *Section) parseTrim() (res TrimClause, err error) {
case model.SQL, model.DATA, model.VARIABLE:
sqlClause := s.parseSQL(res.VarName)
res.Value = append(res.Value, sqlClause)
s.appendTmpl(sqlClause.Finish())
s.appendTmpl(sqlClause.Finishes()...)
case model.IF:
var ifClause IfClause
ifClause, err = s.parseIF(res.VarName)
Expand Down Expand Up @@ -593,6 +681,12 @@ func (s *Section) GetName(status model.Status) string {
case model.TRIM:
defer func() { s.ClauseTotal[model.TRIM]++ }()
return fmt.Sprintf("trimSQL%d", s.ClauseTotal[model.TRIM])
case model.SELECT:
defer func() { s.ClauseTotal[model.SELECT]++ }()
return "selectSQL"
case model.ORDERBY:
defer func() { s.ClauseTotal[model.ORDERBY]++ }()
return "generateSQL"
default:
return "generateSQL"
}
Expand Down Expand Up @@ -677,6 +771,10 @@ func (s *section) sectionType(str string) error {
s.Type = model.END
case "trim":
s.Type = model.TRIM
case "select":
s.Type = model.SELECT
case "orderby":
s.Type = model.ORDERBY
default:
return fmt.Errorf("unknown syntax: %s", str)
}
Expand Down
4 changes: 4 additions & 0 deletions internal/model/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ const (
END
// TRIM ...
TRIM
// SELECT ...
SELECT
// ORDERBY ...
ORDERBY
)

// SourceCode source code
Expand Down
Loading
Loading