Add unit tests for auth challenge and endpoint

Signed-off-by: Derek McGowan <derek@mcgstyle.net> (github: dmcgowan)
This commit is contained in:
Derek McGowan 2015-05-07 13:16:52 -07:00
parent 174a732c94
commit b1ba2183ee
7 changed files with 315 additions and 14 deletions

View File

@ -127,7 +127,7 @@ func expectTokenOrQuoted(s string) (value string, rest string) {
p := make([]byte, len(s)-1) p := make([]byte, len(s)-1)
j := copy(p, s[:i]) j := copy(p, s[:i])
escape := true escape := true
for i = i + i; i < len(s); i++ { for i = i + 1; i < len(s); i++ {
b := s[i] b := s[i]
switch { switch {
case escape: case escape:

View File

@ -0,0 +1,37 @@
package client
import (
"net/http"
"testing"
)
func TestAuthChallengeParse(t *testing.T) {
header := http.Header{}
header.Add("WWW-Authenticate", `Bearer realm="https://auth.example.com/token",service="registry.example.com",other=fun,slashed="he\"\l\lo"`)
challenges := parseAuthHeader(header)
if len(challenges) != 1 {
t.Fatalf("Unexpected number of auth challenges: %d, expected 1", len(challenges))
}
if expected := "bearer"; challenges[0].Scheme != expected {
t.Fatalf("Unexpected scheme: %s, expected: %s", challenges[0].Scheme, expected)
}
if expected := "https://auth.example.com/token"; challenges[0].Parameters["realm"] != expected {
t.Fatalf("Unexpected param: %s, expected: %s", challenges[0].Parameters["realm"], expected)
}
if expected := "registry.example.com"; challenges[0].Parameters["service"] != expected {
t.Fatalf("Unexpected param: %s, expected: %s", challenges[0].Parameters["service"], expected)
}
if expected := "fun"; challenges[0].Parameters["other"] != expected {
t.Fatalf("Unexpected param: %s, expected: %s", challenges[0].Parameters["other"], expected)
}
if expected := "he\"llo"; challenges[0].Parameters["slashed"] != expected {
t.Fatalf("Unexpected param: %s, expected: %s", challenges[0].Parameters["slashed"], expected)
}
}

View File

@ -117,6 +117,8 @@ func (e *RepositoryEndpoint) URLBuilder() (*v2.URLBuilder, error) {
// HTTPClient returns a new HTTP client configured for this endpoint // HTTPClient returns a new HTTP client configured for this endpoint
func (e *RepositoryEndpoint) HTTPClient(name string) (*http.Client, error) { func (e *RepositoryEndpoint) HTTPClient(name string) (*http.Client, error) {
// TODO(dmcgowan): create http.Transport
transport := &repositoryTransport{ transport := &repositoryTransport{
Header: e.Header, Header: e.Header,
} }

View File

@ -0,0 +1,259 @@
package client
import (
"encoding/base64"
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"testing"
"github.com/docker/distribution/testutil"
)
type testAuthenticationWrapper struct {
headers http.Header
authCheck func(string) bool
next http.Handler
}
func (w *testAuthenticationWrapper) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
auth := r.Header.Get("Authorization")
if auth == "" || !w.authCheck(auth) {
h := rw.Header()
for k, values := range w.headers {
h[k] = values
}
rw.WriteHeader(http.StatusUnauthorized)
return
}
w.next.ServeHTTP(rw, r)
}
func testServerWithAuth(rrm testutil.RequestResponseMap, authenticate string, authCheck func(string) bool) (*RepositoryEndpoint, func()) {
h := testutil.NewHandler(rrm)
wrapper := &testAuthenticationWrapper{
headers: http.Header(map[string][]string{
"Docker-Distribution-API-Version": {"registry/2.0"},
"WWW-Authenticate": {authenticate},
}),
authCheck: authCheck,
next: h,
}
s := httptest.NewServer(wrapper)
e := RepositoryEndpoint{Endpoint: s.URL, Mirror: false}
return &e, s.Close
}
type testCredentialStore struct {
username string
password string
}
func (tcs *testCredentialStore) Basic(*url.URL) (string, string) {
return tcs.username, tcs.password
}
func TestEndpointAuthorizeToken(t *testing.T) {
service := "localhost.localdomain"
repo1 := "some/registry"
repo2 := "other/registry"
scope1 := fmt.Sprintf("repository:%s:pull,push", repo1)
scope2 := fmt.Sprintf("repository:%s:pull,push", repo2)
tokenMap := testutil.RequestResponseMap([]testutil.RequestResponseMapping{
{
Request: testutil.Request{
Method: "GET",
Route: fmt.Sprintf("/token?scope=%s&service=%s", url.QueryEscape(scope1), service),
},
Response: testutil.Response{
StatusCode: http.StatusOK,
Body: []byte(`{"token":"statictoken"}`),
},
},
{
Request: testutil.Request{
Method: "GET",
Route: fmt.Sprintf("/token?scope=%s&service=%s", url.QueryEscape(scope2), service),
},
Response: testutil.Response{
StatusCode: http.StatusOK,
Body: []byte(`{"token":"badtoken"}`),
},
},
})
te, tc := testServer(tokenMap)
defer tc()
m := testutil.RequestResponseMap([]testutil.RequestResponseMapping{
{
Request: testutil.Request{
Method: "GET",
Route: "/hello",
},
Response: testutil.Response{
StatusCode: http.StatusAccepted,
},
},
})
authenicate := fmt.Sprintf("Bearer realm=%q,service=%q", te.Endpoint+"/token", service)
validCheck := func(a string) bool {
return a == "Bearer statictoken"
}
e, c := testServerWithAuth(m, authenicate, validCheck)
defer c()
client, err := e.HTTPClient(repo1)
if err != nil {
t.Fatalf("Error creating http client: %s", err)
}
req, _ := http.NewRequest("GET", e.Endpoint+"/hello", nil)
resp, err := client.Do(req)
if err != nil {
t.Fatalf("Error sending get request: %s", err)
}
if resp.StatusCode != http.StatusAccepted {
t.Fatalf("Unexpected status code: %d, expected %d", resp.StatusCode, http.StatusAccepted)
}
badCheck := func(a string) bool {
return a == "Bearer statictoken"
}
e2, c2 := testServerWithAuth(m, authenicate, badCheck)
defer c2()
client2, err := e2.HTTPClient(repo2)
if err != nil {
t.Fatalf("Error creating http client: %s", err)
}
req, _ = http.NewRequest("GET", e.Endpoint+"/hello", nil)
resp, err = client2.Do(req)
if err != nil {
t.Fatalf("Error sending get request: %s", err)
}
if resp.StatusCode != http.StatusUnauthorized {
t.Fatalf("Unexpected status code: %d, expected %d", resp.StatusCode, http.StatusUnauthorized)
}
}
func basicAuth(username, password string) string {
auth := username + ":" + password
return base64.StdEncoding.EncodeToString([]byte(auth))
}
func TestEndpointAuthorizeTokenBasic(t *testing.T) {
service := "localhost.localdomain"
repo := "some/fun/registry"
scope := fmt.Sprintf("repository:%s:pull,push", repo)
username := "tokenuser"
password := "superSecretPa$$word"
tokenMap := testutil.RequestResponseMap([]testutil.RequestResponseMapping{
{
Request: testutil.Request{
Method: "GET",
Route: fmt.Sprintf("/token?account=%s&scope=%s&service=%s", username, url.QueryEscape(scope), service),
},
Response: testutil.Response{
StatusCode: http.StatusOK,
Body: []byte(`{"token":"statictoken"}`),
},
},
})
authenicate1 := fmt.Sprintf("Basic realm=localhost")
basicCheck := func(a string) bool {
return a == fmt.Sprintf("Basic %s", basicAuth(username, password))
}
te, tc := testServerWithAuth(tokenMap, authenicate1, basicCheck)
defer tc()
m := testutil.RequestResponseMap([]testutil.RequestResponseMapping{
{
Request: testutil.Request{
Method: "GET",
Route: "/hello",
},
Response: testutil.Response{
StatusCode: http.StatusAccepted,
},
},
})
authenicate2 := fmt.Sprintf("Bearer realm=%q,service=%q", te.Endpoint+"/token", service)
bearerCheck := func(a string) bool {
return a == "Bearer statictoken"
}
e, c := testServerWithAuth(m, authenicate2, bearerCheck)
defer c()
e.Credentials = &testCredentialStore{
username: username,
password: password,
}
client, err := e.HTTPClient(repo)
if err != nil {
t.Fatalf("Error creating http client: %s", err)
}
req, _ := http.NewRequest("GET", e.Endpoint+"/hello", nil)
resp, err := client.Do(req)
if err != nil {
t.Fatalf("Error sending get request: %s", err)
}
if resp.StatusCode != http.StatusAccepted {
t.Fatalf("Unexpected status code: %d, expected %d", resp.StatusCode, http.StatusAccepted)
}
}
func TestEndpointAuthorizeBasic(t *testing.T) {
m := testutil.RequestResponseMap([]testutil.RequestResponseMapping{
{
Request: testutil.Request{
Method: "GET",
Route: "/hello",
},
Response: testutil.Response{
StatusCode: http.StatusAccepted,
},
},
})
username := "user1"
password := "funSecretPa$$word"
authenicate := fmt.Sprintf("Basic realm=localhost")
validCheck := func(a string) bool {
return a == fmt.Sprintf("Basic %s", basicAuth(username, password))
}
e, c := testServerWithAuth(m, authenicate, validCheck)
defer c()
e.Credentials = &testCredentialStore{
username: username,
password: password,
}
client, err := e.HTTPClient("test/repo/basic")
if err != nil {
t.Fatalf("Error creating http client: %s", err)
}
req, _ := http.NewRequest("GET", e.Endpoint+"/hello", nil)
resp, err := client.Do(req)
if err != nil {
t.Fatalf("Error sending get request: %s", err)
}
if resp.StatusCode != http.StatusAccepted {
t.Fatalf("Unexpected status code: %d, expected %d", resp.StatusCode, http.StatusAccepted)
}
}

View File

@ -25,8 +25,8 @@ import (
"golang.org/x/net/context" "golang.org/x/net/context"
) )
// NewRepositoryClient creates a new Repository for the given repository name and endpoint // NewRepository creates a new Repository for the given repository name and endpoint
func NewRepositoryClient(ctx context.Context, name string, endpoint *RepositoryEndpoint) (distribution.Repository, error) { func NewRepository(ctx context.Context, name string, endpoint *RepositoryEndpoint) (distribution.Repository, error) {
if err := v2.ValidateRespositoryName(name); err != nil { if err := v2.ValidateRespositoryName(name); err != nil {
return nil, err return nil, err
} }

View File

@ -97,7 +97,7 @@ func TestLayerFetch(t *testing.T) {
e, c := testServer(m) e, c := testServer(m)
defer c() defer c()
r, err := NewRepositoryClient(context.Background(), "test.example.com/repo1", e) r, err := NewRepository(context.Background(), "test.example.com/repo1", e)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -127,7 +127,7 @@ func TestLayerExists(t *testing.T) {
e, c := testServer(m) e, c := testServer(m)
defer c() defer c()
r, err := NewRepositoryClient(context.Background(), "test.example.com/repo1", e) r, err := NewRepository(context.Background(), "test.example.com/repo1", e)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -227,7 +227,7 @@ func TestLayerUploadChunked(t *testing.T) {
e, c := testServer(m) e, c := testServer(m)
defer c() defer c()
r, err := NewRepositoryClient(context.Background(), repo, e) r, err := NewRepository(context.Background(), repo, e)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -334,7 +334,7 @@ func TestLayerUploadMonolithic(t *testing.T) {
e, c := testServer(m) e, c := testServer(m)
defer c() defer c()
r, err := NewRepositoryClient(context.Background(), repo, e) r, err := NewRepository(context.Background(), repo, e)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -475,7 +475,7 @@ func TestManifestFetch(t *testing.T) {
e, c := testServer(m) e, c := testServer(m)
defer c() defer c()
r, err := NewRepositoryClient(context.Background(), repo, e) r, err := NewRepository(context.Background(), repo, e)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -508,7 +508,7 @@ func TestManifestFetchByTag(t *testing.T) {
e, c := testServer(m) e, c := testServer(m)
defer c() defer c()
r, err := NewRepositoryClient(context.Background(), repo, e) r, err := NewRepository(context.Background(), repo, e)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -553,7 +553,7 @@ func TestManifestDelete(t *testing.T) {
e, c := testServer(m) e, c := testServer(m)
defer c() defer c()
r, err := NewRepositoryClient(context.Background(), repo, e) r, err := NewRepository(context.Background(), repo, e)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -591,7 +591,7 @@ func TestManifestPut(t *testing.T) {
e, c := testServer(m) e, c := testServer(m)
defer c() defer c()
r, err := NewRepositoryClient(context.Background(), repo, e) r, err := NewRepository(context.Background(), repo, e)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@ -6,6 +6,7 @@ import (
"io" "io"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"net/url"
"sort" "sort"
"strings" "strings"
) )
@ -40,16 +41,18 @@ type Request struct {
func (r Request) String() string { func (r Request) String() string {
queryString := "" queryString := ""
if len(r.QueryParams) > 0 { if len(r.QueryParams) > 0 {
queryString = "?"
keys := make([]string, 0, len(r.QueryParams)) keys := make([]string, 0, len(r.QueryParams))
queryParts := make([]string, 0, len(r.QueryParams))
for k := range r.QueryParams { for k := range r.QueryParams {
keys = append(keys, k) keys = append(keys, k)
} }
sort.Strings(keys) sort.Strings(keys)
for _, k := range keys { for _, k := range keys {
queryString += strings.Join(r.QueryParams[k], "&") + "&" for _, val := range r.QueryParams[k] {
queryParts = append(queryParts, fmt.Sprintf("%s=%s", k, url.QueryEscape(val)))
}
} }
queryString = queryString[:len(queryString)-1] queryString = "?" + strings.Join(queryParts, "&")
} }
return fmt.Sprintf("%s %s%s\n%s", r.Method, r.Route, queryString, r.Body) return fmt.Sprintf("%s %s%s\n%s", r.Method, r.Route, queryString, r.Body)
} }