Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions mc2mc/internal/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ type Loader interface {
}

type OdpsClient interface {
GetOrderedColumns(tableID string) ([]string, error)
GetPartitionNames(ctx context.Context, tableID string) ([]string, error)
ExecSQL(ctx context.Context, query string) error
}
Expand Down Expand Up @@ -65,6 +66,17 @@ func (c *Client) Execute(ctx context.Context, tableID, queryFilePath string) err
if err != nil {
return errors.WithStack(err)
}

// get column names
if tableID != "" {
columnNames, err := c.OdpsClient.GetOrderedColumns(tableID)
if err != nil {
return errors.WithStack(err)
}
// construct query with ordered columns
queryRaw = constructQueryWithOrderedColumns(queryRaw, columnNames)
}

if c.enablePartitionValue && !c.enableAutoPartition {
queryRaw = addPartitionValueColumn(queryRaw)
}
Expand Down Expand Up @@ -98,3 +110,8 @@ func addPartitionValueColumn(rawQuery []byte) []byte {
header, qr := loader.SeparateHeadersAndQuery(string(rawQuery))
return []byte(fmt.Sprintf("%s SELECT *, STRING(CURRENT_DATE()) as __partitionvalue FROM (%s)", header, qr))
}

func constructQueryWithOrderedColumns(query []byte, orderedColumns []string) []byte {
header, qr := loader.SeparateHeadersAndQuery(string(query))
return []byte(fmt.Sprintf("%s %s", header, loader.ConstructQueryWithOrderedColumns(qr, orderedColumns)))
}
41 changes: 37 additions & 4 deletions mc2mc/internal/client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,30 @@ func TestExecute(t *testing.T) {
// assert
assert.Error(t, err)
})
t.Run("should return error when getting ordered columns fails", func(t *testing.T) {
// arrange
client, err := client.NewClient(context.TODO(), client.SetupLogger("error"))
require.NoError(t, err)
client.OdpsClient = &mockOdpsClient{
orderedColumns: func() ([]string, error) {
return nil, fmt.Errorf("error get ordered columns")
},
}
assert.NoError(t, os.WriteFile("/tmp/query.sql", []byte("SELECT * FROM table;"), 0644))
// act
err = client.Execute(context.TODO(), "project_test.table_test", "/tmp/query.sql")
// assert
assert.Error(t, err)
assert.ErrorContains(t, err, "error get ordered columns")
})
t.Run("should return error when getting partition name fails", func(t *testing.T) {
// arrange
client, err := client.NewClient(context.TODO(), client.SetupLogger("error"))
require.NoError(t, err)
client.OdpsClient = &mockOdpsClient{
orderedColumns: func() ([]string, error) {
return []string{"col1", "col2"}, nil
},
partitionResult: func() ([]string, error) {
return nil, fmt.Errorf("error get partition name")
},
Expand All @@ -44,6 +63,9 @@ func TestExecute(t *testing.T) {
client, err := client.NewClient(context.TODO(), client.SetupLogger("error"), client.SetupLoader("APPEND"))
require.NoError(t, err)
client.OdpsClient = &mockOdpsClient{
orderedColumns: func() ([]string, error) {
return []string{"col1", "col2"}, nil
},
partitionResult: func() ([]string, error) {
return nil, nil
},
Expand All @@ -63,6 +85,9 @@ func TestExecute(t *testing.T) {
client, err := client.NewClient(context.TODO(), client.SetupLogger("error"), client.SetupLoader("REPLACE"))
require.NoError(t, err)
client.OdpsClient = &mockOdpsClient{
orderedColumns: func() ([]string, error) {
return []string{"col1", "col2"}, nil
},
partitionResult: func() ([]string, error) {
return []string{"event_date"}, nil
},
Expand All @@ -72,11 +97,11 @@ func TestExecute(t *testing.T) {
}
client.Loader = &mockLoader{
getQueryFunc: func(tableID, query string) string {
return "INSERT OVERWRITE TABLE project_test.table_test SELECT * FROM table;"
return "INSERT OVERWRITE TABLE project_test.table_test SELECT col1, col2 FROM (SELECT * FROM table);"
},
getPartitionedQueryFunc: func(tableID, query string, partitionNames []string) string {
assert.True(t, true, "should be called")
return "INSERT OVERWRITE TABLE project_test.table_test PARTITION(event_date) SELECT * FROM table;"
return "INSERT OVERWRITE TABLE project_test.table_test PARTITION(event_date) SELECT col1, col2 FROM (SELECT * FROM table);"
},
}
require.NoError(t, os.WriteFile("/tmp/query.sql", []byte("SELECT * FROM table;"), 0644))
Expand All @@ -90,6 +115,9 @@ func TestExecute(t *testing.T) {
client, err := client.NewClient(context.TODO(), client.SetupLogger("error"), client.SetupLoader("REPLACE"), client.EnableAutoPartition(true))
require.NoError(t, err)
client.OdpsClient = &mockOdpsClient{
orderedColumns: func() ([]string, error) {
return []string{"col1", "col2"}, nil
},
partitionResult: func() ([]string, error) {
return []string{"_partition_value"}, nil
},
Expand All @@ -99,11 +127,11 @@ func TestExecute(t *testing.T) {
}
client.Loader = &mockLoader{
getQueryFunc: func(tableID, query string) string {
return "INSERT OVERWRITE TABLE project_test.table_test SELECT * FROM table;"
return "INSERT OVERWRITE TABLE project_test.table_test SELECT col1, col2 FROM (SELECT * FROM table);"
},
getPartitionedQueryFunc: func(tableID, query string, partitionNames []string) string {
assert.False(t, true, "should not be called")
return "INSERT OVERWRITE TABLE project_test.table_test PARTITION(event_date) SELECT * FROM table;"
return "INSERT OVERWRITE TABLE project_test.table_test PARTITION(_partition_value) SELECT col1, col2 FROM (SELECT * FROM table);"
},
}
require.NoError(t, os.WriteFile("/tmp/query.sql", []byte("SELECT * FROM table;"), 0644))
Expand All @@ -117,6 +145,7 @@ func TestExecute(t *testing.T) {
type mockOdpsClient struct {
partitionResult func() ([]string, error)
execSQLResult func() error
orderedColumns func() ([]string, error)
}

func (m *mockOdpsClient) GetPartitionNames(ctx context.Context, tableID string) ([]string, error) {
Expand All @@ -127,6 +156,10 @@ func (m *mockOdpsClient) ExecSQL(ctx context.Context, query string) error {
return m.execSQLResult()
}

func (m *mockOdpsClient) GetOrderedColumns(tableID string) ([]string, error) {
return m.orderedColumns()
}

type mockLoader struct {
getQueryFunc func(tableID, query string) string
getPartitionedQueryFunc func(tableID, query string, partitionNames []string) string
Expand Down
21 changes: 21 additions & 0 deletions mc2mc/internal/client/odps.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,30 @@ func (c *odpsClient) GetPartitionNames(_ context.Context, tableID string) ([]str
for _, partition := range table.Schema().PartitionColumns {
partitionNames = append(partitionNames, partition.Name)
}

return partitionNames, nil
}

// GetOrderedColumns returns the ordered column names of the given table
// by querying the table schema.
func (c *odpsClient) GetOrderedColumns(tableID string) ([]string, error) {
splittedTableID := strings.Split(tableID, ".")
if len(splittedTableID) != 3 {
return nil, errors.Errorf("invalid tableID (tableID should be in format project.schema.table): %s", tableID)
}
project, schema, name := splittedTableID[0], splittedTableID[1], splittedTableID[2]
table := odps.NewTable(c.client, project, schema, name)
if err := table.Load(); err != nil {
return nil, errors.WithStack(err)
}
var columnNames []string
for _, column := range table.Schema().Columns {
columnNames = append(columnNames, column.Name)
}

return columnNames, nil
}

// wait waits for the task instance to finish on a separate goroutine
func wait(taskIns *odps.Instance) <-chan error {
errChan := make(chan error)
Expand Down
7 changes: 7 additions & 0 deletions mc2mc/internal/loader/helper.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package loader

import (
"fmt"
"strings"
)

Expand All @@ -23,3 +24,9 @@ func SeparateHeadersAndQuery(query string) (string, string) {
}
return headers, last
}

func ConstructQueryWithOrderedColumns(query string, orderedColumns []string) string {
idx := strings.Index(query, ";")
query = query[:idx]
return fmt.Sprintf("SELECT %s FROM (%s);", strings.Join(orderedColumns, ", "), query)
}
10 changes: 10 additions & 0 deletions mc2mc/internal/loader/helper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,13 @@ where CAST(event_timestamp as DATE) = '{{ .DSTART | Date }}'
assert.Contains(t, query, expectedQuery)
})
}

func TestConstructQueryWithOrderedColumns(t *testing.T) {
t.Run("returns query with ordered columns", func(t *testing.T) {
q1 := `select col_2 as col2, col_3 as col3, col_1 as col1 from project.schema.table;`
orderedColumns := []string{"col1", "col2", "col3"}
query := loader.ConstructQueryWithOrderedColumns(q1, orderedColumns)
expected := "SELECT col1, col2, col3 FROM (select col_2 as col2, col_3 as col3, col_1 as col1 from project.schema.table);"
assert.Equal(t, expected, query)
})
}
Loading