From e9e91faa11f136e91fdeb3b369eec388ecd7de1b Mon Sep 17 00:00:00 2001 From: 0x123456789 <0x123456789> Date: Sat, 30 Mar 2024 21:44:12 +0300 Subject: [PATCH] add http2 version support --- client/client.go | 5 ++++- client/reader.go | 12 +++++++++++- client/reader_test.go | 42 +++++++++++++++++++++++++++++++++++++----- 3 files changed, 52 insertions(+), 7 deletions(-) diff --git a/client/client.go b/client/client.go index 4cd5c0d..b822e91 100644 --- a/client/client.go +++ b/client/client.go @@ -18,7 +18,10 @@ type Version struct { } func (v *Version) String() string { - return fmt.Sprintf("HTTP/%d.%d", v.Major, v.Minor) + if v.Major < 2 { + return fmt.Sprintf("HTTP/%d.%d", v.Major, v.Minor) + } + return fmt.Sprintf("HTTP/%d", v.Major) } var ( diff --git a/client/reader.go b/client/reader.go index b29fd85..c68cadb 100644 --- a/client/reader.go +++ b/client/reader.go @@ -43,13 +43,23 @@ func (r *reader) ReadVersion() (Version, error) { major = int(int(c) - 0x30) } case 6: - if c != '.' { + // For HTTP/2 and HTTP/3 there is no any '.', just do nothing + if c != '.' && major == 1 { return readVersionErr(pos, '.', c) } + if c != ' ' && (major == 2 || major == 3) { + return readVersionErr(pos, ' ', c) + } + if c == ' ' && (major == 2 || major == 3) { + return Version{Major: major, Minor: minor}, nil + } case 7: switch c { case '0', '1', '2', '3', '4', '5', '6', '7', '8', '9': minor = int(int(c) - 0x30) + case ' ': + // HTTP/2 case + minor = 0 } case 8: if c != ' ' { diff --git a/client/reader_test.go b/client/reader_test.go index d0d7bed..77966f1 100644 --- a/client/reader_test.go +++ b/client/reader_test.go @@ -15,6 +15,14 @@ type statusTest struct { err bool } +type versionTest struct { + name string + version string + major int + minor int + err bool +} + func TestStatusCode(t *testing.T) { tests := []statusTest{ {"redirect 301", "301\r\n", 301, false}, @@ -25,11 +33,35 @@ func TestStatusCode(t *testing.T) { {"invalid string", "aaa ", 0, true}, {"number with status text", "1234 unknown", 1234, false}, } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + r := reader{bufio.NewReader(strings.NewReader(test.statusLine))} + result, err := r.ReadStatusCode() + hasError := err != nil + require.Equal(t, test.err, hasError, err) + require.Equal(t, test.result, result) + }) + } +} + +func TestReadVersion(t *testing.T) { + tests := []versionTest{ + {"HTTP/0.9", "HTTP/0.9 OK", 0, 9, false}, + {"HTTP/1.0", "HTTP/1.0 OK", 1, 0, false}, + {"HTTP/1.1", "HTTP/1.1 OK", 1, 1, false}, + {"HTTP/2", "HTTP/2 OK", 2, 0, false}, + } + for _, test := range tests { - r := reader{bufio.NewReader(strings.NewReader(test.statusLine))} - result, err := r.ReadStatusCode() - hasError := err != nil - require.Equal(t, test.err, hasError, err) - require.Equal(t, test.result, result) + t.Run(test.name, func(t *testing.T) { + r := reader{bufio.NewReader(strings.NewReader(test.version))} + result, err := r.ReadVersion() + hasError := err != nil + require.Equal(t, test.err, hasError, err) + require.Equal(t, strings.TrimSuffix(test.version, " OK"), result.String()) + require.Equal(t, test.major, result.Major) + require.Equal(t, test.minor, result.Minor) + }) } }