diff --git a/cmd/ping/main.go b/cmd/ping/main.go index ad0a891..1ac429c 100644 --- a/cmd/ping/main.go +++ b/cmd/ping/main.go @@ -93,6 +93,7 @@ func main() { } req := ping.Request{ + ID: os.Getpid(), Dst: net.ParseIP(host.String()), Src: net.ParseIP(getAddr(*iface)), Data: data, diff --git a/ping.go b/ping.go index 5da9bd5..bec7fa5 100644 --- a/ping.go +++ b/ping.go @@ -50,10 +50,19 @@ type Response struct { } // Client is a ping client. -type Client struct{} +type Client struct { + // CheckResponse checks whether the response is valid. + CheckResponse func(*Response) bool +} // DefaultClient is the default client used by Do. -var DefaultClient = &Client{} +var DefaultClient = &Client{CheckResponse: CheckResponse} + +// CheckResponse checks whether the id and the seq are equal respectively +// between request and response. +func CheckResponse(resp *Response) bool { + return resp.ID == resp.Req.ID && int(resp.Seq) == resp.Req.Seq +} // NewRequest resolves dst as an IPv4 address and returns a pointer to a request // using that as the destination. @@ -108,28 +117,18 @@ func (c *Client) Do(ctx context.Context, req *Request) (*Response, error) { } } - var ( - resp *Response - readErr error - ) - sentAt, err := send(ctx, conn, req) if err != nil { return nil, err } - resp, readErr = read(ctx, conn) - if readErr != nil { - return nil, readErr + resp, err := read(ctx, conn, req, c.CheckResponse) + if err != nil { + return nil, err } resp.RTT = resp.rcvdAt.Sub(sentAt) req.sentAt = sentAt - resp.Req = req - - if readErr != nil { - return nil, readErr - } return resp, nil } @@ -199,18 +198,20 @@ func (req *Request) proto() int { return protocolIPv4ICMP } -func read(ctx context.Context, conn *icmp.PacketConn) (*Response, error) { +func read(ctx context.Context, conn *icmp.PacketConn, req *Request, + check func(*Response) bool) (*Response, error) { if c4 := conn.IPv4PacketConn(); c4 != nil { - return read4(ctx, c4) + return read4(ctx, c4, req, check) } c6 := conn.IPv6PacketConn() if c6 == nil { return nil, errors.New("bad icmp connection type") } - return read6(ctx, c6) + return read6(ctx, c6, req, check) } -func read4(ctx context.Context, conn *ipv4.PacketConn) (*Response, error) { +func read4(ctx context.Context, conn *ipv4.PacketConn, req *Request, + check func(*Response) bool) (*Response, error) { for { select { case <-ctx.Done(): @@ -259,9 +260,18 @@ func read4(ctx context.Context, conn *ipv4.PacketConn) (*Response, error) { ttl = cm.TTL } - srcHost, _, _ := net.SplitHostPort(src.String()) - dstHost, _, _ := net.SplitHostPort(conn.LocalAddr().String()) - return &Response{ + srcHost, _, err := net.SplitHostPort(src.String()) + if err != nil { + srcHost = src.String() + } + + dstHost, _, err := net.SplitHostPort(conn.LocalAddr().String()) + if err != nil { + dstHost = conn.LocalAddr().String() + } + + resp := &Response{ + Req: req, ID: id, Seq: seq, Data: bytesReceived[:n], @@ -270,12 +280,16 @@ func read4(ctx context.Context, conn *ipv4.PacketConn) (*Response, error) { Dst: net.ParseIP(dstHost), TTL: ttl, rcvdAt: rcv, - }, nil + } + if check == nil || check(resp) { + return resp, nil + } } } } -func read6(ctx context.Context, conn *ipv6.PacketConn) (*Response, error) { +func read6(ctx context.Context, conn *ipv6.PacketConn, req *Request, + check func(*Response) bool) (*Response, error) { for { select { case <-ctx.Done(): @@ -324,9 +338,18 @@ func read6(ctx context.Context, conn *ipv6.PacketConn) (*Response, error) { ttl = cm.HopLimit } - srcHost, _, _ := net.SplitHostPort(src.String()) - dstHost, _, _ := net.SplitHostPort(conn.LocalAddr().String()) - return &Response{ + srcHost, _, err := net.SplitHostPort(src.String()) + if err != nil { + srcHost = src.String() + } + + dstHost, _, err := net.SplitHostPort(conn.LocalAddr().String()) + if err != nil { + dstHost = conn.LocalAddr().String() + } + + resp := &Response{ + Req: req, ID: id, Seq: seq, Data: bytesReceived[:n], @@ -335,7 +358,10 @@ func read6(ctx context.Context, conn *ipv6.PacketConn) (*Response, error) { Dst: net.ParseIP(dstHost), TTL: ttl, rcvdAt: rcv, - }, nil + } + if check == nil || check(resp) { + return resp, nil + } } } } diff --git a/ping_test.go b/ping_test.go index f589eaf..f06c7b4 100644 --- a/ping_test.go +++ b/ping_test.go @@ -12,6 +12,42 @@ import ( "github.com/glinton/ping" ) +func TestClient(t *testing.T) { + req1, err := ping.NewRequest("www.baidu.com") + if err != nil { + t.Error(err) + return + } + req1.ID = 101 + req1.Seq = 201 + + req2 := *req1 + req2.ID = 102 + req2.Seq = 202 + + req3 := *req1 + req3.ID = 103 + req3.Seq = 203 + + wg := new(sync.WaitGroup) + wg.Add(3) + go testClient(t, wg, req1) + go testClient(t, wg, &req2) + go testClient(t, wg, &req3) + wg.Wait() +} + +func testClient(t *testing.T, wg *sync.WaitGroup, req *ping.Request) { + defer wg.Done() + + resp, err := ping.Do(context.Background(), req) + if err != nil { + t.Error(err) + } else if resp.ID != req.ID || int(resp.Seq) != req.Seq { + t.Error(req, resp) + } +} + func TestE2E(t *testing.T) { c := &ping.Client{}