Skip to content

Commit 2f36367

Browse files
committed
support to download as stream
1 parent 9798ea1 commit 2f36367

File tree

2 files changed

+206
-12
lines changed

2 files changed

+206
-12
lines changed

pkg/net/http.go

Lines changed: 91 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -97,8 +97,7 @@ func (h *HTTPDownloader) fetchProxyFromEnv(scheme string) {
9797
}
9898
}
9999

100-
// DownloadFile download a file with the progress
101-
func (h *HTTPDownloader) DownloadFile() error {
100+
func (h *HTTPDownloader) DownloadAsStream(writer io.Writer) (err error) {
102101
filepath, downloadURL, showProgress := h.TargetFilePath, h.URL, h.ShowProgress
103102
// Get the data
104103
if h.Context == nil {
@@ -181,22 +180,32 @@ func (h *HTTPDownloader) DownloadFile() error {
181180
}
182181
}
183182

184-
if err := os.MkdirAll(path.Dir(filepath), os.FileMode(0755)); err != nil {
185-
return err
183+
h.progressIndicator.Writer = writer
184+
h.progressIndicator.Init()
185+
186+
// Write the body to file
187+
_, err = io.Copy(h.progressIndicator, resp.Body)
188+
return
189+
}
190+
191+
// DownloadFile download a file with the progress
192+
func (h *HTTPDownloader) DownloadFile() (err error) {
193+
filepath := h.TargetFilePath
194+
if err = os.MkdirAll(path.Dir(filepath), os.FileMode(0755)); err != nil {
195+
return
186196
}
187197

188198
// Create the file
189-
out, err := os.Create(filepath)
199+
var out io.WriteCloser
200+
out, err = os.Create(filepath)
190201
if err != nil {
191-
_ = out.Close()
192-
return err
202+
return
193203
}
204+
defer func() {
205+
_ = out.Close()
206+
}()
194207

195-
h.progressIndicator.Writer = out
196-
h.progressIndicator.Init()
197-
198-
// Write the body to file
199-
_, err = io.Copy(h.progressIndicator, resp.Body)
208+
err = h.DownloadAsStream(out)
200209
return err
201210
}
202211

@@ -272,6 +281,38 @@ func (c *ContinueDownloader) WithBasicAuth(username, password string) *ContinueD
272281
return c
273282
}
274283

284+
func (c *ContinueDownloader) DownloadWithContinueAsStream(targetURL string, output io.Writer, index, continueAt, end int64, showProgress bool) (err error) {
285+
c.downloader = &HTTPDownloader{
286+
URL: targetURL,
287+
ShowProgress: showProgress,
288+
NoProxy: c.noProxy,
289+
RoundTripper: c.roundTripper,
290+
InsecureSkipVerify: c.insecureSkipVerify,
291+
UserName: c.UserName,
292+
Password: c.Password,
293+
Context: c.Context,
294+
Timeout: c.Timeout,
295+
}
296+
if index >= 0 {
297+
c.downloader.Title = fmt.Sprintf("Downloading part %d", index)
298+
}
299+
300+
if continueAt >= 0 {
301+
c.downloader.Header = make(map[string]string, 1)
302+
303+
if end > continueAt {
304+
c.downloader.Header["Range"] = fmt.Sprintf("bytes=%d-%d", continueAt, end)
305+
} else {
306+
c.downloader.Header["Range"] = fmt.Sprintf("bytes=%d-", continueAt)
307+
}
308+
}
309+
310+
if err = c.downloader.DownloadAsStream(output); err != nil {
311+
err = fmt.Errorf("cannot download from %s, error: %v", targetURL, err)
312+
}
313+
return
314+
}
315+
275316
// DownloadWithContinue downloads the files continuously
276317
func (c *ContinueDownloader) DownloadWithContinue(targetURL, output string, index, continueAt, end int64, showProgress bool) (err error) {
277318
c.downloader = &HTTPDownloader{
@@ -306,6 +347,44 @@ func (c *ContinueDownloader) DownloadWithContinue(targetURL, output string, inde
306347
return
307348
}
308349

350+
func DetectSizeWithRoundTripperAndAuthStream(targetURL string, output io.Writer, showProgress, noProxy, insecureSkipVerify bool,
351+
roundTripper http.RoundTripper, username, password string, timeout time.Duration) (total int64, rangeSupport bool, err error) {
352+
downloader := HTTPDownloader{
353+
URL: targetURL,
354+
ShowProgress: showProgress,
355+
RoundTripper: roundTripper,
356+
NoProxy: false, // below HTTP request does not need proxy
357+
InsecureSkipVerify: insecureSkipVerify,
358+
UserName: username,
359+
Password: password,
360+
Timeout: timeout,
361+
}
362+
363+
var detectOffset int64
364+
var lenErr error
365+
366+
detectOffset = 2
367+
downloader.Header = make(map[string]string, 1)
368+
downloader.Header["Range"] = fmt.Sprintf("bytes=%d-", detectOffset)
369+
370+
downloader.PreStart = func(resp *http.Response) bool {
371+
rangeSupport = resp.StatusCode == http.StatusPartialContent
372+
contentLen := resp.Header.Get("Content-Length")
373+
if total, lenErr = strconv.ParseInt(contentLen, 10, 0); lenErr == nil {
374+
total += detectOffset
375+
} else {
376+
rangeSupport = false
377+
}
378+
// always return false because we just want to get the header from response
379+
return false
380+
}
381+
382+
if err = downloader.DownloadAsStream(output); err != nil || lenErr != nil {
383+
err = fmt.Errorf("cannot download from %s, response error: %v, content length error: %v", targetURL, err, lenErr)
384+
}
385+
return
386+
}
387+
309388
// DetectSizeWithRoundTripperAndAuth returns the size of target resource
310389
func DetectSizeWithRoundTripperAndAuth(targetURL, output string, showProgress, noProxy, insecureSkipVerify bool,
311390
roundTripper http.RoundTripper, username, password string, timeout time.Duration) (total int64, rangeSupport bool, err error) {

pkg/net/multi_thread.go

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package net
33
import (
44
"context"
55
"fmt"
6+
"io"
67
"net/http"
78
"os"
89
"os/signal"
@@ -75,6 +76,120 @@ func (d *MultiThreadDownloader) WithBearerToken(bearerToken string) *MultiThread
7576
return d
7677
}
7778

79+
func (d *MultiThreadDownloader) DownloadWithContext(ctx context.Context, targetURL string, outputWriter io.Writer, thread int) (err error) {
80+
// get the total size of the target file
81+
var total int64
82+
var rangeSupport bool
83+
if total, rangeSupport, err = DetectSizeWithRoundTripperAndAuthStream(targetURL, outputWriter, d.showProgress,
84+
d.noProxy, d.insecureSkipVerify, d.roundTripper, d.username, d.password, d.timeout); rangeSupport && err != nil {
85+
return
86+
}
87+
88+
if rangeSupport {
89+
unit := total / int64(thread)
90+
offset := total - unit*int64(thread)
91+
var wg sync.WaitGroup
92+
var partItems map[int]string
93+
var m sync.Mutex
94+
95+
defer func() {
96+
// remove all partial files
97+
for _, part := range partItems {
98+
_ = os.RemoveAll(part)
99+
}
100+
}()
101+
102+
c := make(chan os.Signal, 1)
103+
signal.Notify(c, os.Interrupt)
104+
ctx, cancel := context.WithCancel(context.Background())
105+
var canceled bool
106+
107+
go func() {
108+
<-c
109+
canceled = true
110+
cancel()
111+
}()
112+
113+
fmt.Printf("start to download with %d threads, size: %d, unit: %d", thread, total, unit)
114+
for i := 0; i < thread; i++ {
115+
fmt.Println() // TODO take position, should take over by progerss bars
116+
wg.Add(1)
117+
go func(index int, wg *sync.WaitGroup, ctx context.Context) {
118+
defer wg.Done()
119+
outputFile, err := os.CreateTemp(os.TempDir(), fmt.Sprintf("part-%d", index))
120+
if err != nil {
121+
fmt.Println("failed to create template file", err)
122+
}
123+
outputFile.Close()
124+
125+
m.Lock()
126+
partItems[index] = outputFile.Name()
127+
m.Unlock()
128+
129+
end := unit*int64(index+1) - 1
130+
if index == thread-1 {
131+
// this is the last part
132+
end += offset
133+
}
134+
start := unit * int64(index)
135+
136+
downloader := &ContinueDownloader{}
137+
downloader.WithoutProxy(d.noProxy).
138+
WithRoundTripper(d.roundTripper).
139+
WithInsecureSkipVerify(d.insecureSkipVerify).
140+
WithBasicAuth(d.username, d.password).
141+
WithContext(ctx).WithTimeout(d.timeout)
142+
if downloadErr := downloader.DownloadWithContinue(targetURL, outputFile.Name(),
143+
int64(index), start, end, d.showProgress); downloadErr != nil {
144+
fmt.Println(downloadErr)
145+
}
146+
}(i, &wg, ctx)
147+
}
148+
149+
wg.Wait()
150+
// ProgressIndicator{}.Close()
151+
if canceled {
152+
err = fmt.Errorf("download process canceled")
153+
return
154+
}
155+
156+
// make the cursor right
157+
// TODO the progress component should take over it
158+
if thread > 1 {
159+
// line := GetCurrentLine()
160+
time.Sleep(time.Second)
161+
fmt.Printf("\033[%dE\n", thread) // move to the target line
162+
time.Sleep(time.Second * 5)
163+
}
164+
165+
for i := 0; i < thread; i++ {
166+
partFile := partItems[i]
167+
if data, ferr := os.ReadFile(partFile); ferr == nil {
168+
if _, err = outputWriter.Write(data); err != nil {
169+
err = fmt.Errorf("failed to write file: '%s'", partFile)
170+
break
171+
} else if !d.keepParts {
172+
_ = os.RemoveAll(partFile)
173+
}
174+
} else {
175+
err = fmt.Errorf("failed to read file: '%s'", partFile)
176+
break
177+
}
178+
}
179+
} else {
180+
fmt.Println("cannot download it using multiple threads, failed to one")
181+
downloader := &ContinueDownloader{}
182+
downloader.WithoutProxy(d.noProxy)
183+
downloader.WithRoundTripper(d.roundTripper)
184+
downloader.WithInsecureSkipVerify(d.insecureSkipVerify)
185+
downloader.WithTimeout(d.timeout)
186+
downloader.WithBasicAuth(d.username, d.password)
187+
err = downloader.DownloadWithContinueAsStream(targetURL, outputWriter, -1, 0, 0, true)
188+
d.suggestedFilename = downloader.GetSuggestedFilename()
189+
}
190+
return
191+
}
192+
78193
// Download starts to download the target URL
79194
func (d *MultiThreadDownloader) Download(targetURL, targetFilePath string, thread int) (err error) {
80195
// get the total size of the target file

0 commit comments

Comments
 (0)