Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions cmd/ping/ping.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ Examples:

func main() {
timeout := flag.Duration("t", time.Second*100000, "")
maxRtt := flag.Duration("mr", time.Second*3, "")
interval := flag.Duration("i", time.Second, "")
count := flag.Int("c", -1, "")
size := flag.Int("s", 24, "")
Expand Down Expand Up @@ -84,11 +85,15 @@ func main() {
fmt.Printf("round-trip min/avg/max/stddev = %v/%v/%v/%v\n",
stats.MinRtt, stats.AvgRtt, stats.MaxRtt, stats.StdDevRtt)
}
pinger.OnTimeOut = func(packet *probing.Packet) {
fmt.Println("timeout", packet.Addr, packet.Rtt, packet.TTL)
}

pinger.Count = *count
pinger.Size = *size
pinger.Interval = *interval
pinger.Timeout = *timeout
pinger.MaxRtt = *maxRtt
pinger.TTL = *ttl
pinger.SetPrivileged(*privileged)

Expand Down
63 changes: 41 additions & 22 deletions ping.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,24 +92,21 @@ var (
func New(addr string) *Pinger {
r := rand.New(rand.NewSource(getSeed()))
firstUUID := uuid.New()
var firstSequence = map[uuid.UUID]map[int]struct{}{}
firstSequence[firstUUID] = make(map[int]struct{})
return &Pinger{
Count: -1,
Interval: time.Second,
RecordRtts: true,
Size: timeSliceLength + trackerLength,
Timeout: time.Duration(math.MaxInt64),

Count: -1,
Interval: time.Second,
RecordRtts: true,
Size: timeSliceLength + trackerLength,
Timeout: time.Duration(math.MaxInt64),
MaxRtt: time.Duration(math.MaxInt64),
addr: addr,
done: make(chan interface{}),
id: r.Intn(math.MaxUint16),
trackerUUIDs: []uuid.UUID{firstUUID},
ipaddr: nil,
ipv4: false,
network: "ip",
protocol: "udp",
awaitingSequences: firstSequence,
awaitingSequences: newSeqMap(firstUUID),
TTL: 64,
logger: StdLogger{Logger: log.New(log.Writer(), log.Prefix(), log.Flags())},
}
Expand All @@ -129,6 +126,9 @@ type Pinger struct {
// Timeout specifies a timeout before ping exits, regardless of how many
// packets have been received.
Timeout time.Duration
// MaxRtt If no response is received after this time, OnTimeout is called
// important! This option is not guaranteed. and if we receive the packet that was timeout, the function OnDuplicateRecv will be called
MaxRtt time.Duration

// Count tells pinger to stop after sending (and receiving) Count echo
// packets. If this option is not specified, pinger will operate until
Expand Down Expand Up @@ -183,6 +183,9 @@ type Pinger struct {
// OnRecvError is called when an error occurs while Pinger attempts to receive a packet
OnRecvError func(error)

// OnTimeOut is called when a packet don't have received after MaxRtt.
OnTimeOut func(*Packet)

// Size of packet being sent
Size int

Expand All @@ -206,13 +209,12 @@ type Pinger struct {
df bool

// trackerUUIDs is the list of UUIDs being used for sending packets.
trackerUUIDs []uuid.UUID

ipv4 bool
id int
sequence int
// awaitingSequences are in-flight sequence numbers we keep track of to help remove duplicate receipts
awaitingSequences map[uuid.UUID]map[int]struct{}
awaitingSequences seqMap
// network is one of "ip", "ip4", or "ip6".
network string
// protocol is "icmp" or "udp".
Expand Down Expand Up @@ -530,10 +532,27 @@ func (p *Pinger) runLoop(
}

for {
subTime := time.Duration(math.MaxInt64)
first := p.awaitingSequences.peekFirst()
if first != nil {
subTime = time.Until(first.time.Add(p.MaxRtt))
}

select {
case <-p.done:
return nil

case <-time.After(subTime):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure about the performance implications of this. IIRC this will spawn a new sleeping goroutine for every packet. If there are a lot of outstanding packets it could bloat memory quite a lot.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I create a new pr: #49 and fix some problem

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure about it. but I rewrite it use time.timer

p.awaitingSequences.removeElem(first)
if p.OnTimeOut != nil {
inPkt := &Packet{
IPAddr: p.ipaddr,
Addr: p.addr,
Rtt: p.MaxRtt,
TTL: -1,
ID: p.id,
}
p.OnTimeOut(inPkt)
}
case <-timeout.C:
return nil

Expand Down Expand Up @@ -681,7 +700,7 @@ func (p *Pinger) getPacketUUID(pkt []byte) (*uuid.UUID, error) {
return nil, fmt.Errorf("error decoding tracking UUID: %w", err)
}

for _, item := range p.trackerUUIDs {
for _, item := range p.awaitingSequences.trackerUUIDs {
if item == packetUUID {
return &packetUUID, nil
}
Expand All @@ -691,7 +710,7 @@ func (p *Pinger) getPacketUUID(pkt []byte) (*uuid.UUID, error) {

// getCurrentTrackerUUID grabs the latest tracker UUID.
func (p *Pinger) getCurrentTrackerUUID() uuid.UUID {
return p.trackerUUIDs[len(p.trackerUUIDs)-1]
return p.awaitingSequences.trackerUUIDs[len(p.awaitingSequences.trackerUUIDs)-1]
}

func (p *Pinger) processPacket(recv *packet) error {
Expand Down Expand Up @@ -744,15 +763,16 @@ func (p *Pinger) processPacket(recv *packet) error {
inPkt.Rtt = receivedAt.Sub(timestamp)
inPkt.Seq = pkt.Seq
// If we've already received this sequence, ignore it.
if _, inflight := p.awaitingSequences[*pktUUID][pkt.Seq]; !inflight {
e, inflight := p.awaitingSequences.getElem(*pktUUID, pkt.Seq)
if !inflight {
p.PacketsRecvDuplicates++
if p.OnDuplicateRecv != nil {
p.OnDuplicateRecv(inPkt)
}
return nil
}
// remove it from the list of sequences we're waiting for so we don't get duplicates.
delete(p.awaitingSequences[*pktUUID], pkt.Seq)
p.awaitingSequences.removeElem(e)
p.updateStatistics(inPkt)
default:
// Very bad, not sure how this can happen
Expand All @@ -777,11 +797,11 @@ func (p *Pinger) sendICMP(conn packetConn) error {
if err != nil {
return fmt.Errorf("unable to marshal UUID binary: %w", err)
}
t := append(timeToBytes(time.Now()), uuidEncoded...)
now := time.Now()
t := append(timeToBytes(now), uuidEncoded...)
if remainSize := p.Size - timeSliceLength - trackerLength; remainSize > 0 {
t = append(t, bytes.Repeat([]byte{1}, remainSize)...)
}

body := &icmp.Echo{
ID: p.id,
Seq: p.sequence,
Expand Down Expand Up @@ -829,13 +849,12 @@ func (p *Pinger) sendICMP(conn packetConn) error {
p.OnSend(outPkt)
}
// mark this sequence as in-flight
p.awaitingSequences[currentUUID][p.sequence] = struct{}{}
p.awaitingSequences.putElem(currentUUID, p.sequence, now)
p.PacketsSent++
p.sequence++
if p.sequence > 65535 {
newUUID := uuid.New()
p.trackerUUIDs = append(p.trackerUUIDs, newUUID)
p.awaitingSequences[newUUID] = make(map[int]struct{})
p.awaitingSequences.newSeqMap(newUUID)
p.sequence = 0
}
break
Expand Down
13 changes: 9 additions & 4 deletions ping_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ func TestProcessPacket(t *testing.T) {
if err != nil {
t.Fatalf("unable to marshal UUID binary: %s", err)
}
data := append(timeToBytes(time.Now()), uuidEncoded...)
now := time.Now()
data := append(timeToBytes(now), uuidEncoded...)
if remainSize := pinger.Size - timeSliceLength - trackerLength; remainSize > 0 {
data = append(data, bytes.Repeat([]byte{1}, remainSize)...)
}
Expand All @@ -39,7 +40,8 @@ func TestProcessPacket(t *testing.T) {
Seq: pinger.sequence,
Data: data,
}
pinger.awaitingSequences[currentUUID][pinger.sequence] = struct{}{}
pinger.awaitingSequences.putElem(currentUUID, pinger.sequence, now)
//pinger.awaitingSequences[currentUUID][pinger.sequence] = struct{}{}

msg := &icmp.Message{
Type: ipv4.ICMPTypeEchoReply,
Expand Down Expand Up @@ -598,7 +600,8 @@ func TestProcessPacket_IgnoresDuplicateSequence(t *testing.T) {
if err != nil {
t.Fatalf("unable to marshal UUID binary: %s", err)
}
data := append(timeToBytes(time.Now()), uuidEncoded...)
now := time.Now()
data := append(timeToBytes(now), uuidEncoded...)
if remainSize := pinger.Size - timeSliceLength - trackerLength; remainSize > 0 {
data = append(data, bytes.Repeat([]byte{1}, remainSize)...)
}
Expand All @@ -609,7 +612,9 @@ func TestProcessPacket_IgnoresDuplicateSequence(t *testing.T) {
Data: data,
}
// register the sequence as sent
pinger.awaitingSequences[currentUUID][0] = struct{}{}

pinger.awaitingSequences.putElem(currentUUID, 0, now)
//pinger.awaitingSequences[currentUUID][0] = struct{}{}

msg := &icmp.Message{
Type: ipv4.ICMPTypeEchoReply,
Expand Down
70 changes: 70 additions & 0 deletions seq_map.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
package probing

import (
"github.com/google/uuid"
"time"
)

type seqMap struct {
trackerUUIDs []uuid.UUID
head *elem
tail *elem
seqMap map[uuid.UUID]map[int]*elem
}
type elem struct {
uuid uuid.UUID
seq int
time time.Time
prev *elem
next *elem
}

func newSeqMap(u uuid.UUID) seqMap {
s := seqMap{
head: &elem{},
tail: &elem{},
seqMap: map[uuid.UUID]map[int]*elem{},
}
s.trackerUUIDs = append(s.trackerUUIDs, u)
s.seqMap[u] = make(map[int]*elem)
s.head.next = s.tail
s.tail.prev = s.head
return s
}

func (s seqMap) newSeqMap(u uuid.UUID) {
s.trackerUUIDs = append(s.trackerUUIDs, u)
s.seqMap[u] = make(map[int]*elem)
s.head.next = s.tail
s.tail.prev = s.head
}

func (s seqMap) putElem(uuid uuid.UUID, seq int, now time.Time) {
e := &elem{
uuid: uuid,
seq: seq,
time: now,
prev: s.tail.prev,
next: s.tail,
}
s.tail.prev.next = e
s.tail.prev = e
s.seqMap[uuid][seq] = e
}
func (s seqMap) getElem(uuid uuid.UUID, seq int) (*elem, bool) {
e, ok := s.seqMap[uuid][seq]
return e, ok
}
func (s seqMap) removeElem(e *elem) {
e.prev.next = e.next
e.next.prev = e.prev
if m, ok := s.seqMap[e.uuid]; ok {
delete(m, e.seq)
}
}
func (s seqMap) peekFirst() *elem {
if s.head.next == s.tail {
return nil
}
return s.head.next
}
95 changes: 95 additions & 0 deletions seq_map_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
package probing

import (
"github.com/google/uuid"
"testing"
"time"
)

func TestSeqMap(t *testing.T) {
u := uuid.New()
s := newSeqMap(u)
t.Run("newSeqMap", func(t *testing.T) {
u2 := uuid.New()
s.newSeqMap(u2)

if len(s.trackerUUIDs) != 1 {
t.Errorf("Expected length of trackerUUIDs to be 2, got %d", len(s.trackerUUIDs))
}

if _, ok := s.seqMap[u2]; !ok {
t.Errorf("Expected seqMap to contain UUID %s", u2.String())
}
})

t.Run("putElem", func(t *testing.T) {
seq := 1
s.putElem(u, seq, time.Now())

// 检查 seqMap 中是否包含正确的元素
if _, ok := s.seqMap[u][seq]; !ok {
t.Errorf("Expected seqMap[%s][%d] to exist", u.String(), seq)
}

if s.peekFirst().seq != seq {
t.Errorf("Expected tail.prev.seq to be %d, got %d", seq, s.tail.prev.seq)
}
})

t.Run("getElem", func(t *testing.T) {
seq := 1
elem, ok := s.getElem(u, seq)

if !ok {
t.Errorf("Expected getElem to return true for existing element")
}

if elem.seq != seq {
t.Errorf("Expected element's seq to be %d, got %d", seq, elem.seq)
}
})

t.Run("removeElem", func(t *testing.T) {
seq := 1
elem, ok := s.getElem(u, seq)
if !ok {
t.Fatalf("Expected getElem to return true for existing element")
}

s.removeElem(elem)

if _, ok := s.seqMap[u][seq]; ok {
t.Errorf("Expected seqMap[%s][%d] to be removed", u.String(), seq)
}
})

// test peekFirst
t.Run("peekFirst", func(t *testing.T) {
seq := 2
s.putElem(u, seq, time.Now())

elem := s.peekFirst()

// 检查 peekFirst 是否返回链表的第一个元素
if elem.seq != seq {
t.Errorf("Expected peekFirst to return element with seq %d, got %d", seq, elem.seq)
}
})
}

func TestSeqMap2(t *testing.T) {
u := uuid.New()
s := newSeqMap(u)
for i := 0; i < 100; i++ {
s.putElem(u, i, time.Now())
}
for i := 0; i < 100; i++ {
e, ok := s.getElem(u, i)
AssertTrue(t, ok && e.seq == i)
}
for i := 0; i < 20; i++ {
first := s.peekFirst()
AssertTrue(t, first.seq == i)
s.removeElem(first)
}
}