diff --git a/copy/copy.go b/copy/copy.go index 44675f37c..b24c8a1c6 100644 --- a/copy/copy.go +++ b/copy/copy.go @@ -241,10 +241,13 @@ func Image(ctx context.Context, policyContext *signature.PolicyContext, destRef, // If reportWriter is not a TTY (e.g., when piping to a file), do not // print the progress bars to avoid long and hard to parse output. - // Instead use printCopyInfo() to print single line "Copying ..." messages. + // Instead use text-based aggregate progress via nonTTYProgressWriter. progressOutput := reportWriter if !isTTY(reportWriter) { progressOutput = io.Discard + + cleanupProgress := setupNonTTYProgressWriter(reportWriter, options) + defer cleanupProgress() } c := &copier{ diff --git a/copy/progress_nontty.go b/copy/progress_nontty.go new file mode 100644 index 000000000..02e302318 --- /dev/null +++ b/copy/progress_nontty.go @@ -0,0 +1,91 @@ +package copy + +import ( + "fmt" + "io" + "time" + + "github.com/containers/image/v5/types" + "github.com/vbauerster/mpb/v8/decor" +) + +const ( + // nonTTYProgressChannelSize is the buffer size for the progress channel + // in non-TTY mode. Buffered to prevent blocking during parallel downloads. + nonTTYProgressChannelSize = 10 + + // nonTTYProgressInterval is how often aggregate progress is printed + // in non-TTY mode. + nonTTYProgressInterval = 500 * time.Millisecond +) + +// nonTTYProgressWriter consumes ProgressProperties from a channel and writes +// aggregate text-based progress output suitable for non-TTY environments. +// No mutex needed - single goroutine processes events sequentially from channel. +type nonTTYProgressWriter struct { + output io.Writer + + // Aggregate tracking (no per-blob state needed) + totalSize int64 // Sum of all known blob sizes + downloaded int64 // Total bytes downloaded (accumulated from OffsetUpdate) + + // Output throttling + lastOutput time.Time + outputInterval time.Duration +} + +// newNonTTYProgressWriter creates a writer that outputs aggregate download +// progress as simple text lines, suitable for non-TTY environments like +// CI/CD pipelines or redirected output. +func newNonTTYProgressWriter(output io.Writer, interval time.Duration) *nonTTYProgressWriter { + return &nonTTYProgressWriter{ + output: output, + outputInterval: interval, + } +} + +// setupNonTTYProgressWriter configures text-based progress output for non-TTY +// environments unless the caller already provided a buffered Progress channel. +// Returns a cleanup function that must be deferred by the caller. +func setupNonTTYProgressWriter(reportWriter io.Writer, options *Options) func() { + if options.Progress != nil && cap(options.Progress) > 0 { + return func() {} + } + + // Use user's interval if greater than our default, otherwise use default. + // This allows users to slow down output while maintaining a sensible minimum. + interval := max(options.ProgressInterval, nonTTYProgressInterval) + if options.ProgressInterval <= 0 { + options.ProgressInterval = nonTTYProgressInterval + } + + progressChan := make(chan types.ProgressProperties, nonTTYProgressChannelSize) + options.Progress = progressChan + + pw := newNonTTYProgressWriter(reportWriter, interval) + go pw.Run(progressChan) + + return func() { close(progressChan) } +} + +// Run consumes progress events from the channel and prints throttled +// aggregate progress. Blocks until the channel is closed. Intended to +// be called as a goroutine: go tw.Run(progressChan) +func (w *nonTTYProgressWriter) Run(ch <-chan types.ProgressProperties) { + for props := range ch { + switch props.Event { + case types.ProgressEventNewArtifact: + // New blob starting - add its size to total + w.totalSize += props.Artifact.Size + + case types.ProgressEventRead: + // Bytes downloaded - accumulate and maybe print + w.downloaded += int64(props.OffsetUpdate) + if time.Since(w.lastOutput) > w.outputInterval { + fmt.Fprintf(w.output, "Progress: %.1f / %.1f\n", + decor.SizeB1024(w.downloaded), decor.SizeB1024(w.totalSize)) + w.lastOutput = time.Now() + } + } + } +} diff --git a/copy/progress_nontty_test.go b/copy/progress_nontty_test.go new file mode 100644 index 000000000..06415b0dc --- /dev/null +++ b/copy/progress_nontty_test.go @@ -0,0 +1,181 @@ +package copy + +import ( + "bytes" + "strings" + "testing" + "time" + + "github.com/containers/image/v5/types" + "github.com/stretchr/testify/assert" +) + +func TestNonTTYProgressWriterRun(t *testing.T) { + tests := []struct { + name string + interval time.Duration + events []types.ProgressProperties + wantTotalSize int64 + wantDownloaded int64 + wantLines int + wantContains string + }{ + { + name: "new artifacts only", + interval: time.Nanosecond, + events: []types.ProgressProperties{ + {Event: types.ProgressEventNewArtifact, Artifact: types.BlobInfo{Size: 1024}}, + {Event: types.ProgressEventNewArtifact, Artifact: types.BlobInfo{Size: 2048}}, + }, + wantTotalSize: 3072, + wantDownloaded: 0, + wantLines: 0, + }, + { + name: "read events produce output", + interval: time.Nanosecond, + events: []types.ProgressProperties{ + {Event: types.ProgressEventNewArtifact, Artifact: types.BlobInfo{Size: 10240}}, + {Event: types.ProgressEventRead, OffsetUpdate: 5120}, + }, + wantTotalSize: 10240, + wantDownloaded: 5120, + wantLines: 1, + wantContains: "Progress:", + }, + { + name: "throttling limits output", + interval: 5 * time.Second, + events: []types.ProgressProperties{ + {Event: types.ProgressEventNewArtifact, Artifact: types.BlobInfo{Size: 10240}}, + {Event: types.ProgressEventRead, OffsetUpdate: 1024}, + {Event: types.ProgressEventRead, OffsetUpdate: 1024}, + {Event: types.ProgressEventRead, OffsetUpdate: 1024}, + }, + wantTotalSize: 10240, + wantDownloaded: 3072, + wantLines: 1, + }, + { + name: "unknown events ignored", + interval: time.Nanosecond, + events: []types.ProgressProperties{ + {Event: types.ProgressEventDone}, + }, + wantTotalSize: 0, + wantDownloaded: 0, + wantLines: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var buf bytes.Buffer + pw := newNonTTYProgressWriter(&buf, tt.interval) + + ch := make(chan types.ProgressProperties, len(tt.events)) + for _, e := range tt.events { + ch <- e + } + close(ch) + + pw.Run(ch) + + assert.Equal(t, tt.wantTotalSize, pw.totalSize) + assert.Equal(t, tt.wantDownloaded, pw.downloaded) + + output := buf.String() + if tt.wantLines == 0 { + assert.Empty(t, output) + } else { + lines := strings.Split(strings.TrimSpace(output), "\n") + assert.Equal(t, tt.wantLines, len(lines)) + } + if tt.wantContains != "" { + assert.Contains(t, output, tt.wantContains) + } + }) + } +} + +func TestSetupNonTTYProgressWriter(t *testing.T) { + tests := []struct { + name string + progress chan types.ProgressProperties + progressInterval time.Duration + wantProgressSet bool + wantIntervalSet bool + wantMinInterval time.Duration + }{ + { + name: "nil channel gets default setup", + progress: nil, + progressInterval: 0, + wantProgressSet: true, + wantIntervalSet: true, + wantMinInterval: nonTTYProgressInterval, + }, + { + name: "unbuffered channel gets replaced", + progress: make(chan types.ProgressProperties), + progressInterval: 0, + wantProgressSet: true, + wantIntervalSet: true, + wantMinInterval: nonTTYProgressInterval, + }, + { + name: "buffered channel is kept", + progress: make(chan types.ProgressProperties, 5), + progressInterval: 0, + wantProgressSet: false, + wantIntervalSet: false, + }, + { + name: "caller interval larger than default is respected", + progress: nil, + progressInterval: 2 * time.Second, + wantProgressSet: true, + wantIntervalSet: false, + wantMinInterval: 2 * time.Second, + }, + { + name: "caller interval smaller than default is kept", + progress: nil, + progressInterval: 100 * time.Millisecond, + wantProgressSet: true, + wantIntervalSet: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var buf bytes.Buffer + opts := &Options{ + Progress: tt.progress, + ProgressInterval: tt.progressInterval, + } + originalProgress := opts.Progress + + cleanup := setupNonTTYProgressWriter(&buf, opts) + defer cleanup() + + if tt.wantProgressSet { + assert.NotNil(t, opts.Progress) + assert.Greater(t, cap(opts.Progress), 0) + if originalProgress != nil { + assert.NotEqual(t, originalProgress, opts.Progress) + } + } else { + assert.Equal(t, originalProgress, opts.Progress) + } + + if tt.wantIntervalSet { + assert.Equal(t, nonTTYProgressInterval, opts.ProgressInterval) + } + + if tt.wantMinInterval > 0 { + assert.GreaterOrEqual(t, opts.ProgressInterval, tt.wantMinInterval) + } + }) + } +}