Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/backup.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ on:

jobs:
BackupGit:
runs-on: ubuntu-20.04
runs-on: ubuntu-22.04
steps:
- uses: actions/[email protected]
- name: backup
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/coverage-report.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ on:

jobs:
TestAndReport:
runs-on: ubuntu-20.04
runs-on: ubuntu-22.04
steps:
- name: Set up Go
uses: actions/setup-go@v4
Expand Down
12 changes: 6 additions & 6 deletions .github/workflows/pull-request.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ env:
jobs:
build:
name: Build
runs-on: ubuntu-20.04
runs-on: ubuntu-22.04
steps:
- name: Set up Go
uses: actions/setup-go@v4
Expand Down Expand Up @@ -51,7 +51,7 @@ jobs:

GoLint:
name: Lint
runs-on: ubuntu-20.04
runs-on: ubuntu-22.04
steps:
- name: Set up Go
uses: actions/setup-go@v4
Expand All @@ -66,7 +66,7 @@ jobs:
golint-path: ./...
Security:
name: Security
runs-on: ubuntu-20.04
runs-on: ubuntu-22.04
env:
GO111MODULE: on
steps:
Expand All @@ -78,7 +78,7 @@ jobs:
args: '-exclude=G402,G204,G304,G110,G306,G107 ./...'
CodeQL:
name: CodeQL
runs-on: ubuntu-20.04
runs-on: ubuntu-22.04
env:
GO111MODULE: on
steps:
Expand All @@ -92,15 +92,15 @@ jobs:
uses: github/codeql-action/analyze@v1
MarkdownLinkCheck:
name: MarkdownLinkCheck
runs-on: ubuntu-20.04
runs-on: ubuntu-22.04
steps:
- uses: actions/[email protected]
- uses: gaurav-nelson/[email protected]
with:
use-verbose-mode: 'yes'

image:
runs-on: ubuntu-20.04
runs-on: ubuntu-22.04
steps:
- name: Checkout
uses: actions/checkout@v4
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/release-drafter.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ on:

jobs:
UpdateReleaseDraft:
runs-on: ubuntu-20.04
runs-on: ubuntu-22.04
steps:
- uses: release-drafter/release-drafter@v5
env:
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/release.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ env:

jobs:
goreleaser:
runs-on: ubuntu-20.04
runs-on: ubuntu-22.04
steps:
- name: Checkout
uses: actions/[email protected]
Expand Down Expand Up @@ -43,7 +43,7 @@ jobs:
oras push ${{ env.REGISTRY }}/linuxsuren/hd:$TAG release

image:
runs-on: ubuntu-20.04
runs-on: ubuntu-22.04
steps:
- name: Checkout
uses: actions/checkout@v4
Expand Down
109 changes: 97 additions & 12 deletions pkg/net/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,8 @@ func (h *HTTPDownloader) fetchProxyFromEnv(scheme string) {
}
}

// DownloadFile download a file with the progress
func (h *HTTPDownloader) DownloadFile() error {
// DownloadAsStream downloads the file as stream
func (h *HTTPDownloader) DownloadAsStream(writer io.Writer) (err error) {
filepath, downloadURL, showProgress := h.TargetFilePath, h.URL, h.ShowProgress
// Get the data
if h.Context == nil {
Expand All @@ -115,7 +115,10 @@ func (h *HTTPDownloader) DownloadFile() error {

if h.UserName != "" && h.Password != "" {
req.SetBasicAuth(h.UserName, h.Password)
} else if h.Password != "" {
req.Header.Set("Authorization", "Bearer "+h.Password)
}

var tr http.RoundTripper
if h.RoundTripper != nil {
tr = h.RoundTripper
Expand Down Expand Up @@ -178,22 +181,32 @@ func (h *HTTPDownloader) DownloadFile() error {
}
}

if err := os.MkdirAll(path.Dir(filepath), os.FileMode(0755)); err != nil {
return err
h.progressIndicator.Writer = writer
h.progressIndicator.Init()

// Write the body to file
_, err = io.Copy(h.progressIndicator, resp.Body)
return
}

// DownloadFile download a file with the progress
func (h *HTTPDownloader) DownloadFile() (err error) {
filepath := h.TargetFilePath
if err = os.MkdirAll(path.Dir(filepath), os.FileMode(0755)); err != nil {
return
}

// Create the file
out, err := os.Create(filepath)
var out io.WriteCloser
out, err = os.Create(filepath)
if err != nil {
_ = out.Close()
return err
return
}
defer func() {
_ = out.Close()
}()

h.progressIndicator.Writer = out
h.progressIndicator.Init()

// Write the body to file
_, err = io.Copy(h.progressIndicator, resp.Body)
err = h.DownloadAsStream(out)
return err
}

Expand Down Expand Up @@ -269,6 +282,39 @@ func (c *ContinueDownloader) WithBasicAuth(username, password string) *ContinueD
return c
}

// DownloadWithContinueAsStream downloads the files continuously
func (c *ContinueDownloader) DownloadWithContinueAsStream(targetURL string, output io.Writer, index, continueAt, end int64, showProgress bool) (err error) {
c.downloader = &HTTPDownloader{
URL: targetURL,
ShowProgress: showProgress,
NoProxy: c.noProxy,
RoundTripper: c.roundTripper,
InsecureSkipVerify: c.insecureSkipVerify,
UserName: c.UserName,
Password: c.Password,
Context: c.Context,
Timeout: c.Timeout,
}
if index >= 0 {
c.downloader.Title = fmt.Sprintf("Downloading part %d", index)
}

if continueAt >= 0 {
c.downloader.Header = make(map[string]string, 1)

if end > continueAt {
c.downloader.Header["Range"] = fmt.Sprintf("bytes=%d-%d", continueAt, end)
} else {
c.downloader.Header["Range"] = fmt.Sprintf("bytes=%d-", continueAt)
}
}

if err = c.downloader.DownloadAsStream(output); err != nil {
err = fmt.Errorf("cannot download from %s, error: %v", targetURL, err)
}
return
}

// DownloadWithContinue downloads the files continuously
func (c *ContinueDownloader) DownloadWithContinue(targetURL, output string, index, continueAt, end int64, showProgress bool) (err error) {
c.downloader = &HTTPDownloader{
Expand Down Expand Up @@ -303,6 +349,45 @@ func (c *ContinueDownloader) DownloadWithContinue(targetURL, output string, inde
return
}

// DetectSizeWithRoundTripperAndAuthStream returns the size of target resource
func DetectSizeWithRoundTripperAndAuthStream(targetURL string, output io.Writer, showProgress, noProxy, insecureSkipVerify bool,
roundTripper http.RoundTripper, username, password string, timeout time.Duration) (total int64, rangeSupport bool, err error) {
downloader := HTTPDownloader{
URL: targetURL,
ShowProgress: showProgress,
RoundTripper: roundTripper,
NoProxy: false, // below HTTP request does not need proxy
InsecureSkipVerify: insecureSkipVerify,
UserName: username,
Password: password,
Timeout: timeout,
}

var detectOffset int64
var lenErr error

detectOffset = 2
downloader.Header = make(map[string]string, 1)
downloader.Header["Range"] = fmt.Sprintf("bytes=%d-", detectOffset)

downloader.PreStart = func(resp *http.Response) bool {
rangeSupport = resp.StatusCode == http.StatusPartialContent
contentLen := resp.Header.Get("Content-Length")
if total, lenErr = strconv.ParseInt(contentLen, 10, 0); lenErr == nil {
total += detectOffset
} else {
rangeSupport = false
}
// always return false because we just want to get the header from response
return false
}

if err = downloader.DownloadAsStream(output); err != nil || lenErr != nil {
err = fmt.Errorf("cannot download from %s, response error: %v, content length error: %v", targetURL, err, lenErr)
}
return
}

// DetectSizeWithRoundTripperAndAuth returns the size of target resource
func DetectSizeWithRoundTripperAndAuth(targetURL, output string, showProgress, noProxy, insecureSkipVerify bool,
roundTripper http.RoundTripper, username, password string, timeout time.Duration) (total int64, rangeSupport bool, err error) {
Expand Down
122 changes: 122 additions & 0 deletions pkg/net/multi_thread.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package net
import (
"context"
"fmt"
"io"
"net/http"
"os"
"os/signal"
Expand Down Expand Up @@ -70,6 +71,127 @@ func (d *MultiThreadDownloader) WithBasicAuth(username, password string) *MultiT
return d
}

// WithBearerToken sets the bearer token
func (d *MultiThreadDownloader) WithBearerToken(bearerToken string) *MultiThreadDownloader {
d.password = bearerToken
return d
}

// DownloadWithContext starts to download the target URL with context
func (d *MultiThreadDownloader) DownloadWithContext(ctx context.Context, targetURL string, outputWriter io.Writer, thread int) (err error) {
// get the total size of the target file
var total int64
var rangeSupport bool
if total, rangeSupport, err = DetectSizeWithRoundTripperAndAuthStream(targetURL, outputWriter, d.showProgress,
d.noProxy, d.insecureSkipVerify, d.roundTripper, d.username, d.password, d.timeout); rangeSupport && err != nil {
return
}

if rangeSupport {
unit := total / int64(thread)
offset := total - unit*int64(thread)
var wg sync.WaitGroup
var m sync.Mutex
partItems := make(map[int]string)

defer func() {
// remove all partial files
for _, part := range partItems {
_ = os.RemoveAll(part)
}
}()

c := make(chan os.Signal, 1)
signal.Notify(c, os.Interrupt)
ctx, cancel := context.WithCancel(context.Background())
var canceled bool

go func() {
<-c
canceled = true
cancel()
}()

fmt.Printf("start to download with %d threads, size: %d, unit: %d", thread, total, unit)
for i := 0; i < thread; i++ {
fmt.Println() // TODO take position, should take over by progerss bars
wg.Add(1)
go func(index int, wg *sync.WaitGroup, ctx context.Context) {
defer wg.Done()
outputFile, err := os.CreateTemp(os.TempDir(), fmt.Sprintf("part-%d", index))
if err != nil {
fmt.Println("failed to create template file", err)
}
outputFile.Close()

m.Lock()
partItems[index] = outputFile.Name()
m.Unlock()

end := unit*int64(index+1) - 1
if index == thread-1 {
// this is the last part
end += offset
}
start := unit * int64(index)

downloader := &ContinueDownloader{}
downloader.WithoutProxy(d.noProxy).
WithRoundTripper(d.roundTripper).
WithInsecureSkipVerify(d.insecureSkipVerify).
WithBasicAuth(d.username, d.password).
WithContext(ctx).WithTimeout(d.timeout)
if downloadErr := downloader.DownloadWithContinue(targetURL, outputFile.Name(),
int64(index), start, end, d.showProgress); downloadErr != nil {
fmt.Println(downloadErr)
}
}(i, &wg, ctx)
}

wg.Wait()
// ProgressIndicator{}.Close()
if canceled {
err = fmt.Errorf("download process canceled")
return
}

// make the cursor right
// TODO the progress component should take over it
if thread > 1 {
// line := GetCurrentLine()
time.Sleep(time.Second)
fmt.Printf("\033[%dE\n", thread) // move to the target line
time.Sleep(time.Second * 5)
}

for i := 0; i < thread; i++ {
partFile := partItems[i]
if data, ferr := os.ReadFile(partFile); ferr == nil {
if _, err = outputWriter.Write(data); err != nil {
err = fmt.Errorf("failed to write file: '%s'", partFile)
break
} else if !d.keepParts {
_ = os.RemoveAll(partFile)
}
} else {
err = fmt.Errorf("failed to read file: '%s'", partFile)
break
}
}
} else {
fmt.Println("cannot download it using multiple threads, failed to one")
downloader := &ContinueDownloader{}
downloader.WithoutProxy(d.noProxy)
downloader.WithRoundTripper(d.roundTripper)
downloader.WithInsecureSkipVerify(d.insecureSkipVerify)
downloader.WithTimeout(d.timeout)
downloader.WithBasicAuth(d.username, d.password)
err = downloader.DownloadWithContinueAsStream(targetURL, outputWriter, -1, 0, 0, true)
d.suggestedFilename = downloader.GetSuggestedFilename()
}
return
}

// Download starts to download the target URL
func (d *MultiThreadDownloader) Download(targetURL, targetFilePath string, thread int) (err error) {
// get the total size of the target file
Expand Down
Loading