Skip to content

Commit 65dd3b1

Browse files
committed
chore: context aware signal
1 parent cecda59 commit 65dd3b1

File tree

2 files changed

+30
-4
lines changed

2 files changed

+30
-4
lines changed

mc2mc/internal/client/odps.go

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,12 @@ func (c *odpsClient) ExecSQL(ctx context.Context, query string, additionalHints
5757
// wait execution success
5858
select {
5959
case <-ctx.Done():
60-
c.logger.Info("context cancelled, terminating task instance")
61-
err := c.terminate(taskIns)
62-
return e.Join(ctx.Err(), err)
60+
msg := "context canceled"
61+
if err := context.Cause(ctx); err != nil {
62+
msg = fmt.Sprintf("%s: %s", msg, err.Error())
63+
}
64+
c.logger.Info(msg)
65+
return errors.WithStack(c.terminate(taskIns))
6366
case err := <-c.wait(taskIns):
6467
return errors.WithStack(err)
6568
}

mc2mc/mc2mc.go

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ func mc2mc(envs []string) error {
3333
}
3434

3535
// graceful shutdown
36-
ctx, cancelFn := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
36+
ctx, cancelFn := signalAwareContext(context.Background(), os.Interrupt, syscall.SIGTERM)
3737
defer cancelFn()
3838

3939
// initiate client
@@ -210,3 +210,26 @@ func execute(ctx context.Context, c *client.Client, queriesToExecute []string, a
210210
}
211211
return nil
212212
}
213+
214+
// signalAwareContext creates a context that is aware of signals.
215+
func signalAwareContext(parent context.Context, signals ...os.Signal) (context.Context, context.CancelFunc) {
216+
ctx, cancelWithCause := context.WithCancelCause(parent)
217+
sigCh := make(chan os.Signal, 1)
218+
signal.Notify(sigCh, signals...)
219+
220+
// start a goroutine to handle signals
221+
go func() {
222+
select {
223+
case sig := <-sigCh:
224+
cancelWithCause(fmt.Errorf("signal: %v", sig))
225+
signal.Stop(sigCh)
226+
case <-ctx.Done():
227+
signal.Stop(sigCh)
228+
}
229+
}()
230+
231+
// return a standard CancelFunc that preserves the original behavior
232+
return ctx, func() {
233+
cancelWithCause(context.Canceled)
234+
}
235+
}

0 commit comments

Comments
 (0)