Skip to content

Commit ae4b2e9

Browse files
committed
feat: add round-robin retries via Retry Rounds/Retry Delay
1 parent 6147653 commit ae4b2e9

File tree

3 files changed

+36
-39
lines changed

3 files changed

+36
-39
lines changed

runner/options.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -765,6 +765,10 @@ func (options *Options) ValidateOptions() error {
765765
return errors.New(fmt.Sprintf("invalid retry-delay: must be >0 when retry-rounds=%d (got %d)", options.RetryRounds, options.RetryDelay))
766766
}
767767

768+
if options.RetryRounds > 0 && options.RetryDelay <= 0 {
769+
return errors.New(fmt.Sprintf("invalid retry-delay: must be >0 when retry-rounds=%d (got %d)", options.RetryRounds, options.RetryDelay))
770+
}
771+
768772
return nil
769773
}
770774

runner/runner.go

Lines changed: 26 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import (
2323
"strconv"
2424
"strings"
2525
"sync"
26+
"sync/atomic"
2627
"time"
2728

2829
"golang.org/x/exp/maps"
@@ -1259,7 +1260,7 @@ func (r *Runner) RunEnumeration() {
12591260
wg, _ := syncutil.New(syncutil.WithSize(r.options.Threads))
12601261
retryCh := make(chan retryJob)
12611262

1262-
retryCancel, retryWait := r.retryLoop(context.Background(), retryCh, output, r.analyze)
1263+
_, drainedCh := r.retryLoop(context.Background(), retryCh, output, r.analyze)
12631264

12641265
processItem := func(k string) error {
12651266
if r.options.resumeCfg != nil {
@@ -1302,11 +1303,9 @@ func (r *Runner) RunEnumeration() {
13021303
}
13031304

13041305
wg.Wait()
1305-
1306-
retryWait()
1307-
retryCancel()
1308-
close(retryCh)
1309-
1306+
if r.options.RetryRounds > 0 {
1307+
<-drainedCh
1308+
}
13101309
close(output)
13111310
wgoutput.Wait()
13121311

@@ -1333,24 +1332,30 @@ type analyzeFunc func(*httpx.HTTPX, string, httpx.Target, string, string, *ScanO
13331332

13341333
func (r *Runner) retryLoop(
13351334
parent context.Context,
1336-
ch chan retryJob,
1335+
retryCh chan retryJob,
13371336
output chan<- Result,
13381337
analyze analyzeFunc,
1339-
) (func(), func()) {
1338+
) (stop func(), drained <-chan struct{}) {
1339+
var remaining atomic.Int64
13401340
ctx, cancel := context.WithCancel(parent)
1341-
var jobWG sync.WaitGroup
1341+
drainedCh := make(chan struct{})
13421342

13431343
go func() {
1344+
defer close(retryCh)
1345+
13441346
for {
13451347
select {
13461348
case <-ctx.Done():
13471349
return
1348-
case job := <-ch:
1349-
jobWG.Add(1)
1350+
case job, ok := <-retryCh:
1351+
if !ok {
1352+
return
1353+
}
1354+
if job.attempt == 1 {
1355+
remaining.Add(1)
1356+
}
13501357

13511358
go func(j retryJob) {
1352-
defer jobWG.Done()
1353-
13541359
if wait := time.Until(j.when); wait > 0 {
13551360
timer := time.NewTimer(wait)
13561361
select {
@@ -1364,28 +1369,27 @@ func (r *Runner) retryLoop(
13641369
res := analyze(j.hp, j.protocol, j.target, j.method, j.origInput, j.scanopts)
13651370
output <- res
13661371

1367-
if res.StatusCode == http.StatusTooManyRequests {
1368-
if j.attempt >= r.options.RetryRounds {
1369-
return
1370-
}
1371-
1372+
if res.StatusCode == http.StatusTooManyRequests && j.attempt < r.options.RetryRounds {
13721373
j.attempt++
13731374
j.when = time.Now().Add(time.Duration(r.options.RetryDelay) * time.Millisecond)
13741375

13751376
select {
13761377
case <-ctx.Done():
13771378
return
1378-
case ch <- j:
1379+
case retryCh <- j:
1380+
return
13791381
}
13801382
}
1383+
1384+
if remaining.Add(-1) == 0 {
1385+
close(drainedCh)
1386+
}
13811387
}(job)
13821388
}
13831389
}
13841390
}()
13851391

1386-
stop := func() { cancel() }
1387-
wait := func() { jobWG.Wait() }
1388-
return stop, wait
1392+
return func() { cancel() }, drainedCh
13891393
}
13901394

13911395
func logFilteredErrorPage(fileName, url string) {

runner/runner_test.go

Lines changed: 6 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ package runner
33
import (
44
"context"
55
"fmt"
6-
"log"
76
"net/http"
87
"net/http/httptest"
98
"os"
@@ -324,41 +323,34 @@ func TestRunner_Process_And_RetryLoop(t *testing.T) {
324323
var hits1, hits2 int32
325324
srv1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
326325
if atomic.AddInt32(&hits1, 1) != 4 {
327-
log.Println("serv1 429")
328326
w.WriteHeader(http.StatusTooManyRequests)
329327
return
330328
}
331-
log.Println("serv1 200")
332329
w.WriteHeader(http.StatusOK)
333330
}))
334331
defer srv1.Close()
335332

336333
srv2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
337334
if atomic.AddInt32(&hits2, 1) != 3 {
338-
log.Println("serv2 429")
339335
w.WriteHeader(http.StatusTooManyRequests)
340336
return
341337
}
342-
log.Println("serv2 200")
343338
w.WriteHeader(http.StatusOK)
344339
}))
345340
defer srv2.Close()
346341

347342
r, err := New(&Options{
348343
Threads: 1,
349-
Delay: 0,
350-
RetryRounds: 3,
351-
RetryDelay: 200, // Duration 권장
352-
Timeout: 2,
344+
RetryRounds: 2,
345+
RetryDelay: 5,
346+
Timeout: 3,
353347
})
354348
require.NoError(t, err)
355349

356350
output := make(chan Result)
357351
retryCh := make(chan retryJob)
358352

359-
// ctx, timeout := context.WithTimeout(context.Background(), time.Duration(r.options.Timeout))
360-
// defer timeout()
361-
cancel, wait := r.retryLoop(context.Background(), retryCh, output, r.analyze)
353+
_, drainedCh := r.retryLoop(context.Background(), retryCh, output, r.analyze)
362354

363355
wg, _ := syncutil.New(syncutil.WithSize(r.options.Threads))
364356
so := r.scanopts.Clone()
@@ -399,15 +391,12 @@ func TestRunner_Process_And_RetryLoop(t *testing.T) {
399391
}
400392

401393
wg.Wait()
402-
wait()
403-
cancel()
404-
405-
close(retryCh)
394+
<-drainedCh
406395
close(output)
407396
drainWG.Wait()
408397

409398
require.Equal(t, 3, s1n429)
410-
require.Equal(t, 1, s1n200)
399+
require.Equal(t, 0, s1n200)
411400
require.Equal(t, 2, s2n429)
412401
require.Equal(t, 1, s2n200)
413402
}

0 commit comments

Comments
 (0)