Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 67 additions & 48 deletions runner/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -445,15 +445,23 @@ func (r *Runner) prepareInputPaths() {
}
}

var duplicateTargetErr = errors.New("duplicate target")

func (r *Runner) prepareInput() {
var numHosts int
// check if input target host(s) have been provided
if len(r.options.InputTargetHost) > 0 {
for _, target := range r.options.InputTargetHost {
expandedTarget, _ := r.countTargetFromRawTarget(target)
if expandedTarget > 0 {
expandedTarget, err := r.countTargetFromRawTarget(target)
if err == nil && expandedTarget > 0 {
numHosts += expandedTarget
r.hm.Set(target, nil) //nolint
r.hm.Set(target, []byte("1")) //nolint
} else if r.options.SkipDedupe && errors.Is(err, duplicateTargetErr) {
if v, ok := r.hm.Get(target); ok {
cnt, _ := strconv.Atoi(string(v))
r.hm.Set(target, []byte(strconv.Itoa(cnt+1)))
numHosts += 1
}
}
}
}
Expand Down Expand Up @@ -611,10 +619,16 @@ func (r *Runner) loadAndCloseFile(finput *os.File) (numTargets int, err error) {
for scanner.Scan() {
target := strings.TrimSpace(scanner.Text())
// Used just to get the exact number of targets
expandedTarget, _ := r.countTargetFromRawTarget(target)
if expandedTarget > 0 {
expandedTarget, err := r.countTargetFromRawTarget(target)
if err == nil && expandedTarget > 0 {
numTargets += expandedTarget
r.hm.Set(target, nil) //nolint
r.hm.Set(target, []byte("1")) //nolint
} else if r.options.SkipDedupe && errors.Is(err, duplicateTargetErr) {
if v, ok := r.hm.Get(target); ok {
cnt, _ := strconv.Atoi(string(v))
r.hm.Set(target, []byte(strconv.Itoa(cnt+1)))
numTargets += 1
}
}
Comment on lines 654 to 664
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Extract duplicate counting logic to reduce code duplication.

This logic is duplicated from prepareInput. Consider extracting it into a helper method to maintain DRY principles and ensure consistent behavior.

Consider creating a helper method:

func (r *Runner) incrementTargetCount(target string) {
    if v, ok := r.hm.Get(target); ok {
        cnt, err := strconv.Atoi(string(v))
        if err != nil {
            cnt = 1
        }
        r.hm.Set(target, []byte(strconv.Itoa(cnt+1)))
    }
}

Then use it in both locations:

 } else if r.options.SkipDedupe && errors.Is(err, duplicateTargetErr) {
-    if v, ok := r.hm.Get(target); ok {
-        cnt, _ := strconv.Atoi(string(v))
-        r.hm.Set(target, []byte(strconv.Itoa(cnt+1)))
-        numTargets += 1
-    }
+    r.incrementTargetCount(target)
+    numTargets += 1
 }
🤖 Prompt for AI Agents
In runner/runner.go around lines 622 to 632, the logic for incrementing the
count of duplicate targets is duplicated from the prepareInput method. Extract
this duplicate counting logic into a new helper method on the Runner struct, for
example incrementTargetCount, which safely retrieves the current count, handles
conversion errors by defaulting to 1, increments the count, and updates the map.
Replace the duplicated code in both places with calls to this new helper method
to adhere to DRY principles and maintain consistent behavior.

}
err = finput.Close()
Expand All @@ -625,8 +639,9 @@ func (r *Runner) countTargetFromRawTarget(rawTarget string) (numTargets int, err
if rawTarget == "" {
return 0, nil
}

if _, ok := r.hm.Get(rawTarget); ok {
return 0, nil
return 0, duplicateTargetErr
}

expandedTarget := 0
Expand Down Expand Up @@ -1064,42 +1079,12 @@ func (r *Runner) RunEnumeration() {
URL, _ := urlutil.Parse(resp.URL)
domainFile := resp.Method + ":" + URL.EscapedString()
hash := hashes.Sha1([]byte(domainFile))
domainResponseFile := fmt.Sprintf("%s.txt", hash)
screenshotResponseFile := fmt.Sprintf("%s.png", hash)
hostFilename := strings.ReplaceAll(URL.Host, ":", "_")
domainResponseBaseDir := filepath.Join(r.options.StoreResponseDir, "response")
domainScreenshotBaseDir := filepath.Join(r.options.StoreResponseDir, "screenshot")
responseBaseDir := filepath.Join(domainResponseBaseDir, hostFilename)
screenshotBaseDir := filepath.Join(domainScreenshotBaseDir, hostFilename)

var responsePath, screenshotPath, screenshotPathRel string
// store response
if r.scanopts.StoreResponse || r.scanopts.StoreChain {
if r.scanopts.OmitBody {
resp.Raw = strings.Replace(resp.Raw, resp.ResponseBody, "", -1)
}

responsePath = fileutilz.AbsPathOrDefault(filepath.Join(responseBaseDir, domainResponseFile))
// URL.EscapedString returns that can be used as filename
respRaw := resp.Raw
reqRaw := resp.RequestRaw
if len(respRaw) > r.scanopts.MaxResponseBodySizeToSave {
respRaw = respRaw[:r.scanopts.MaxResponseBodySizeToSave]
}
data := reqRaw
if r.options.StoreChain && resp.Response != nil && resp.Response.HasChain() {
data = append(data, append([]byte("\n"), []byte(resp.Response.GetChain())...)...)
}
data = append(data, respRaw...)
data = append(data, []byte("\n\n\n")...)
data = append(data, []byte(resp.URL)...)
_ = fileutil.CreateFolder(responseBaseDir)
writeErr := os.WriteFile(responsePath, data, 0644)
if writeErr != nil {
gologger.Error().Msgf("Could not write response at path '%s', to disk: %s", responsePath, writeErr)
}
resp.StoredResponsePath = responsePath
}
var screenshotPath, screenshotPathRel string

if r.scanopts.Screenshot {
screenshotPath = fileutilz.AbsPathOrDefault(filepath.Join(screenshotBaseDir, screenshotResponseFile))
Expand Down Expand Up @@ -1257,14 +1242,28 @@ func (r *Runner) RunEnumeration() {
}
}

if len(r.options.requestURIs) > 0 {
for _, p := range r.options.requestURIs {
scanopts := r.scanopts.Clone()
scanopts.RequestURI = p
r.process(k, wg, r.hp, protocol, scanopts, output)
runProcess := func(times int) {
for i := 0; i < times; i++ {
if len(r.options.requestURIs) > 0 {
for _, p := range r.options.requestURIs {
scanopts := r.scanopts.Clone()
scanopts.RequestURI = p
r.process(k, wg, r.hp, protocol, scanopts, output)
}
} else {
r.process(k, wg, r.hp, protocol, &r.scanopts, output)
}
}
} else {
r.process(k, wg, r.hp, protocol, &r.scanopts, output)
}

if r.options.Stream {
runProcess(1)
} else if v, ok := r.hm.Get(k); ok {
cnt, err := strconv.Atoi(string(v))
if err != nil || cnt <= 0 {
cnt = 1
}
runProcess(cnt)
}

return nil
Expand Down Expand Up @@ -2150,9 +2149,29 @@ retry:
data = append(data, []byte("\n\n\n")...)
data = append(data, []byte(fullURL)...)
_ = fileutil.CreateFolder(responseBaseDir)
writeErr := os.WriteFile(responsePath, data, 0644)
if writeErr != nil {
gologger.Error().Msgf("Could not write response at path '%s', to disk: %s", responsePath, writeErr)

finalPath := responsePath
idx := 0
for {
targetPath := finalPath
if idx > 0 {
basePath := strings.TrimSuffix(responsePath, ".txt")
targetPath = fmt.Sprintf("%s_%d.txt", basePath, idx)
}
f, err := os.OpenFile(targetPath, os.O_WRONLY|os.O_CREATE|os.O_EXCL, 0644)
if err == nil {
_, writeErr := f.Write(data)
f.Close()
if writeErr != nil {
gologger.Error().Msgf("Could not write to '%s': %s", targetPath, writeErr)
}
break
}
if !os.IsExist(err) {
gologger.Error().Msgf("Failed to create file '%s': %s", targetPath, err)
break
}
idx++
}
}

Expand Down
7 changes: 5 additions & 2 deletions runner/runner_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"testing"
"time"

"github.com/pkg/errors"
_ "github.com/projectdiscovery/fdmax/autofdmax"
"github.com/projectdiscovery/httpx/common/httpx"
"github.com/projectdiscovery/mapcidr/asn"
Expand Down Expand Up @@ -124,7 +125,9 @@ func TestRunner_asn_targets(t *testing.T) {
}

func TestRunner_countTargetFromRawTarget(t *testing.T) {
options := &Options{}
options := &Options{
SkipDedupe: false,
}
r, err := New(options)
require.Nil(t, err, "could not create httpx runner")

Expand All @@ -139,7 +142,7 @@ func TestRunner_countTargetFromRawTarget(t *testing.T) {
err = r.hm.Set(input, nil)
require.Nil(t, err, "could not set value to hm")
got, err = r.countTargetFromRawTarget(input)
require.Nil(t, err, "could not count targets")
require.True(t, errors.Is(err, duplicateTargetErr), "expected duplicate target error")
require.Equal(t, expected, got, "got wrong output")

input = "173.0.84.0/24"
Expand Down