Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
282 changes: 282 additions & 0 deletions auto_registry_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,282 @@
package gen

import (
"os"
"path/filepath"
"strings"
"testing"

"gorm.io/driver/sqlite"
"gorm.io/gorm"
)

func TestAutoRegistryInitGeneration(t *testing.T) {
// 创建临时目录
tempDir := t.TempDir()

// 创建测试数据库
db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
if err != nil {
t.Fatalf("failed to create test database: %v", err)
}

// 创建简单的测试表
err = db.Exec(`CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT)`).Error
if err != nil {
t.Fatalf("failed to create users table: %v", err)
}

err = db.Exec(`CREATE TABLE orders (id INTEGER PRIMARY KEY, amount DECIMAL)`).Error
if err != nil {
t.Fatalf("failed to create orders table: %v", err)
}

tests := []struct {
name string
configTables []string
expectInitFunc map[string]bool // table -> shouldHaveInit
}{
{
name: "AllTables",
configTables: []string{}, // 空数组表示所有表
expectInitFunc: map[string]bool{
"users": true,
"orders": true,
},
},
{
name: "OnlyUsersTable",
configTables: []string{"users"},
expectInitFunc: map[string]bool{
"users": true,
"orders": false,
},
},
{
name: "BothTables",
configTables: []string{"users", "orders"},
expectInitFunc: map[string]bool{
"users": true,
"orders": true,
},
},
{
name: "NoTables",
configTables: []string{"nonexistent"},
expectInitFunc: map[string]bool{
"users": false,
"orders": false,
},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
modelDir := filepath.Join(tempDir, tt.name, "model")

// 配置生成器
g := NewGenerator(Config{
OutPath: filepath.Join(tempDir, tt.name, "query"),
ModelPkgPath: modelDir,
Mode: WithDefaultQuery | WithAutoRegistry,
})

// 根据测试配置启用自动注册
if len(tt.configTables) == 0 {
g.WithAutoRegistry() // 不传参数,所有表
} else {
g.WithAutoRegistry(tt.configTables...) // 传入指定表名
}

g.UseDB(db)
g.GenerateAllTable()
g.Execute()

// 验证每个表的 init 函数生成情况
for tableName, shouldHaveInit := range tt.expectInitFunc {
t.Run(tableName, func(t *testing.T) {
checkInitFunction(t, modelDir, tableName, shouldHaveInit)
})
}

// 验证注册表文件是否生成
registryFile := filepath.Join(modelDir, "gen.go")
if _, err := os.Stat(registryFile); os.IsNotExist(err) {
t.Errorf("registry file %s should exist", registryFile)
}
})
}
}

// checkInitFunction 检查指定表的模型文件是否包含正确的 init 函数
func checkInitFunction(t *testing.T, modelDir, tableName string, shouldHaveInit bool) {
fileName := filepath.Join(modelDir, tableName+".gen.go")

// 检查文件是否存在
if _, err := os.Stat(fileName); os.IsNotExist(err) {
t.Errorf("model file %s does not exist", fileName)
return
}

// 读取文件内容
content, err := os.ReadFile(fileName)
if err != nil {
t.Errorf("failed to read file %s: %v", fileName, err)
return
}

fileContent := string(content)

// 检查是否包含 init 函数
hasInitFunc := strings.Contains(fileContent, "func init() {")
hasRegisterCall := strings.Contains(fileContent, "RegisterModel(")

if shouldHaveInit {
if !hasInitFunc {
t.Errorf("file %s should contain 'func init() {' but doesn't", fileName)
}
if !hasRegisterCall {
t.Errorf("file %s should contain 'RegisterModel(' call but doesn't", fileName)
}

// 验证 RegisterModel 调用格式
if hasInitFunc && hasRegisterCall {
expectedModelName := getExpectedModelName(tableName)
expectedCall := "RegisterModel(" + expectedModelName + "{}, \"" + expectedModelName + "\")"

if !strings.Contains(fileContent, expectedCall) {
t.Errorf("file %s should contain %s", fileName, expectedCall)
t.Logf("Actual file content:\n%s", fileContent)
}
}
} else {
if hasInitFunc && hasRegisterCall {
t.Errorf("file %s should not contain init function with RegisterModel call", fileName)
}
}
}

// getExpectedModelName 根据表名获取期望的模型名
func getExpectedModelName(tableName string) string {
switch tableName {
case "users":
return "User"
case "orders":
return "Order"
default:
// 简单的首字母大写
return strings.Title(tableName)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[BestPractice]

Deprecated function usage detected: strings.Title is deprecated since Go 1.18. Use cases.Title from golang.org/x/text/cases package instead:

Suggested Change
Suggested change
// getExpectedModelName 根据表名获取期望的模型名
func getExpectedModelName(tableName string) string {
switch tableName {
case "users":
return "User"
case "orders":
return "Order"
default:
// 简单的首字母大写
return strings.Title(tableName)
func getExpectedModelName(tableName string) string {
switch tableName {
case "users":
return "User"
case "orders":
return "Order"
default:
// Convert first letter to uppercase
if len(tableName) == 0 {
return ""
}
return strings.ToUpper(tableName[:1]) + tableName[1:]
}
}

Committable suggestion

Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Context for Agents
[**BestPractice**]

Deprecated function usage detected: `strings.Title` is deprecated since Go 1.18. Use `cases.Title` from golang.org/x/text/cases package instead:

<details>
<summary>Suggested Change</summary>

```suggestion
func getExpectedModelName(tableName string) string {
	switch tableName {
	case "users":
		return "User"
	case "orders":
		return "Order"
	default:
		// Convert first letter to uppercase
		if len(tableName) == 0 {
			return ""
		}
		return strings.ToUpper(tableName[:1]) + tableName[1:]
	}
}
```

⚡ **Committable suggestion**

Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

</details>

File: auto_registry_test.go
Line: 168

}
}

// TestShouldEnableAutoRegistry 测试表过滤逻辑
func TestShouldEnableAutoRegistry(t *testing.T) {
tests := []struct {
name string
configuredList []string
tableName string
expected bool
}{
{
name: "EmptyList_AllTablesEnabled",
configuredList: []string{},
tableName: "users",
expected: true,
},
{
name: "TableInList_ShouldEnable",
configuredList: []string{"users", "orders"},
tableName: "users",
expected: true,
},
{
name: "TableNotInList_ShouldDisable",
configuredList: []string{"users", "orders"},
tableName: "products",
expected: false,
},
{
name: "SingleTable_Match",
configuredList: []string{"users"},
tableName: "users",
expected: true,
},
{
name: "SingleTable_NoMatch",
configuredList: []string{"users"},
tableName: "orders",
expected: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
g := &Generator{
Config: Config{
RegistryTableList: tt.configuredList,
},
}

result := g.shouldEnableAutoRegistry(tt.tableName)
if result != tt.expected {
t.Errorf("shouldEnableAutoRegistry(%s) = %v, want %v",
tt.tableName, result, tt.expected)
}
})
}
}

// TestWithAutoRegistryConfig 测试配置方法
func TestWithAutoRegistryConfig(t *testing.T) {
tests := []struct {
name string
tableNames []string
expectedList []string
expectedMode GenerateMode
}{
{
name: "NoTables",
tableNames: []string{},
expectedList: []string{},
expectedMode: WithAutoRegistry,
},
{
name: "SingleTable",
tableNames: []string{"users"},
expectedList: []string{"users"},
expectedMode: WithAutoRegistry,
},
{
name: "MultipleTables",
tableNames: []string{"users", "orders", "products"},
expectedList: []string{"users", "orders", "products"},
expectedMode: WithAutoRegistry,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := Config{}

// 调用 WithAutoRegistry 方法
cfg.WithAutoRegistry(tt.tableNames...)

// 验证配置结果
if cfg.Mode&WithAutoRegistry == 0 {
t.Error("WithAutoRegistry mode should be enabled")
}

if len(cfg.RegistryTableList) != len(tt.expectedList) {
t.Errorf("RegistryTableList length = %d, want %d",
len(cfg.RegistryTableList), len(tt.expectedList))
}

for i, expected := range tt.expectedList {
if i >= len(cfg.RegistryTableList) || cfg.RegistryTableList[i] != expected {
t.Errorf("RegistryTableList[%d] = %s, want %s",
i, cfg.RegistryTableList[i], expected)
}
Comment on lines 275 to 283

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[BestPractice]

Potential slice bounds error: Accessing cfg.RegistryTableList[i] in the error message without verifying i < len(cfg.RegistryTableList) could panic if the slice is shorter than expected:

Suggested Change
Suggested change
for i, expected := range tt.expectedList {
if i >= len(cfg.RegistryTableList) || cfg.RegistryTableList[i] != expected {
t.Errorf("RegistryTableList[%d] = %s, want %s",
i, cfg.RegistryTableList[i], expected)
}
for i, expected := range tt.expectedList {
if i >= len(cfg.RegistryTableList) {
t.Errorf("RegistryTableList[%d] missing, want %s", i, expected)
} else if cfg.RegistryTableList[i] != expected {
t.Errorf("RegistryTableList[%d] = %s, want %s",
i, cfg.RegistryTableList[i], expected)
}
}

Committable suggestion

Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Context for Agents
[**BestPractice**]

Potential slice bounds error: Accessing `cfg.RegistryTableList[i]` in the error message without verifying `i < len(cfg.RegistryTableList)` could panic if the slice is shorter than expected:

<details>
<summary>Suggested Change</summary>

```suggestion
			for i, expected := range tt.expectedList {
				if i >= len(cfg.RegistryTableList) {
					t.Errorf("RegistryTableList[%d] missing, want %s", i, expected)
				} else if cfg.RegistryTableList[i] != expected {
					t.Errorf("RegistryTableList[%d] = %s, want %s",
						i, cfg.RegistryTableList[i], expected)
				}
			}
```

⚡ **Committable suggestion**

Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

</details>

File: auto_registry_test.go
Line: 278

}
})
}
}
13 changes: 13 additions & 0 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ const (

// WithGeneric generate code with generic
WithGeneric

// WithAutoRegistry generate init functions to auto-register models
WithAutoRegistry
)

// Config generator's basic configuration
Expand All @@ -38,6 +41,9 @@ type Config struct {
ModelPkgPath string // generated model code's package name
WithUnitTest bool // generate unit test for query code

// auto registry configuration
RegistryTableList []string // specific table names to enable auto registry, empty means all tables

// generate model global configuration
FieldNullable bool // generate pointer when field is nullable
FieldCoverable bool // generate pointer when field has default value, to fix problem zero value cannot be assign: https://gorm.io/docs/create.html#Default-Values
Expand Down Expand Up @@ -106,6 +112,13 @@ func (cfg *Config) WithJSONTagNameStrategy(ns func(columnName string) (tagConten
cfg.fieldJSONTagNS = ns
}

// WithAutoRegistry enable auto registry feature for generated models
// tableNames: optional table names to enable auto registry, if empty, all tables will be enabled
func (cfg *Config) WithAutoRegistry(tableNames ...string) {
cfg.Mode |= WithAutoRegistry
cfg.RegistryTableList = tableNames
}

// WithImportPkgPath specify import package path
func (cfg *Config) WithImportPkgPath(paths ...string) {
for i, path := range paths {
Expand Down
2 changes: 1 addition & 1 deletion examples/biz/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import (
"context"
"fmt"

"gorm.io/gen/examples/dal/query"
"examples/dal/query"
)

var q = query.Q
Expand Down
13 changes: 9 additions & 4 deletions examples/cmd/gen/generate.go
Original file line number Diff line number Diff line change
@@ -1,24 +1,29 @@
package main

import (
"examples/conf"
"examples/dal"

"gorm.io/gen"
"gorm.io/gen/examples/conf"
"gorm.io/gen/examples/dal"
)

func init() {
dal.DB = dal.ConnectDB(conf.MySQLDSN).Debug()
dal.DB = dal.ConnectDB(conf.SQLiteDBName).Debug()

prepare(dal.DB) // prepare table for generate
}

func main() {
g := gen.NewGenerator(gen.Config{
OutPath: "../../dal/query",
OutPath: "../../dal/query",
ModelPkgPath: "../../dal/model",
})

g.UseDB(dal.DB)

// auto registry to models
g.WithAutoRegistry()

// generate all table from database
g.ApplyBasic(g.GenerateAllTable()...)

Expand Down
14 changes: 8 additions & 6 deletions examples/cmd/gen/prepare.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@ import (
// prepare table for test

const mytableSQL = "CREATE TABLE IF NOT EXISTS `mytables` (" +
" `ID` int(11) NOT NULL," +
" `username` varchar(16) DEFAULT NULL," +
" `age` int(8) NOT NULL," +
" `phone` varchar(11) NOT NULL," +
" INDEX `idx_username` (`username`)" +
") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;"
" `ID` INTEGER NOT NULL PRIMARY KEY," +
" `username` TEXT," +
" `age` INTEGER NOT NULL," +
" `phone` TEXT NOT NULL" +
");"

const indexSQL = "CREATE INDEX IF NOT EXISTS `idx_username` ON `mytables` (`username`);"

func prepare(db *gorm.DB) {
db.Exec(mytableSQL)
db.Exec(indexSQL)
}
5 changes: 3 additions & 2 deletions examples/cmd/only_model/generate.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
package main

import (
"examples/conf"
"examples/dal"

"gorm.io/gen"
"gorm.io/gen/examples/conf"
"gorm.io/gen/examples/dal"
)

func init() {
Expand Down
Loading
Loading