Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion mc2mc/internal/client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@ import (
"os"
"testing"

"github.com/goto/transformers/mc2mc/internal/client"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/goto/transformers/mc2mc/internal/client"
)

func TestExecute(t *testing.T) {
Expand Down
6 changes: 4 additions & 2 deletions mc2mc/internal/loader/append.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@ func NewAppendLoader(logger *slog.Logger) (*appendLoader, error) {
}

func (l *appendLoader) GetQuery(tableID, query string) string {
return fmt.Sprintf("INSERT INTO TABLE %s %s;", tableID, query)
headers, qr := SeparateHeadersAndQuery(query)
return fmt.Sprintf("%s INSERT INTO TABLE %s %s;", headers, tableID, qr)
}

func (l *appendLoader) GetPartitionedQuery(tableID, query string, partitionNames []string) string {
return fmt.Sprintf("INSERT INTO TABLE %s PARTITION (%s) %s;", tableID, strings.Join(partitionNames, ", "), query)
headers, qr := SeparateHeadersAndQuery(query)
return fmt.Sprintf("%s INSERT INTO TABLE %s PARTITION (%s) %s;", headers, tableID, strings.Join(partitionNames, ", "), qr)
}
25 changes: 25 additions & 0 deletions mc2mc/internal/loader/helper.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package loader

import (
"strings"
)

func SeparateHeadersAndQuery(query string) (string, string) {
parts := strings.Split(query, ";")

last := ""
idx := len(parts) - 1
for idx >= 0 {
last = parts[idx]
if strings.TrimSpace(last) != "" {
break
}
idx = idx - 1
}

headers := strings.Join(parts[:idx], ";")
if headers != "" {
headers += ";"
}
return headers, last
}
58 changes: 58 additions & 0 deletions mc2mc/internal/loader/helper_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package loader_test

import (
"strings"
"testing"

"github.com/stretchr/testify/assert"

"github.com/goto/transformers/mc2mc/internal/loader"
)

func TestMacroSeparator(t *testing.T) {
t.Run("returns query without macros", func(t *testing.T) {
q1 := `select * from playground`
macros, query := loader.SeparateHeadersAndQuery(q1)
assert.Empty(t, macros)
assert.Equal(t, q1, query)
})
t.Run("returns query removing whitespace", func(t *testing.T) {
q1 := `
select * from playground`

header, query := loader.SeparateHeadersAndQuery(q1)
assert.Empty(t, header)
assert.Contains(t, query, q1)
})
t.Run("splits headers and query", func(t *testing.T) {
q1 := `set odps.sql.allow.fullscan=true;
select * from playground`
headers, query := loader.SeparateHeadersAndQuery(q1)
assert.Equal(t, "set odps.sql.allow.fullscan=true;", headers)
assert.Equal(t, "select * from playground", strings.TrimSpace(query))
})
t.Run("works with query of multiple headers", func(t *testing.T) {
q1 := `set odps.sql.allow.fullscan=true;
set odps.sql.python.version=cp37;

select distinct event_timestamp,
client_id,
country_code,
from presentation.main.important_date
where CAST(event_timestamp as DATE) = '{{ .DSTART | Date }}'
and client_id in ('123')
`
headers, query := loader.SeparateHeadersAndQuery(q1)
expectedHeader := `set odps.sql.allow.fullscan=true;
set odps.sql.python.version=cp37;`
assert.Equal(t, expectedHeader, headers)

expectedQuery := `select distinct event_timestamp,
client_id,
country_code,
from presentation.main.important_date
where CAST(event_timestamp as DATE) = '{{ .DSTART | Date }}'
and client_id in ('123')`
assert.Contains(t, query, expectedQuery)
})
}
6 changes: 4 additions & 2 deletions mc2mc/internal/loader/replace.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@ func NewReplaceLoader(logger *slog.Logger) *replaceLoader {
}

func (l *replaceLoader) GetQuery(tableID, query string) string {
return fmt.Sprintf("INSERT OVERWRITE TABLE %s %s;", tableID, query)
headers, qr := SeparateHeadersAndQuery(query)
return fmt.Sprintf("%s INSERT OVERWRITE TABLE %s %s;", headers, tableID, qr)
}

func (l *replaceLoader) GetPartitionedQuery(tableID, query string, partitionNames []string) string {
return fmt.Sprintf("INSERT OVERWRITE TABLE %s PARTITION (%s) %s;", tableID, strings.Join(partitionNames, ", "), query)
headers, qr := SeparateHeadersAndQuery(query)
return fmt.Sprintf("%s INSERT OVERWRITE TABLE %s PARTITION (%s) %s;", headers, tableID, strings.Join(partitionNames, ", "), qr)
}
Loading