diff --git a/mc2mc/internal/client/odps.go b/mc2mc/internal/client/odps.go index ce840f0..927941e 100644 --- a/mc2mc/internal/client/odps.go +++ b/mc2mc/internal/client/odps.go @@ -57,9 +57,12 @@ func (c *odpsClient) ExecSQL(ctx context.Context, query string, additionalHints // wait execution success select { case <-ctx.Done(): - c.logger.Info("context cancelled, terminating task instance") - err := c.terminate(taskIns) - return e.Join(ctx.Err(), err) + msg := "context canceled" + if err := context.Cause(ctx); err != nil { + msg = fmt.Sprintf("%s: %s", msg, err.Error()) + } + c.logger.Info(msg) + return errors.WithStack(c.terminate(taskIns)) case err := <-c.wait(taskIns): return errors.WithStack(err) } diff --git a/mc2mc/mc2mc.go b/mc2mc/mc2mc.go index 6e41877..8392796 100644 --- a/mc2mc/mc2mc.go +++ b/mc2mc/mc2mc.go @@ -33,7 +33,7 @@ func mc2mc(envs []string) error { } // graceful shutdown - ctx, cancelFn := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) + ctx, cancelFn := signalAwareContext(context.Background(), os.Interrupt, syscall.SIGTERM) defer cancelFn() // initiate client @@ -210,3 +210,26 @@ func execute(ctx context.Context, c *client.Client, queriesToExecute []string, a } return nil } + +// signalAwareContext creates a context that is aware of signals. +func signalAwareContext(parent context.Context, signals ...os.Signal) (context.Context, context.CancelFunc) { + ctx, cancelWithCause := context.WithCancelCause(parent) + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, signals...) + + // start a goroutine to handle signals + go func() { + select { + case sig := <-sigCh: + cancelWithCause(fmt.Errorf("signal: %v", sig)) + signal.Stop(sigCh) + case <-ctx.Done(): + signal.Stop(sigCh) + } + }() + + // return a standard CancelFunc that preserves the original behavior + return ctx, func() { + cancelWithCause(context.Canceled) + } +}