Skip to content

Commit 715cafd

Browse files
authored
Merge pull request #1573 from bryanvaz/fix/psql_create_if_not_exists
fix: Use existing PostgreSQL table if it exists
2 parents ab88d43 + 5c2bcfe commit 715cafd

File tree

3 files changed

+227
-16
lines changed

3 files changed

+227
-16
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
*.fasthttp.gz
2323
*.pprof
2424
*.workspace
25+
/tmp/
2526

2627
# Dependencies
2728
/vendor/

postgres/postgres.go

Lines changed: 74 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -26,19 +26,31 @@ type Storage struct {
2626
}
2727

2828
var (
29-
checkSchemaMsg = "The `v` row has an incorrect data type. " +
30-
"It should be BYTEA but is instead %s. This will cause encoding-related panics if the DB is not migrated (see https://github.com/gofiber/storage/blob/main/MIGRATE.md)."
31-
dropQuery = `DROP TABLE IF EXISTS %s;`
29+
checkSchemaMsg = "The `%s` row has an incorrect data type. " +
30+
"It should be %s but is instead %s. This will cause encoding-related panics if the DB is not migrated (see https://github.com/gofiber/storage/blob/main/MIGRATE.md)."
31+
dropQuery = `DROP TABLE IF EXISTS %s;`
32+
checkTableExistsQuery = `SELECT COUNT(table_name)
33+
FROM information_schema.tables
34+
WHERE table_schema = '%s'
35+
AND table_name = '%s';`
3236
initQuery = []string{
33-
`CREATE TABLE IF NOT EXISTS %s (
37+
`CREATE TABLE %s (
3438
k VARCHAR(64) PRIMARY KEY NOT NULL DEFAULT '',
3539
v BYTEA NOT NULL,
3640
e BIGINT NOT NULL DEFAULT '0'
3741
);`,
3842
`CREATE INDEX IF NOT EXISTS e ON %s (e);`,
3943
}
40-
checkSchemaQuery = `SELECT DATA_TYPE FROM INFORMATION_SCHEMA.COLUMNS
41-
WHERE table_name = '%s' AND COLUMN_NAME = 'v';`
44+
checkSchemaQuery = `SELECT column_name, data_type
45+
FROM information_schema.columns
46+
WHERE table_schema = '%s'
47+
AND table_name = '%s'
48+
AND column_name IN ('k','v','e');`
49+
checkSchemaTargetDataType = map[string]string{
50+
"k": "character varying",
51+
"v": "bytea",
52+
"e": "bigint",
53+
}
4254
)
4355

4456
// New creates a new storage
@@ -61,6 +73,14 @@ func New(config ...Config) *Storage {
6173
panic(err)
6274
}
6375

76+
// Parse out schema in config, if provided
77+
schema := "public"
78+
tableName := cfg.Table
79+
if strings.Contains(cfg.Table, ".") {
80+
schema = strings.Split(cfg.Table, ".")[0]
81+
tableName = strings.Split(cfg.Table, ".")[1]
82+
}
83+
6484
// Drop table if set to true
6585
if cfg.Reset {
6686
if _, err := db.Exec(context.Background(), fmt.Sprintf(dropQuery, cfg.Table)); err != nil {
@@ -69,11 +89,23 @@ func New(config ...Config) *Storage {
6989
}
7090
}
7191

92+
// Determine if table exists
93+
tableExists := false
94+
row := db.QueryRow(context.Background(), fmt.Sprintf(checkTableExistsQuery, schema, tableName))
95+
var count int
96+
if err := row.Scan(&count); err != nil {
97+
db.Close()
98+
panic(err)
99+
}
100+
tableExists = count > 0
101+
72102
// Init database queries
73-
for _, query := range initQuery {
74-
if _, err := db.Exec(context.Background(), fmt.Sprintf(query, cfg.Table)); err != nil {
75-
db.Close()
76-
panic(err)
103+
if !tableExists {
104+
for _, query := range initQuery {
105+
if _, err := db.Exec(context.Background(), fmt.Sprintf(query, cfg.Table)); err != nil {
106+
db.Close()
107+
panic(err)
108+
}
77109
}
78110
}
79111

@@ -185,15 +217,41 @@ func (s *Storage) gc(t time.Time) {
185217
_, _ = s.db.Exec(context.Background(), s.sqlGC, t.Unix())
186218
}
187219

188-
func (s *Storage) checkSchema(tableName string) {
189-
var data []byte
220+
func (s *Storage) checkSchema(fullTableName string) {
221+
schema := "public"
222+
tableName := fullTableName
223+
if strings.Contains(fullTableName, ".") {
224+
schema = strings.Split(fullTableName, ".")[0]
225+
tableName = strings.Split(fullTableName, ".")[1]
226+
}
190227

191-
row := s.db.QueryRow(context.Background(), fmt.Sprintf(checkSchemaQuery, tableName))
192-
if err := row.Scan(&data); err != nil {
228+
rows, err := s.db.Query(context.Background(), fmt.Sprintf(checkSchemaQuery, schema, tableName))
229+
if err != nil {
193230
panic(err)
194231
}
232+
defer rows.Close()
233+
234+
data := make(map[string]string)
195235

196-
if strings.ToLower(string(data)) != "bytea" {
197-
fmt.Printf(checkSchemaMsg, string(data))
236+
rowCount := 0
237+
for rows.Next() {
238+
var columnName, dataType string
239+
if err := rows.Scan(&columnName, &dataType); err != nil {
240+
panic(err)
241+
}
242+
data[columnName] = dataType
243+
rowCount++
244+
}
245+
if rowCount == 0 {
246+
panic(fmt.Errorf("table %s does not exist", tableName))
247+
}
248+
for columnName, dataType := range checkSchemaTargetDataType {
249+
dt, ok := data[columnName]
250+
if !ok {
251+
panic(fmt.Errorf("required column %s does not exist in table %s", columnName, tableName))
252+
}
253+
if dt != dataType {
254+
panic(fmt.Errorf(checkSchemaMsg, columnName, dataType, dt))
255+
}
198256
}
199257
}

postgres/postgres_test.go

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package postgres
33
import (
44
"context"
55
"os"
6+
"strconv"
67
"testing"
78
"time"
89

@@ -17,6 +18,157 @@ var testStore = New(Config{
1718
Reset: true,
1819
})
1920

21+
func TestNoCreateUser(t *testing.T) {
22+
// Create a new user
23+
// give the use usage permissions to the database (but not create)
24+
ctx := context.Background()
25+
conn := testStore.Conn()
26+
27+
username := "testuser" + strconv.Itoa(int(time.Now().UnixNano()))
28+
password := "testpassword"
29+
30+
_, err := conn.Exec(ctx, "CREATE USER "+username+" WITH PASSWORD '"+password+"'")
31+
require.NoError(t, err)
32+
33+
_, err = conn.Exec(ctx, "GRANT CONNECT ON DATABASE "+os.Getenv("POSTGRES_DATABASE")+" TO "+username)
34+
require.NoError(t, err)
35+
36+
_, err = conn.Exec(ctx, "GRANT USAGE ON SCHEMA public TO "+username)
37+
require.NoError(t, err)
38+
39+
_, err = conn.Exec(ctx, "GRANT SELECT, INSERT, UPDATE, DELETE ON ALL TABLES IN SCHEMA public TO "+username)
40+
require.NoError(t, err)
41+
42+
_, err = conn.Exec(ctx, "ALTER DEFAULT PRIVILEGES IN SCHEMA public GRANT SELECT, INSERT, UPDATE, DELETE ON TABLES TO "+username)
43+
require.NoError(t, err)
44+
45+
_, err = conn.Exec(ctx, "REVOKE CREATE ON SCHEMA public FROM "+username)
46+
require.NoError(t, err)
47+
48+
t.Run("should panic if limited user tries to create table", func(t *testing.T) {
49+
tableThatDoesNotExist := "public.table_does_not_exists_" + strconv.Itoa(int(time.Now().UnixNano()))
50+
51+
defer func() {
52+
r := recover()
53+
require.NotNil(t, r, "Expected a panic when creating a table without permissions")
54+
}()
55+
56+
// This should panic since the user doesn't have CREATE permissions
57+
New(Config{
58+
Database: os.Getenv("POSTGRES_DATABASE"),
59+
Username: username,
60+
Password: password,
61+
Reset: true,
62+
Table: tableThatDoesNotExist,
63+
})
64+
})
65+
66+
// connect to an existing table using an unprivileged user
67+
limitedStore := New(Config{
68+
Database: os.Getenv("POSTGRES_DATABASE"),
69+
Username: username,
70+
Password: password,
71+
Reset: false,
72+
})
73+
74+
defer func() {
75+
limitedStore.Close()
76+
conn.Exec(ctx, "DROP USER "+username)
77+
}()
78+
79+
t.Run("should set", func(t *testing.T) {
80+
var (
81+
key = "john" + strconv.Itoa(int(time.Now().UnixNano()))
82+
val = []byte("doe" + strconv.Itoa(int(time.Now().UnixNano())))
83+
)
84+
85+
err := limitedStore.Set(key, val, 0)
86+
require.NoError(t, err)
87+
})
88+
t.Run("should set override", func(t *testing.T) {
89+
var (
90+
key = "john" + strconv.Itoa(int(time.Now().UnixNano()))
91+
val = []byte("doe" + strconv.Itoa(int(time.Now().UnixNano())))
92+
)
93+
err := limitedStore.Set(key, val, 0)
94+
require.NoError(t, err)
95+
err = limitedStore.Set(key, val, 0)
96+
require.NoError(t, err)
97+
})
98+
t.Run("should get", func(t *testing.T) {
99+
var (
100+
key = "john" + strconv.Itoa(int(time.Now().UnixNano()))
101+
val = []byte("doe" + strconv.Itoa(int(time.Now().UnixNano())))
102+
)
103+
err := limitedStore.Set(key, val, 0)
104+
require.NoError(t, err)
105+
result, err := limitedStore.Get(key)
106+
require.NoError(t, err)
107+
require.Equal(t, val, result)
108+
})
109+
t.Run("should set expiration", func(t *testing.T) {
110+
var (
111+
key = "john" + strconv.Itoa(int(time.Now().UnixNano()))
112+
val = []byte("doe" + strconv.Itoa(int(time.Now().UnixNano())))
113+
exp = 100 * time.Millisecond
114+
)
115+
err := limitedStore.Set(key, val, exp)
116+
require.NoError(t, err)
117+
})
118+
t.Run("should get expired", func(t *testing.T) {
119+
var (
120+
key = "john" + strconv.Itoa(int(time.Now().UnixNano()))
121+
val = []byte("doe" + strconv.Itoa(int(time.Now().UnixNano())))
122+
exp = 100 * time.Millisecond
123+
)
124+
err := limitedStore.Set(key, val, exp)
125+
require.NoError(t, err)
126+
time.Sleep(200 * time.Millisecond)
127+
result, err := limitedStore.Get(key)
128+
require.NoError(t, err)
129+
require.Zero(t, len(result))
130+
})
131+
t.Run("should get not exists", func(t *testing.T) {
132+
result, err := limitedStore.Get("nonexistentkey")
133+
require.NoError(t, err)
134+
require.Zero(t, len(result))
135+
})
136+
t.Run("should delete", func(t *testing.T) {
137+
var (
138+
key = "john" + strconv.Itoa(int(time.Now().UnixNano()))
139+
val = []byte("doe" + strconv.Itoa(int(time.Now().UnixNano())))
140+
)
141+
err := limitedStore.Set(key, val, 0)
142+
require.NoError(t, err)
143+
err = limitedStore.Delete(key)
144+
require.NoError(t, err)
145+
result, err := limitedStore.Get(key)
146+
require.NoError(t, err)
147+
require.Zero(t, len(result))
148+
})
149+
150+
}
151+
func Test_Should_Panic_On_Wrong_Schema(t *testing.T) {
152+
// Create a test table with wrong schema
153+
_, err := testStore.Conn().Exec(context.Background(), `
154+
CREATE TABLE IF NOT EXISTS test_schema_table (
155+
k VARCHAR(64) PRIMARY KEY NOT NULL DEFAULT '',
156+
v BYTEA NOT NULL,
157+
e VARCHAR(64) NOT NULL DEFAULT '' -- Changed e from BIGINT to VARCHAR
158+
);
159+
`)
160+
require.NoError(t, err)
161+
defer func() {
162+
_, err := testStore.Conn().Exec(context.Background(), "DROP TABLE IF EXISTS test_schema_table;")
163+
require.NoError(t, err)
164+
}()
165+
166+
// Call checkSchema with the wrong table
167+
require.Panics(t, func() {
168+
testStore.checkSchema("test_schema_table")
169+
})
170+
}
171+
20172
func Test_Postgres_Set(t *testing.T) {
21173
var (
22174
key = "john"

0 commit comments

Comments
 (0)