Skip to content

Commit c44405a

Browse files
authored
Implement Generics API (#7424)
* Implement Generics API * Add more generics tests * Add more tests and Take method * use delayed‑ops pipeline for generics API * fix generics tests for mysql * Support SubQuery for Generics * Add clause.JoinTable helper method * Fix golangci-lint error * Complete the design and implementation of generic version Join * improve generics version Joins support * allow configuring select/omit columns for joins via subqueries * finish generic version Preload * handle error of generics Joins/Preload * fix tests * Add LimitPerRecord for generic version Preload * fix tests for mysql 5.7 * test for nested generic version Join/Preload * Add WithResult support for generics API * test reuse generics db conditions * fix data race * remove ExampleLRU test * Add default transaction timeout support * fix test
1 parent 751c1d6 commit c44405a

20 files changed

+1644
-92
lines changed

callbacks/create.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,10 @@ func Create(config *Config) func(db *gorm.DB) {
8989
db.AddError(rows.Close())
9090
}()
9191
gorm.Scan(rows, db, mode)
92+
93+
if db.Statement.Result != nil {
94+
db.Statement.Result.RowsAffected = db.RowsAffected
95+
}
9296
}
9397

9498
return
@@ -103,6 +107,12 @@ func Create(config *Config) func(db *gorm.DB) {
103107
}
104108

105109
db.RowsAffected, _ = result.RowsAffected()
110+
111+
if db.Statement.Result != nil {
112+
db.Statement.Result.Result = result
113+
db.Statement.Result.RowsAffected = db.RowsAffected
114+
}
115+
106116
if db.RowsAffected == 0 {
107117
return
108118
}

callbacks/delete.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,15 +157,25 @@ func Delete(config *Config) func(db *gorm.DB) {
157157
ok, mode := hasReturning(db, supportReturning)
158158
if !ok {
159159
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
160+
160161
if db.AddError(err) == nil {
161162
db.RowsAffected, _ = result.RowsAffected()
163+
164+
if db.Statement.Result != nil {
165+
db.Statement.Result.Result = result
166+
db.Statement.Result.RowsAffected = db.RowsAffected
167+
}
162168
}
163169

164170
return
165171
}
166172

167173
if rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); db.AddError(err) == nil {
168174
gorm.Scan(rows, db, mode)
175+
176+
if db.Statement.Result != nil {
177+
db.Statement.Result.RowsAffected = db.RowsAffected
178+
}
169179
db.AddError(rows.Close())
170180
}
171181
}

callbacks/preload.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,8 @@ func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preload
275275
column, values := schema.ToQueryValues(clause.CurrentTable, relForeignKeys, foreignValues)
276276

277277
if len(values) != 0 {
278+
tx = tx.Model(reflectResults.Addr().Interface()).Where(clause.IN{Column: column, Values: values})
279+
278280
for _, cond := range conds {
279281
if fc, ok := cond.(func(*gorm.DB) *gorm.DB); ok {
280282
tx = fc(tx)
@@ -283,7 +285,11 @@ func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preload
283285
}
284286
}
285287

286-
if err := tx.Where(clause.IN{Column: column, Values: values}).Find(reflectResults.Addr().Interface(), inlineConds...).Error; err != nil {
288+
if len(inlineConds) > 0 {
289+
tx = tx.Where(inlineConds[0], inlineConds[1:]...)
290+
}
291+
292+
if err := tx.Find(reflectResults.Addr().Interface()).Error; err != nil {
287293
return err
288294
}
289295
}

callbacks/query.go

Lines changed: 30 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@ func Query(db *gorm.DB) {
2525
db.AddError(rows.Close())
2626
}()
2727
gorm.Scan(rows, db, 0)
28+
29+
if db.Statement.Result != nil {
30+
db.Statement.Result.RowsAffected = db.RowsAffected
31+
}
2832
}
2933
}
3034
}
@@ -110,7 +114,7 @@ func BuildQuerySQL(db *gorm.DB) {
110114
}
111115
}
112116

113-
specifiedRelationsName := make(map[string]interface{})
117+
specifiedRelationsName := map[string]string{clause.CurrentTable: clause.CurrentTable}
114118
for _, join := range db.Statement.Joins {
115119
if db.Statement.Schema != nil {
116120
var isRelations bool // is relations or raw sql
@@ -124,12 +128,12 @@ func BuildQuerySQL(db *gorm.DB) {
124128
nestedJoinNames := strings.Split(join.Name, ".")
125129
if len(nestedJoinNames) > 1 {
126130
isNestedJoin := true
127-
gussNestedRelations := make([]*schema.Relationship, 0, len(nestedJoinNames))
131+
guessNestedRelations := make([]*schema.Relationship, 0, len(nestedJoinNames))
128132
currentRelations := db.Statement.Schema.Relationships.Relations
129133
for _, relname := range nestedJoinNames {
130134
// incomplete match, only treated as raw sql
131135
if relation, ok = currentRelations[relname]; ok {
132-
gussNestedRelations = append(gussNestedRelations, relation)
136+
guessNestedRelations = append(guessNestedRelations, relation)
133137
currentRelations = relation.FieldSchema.Relationships.Relations
134138
} else {
135139
isNestedJoin = false
@@ -139,18 +143,13 @@ func BuildQuerySQL(db *gorm.DB) {
139143

140144
if isNestedJoin {
141145
isRelations = true
142-
relations = gussNestedRelations
146+
relations = guessNestedRelations
143147
}
144148
}
145149
}
146150

147151
if isRelations {
148-
genJoinClause := func(joinType clause.JoinType, parentTableName string, relation *schema.Relationship) clause.Join {
149-
tableAliasName := relation.Name
150-
if parentTableName != clause.CurrentTable {
151-
tableAliasName = utils.NestedRelationName(parentTableName, tableAliasName)
152-
}
153-
152+
genJoinClause := func(joinType clause.JoinType, tableAliasName string, parentTableName string, relation *schema.Relationship) clause.Join {
154153
columnStmt := gorm.Statement{
155154
Table: tableAliasName, DB: db, Schema: relation.FieldSchema,
156155
Selects: join.Selects, Omits: join.Omits,
@@ -167,6 +166,13 @@ func BuildQuerySQL(db *gorm.DB) {
167166
}
168167
}
169168

169+
if join.Expression != nil {
170+
return clause.Join{
171+
Type: join.JoinType,
172+
Expression: join.Expression,
173+
}
174+
}
175+
170176
exprs := make([]clause.Expression, len(relation.References))
171177
for idx, ref := range relation.References {
172178
if ref.OwnPrimaryKey {
@@ -226,19 +232,24 @@ func BuildQuerySQL(db *gorm.DB) {
226232
}
227233

228234
parentTableName := clause.CurrentTable
229-
for _, rel := range relations {
235+
for idx, rel := range relations {
230236
// joins table alias like "Manager, Company, Manager__Company"
231-
nestedAlias := utils.NestedRelationName(parentTableName, rel.Name)
232-
if _, ok := specifiedRelationsName[nestedAlias]; !ok {
233-
fromClause.Joins = append(fromClause.Joins, genJoinClause(join.JoinType, parentTableName, rel))
234-
specifiedRelationsName[nestedAlias] = nil
237+
curAliasName := rel.Name
238+
if parentTableName != clause.CurrentTable {
239+
curAliasName = utils.NestedRelationName(parentTableName, curAliasName)
235240
}
236241

237-
if parentTableName != clause.CurrentTable {
238-
parentTableName = utils.NestedRelationName(parentTableName, rel.Name)
239-
} else {
240-
parentTableName = rel.Name
242+
if _, ok := specifiedRelationsName[curAliasName]; !ok {
243+
aliasName := curAliasName
244+
if idx == len(relations)-1 && join.Alias != "" {
245+
aliasName = join.Alias
246+
}
247+
248+
fromClause.Joins = append(fromClause.Joins, genJoinClause(join.JoinType, aliasName, specifiedRelationsName[parentTableName], rel))
249+
specifiedRelationsName[curAliasName] = aliasName
241250
}
251+
252+
parentTableName = curAliasName
242253
}
243254
} else {
244255
fromClause.Joins = append(fromClause.Joins, clause.Join{

callbacks/raw.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,5 +13,10 @@ func RawExec(db *gorm.DB) {
1313
}
1414

1515
db.RowsAffected, _ = result.RowsAffected()
16+
17+
if db.Statement.Result != nil {
18+
db.Statement.Result.Result = result
19+
db.Statement.Result.RowsAffected = db.RowsAffected
20+
}
1621
}
1722
}

callbacks/update.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,13 +92,22 @@ func Update(config *Config) func(db *gorm.DB) {
9292
gorm.Scan(rows, db, mode)
9393
db.Statement.Dest = dest
9494
db.AddError(rows.Close())
95+
96+
if db.Statement.Result != nil {
97+
db.Statement.Result.RowsAffected = db.RowsAffected
98+
}
9599
}
96100
} else {
97101
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
98102

99103
if db.AddError(err) == nil {
100104
db.RowsAffected, _ = result.RowsAffected()
101105
}
106+
107+
if db.Statement.Result != nil {
108+
db.Statement.Result.Result = result
109+
db.Statement.Result.RowsAffected = db.RowsAffected
110+
}
102111
}
103112
}
104113
}

chainable_api.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -448,9 +448,10 @@ func (db *DB) Assign(attrs ...interface{}) (tx *DB) {
448448
// Unscoped allows queries to include records marked as deleted,
449449
// overriding the soft deletion behavior.
450450
// Example:
451-
// var users []User
452-
// db.Unscoped().Find(&users)
453-
// // Retrieves all users, including deleted ones.
451+
//
452+
// var users []User
453+
// db.Unscoped().Find(&users)
454+
// // Retrieves all users, including deleted ones.
454455
func (db *DB) Unscoped() (tx *DB) {
455456
tx = db.getInstance()
456457
tx.Statement.Unscoped = true

clause/joins.go

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
package clause
22

3+
import "gorm.io/gorm/utils"
4+
35
type JoinType string
46

57
const (
@@ -9,6 +11,30 @@ const (
911
RightJoin JoinType = "RIGHT"
1012
)
1113

14+
type JoinTarget struct {
15+
Type JoinType
16+
Association string
17+
Subquery Expression
18+
Table string
19+
}
20+
21+
func Has(name string) JoinTarget {
22+
return JoinTarget{Type: InnerJoin, Association: name}
23+
}
24+
25+
func (jt JoinType) Association(name string) JoinTarget {
26+
return JoinTarget{Type: jt, Association: name}
27+
}
28+
29+
func (jt JoinType) AssociationFrom(name string, subquery Expression) JoinTarget {
30+
return JoinTarget{Type: jt, Association: name, Subquery: subquery}
31+
}
32+
33+
func (jt JoinTarget) As(name string) JoinTarget {
34+
jt.Table = name
35+
return jt
36+
}
37+
1238
// Join clause for from
1339
type Join struct {
1440
Type JoinType
@@ -18,6 +44,12 @@ type Join struct {
1844
Expression Expression
1945
}
2046

47+
func JoinTable(names ...string) Table {
48+
return Table{
49+
Name: utils.JoinNestedRelationNames(names),
50+
}
51+
}
52+
2153
func (join Join) Build(builder Builder) {
2254
if join.Expression != nil {
2355
join.Expression.Build(builder)

finisher_api.go

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package gorm
22

33
import (
4+
"context"
45
"database/sql"
56
"errors"
67
"fmt"
@@ -673,11 +674,18 @@ func (db *DB) Begin(opts ...*sql.TxOptions) *DB {
673674
opt = opts[0]
674675
}
675676

677+
ctx := tx.Statement.Context
678+
if _, ok := ctx.Deadline(); !ok {
679+
if db.Config.DefaultTransactionTimeout > 0 {
680+
ctx, _ = context.WithTimeout(ctx, db.Config.DefaultTransactionTimeout)
681+
}
682+
}
683+
676684
switch beginner := tx.Statement.ConnPool.(type) {
677685
case TxBeginner:
678-
tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
686+
tx.Statement.ConnPool, err = beginner.BeginTx(ctx, opt)
679687
case ConnPoolBeginner:
680-
tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
688+
tx.Statement.ConnPool, err = beginner.BeginTx(ctx, opt)
681689
default:
682690
err = ErrInvalidTransaction
683691
}

0 commit comments

Comments
 (0)