add s3 region filters for cloudfront
Signed-off-by: tifayuki <tifayuki@gmail.com>
This commit is contained in:
@@ -16,7 +16,7 @@ import (
|
||||
"github.com/aws/aws-sdk-go/service/cloudfront/sign"
|
||||
dcontext "github.com/docker/distribution/context"
|
||||
storagedriver "github.com/docker/distribution/registry/storage/driver"
|
||||
storagemiddleware "github.com/docker/distribution/registry/storage/driver/middleware"
|
||||
"github.com/docker/distribution/registry/storage/driver/middleware"
|
||||
)
|
||||
|
||||
// cloudFrontStorageMiddleware provides a simple implementation of layerHandler that
|
||||
@@ -24,6 +24,7 @@ import (
|
||||
// then issues HTTP Temporary Redirects to this CloudFront content URL.
|
||||
type cloudFrontStorageMiddleware struct {
|
||||
storagedriver.StorageDriver
|
||||
awsIPs *awsIPs
|
||||
urlSigner *sign.URLSigner
|
||||
baseURL string
|
||||
duration time.Duration
|
||||
@@ -34,7 +35,13 @@ var _ storagedriver.StorageDriver = &cloudFrontStorageMiddleware{}
|
||||
// newCloudFrontLayerHandler constructs and returns a new CloudFront
|
||||
// LayerHandler implementation.
|
||||
// Required options: baseurl, privatekey, keypairid
|
||||
|
||||
// Optional options: ipFilteredBy, awsregion
|
||||
// ipfilteredby: valid value "none|aws|awsregion". "none", do not filter any IP, default value. "aws", only aws IP goes
|
||||
// to S3 directly. "awsregion", only regions listed in awsregion options goes to S3 directly
|
||||
// awsregion: a comma separated string of AWS regions.
|
||||
func newCloudFrontStorageMiddleware(storageDriver storagedriver.StorageDriver, options map[string]interface{}) (storagedriver.StorageDriver, error) {
|
||||
// parse baseurl
|
||||
base, ok := options["baseurl"]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no baseurl provided")
|
||||
@@ -52,6 +59,8 @@ func newCloudFrontStorageMiddleware(storageDriver storagedriver.StorageDriver, o
|
||||
if _, err := url.Parse(baseURL); err != nil {
|
||||
return nil, fmt.Errorf("invalid baseurl: %v", err)
|
||||
}
|
||||
|
||||
// parse privatekey to get pkPath
|
||||
pk, ok := options["privatekey"]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no privatekey provided")
|
||||
@@ -60,6 +69,8 @@ func newCloudFrontStorageMiddleware(storageDriver storagedriver.StorageDriver, o
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("privatekey must be a string")
|
||||
}
|
||||
|
||||
// parse keypairid
|
||||
kpid, ok := options["keypairid"]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no keypairid provided")
|
||||
@@ -69,6 +80,7 @@ func newCloudFrontStorageMiddleware(storageDriver storagedriver.StorageDriver, o
|
||||
return nil, fmt.Errorf("keypairid must be a string")
|
||||
}
|
||||
|
||||
// get urlSigner from the file specified in pkPath
|
||||
pkBytes, err := ioutil.ReadFile(pkPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read privatekey file: %s", err)
|
||||
@@ -82,12 +94,11 @@ func newCloudFrontStorageMiddleware(storageDriver storagedriver.StorageDriver, o
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
urlSigner := sign.NewURLSigner(keypairID, privateKey)
|
||||
|
||||
// parse duration
|
||||
duration := 20 * time.Minute
|
||||
d, ok := options["duration"]
|
||||
if ok {
|
||||
if d, ok := options["duration"]; ok {
|
||||
switch d := d.(type) {
|
||||
case time.Duration:
|
||||
duration = d
|
||||
@@ -100,11 +111,62 @@ func newCloudFrontStorageMiddleware(storageDriver storagedriver.StorageDriver, o
|
||||
}
|
||||
}
|
||||
|
||||
// parse updatefrenquency
|
||||
updateFrequency := defaultUpdateFrequency
|
||||
if u, ok := options["updatefrenquency"]; ok {
|
||||
switch u := u.(type) {
|
||||
case time.Duration:
|
||||
updateFrequency = u
|
||||
case string:
|
||||
updateFreq, err := time.ParseDuration(u)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid updatefrenquency: %s", err)
|
||||
}
|
||||
duration = updateFreq
|
||||
}
|
||||
}
|
||||
|
||||
// parse iprangesurl
|
||||
ipRangesURL := defaultIPRangesURL
|
||||
if i, ok := options["iprangesurl"]; ok {
|
||||
if iprangeurl, ok := i.(string); ok {
|
||||
ipRangesURL = iprangeurl
|
||||
} else {
|
||||
return nil, fmt.Errorf("iprangesurl must be a string")
|
||||
}
|
||||
}
|
||||
|
||||
// parse ipfilteredby
|
||||
var awsIPs *awsIPs
|
||||
if ipFilteredBy := options["ipfilteredby"].(string); ok {
|
||||
switch strings.ToLower(strings.TrimSpace(ipFilteredBy)) {
|
||||
case "", "none":
|
||||
awsIPs = nil
|
||||
case "aws":
|
||||
newAWSIPs(ipRangesURL, updateFrequency, nil)
|
||||
case "awsregion":
|
||||
var awsRegion []string
|
||||
if regions, ok := options["awsregion"].(string); ok {
|
||||
for _, awsRegions := range strings.Split(regions, ",") {
|
||||
awsRegion = append(awsRegion, strings.ToLower(strings.TrimSpace(awsRegions)))
|
||||
}
|
||||
awsIPs = newAWSIPs(ipRangesURL, updateFrequency, awsRegion)
|
||||
} else {
|
||||
return nil, fmt.Errorf("awsRegion must be a comma separated string of valid aws regions")
|
||||
}
|
||||
default:
|
||||
return nil, fmt.Errorf("ipfilteredby only allows a string the following value: none|aws|awsregion")
|
||||
}
|
||||
} else {
|
||||
return nil, fmt.Errorf("ipfilteredby only allows a string with the following value: none|aws|awsregion")
|
||||
}
|
||||
|
||||
return &cloudFrontStorageMiddleware{
|
||||
StorageDriver: storageDriver,
|
||||
urlSigner: urlSigner,
|
||||
baseURL: baseURL,
|
||||
duration: duration,
|
||||
awsIPs: awsIPs,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -114,8 +176,8 @@ type S3BucketKeyer interface {
|
||||
S3BucketKey(path string) string
|
||||
}
|
||||
|
||||
// Resolve returns an http.Handler which can serve the contents of the given
|
||||
// Layer, or an error if not supported by the storagedriver.
|
||||
// URLFor attempts to find a url which may be used to retrieve the file at the given path.
|
||||
// Returns an error if the file cannot be found.
|
||||
func (lh *cloudFrontStorageMiddleware) URLFor(ctx context.Context, path string, options map[string]interface{}) (string, error) {
|
||||
// TODO(endophage): currently only supports S3
|
||||
keyer, ok := lh.StorageDriver.(S3BucketKeyer)
|
||||
@@ -124,6 +186,11 @@ func (lh *cloudFrontStorageMiddleware) URLFor(ctx context.Context, path string,
|
||||
return lh.StorageDriver.URLFor(ctx, path, options)
|
||||
}
|
||||
|
||||
if eligibleForS3(ctx, lh.awsIPs) {
|
||||
return lh.StorageDriver.URLFor(ctx, path, options)
|
||||
}
|
||||
|
||||
// Get signed cloudfront url.
|
||||
cfURL, err := lh.urlSigner.Sign(lh.baseURL+keyer.S3BucketKey(path), time.Now().Add(lh.duration))
|
||||
if err != nil {
|
||||
return "", err
|
||||
|
223
registry/storage/driver/middleware/cloudfront/s3filter.go
Normal file
223
registry/storage/driver/middleware/cloudfront/s3filter.go
Normal file
@@ -0,0 +1,223 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
dcontext "github.com/docker/distribution/context"
|
||||
)
|
||||
|
||||
const (
|
||||
// ipRangesURL is the URL to get definition of AWS IPs
|
||||
defaultIPRangesURL = "https://ip-ranges.amazonaws.com/ip-ranges.json"
|
||||
// updateFrequency tells how frequently AWS IPs need to be updated
|
||||
defaultUpdateFrequency = time.Hour * 12
|
||||
)
|
||||
|
||||
// newAWSIPs returns a New awsIP object.
|
||||
// If awsRegion is `nil`, it accepts any region. Otherwise, it only allow the regions specified
|
||||
func newAWSIPs(host string, updateFrequency time.Duration, awsRegion []string) *awsIPs {
|
||||
ips := &awsIPs{
|
||||
host: host,
|
||||
updateFrequency: updateFrequency,
|
||||
awsRegion: awsRegion,
|
||||
updaterStopChan: make(chan bool),
|
||||
}
|
||||
if err := ips.tryUpdate(); err != nil {
|
||||
dcontext.GetLogger(context.Background()).WithError(err).Warn("failed to update AWS IP")
|
||||
}
|
||||
go ips.updater()
|
||||
return ips
|
||||
}
|
||||
|
||||
// awsIPs tracks a list of AWS ips, filtered by awsRegion
|
||||
type awsIPs struct {
|
||||
host string
|
||||
updateFrequency time.Duration
|
||||
ipv4 []net.IPNet
|
||||
ipv6 []net.IPNet
|
||||
mutex sync.RWMutex
|
||||
awsRegion []string
|
||||
updaterStopChan chan bool
|
||||
initialized bool
|
||||
}
|
||||
|
||||
type awsIPResponse struct {
|
||||
Prefixes []prefixEntry `json:"prefixes"`
|
||||
V6Prefixes []prefixEntry `json:"ipv6_prefixes"`
|
||||
}
|
||||
|
||||
type prefixEntry struct {
|
||||
IPV4Prefix string `json:"ip_prefix"`
|
||||
IPV6Prefix string `json:"ipv6_prefix"`
|
||||
Region string `json:"region"`
|
||||
Service string `json:"service"`
|
||||
}
|
||||
|
||||
func fetchAWSIPs(url string) (awsIPResponse, error) {
|
||||
var response awsIPResponse
|
||||
resp, err := http.Get(url)
|
||||
if err != nil {
|
||||
return response, err
|
||||
}
|
||||
if resp.StatusCode != 200 {
|
||||
body, _ := ioutil.ReadAll(resp.Body)
|
||||
return response, fmt.Errorf("failed to fetch network data. response = %s", body)
|
||||
}
|
||||
decoder := json.NewDecoder(resp.Body)
|
||||
err = decoder.Decode(&response)
|
||||
if err != nil {
|
||||
return response, err
|
||||
}
|
||||
return response, nil
|
||||
}
|
||||
|
||||
// tryUpdate attempts to download the new set of ip addresses.
|
||||
// tryUpdate must be thread safe with contains
|
||||
func (s *awsIPs) tryUpdate() error {
|
||||
response, err := fetchAWSIPs(s.host)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var ipv4 []net.IPNet
|
||||
var ipv6 []net.IPNet
|
||||
|
||||
processAddress := func(output *[]net.IPNet, prefix string, region string) {
|
||||
regionAllowed := false
|
||||
if len(s.awsRegion) > 0 {
|
||||
for _, ar := range s.awsRegion {
|
||||
if strings.ToLower(region) == ar {
|
||||
regionAllowed = true
|
||||
break
|
||||
}
|
||||
}
|
||||
} else {
|
||||
regionAllowed = true
|
||||
}
|
||||
|
||||
_, network, err := net.ParseCIDR(prefix)
|
||||
if err != nil {
|
||||
dcontext.GetLoggerWithFields(dcontext.Background(), map[interface{}]interface{}{
|
||||
"cidr": prefix,
|
||||
}).Error("unparseable cidr")
|
||||
return
|
||||
}
|
||||
if regionAllowed {
|
||||
*output = append(*output, *network)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
for _, prefix := range response.Prefixes {
|
||||
processAddress(&ipv4, prefix.IPV4Prefix, prefix.Region)
|
||||
}
|
||||
for _, prefix := range response.V6Prefixes {
|
||||
processAddress(&ipv6, prefix.IPV6Prefix, prefix.Region)
|
||||
}
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
// Update each attr of awsips atomically.
|
||||
s.ipv4 = ipv4
|
||||
s.ipv6 = ipv6
|
||||
s.initialized = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// This function is meant to be run in a background goroutine.
|
||||
// It will periodically update the ips from aws.
|
||||
func (s *awsIPs) updater() {
|
||||
defer close(s.updaterStopChan)
|
||||
for {
|
||||
time.Sleep(s.updateFrequency)
|
||||
select {
|
||||
case <-s.updaterStopChan:
|
||||
dcontext.GetLogger(context.Background()).Info("aws ip updater received stop signal")
|
||||
return
|
||||
default:
|
||||
err := s.tryUpdate()
|
||||
if err != nil {
|
||||
dcontext.GetLogger(context.Background()).WithError(err).Error("git AWS IP")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// getCandidateNetworks returns either the ipv4 or ipv6 networks
|
||||
// that were last read from aws. The networks returned
|
||||
// have the same type as the ip address provided.
|
||||
func (s *awsIPs) getCandidateNetworks(ip net.IP) []net.IPNet {
|
||||
s.mutex.RLock()
|
||||
defer s.mutex.RUnlock()
|
||||
if ip.To4() != nil {
|
||||
return s.ipv4
|
||||
} else if ip.To16() != nil {
|
||||
return s.ipv6
|
||||
} else {
|
||||
dcontext.GetLoggerWithFields(dcontext.Background(), map[interface{}]interface{}{
|
||||
"ip": ip,
|
||||
}).Error("unknown ip address format")
|
||||
// assume mismatch, pass through cloudfront
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Contains determines whether the host is within aws.
|
||||
func (s *awsIPs) contains(ip net.IP) bool {
|
||||
networks := s.getCandidateNetworks(ip)
|
||||
for _, network := range networks {
|
||||
if network.Contains(ip) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// parseIPFromRequest attempts to extract the ip address of the
|
||||
// client that made the request
|
||||
func parseIPFromRequest(ctx context.Context) (net.IP, error) {
|
||||
request, err := dcontext.GetRequest(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ipStr := dcontext.RemoteIP(request)
|
||||
ip := net.ParseIP(ipStr)
|
||||
if ip == nil {
|
||||
return nil, fmt.Errorf("invalid ip address from requester: %s", ipStr)
|
||||
}
|
||||
|
||||
return ip, nil
|
||||
}
|
||||
|
||||
// eligibleForS3 checks if a request is eligible for using S3 directly
|
||||
// Return true only when the IP belongs to a specific aws region and user-agent is docker
|
||||
func eligibleForS3(ctx context.Context, awsIPs *awsIPs) bool {
|
||||
if awsIPs != nil && awsIPs.initialized {
|
||||
if addr, err := parseIPFromRequest(ctx); err == nil {
|
||||
request, err := dcontext.GetRequest(ctx)
|
||||
if err != nil {
|
||||
dcontext.GetLogger(ctx).Warnf("the CloudFront middleware cannot parse the request: %s", err)
|
||||
} else {
|
||||
loggerField := map[interface{}]interface{}{
|
||||
"user-client": request.UserAgent(),
|
||||
"ip": dcontext.RemoteIP(request),
|
||||
}
|
||||
if awsIPs.contains(addr) {
|
||||
dcontext.GetLoggerWithFields(ctx, loggerField).Info("request from the allowed AWS region, skipping CloudFront")
|
||||
return true
|
||||
}
|
||||
dcontext.GetLoggerWithFields(ctx, loggerField).Warn("request not from the allowed AWS region, fallback to CloudFront")
|
||||
}
|
||||
} else {
|
||||
dcontext.GetLogger(ctx).WithError(err).Warn("failed to parse ip address from context, fallback to CloudFront")
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
401
registry/storage/driver/middleware/cloudfront/s3filter_test.go
Normal file
401
registry/storage/driver/middleware/cloudfront/s3filter_test.go
Normal file
@@ -0,0 +1,401 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
dcontext "github.com/docker/distribution/context"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"reflect" // used as a replacement for testify
|
||||
)
|
||||
|
||||
// Rather than pull in all of testify
|
||||
func assertEqual(t *testing.T, x, y interface{}) {
|
||||
if !reflect.DeepEqual(x, y) {
|
||||
t.Errorf("%s: Not equal! Expected='%v', Actual='%v'\n", t.Name(), x, y)
|
||||
t.FailNow()
|
||||
}
|
||||
}
|
||||
|
||||
type mockIPRangeHandler struct {
|
||||
data awsIPResponse
|
||||
}
|
||||
|
||||
func (m mockIPRangeHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
bytes, err := json.Marshal(m.data)
|
||||
if err != nil {
|
||||
w.WriteHeader(500)
|
||||
return
|
||||
}
|
||||
w.Write(bytes)
|
||||
|
||||
}
|
||||
|
||||
func newTestHandler(data awsIPResponse) *httptest.Server {
|
||||
return httptest.NewServer(mockIPRangeHandler{
|
||||
data: data,
|
||||
})
|
||||
}
|
||||
|
||||
func serverIPRanges(server *httptest.Server) string {
|
||||
return fmt.Sprintf("%s/", server.URL)
|
||||
}
|
||||
|
||||
func setupTest(data awsIPResponse) *httptest.Server {
|
||||
// This is a basic schema which only claims the exact ip
|
||||
// is in aws.
|
||||
server := newTestHandler(data)
|
||||
return server
|
||||
}
|
||||
|
||||
func TestS3TryUpdate(t *testing.T) {
|
||||
t.Parallel()
|
||||
server := setupTest(awsIPResponse{
|
||||
Prefixes: []prefixEntry{
|
||||
{IPV4Prefix: "123.231.123.231/32"},
|
||||
},
|
||||
})
|
||||
defer server.Close()
|
||||
|
||||
ips := newAWSIPs(serverIPRanges(server), time.Hour, nil)
|
||||
|
||||
assertEqual(t, 1, len(ips.ipv4))
|
||||
assertEqual(t, 0, len(ips.ipv6))
|
||||
|
||||
}
|
||||
|
||||
func TestMatchIPV6(t *testing.T) {
|
||||
t.Parallel()
|
||||
server := setupTest(awsIPResponse{
|
||||
V6Prefixes: []prefixEntry{
|
||||
{IPV6Prefix: "ff00::/16"},
|
||||
},
|
||||
})
|
||||
defer server.Close()
|
||||
|
||||
ips := newAWSIPs(serverIPRanges(server), time.Hour, nil)
|
||||
ips.tryUpdate()
|
||||
assertEqual(t, true, ips.contains(net.ParseIP("ff00::")))
|
||||
assertEqual(t, 1, len(ips.ipv6))
|
||||
assertEqual(t, 0, len(ips.ipv4))
|
||||
}
|
||||
|
||||
func TestMatchIPV4(t *testing.T) {
|
||||
t.Parallel()
|
||||
server := setupTest(awsIPResponse{
|
||||
Prefixes: []prefixEntry{
|
||||
{IPV4Prefix: "192.168.0.0/24"},
|
||||
},
|
||||
})
|
||||
defer server.Close()
|
||||
|
||||
ips := newAWSIPs(serverIPRanges(server), time.Hour, nil)
|
||||
ips.tryUpdate()
|
||||
assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.0")))
|
||||
assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.1")))
|
||||
assertEqual(t, false, ips.contains(net.ParseIP("192.169.0.0")))
|
||||
}
|
||||
|
||||
func TestMatchIPV4_2(t *testing.T) {
|
||||
t.Parallel()
|
||||
server := setupTest(awsIPResponse{
|
||||
Prefixes: []prefixEntry{
|
||||
{
|
||||
IPV4Prefix: "192.168.0.0/24",
|
||||
Region: "us-east-1",
|
||||
},
|
||||
},
|
||||
})
|
||||
defer server.Close()
|
||||
|
||||
ips := newAWSIPs(serverIPRanges(server), time.Hour, nil)
|
||||
ips.tryUpdate()
|
||||
assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.0")))
|
||||
assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.1")))
|
||||
assertEqual(t, false, ips.contains(net.ParseIP("192.169.0.0")))
|
||||
}
|
||||
|
||||
func TestMatchIPV4WithRegionMatched(t *testing.T) {
|
||||
t.Parallel()
|
||||
server := setupTest(awsIPResponse{
|
||||
Prefixes: []prefixEntry{
|
||||
{
|
||||
IPV4Prefix: "192.168.0.0/24",
|
||||
Region: "us-east-1",
|
||||
},
|
||||
},
|
||||
})
|
||||
defer server.Close()
|
||||
|
||||
ips := newAWSIPs(serverIPRanges(server), time.Hour, []string{"us-east-1"})
|
||||
ips.tryUpdate()
|
||||
assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.0")))
|
||||
assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.1")))
|
||||
assertEqual(t, false, ips.contains(net.ParseIP("192.169.0.0")))
|
||||
}
|
||||
|
||||
func TestMatchIPV4WithRegionMatch_2(t *testing.T) {
|
||||
t.Parallel()
|
||||
server := setupTest(awsIPResponse{
|
||||
Prefixes: []prefixEntry{
|
||||
{
|
||||
IPV4Prefix: "192.168.0.0/24",
|
||||
Region: "us-east-1",
|
||||
},
|
||||
},
|
||||
})
|
||||
defer server.Close()
|
||||
|
||||
ips := newAWSIPs(serverIPRanges(server), time.Hour, []string{"us-west-2", "us-east-1"})
|
||||
ips.tryUpdate()
|
||||
assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.0")))
|
||||
assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.1")))
|
||||
assertEqual(t, false, ips.contains(net.ParseIP("192.169.0.0")))
|
||||
}
|
||||
|
||||
func TestMatchIPV4WithRegionNotMatched(t *testing.T) {
|
||||
t.Parallel()
|
||||
server := setupTest(awsIPResponse{
|
||||
Prefixes: []prefixEntry{
|
||||
{
|
||||
IPV4Prefix: "192.168.0.0/24",
|
||||
Region: "us-east-1",
|
||||
},
|
||||
},
|
||||
})
|
||||
defer server.Close()
|
||||
|
||||
ips := newAWSIPs(serverIPRanges(server), time.Hour, []string{"us-west-2"})
|
||||
ips.tryUpdate()
|
||||
assertEqual(t, false, ips.contains(net.ParseIP("192.168.0.0")))
|
||||
assertEqual(t, false, ips.contains(net.ParseIP("192.168.0.1")))
|
||||
assertEqual(t, false, ips.contains(net.ParseIP("192.169.0.0")))
|
||||
}
|
||||
|
||||
func TestInvalidData(t *testing.T) {
|
||||
t.Parallel()
|
||||
// Invalid entries from aws should be ignored.
|
||||
server := setupTest(awsIPResponse{
|
||||
Prefixes: []prefixEntry{
|
||||
{IPV4Prefix: "9000"},
|
||||
{IPV4Prefix: "192.168.0.0/24"},
|
||||
},
|
||||
})
|
||||
defer server.Close()
|
||||
|
||||
ips := newAWSIPs(serverIPRanges(server), time.Hour, nil)
|
||||
ips.tryUpdate()
|
||||
assertEqual(t, 1, len(ips.ipv4))
|
||||
}
|
||||
|
||||
func TestInvalidNetworkType(t *testing.T) {
|
||||
t.Parallel()
|
||||
server := setupTest(awsIPResponse{
|
||||
Prefixes: []prefixEntry{
|
||||
{IPV4Prefix: "192.168.0.0/24"},
|
||||
},
|
||||
V6Prefixes: []prefixEntry{
|
||||
{IPV6Prefix: "ff00::/8"},
|
||||
{IPV6Prefix: "fe00::/8"},
|
||||
},
|
||||
})
|
||||
defer server.Close()
|
||||
|
||||
ips := newAWSIPs(serverIPRanges(server), time.Hour, nil)
|
||||
assertEqual(t, 0, len(ips.getCandidateNetworks(make([]byte, 17)))) // 17 bytes does not correspond to any net type
|
||||
assertEqual(t, 1, len(ips.getCandidateNetworks(make([]byte, 4)))) // netv4 networks
|
||||
assertEqual(t, 2, len(ips.getCandidateNetworks(make([]byte, 16)))) // netv6 networks
|
||||
}
|
||||
|
||||
func TestParsing(t *testing.T) {
|
||||
var data = `{
|
||||
"prefixes": [{
|
||||
"ip_prefix": "192.168.0.0",
|
||||
"region": "someregion",
|
||||
"service": "s3"}],
|
||||
"ipv6_prefixes": [{
|
||||
"ipv6_prefix": "2001:4860:4860::8888",
|
||||
"region": "anotherregion",
|
||||
"service": "ec2"}]
|
||||
}`
|
||||
rawMockHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Write([]byte(data)) })
|
||||
t.Parallel()
|
||||
server := httptest.NewServer(rawMockHandler)
|
||||
defer server.Close()
|
||||
schema, err := fetchAWSIPs(server.URL)
|
||||
|
||||
assertEqual(t, nil, err)
|
||||
assertEqual(t, 1, len(schema.Prefixes))
|
||||
assertEqual(t, prefixEntry{
|
||||
IPV4Prefix: "192.168.0.0",
|
||||
Region: "someregion",
|
||||
Service: "s3",
|
||||
}, schema.Prefixes[0])
|
||||
assertEqual(t, 1, len(schema.V6Prefixes))
|
||||
assertEqual(t, prefixEntry{
|
||||
IPV6Prefix: "2001:4860:4860::8888",
|
||||
Region: "anotherregion",
|
||||
Service: "ec2",
|
||||
}, schema.V6Prefixes[0])
|
||||
}
|
||||
|
||||
func TestUpdateCalledRegularly(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
updateCount := 0
|
||||
server := httptest.NewServer(http.HandlerFunc(
|
||||
func(rw http.ResponseWriter, req *http.Request) {
|
||||
updateCount++
|
||||
rw.Write([]byte("ok"))
|
||||
}))
|
||||
defer server.Close()
|
||||
newAWSIPs(fmt.Sprintf("%s/", server.URL), time.Second, nil)
|
||||
time.Sleep(time.Second*4 + time.Millisecond*500)
|
||||
if updateCount < 4 {
|
||||
t.Errorf("Update should have been called at least 4 times, actual=%d", updateCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEligibleForS3(t *testing.T) {
|
||||
awsIPs := &awsIPs{
|
||||
ipv4: []net.IPNet{{
|
||||
IP: net.ParseIP("192.168.1.1"),
|
||||
Mask: net.IPv4Mask(255, 255, 255, 0),
|
||||
}},
|
||||
initialized: true,
|
||||
}
|
||||
empty := context.TODO()
|
||||
makeContext := func(ip string) context.Context {
|
||||
req := &http.Request{
|
||||
RemoteAddr: ip,
|
||||
}
|
||||
|
||||
return dcontext.WithRequest(empty, req)
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
Context context.Context
|
||||
Expected bool
|
||||
}{
|
||||
{Context: empty, Expected: false},
|
||||
{Context: makeContext("192.168.1.2"), Expected: true},
|
||||
{Context: makeContext("192.168.0.2"), Expected: false},
|
||||
}
|
||||
|
||||
for _, testCase := range cases {
|
||||
name := fmt.Sprintf("Client IP = %v",
|
||||
testCase.Context.Value("http.request.ip"))
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assertEqual(t, testCase.Expected, eligibleForS3(testCase.Context, awsIPs))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEligibleForS3WithAWSIPNotInitialized(t *testing.T) {
|
||||
awsIPs := &awsIPs{
|
||||
ipv4: []net.IPNet{{
|
||||
IP: net.ParseIP("192.168.1.1"),
|
||||
Mask: net.IPv4Mask(255, 255, 255, 0),
|
||||
}},
|
||||
initialized: false,
|
||||
}
|
||||
empty := context.TODO()
|
||||
makeContext := func(ip string) context.Context {
|
||||
req := &http.Request{
|
||||
RemoteAddr: ip,
|
||||
}
|
||||
|
||||
return dcontext.WithRequest(empty, req)
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
Context context.Context
|
||||
Expected bool
|
||||
}{
|
||||
{Context: empty, Expected: false},
|
||||
{Context: makeContext("192.168.1.2"), Expected: false},
|
||||
{Context: makeContext("192.168.0.2"), Expected: false},
|
||||
}
|
||||
|
||||
for _, testCase := range cases {
|
||||
name := fmt.Sprintf("Client IP = %v",
|
||||
testCase.Context.Value("http.request.ip"))
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assertEqual(t, testCase.Expected, eligibleForS3(testCase.Context, awsIPs))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// populate ips with a number of different ipv4 and ipv6 networks, for the purposes
|
||||
// of benchmarking contains() performance.
|
||||
func populateRandomNetworks(b *testing.B, ips *awsIPs, ipv4Count, ipv6Count int) {
|
||||
generateNetworks := func(dest *[]net.IPNet, bytes int, count int) {
|
||||
for i := 0; i < count; i++ {
|
||||
ip := make([]byte, bytes)
|
||||
_, err := rand.Read(ip)
|
||||
if err != nil {
|
||||
b.Fatalf("failed to generate network for test : %s", err.Error())
|
||||
}
|
||||
mask := make([]byte, bytes)
|
||||
for i := 0; i < bytes; i++ {
|
||||
mask[i] = 0xff
|
||||
}
|
||||
*dest = append(*dest, net.IPNet{
|
||||
IP: ip,
|
||||
Mask: mask,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
generateNetworks(&ips.ipv4, 4, ipv4Count)
|
||||
generateNetworks(&ips.ipv6, 16, ipv6Count)
|
||||
}
|
||||
|
||||
func BenchmarkContainsRandom(b *testing.B) {
|
||||
// Generate a random network configuration, of size comparable to
|
||||
// aws official networks list
|
||||
// curl -s https://ip-ranges.amazonaws.com/ip-ranges.json | jq '.prefixes | length'
|
||||
// 941
|
||||
numNetworksPerType := 1000 // keep in sync with the above
|
||||
// intentionally skip constructor when creating awsIPs, to avoid updater routine.
|
||||
// This benchmark is only concerned with contains() performance.
|
||||
awsIPs := awsIPs{}
|
||||
populateRandomNetworks(b, &awsIPs, numNetworksPerType, numNetworksPerType)
|
||||
|
||||
ipv4 := make([][]byte, b.N)
|
||||
ipv6 := make([][]byte, b.N)
|
||||
for i := 0; i < b.N; i++ {
|
||||
ipv4[i] = make([]byte, 4)
|
||||
ipv6[i] = make([]byte, 16)
|
||||
rand.Read(ipv4[i])
|
||||
rand.Read(ipv6[i])
|
||||
}
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
awsIPs.contains(ipv4[i])
|
||||
awsIPs.contains(ipv6[i])
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkContainsProd(b *testing.B) {
|
||||
awsIPs := newAWSIPs(defaultIPRangesURL, defaultUpdateFrequency, nil)
|
||||
ipv4 := make([][]byte, b.N)
|
||||
ipv6 := make([][]byte, b.N)
|
||||
for i := 0; i < b.N; i++ {
|
||||
ipv4[i] = make([]byte, 4)
|
||||
ipv6[i] = make([]byte, 16)
|
||||
rand.Read(ipv4[i])
|
||||
rand.Read(ipv6[i])
|
||||
}
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
awsIPs.contains(ipv4[i])
|
||||
awsIPs.contains(ipv6[i])
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user