731e0b0066
You shouldn't have to import both: github.com/docker/distribution/context golang.org/x/net/context just to use the distribution tools and implement the distribution interfaces. By pulling the Context interface from golang.org/x/net/context into the context package within the distribution project, you no longer have to import both packages. Note: You do not have to change anything anywhere else yet! All current uses of both packages together will still work correctly because the Context interface from either package is identical. I've also made some other minor changes: - Added a RemoteIP function. It's like RemoteAddr but discards the port suffix - Added `.String()` to the response duration context value so that JSON log formatting shows human-parseable duration and not just number of nano-seconds - Added WithMapContext(...) to the context package. This is a useful function so I pulled it out of the main.go in cmd/registry so that it can be used elsewhere. Docker-DCO-1.1-Signed-off-by: Josh Hawn <josh.hawn@docker.com> (github: jlhawn)
273 lines
5.9 KiB
Go
273 lines
5.9 KiB
Go
package context
|
|
|
|
import (
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"net/http/httputil"
|
|
"net/url"
|
|
"reflect"
|
|
"testing"
|
|
"time"
|
|
)
|
|
|
|
func TestWithRequest(t *testing.T) {
|
|
var req http.Request
|
|
|
|
start := time.Now()
|
|
req.Method = "GET"
|
|
req.Host = "example.com"
|
|
req.RequestURI = "/test-test"
|
|
req.Header = make(http.Header)
|
|
req.Header.Set("Referer", "foo.com/referer")
|
|
req.Header.Set("User-Agent", "test/0.1")
|
|
|
|
ctx := WithRequest(Background(), &req)
|
|
for _, testcase := range []struct {
|
|
key string
|
|
expected interface{}
|
|
}{
|
|
{
|
|
key: "http.request",
|
|
expected: &req,
|
|
},
|
|
{
|
|
key: "http.request.id",
|
|
},
|
|
{
|
|
key: "http.request.method",
|
|
expected: req.Method,
|
|
},
|
|
{
|
|
key: "http.request.host",
|
|
expected: req.Host,
|
|
},
|
|
{
|
|
key: "http.request.uri",
|
|
expected: req.RequestURI,
|
|
},
|
|
{
|
|
key: "http.request.referer",
|
|
expected: req.Referer(),
|
|
},
|
|
{
|
|
key: "http.request.useragent",
|
|
expected: req.UserAgent(),
|
|
},
|
|
{
|
|
key: "http.request.remoteaddr",
|
|
expected: req.RemoteAddr,
|
|
},
|
|
{
|
|
key: "http.request.startedat",
|
|
},
|
|
} {
|
|
v := ctx.Value(testcase.key)
|
|
|
|
if v == nil {
|
|
t.Fatalf("value not found for %q", testcase.key)
|
|
}
|
|
|
|
if testcase.expected != nil && v != testcase.expected {
|
|
t.Fatalf("%s: %v != %v", testcase.key, v, testcase.expected)
|
|
}
|
|
|
|
// Key specific checks!
|
|
switch testcase.key {
|
|
case "http.request.id":
|
|
if _, ok := v.(string); !ok {
|
|
t.Fatalf("request id not a string: %v", v)
|
|
}
|
|
case "http.request.startedat":
|
|
vt, ok := v.(time.Time)
|
|
if !ok {
|
|
t.Fatalf("value not a time: %v", v)
|
|
}
|
|
|
|
now := time.Now()
|
|
if vt.After(now) {
|
|
t.Fatalf("time generated too late: %v > %v", vt, now)
|
|
}
|
|
|
|
if vt.Before(start) {
|
|
t.Fatalf("time generated too early: %v < %v", vt, start)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
type testResponseWriter struct {
|
|
flushed bool
|
|
status int
|
|
written int64
|
|
header http.Header
|
|
}
|
|
|
|
func (trw *testResponseWriter) Header() http.Header {
|
|
if trw.header == nil {
|
|
trw.header = make(http.Header)
|
|
}
|
|
|
|
return trw.header
|
|
}
|
|
|
|
func (trw *testResponseWriter) Write(p []byte) (n int, err error) {
|
|
if trw.status == 0 {
|
|
trw.status = http.StatusOK
|
|
}
|
|
|
|
n = len(p)
|
|
trw.written += int64(n)
|
|
return
|
|
}
|
|
|
|
func (trw *testResponseWriter) WriteHeader(status int) {
|
|
trw.status = status
|
|
}
|
|
|
|
func (trw *testResponseWriter) Flush() {
|
|
trw.flushed = true
|
|
}
|
|
|
|
func TestWithResponseWriter(t *testing.T) {
|
|
trw := testResponseWriter{}
|
|
ctx, rw := WithResponseWriter(Background(), &trw)
|
|
|
|
if ctx.Value("http.response") != &trw {
|
|
t.Fatalf("response not available in context: %v != %v", ctx.Value("http.response"), &trw)
|
|
}
|
|
|
|
if n, err := rw.Write(make([]byte, 1024)); err != nil {
|
|
t.Fatalf("unexpected error writing: %v", err)
|
|
} else if n != 1024 {
|
|
t.Fatalf("unexpected number of bytes written: %v != %v", n, 1024)
|
|
}
|
|
|
|
if ctx.Value("http.response.status") != http.StatusOK {
|
|
t.Fatalf("unexpected response status in context: %v != %v", ctx.Value("http.response.status"), http.StatusOK)
|
|
}
|
|
|
|
if ctx.Value("http.response.written") != int64(1024) {
|
|
t.Fatalf("unexpected number reported bytes written: %v != %v", ctx.Value("http.response.written"), 1024)
|
|
}
|
|
|
|
// Make sure flush propagates
|
|
rw.(http.Flusher).Flush()
|
|
|
|
if !trw.flushed {
|
|
t.Fatalf("response writer not flushed")
|
|
}
|
|
|
|
// Write another status and make sure context is correct. This normally
|
|
// wouldn't work except for in this contrived testcase.
|
|
rw.WriteHeader(http.StatusBadRequest)
|
|
|
|
if ctx.Value("http.response.status") != http.StatusBadRequest {
|
|
t.Fatalf("unexpected response status in context: %v != %v", ctx.Value("http.response.status"), http.StatusBadRequest)
|
|
}
|
|
}
|
|
|
|
func TestWithVars(t *testing.T) {
|
|
var req http.Request
|
|
vars := map[string]string{
|
|
"foo": "asdf",
|
|
"bar": "qwer",
|
|
}
|
|
|
|
getVarsFromRequest = func(r *http.Request) map[string]string {
|
|
if r != &req {
|
|
t.Fatalf("unexpected request: %v != %v", r, req)
|
|
}
|
|
|
|
return vars
|
|
}
|
|
|
|
ctx := WithVars(Background(), &req)
|
|
for _, testcase := range []struct {
|
|
key string
|
|
expected interface{}
|
|
}{
|
|
{
|
|
key: "vars",
|
|
expected: vars,
|
|
},
|
|
{
|
|
key: "vars.foo",
|
|
expected: "asdf",
|
|
},
|
|
{
|
|
key: "vars.bar",
|
|
expected: "qwer",
|
|
},
|
|
} {
|
|
v := ctx.Value(testcase.key)
|
|
|
|
if !reflect.DeepEqual(v, testcase.expected) {
|
|
t.Fatalf("%q: %v != %v", testcase.key, v, testcase.expected)
|
|
}
|
|
}
|
|
}
|
|
|
|
// SingleHostReverseProxy will insert an X-Forwarded-For header, and can be used to test
|
|
// RemoteAddr(). A fake RemoteAddr cannot be set on the HTTP request - it is overwritten
|
|
// at the transport layer to 127.0.0.1:<port> . However, as the X-Forwarded-For header
|
|
// just contains the IP address, it is different enough for testing.
|
|
func TestRemoteAddr(t *testing.T) {
|
|
var expectedRemote string
|
|
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
defer r.Body.Close()
|
|
|
|
if r.RemoteAddr == expectedRemote {
|
|
t.Errorf("Unexpected matching remote addresses")
|
|
}
|
|
|
|
actualRemote := RemoteAddr(r)
|
|
if expectedRemote != actualRemote {
|
|
t.Errorf("Mismatching remote hosts: %v != %v", expectedRemote, actualRemote)
|
|
}
|
|
|
|
w.WriteHeader(200)
|
|
}))
|
|
|
|
defer backend.Close()
|
|
backendURL, err := url.Parse(backend.URL)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
proxy := httputil.NewSingleHostReverseProxy(backendURL)
|
|
frontend := httptest.NewServer(proxy)
|
|
defer frontend.Close()
|
|
|
|
// X-Forwarded-For set by proxy
|
|
expectedRemote = "127.0.0.1"
|
|
proxyReq, err := http.NewRequest("GET", frontend.URL, nil)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
_, err = http.DefaultClient.Do(proxyReq)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
// RemoteAddr in X-Real-Ip
|
|
getReq, err := http.NewRequest("GET", backend.URL, nil)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
expectedRemote = "1.2.3.4"
|
|
getReq.Header["X-Real-ip"] = []string{expectedRemote}
|
|
_, err = http.DefaultClient.Do(getReq)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
// Valid X-Real-Ip and invalid X-Forwarded-For
|
|
getReq.Header["X-forwarded-for"] = []string{"1.2.3"}
|
|
_, err = http.DefaultClient.Do(getReq)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
}
|