diff --git a/generator.go b/generator.go index b90bbc79..56fa82c2 100644 --- a/generator.go +++ b/generator.go @@ -27,6 +27,26 @@ import ( "gorm.io/gen/internal/utils/pools" ) +// Page pagination info +type Page struct { + Page int // current page + Limit int // limit size +} + +func (p Page) GetLimit() int { + if p.Limit < 1 { + return 10 + } + return p.Limit +} + +func (p Page) GetOffset() int { + if p.Page <= 1 { + return 0 + } + return (p.Page - 1) * p.GetLimit() +} + // T generic type type T interface{} diff --git a/helper/clause.go b/helper/clause.go index 4bd18094..3eff1af4 100644 --- a/helper/clause.go +++ b/helper/clause.go @@ -110,6 +110,32 @@ func setValue(value string) string { return strings.Trim(value, ", ") } +// JoinRecordBuilder join records builder +func JoinRecordBuilder(src *strings.Builder, selectValue, suffix strings.Builder) { + value1 := trimAll(selectValue.String()) + if value1 != "" { + src.WriteString("SELECT ") + src.WriteString(value1) + src.WriteString(" ") + } + value2 := trimAll(suffix.String()) + if value2 != "" { + src.WriteString(strings.Trim(value2, " ;")) + src.WriteString(" ") + src.WriteString("LIMIT ? OFFSET ?; ") + } +} + +// JoinCountBuilder join count builder +func JoinCountBuilder(src *strings.Builder, suffix strings.Builder) { + value := trimAll(suffix.String()) + if value != "" { + src.WriteString("SELECT COUNT(*) ") + src.WriteString(strings.Trim(value, " ;")) + src.WriteString("; ") + } +} + // JoinWhereBuilder join where builder func JoinWhereBuilder(src *strings.Builder, whereValue strings.Builder) { value := trimAll(whereValue.String()) diff --git a/internal/generate/clause.go b/internal/generate/clause.go index 9c5d6abf..070b590b 100644 --- a/internal/generate/clause.go +++ b/internal/generate/clause.go @@ -51,9 +51,62 @@ func (s SQLClause) Create() string { return fmt.Sprintf("%s.WriteString(%s)", s.VarName, s.String()) } -// Finish finish clause -func (s SQLClause) Finish() string { - return fmt.Sprintf("%s.WriteString(%s)", s.VarName, s.String()) +// Finishes finish clause +func (s SQLClause) Finishes(conds ...bool) []string { + var lines []string + if s.VarName == "generateSQL" && conds != nil && len(conds) > 0 && conds[0] { + if strings.Trim(s.String(), " ;\"") != "" { + lines = append(lines, fmt.Sprintf("recordSQL.WriteString(%s)", s.String())) + lines = append(lines, fmt.Sprintf("countSQL.WriteString(%s)", s.String())) + } + } else { + lines = append(lines, fmt.Sprintf("%s.WriteString(%s)", s.VarName, s.String())) + } + return lines +} + +// SelectClause select clause +type SelectClause struct { + clause + Value []Clause +} + +// String string clause +func (s SelectClause) String() string { + return fmt.Sprintf("helper.SelectTrim(%s.String())", s.VarName) +} + +// Create create clause +func (s SelectClause) Create() string { + return "" +} + +// Finishes finish clause +func (s SelectClause) Finishes() []string { + return nil +} + +// OrderByClause order by clause +type OrderByClause struct { + clause + Value []Clause +} + +// String string clause +func (s OrderByClause) String() string { + return fmt.Sprintf("helper.OrderByTrim(%s.String())", s.VarName) +} + +// Create create clause +func (s OrderByClause) Create() string { + return "helper.JoinCountBuilder(&countSQL, generateSQL)" +} + +// Finishes finish clause +func (s OrderByClause) Finishes() []string { + return []string{ + "helper.JoinRecordBuilder(&recordSQL, selectSQL, generateSQL)", + } } // IfClause if clause diff --git a/internal/generate/export.go b/internal/generate/export.go index 793a1059..76e8932c 100644 --- a/internal/generate/export.go +++ b/internal/generate/export.go @@ -207,6 +207,14 @@ func BuildDIYMethod(f *parser.InterfaceSet, s *QueryStructMeta, data []*Interfac err = fmt.Errorf("sql [%s] build err:%w", t.SQLString, err) return } + if !t.NeedPaginate && t.Section.ClauseTotal[model.SELECT] > 0 { + err = fmt.Errorf("sql [%s] check err:select block can only be used if the page parameter exists", t.SQLString) + return + } + if !t.NeedCount && t.Section.ClauseTotal[model.ORDERBY] > 0 { + err = fmt.Errorf("sql [%s] check err:order by block can only be used if the count result exists", t.SQLString) + return + } checkResults = append(checkResults, t) } } diff --git a/internal/generate/interface.go b/internal/generate/interface.go index f43bf56f..bb807d55 100644 --- a/internal/generate/interface.go +++ b/internal/generate/interface.go @@ -27,6 +27,8 @@ type InterfaceMethod struct { // feature will replace InterfaceMethod to parser. InterfaceName string // origin interface name Package string // interface package name HasForParams bool // + NeedPaginate bool // need paginate or not + NeedCount bool // need count or not } // FuncSign function signature @@ -127,7 +129,7 @@ func (m *InterfaceMethod) IsRepeatFromSameInterface(newMethod *InterfaceMethod) return m.MethodName == newMethod.MethodName && m.InterfaceName == newMethod.InterfaceName && m.TargetStruct == newMethod.TargetStruct } -//GetParamInTmpl return param list +// GetParamInTmpl return param list func (m *InterfaceMethod) GetParamInTmpl() string { return paramToString(m.Params) } @@ -193,6 +195,9 @@ func (m *InterfaceMethod) checkParams(params []parser.Param) (err error) { case param.IsGenT(): param.Type = m.OriginStruct.Type param.Package = m.OriginStruct.Package + case param.IsGenPage(): + param.SetName("page") + m.NeedPaginate = true // need paginate } paramList[i] = param } @@ -200,7 +205,7 @@ func (m *InterfaceMethod) checkParams(params []parser.Param) (err error) { return } -// checkResult check all parameters and replace gen.T by target structure. Parameters must be one of int/string/struct/map +// checkResult check all parameters and replace gen.T by target structure. Parameters must be one of int/int64/string/struct/map func (m *InterfaceMethod) checkResult(result []parser.Param) (err error) { resList := make([]parser.Param, len(result)) var hasError bool @@ -215,6 +220,12 @@ func (m *InterfaceMethod) checkResult(result []parser.Param) (err error) { switch { case param.InMainPkg(): return fmt.Errorf("query method cannot return struct of main package in [%s.%s]", m.InterfaceName, m.MethodName) + case m.NeedPaginate && !m.ResultData.IsNull() && param.IsCount(): + if m.NeedCount { + return fmt.Errorf("query method cannot return more than 1 count value in [%s.%s]", m.InterfaceName, m.MethodName) + } + param.SetName("count") + m.NeedCount = true case param.IsError(): if hasError { return fmt.Errorf("query method cannot return more than 1 error value in [%s.%s]", m.InterfaceName, m.MethodName) diff --git a/internal/generate/section.go b/internal/generate/section.go index 07276422..6eb38750 100644 --- a/internal/generate/section.go +++ b/internal/generate/section.go @@ -56,8 +56,8 @@ func (s *Section) current() section { return s.members[s.currentIndex] } -func (s *Section) appendTmpl(value string) { - s.Tmpls = append(s.Tmpls, value) +func (s *Section) appendTmpl(value ...string) { + s.Tmpls = append(s.Tmpls, value...) } func (s *Section) hasSameName(value string) bool { @@ -76,13 +76,29 @@ func (s *Section) BuildSQL() ([]Clause, error) { } name := "generateSQL" res := make([]Clause, 0, len(s.members)) + ordWrite := false for { c := s.current() switch c.Type { case model.SQL, model.DATA, model.VARIABLE: sqlClause := s.parseSQL(name) res = append(res, sqlClause) - s.appendTmpl(sqlClause.Finish()) + s.appendTmpl(sqlClause.Finishes(ordWrite)...) + case model.SELECT: + selectClause, err := s.parseSelect() + if err != nil { + return nil, err + } + res = append(res, selectClause) + s.appendTmpl(selectClause.Finishes()...) + case model.ORDERBY: + ordWrite = true + orderByClause, err := s.parseOrderBy() + if err != nil { + return nil, err + } + res = append(res, orderByClause) + s.appendTmpl(orderByClause.Finishes()...) case model.IF: ifClause, err := s.parseIF(name) if err != nil { @@ -131,6 +147,78 @@ func (s *Section) BuildSQL() ([]Clause, error) { return res, nil } +// parseSelect parse select clause +func (s *Section) parseSelect() (res SelectClause, err error) { + c := s.current() + s.current() + res.VarName = s.GetName(c.Type) + s.appendTmpl(res.Create()) + res.Type = c.Type + + if !s.HasMore() { + return + } + c = s.next() + for { + switch c.Type { + case model.SQL: + sqlClause := s.parseSQL(res.VarName) + res.Value = append(res.Value, sqlClause) + s.appendTmpl(sqlClause.Finishes()...) + case model.END: + return + default: + err = fmt.Errorf("unknow clause : %s", c.Value) + return + } + if !s.HasMore() { + break + } + c = s.next() + } + if c.isEnd() { + return + } + err = fmt.Errorf("incomplete SQL,select not end") + return +} + +// parseOrderBy parse order by clause +func (s *Section) parseOrderBy() (res OrderByClause, err error) { + c := s.current() + s.current() + res.VarName = s.GetName(c.Type) + s.appendTmpl(res.Create()) + res.Type = c.Type + + if !s.HasMore() { + return + } + c = s.next() + for { + switch c.Type { + case model.SQL: + sqlClause := s.parseSQL(res.VarName) + res.Value = append(res.Value, sqlClause) + s.appendTmpl(sqlClause.Finishes()...) + case model.END: + return + default: + err = fmt.Errorf("unknow clause : %s", c.Value) + return + } + if !s.HasMore() { + break + } + c = s.next() + } + if c.isEnd() { + return + } + err = fmt.Errorf("incomplete SQL,order by not end") + return +} + // parseIF parse if clause func (s *Section) parseIF(name string) (res IfClause, err error) { c := s.current() @@ -146,7 +234,7 @@ func (s *Section) parseIF(name string) (res IfClause, err error) { case model.SQL, model.DATA, model.VARIABLE: sqlClause := s.parseSQL(name) res.Value = append(res.Value, sqlClause) - s.appendTmpl(sqlClause.Finish()) + s.appendTmpl(sqlClause.Finishes()...) case model.IF: var ifClause IfClause ifClause, err = s.parseIF(name) @@ -301,7 +389,7 @@ func (s *Section) parseWhere() (res WhereClause, err error) { case model.SQL, model.DATA, model.VARIABLE: sqlClause := s.parseSQL(res.VarName) res.Value = append(res.Value, sqlClause) - s.appendTmpl(sqlClause.Finish()) + s.appendTmpl(sqlClause.Finishes()...) case model.IF: var ifClause IfClause ifClause, err = s.parseIF(res.VarName) @@ -368,7 +456,7 @@ func (s *Section) parseSet() (res SetClause, err error) { case model.SQL, model.DATA, model.VARIABLE: sqlClause := s.parseSQL(res.VarName) res.Value = append(res.Value, sqlClause) - s.appendTmpl(sqlClause.Finish()) + s.appendTmpl(sqlClause.Finishes()...) case model.IF: var ifClause IfClause ifClause, err = s.parseIF(res.VarName) @@ -434,7 +522,7 @@ func (s *Section) parseTrim() (res TrimClause, err error) { case model.SQL, model.DATA, model.VARIABLE: sqlClause := s.parseSQL(res.VarName) res.Value = append(res.Value, sqlClause) - s.appendTmpl(sqlClause.Finish()) + s.appendTmpl(sqlClause.Finishes()...) case model.IF: var ifClause IfClause ifClause, err = s.parseIF(res.VarName) @@ -593,6 +681,12 @@ func (s *Section) GetName(status model.Status) string { case model.TRIM: defer func() { s.ClauseTotal[model.TRIM]++ }() return fmt.Sprintf("trimSQL%d", s.ClauseTotal[model.TRIM]) + case model.SELECT: + defer func() { s.ClauseTotal[model.SELECT]++ }() + return "selectSQL" + case model.ORDERBY: + defer func() { s.ClauseTotal[model.ORDERBY]++ }() + return "generateSQL" default: return "generateSQL" } @@ -677,6 +771,10 @@ func (s *section) sectionType(str string) error { s.Type = model.END case "trim": s.Type = model.TRIM + case "select": + s.Type = model.SELECT + case "orderby": + s.Type = model.ORDERBY default: return fmt.Errorf("unknown syntax: %s", str) } diff --git a/internal/model/base.go b/internal/model/base.go index e5cdcfd3..ba14bdd8 100644 --- a/internal/model/base.go +++ b/internal/model/base.go @@ -38,6 +38,10 @@ const ( END // TRIM ... TRIM + // SELECT ... + SELECT + // ORDERBY ... + ORDERBY ) // SourceCode source code diff --git a/internal/parser/parser.go b/internal/parser/parser.go index 5c4cdda0..fc96ab76 100644 --- a/internal/parser/parser.go +++ b/internal/parser/parser.go @@ -141,6 +141,16 @@ func (p *Param) IsError() bool { return p.Type == "error" } +// IsGenPage ... +func (p *Param) IsGenPage() bool { + return p.Package == "gen" && p.Type == "Page" && !p.IsPointer +} + +// IsCount ... +func (p *Param) IsCount() bool { + return p.Package == "" && p.Type == "int64" && !p.IsPointer +} + // IsGenM ... func (p *Param) IsGenM() bool { return p.Package == "gen" && p.Type == "M" @@ -273,8 +283,10 @@ func (p *Param) astGetParamType(param *ast.Field) { p.astGetEltType(v.X) case *ast.IndexExpr: p.astGetEltType(v.X) + p.astGetGenericType(v.Index) case *ast.IndexListExpr: p.astGetEltType(v.X) + p.astGetGenericType(v.Indices...) default: log.Printf("Unsupported param type: %+v", v) } @@ -302,11 +314,33 @@ func (p *Param) astGetEltType(expr ast.Expr) { p.Type = "[]" + p.Type case *ast.IndexExpr: p.astGetEltType(v.X) + p.astGetGenericType(v.Index) + case *ast.IndexListExpr: + p.astGetEltType(v.X) + p.astGetGenericType(v.Indices...) default: log.Printf("Unsupported param type: %+v", v) } } +func (p *Param) astGetGenericType(exprList ...ast.Expr) { + if p.Package == "" { + p.Package = "UNDEFINED" // Generic types are definitely not built-in types. + } + if len(exprList) == 0 { + return + } + var types []string + for _, expr := range exprList { + typeStr := astGetType(expr) + if typeStr == "" { + typeStr = "interface{}" // fallback for unsupported types + } + types = append(types, typeStr) + } + p.Type = fmt.Sprintf("%s[%s]", p.Type, strings.Join(types, ", ")) +} + func (p *Param) astGetPackageName(expr ast.Expr) { switch v := expr.(type) { case *ast.Ident: @@ -324,6 +358,21 @@ func astGetType(expr ast.Expr) string { return v.Name case *ast.InterfaceType: return "interface{}" + case *ast.SelectorExpr: + typ := v.Sel.Name + pkg := astGetPackageName(v.X) + if pkg != "" { + return fmt.Sprintf("%s.%s", pkg, typ) + } + return typ + } + return "" +} + +func astGetPackageName(expr ast.Expr) string { + switch v := expr.(type) { + case *ast.Ident: + return v.Name } return "" } diff --git a/internal/template/method.go b/internal/template/method.go index 5263166e..299e6a12 100644 --- a/internal/template/method.go +++ b/internal/template/method.go @@ -5,23 +5,44 @@ const DIYMethod = ` // {{.DocComment }} func ({{.S}} {{.TargetStruct}}Do){{.FuncSign}}{ - {{if .HasSQLData}}var params []interface{} + {{- $needParams := or .HasSQLData .NeedPaginate }} + {{- $execSqlName := "generateSQL"}} + {{- $needRecordError := or .NeedCount .ReturnError}} + {{if $needParams}}var params []interface{} {{end}}var generateSQL strings.Builder + {{if .NeedPaginate}}{{$execSqlName = "recordSQL"}}var recordSQL, selectSQL strings.Builder{{end}} + {{if .NeedCount}}var countSQL strings.Builder{{end}} {{range $line:=.Section.Tmpls}}{{$line}} {{end}} - + {{if .NeedPaginate}} + {{if not .NeedCount}}helper.JoinRecordBuilder(&recordSQL, selectSQL, generateSQL){{end}} + params = append(params, page.GetLimit(), page.GetOffset()) + {{end}} {{if .HasNeedNewResult}}result ={{if .ResultData.IsMap}}make{{else}}new{{end}}({{if ne .ResultData.Package ""}}{{.ResultData.Package}}.{{end}}{{.ResultData.Type}}){{end}} {{if .ReturnSQLResult}}stmt := {{.S}}.UnderlyingDB().Statement - result,{{if .ReturnError}}err{{else}}_{{end}} = stmt.ConnPool.ExecContext(stmt.Context,generateSQL.String(){{if .HasSQLData}},params...{{end}}) // ignore_security_alert - {{else if .ReturnSQLRow}}row = {{.S}}.UnderlyingDB().Raw(generateSQL.String(){{if .HasSQLData}},params...{{end}}).Row() // ignore_security_alert - {{else if .ReturnSQLRows}}rows,{{if .ReturnError}}err{{else}}_{{end}} = {{.S}}.UnderlyingDB().Raw(generateSQL.String(){{if .HasSQLData}},params...{{end}}).Rows() // ignore_security_alert + result,{{if $needRecordError}}err{{else}}_{{end}} = stmt.ConnPool.ExecContext(stmt.Context,{{$execSqlName}}.String(){{if $needParams}},params...{{end}}) // ignore_security_alert + {{else if .ReturnSQLRow}}row = {{.S}}.UnderlyingDB().Raw({{$execSqlName}}.String(){{if $needParams}},params...{{end}}).Row() // ignore_security_alert + {{else if .ReturnSQLRows}}rows,{{if $needRecordError}}err{{else}}_{{end}} = {{.S}}.UnderlyingDB().Raw({{$execSqlName}}.String(){{if $needParams}},params...{{end}}).Rows() // ignore_security_alert {{else}}var executeSQL *gorm.DB - executeSQL = {{.S}}.UnderlyingDB().{{.GormOption}}(generateSQL.String(){{if .HasSQLData}},params...{{end}}){{if not .ResultData.IsNull}}.{{.GormRunMethodName}}({{if .HasGotPoint}}&{{end}}{{.ResultData.Name}}){{end}} // ignore_security_alert + executeSQL = {{.S}}.UnderlyingDB().{{.GormOption}}({{$execSqlName}}.String(){{if $needParams}},params...{{end}}){{if not .ResultData.IsNull}}.{{.GormRunMethodName}}({{if .HasGotPoint}}&{{end}}{{.ResultData.Name}}){{end}} // ignore_security_alert {{if .ReturnRowsAffected}}rowsAffected = executeSQL.RowsAffected - {{end}}{{if .ReturnError}}err = executeSQL.Error + {{end}}{{if $needRecordError}}err = executeSQL.Error {{end}}{{if .ReturnNothing}}_ = executeSQL {{end}}{{end}} + {{- if .NeedCount}}if err != nil { + return + } + + if size := len({{.ResultData.Name}}); 0 < page.GetLimit() && 0 < size && size < page.GetLimit() { + count = int64(size+page.GetOffset()) + return + } + + executeSQL = {{.S}}.UnderlyingDB().{{.GormOption}}(countSQL.String(){{if $needParams}},params[:len(params)-2]...{{end}}).Take(&count) // ignore_security_alert + {{if .ReturnError}}err = executeSQL.Error + {{else}}_ = executeSQL{{end}} + {{end}} return }