Skip to content

Commit 1901911

Browse files
authored
Add Set-based Create and Update support to Generics API (#7578)
1 parent cb65743 commit 1901911

File tree

9 files changed

+158
-32
lines changed

9 files changed

+158
-32
lines changed

.github/workflows/tests.yml

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ jobs:
1616
sqlite:
1717
strategy:
1818
matrix:
19-
go: ['1.23', '1.24']
19+
go: ['1.24', '1.25']
2020
platform: [ubuntu-latest] # can not run in windows OS
2121
runs-on: ${{ matrix.platform }}
2222

@@ -42,7 +42,7 @@ jobs:
4242
strategy:
4343
matrix:
4444
dbversion: ['mysql:9', 'mysql:8', 'mysql:5.7']
45-
go: ['1.23', '1.24']
45+
go: ['1.24', '1.25']
4646
platform: [ubuntu-latest]
4747
runs-on: ${{ matrix.platform }}
4848

@@ -85,7 +85,7 @@ jobs:
8585
strategy:
8686
matrix:
8787
dbversion: [ 'mariadb:latest' ]
88-
go: ['1.23', '1.24']
88+
go: ['1.24', '1.25']
8989
platform: [ ubuntu-latest ]
9090
runs-on: ${{ matrix.platform }}
9191

@@ -128,7 +128,7 @@ jobs:
128128
strategy:
129129
matrix:
130130
dbversion: ['postgres:latest', 'postgres:15', 'postgres:14', 'postgres:13']
131-
go: ['1.23', '1.24']
131+
go: ['1.24', '1.25']
132132
platform: [ubuntu-latest] # can not run in macOS and Windows
133133
runs-on: ${{ matrix.platform }}
134134

@@ -170,7 +170,7 @@ jobs:
170170
sqlserver:
171171
strategy:
172172
matrix:
173-
go: ['1.23', '1.24']
173+
go: ['1.24', '1.25']
174174
platform: [ubuntu-latest] # can not run test in macOS and windows
175175
runs-on: ${{ matrix.platform }}
176176

@@ -212,7 +212,7 @@ jobs:
212212
strategy:
213213
matrix:
214214
dbversion: [ 'v6.5.0' ]
215-
go: ['1.23', '1.24']
215+
go: ['1.24', '1.25']
216216
platform: [ ubuntu-latest ]
217217
runs-on: ${{ matrix.platform }}
218218

@@ -245,7 +245,7 @@ jobs:
245245
strategy:
246246
matrix:
247247
dbversion: ['opengauss/opengauss:7.0.0-RC1.B023']
248-
go: ['1.23', '1.24']
248+
go: ['1.24', '1.25']
249249
platform: [ubuntu-latest] # can not run in macOS and Windows
250250
runs-on: ${{ matrix.platform }}
251251

@@ -307,4 +307,4 @@ jobs:
307307
key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }}
308308

309309
- name: Tests
310-
run: GITHUB_ACTION=true GORM_DIALECT=gaussdb GORM_DSN="user=gaussdb password=Gaussdb@123 dbname=gorm host=localhost port=9950 sslmode=disable TimeZone=Asia/Shanghai" ./tests/tests_all.sh
310+
run: GITHUB_ACTION=true GORM_DIALECT=gaussdb GORM_DSN="user=gaussdb password=Gaussdb@123 dbname=gorm host=localhost port=9950 sslmode=disable TimeZone=Asia/Shanghai" ./tests/tests_all.sh

generics.go

Lines changed: 75 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,33 @@ type Interface[T any] interface {
3535
}
3636

3737
type CreateInterface[T any] interface {
38-
ChainInterface[T]
38+
ExecInterface[T]
39+
// chain methods available at start; return ChainInterface
40+
Scopes(scopes ...func(db *Statement)) ChainInterface[T]
41+
Where(query interface{}, args ...interface{}) ChainInterface[T]
42+
Not(query interface{}, args ...interface{}) ChainInterface[T]
43+
Or(query interface{}, args ...interface{}) ChainInterface[T]
44+
Limit(offset int) ChainInterface[T]
45+
Offset(offset int) ChainInterface[T]
46+
Joins(query clause.JoinTarget, on func(db JoinBuilder, joinTable clause.Table, curTable clause.Table) error) ChainInterface[T]
47+
Preload(association string, query func(db PreloadBuilder) error) ChainInterface[T]
48+
Select(query string, args ...interface{}) ChainInterface[T]
49+
Omit(columns ...string) ChainInterface[T]
50+
MapColumns(m map[string]string) ChainInterface[T]
51+
Distinct(args ...interface{}) ChainInterface[T]
52+
Group(name string) ChainInterface[T]
53+
Having(query interface{}, args ...interface{}) ChainInterface[T]
54+
Order(value interface{}) ChainInterface[T]
55+
Build(builder clause.Builder)
56+
57+
Delete(ctx context.Context) (rowsAffected int, err error)
58+
Update(ctx context.Context, name string, value any) (rowsAffected int, err error)
59+
Updates(ctx context.Context, t T) (rowsAffected int, err error)
60+
3961
Table(name string, args ...interface{}) CreateInterface[T]
4062
Create(ctx context.Context, r *T) error
4163
CreateInBatches(ctx context.Context, r *[]T, batchSize int) error
64+
Set(assignments ...clause.Assignment) SetCreateOrUpdateInterface[T]
4265
}
4366

4467
type ChainInterface[T any] interface {
@@ -58,15 +81,28 @@ type ChainInterface[T any] interface {
5881
Group(name string) ChainInterface[T]
5982
Having(query interface{}, args ...interface{}) ChainInterface[T]
6083
Order(value interface{}) ChainInterface[T]
84+
Set(assignments ...clause.Assignment) SetUpdateOnlyInterface[T]
6185

6286
Build(builder clause.Builder)
6387

88+
Table(name string, args ...interface{}) ChainInterface[T]
6489
Delete(ctx context.Context) (rowsAffected int, err error)
6590
Update(ctx context.Context, name string, value any) (rowsAffected int, err error)
6691
Updates(ctx context.Context, t T) (rowsAffected int, err error)
6792
Count(ctx context.Context, column string) (result int64, err error)
6893
}
6994

95+
// SetUpdateOnlyInterface is returned by Set after chaining; only Update is allowed
96+
type SetUpdateOnlyInterface[T any] interface {
97+
Update(ctx context.Context) (rowsAffected int, err error)
98+
}
99+
100+
// SetCreateOrUpdateInterface is returned by Set at start; Create or Update are allowed
101+
type SetCreateOrUpdateInterface[T any] interface {
102+
Create(ctx context.Context) error
103+
Update(ctx context.Context) (rowsAffected int, err error)
104+
}
105+
70106
type ExecInterface[T any] interface {
71107
Scan(ctx context.Context, r interface{}) error
72108
First(context.Context) (T, error)
@@ -163,6 +199,12 @@ func (c createG[T]) Table(name string, args ...interface{}) CreateInterface[T] {
163199
})}
164200
}
165201

202+
func (c createG[T]) Set(assignments ...clause.Assignment) SetCreateOrUpdateInterface[T] {
203+
assigns := make([]clause.Assignment, len(assignments))
204+
copy(assigns, assignments)
205+
return setCreateOrUpdateG[T]{c: c.chainG, assigns: assigns}
206+
}
207+
166208
func (c createG[T]) Create(ctx context.Context, r *T) error {
167209
return c.g.apply(ctx).Create(r).Error
168210
}
@@ -189,6 +231,12 @@ func (c chainG[T]) with(v op) chainG[T] {
189231
}
190232
}
191233

234+
func (c chainG[T]) Table(name string, args ...interface{}) ChainInterface[T] {
235+
return c.with(func(db *DB) *DB {
236+
return db.Table(name, args...)
237+
})
238+
}
239+
192240
func (c chainG[T]) Scopes(scopes ...func(db *Statement)) ChainInterface[T] {
193241
return c.with(func(db *DB) *DB {
194242
for _, fc := range scopes {
@@ -198,12 +246,6 @@ func (c chainG[T]) Scopes(scopes ...func(db *Statement)) ChainInterface[T] {
198246
})
199247
}
200248

201-
func (c chainG[T]) Table(name string, args ...interface{}) ChainInterface[T] {
202-
return c.with(func(db *DB) *DB {
203-
return db.Table(name, args...)
204-
})
205-
}
206-
207249
func (c chainG[T]) Where(query interface{}, args ...interface{}) ChainInterface[T] {
208250
return c.with(func(db *DB) *DB {
209251
return db.Where(query, args...)
@@ -390,6 +432,12 @@ func (c chainG[T]) MapColumns(m map[string]string) ChainInterface[T] {
390432
})
391433
}
392434

435+
func (c chainG[T]) Set(assignments ...clause.Assignment) SetUpdateOnlyInterface[T] {
436+
assigns := make([]clause.Assignment, len(assignments))
437+
copy(assigns, assignments)
438+
return setCreateOrUpdateG[T]{c: c, assigns: assigns}
439+
}
440+
393441
func (c chainG[T]) Distinct(args ...interface{}) ChainInterface[T] {
394442
return c.with(func(db *DB) *DB {
395443
return db.Distinct(args...)
@@ -557,6 +605,26 @@ func (c chainG[T]) Build(builder clause.Builder) {
557605
}
558606
}
559607

608+
type setCreateOrUpdateG[T any] struct {
609+
c chainG[T]
610+
assigns []clause.Assignment
611+
}
612+
613+
func (s setCreateOrUpdateG[T]) Update(ctx context.Context) (rowsAffected int, err error) {
614+
var r T
615+
res := s.c.g.apply(ctx).Model(r).Clauses(clause.Set(s.assigns)).Updates(map[string]interface{}{})
616+
return int(res.RowsAffected), res.Error
617+
}
618+
619+
func (s setCreateOrUpdateG[T]) Create(ctx context.Context) error {
620+
var r T
621+
data := make(map[string]interface{}, len(s.assigns))
622+
for _, a := range s.assigns {
623+
data[a.Column.Name] = a.Value
624+
}
625+
return s.c.g.apply(ctx).Model(r).Create(data).Error
626+
}
627+
560628
type execG[T any] struct {
561629
g *g[T]
562630
}

tests/connection_test.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
package tests_test
22

33
import (
4-
"fmt"
54
"testing"
65

76
"gorm.io/driver/mysql"
@@ -28,7 +27,7 @@ func TestWithSingleConnection(t *testing.T) {
2827
return nil
2928
})
3029
if err != nil {
31-
t.Errorf(fmt.Sprintf("WithSingleConnection should work, but got err %v", err))
30+
t.Errorf("WithSingleConnection should work, but got err %v", err)
3231
}
3332

3433
if actualName != expectedName {

tests/count_test.go

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
package tests_test
22

33
import (
4-
"fmt"
54
"regexp"
65
"sort"
76
"strings"
@@ -22,15 +21,15 @@ func TestCountWithGroup(t *testing.T) {
2221

2322
var count1 int64
2423
if err := DB.Model(&Company{}).Where("name = ?", "company_count_group_a").Group("name").Count(&count1).Error; err != nil {
25-
t.Errorf(fmt.Sprintf("Count should work, but got err %v", err))
24+
t.Errorf("Count should work, but got err %v", err)
2625
}
2726
if count1 != 1 {
2827
t.Errorf("Count with group should be 1, but got count: %v", count1)
2928
}
3029

3130
var count2 int64
3231
if err := DB.Model(&Company{}).Where("name in ?", []string{"company_count_group_b", "company_count_group_c"}).Group("name").Count(&count2).Error; err != nil {
33-
t.Errorf(fmt.Sprintf("Count should work, but got err %v", err))
32+
t.Errorf("Count should work, but got err %v", err)
3433
}
3534
if count2 != 2 {
3635
t.Errorf("Count with group should be 2, but got count: %v", count2)
@@ -49,15 +48,15 @@ func TestCount(t *testing.T) {
4948
DB.Save(&user1).Save(&user2).Save(&user3)
5049

5150
if err := DB.Where("name = ?", user1.Name).Or("name = ?", user3.Name).Find(&users).Count(&count).Error; err != nil {
52-
t.Errorf(fmt.Sprintf("Count should work, but got err %v", err))
51+
t.Errorf("Count should work, but got err %v", err)
5352
}
5453

5554
if count != int64(len(users)) {
5655
t.Errorf("Count() method should get correct value, expect: %v, got %v", count, len(users))
5756
}
5857

5958
if err := DB.Model(&User{}).Where("name = ?", user1.Name).Or("name = ?", user3.Name).Count(&count).Find(&users).Error; err != nil {
60-
t.Errorf(fmt.Sprintf("Count should work, but got err %v", err))
59+
t.Errorf("Count should work, but got err %v", err)
6160
}
6261

6362
if count != int64(len(users)) {
@@ -110,7 +109,7 @@ func TestCount(t *testing.T) {
110109
if err := DB.Model(&User{}).Where("name in ?", []string{user1.Name, user2.Name, user3.Name}).Select(
111110
"(CASE WHEN name=? THEN ? ELSE ? END) as name", "count-1", "main", "other",
112111
).Count(&count6).Find(&users).Error; err != nil || count6 != 3 {
113-
t.Fatalf(fmt.Sprintf("Count should work, but got err %v", err))
112+
t.Fatalf("Count should work, but got err %v", err)
114113
}
115114

116115
expects := []User{{Name: "main"}, {Name: "other"}, {Name: "other"}}
@@ -124,7 +123,7 @@ func TestCount(t *testing.T) {
124123
if err := DB.Model(&User{}).Where("name in ?", []string{user1.Name, user2.Name, user3.Name}).Select(
125124
"(CASE WHEN name=? THEN ? ELSE ? END) as name, age", "count-1", "main", "other",
126125
).Count(&count7).Find(&users).Error; err != nil || count7 != 3 {
127-
t.Fatalf(fmt.Sprintf("Count should work, but got err %v", err))
126+
t.Fatalf("Count should work, but got err %v", err)
128127
}
129128

130129
expects = []User{{Name: "main", Age: 18}, {Name: "other", Age: 18}, {Name: "other", Age: 18}}

tests/generics_test.go

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -667,6 +667,64 @@ func TestGenericsDistinct(t *testing.T) {
667667
}
668668
}
669669

670+
func TestGenericsSetCreate(t *testing.T) {
671+
ctx := context.Background()
672+
673+
name := "GenericsSetCreate"
674+
age := uint(21)
675+
676+
err := gorm.G[User](DB).Set(
677+
clause.Assignment{Column: clause.Column{Name: "name"}, Value: name},
678+
clause.Assignment{Column: clause.Column{Name: "age"}, Value: age},
679+
).Create(ctx)
680+
if err != nil {
681+
t.Fatalf("Set Create failed: %v", err)
682+
}
683+
684+
u, err := gorm.G[User](DB).Where("name = ?", name).First(ctx)
685+
if err != nil {
686+
t.Fatalf("failed to find created user: %v", err)
687+
}
688+
if u.ID == 0 || u.Name != name || u.Age != age {
689+
t.Fatalf("created user mismatch, got %+v", u)
690+
}
691+
}
692+
693+
func TestGenericsSetUpdate(t *testing.T) {
694+
ctx := context.Background()
695+
696+
// prepare
697+
u := User{Name: "GenericsSetUpdate_Before", Age: 30}
698+
if err := gorm.G[User](DB).Create(ctx, &u); err != nil {
699+
t.Fatalf("prepare user failed: %v", err)
700+
}
701+
702+
// update with Set after chain
703+
newName := "GenericsSetUpdate_After"
704+
newAge := uint(31)
705+
rows, err := gorm.G[User](DB).
706+
Where("id = ?", u.ID).
707+
Set(
708+
clause.Assignment{Column: clause.Column{Name: "name"}, Value: newName},
709+
clause.Assignment{Column: clause.Column{Name: "age"}, Value: newAge},
710+
).
711+
Update(ctx)
712+
if err != nil {
713+
t.Fatalf("Set Update failed: %v", err)
714+
}
715+
if rows != 1 {
716+
t.Fatalf("expected 1 row affected, got %d", rows)
717+
}
718+
719+
nu, err := gorm.G[User](DB).Where("id = ?", u.ID).First(ctx)
720+
if err != nil {
721+
t.Fatalf("failed to query updated user: %v", err)
722+
}
723+
if nu.Name != newName || nu.Age != newAge {
724+
t.Fatalf("updated user mismatch, got %+v", nu)
725+
}
726+
}
727+
670728
func TestGenericsGroupHaving(t *testing.T) {
671729
ctx := context.Background()
672730

tests/go.mod

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
module gorm.io/gorm/tests
22

3-
go 1.23.0
3+
go 1.24.0
44

55
require (
66
github.com/google/uuid v1.6.0
@@ -12,7 +12,7 @@ require (
1212
gorm.io/driver/postgres v1.6.0
1313
gorm.io/driver/sqlite v1.6.0
1414
gorm.io/driver/sqlserver v1.6.1
15-
gorm.io/gorm v1.30.2
15+
gorm.io/gorm v1.30.3
1616
)
1717

1818
require (
@@ -32,8 +32,8 @@ require (
3232
github.com/pmezard/go-difflib v1.0.0 // indirect
3333
github.com/tjfoc/gmsm v1.4.1 // indirect
3434
golang.org/x/crypto v0.41.0 // indirect
35-
golang.org/x/sync v0.16.0 // indirect
36-
golang.org/x/text v0.28.0 // indirect
35+
golang.org/x/sync v0.17.0 // indirect
36+
golang.org/x/text v0.29.0 // indirect
3737
gopkg.in/yaml.v3 v3.0.1 // indirect
3838
)
3939

0 commit comments

Comments
 (0)