diff --git a/.github/workflows/backup.yaml b/.github/workflows/backup.yaml index bcf5b24..6aefd20 100644 --- a/.github/workflows/backup.yaml +++ b/.github/workflows/backup.yaml @@ -8,7 +8,7 @@ on: jobs: BackupGit: - runs-on: ubuntu-20.04 + runs-on: ubuntu-22.04 steps: - uses: actions/checkout@v3.6.0 - name: backup diff --git a/.github/workflows/coverage-report.yaml b/.github/workflows/coverage-report.yaml index 17908f5..5ba6a15 100644 --- a/.github/workflows/coverage-report.yaml +++ b/.github/workflows/coverage-report.yaml @@ -7,7 +7,7 @@ on: jobs: TestAndReport: - runs-on: ubuntu-20.04 + runs-on: ubuntu-22.04 steps: - name: Set up Go uses: actions/setup-go@v4 diff --git a/.github/workflows/pull-request.yaml b/.github/workflows/pull-request.yaml index 73d1d34..fd6e426 100644 --- a/.github/workflows/pull-request.yaml +++ b/.github/workflows/pull-request.yaml @@ -10,7 +10,7 @@ env: jobs: build: name: Build - runs-on: ubuntu-20.04 + runs-on: ubuntu-22.04 steps: - name: Set up Go uses: actions/setup-go@v4 @@ -51,7 +51,7 @@ jobs: GoLint: name: Lint - runs-on: ubuntu-20.04 + runs-on: ubuntu-22.04 steps: - name: Set up Go uses: actions/setup-go@v4 @@ -66,7 +66,7 @@ jobs: golint-path: ./... Security: name: Security - runs-on: ubuntu-20.04 + runs-on: ubuntu-22.04 env: GO111MODULE: on steps: @@ -78,7 +78,7 @@ jobs: args: '-exclude=G402,G204,G304,G110,G306,G107 ./...' CodeQL: name: CodeQL - runs-on: ubuntu-20.04 + runs-on: ubuntu-22.04 env: GO111MODULE: on steps: @@ -92,7 +92,7 @@ jobs: uses: github/codeql-action/analyze@v1 MarkdownLinkCheck: name: MarkdownLinkCheck - runs-on: ubuntu-20.04 + runs-on: ubuntu-22.04 steps: - uses: actions/checkout@v3.6.0 - uses: gaurav-nelson/github-action-markdown-link-check@1.0.13 @@ -100,7 +100,7 @@ jobs: use-verbose-mode: 'yes' image: - runs-on: ubuntu-20.04 + runs-on: ubuntu-22.04 steps: - name: Checkout uses: actions/checkout@v4 diff --git a/.github/workflows/release-drafter.yml b/.github/workflows/release-drafter.yml index 73fcc4f..acc8db4 100644 --- a/.github/workflows/release-drafter.yml +++ b/.github/workflows/release-drafter.yml @@ -7,7 +7,7 @@ on: jobs: UpdateReleaseDraft: - runs-on: ubuntu-20.04 + runs-on: ubuntu-22.04 steps: - uses: release-drafter/release-drafter@v5 env: diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml index eeb6472..01e91dc 100644 --- a/.github/workflows/release.yaml +++ b/.github/workflows/release.yaml @@ -9,7 +9,7 @@ env: jobs: goreleaser: - runs-on: ubuntu-20.04 + runs-on: ubuntu-22.04 steps: - name: Checkout uses: actions/checkout@v3.6.0 @@ -43,7 +43,7 @@ jobs: oras push ${{ env.REGISTRY }}/linuxsuren/hd:$TAG release image: - runs-on: ubuntu-20.04 + runs-on: ubuntu-22.04 steps: - name: Checkout uses: actions/checkout@v4 diff --git a/pkg/net/http.go b/pkg/net/http.go index 3811240..c965ec6 100644 --- a/pkg/net/http.go +++ b/pkg/net/http.go @@ -97,8 +97,8 @@ func (h *HTTPDownloader) fetchProxyFromEnv(scheme string) { } } -// DownloadFile download a file with the progress -func (h *HTTPDownloader) DownloadFile() error { +// DownloadAsStream downloads the file as stream +func (h *HTTPDownloader) DownloadAsStream(writer io.Writer) (err error) { filepath, downloadURL, showProgress := h.TargetFilePath, h.URL, h.ShowProgress // Get the data if h.Context == nil { @@ -115,7 +115,10 @@ func (h *HTTPDownloader) DownloadFile() error { if h.UserName != "" && h.Password != "" { req.SetBasicAuth(h.UserName, h.Password) + } else if h.Password != "" { + req.Header.Set("Authorization", "Bearer "+h.Password) } + var tr http.RoundTripper if h.RoundTripper != nil { tr = h.RoundTripper @@ -178,22 +181,32 @@ func (h *HTTPDownloader) DownloadFile() error { } } - if err := os.MkdirAll(path.Dir(filepath), os.FileMode(0755)); err != nil { - return err + h.progressIndicator.Writer = writer + h.progressIndicator.Init() + + // Write the body to file + _, err = io.Copy(h.progressIndicator, resp.Body) + return +} + +// DownloadFile download a file with the progress +func (h *HTTPDownloader) DownloadFile() (err error) { + filepath := h.TargetFilePath + if err = os.MkdirAll(path.Dir(filepath), os.FileMode(0755)); err != nil { + return } // Create the file - out, err := os.Create(filepath) + var out io.WriteCloser + out, err = os.Create(filepath) if err != nil { - _ = out.Close() - return err + return } + defer func() { + _ = out.Close() + }() - h.progressIndicator.Writer = out - h.progressIndicator.Init() - - // Write the body to file - _, err = io.Copy(h.progressIndicator, resp.Body) + err = h.DownloadAsStream(out) return err } @@ -269,6 +282,39 @@ func (c *ContinueDownloader) WithBasicAuth(username, password string) *ContinueD return c } +// DownloadWithContinueAsStream downloads the files continuously +func (c *ContinueDownloader) DownloadWithContinueAsStream(targetURL string, output io.Writer, index, continueAt, end int64, showProgress bool) (err error) { + c.downloader = &HTTPDownloader{ + URL: targetURL, + ShowProgress: showProgress, + NoProxy: c.noProxy, + RoundTripper: c.roundTripper, + InsecureSkipVerify: c.insecureSkipVerify, + UserName: c.UserName, + Password: c.Password, + Context: c.Context, + Timeout: c.Timeout, + } + if index >= 0 { + c.downloader.Title = fmt.Sprintf("Downloading part %d", index) + } + + if continueAt >= 0 { + c.downloader.Header = make(map[string]string, 1) + + if end > continueAt { + c.downloader.Header["Range"] = fmt.Sprintf("bytes=%d-%d", continueAt, end) + } else { + c.downloader.Header["Range"] = fmt.Sprintf("bytes=%d-", continueAt) + } + } + + if err = c.downloader.DownloadAsStream(output); err != nil { + err = fmt.Errorf("cannot download from %s, error: %v", targetURL, err) + } + return +} + // DownloadWithContinue downloads the files continuously func (c *ContinueDownloader) DownloadWithContinue(targetURL, output string, index, continueAt, end int64, showProgress bool) (err error) { c.downloader = &HTTPDownloader{ @@ -303,6 +349,45 @@ func (c *ContinueDownloader) DownloadWithContinue(targetURL, output string, inde return } +// DetectSizeWithRoundTripperAndAuthStream returns the size of target resource +func DetectSizeWithRoundTripperAndAuthStream(targetURL string, output io.Writer, showProgress, noProxy, insecureSkipVerify bool, + roundTripper http.RoundTripper, username, password string, timeout time.Duration) (total int64, rangeSupport bool, err error) { + downloader := HTTPDownloader{ + URL: targetURL, + ShowProgress: showProgress, + RoundTripper: roundTripper, + NoProxy: false, // below HTTP request does not need proxy + InsecureSkipVerify: insecureSkipVerify, + UserName: username, + Password: password, + Timeout: timeout, + } + + var detectOffset int64 + var lenErr error + + detectOffset = 2 + downloader.Header = make(map[string]string, 1) + downloader.Header["Range"] = fmt.Sprintf("bytes=%d-", detectOffset) + + downloader.PreStart = func(resp *http.Response) bool { + rangeSupport = resp.StatusCode == http.StatusPartialContent + contentLen := resp.Header.Get("Content-Length") + if total, lenErr = strconv.ParseInt(contentLen, 10, 0); lenErr == nil { + total += detectOffset + } else { + rangeSupport = false + } + // always return false because we just want to get the header from response + return false + } + + if err = downloader.DownloadAsStream(output); err != nil || lenErr != nil { + err = fmt.Errorf("cannot download from %s, response error: %v, content length error: %v", targetURL, err, lenErr) + } + return +} + // DetectSizeWithRoundTripperAndAuth returns the size of target resource func DetectSizeWithRoundTripperAndAuth(targetURL, output string, showProgress, noProxy, insecureSkipVerify bool, roundTripper http.RoundTripper, username, password string, timeout time.Duration) (total int64, rangeSupport bool, err error) { diff --git a/pkg/net/multi_thread.go b/pkg/net/multi_thread.go index a69ec34..19538d0 100644 --- a/pkg/net/multi_thread.go +++ b/pkg/net/multi_thread.go @@ -3,6 +3,7 @@ package net import ( "context" "fmt" + "io" "net/http" "os" "os/signal" @@ -70,6 +71,127 @@ func (d *MultiThreadDownloader) WithBasicAuth(username, password string) *MultiT return d } +// WithBearerToken sets the bearer token +func (d *MultiThreadDownloader) WithBearerToken(bearerToken string) *MultiThreadDownloader { + d.password = bearerToken + return d +} + +// DownloadWithContext starts to download the target URL with context +func (d *MultiThreadDownloader) DownloadWithContext(ctx context.Context, targetURL string, outputWriter io.Writer, thread int) (err error) { + // get the total size of the target file + var total int64 + var rangeSupport bool + if total, rangeSupport, err = DetectSizeWithRoundTripperAndAuthStream(targetURL, outputWriter, d.showProgress, + d.noProxy, d.insecureSkipVerify, d.roundTripper, d.username, d.password, d.timeout); rangeSupport && err != nil { + return + } + + if rangeSupport { + unit := total / int64(thread) + offset := total - unit*int64(thread) + var wg sync.WaitGroup + var m sync.Mutex + partItems := make(map[int]string) + + defer func() { + // remove all partial files + for _, part := range partItems { + _ = os.RemoveAll(part) + } + }() + + c := make(chan os.Signal, 1) + signal.Notify(c, os.Interrupt) + ctx, cancel := context.WithCancel(context.Background()) + var canceled bool + + go func() { + <-c + canceled = true + cancel() + }() + + fmt.Printf("start to download with %d threads, size: %d, unit: %d", thread, total, unit) + for i := 0; i < thread; i++ { + fmt.Println() // TODO take position, should take over by progerss bars + wg.Add(1) + go func(index int, wg *sync.WaitGroup, ctx context.Context) { + defer wg.Done() + outputFile, err := os.CreateTemp(os.TempDir(), fmt.Sprintf("part-%d", index)) + if err != nil { + fmt.Println("failed to create template file", err) + } + outputFile.Close() + + m.Lock() + partItems[index] = outputFile.Name() + m.Unlock() + + end := unit*int64(index+1) - 1 + if index == thread-1 { + // this is the last part + end += offset + } + start := unit * int64(index) + + downloader := &ContinueDownloader{} + downloader.WithoutProxy(d.noProxy). + WithRoundTripper(d.roundTripper). + WithInsecureSkipVerify(d.insecureSkipVerify). + WithBasicAuth(d.username, d.password). + WithContext(ctx).WithTimeout(d.timeout) + if downloadErr := downloader.DownloadWithContinue(targetURL, outputFile.Name(), + int64(index), start, end, d.showProgress); downloadErr != nil { + fmt.Println(downloadErr) + } + }(i, &wg, ctx) + } + + wg.Wait() + // ProgressIndicator{}.Close() + if canceled { + err = fmt.Errorf("download process canceled") + return + } + + // make the cursor right + // TODO the progress component should take over it + if thread > 1 { + // line := GetCurrentLine() + time.Sleep(time.Second) + fmt.Printf("\033[%dE\n", thread) // move to the target line + time.Sleep(time.Second * 5) + } + + for i := 0; i < thread; i++ { + partFile := partItems[i] + if data, ferr := os.ReadFile(partFile); ferr == nil { + if _, err = outputWriter.Write(data); err != nil { + err = fmt.Errorf("failed to write file: '%s'", partFile) + break + } else if !d.keepParts { + _ = os.RemoveAll(partFile) + } + } else { + err = fmt.Errorf("failed to read file: '%s'", partFile) + break + } + } + } else { + fmt.Println("cannot download it using multiple threads, failed to one") + downloader := &ContinueDownloader{} + downloader.WithoutProxy(d.noProxy) + downloader.WithRoundTripper(d.roundTripper) + downloader.WithInsecureSkipVerify(d.insecureSkipVerify) + downloader.WithTimeout(d.timeout) + downloader.WithBasicAuth(d.username, d.password) + err = downloader.DownloadWithContinueAsStream(targetURL, outputWriter, -1, 0, 0, true) + d.suggestedFilename = downloader.GetSuggestedFilename() + } + return +} + // Download starts to download the target URL func (d *MultiThreadDownloader) Download(targetURL, targetFilePath string, thread int) (err error) { // get the total size of the target file