From 676d37c371c28d39b4420b0cdb44e225e84857b8 Mon Sep 17 00:00:00 2001 From: xgfone Date: Tue, 8 Oct 2019 14:15:06 +0800 Subject: [PATCH 1/5] add CheckResponse to check the id and the seq --- ping.go | 45 +++++++++++++++++++++++++++----------- ping_test.go | 62 +++++++++++++++++++++++++++++++++++++++------------- 2 files changed, 79 insertions(+), 28 deletions(-) diff --git a/ping.go b/ping.go index 5da9bd5..f77e8aa 100644 --- a/ping.go +++ b/ping.go @@ -50,10 +50,19 @@ type Response struct { } // Client is a ping client. -type Client struct{} +type Client struct { + // CheckResponse checks whether the response is valid. + CheckResponse func(*Response) bool +} // DefaultClient is the default client used by Do. -var DefaultClient = &Client{} +var DefaultClient = &Client{CheckResponse: CheckResponse} + +// CheckResponse checks whether the id and the seq are equal respectively +// between request and response. +func CheckResponse(resp *Response) bool { + return resp.ID == resp.Req.ID && int(resp.Seq) == resp.Req.Seq +} // NewRequest resolves dst as an IPv4 address and returns a pointer to a request // using that as the destination. @@ -118,14 +127,13 @@ func (c *Client) Do(ctx context.Context, req *Request) (*Response, error) { return nil, err } - resp, readErr = read(ctx, conn) + resp, readErr = read(ctx, conn, req, c.CheckResponse) if readErr != nil { return nil, readErr } resp.RTT = resp.rcvdAt.Sub(sentAt) req.sentAt = sentAt - resp.Req = req if readErr != nil { return nil, readErr @@ -199,18 +207,20 @@ func (req *Request) proto() int { return protocolIPv4ICMP } -func read(ctx context.Context, conn *icmp.PacketConn) (*Response, error) { +func read(ctx context.Context, conn *icmp.PacketConn, req *Request, + check func(*Response) bool) (*Response, error) { if c4 := conn.IPv4PacketConn(); c4 != nil { - return read4(ctx, c4) + return read4(ctx, c4, req, check) } c6 := conn.IPv6PacketConn() if c6 == nil { return nil, errors.New("bad icmp connection type") } - return read6(ctx, c6) + return read6(ctx, c6, req, check) } -func read4(ctx context.Context, conn *ipv4.PacketConn) (*Response, error) { +func read4(ctx context.Context, conn *ipv4.PacketConn, req *Request, + check func(*Response) bool) (*Response, error) { for { select { case <-ctx.Done(): @@ -261,7 +271,8 @@ func read4(ctx context.Context, conn *ipv4.PacketConn) (*Response, error) { srcHost, _, _ := net.SplitHostPort(src.String()) dstHost, _, _ := net.SplitHostPort(conn.LocalAddr().String()) - return &Response{ + resp := &Response{ + Req: req, ID: id, Seq: seq, Data: bytesReceived[:n], @@ -270,12 +281,16 @@ func read4(ctx context.Context, conn *ipv4.PacketConn) (*Response, error) { Dst: net.ParseIP(dstHost), TTL: ttl, rcvdAt: rcv, - }, nil + } + if check == nil || check(resp) { + return resp, nil + } } } } -func read6(ctx context.Context, conn *ipv6.PacketConn) (*Response, error) { +func read6(ctx context.Context, conn *ipv6.PacketConn, req *Request, + check func(*Response) bool) (*Response, error) { for { select { case <-ctx.Done(): @@ -326,7 +341,8 @@ func read6(ctx context.Context, conn *ipv6.PacketConn) (*Response, error) { srcHost, _, _ := net.SplitHostPort(src.String()) dstHost, _, _ := net.SplitHostPort(conn.LocalAddr().String()) - return &Response{ + resp := &Response{ + Req: req, ID: id, Seq: seq, Data: bytesReceived[:n], @@ -335,7 +351,10 @@ func read6(ctx context.Context, conn *ipv6.PacketConn) (*Response, error) { Dst: net.ParseIP(dstHost), TTL: ttl, rcvdAt: rcv, - }, nil + } + if check == nil || check(resp) { + return resp, nil + } } } } diff --git a/ping_test.go b/ping_test.go index f589eaf..04c2d04 100644 --- a/ping_test.go +++ b/ping_test.go @@ -1,4 +1,4 @@ -package ping_test +package ping import ( "context" @@ -8,13 +8,45 @@ import ( "sync" "testing" "time" - - "github.com/glinton/ping" ) -func TestE2E(t *testing.T) { - c := &ping.Client{} +func TestClient(t *testing.T) { + req1, err := NewRequest("www.baidu.com") + if err != nil { + t.Error(err) + return + } + req1.ID = 101 + req1.Seq = 201 + + req2 := *req1 + req2.ID = 102 + req2.Seq = 202 + + req3 := *req1 + req3.ID = 103 + req3.Seq = 203 + + wg := new(sync.WaitGroup) + wg.Add(3) + go testClient(t, wg, req1) + go testClient(t, wg, &req2) + go testClient(t, wg, &req3) + wg.Wait() +} +func testClient(t *testing.T, wg *sync.WaitGroup, req *Request) { + defer wg.Done() + + resp, err := Do(context.Background(), req) + if err != nil { + t.Error(err) + } else if resp.ID != req.ID || int(resp.Seq) != req.Seq { + t.Error(req, resp) + } +} + +func TestE2E(t *testing.T) { hostIPs := []string{"8.8.8.8", "8.8.4.4", "1.1.1.1"} count := 3 deadline := time.Second * 5 @@ -34,7 +66,7 @@ func TestE2E(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), deadline) defer cancel() - resps := make(chan *ping.Response, count) + resps := make(chan *Response, count) packetsSent := 0 for count == 0 || packetsSent < count { @@ -50,7 +82,7 @@ func TestE2E(t *testing.T) { wg.Add(1) go func(seq int) { defer wg.Done() - resp, err := c.Do(ctx, &ping.Request{ + resp, err := DefaultClient.Do(ctx, &Request{ Dst: net.ParseIP(host), Seq: seq, }) @@ -68,7 +100,7 @@ func TestE2E(t *testing.T) { wg.Wait() close(resps) - rsps := []*ping.Response{} + rsps := []*Response{} for res := range resps { rsps = append(rsps, res) } @@ -79,12 +111,12 @@ func TestE2E(t *testing.T) { pwg.Wait() } -func onRcv(res *ping.Response) { +func onRcv(res *Response) { fmt.Printf("%d bytes from %s: icmp_seq=%d time=%v ttl=%v\n", res.TotalLength, res.Src.String(), res.Seq, res.RTT, res.TTL) } -func onFin(packetsSent int, resps []*ping.Response) { +func onFin(packetsSent int, resps []*Response) { if len(resps) == 0 { fmt.Println("Sent:", packetsSent, "Received: 0") return @@ -119,12 +151,12 @@ func onFin(packetsSent int, resps []*ping.Response) { } func ExampleDo() { - req, err := ping.NewRequest("localhost") + req, err := NewRequest("localhost") if err != nil { panic(err) } - res, err := ping.Do(context.Background(), req) + res, err := Do(context.Background(), req) if err != nil { panic(err) } @@ -134,7 +166,7 @@ func ExampleDo() { } func ExampleIPv4() { - res, err := ping.IPv4(context.Background(), "google.com") + res, err := IPv4(context.Background(), "google.com") if err != nil { panic(err) } @@ -144,7 +176,7 @@ func ExampleIPv4() { } func ExampleNewRequest_withSource() { - req, err := ping.NewRequest("localhost") + req, err := NewRequest("localhost") if err != nil { panic(err) } @@ -153,7 +185,7 @@ func ExampleNewRequest_withSource() { // and want to ping from a specific interface, set the source. req.Src = net.ParseIP("127.0.0.2") - res, err := ping.Do(context.Background(), req) + res, err := Do(context.Background(), req) if err != nil { panic(err) } From 1a63ebb450f626a15e506275e74c13db6adeb83e Mon Sep 17 00:00:00 2001 From: xgfone Date: Tue, 8 Oct 2019 15:02:16 +0800 Subject: [PATCH 2/5] use the origin address when SplitHostPort fails --- cmd/ping/main.go | 1 + ping.go | 24 ++++++++++++++++++++---- 2 files changed, 21 insertions(+), 4 deletions(-) diff --git a/cmd/ping/main.go b/cmd/ping/main.go index ad0a891..1ac429c 100644 --- a/cmd/ping/main.go +++ b/cmd/ping/main.go @@ -93,6 +93,7 @@ func main() { } req := ping.Request{ + ID: os.Getpid(), Dst: net.ParseIP(host.String()), Src: net.ParseIP(getAddr(*iface)), Data: data, diff --git a/ping.go b/ping.go index f77e8aa..38f5f94 100644 --- a/ping.go +++ b/ping.go @@ -269,8 +269,16 @@ func read4(ctx context.Context, conn *ipv4.PacketConn, req *Request, ttl = cm.TTL } - srcHost, _, _ := net.SplitHostPort(src.String()) - dstHost, _, _ := net.SplitHostPort(conn.LocalAddr().String()) + srcHost, _, err := net.SplitHostPort(src.String()) + if err != nil { + srcHost = src.String() + } + + dstHost, _, err := net.SplitHostPort(conn.LocalAddr().String()) + if err != nil { + dstHost = conn.LocalAddr().String() + } + resp := &Response{ Req: req, ID: id, @@ -339,8 +347,16 @@ func read6(ctx context.Context, conn *ipv6.PacketConn, req *Request, ttl = cm.HopLimit } - srcHost, _, _ := net.SplitHostPort(src.String()) - dstHost, _, _ := net.SplitHostPort(conn.LocalAddr().String()) + srcHost, _, err := net.SplitHostPort(src.String()) + if err != nil { + srcHost = src.String() + } + + dstHost, _, err := net.SplitHostPort(conn.LocalAddr().String()) + if err != nil { + dstHost = conn.LocalAddr().String() + } + resp := &Response{ Req: req, ID: id, From a3c3f914b85dc79a4df0fd089fdb10ebad9f39d5 Mon Sep 17 00:00:00 2001 From: xgfone Date: Tue, 8 Oct 2019 17:03:56 +0800 Subject: [PATCH 3/5] remove the redundant codes --- ping.go | 15 +++------------ 1 file changed, 3 insertions(+), 12 deletions(-) diff --git a/ping.go b/ping.go index 38f5f94..bec7fa5 100644 --- a/ping.go +++ b/ping.go @@ -117,28 +117,19 @@ func (c *Client) Do(ctx context.Context, req *Request) (*Response, error) { } } - var ( - resp *Response - readErr error - ) - sentAt, err := send(ctx, conn, req) if err != nil { return nil, err } - resp, readErr = read(ctx, conn, req, c.CheckResponse) - if readErr != nil { - return nil, readErr + resp, err := read(ctx, conn, req, c.CheckResponse) + if err != nil { + return nil, err } resp.RTT = resp.rcvdAt.Sub(sentAt) req.sentAt = sentAt - if readErr != nil { - return nil, readErr - } - return resp, nil } From c70928fad22e9967129221cb01259e194bfd37f3 Mon Sep 17 00:00:00 2001 From: xiegaofeng Date: Tue, 22 Oct 2019 10:13:33 +0800 Subject: [PATCH 4/5] change the tests to a separate package --- ping_test.go | 30 ++++++++++++++++-------------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/ping_test.go b/ping_test.go index 04c2d04..09de13e 100644 --- a/ping_test.go +++ b/ping_test.go @@ -1,4 +1,4 @@ -package ping +package ping_test import ( "context" @@ -8,10 +8,12 @@ import ( "sync" "testing" "time" + + "github.com/glinton/ping" ) func TestClient(t *testing.T) { - req1, err := NewRequest("www.baidu.com") + req1, err := ping.NewRequest("www.baidu.com") if err != nil { t.Error(err) return @@ -35,10 +37,10 @@ func TestClient(t *testing.T) { wg.Wait() } -func testClient(t *testing.T, wg *sync.WaitGroup, req *Request) { +func testClient(t *testing.T, wg *sync.WaitGroup, req *ping.Request) { defer wg.Done() - resp, err := Do(context.Background(), req) + resp, err := ping.Do(context.Background(), req) if err != nil { t.Error(err) } else if resp.ID != req.ID || int(resp.Seq) != req.Seq { @@ -66,7 +68,7 @@ func TestE2E(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), deadline) defer cancel() - resps := make(chan *Response, count) + resps := make(chan *ping.Response, count) packetsSent := 0 for count == 0 || packetsSent < count { @@ -82,7 +84,7 @@ func TestE2E(t *testing.T) { wg.Add(1) go func(seq int) { defer wg.Done() - resp, err := DefaultClient.Do(ctx, &Request{ + resp, err := ping.DefaultClient.Do(ctx, &ping.Request{ Dst: net.ParseIP(host), Seq: seq, }) @@ -100,7 +102,7 @@ func TestE2E(t *testing.T) { wg.Wait() close(resps) - rsps := []*Response{} + rsps := []*ping.Response{} for res := range resps { rsps = append(rsps, res) } @@ -111,12 +113,12 @@ func TestE2E(t *testing.T) { pwg.Wait() } -func onRcv(res *Response) { +func onRcv(res *ping.Response) { fmt.Printf("%d bytes from %s: icmp_seq=%d time=%v ttl=%v\n", res.TotalLength, res.Src.String(), res.Seq, res.RTT, res.TTL) } -func onFin(packetsSent int, resps []*Response) { +func onFin(packetsSent int, resps []*ping.Response) { if len(resps) == 0 { fmt.Println("Sent:", packetsSent, "Received: 0") return @@ -151,12 +153,12 @@ func onFin(packetsSent int, resps []*Response) { } func ExampleDo() { - req, err := NewRequest("localhost") + req, err := ping.NewRequest("localhost") if err != nil { panic(err) } - res, err := Do(context.Background(), req) + res, err := ping.Do(context.Background(), req) if err != nil { panic(err) } @@ -166,7 +168,7 @@ func ExampleDo() { } func ExampleIPv4() { - res, err := IPv4(context.Background(), "google.com") + res, err := ping.IPv4(context.Background(), "google.com") if err != nil { panic(err) } @@ -176,7 +178,7 @@ func ExampleIPv4() { } func ExampleNewRequest_withSource() { - req, err := NewRequest("localhost") + req, err := ping.NewRequest("localhost") if err != nil { panic(err) } @@ -185,7 +187,7 @@ func ExampleNewRequest_withSource() { // and want to ping from a specific interface, set the source. req.Src = net.ParseIP("127.0.0.2") - res, err := Do(context.Background(), req) + res, err := ping.Do(context.Background(), req) if err != nil { panic(err) } From 955d59485facb4e150769d9e540e4292f085edbc Mon Sep 17 00:00:00 2001 From: xiegaofeng Date: Tue, 22 Oct 2019 10:18:01 +0800 Subject: [PATCH 5/5] use the new client instead of the default --- ping_test.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/ping_test.go b/ping_test.go index 09de13e..f06c7b4 100644 --- a/ping_test.go +++ b/ping_test.go @@ -49,6 +49,8 @@ func testClient(t *testing.T, wg *sync.WaitGroup, req *ping.Request) { } func TestE2E(t *testing.T) { + c := &ping.Client{} + hostIPs := []string{"8.8.8.8", "8.8.4.4", "1.1.1.1"} count := 3 deadline := time.Second * 5 @@ -84,7 +86,7 @@ func TestE2E(t *testing.T) { wg.Add(1) go func(seq int) { defer wg.Done() - resp, err := ping.DefaultClient.Do(ctx, &ping.Request{ + resp, err := c.Do(ctx, &ping.Request{ Dst: net.ParseIP(host), Seq: seq, })