diff --git a/sign.go b/sign.go index 2ddff2c..bc303ff 100644 --- a/sign.go +++ b/sign.go @@ -10,8 +10,11 @@ import ( "encoding/base64" "io" "net/http" + "net/url" "sort" + "strconv" "strings" + "time" ) var signParams = map[string]bool{ @@ -71,6 +74,10 @@ func Sign(r *http.Request, k Keys) { DefaultService.Sign(r, k) } +func SignQuery(r *http.Request, t time.Time, k Keys) { + DefaultService.SignQuery(r, t, k) +} + // Service represents an S3-compatible service. type Service struct { Domain string // service root domain, used to extract subdomain from an http.Request and pass it to Bucket @@ -89,6 +96,21 @@ func (s *Service) Sign(r *http.Request, k Keys) { r.Header.Set("Authorization", "AWS "+k.AccessKey+":"+string(sig)) } +func (s *Service) SignQuery(r *http.Request, t time.Time, k Keys) { + if k.SecurityToken != "" { + r.Header.Set("X-Amz-Security-Token", k.SecurityToken) + } + h := hmac.New(sha1.New, []byte(k.SecretKey)) + s.writeSigQueryData(h, r, t) + sig := make([]byte, base64.StdEncoding.EncodedLen(h.Size())) + base64.StdEncoding.Encode(sig, h.Sum(nil)) + if r.URL.RawQuery == "" { + r.URL.RawQuery = "AWSAccessKeyId=" + k.AccessKey + "&Signature=" + url.QueryEscape(string(sig)) + "&Expires=" + strconv.FormatInt(t.Unix(), 10) + } else { + r.URL.RawQuery += "&AWSAccessKeyId=" + k.AccessKey + "&Signature=" + url.QueryEscape(string(sig)) + "&Expires=" + strconv.FormatInt(t.Unix(), 10) + } +} + func (s *Service) writeSigData(w io.Writer, r *http.Request) { w.Write([]byte(r.Method)) w.Write([]byte{'\n'}) @@ -104,6 +126,19 @@ func (s *Service) writeSigData(w io.Writer, r *http.Request) { s.writeResource(w, r) } +func (s *Service) writeSigQueryData(w io.Writer, r *http.Request, t time.Time) { + w.Write([]byte(r.Method)) + w.Write([]byte{'\n'}) + w.Write([]byte(r.Header.Get("content-md5"))) + w.Write([]byte{'\n'}) + w.Write([]byte(r.Header.Get("content-type"))) + w.Write([]byte{'\n'}) + w.Write([]byte(strconv.FormatInt(t.Unix(), 10))) + w.Write([]byte{'\n'}) + writeAmzHeaders(w, r) + s.writeResource(w, r) +} + func (s *Service) writeResource(w io.Writer, r *http.Request) { s.writeVhostBucket(w, strings.ToLower(r.Host)) path := r.URL.RequestURI() diff --git a/sign_test.go b/sign_test.go index 9187c81..f8f1b17 100644 --- a/sign_test.go +++ b/sign_test.go @@ -4,6 +4,7 @@ import ( "bytes" "net/http" "testing" + "time" ) var exKeys = Keys{ @@ -297,3 +298,70 @@ func TestVhostBucket(t *testing.T) { } } } + +var signQueryTest = []struct { + service *Service + method string + url string + more http.Header + expires int64 + expBuf string + expQuery string +}{ + { + DefaultService, + "GET", + "http://johnsmith.s3.amazonaws.com/photos/puppy.jpg", + nil, + 1175139620, + "GET\n\n\n1175139620\n/johnsmith/photos/puppy.jpg", + "AWSAccessKeyId=AKIAIOSFODNN7EXAMPLE&Signature=NpgCjnDzrM%2BWFzoENXmpNDUsSn8%3D&Expires=1175139620", + }, { + DefaultService, + "GET", + "http://s3.amazonaws.com/johnsmith/photos/puppy.jpg", + nil, + 1175139620, + "GET\n\n\n1175139620\n/johnsmith/photos/puppy.jpg", + "AWSAccessKeyId=AKIAIOSFODNN7EXAMPLE&Signature=NpgCjnDzrM%2BWFzoENXmpNDUsSn8%3D&Expires=1175139620", + }, { + DefaultService, + "GET", + "http://johnsmith.s3.amazonaws.com/photos/puppy.jpg?acl", + nil, + 1175139620, + "GET\n\n\n1175139620\n/johnsmith/photos/puppy.jpg?acl", + "acl&AWSAccessKeyId=AKIAIOSFODNN7EXAMPLE&Signature=2t9CVTYWEqyKpbsimoCqHNLfxsA%3D&Expires=1175139620", + }, +} + +func TestQuerySign(t *testing.T) { + for _, ts := range signQueryTest { + r, err := http.NewRequest(ts.method, ts.url, nil) + if err != nil { + panic(err) + } + + for k, vs := range ts.more { + for _, v := range vs { + r.Header.Add(k, v) + } + } + var buf bytes.Buffer + ts.service.writeSigQueryData(&buf, r, time.Unix(ts.expires, 0)) + if buf.String() != ts.expBuf { + t.Errorf("in %s:", r.Method) + t.Logf("url %s", r.URL.String()) + t.Logf("exp %q", ts.expBuf) + t.Logf("got %q", buf.String()) + } + + ts.service.SignQuery(r, time.Unix(ts.expires, 0), exKeys) + if got := r.URL.RawQuery; got != ts.expQuery { + t.Errorf("in %s:", r.Method) + t.Logf("url %s", r.URL.String()) + t.Logf("exp %q", ts.expQuery) + t.Logf("got %q", got) + } + } +}