Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cmd/ping/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ func main() {
}

req := ping.Request{
ID: os.Getpid(),
Dst: net.ParseIP(host.String()),
Src: net.ParseIP(getAddr(*iface)),
Data: data,
Expand Down
82 changes: 54 additions & 28 deletions ping.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,19 @@ type Response struct {
}

// Client is a ping client.
type Client struct{}
type Client struct {
// CheckResponse checks whether the response is valid.
CheckResponse func(*Response) bool
}

// DefaultClient is the default client used by Do.
var DefaultClient = &Client{}
var DefaultClient = &Client{CheckResponse: CheckResponse}

// CheckResponse checks whether the id and the seq are equal respectively
// between request and response.
func CheckResponse(resp *Response) bool {
return resp.ID == resp.Req.ID && int(resp.Seq) == resp.Req.Seq
}

// NewRequest resolves dst as an IPv4 address and returns a pointer to a request
// using that as the destination.
Expand Down Expand Up @@ -108,28 +117,18 @@ func (c *Client) Do(ctx context.Context, req *Request) (*Response, error) {
}
}

var (
resp *Response
readErr error
)

sentAt, err := send(ctx, conn, req)
if err != nil {
return nil, err
}

resp, readErr = read(ctx, conn)
if readErr != nil {
return nil, readErr
resp, err := read(ctx, conn, req, c.CheckResponse)
if err != nil {
return nil, err
}

resp.RTT = resp.rcvdAt.Sub(sentAt)
req.sentAt = sentAt
resp.Req = req

if readErr != nil {
return nil, readErr
}

return resp, nil
}
Expand Down Expand Up @@ -199,18 +198,20 @@ func (req *Request) proto() int {
return protocolIPv4ICMP
}

func read(ctx context.Context, conn *icmp.PacketConn) (*Response, error) {
func read(ctx context.Context, conn *icmp.PacketConn, req *Request,
check func(*Response) bool) (*Response, error) {
if c4 := conn.IPv4PacketConn(); c4 != nil {
return read4(ctx, c4)
return read4(ctx, c4, req, check)
}
c6 := conn.IPv6PacketConn()
if c6 == nil {
return nil, errors.New("bad icmp connection type")
}
return read6(ctx, c6)
return read6(ctx, c6, req, check)
}

func read4(ctx context.Context, conn *ipv4.PacketConn) (*Response, error) {
func read4(ctx context.Context, conn *ipv4.PacketConn, req *Request,
check func(*Response) bool) (*Response, error) {
for {
select {
case <-ctx.Done():
Expand Down Expand Up @@ -259,9 +260,18 @@ func read4(ctx context.Context, conn *ipv4.PacketConn) (*Response, error) {
ttl = cm.TTL
}

srcHost, _, _ := net.SplitHostPort(src.String())
dstHost, _, _ := net.SplitHostPort(conn.LocalAddr().String())
return &Response{
srcHost, _, err := net.SplitHostPort(src.String())
if err != nil {
srcHost = src.String()
}

dstHost, _, err := net.SplitHostPort(conn.LocalAddr().String())
if err != nil {
dstHost = conn.LocalAddr().String()
}

resp := &Response{
Req: req,
ID: id,
Seq: seq,
Data: bytesReceived[:n],
Expand All @@ -270,12 +280,16 @@ func read4(ctx context.Context, conn *ipv4.PacketConn) (*Response, error) {
Dst: net.ParseIP(dstHost),
TTL: ttl,
rcvdAt: rcv,
}, nil
}
if check == nil || check(resp) {
return resp, nil
}
}
}
}

func read6(ctx context.Context, conn *ipv6.PacketConn) (*Response, error) {
func read6(ctx context.Context, conn *ipv6.PacketConn, req *Request,
check func(*Response) bool) (*Response, error) {
for {
select {
case <-ctx.Done():
Expand Down Expand Up @@ -324,9 +338,18 @@ func read6(ctx context.Context, conn *ipv6.PacketConn) (*Response, error) {
ttl = cm.HopLimit
}

srcHost, _, _ := net.SplitHostPort(src.String())
dstHost, _, _ := net.SplitHostPort(conn.LocalAddr().String())
return &Response{
srcHost, _, err := net.SplitHostPort(src.String())
if err != nil {
srcHost = src.String()
}

dstHost, _, err := net.SplitHostPort(conn.LocalAddr().String())
if err != nil {
dstHost = conn.LocalAddr().String()
}

resp := &Response{
Req: req,
ID: id,
Seq: seq,
Data: bytesReceived[:n],
Expand All @@ -335,7 +358,10 @@ func read6(ctx context.Context, conn *ipv6.PacketConn) (*Response, error) {
Dst: net.ParseIP(dstHost),
TTL: ttl,
rcvdAt: rcv,
}, nil
}
if check == nil || check(resp) {
return resp, nil
}
}
}
}
Expand Down
36 changes: 36 additions & 0 deletions ping_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,42 @@ import (
"github.com/glinton/ping"
)

func TestClient(t *testing.T) {
req1, err := ping.NewRequest("www.baidu.com")
if err != nil {
t.Error(err)
return
}
req1.ID = 101
req1.Seq = 201

req2 := *req1
req2.ID = 102
req2.Seq = 202

req3 := *req1
req3.ID = 103
req3.Seq = 203

wg := new(sync.WaitGroup)
wg.Add(3)
go testClient(t, wg, req1)
go testClient(t, wg, &req2)
go testClient(t, wg, &req3)
wg.Wait()
}

func testClient(t *testing.T, wg *sync.WaitGroup, req *ping.Request) {
defer wg.Done()

resp, err := ping.Do(context.Background(), req)
if err != nil {
t.Error(err)
} else if resp.ID != req.ID || int(resp.Seq) != req.Seq {
t.Error(req, resp)
}
}

func TestE2E(t *testing.T) {
c := &ping.Client{}

Expand Down