Skip to content

Commit 2407aba

Browse files
zhangenyaozey1996
authored andcommitted
add timeout handler and max rtt option
you can set MaxRtt and OnTimeout func Ontime func while be call when a request was not answered within a specified time Signed-off-by: zhangenyao <[email protected]>
1 parent 23b417c commit 2407aba

File tree

5 files changed

+226
-20
lines changed

5 files changed

+226
-20
lines changed

cmd/ping/ping.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ Examples:
3838

3939
func main() {
4040
timeout := flag.Duration("t", time.Second*100000, "")
41+
maxRtt := flag.Duration("mr", time.Second*3, "")
4142
interval := flag.Duration("i", time.Second, "")
4243
count := flag.Int("c", -1, "")
4344
size := flag.Int("s", 24, "")
@@ -84,11 +85,15 @@ func main() {
8485
fmt.Printf("round-trip min/avg/max/stddev = %v/%v/%v/%v\n",
8586
stats.MinRtt, stats.AvgRtt, stats.MaxRtt, stats.StdDevRtt)
8687
}
88+
pinger.OnTimeOut = func(packet *probing.Packet) {
89+
fmt.Println("timeout", packet.Addr, packet.Rtt, packet.TTL)
90+
}
8791

8892
pinger.Count = *count
8993
pinger.Size = *size
9094
pinger.Interval = *interval
9195
pinger.Timeout = *timeout
96+
pinger.MaxRtt = *maxRtt
9297
pinger.TTL = *ttl
9398
pinger.SetPrivileged(*privileged)
9499

ping.go

Lines changed: 47 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -92,24 +92,22 @@ var (
9292
func New(addr string) *Pinger {
9393
r := rand.New(rand.NewSource(getSeed()))
9494
firstUUID := uuid.New()
95-
var firstSequence = map[uuid.UUID]map[int]struct{}{}
96-
firstSequence[firstUUID] = make(map[int]struct{})
9795
return &Pinger{
9896
Count: -1,
9997
Interval: time.Second,
10098
RecordRtts: true,
10199
Size: timeSliceLength + trackerLength,
102100
Timeout: time.Duration(math.MaxInt64),
101+
MaxRtt: time.Duration(math.MaxInt64),
103102

104103
addr: addr,
105104
done: make(chan interface{}),
106105
id: r.Intn(math.MaxUint16),
107-
trackerUUIDs: []uuid.UUID{firstUUID},
108106
ipaddr: nil,
109107
ipv4: false,
110108
network: "ip",
111109
protocol: "udp",
112-
awaitingSequences: firstSequence,
110+
awaitingSequences: newSeqMap(firstUUID),
113111
TTL: 64,
114112
logger: StdLogger{Logger: log.New(log.Writer(), log.Prefix(), log.Flags())},
115113
}
@@ -129,6 +127,9 @@ type Pinger struct {
129127
// Timeout specifies a timeout before ping exits, regardless of how many
130128
// packets have been received.
131129
Timeout time.Duration
130+
// MaxRtt If no response is received after this time, OnTimeout is called
131+
// important! This option is not guaranteed. and if we receive the packet that was timeout, the function OnDuplicateRecv will be called
132+
MaxRtt time.Duration
132133

133134
// Count tells pinger to stop after sending (and receiving) Count echo
134135
// packets. If this option is not specified, pinger will operate until
@@ -183,6 +184,8 @@ type Pinger struct {
183184
// OnRecvError is called when an error occurs while Pinger attempts to receive a packet
184185
OnRecvError func(error)
185186

187+
// OnTimeOut is called when a packet don't have received after MaxRtt.
188+
OnTimeOut func(*Packet)
186189
// Size of packet being sent
187190
Size int
188191

@@ -205,14 +208,11 @@ type Pinger struct {
205208
// df when true sets the do-not-fragment bit in the outer IP or IPv6 header
206209
df bool
207210

208-
// trackerUUIDs is the list of UUIDs being used for sending packets.
209-
trackerUUIDs []uuid.UUID
210-
211211
ipv4 bool
212212
id int
213213
sequence int
214214
// awaitingSequences are in-flight sequence numbers we keep track of to help remove duplicate receipts
215-
awaitingSequences map[uuid.UUID]map[int]struct{}
215+
awaitingSequences seqMap
216216
// network is one of "ip", "ip4", or "ip6".
217217
network string
218218
// protocol is "icmp" or "udp".
@@ -520,20 +520,50 @@ func (p *Pinger) runLoop(
520520

521521
timeout := time.NewTicker(p.Timeout)
522522
interval := time.NewTicker(p.Interval)
523+
timeoutTimer := time.NewTimer(time.Duration(math.MaxInt64))
524+
skip := false
523525
defer func() {
524526
interval.Stop()
525527
timeout.Stop()
528+
timeoutTimer.Stop()
526529
}()
527530

528531
if err := p.sendICMP(conn); err != nil {
529532
return err
530533
}
531534

532535
for {
536+
if !skip {
537+
if !timeoutTimer.Stop() {
538+
<-timeoutTimer.C
539+
}
540+
}
541+
skip = false
542+
first := p.awaitingSequences.peekFirst()
543+
if first != nil {
544+
timeoutTimer.Reset(time.Until(first.time.Add(p.MaxRtt)))
545+
} else {
546+
timeoutTimer.Reset(time.Duration(math.MaxInt64))
547+
}
548+
533549
select {
534550
case <-p.done:
535551
return nil
536552

553+
case <-timeoutTimer.C:
554+
skip = true
555+
p.awaitingSequences.removeElem(first)
556+
if p.OnTimeOut != nil {
557+
inPkt := &Packet{
558+
IPAddr: p.ipaddr,
559+
Addr: p.addr,
560+
Rtt: p.MaxRtt,
561+
Seq: first.seq,
562+
TTL: -1,
563+
ID: p.id,
564+
}
565+
p.OnTimeOut(inPkt)
566+
}
537567
case <-timeout.C:
538568
return nil
539569

@@ -681,7 +711,7 @@ func (p *Pinger) getPacketUUID(pkt []byte) (*uuid.UUID, error) {
681711
return nil, fmt.Errorf("error decoding tracking UUID: %w", err)
682712
}
683713

684-
for _, item := range p.trackerUUIDs {
714+
for _, item := range p.awaitingSequences.trackerUUIDs {
685715
if item == packetUUID {
686716
return &packetUUID, nil
687717
}
@@ -691,7 +721,7 @@ func (p *Pinger) getPacketUUID(pkt []byte) (*uuid.UUID, error) {
691721

692722
// getCurrentTrackerUUID grabs the latest tracker UUID.
693723
func (p *Pinger) getCurrentTrackerUUID() uuid.UUID {
694-
return p.trackerUUIDs[len(p.trackerUUIDs)-1]
724+
return p.awaitingSequences.trackerUUIDs[len(p.awaitingSequences.trackerUUIDs)-1]
695725
}
696726

697727
func (p *Pinger) processPacket(recv *packet) error {
@@ -744,15 +774,16 @@ func (p *Pinger) processPacket(recv *packet) error {
744774
inPkt.Rtt = receivedAt.Sub(timestamp)
745775
inPkt.Seq = pkt.Seq
746776
// If we've already received this sequence, ignore it.
747-
if _, inflight := p.awaitingSequences[*pktUUID][pkt.Seq]; !inflight {
777+
e, inflight := p.awaitingSequences.getElem(*pktUUID, pkt.Seq)
778+
if !inflight {
748779
p.PacketsRecvDuplicates++
749780
if p.OnDuplicateRecv != nil {
750781
p.OnDuplicateRecv(inPkt)
751782
}
752783
return nil
753784
}
754785
// remove it from the list of sequences we're waiting for so we don't get duplicates.
755-
delete(p.awaitingSequences[*pktUUID], pkt.Seq)
786+
p.awaitingSequences.removeElem(e)
756787
p.updateStatistics(inPkt)
757788
default:
758789
// Very bad, not sure how this can happen
@@ -777,7 +808,8 @@ func (p *Pinger) sendICMP(conn packetConn) error {
777808
if err != nil {
778809
return fmt.Errorf("unable to marshal UUID binary: %w", err)
779810
}
780-
t := append(timeToBytes(time.Now()), uuidEncoded...)
811+
now := time.Now()
812+
t := append(timeToBytes(now), uuidEncoded...)
781813
if remainSize := p.Size - timeSliceLength - trackerLength; remainSize > 0 {
782814
t = append(t, bytes.Repeat([]byte{1}, remainSize)...)
783815
}
@@ -829,13 +861,12 @@ func (p *Pinger) sendICMP(conn packetConn) error {
829861
p.OnSend(outPkt)
830862
}
831863
// mark this sequence as in-flight
832-
p.awaitingSequences[currentUUID][p.sequence] = struct{}{}
864+
p.awaitingSequences.putElem(currentUUID, p.sequence, now)
833865
p.PacketsSent++
834866
p.sequence++
835867
if p.sequence > 65535 {
836868
newUUID := uuid.New()
837-
p.trackerUUIDs = append(p.trackerUUIDs, newUUID)
838-
p.awaitingSequences[newUUID] = make(map[int]struct{})
869+
p.awaitingSequences.newSeqMap(newUUID)
839870
p.sequence = 0
840871
}
841872
break

ping_test.go

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@ func TestProcessPacket(t *testing.T) {
2929
if err != nil {
3030
t.Fatalf("unable to marshal UUID binary: %s", err)
3131
}
32-
data := append(timeToBytes(time.Now()), uuidEncoded...)
32+
now := time.Now()
33+
data := append(timeToBytes(now), uuidEncoded...)
3334
if remainSize := pinger.Size - timeSliceLength - trackerLength; remainSize > 0 {
3435
data = append(data, bytes.Repeat([]byte{1}, remainSize)...)
3536
}
@@ -39,7 +40,8 @@ func TestProcessPacket(t *testing.T) {
3940
Seq: pinger.sequence,
4041
Data: data,
4142
}
42-
pinger.awaitingSequences[currentUUID][pinger.sequence] = struct{}{}
43+
pinger.awaitingSequences.putElem(currentUUID, pinger.sequence, now)
44+
//pinger.awaitingSequences[currentUUID][pinger.sequence] = struct{}{}
4345

4446
msg := &icmp.Message{
4547
Type: ipv4.ICMPTypeEchoReply,
@@ -598,7 +600,8 @@ func TestProcessPacket_IgnoresDuplicateSequence(t *testing.T) {
598600
if err != nil {
599601
t.Fatalf("unable to marshal UUID binary: %s", err)
600602
}
601-
data := append(timeToBytes(time.Now()), uuidEncoded...)
603+
now := time.Now()
604+
data := append(timeToBytes(now), uuidEncoded...)
602605
if remainSize := pinger.Size - timeSliceLength - trackerLength; remainSize > 0 {
603606
data = append(data, bytes.Repeat([]byte{1}, remainSize)...)
604607
}
@@ -609,7 +612,9 @@ func TestProcessPacket_IgnoresDuplicateSequence(t *testing.T) {
609612
Data: data,
610613
}
611614
// register the sequence as sent
612-
pinger.awaitingSequences[currentUUID][0] = struct{}{}
615+
616+
pinger.awaitingSequences.putElem(currentUUID, 0, now)
617+
//pinger.awaitingSequences[currentUUID][0] = struct{}{}
613618

614619
msg := &icmp.Message{
615620
Type: ipv4.ICMPTypeEchoReply,

seq_map.go

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
package probing
2+
3+
import (
4+
"github.com/google/uuid"
5+
"time"
6+
)
7+
8+
type seqMap struct {
9+
trackerUUIDs []uuid.UUID
10+
head *elem
11+
tail *elem
12+
seqMap map[uuid.UUID]map[int]*elem
13+
}
14+
type elem struct {
15+
uuid uuid.UUID
16+
seq int
17+
time time.Time
18+
prev *elem
19+
next *elem
20+
}
21+
22+
func newSeqMap(u uuid.UUID) seqMap {
23+
s := seqMap{
24+
head: &elem{},
25+
tail: &elem{},
26+
seqMap: map[uuid.UUID]map[int]*elem{},
27+
}
28+
s.trackerUUIDs = append(s.trackerUUIDs, u)
29+
s.seqMap[u] = make(map[int]*elem)
30+
s.head.next = s.tail
31+
s.tail.prev = s.head
32+
return s
33+
}
34+
35+
func (s seqMap) newSeqMap(u uuid.UUID) {
36+
s.trackerUUIDs = append(s.trackerUUIDs, u)
37+
s.seqMap[u] = make(map[int]*elem)
38+
s.head.next = s.tail
39+
s.tail.prev = s.head
40+
}
41+
42+
func (s seqMap) putElem(uuid uuid.UUID, seq int, now time.Time) {
43+
e := &elem{
44+
uuid: uuid,
45+
seq: seq,
46+
time: now,
47+
prev: s.tail.prev,
48+
next: s.tail,
49+
}
50+
s.tail.prev.next = e
51+
s.tail.prev = e
52+
s.seqMap[uuid][seq] = e
53+
}
54+
func (s seqMap) getElem(uuid uuid.UUID, seq int) (*elem, bool) {
55+
e, ok := s.seqMap[uuid][seq]
56+
return e, ok
57+
}
58+
func (s seqMap) removeElem(e *elem) {
59+
e.prev.next = e.next
60+
e.next.prev = e.prev
61+
if m, ok := s.seqMap[e.uuid]; ok {
62+
delete(m, e.seq)
63+
}
64+
}
65+
func (s seqMap) peekFirst() *elem {
66+
if s.head.next == s.tail {
67+
return nil
68+
}
69+
return s.head.next
70+
}

0 commit comments

Comments
 (0)