add s3 region filters for cloudfront
Signed-off-by: tifayuki <tifayuki@gmail.com>
This commit is contained in:
parent
3800056b88
commit
e8ecc6dc55
1
.gitignore
vendored
1
.gitignore
vendored
@ -35,3 +35,4 @@ bin/*
|
|||||||
# Editor/IDE specific files.
|
# Editor/IDE specific files.
|
||||||
*.sublime-project
|
*.sublime-project
|
||||||
*.sublime-workspace
|
*.sublime-workspace
|
||||||
|
.idea/*
|
||||||
|
@ -39,6 +39,8 @@ type Logger interface {
|
|||||||
Warn(args ...interface{})
|
Warn(args ...interface{})
|
||||||
Warnf(format string, args ...interface{})
|
Warnf(format string, args ...interface{})
|
||||||
Warnln(args ...interface{})
|
Warnln(args ...interface{})
|
||||||
|
|
||||||
|
WithError(err error) *logrus.Entry
|
||||||
}
|
}
|
||||||
|
|
||||||
type loggerKey struct{}
|
type loggerKey struct{}
|
||||||
|
@ -183,6 +183,10 @@ middleware:
|
|||||||
privatekey: /path/to/pem
|
privatekey: /path/to/pem
|
||||||
keypairid: cloudfrontkeypairid
|
keypairid: cloudfrontkeypairid
|
||||||
duration: 3000s
|
duration: 3000s
|
||||||
|
ipfilteredby: awsregion
|
||||||
|
awsregion: us-east-1, use-east-2
|
||||||
|
updatefrenquency: 12h
|
||||||
|
iprangesurl: https://ip-ranges.amazonaws.com/ip-ranges.json
|
||||||
storage:
|
storage:
|
||||||
- name: redirect
|
- name: redirect
|
||||||
options:
|
options:
|
||||||
@ -636,6 +640,10 @@ middleware:
|
|||||||
privatekey: /path/to/pem
|
privatekey: /path/to/pem
|
||||||
keypairid: cloudfrontkeypairid
|
keypairid: cloudfrontkeypairid
|
||||||
duration: 3000s
|
duration: 3000s
|
||||||
|
ipfilteredby: awsregion
|
||||||
|
awsregion: us-east-1, use-east-2
|
||||||
|
updatefrenquency: 12h
|
||||||
|
iprangesurl: https://ip-ranges.amazonaws.com/ip-ranges.json
|
||||||
```
|
```
|
||||||
|
|
||||||
Each middleware entry has `name` and `options` entries. The `name` must
|
Each middleware entry has `name` and `options` entries. The `name` must
|
||||||
@ -655,6 +663,14 @@ interpretation of the options.
|
|||||||
| `privatekey` | yes | The private key for Cloudfront, provided by AWS. |
|
| `privatekey` | yes | The private key for Cloudfront, provided by AWS. |
|
||||||
| `keypairid` | yes | The key pair ID provided by AWS. |
|
| `keypairid` | yes | The key pair ID provided by AWS. |
|
||||||
| `duration` | no | An integer and unit for the duration of the Cloudfront session. Valid time units are `ns`, `us` (or `µs`), `ms`, `s`, `m`, or `h`. For example, `3000s` is valid, but `3000 s` is not. If you do not specify a `duration` or you specify an integer without a time unit, the duration defaults to `20m` (20 minutes).|
|
| `duration` | no | An integer and unit for the duration of the Cloudfront session. Valid time units are `ns`, `us` (or `µs`), `ms`, `s`, `m`, or `h`. For example, `3000s` is valid, but `3000 s` is not. If you do not specify a `duration` or you specify an integer without a time unit, the duration defaults to `20m` (20 minutes).|
|
||||||
|
|`ipfilteredby`|no | A string with the following value `none|aws|awsregion`. |
|
||||||
|
|`awsregion`|no | A comma separated string of AWS regions, only available when `ipfilteredby` is `awsregion`. For example, `us-east-1, us-west-2`|
|
||||||
|
|`updatefrenquency`|no | The frequency to update AWS IP regions, default: `12h`|
|
||||||
|
|`iprangesurl`|no | The URL contains the AWS IP ranges information, default: `https://ip-ranges.amazonaws.com/ip-ranges.json`|
|
||||||
|
Then value of ipfilteredby:
|
||||||
|
`none`: default, do not filter by IP
|
||||||
|
`aws`: IP from AWS goes to S3 directly
|
||||||
|
`awsregion`: IP from certain AWS regions goes to S3 directly, use together with `awsregion`
|
||||||
|
|
||||||
### `redirect`
|
### `redirect`
|
||||||
|
|
||||||
|
@ -16,7 +16,7 @@ import (
|
|||||||
"github.com/aws/aws-sdk-go/service/cloudfront/sign"
|
"github.com/aws/aws-sdk-go/service/cloudfront/sign"
|
||||||
dcontext "github.com/docker/distribution/context"
|
dcontext "github.com/docker/distribution/context"
|
||||||
storagedriver "github.com/docker/distribution/registry/storage/driver"
|
storagedriver "github.com/docker/distribution/registry/storage/driver"
|
||||||
storagemiddleware "github.com/docker/distribution/registry/storage/driver/middleware"
|
"github.com/docker/distribution/registry/storage/driver/middleware"
|
||||||
)
|
)
|
||||||
|
|
||||||
// cloudFrontStorageMiddleware provides a simple implementation of layerHandler that
|
// cloudFrontStorageMiddleware provides a simple implementation of layerHandler that
|
||||||
@ -24,6 +24,7 @@ import (
|
|||||||
// then issues HTTP Temporary Redirects to this CloudFront content URL.
|
// then issues HTTP Temporary Redirects to this CloudFront content URL.
|
||||||
type cloudFrontStorageMiddleware struct {
|
type cloudFrontStorageMiddleware struct {
|
||||||
storagedriver.StorageDriver
|
storagedriver.StorageDriver
|
||||||
|
awsIPs *awsIPs
|
||||||
urlSigner *sign.URLSigner
|
urlSigner *sign.URLSigner
|
||||||
baseURL string
|
baseURL string
|
||||||
duration time.Duration
|
duration time.Duration
|
||||||
@ -34,7 +35,13 @@ var _ storagedriver.StorageDriver = &cloudFrontStorageMiddleware{}
|
|||||||
// newCloudFrontLayerHandler constructs and returns a new CloudFront
|
// newCloudFrontLayerHandler constructs and returns a new CloudFront
|
||||||
// LayerHandler implementation.
|
// LayerHandler implementation.
|
||||||
// Required options: baseurl, privatekey, keypairid
|
// Required options: baseurl, privatekey, keypairid
|
||||||
|
|
||||||
|
// Optional options: ipFilteredBy, awsregion
|
||||||
|
// ipfilteredby: valid value "none|aws|awsregion". "none", do not filter any IP, default value. "aws", only aws IP goes
|
||||||
|
// to S3 directly. "awsregion", only regions listed in awsregion options goes to S3 directly
|
||||||
|
// awsregion: a comma separated string of AWS regions.
|
||||||
func newCloudFrontStorageMiddleware(storageDriver storagedriver.StorageDriver, options map[string]interface{}) (storagedriver.StorageDriver, error) {
|
func newCloudFrontStorageMiddleware(storageDriver storagedriver.StorageDriver, options map[string]interface{}) (storagedriver.StorageDriver, error) {
|
||||||
|
// parse baseurl
|
||||||
base, ok := options["baseurl"]
|
base, ok := options["baseurl"]
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("no baseurl provided")
|
return nil, fmt.Errorf("no baseurl provided")
|
||||||
@ -52,6 +59,8 @@ func newCloudFrontStorageMiddleware(storageDriver storagedriver.StorageDriver, o
|
|||||||
if _, err := url.Parse(baseURL); err != nil {
|
if _, err := url.Parse(baseURL); err != nil {
|
||||||
return nil, fmt.Errorf("invalid baseurl: %v", err)
|
return nil, fmt.Errorf("invalid baseurl: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// parse privatekey to get pkPath
|
||||||
pk, ok := options["privatekey"]
|
pk, ok := options["privatekey"]
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("no privatekey provided")
|
return nil, fmt.Errorf("no privatekey provided")
|
||||||
@ -60,6 +69,8 @@ func newCloudFrontStorageMiddleware(storageDriver storagedriver.StorageDriver, o
|
|||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("privatekey must be a string")
|
return nil, fmt.Errorf("privatekey must be a string")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// parse keypairid
|
||||||
kpid, ok := options["keypairid"]
|
kpid, ok := options["keypairid"]
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("no keypairid provided")
|
return nil, fmt.Errorf("no keypairid provided")
|
||||||
@ -69,6 +80,7 @@ func newCloudFrontStorageMiddleware(storageDriver storagedriver.StorageDriver, o
|
|||||||
return nil, fmt.Errorf("keypairid must be a string")
|
return nil, fmt.Errorf("keypairid must be a string")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// get urlSigner from the file specified in pkPath
|
||||||
pkBytes, err := ioutil.ReadFile(pkPath)
|
pkBytes, err := ioutil.ReadFile(pkPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to read privatekey file: %s", err)
|
return nil, fmt.Errorf("failed to read privatekey file: %s", err)
|
||||||
@ -82,12 +94,11 @@ func newCloudFrontStorageMiddleware(storageDriver storagedriver.StorageDriver, o
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
urlSigner := sign.NewURLSigner(keypairID, privateKey)
|
urlSigner := sign.NewURLSigner(keypairID, privateKey)
|
||||||
|
|
||||||
|
// parse duration
|
||||||
duration := 20 * time.Minute
|
duration := 20 * time.Minute
|
||||||
d, ok := options["duration"]
|
if d, ok := options["duration"]; ok {
|
||||||
if ok {
|
|
||||||
switch d := d.(type) {
|
switch d := d.(type) {
|
||||||
case time.Duration:
|
case time.Duration:
|
||||||
duration = d
|
duration = d
|
||||||
@ -100,11 +111,62 @@ func newCloudFrontStorageMiddleware(storageDriver storagedriver.StorageDriver, o
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// parse updatefrenquency
|
||||||
|
updateFrequency := defaultUpdateFrequency
|
||||||
|
if u, ok := options["updatefrenquency"]; ok {
|
||||||
|
switch u := u.(type) {
|
||||||
|
case time.Duration:
|
||||||
|
updateFrequency = u
|
||||||
|
case string:
|
||||||
|
updateFreq, err := time.ParseDuration(u)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid updatefrenquency: %s", err)
|
||||||
|
}
|
||||||
|
duration = updateFreq
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// parse iprangesurl
|
||||||
|
ipRangesURL := defaultIPRangesURL
|
||||||
|
if i, ok := options["iprangesurl"]; ok {
|
||||||
|
if iprangeurl, ok := i.(string); ok {
|
||||||
|
ipRangesURL = iprangeurl
|
||||||
|
} else {
|
||||||
|
return nil, fmt.Errorf("iprangesurl must be a string")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// parse ipfilteredby
|
||||||
|
var awsIPs *awsIPs
|
||||||
|
if ipFilteredBy := options["ipfilteredby"].(string); ok {
|
||||||
|
switch strings.ToLower(strings.TrimSpace(ipFilteredBy)) {
|
||||||
|
case "", "none":
|
||||||
|
awsIPs = nil
|
||||||
|
case "aws":
|
||||||
|
newAWSIPs(ipRangesURL, updateFrequency, nil)
|
||||||
|
case "awsregion":
|
||||||
|
var awsRegion []string
|
||||||
|
if regions, ok := options["awsregion"].(string); ok {
|
||||||
|
for _, awsRegions := range strings.Split(regions, ",") {
|
||||||
|
awsRegion = append(awsRegion, strings.ToLower(strings.TrimSpace(awsRegions)))
|
||||||
|
}
|
||||||
|
awsIPs = newAWSIPs(ipRangesURL, updateFrequency, awsRegion)
|
||||||
|
} else {
|
||||||
|
return nil, fmt.Errorf("awsRegion must be a comma separated string of valid aws regions")
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("ipfilteredby only allows a string the following value: none|aws|awsregion")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return nil, fmt.Errorf("ipfilteredby only allows a string with the following value: none|aws|awsregion")
|
||||||
|
}
|
||||||
|
|
||||||
return &cloudFrontStorageMiddleware{
|
return &cloudFrontStorageMiddleware{
|
||||||
StorageDriver: storageDriver,
|
StorageDriver: storageDriver,
|
||||||
urlSigner: urlSigner,
|
urlSigner: urlSigner,
|
||||||
baseURL: baseURL,
|
baseURL: baseURL,
|
||||||
duration: duration,
|
duration: duration,
|
||||||
|
awsIPs: awsIPs,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -114,8 +176,8 @@ type S3BucketKeyer interface {
|
|||||||
S3BucketKey(path string) string
|
S3BucketKey(path string) string
|
||||||
}
|
}
|
||||||
|
|
||||||
// Resolve returns an http.Handler which can serve the contents of the given
|
// URLFor attempts to find a url which may be used to retrieve the file at the given path.
|
||||||
// Layer, or an error if not supported by the storagedriver.
|
// Returns an error if the file cannot be found.
|
||||||
func (lh *cloudFrontStorageMiddleware) URLFor(ctx context.Context, path string, options map[string]interface{}) (string, error) {
|
func (lh *cloudFrontStorageMiddleware) URLFor(ctx context.Context, path string, options map[string]interface{}) (string, error) {
|
||||||
// TODO(endophage): currently only supports S3
|
// TODO(endophage): currently only supports S3
|
||||||
keyer, ok := lh.StorageDriver.(S3BucketKeyer)
|
keyer, ok := lh.StorageDriver.(S3BucketKeyer)
|
||||||
@ -124,6 +186,11 @@ func (lh *cloudFrontStorageMiddleware) URLFor(ctx context.Context, path string,
|
|||||||
return lh.StorageDriver.URLFor(ctx, path, options)
|
return lh.StorageDriver.URLFor(ctx, path, options)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if eligibleForS3(ctx, lh.awsIPs) {
|
||||||
|
return lh.StorageDriver.URLFor(ctx, path, options)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get signed cloudfront url.
|
||||||
cfURL, err := lh.urlSigner.Sign(lh.baseURL+keyer.S3BucketKey(path), time.Now().Add(lh.duration))
|
cfURL, err := lh.urlSigner.Sign(lh.baseURL+keyer.S3BucketKey(path), time.Now().Add(lh.duration))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
|
223
registry/storage/driver/middleware/cloudfront/s3filter.go
Normal file
223
registry/storage/driver/middleware/cloudfront/s3filter.go
Normal file
@ -0,0 +1,223 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io/ioutil"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
dcontext "github.com/docker/distribution/context"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// ipRangesURL is the URL to get definition of AWS IPs
|
||||||
|
defaultIPRangesURL = "https://ip-ranges.amazonaws.com/ip-ranges.json"
|
||||||
|
// updateFrequency tells how frequently AWS IPs need to be updated
|
||||||
|
defaultUpdateFrequency = time.Hour * 12
|
||||||
|
)
|
||||||
|
|
||||||
|
// newAWSIPs returns a New awsIP object.
|
||||||
|
// If awsRegion is `nil`, it accepts any region. Otherwise, it only allow the regions specified
|
||||||
|
func newAWSIPs(host string, updateFrequency time.Duration, awsRegion []string) *awsIPs {
|
||||||
|
ips := &awsIPs{
|
||||||
|
host: host,
|
||||||
|
updateFrequency: updateFrequency,
|
||||||
|
awsRegion: awsRegion,
|
||||||
|
updaterStopChan: make(chan bool),
|
||||||
|
}
|
||||||
|
if err := ips.tryUpdate(); err != nil {
|
||||||
|
dcontext.GetLogger(context.Background()).WithError(err).Warn("failed to update AWS IP")
|
||||||
|
}
|
||||||
|
go ips.updater()
|
||||||
|
return ips
|
||||||
|
}
|
||||||
|
|
||||||
|
// awsIPs tracks a list of AWS ips, filtered by awsRegion
|
||||||
|
type awsIPs struct {
|
||||||
|
host string
|
||||||
|
updateFrequency time.Duration
|
||||||
|
ipv4 []net.IPNet
|
||||||
|
ipv6 []net.IPNet
|
||||||
|
mutex sync.RWMutex
|
||||||
|
awsRegion []string
|
||||||
|
updaterStopChan chan bool
|
||||||
|
initialized bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type awsIPResponse struct {
|
||||||
|
Prefixes []prefixEntry `json:"prefixes"`
|
||||||
|
V6Prefixes []prefixEntry `json:"ipv6_prefixes"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type prefixEntry struct {
|
||||||
|
IPV4Prefix string `json:"ip_prefix"`
|
||||||
|
IPV6Prefix string `json:"ipv6_prefix"`
|
||||||
|
Region string `json:"region"`
|
||||||
|
Service string `json:"service"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func fetchAWSIPs(url string) (awsIPResponse, error) {
|
||||||
|
var response awsIPResponse
|
||||||
|
resp, err := http.Get(url)
|
||||||
|
if err != nil {
|
||||||
|
return response, err
|
||||||
|
}
|
||||||
|
if resp.StatusCode != 200 {
|
||||||
|
body, _ := ioutil.ReadAll(resp.Body)
|
||||||
|
return response, fmt.Errorf("failed to fetch network data. response = %s", body)
|
||||||
|
}
|
||||||
|
decoder := json.NewDecoder(resp.Body)
|
||||||
|
err = decoder.Decode(&response)
|
||||||
|
if err != nil {
|
||||||
|
return response, err
|
||||||
|
}
|
||||||
|
return response, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// tryUpdate attempts to download the new set of ip addresses.
|
||||||
|
// tryUpdate must be thread safe with contains
|
||||||
|
func (s *awsIPs) tryUpdate() error {
|
||||||
|
response, err := fetchAWSIPs(s.host)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
var ipv4 []net.IPNet
|
||||||
|
var ipv6 []net.IPNet
|
||||||
|
|
||||||
|
processAddress := func(output *[]net.IPNet, prefix string, region string) {
|
||||||
|
regionAllowed := false
|
||||||
|
if len(s.awsRegion) > 0 {
|
||||||
|
for _, ar := range s.awsRegion {
|
||||||
|
if strings.ToLower(region) == ar {
|
||||||
|
regionAllowed = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
regionAllowed = true
|
||||||
|
}
|
||||||
|
|
||||||
|
_, network, err := net.ParseCIDR(prefix)
|
||||||
|
if err != nil {
|
||||||
|
dcontext.GetLoggerWithFields(dcontext.Background(), map[interface{}]interface{}{
|
||||||
|
"cidr": prefix,
|
||||||
|
}).Error("unparseable cidr")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if regionAllowed {
|
||||||
|
*output = append(*output, *network)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, prefix := range response.Prefixes {
|
||||||
|
processAddress(&ipv4, prefix.IPV4Prefix, prefix.Region)
|
||||||
|
}
|
||||||
|
for _, prefix := range response.V6Prefixes {
|
||||||
|
processAddress(&ipv6, prefix.IPV6Prefix, prefix.Region)
|
||||||
|
}
|
||||||
|
s.mutex.Lock()
|
||||||
|
defer s.mutex.Unlock()
|
||||||
|
// Update each attr of awsips atomically.
|
||||||
|
s.ipv4 = ipv4
|
||||||
|
s.ipv6 = ipv6
|
||||||
|
s.initialized = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// This function is meant to be run in a background goroutine.
|
||||||
|
// It will periodically update the ips from aws.
|
||||||
|
func (s *awsIPs) updater() {
|
||||||
|
defer close(s.updaterStopChan)
|
||||||
|
for {
|
||||||
|
time.Sleep(s.updateFrequency)
|
||||||
|
select {
|
||||||
|
case <-s.updaterStopChan:
|
||||||
|
dcontext.GetLogger(context.Background()).Info("aws ip updater received stop signal")
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
err := s.tryUpdate()
|
||||||
|
if err != nil {
|
||||||
|
dcontext.GetLogger(context.Background()).WithError(err).Error("git AWS IP")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// getCandidateNetworks returns either the ipv4 or ipv6 networks
|
||||||
|
// that were last read from aws. The networks returned
|
||||||
|
// have the same type as the ip address provided.
|
||||||
|
func (s *awsIPs) getCandidateNetworks(ip net.IP) []net.IPNet {
|
||||||
|
s.mutex.RLock()
|
||||||
|
defer s.mutex.RUnlock()
|
||||||
|
if ip.To4() != nil {
|
||||||
|
return s.ipv4
|
||||||
|
} else if ip.To16() != nil {
|
||||||
|
return s.ipv6
|
||||||
|
} else {
|
||||||
|
dcontext.GetLoggerWithFields(dcontext.Background(), map[interface{}]interface{}{
|
||||||
|
"ip": ip,
|
||||||
|
}).Error("unknown ip address format")
|
||||||
|
// assume mismatch, pass through cloudfront
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Contains determines whether the host is within aws.
|
||||||
|
func (s *awsIPs) contains(ip net.IP) bool {
|
||||||
|
networks := s.getCandidateNetworks(ip)
|
||||||
|
for _, network := range networks {
|
||||||
|
if network.Contains(ip) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseIPFromRequest attempts to extract the ip address of the
|
||||||
|
// client that made the request
|
||||||
|
func parseIPFromRequest(ctx context.Context) (net.IP, error) {
|
||||||
|
request, err := dcontext.GetRequest(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
ipStr := dcontext.RemoteIP(request)
|
||||||
|
ip := net.ParseIP(ipStr)
|
||||||
|
if ip == nil {
|
||||||
|
return nil, fmt.Errorf("invalid ip address from requester: %s", ipStr)
|
||||||
|
}
|
||||||
|
|
||||||
|
return ip, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// eligibleForS3 checks if a request is eligible for using S3 directly
|
||||||
|
// Return true only when the IP belongs to a specific aws region and user-agent is docker
|
||||||
|
func eligibleForS3(ctx context.Context, awsIPs *awsIPs) bool {
|
||||||
|
if awsIPs != nil && awsIPs.initialized {
|
||||||
|
if addr, err := parseIPFromRequest(ctx); err == nil {
|
||||||
|
request, err := dcontext.GetRequest(ctx)
|
||||||
|
if err != nil {
|
||||||
|
dcontext.GetLogger(ctx).Warnf("the CloudFront middleware cannot parse the request: %s", err)
|
||||||
|
} else {
|
||||||
|
loggerField := map[interface{}]interface{}{
|
||||||
|
"user-client": request.UserAgent(),
|
||||||
|
"ip": dcontext.RemoteIP(request),
|
||||||
|
}
|
||||||
|
if awsIPs.contains(addr) {
|
||||||
|
dcontext.GetLoggerWithFields(ctx, loggerField).Info("request from the allowed AWS region, skipping CloudFront")
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
dcontext.GetLoggerWithFields(ctx, loggerField).Warn("request not from the allowed AWS region, fallback to CloudFront")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
dcontext.GetLogger(ctx).WithError(err).Warn("failed to parse ip address from context, fallback to CloudFront")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
401
registry/storage/driver/middleware/cloudfront/s3filter_test.go
Normal file
401
registry/storage/driver/middleware/cloudfront/s3filter_test.go
Normal file
@ -0,0 +1,401 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
dcontext "github.com/docker/distribution/context"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"reflect" // used as a replacement for testify
|
||||||
|
)
|
||||||
|
|
||||||
|
// Rather than pull in all of testify
|
||||||
|
func assertEqual(t *testing.T, x, y interface{}) {
|
||||||
|
if !reflect.DeepEqual(x, y) {
|
||||||
|
t.Errorf("%s: Not equal! Expected='%v', Actual='%v'\n", t.Name(), x, y)
|
||||||
|
t.FailNow()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type mockIPRangeHandler struct {
|
||||||
|
data awsIPResponse
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m mockIPRangeHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
|
bytes, err := json.Marshal(m.data)
|
||||||
|
if err != nil {
|
||||||
|
w.WriteHeader(500)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.Write(bytes)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTestHandler(data awsIPResponse) *httptest.Server {
|
||||||
|
return httptest.NewServer(mockIPRangeHandler{
|
||||||
|
data: data,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func serverIPRanges(server *httptest.Server) string {
|
||||||
|
return fmt.Sprintf("%s/", server.URL)
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupTest(data awsIPResponse) *httptest.Server {
|
||||||
|
// This is a basic schema which only claims the exact ip
|
||||||
|
// is in aws.
|
||||||
|
server := newTestHandler(data)
|
||||||
|
return server
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestS3TryUpdate(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
server := setupTest(awsIPResponse{
|
||||||
|
Prefixes: []prefixEntry{
|
||||||
|
{IPV4Prefix: "123.231.123.231/32"},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
ips := newAWSIPs(serverIPRanges(server), time.Hour, nil)
|
||||||
|
|
||||||
|
assertEqual(t, 1, len(ips.ipv4))
|
||||||
|
assertEqual(t, 0, len(ips.ipv6))
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMatchIPV6(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
server := setupTest(awsIPResponse{
|
||||||
|
V6Prefixes: []prefixEntry{
|
||||||
|
{IPV6Prefix: "ff00::/16"},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
ips := newAWSIPs(serverIPRanges(server), time.Hour, nil)
|
||||||
|
ips.tryUpdate()
|
||||||
|
assertEqual(t, true, ips.contains(net.ParseIP("ff00::")))
|
||||||
|
assertEqual(t, 1, len(ips.ipv6))
|
||||||
|
assertEqual(t, 0, len(ips.ipv4))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMatchIPV4(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
server := setupTest(awsIPResponse{
|
||||||
|
Prefixes: []prefixEntry{
|
||||||
|
{IPV4Prefix: "192.168.0.0/24"},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
ips := newAWSIPs(serverIPRanges(server), time.Hour, nil)
|
||||||
|
ips.tryUpdate()
|
||||||
|
assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.0")))
|
||||||
|
assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.1")))
|
||||||
|
assertEqual(t, false, ips.contains(net.ParseIP("192.169.0.0")))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMatchIPV4_2(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
server := setupTest(awsIPResponse{
|
||||||
|
Prefixes: []prefixEntry{
|
||||||
|
{
|
||||||
|
IPV4Prefix: "192.168.0.0/24",
|
||||||
|
Region: "us-east-1",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
ips := newAWSIPs(serverIPRanges(server), time.Hour, nil)
|
||||||
|
ips.tryUpdate()
|
||||||
|
assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.0")))
|
||||||
|
assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.1")))
|
||||||
|
assertEqual(t, false, ips.contains(net.ParseIP("192.169.0.0")))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMatchIPV4WithRegionMatched(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
server := setupTest(awsIPResponse{
|
||||||
|
Prefixes: []prefixEntry{
|
||||||
|
{
|
||||||
|
IPV4Prefix: "192.168.0.0/24",
|
||||||
|
Region: "us-east-1",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
ips := newAWSIPs(serverIPRanges(server), time.Hour, []string{"us-east-1"})
|
||||||
|
ips.tryUpdate()
|
||||||
|
assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.0")))
|
||||||
|
assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.1")))
|
||||||
|
assertEqual(t, false, ips.contains(net.ParseIP("192.169.0.0")))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMatchIPV4WithRegionMatch_2(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
server := setupTest(awsIPResponse{
|
||||||
|
Prefixes: []prefixEntry{
|
||||||
|
{
|
||||||
|
IPV4Prefix: "192.168.0.0/24",
|
||||||
|
Region: "us-east-1",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
ips := newAWSIPs(serverIPRanges(server), time.Hour, []string{"us-west-2", "us-east-1"})
|
||||||
|
ips.tryUpdate()
|
||||||
|
assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.0")))
|
||||||
|
assertEqual(t, true, ips.contains(net.ParseIP("192.168.0.1")))
|
||||||
|
assertEqual(t, false, ips.contains(net.ParseIP("192.169.0.0")))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMatchIPV4WithRegionNotMatched(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
server := setupTest(awsIPResponse{
|
||||||
|
Prefixes: []prefixEntry{
|
||||||
|
{
|
||||||
|
IPV4Prefix: "192.168.0.0/24",
|
||||||
|
Region: "us-east-1",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
ips := newAWSIPs(serverIPRanges(server), time.Hour, []string{"us-west-2"})
|
||||||
|
ips.tryUpdate()
|
||||||
|
assertEqual(t, false, ips.contains(net.ParseIP("192.168.0.0")))
|
||||||
|
assertEqual(t, false, ips.contains(net.ParseIP("192.168.0.1")))
|
||||||
|
assertEqual(t, false, ips.contains(net.ParseIP("192.169.0.0")))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInvalidData(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
// Invalid entries from aws should be ignored.
|
||||||
|
server := setupTest(awsIPResponse{
|
||||||
|
Prefixes: []prefixEntry{
|
||||||
|
{IPV4Prefix: "9000"},
|
||||||
|
{IPV4Prefix: "192.168.0.0/24"},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
ips := newAWSIPs(serverIPRanges(server), time.Hour, nil)
|
||||||
|
ips.tryUpdate()
|
||||||
|
assertEqual(t, 1, len(ips.ipv4))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInvalidNetworkType(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
server := setupTest(awsIPResponse{
|
||||||
|
Prefixes: []prefixEntry{
|
||||||
|
{IPV4Prefix: "192.168.0.0/24"},
|
||||||
|
},
|
||||||
|
V6Prefixes: []prefixEntry{
|
||||||
|
{IPV6Prefix: "ff00::/8"},
|
||||||
|
{IPV6Prefix: "fe00::/8"},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
ips := newAWSIPs(serverIPRanges(server), time.Hour, nil)
|
||||||
|
assertEqual(t, 0, len(ips.getCandidateNetworks(make([]byte, 17)))) // 17 bytes does not correspond to any net type
|
||||||
|
assertEqual(t, 1, len(ips.getCandidateNetworks(make([]byte, 4)))) // netv4 networks
|
||||||
|
assertEqual(t, 2, len(ips.getCandidateNetworks(make([]byte, 16)))) // netv6 networks
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParsing(t *testing.T) {
|
||||||
|
var data = `{
|
||||||
|
"prefixes": [{
|
||||||
|
"ip_prefix": "192.168.0.0",
|
||||||
|
"region": "someregion",
|
||||||
|
"service": "s3"}],
|
||||||
|
"ipv6_prefixes": [{
|
||||||
|
"ipv6_prefix": "2001:4860:4860::8888",
|
||||||
|
"region": "anotherregion",
|
||||||
|
"service": "ec2"}]
|
||||||
|
}`
|
||||||
|
rawMockHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Write([]byte(data)) })
|
||||||
|
t.Parallel()
|
||||||
|
server := httptest.NewServer(rawMockHandler)
|
||||||
|
defer server.Close()
|
||||||
|
schema, err := fetchAWSIPs(server.URL)
|
||||||
|
|
||||||
|
assertEqual(t, nil, err)
|
||||||
|
assertEqual(t, 1, len(schema.Prefixes))
|
||||||
|
assertEqual(t, prefixEntry{
|
||||||
|
IPV4Prefix: "192.168.0.0",
|
||||||
|
Region: "someregion",
|
||||||
|
Service: "s3",
|
||||||
|
}, schema.Prefixes[0])
|
||||||
|
assertEqual(t, 1, len(schema.V6Prefixes))
|
||||||
|
assertEqual(t, prefixEntry{
|
||||||
|
IPV6Prefix: "2001:4860:4860::8888",
|
||||||
|
Region: "anotherregion",
|
||||||
|
Service: "ec2",
|
||||||
|
}, schema.V6Prefixes[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdateCalledRegularly(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
updateCount := 0
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(
|
||||||
|
func(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
updateCount++
|
||||||
|
rw.Write([]byte("ok"))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
newAWSIPs(fmt.Sprintf("%s/", server.URL), time.Second, nil)
|
||||||
|
time.Sleep(time.Second*4 + time.Millisecond*500)
|
||||||
|
if updateCount < 4 {
|
||||||
|
t.Errorf("Update should have been called at least 4 times, actual=%d", updateCount)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEligibleForS3(t *testing.T) {
|
||||||
|
awsIPs := &awsIPs{
|
||||||
|
ipv4: []net.IPNet{{
|
||||||
|
IP: net.ParseIP("192.168.1.1"),
|
||||||
|
Mask: net.IPv4Mask(255, 255, 255, 0),
|
||||||
|
}},
|
||||||
|
initialized: true,
|
||||||
|
}
|
||||||
|
empty := context.TODO()
|
||||||
|
makeContext := func(ip string) context.Context {
|
||||||
|
req := &http.Request{
|
||||||
|
RemoteAddr: ip,
|
||||||
|
}
|
||||||
|
|
||||||
|
return dcontext.WithRequest(empty, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
cases := []struct {
|
||||||
|
Context context.Context
|
||||||
|
Expected bool
|
||||||
|
}{
|
||||||
|
{Context: empty, Expected: false},
|
||||||
|
{Context: makeContext("192.168.1.2"), Expected: true},
|
||||||
|
{Context: makeContext("192.168.0.2"), Expected: false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, testCase := range cases {
|
||||||
|
name := fmt.Sprintf("Client IP = %v",
|
||||||
|
testCase.Context.Value("http.request.ip"))
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
assertEqual(t, testCase.Expected, eligibleForS3(testCase.Context, awsIPs))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEligibleForS3WithAWSIPNotInitialized(t *testing.T) {
|
||||||
|
awsIPs := &awsIPs{
|
||||||
|
ipv4: []net.IPNet{{
|
||||||
|
IP: net.ParseIP("192.168.1.1"),
|
||||||
|
Mask: net.IPv4Mask(255, 255, 255, 0),
|
||||||
|
}},
|
||||||
|
initialized: false,
|
||||||
|
}
|
||||||
|
empty := context.TODO()
|
||||||
|
makeContext := func(ip string) context.Context {
|
||||||
|
req := &http.Request{
|
||||||
|
RemoteAddr: ip,
|
||||||
|
}
|
||||||
|
|
||||||
|
return dcontext.WithRequest(empty, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
cases := []struct {
|
||||||
|
Context context.Context
|
||||||
|
Expected bool
|
||||||
|
}{
|
||||||
|
{Context: empty, Expected: false},
|
||||||
|
{Context: makeContext("192.168.1.2"), Expected: false},
|
||||||
|
{Context: makeContext("192.168.0.2"), Expected: false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, testCase := range cases {
|
||||||
|
name := fmt.Sprintf("Client IP = %v",
|
||||||
|
testCase.Context.Value("http.request.ip"))
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
assertEqual(t, testCase.Expected, eligibleForS3(testCase.Context, awsIPs))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// populate ips with a number of different ipv4 and ipv6 networks, for the purposes
|
||||||
|
// of benchmarking contains() performance.
|
||||||
|
func populateRandomNetworks(b *testing.B, ips *awsIPs, ipv4Count, ipv6Count int) {
|
||||||
|
generateNetworks := func(dest *[]net.IPNet, bytes int, count int) {
|
||||||
|
for i := 0; i < count; i++ {
|
||||||
|
ip := make([]byte, bytes)
|
||||||
|
_, err := rand.Read(ip)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatalf("failed to generate network for test : %s", err.Error())
|
||||||
|
}
|
||||||
|
mask := make([]byte, bytes)
|
||||||
|
for i := 0; i < bytes; i++ {
|
||||||
|
mask[i] = 0xff
|
||||||
|
}
|
||||||
|
*dest = append(*dest, net.IPNet{
|
||||||
|
IP: ip,
|
||||||
|
Mask: mask,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
generateNetworks(&ips.ipv4, 4, ipv4Count)
|
||||||
|
generateNetworks(&ips.ipv6, 16, ipv6Count)
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkContainsRandom(b *testing.B) {
|
||||||
|
// Generate a random network configuration, of size comparable to
|
||||||
|
// aws official networks list
|
||||||
|
// curl -s https://ip-ranges.amazonaws.com/ip-ranges.json | jq '.prefixes | length'
|
||||||
|
// 941
|
||||||
|
numNetworksPerType := 1000 // keep in sync with the above
|
||||||
|
// intentionally skip constructor when creating awsIPs, to avoid updater routine.
|
||||||
|
// This benchmark is only concerned with contains() performance.
|
||||||
|
awsIPs := awsIPs{}
|
||||||
|
populateRandomNetworks(b, &awsIPs, numNetworksPerType, numNetworksPerType)
|
||||||
|
|
||||||
|
ipv4 := make([][]byte, b.N)
|
||||||
|
ipv6 := make([][]byte, b.N)
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
ipv4[i] = make([]byte, 4)
|
||||||
|
ipv6[i] = make([]byte, 16)
|
||||||
|
rand.Read(ipv4[i])
|
||||||
|
rand.Read(ipv6[i])
|
||||||
|
}
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
awsIPs.contains(ipv4[i])
|
||||||
|
awsIPs.contains(ipv6[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkContainsProd(b *testing.B) {
|
||||||
|
awsIPs := newAWSIPs(defaultIPRangesURL, defaultUpdateFrequency, nil)
|
||||||
|
ipv4 := make([][]byte, b.N)
|
||||||
|
ipv6 := make([][]byte, b.N)
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
ipv4[i] = make([]byte, 4)
|
||||||
|
ipv6[i] = make([]byte, 16)
|
||||||
|
rand.Read(ipv4[i])
|
||||||
|
rand.Read(ipv6[i])
|
||||||
|
}
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
awsIPs.contains(ipv4[i])
|
||||||
|
awsIPs.contains(ipv6[i])
|
||||||
|
}
|
||||||
|
}
|
340
vendor/github.com/miekg/dns/msg_generate.go
generated
vendored
340
vendor/github.com/miekg/dns/msg_generate.go
generated
vendored
@ -1,340 +0,0 @@
|
|||||||
//+build ignore
|
|
||||||
|
|
||||||
// msg_generate.go is meant to run with go generate. It will use
|
|
||||||
// go/{importer,types} to track down all the RR struct types. Then for each type
|
|
||||||
// it will generate pack/unpack methods based on the struct tags. The generated source is
|
|
||||||
// written to zmsg.go, and is meant to be checked into git.
|
|
||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"fmt"
|
|
||||||
"go/format"
|
|
||||||
"go/importer"
|
|
||||||
"go/types"
|
|
||||||
"log"
|
|
||||||
"os"
|
|
||||||
"strings"
|
|
||||||
)
|
|
||||||
|
|
||||||
var packageHdr = `
|
|
||||||
// *** DO NOT MODIFY ***
|
|
||||||
// AUTOGENERATED BY go generate from msg_generate.go
|
|
||||||
|
|
||||||
package dns
|
|
||||||
|
|
||||||
`
|
|
||||||
|
|
||||||
// getTypeStruct will take a type and the package scope, and return the
|
|
||||||
// (innermost) struct if the type is considered a RR type (currently defined as
|
|
||||||
// those structs beginning with a RR_Header, could be redefined as implementing
|
|
||||||
// the RR interface). The bool return value indicates if embedded structs were
|
|
||||||
// resolved.
|
|
||||||
func getTypeStruct(t types.Type, scope *types.Scope) (*types.Struct, bool) {
|
|
||||||
st, ok := t.Underlying().(*types.Struct)
|
|
||||||
if !ok {
|
|
||||||
return nil, false
|
|
||||||
}
|
|
||||||
if st.Field(0).Type() == scope.Lookup("RR_Header").Type() {
|
|
||||||
return st, false
|
|
||||||
}
|
|
||||||
if st.Field(0).Anonymous() {
|
|
||||||
st, _ := getTypeStruct(st.Field(0).Type(), scope)
|
|
||||||
return st, true
|
|
||||||
}
|
|
||||||
return nil, false
|
|
||||||
}
|
|
||||||
|
|
||||||
func main() {
|
|
||||||
// Import and type-check the package
|
|
||||||
pkg, err := importer.Default().Import("github.com/miekg/dns")
|
|
||||||
fatalIfErr(err)
|
|
||||||
scope := pkg.Scope()
|
|
||||||
|
|
||||||
// Collect actual types (*X)
|
|
||||||
var namedTypes []string
|
|
||||||
for _, name := range scope.Names() {
|
|
||||||
o := scope.Lookup(name)
|
|
||||||
if o == nil || !o.Exported() {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if st, _ := getTypeStruct(o.Type(), scope); st == nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if name == "PrivateRR" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if corresponding TypeX exists
|
|
||||||
if scope.Lookup("Type"+o.Name()) == nil && o.Name() != "RFC3597" {
|
|
||||||
log.Fatalf("Constant Type%s does not exist.", o.Name())
|
|
||||||
}
|
|
||||||
|
|
||||||
namedTypes = append(namedTypes, o.Name())
|
|
||||||
}
|
|
||||||
|
|
||||||
b := &bytes.Buffer{}
|
|
||||||
b.WriteString(packageHdr)
|
|
||||||
|
|
||||||
fmt.Fprint(b, "// pack*() functions\n\n")
|
|
||||||
for _, name := range namedTypes {
|
|
||||||
o := scope.Lookup(name)
|
|
||||||
st, _ := getTypeStruct(o.Type(), scope)
|
|
||||||
|
|
||||||
fmt.Fprintf(b, "func (rr *%s) pack(msg []byte, off int, compression map[string]int, compress bool) (int, error) {\n", name)
|
|
||||||
fmt.Fprint(b, `off, err := rr.Hdr.pack(msg, off, compression, compress)
|
|
||||||
if err != nil {
|
|
||||||
return off, err
|
|
||||||
}
|
|
||||||
headerEnd := off
|
|
||||||
`)
|
|
||||||
for i := 1; i < st.NumFields(); i++ {
|
|
||||||
o := func(s string) {
|
|
||||||
fmt.Fprintf(b, s, st.Field(i).Name())
|
|
||||||
fmt.Fprint(b, `if err != nil {
|
|
||||||
return off, err
|
|
||||||
}
|
|
||||||
`)
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, ok := st.Field(i).Type().(*types.Slice); ok {
|
|
||||||
switch st.Tag(i) {
|
|
||||||
case `dns:"-"`: // ignored
|
|
||||||
case `dns:"txt"`:
|
|
||||||
o("off, err = packStringTxt(rr.%s, msg, off)\n")
|
|
||||||
case `dns:"opt"`:
|
|
||||||
o("off, err = packDataOpt(rr.%s, msg, off)\n")
|
|
||||||
case `dns:"nsec"`:
|
|
||||||
o("off, err = packDataNsec(rr.%s, msg, off)\n")
|
|
||||||
case `dns:"domain-name"`:
|
|
||||||
o("off, err = packDataDomainNames(rr.%s, msg, off, compression, compress)\n")
|
|
||||||
default:
|
|
||||||
log.Fatalln(name, st.Field(i).Name(), st.Tag(i))
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
switch {
|
|
||||||
case st.Tag(i) == `dns:"-"`: // ignored
|
|
||||||
case st.Tag(i) == `dns:"cdomain-name"`:
|
|
||||||
fallthrough
|
|
||||||
case st.Tag(i) == `dns:"domain-name"`:
|
|
||||||
o("off, err = PackDomainName(rr.%s, msg, off, compression, compress)\n")
|
|
||||||
case st.Tag(i) == `dns:"a"`:
|
|
||||||
o("off, err = packDataA(rr.%s, msg, off)\n")
|
|
||||||
case st.Tag(i) == `dns:"aaaa"`:
|
|
||||||
o("off, err = packDataAAAA(rr.%s, msg, off)\n")
|
|
||||||
case st.Tag(i) == `dns:"uint48"`:
|
|
||||||
o("off, err = packUint48(rr.%s, msg, off)\n")
|
|
||||||
case st.Tag(i) == `dns:"txt"`:
|
|
||||||
o("off, err = packString(rr.%s, msg, off)\n")
|
|
||||||
|
|
||||||
case strings.HasPrefix(st.Tag(i), `dns:"size-base32`): // size-base32 can be packed just like base32
|
|
||||||
fallthrough
|
|
||||||
case st.Tag(i) == `dns:"base32"`:
|
|
||||||
o("off, err = packStringBase32(rr.%s, msg, off)\n")
|
|
||||||
|
|
||||||
case strings.HasPrefix(st.Tag(i), `dns:"size-base64`): // size-base64 can be packed just like base64
|
|
||||||
fallthrough
|
|
||||||
case st.Tag(i) == `dns:"base64"`:
|
|
||||||
o("off, err = packStringBase64(rr.%s, msg, off)\n")
|
|
||||||
|
|
||||||
case strings.HasPrefix(st.Tag(i), `dns:"size-hex:SaltLength`): // Hack to fix empty salt length for NSEC3
|
|
||||||
o("if rr.%s == \"-\" { /* do nothing, empty salt */ }\n")
|
|
||||||
continue
|
|
||||||
case strings.HasPrefix(st.Tag(i), `dns:"size-hex`): // size-hex can be packed just like hex
|
|
||||||
fallthrough
|
|
||||||
case st.Tag(i) == `dns:"hex"`:
|
|
||||||
o("off, err = packStringHex(rr.%s, msg, off)\n")
|
|
||||||
|
|
||||||
case st.Tag(i) == `dns:"octet"`:
|
|
||||||
o("off, err = packStringOctet(rr.%s, msg, off)\n")
|
|
||||||
case st.Tag(i) == "":
|
|
||||||
switch st.Field(i).Type().(*types.Basic).Kind() {
|
|
||||||
case types.Uint8:
|
|
||||||
o("off, err = packUint8(rr.%s, msg, off)\n")
|
|
||||||
case types.Uint16:
|
|
||||||
o("off, err = packUint16(rr.%s, msg, off)\n")
|
|
||||||
case types.Uint32:
|
|
||||||
o("off, err = packUint32(rr.%s, msg, off)\n")
|
|
||||||
case types.Uint64:
|
|
||||||
o("off, err = packUint64(rr.%s, msg, off)\n")
|
|
||||||
case types.String:
|
|
||||||
o("off, err = packString(rr.%s, msg, off)\n")
|
|
||||||
default:
|
|
||||||
log.Fatalln(name, st.Field(i).Name())
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
log.Fatalln(name, st.Field(i).Name(), st.Tag(i))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// We have packed everything, only now we know the rdlength of this RR
|
|
||||||
fmt.Fprintln(b, "rr.Header().Rdlength = uint16(off-headerEnd)")
|
|
||||||
fmt.Fprintln(b, "return off, nil }\n")
|
|
||||||
}
|
|
||||||
|
|
||||||
fmt.Fprint(b, "// unpack*() functions\n\n")
|
|
||||||
for _, name := range namedTypes {
|
|
||||||
o := scope.Lookup(name)
|
|
||||||
st, _ := getTypeStruct(o.Type(), scope)
|
|
||||||
|
|
||||||
fmt.Fprintf(b, "func unpack%s(h RR_Header, msg []byte, off int) (RR, int, error) {\n", name)
|
|
||||||
fmt.Fprintf(b, "rr := new(%s)\n", name)
|
|
||||||
fmt.Fprint(b, "rr.Hdr = h\n")
|
|
||||||
fmt.Fprint(b, `if noRdata(h) {
|
|
||||||
return rr, off, nil
|
|
||||||
}
|
|
||||||
var err error
|
|
||||||
rdStart := off
|
|
||||||
_ = rdStart
|
|
||||||
|
|
||||||
`)
|
|
||||||
for i := 1; i < st.NumFields(); i++ {
|
|
||||||
o := func(s string) {
|
|
||||||
fmt.Fprintf(b, s, st.Field(i).Name())
|
|
||||||
fmt.Fprint(b, `if err != nil {
|
|
||||||
return rr, off, err
|
|
||||||
}
|
|
||||||
`)
|
|
||||||
}
|
|
||||||
|
|
||||||
// size-* are special, because they reference a struct member we should use for the length.
|
|
||||||
if strings.HasPrefix(st.Tag(i), `dns:"size-`) {
|
|
||||||
structMember := structMember(st.Tag(i))
|
|
||||||
structTag := structTag(st.Tag(i))
|
|
||||||
switch structTag {
|
|
||||||
case "hex":
|
|
||||||
fmt.Fprintf(b, "rr.%s, off, err = unpackStringHex(msg, off, off + int(rr.%s))\n", st.Field(i).Name(), structMember)
|
|
||||||
case "base32":
|
|
||||||
fmt.Fprintf(b, "rr.%s, off, err = unpackStringBase32(msg, off, off + int(rr.%s))\n", st.Field(i).Name(), structMember)
|
|
||||||
case "base64":
|
|
||||||
fmt.Fprintf(b, "rr.%s, off, err = unpackStringBase64(msg, off, off + int(rr.%s))\n", st.Field(i).Name(), structMember)
|
|
||||||
default:
|
|
||||||
log.Fatalln(name, st.Field(i).Name(), st.Tag(i))
|
|
||||||
}
|
|
||||||
fmt.Fprint(b, `if err != nil {
|
|
||||||
return rr, off, err
|
|
||||||
}
|
|
||||||
`)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, ok := st.Field(i).Type().(*types.Slice); ok {
|
|
||||||
switch st.Tag(i) {
|
|
||||||
case `dns:"-"`: // ignored
|
|
||||||
case `dns:"txt"`:
|
|
||||||
o("rr.%s, off, err = unpackStringTxt(msg, off)\n")
|
|
||||||
case `dns:"opt"`:
|
|
||||||
o("rr.%s, off, err = unpackDataOpt(msg, off)\n")
|
|
||||||
case `dns:"nsec"`:
|
|
||||||
o("rr.%s, off, err = unpackDataNsec(msg, off)\n")
|
|
||||||
case `dns:"domain-name"`:
|
|
||||||
o("rr.%s, off, err = unpackDataDomainNames(msg, off, rdStart + int(rr.Hdr.Rdlength))\n")
|
|
||||||
default:
|
|
||||||
log.Fatalln(name, st.Field(i).Name(), st.Tag(i))
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
switch st.Tag(i) {
|
|
||||||
case `dns:"-"`: // ignored
|
|
||||||
case `dns:"cdomain-name"`:
|
|
||||||
fallthrough
|
|
||||||
case `dns:"domain-name"`:
|
|
||||||
o("rr.%s, off, err = UnpackDomainName(msg, off)\n")
|
|
||||||
case `dns:"a"`:
|
|
||||||
o("rr.%s, off, err = unpackDataA(msg, off)\n")
|
|
||||||
case `dns:"aaaa"`:
|
|
||||||
o("rr.%s, off, err = unpackDataAAAA(msg, off)\n")
|
|
||||||
case `dns:"uint48"`:
|
|
||||||
o("rr.%s, off, err = unpackUint48(msg, off)\n")
|
|
||||||
case `dns:"txt"`:
|
|
||||||
o("rr.%s, off, err = unpackString(msg, off)\n")
|
|
||||||
case `dns:"base32"`:
|
|
||||||
o("rr.%s, off, err = unpackStringBase32(msg, off, rdStart + int(rr.Hdr.Rdlength))\n")
|
|
||||||
case `dns:"base64"`:
|
|
||||||
o("rr.%s, off, err = unpackStringBase64(msg, off, rdStart + int(rr.Hdr.Rdlength))\n")
|
|
||||||
case `dns:"hex"`:
|
|
||||||
o("rr.%s, off, err = unpackStringHex(msg, off, rdStart + int(rr.Hdr.Rdlength))\n")
|
|
||||||
case `dns:"octet"`:
|
|
||||||
o("rr.%s, off, err = unpackStringOctet(msg, off)\n")
|
|
||||||
case "":
|
|
||||||
switch st.Field(i).Type().(*types.Basic).Kind() {
|
|
||||||
case types.Uint8:
|
|
||||||
o("rr.%s, off, err = unpackUint8(msg, off)\n")
|
|
||||||
case types.Uint16:
|
|
||||||
o("rr.%s, off, err = unpackUint16(msg, off)\n")
|
|
||||||
case types.Uint32:
|
|
||||||
o("rr.%s, off, err = unpackUint32(msg, off)\n")
|
|
||||||
case types.Uint64:
|
|
||||||
o("rr.%s, off, err = unpackUint64(msg, off)\n")
|
|
||||||
case types.String:
|
|
||||||
o("rr.%s, off, err = unpackString(msg, off)\n")
|
|
||||||
default:
|
|
||||||
log.Fatalln(name, st.Field(i).Name())
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
log.Fatalln(name, st.Field(i).Name(), st.Tag(i))
|
|
||||||
}
|
|
||||||
// If we've hit len(msg) we return without error.
|
|
||||||
if i < st.NumFields()-1 {
|
|
||||||
fmt.Fprintf(b, `if off == len(msg) {
|
|
||||||
return rr, off, nil
|
|
||||||
}
|
|
||||||
`)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
fmt.Fprintf(b, "return rr, off, err }\n\n")
|
|
||||||
}
|
|
||||||
// Generate typeToUnpack map
|
|
||||||
fmt.Fprintln(b, "var typeToUnpack = map[uint16]func(RR_Header, []byte, int) (RR, int, error){")
|
|
||||||
for _, name := range namedTypes {
|
|
||||||
if name == "RFC3597" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
fmt.Fprintf(b, "Type%s: unpack%s,\n", name, name)
|
|
||||||
}
|
|
||||||
fmt.Fprintln(b, "}\n")
|
|
||||||
|
|
||||||
// gofmt
|
|
||||||
res, err := format.Source(b.Bytes())
|
|
||||||
if err != nil {
|
|
||||||
b.WriteTo(os.Stderr)
|
|
||||||
log.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// write result
|
|
||||||
f, err := os.Create("zmsg.go")
|
|
||||||
fatalIfErr(err)
|
|
||||||
defer f.Close()
|
|
||||||
f.Write(res)
|
|
||||||
}
|
|
||||||
|
|
||||||
// structMember will take a tag like dns:"size-base32:SaltLength" and return the last part of this string.
|
|
||||||
func structMember(s string) string {
|
|
||||||
fields := strings.Split(s, ":")
|
|
||||||
if len(fields) == 0 {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
f := fields[len(fields)-1]
|
|
||||||
// f should have a closing "
|
|
||||||
if len(f) > 1 {
|
|
||||||
return f[:len(f)-1]
|
|
||||||
}
|
|
||||||
return f
|
|
||||||
}
|
|
||||||
|
|
||||||
// structTag will take a tag like dns:"size-base32:SaltLength" and return base32.
|
|
||||||
func structTag(s string) string {
|
|
||||||
fields := strings.Split(s, ":")
|
|
||||||
if len(fields) < 2 {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
return fields[1][len("\"size-"):]
|
|
||||||
}
|
|
||||||
|
|
||||||
func fatalIfErr(err error) {
|
|
||||||
if err != nil {
|
|
||||||
log.Fatal(err)
|
|
||||||
}
|
|
||||||
}
|
|
271
vendor/github.com/miekg/dns/types_generate.go
generated
vendored
271
vendor/github.com/miekg/dns/types_generate.go
generated
vendored
@ -1,271 +0,0 @@
|
|||||||
//+build ignore
|
|
||||||
|
|
||||||
// types_generate.go is meant to run with go generate. It will use
|
|
||||||
// go/{importer,types} to track down all the RR struct types. Then for each type
|
|
||||||
// it will generate conversion tables (TypeToRR and TypeToString) and banal
|
|
||||||
// methods (len, Header, copy) based on the struct tags. The generated source is
|
|
||||||
// written to ztypes.go, and is meant to be checked into git.
|
|
||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"fmt"
|
|
||||||
"go/format"
|
|
||||||
"go/importer"
|
|
||||||
"go/types"
|
|
||||||
"log"
|
|
||||||
"os"
|
|
||||||
"strings"
|
|
||||||
"text/template"
|
|
||||||
)
|
|
||||||
|
|
||||||
var skipLen = map[string]struct{}{
|
|
||||||
"NSEC": {},
|
|
||||||
"NSEC3": {},
|
|
||||||
"OPT": {},
|
|
||||||
}
|
|
||||||
|
|
||||||
var packageHdr = `
|
|
||||||
// *** DO NOT MODIFY ***
|
|
||||||
// AUTOGENERATED BY go generate from type_generate.go
|
|
||||||
|
|
||||||
package dns
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/base64"
|
|
||||||
"net"
|
|
||||||
)
|
|
||||||
|
|
||||||
`
|
|
||||||
|
|
||||||
var TypeToRR = template.Must(template.New("TypeToRR").Parse(`
|
|
||||||
// TypeToRR is a map of constructors for each RR type.
|
|
||||||
var TypeToRR = map[uint16]func() RR{
|
|
||||||
{{range .}}{{if ne . "RFC3597"}} Type{{.}}: func() RR { return new({{.}}) },
|
|
||||||
{{end}}{{end}} }
|
|
||||||
|
|
||||||
`))
|
|
||||||
|
|
||||||
var typeToString = template.Must(template.New("typeToString").Parse(`
|
|
||||||
// TypeToString is a map of strings for each RR type.
|
|
||||||
var TypeToString = map[uint16]string{
|
|
||||||
{{range .}}{{if ne . "NSAPPTR"}} Type{{.}}: "{{.}}",
|
|
||||||
{{end}}{{end}} TypeNSAPPTR: "NSAP-PTR",
|
|
||||||
}
|
|
||||||
|
|
||||||
`))
|
|
||||||
|
|
||||||
var headerFunc = template.Must(template.New("headerFunc").Parse(`
|
|
||||||
// Header() functions
|
|
||||||
{{range .}} func (rr *{{.}}) Header() *RR_Header { return &rr.Hdr }
|
|
||||||
{{end}}
|
|
||||||
|
|
||||||
`))
|
|
||||||
|
|
||||||
// getTypeStruct will take a type and the package scope, and return the
|
|
||||||
// (innermost) struct if the type is considered a RR type (currently defined as
|
|
||||||
// those structs beginning with a RR_Header, could be redefined as implementing
|
|
||||||
// the RR interface). The bool return value indicates if embedded structs were
|
|
||||||
// resolved.
|
|
||||||
func getTypeStruct(t types.Type, scope *types.Scope) (*types.Struct, bool) {
|
|
||||||
st, ok := t.Underlying().(*types.Struct)
|
|
||||||
if !ok {
|
|
||||||
return nil, false
|
|
||||||
}
|
|
||||||
if st.Field(0).Type() == scope.Lookup("RR_Header").Type() {
|
|
||||||
return st, false
|
|
||||||
}
|
|
||||||
if st.Field(0).Anonymous() {
|
|
||||||
st, _ := getTypeStruct(st.Field(0).Type(), scope)
|
|
||||||
return st, true
|
|
||||||
}
|
|
||||||
return nil, false
|
|
||||||
}
|
|
||||||
|
|
||||||
func main() {
|
|
||||||
// Import and type-check the package
|
|
||||||
pkg, err := importer.Default().Import("github.com/miekg/dns")
|
|
||||||
fatalIfErr(err)
|
|
||||||
scope := pkg.Scope()
|
|
||||||
|
|
||||||
// Collect constants like TypeX
|
|
||||||
var numberedTypes []string
|
|
||||||
for _, name := range scope.Names() {
|
|
||||||
o := scope.Lookup(name)
|
|
||||||
if o == nil || !o.Exported() {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
b, ok := o.Type().(*types.Basic)
|
|
||||||
if !ok || b.Kind() != types.Uint16 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if !strings.HasPrefix(o.Name(), "Type") {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
name := strings.TrimPrefix(o.Name(), "Type")
|
|
||||||
if name == "PrivateRR" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
numberedTypes = append(numberedTypes, name)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Collect actual types (*X)
|
|
||||||
var namedTypes []string
|
|
||||||
for _, name := range scope.Names() {
|
|
||||||
o := scope.Lookup(name)
|
|
||||||
if o == nil || !o.Exported() {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if st, _ := getTypeStruct(o.Type(), scope); st == nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if name == "PrivateRR" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if corresponding TypeX exists
|
|
||||||
if scope.Lookup("Type"+o.Name()) == nil && o.Name() != "RFC3597" {
|
|
||||||
log.Fatalf("Constant Type%s does not exist.", o.Name())
|
|
||||||
}
|
|
||||||
|
|
||||||
namedTypes = append(namedTypes, o.Name())
|
|
||||||
}
|
|
||||||
|
|
||||||
b := &bytes.Buffer{}
|
|
||||||
b.WriteString(packageHdr)
|
|
||||||
|
|
||||||
// Generate TypeToRR
|
|
||||||
fatalIfErr(TypeToRR.Execute(b, namedTypes))
|
|
||||||
|
|
||||||
// Generate typeToString
|
|
||||||
fatalIfErr(typeToString.Execute(b, numberedTypes))
|
|
||||||
|
|
||||||
// Generate headerFunc
|
|
||||||
fatalIfErr(headerFunc.Execute(b, namedTypes))
|
|
||||||
|
|
||||||
// Generate len()
|
|
||||||
fmt.Fprint(b, "// len() functions\n")
|
|
||||||
for _, name := range namedTypes {
|
|
||||||
if _, ok := skipLen[name]; ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
o := scope.Lookup(name)
|
|
||||||
st, isEmbedded := getTypeStruct(o.Type(), scope)
|
|
||||||
if isEmbedded {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
fmt.Fprintf(b, "func (rr *%s) len() int {\n", name)
|
|
||||||
fmt.Fprintf(b, "l := rr.Hdr.len()\n")
|
|
||||||
for i := 1; i < st.NumFields(); i++ {
|
|
||||||
o := func(s string) { fmt.Fprintf(b, s, st.Field(i).Name()) }
|
|
||||||
|
|
||||||
if _, ok := st.Field(i).Type().(*types.Slice); ok {
|
|
||||||
switch st.Tag(i) {
|
|
||||||
case `dns:"-"`:
|
|
||||||
// ignored
|
|
||||||
case `dns:"cdomain-name"`, `dns:"domain-name"`, `dns:"txt"`:
|
|
||||||
o("for _, x := range rr.%s { l += len(x) + 1 }\n")
|
|
||||||
default:
|
|
||||||
log.Fatalln(name, st.Field(i).Name(), st.Tag(i))
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
switch {
|
|
||||||
case st.Tag(i) == `dns:"-"`:
|
|
||||||
// ignored
|
|
||||||
case st.Tag(i) == `dns:"cdomain-name"`, st.Tag(i) == `dns:"domain-name"`:
|
|
||||||
o("l += len(rr.%s) + 1\n")
|
|
||||||
case st.Tag(i) == `dns:"octet"`:
|
|
||||||
o("l += len(rr.%s)\n")
|
|
||||||
case strings.HasPrefix(st.Tag(i), `dns:"size-base64`):
|
|
||||||
fallthrough
|
|
||||||
case st.Tag(i) == `dns:"base64"`:
|
|
||||||
o("l += base64.StdEncoding.DecodedLen(len(rr.%s))\n")
|
|
||||||
case strings.HasPrefix(st.Tag(i), `dns:"size-hex`):
|
|
||||||
fallthrough
|
|
||||||
case st.Tag(i) == `dns:"hex"`:
|
|
||||||
o("l += len(rr.%s)/2 + 1\n")
|
|
||||||
case st.Tag(i) == `dns:"a"`:
|
|
||||||
o("l += net.IPv4len // %s\n")
|
|
||||||
case st.Tag(i) == `dns:"aaaa"`:
|
|
||||||
o("l += net.IPv6len // %s\n")
|
|
||||||
case st.Tag(i) == `dns:"txt"`:
|
|
||||||
o("for _, t := range rr.%s { l += len(t) + 1 }\n")
|
|
||||||
case st.Tag(i) == `dns:"uint48"`:
|
|
||||||
o("l += 6 // %s\n")
|
|
||||||
case st.Tag(i) == "":
|
|
||||||
switch st.Field(i).Type().(*types.Basic).Kind() {
|
|
||||||
case types.Uint8:
|
|
||||||
o("l += 1 // %s\n")
|
|
||||||
case types.Uint16:
|
|
||||||
o("l += 2 // %s\n")
|
|
||||||
case types.Uint32:
|
|
||||||
o("l += 4 // %s\n")
|
|
||||||
case types.Uint64:
|
|
||||||
o("l += 8 // %s\n")
|
|
||||||
case types.String:
|
|
||||||
o("l += len(rr.%s) + 1\n")
|
|
||||||
default:
|
|
||||||
log.Fatalln(name, st.Field(i).Name())
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
log.Fatalln(name, st.Field(i).Name(), st.Tag(i))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
fmt.Fprintf(b, "return l }\n")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Generate copy()
|
|
||||||
fmt.Fprint(b, "// copy() functions\n")
|
|
||||||
for _, name := range namedTypes {
|
|
||||||
o := scope.Lookup(name)
|
|
||||||
st, isEmbedded := getTypeStruct(o.Type(), scope)
|
|
||||||
if isEmbedded {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
fmt.Fprintf(b, "func (rr *%s) copy() RR {\n", name)
|
|
||||||
fields := []string{"*rr.Hdr.copyHeader()"}
|
|
||||||
for i := 1; i < st.NumFields(); i++ {
|
|
||||||
f := st.Field(i).Name()
|
|
||||||
if sl, ok := st.Field(i).Type().(*types.Slice); ok {
|
|
||||||
t := sl.Underlying().String()
|
|
||||||
t = strings.TrimPrefix(t, "[]")
|
|
||||||
if strings.Contains(t, ".") {
|
|
||||||
splits := strings.Split(t, ".")
|
|
||||||
t = splits[len(splits)-1]
|
|
||||||
}
|
|
||||||
fmt.Fprintf(b, "%s := make([]%s, len(rr.%s)); copy(%s, rr.%s)\n",
|
|
||||||
f, t, f, f, f)
|
|
||||||
fields = append(fields, f)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if st.Field(i).Type().String() == "net.IP" {
|
|
||||||
fields = append(fields, "copyIP(rr."+f+")")
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
fields = append(fields, "rr."+f)
|
|
||||||
}
|
|
||||||
fmt.Fprintf(b, "return &%s{%s}\n", name, strings.Join(fields, ","))
|
|
||||||
fmt.Fprintf(b, "}\n")
|
|
||||||
}
|
|
||||||
|
|
||||||
// gofmt
|
|
||||||
res, err := format.Source(b.Bytes())
|
|
||||||
if err != nil {
|
|
||||||
b.WriteTo(os.Stderr)
|
|
||||||
log.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// write result
|
|
||||||
f, err := os.Create("ztypes.go")
|
|
||||||
fatalIfErr(err)
|
|
||||||
defer f.Close()
|
|
||||||
f.Write(res)
|
|
||||||
}
|
|
||||||
|
|
||||||
func fatalIfErr(err error) {
|
|
||||||
if err != nil {
|
|
||||||
log.Fatal(err)
|
|
||||||
}
|
|
||||||
}
|
|
68
vendor/golang.org/x/net/idna/idna.go
generated
vendored
68
vendor/golang.org/x/net/idna/idna.go
generated
vendored
@ -1,68 +0,0 @@
|
|||||||
// Copyright 2012 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
// Package idna implements IDNA2008 (Internationalized Domain Names for
|
|
||||||
// Applications), defined in RFC 5890, RFC 5891, RFC 5892, RFC 5893 and
|
|
||||||
// RFC 5894.
|
|
||||||
package idna // import "golang.org/x/net/idna"
|
|
||||||
|
|
||||||
import (
|
|
||||||
"strings"
|
|
||||||
"unicode/utf8"
|
|
||||||
)
|
|
||||||
|
|
||||||
// TODO(nigeltao): specify when errors occur. For example, is ToASCII(".") or
|
|
||||||
// ToASCII("foo\x00") an error? See also http://www.unicode.org/faq/idn.html#11
|
|
||||||
|
|
||||||
// acePrefix is the ASCII Compatible Encoding prefix.
|
|
||||||
const acePrefix = "xn--"
|
|
||||||
|
|
||||||
// ToASCII converts a domain or domain label to its ASCII form. For example,
|
|
||||||
// ToASCII("bücher.example.com") is "xn--bcher-kva.example.com", and
|
|
||||||
// ToASCII("golang") is "golang".
|
|
||||||
func ToASCII(s string) (string, error) {
|
|
||||||
if ascii(s) {
|
|
||||||
return s, nil
|
|
||||||
}
|
|
||||||
labels := strings.Split(s, ".")
|
|
||||||
for i, label := range labels {
|
|
||||||
if !ascii(label) {
|
|
||||||
a, err := encode(acePrefix, label)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
labels[i] = a
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return strings.Join(labels, "."), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ToUnicode converts a domain or domain label to its Unicode form. For example,
|
|
||||||
// ToUnicode("xn--bcher-kva.example.com") is "bücher.example.com", and
|
|
||||||
// ToUnicode("golang") is "golang".
|
|
||||||
func ToUnicode(s string) (string, error) {
|
|
||||||
if !strings.Contains(s, acePrefix) {
|
|
||||||
return s, nil
|
|
||||||
}
|
|
||||||
labels := strings.Split(s, ".")
|
|
||||||
for i, label := range labels {
|
|
||||||
if strings.HasPrefix(label, acePrefix) {
|
|
||||||
u, err := decode(label[len(acePrefix):])
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
labels[i] = u
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return strings.Join(labels, "."), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func ascii(s string) bool {
|
|
||||||
for i := 0; i < len(s); i++ {
|
|
||||||
if s[i] >= utf8.RuneSelf {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
|
200
vendor/golang.org/x/net/idna/punycode.go
generated
vendored
200
vendor/golang.org/x/net/idna/punycode.go
generated
vendored
@ -1,200 +0,0 @@
|
|||||||
// Copyright 2012 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
package idna
|
|
||||||
|
|
||||||
// This file implements the Punycode algorithm from RFC 3492.
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"math"
|
|
||||||
"strings"
|
|
||||||
"unicode/utf8"
|
|
||||||
)
|
|
||||||
|
|
||||||
// These parameter values are specified in section 5.
|
|
||||||
//
|
|
||||||
// All computation is done with int32s, so that overflow behavior is identical
|
|
||||||
// regardless of whether int is 32-bit or 64-bit.
|
|
||||||
const (
|
|
||||||
base int32 = 36
|
|
||||||
damp int32 = 700
|
|
||||||
initialBias int32 = 72
|
|
||||||
initialN int32 = 128
|
|
||||||
skew int32 = 38
|
|
||||||
tmax int32 = 26
|
|
||||||
tmin int32 = 1
|
|
||||||
)
|
|
||||||
|
|
||||||
// decode decodes a string as specified in section 6.2.
|
|
||||||
func decode(encoded string) (string, error) {
|
|
||||||
if encoded == "" {
|
|
||||||
return "", nil
|
|
||||||
}
|
|
||||||
pos := 1 + strings.LastIndex(encoded, "-")
|
|
||||||
if pos == 1 {
|
|
||||||
return "", fmt.Errorf("idna: invalid label %q", encoded)
|
|
||||||
}
|
|
||||||
if pos == len(encoded) {
|
|
||||||
return encoded[:len(encoded)-1], nil
|
|
||||||
}
|
|
||||||
output := make([]rune, 0, len(encoded))
|
|
||||||
if pos != 0 {
|
|
||||||
for _, r := range encoded[:pos-1] {
|
|
||||||
output = append(output, r)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
i, n, bias := int32(0), initialN, initialBias
|
|
||||||
for pos < len(encoded) {
|
|
||||||
oldI, w := i, int32(1)
|
|
||||||
for k := base; ; k += base {
|
|
||||||
if pos == len(encoded) {
|
|
||||||
return "", fmt.Errorf("idna: invalid label %q", encoded)
|
|
||||||
}
|
|
||||||
digit, ok := decodeDigit(encoded[pos])
|
|
||||||
if !ok {
|
|
||||||
return "", fmt.Errorf("idna: invalid label %q", encoded)
|
|
||||||
}
|
|
||||||
pos++
|
|
||||||
i += digit * w
|
|
||||||
if i < 0 {
|
|
||||||
return "", fmt.Errorf("idna: invalid label %q", encoded)
|
|
||||||
}
|
|
||||||
t := k - bias
|
|
||||||
if t < tmin {
|
|
||||||
t = tmin
|
|
||||||
} else if t > tmax {
|
|
||||||
t = tmax
|
|
||||||
}
|
|
||||||
if digit < t {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
w *= base - t
|
|
||||||
if w >= math.MaxInt32/base {
|
|
||||||
return "", fmt.Errorf("idna: invalid label %q", encoded)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
x := int32(len(output) + 1)
|
|
||||||
bias = adapt(i-oldI, x, oldI == 0)
|
|
||||||
n += i / x
|
|
||||||
i %= x
|
|
||||||
if n > utf8.MaxRune || len(output) >= 1024 {
|
|
||||||
return "", fmt.Errorf("idna: invalid label %q", encoded)
|
|
||||||
}
|
|
||||||
output = append(output, 0)
|
|
||||||
copy(output[i+1:], output[i:])
|
|
||||||
output[i] = n
|
|
||||||
i++
|
|
||||||
}
|
|
||||||
return string(output), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// encode encodes a string as specified in section 6.3 and prepends prefix to
|
|
||||||
// the result.
|
|
||||||
//
|
|
||||||
// The "while h < length(input)" line in the specification becomes "for
|
|
||||||
// remaining != 0" in the Go code, because len(s) in Go is in bytes, not runes.
|
|
||||||
func encode(prefix, s string) (string, error) {
|
|
||||||
output := make([]byte, len(prefix), len(prefix)+1+2*len(s))
|
|
||||||
copy(output, prefix)
|
|
||||||
delta, n, bias := int32(0), initialN, initialBias
|
|
||||||
b, remaining := int32(0), int32(0)
|
|
||||||
for _, r := range s {
|
|
||||||
if r < 0x80 {
|
|
||||||
b++
|
|
||||||
output = append(output, byte(r))
|
|
||||||
} else {
|
|
||||||
remaining++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
h := b
|
|
||||||
if b > 0 {
|
|
||||||
output = append(output, '-')
|
|
||||||
}
|
|
||||||
for remaining != 0 {
|
|
||||||
m := int32(0x7fffffff)
|
|
||||||
for _, r := range s {
|
|
||||||
if m > r && r >= n {
|
|
||||||
m = r
|
|
||||||
}
|
|
||||||
}
|
|
||||||
delta += (m - n) * (h + 1)
|
|
||||||
if delta < 0 {
|
|
||||||
return "", fmt.Errorf("idna: invalid label %q", s)
|
|
||||||
}
|
|
||||||
n = m
|
|
||||||
for _, r := range s {
|
|
||||||
if r < n {
|
|
||||||
delta++
|
|
||||||
if delta < 0 {
|
|
||||||
return "", fmt.Errorf("idna: invalid label %q", s)
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if r > n {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
q := delta
|
|
||||||
for k := base; ; k += base {
|
|
||||||
t := k - bias
|
|
||||||
if t < tmin {
|
|
||||||
t = tmin
|
|
||||||
} else if t > tmax {
|
|
||||||
t = tmax
|
|
||||||
}
|
|
||||||
if q < t {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
output = append(output, encodeDigit(t+(q-t)%(base-t)))
|
|
||||||
q = (q - t) / (base - t)
|
|
||||||
}
|
|
||||||
output = append(output, encodeDigit(q))
|
|
||||||
bias = adapt(delta, h+1, h == b)
|
|
||||||
delta = 0
|
|
||||||
h++
|
|
||||||
remaining--
|
|
||||||
}
|
|
||||||
delta++
|
|
||||||
n++
|
|
||||||
}
|
|
||||||
return string(output), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func decodeDigit(x byte) (digit int32, ok bool) {
|
|
||||||
switch {
|
|
||||||
case '0' <= x && x <= '9':
|
|
||||||
return int32(x - ('0' - 26)), true
|
|
||||||
case 'A' <= x && x <= 'Z':
|
|
||||||
return int32(x - 'A'), true
|
|
||||||
case 'a' <= x && x <= 'z':
|
|
||||||
return int32(x - 'a'), true
|
|
||||||
}
|
|
||||||
return 0, false
|
|
||||||
}
|
|
||||||
|
|
||||||
func encodeDigit(digit int32) byte {
|
|
||||||
switch {
|
|
||||||
case 0 <= digit && digit < 26:
|
|
||||||
return byte(digit + 'a')
|
|
||||||
case 26 <= digit && digit < 36:
|
|
||||||
return byte(digit + ('0' - 26))
|
|
||||||
}
|
|
||||||
panic("idna: internal error in punycode encoding")
|
|
||||||
}
|
|
||||||
|
|
||||||
// adapt is the bias adaptation function specified in section 6.1.
|
|
||||||
func adapt(delta, numPoints int32, firstTime bool) int32 {
|
|
||||||
if firstTime {
|
|
||||||
delta /= damp
|
|
||||||
} else {
|
|
||||||
delta /= 2
|
|
||||||
}
|
|
||||||
delta += delta / numPoints
|
|
||||||
k := int32(0)
|
|
||||||
for delta > ((base-tmin)*tmax)/2 {
|
|
||||||
delta /= base - tmin
|
|
||||||
k += base
|
|
||||||
}
|
|
||||||
return k + (base-tmin+1)*delta/(delta+skew)
|
|
||||||
}
|
|
663
vendor/golang.org/x/net/publicsuffix/gen.go
generated
vendored
663
vendor/golang.org/x/net/publicsuffix/gen.go
generated
vendored
@ -1,663 +0,0 @@
|
|||||||
// Copyright 2012 The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
// +build ignore
|
|
||||||
|
|
||||||
package main
|
|
||||||
|
|
||||||
// This program generates table.go and table_test.go.
|
|
||||||
// Invoke as:
|
|
||||||
//
|
|
||||||
// go run gen.go -version "xxx" >table.go
|
|
||||||
// go run gen.go -version "xxx" -test >table_test.go
|
|
||||||
//
|
|
||||||
// Pass -v to print verbose progress information.
|
|
||||||
//
|
|
||||||
// The version is derived from information found at
|
|
||||||
// https://github.com/publicsuffix/list/commits/master/public_suffix_list.dat
|
|
||||||
//
|
|
||||||
// To fetch a particular git revision, such as 5c70ccd250, pass
|
|
||||||
// -url "https://raw.githubusercontent.com/publicsuffix/list/5c70ccd250/public_suffix_list.dat"
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bufio"
|
|
||||||
"bytes"
|
|
||||||
"flag"
|
|
||||||
"fmt"
|
|
||||||
"go/format"
|
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
"os"
|
|
||||||
"regexp"
|
|
||||||
"sort"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"golang.org/x/net/idna"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
// These sum of these four values must be no greater than 32.
|
|
||||||
nodesBitsChildren = 9
|
|
||||||
nodesBitsICANN = 1
|
|
||||||
nodesBitsTextOffset = 15
|
|
||||||
nodesBitsTextLength = 6
|
|
||||||
|
|
||||||
// These sum of these four values must be no greater than 32.
|
|
||||||
childrenBitsWildcard = 1
|
|
||||||
childrenBitsNodeType = 2
|
|
||||||
childrenBitsHi = 14
|
|
||||||
childrenBitsLo = 14
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
maxChildren int
|
|
||||||
maxTextOffset int
|
|
||||||
maxTextLength int
|
|
||||||
maxHi uint32
|
|
||||||
maxLo uint32
|
|
||||||
)
|
|
||||||
|
|
||||||
func max(a, b int) int {
|
|
||||||
if a < b {
|
|
||||||
return b
|
|
||||||
}
|
|
||||||
return a
|
|
||||||
}
|
|
||||||
|
|
||||||
func u32max(a, b uint32) uint32 {
|
|
||||||
if a < b {
|
|
||||||
return b
|
|
||||||
}
|
|
||||||
return a
|
|
||||||
}
|
|
||||||
|
|
||||||
const (
|
|
||||||
nodeTypeNormal = 0
|
|
||||||
nodeTypeException = 1
|
|
||||||
nodeTypeParentOnly = 2
|
|
||||||
numNodeType = 3
|
|
||||||
)
|
|
||||||
|
|
||||||
func nodeTypeStr(n int) string {
|
|
||||||
switch n {
|
|
||||||
case nodeTypeNormal:
|
|
||||||
return "+"
|
|
||||||
case nodeTypeException:
|
|
||||||
return "!"
|
|
||||||
case nodeTypeParentOnly:
|
|
||||||
return "o"
|
|
||||||
}
|
|
||||||
panic("unreachable")
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
|
||||||
labelEncoding = map[string]uint32{}
|
|
||||||
labelsList = []string{}
|
|
||||||
labelsMap = map[string]bool{}
|
|
||||||
rules = []string{}
|
|
||||||
|
|
||||||
// validSuffix is used to check that the entries in the public suffix list
|
|
||||||
// are in canonical form (after Punycode encoding). Specifically, capital
|
|
||||||
// letters are not allowed.
|
|
||||||
validSuffix = regexp.MustCompile(`^[a-z0-9_\!\*\-\.]+$`)
|
|
||||||
|
|
||||||
subset = flag.Bool("subset", false, "generate only a subset of the full table, for debugging")
|
|
||||||
url = flag.String("url",
|
|
||||||
"https://publicsuffix.org/list/effective_tld_names.dat",
|
|
||||||
"URL of the publicsuffix.org list. If empty, stdin is read instead")
|
|
||||||
v = flag.Bool("v", false, "verbose output (to stderr)")
|
|
||||||
version = flag.String("version", "", "the effective_tld_names.dat version")
|
|
||||||
test = flag.Bool("test", false, "generate table_test.go")
|
|
||||||
)
|
|
||||||
|
|
||||||
func main() {
|
|
||||||
if err := main1(); err != nil {
|
|
||||||
fmt.Fprintln(os.Stderr, err)
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func main1() error {
|
|
||||||
flag.Parse()
|
|
||||||
if nodesBitsTextLength+nodesBitsTextOffset+nodesBitsICANN+nodesBitsChildren > 32 {
|
|
||||||
return fmt.Errorf("not enough bits to encode the nodes table")
|
|
||||||
}
|
|
||||||
if childrenBitsLo+childrenBitsHi+childrenBitsNodeType+childrenBitsWildcard > 32 {
|
|
||||||
return fmt.Errorf("not enough bits to encode the children table")
|
|
||||||
}
|
|
||||||
if *version == "" {
|
|
||||||
return fmt.Errorf("-version was not specified")
|
|
||||||
}
|
|
||||||
var r io.Reader = os.Stdin
|
|
||||||
if *url != "" {
|
|
||||||
res, err := http.Get(*url)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if res.StatusCode != http.StatusOK {
|
|
||||||
return fmt.Errorf("bad GET status for %s: %d", *url, res.Status)
|
|
||||||
}
|
|
||||||
r = res.Body
|
|
||||||
defer res.Body.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
var root node
|
|
||||||
icann := false
|
|
||||||
buf := new(bytes.Buffer)
|
|
||||||
br := bufio.NewReader(r)
|
|
||||||
for {
|
|
||||||
s, err := br.ReadString('\n')
|
|
||||||
if err != nil {
|
|
||||||
if err == io.EOF {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
s = strings.TrimSpace(s)
|
|
||||||
if strings.Contains(s, "BEGIN ICANN DOMAINS") {
|
|
||||||
icann = true
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if strings.Contains(s, "END ICANN DOMAINS") {
|
|
||||||
icann = false
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if s == "" || strings.HasPrefix(s, "//") {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
s, err = idna.ToASCII(s)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if !validSuffix.MatchString(s) {
|
|
||||||
return fmt.Errorf("bad publicsuffix.org list data: %q", s)
|
|
||||||
}
|
|
||||||
|
|
||||||
if *subset {
|
|
||||||
switch {
|
|
||||||
case s == "ac.jp" || strings.HasSuffix(s, ".ac.jp"):
|
|
||||||
case s == "ak.us" || strings.HasSuffix(s, ".ak.us"):
|
|
||||||
case s == "ao" || strings.HasSuffix(s, ".ao"):
|
|
||||||
case s == "ar" || strings.HasSuffix(s, ".ar"):
|
|
||||||
case s == "arpa" || strings.HasSuffix(s, ".arpa"):
|
|
||||||
case s == "cy" || strings.HasSuffix(s, ".cy"):
|
|
||||||
case s == "dyndns.org" || strings.HasSuffix(s, ".dyndns.org"):
|
|
||||||
case s == "jp":
|
|
||||||
case s == "kobe.jp" || strings.HasSuffix(s, ".kobe.jp"):
|
|
||||||
case s == "kyoto.jp" || strings.HasSuffix(s, ".kyoto.jp"):
|
|
||||||
case s == "om" || strings.HasSuffix(s, ".om"):
|
|
||||||
case s == "uk" || strings.HasSuffix(s, ".uk"):
|
|
||||||
case s == "uk.com" || strings.HasSuffix(s, ".uk.com"):
|
|
||||||
case s == "tw" || strings.HasSuffix(s, ".tw"):
|
|
||||||
case s == "zw" || strings.HasSuffix(s, ".zw"):
|
|
||||||
case s == "xn--p1ai" || strings.HasSuffix(s, ".xn--p1ai"):
|
|
||||||
// xn--p1ai is Russian-Cyrillic "рф".
|
|
||||||
default:
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
rules = append(rules, s)
|
|
||||||
|
|
||||||
nt, wildcard := nodeTypeNormal, false
|
|
||||||
switch {
|
|
||||||
case strings.HasPrefix(s, "*."):
|
|
||||||
s, nt = s[2:], nodeTypeParentOnly
|
|
||||||
wildcard = true
|
|
||||||
case strings.HasPrefix(s, "!"):
|
|
||||||
s, nt = s[1:], nodeTypeException
|
|
||||||
}
|
|
||||||
labels := strings.Split(s, ".")
|
|
||||||
for n, i := &root, len(labels)-1; i >= 0; i-- {
|
|
||||||
label := labels[i]
|
|
||||||
n = n.child(label)
|
|
||||||
if i == 0 {
|
|
||||||
if nt != nodeTypeParentOnly && n.nodeType == nodeTypeParentOnly {
|
|
||||||
n.nodeType = nt
|
|
||||||
}
|
|
||||||
n.icann = n.icann && icann
|
|
||||||
n.wildcard = n.wildcard || wildcard
|
|
||||||
}
|
|
||||||
labelsMap[label] = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
labelsList = make([]string, 0, len(labelsMap))
|
|
||||||
for label := range labelsMap {
|
|
||||||
labelsList = append(labelsList, label)
|
|
||||||
}
|
|
||||||
sort.Strings(labelsList)
|
|
||||||
|
|
||||||
p := printReal
|
|
||||||
if *test {
|
|
||||||
p = printTest
|
|
||||||
}
|
|
||||||
if err := p(buf, &root); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
b, err := format.Source(buf.Bytes())
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
_, err = os.Stdout.Write(b)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func printTest(w io.Writer, n *node) error {
|
|
||||||
fmt.Fprintf(w, "// generated by go run gen.go; DO NOT EDIT\n\n")
|
|
||||||
fmt.Fprintf(w, "package publicsuffix\n\nvar rules = [...]string{\n")
|
|
||||||
for _, rule := range rules {
|
|
||||||
fmt.Fprintf(w, "%q,\n", rule)
|
|
||||||
}
|
|
||||||
fmt.Fprintf(w, "}\n\nvar nodeLabels = [...]string{\n")
|
|
||||||
if err := n.walk(w, printNodeLabel); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
fmt.Fprintf(w, "}\n")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func printReal(w io.Writer, n *node) error {
|
|
||||||
const header = `// generated by go run gen.go; DO NOT EDIT
|
|
||||||
|
|
||||||
package publicsuffix
|
|
||||||
|
|
||||||
const version = %q
|
|
||||||
|
|
||||||
const (
|
|
||||||
nodesBitsChildren = %d
|
|
||||||
nodesBitsICANN = %d
|
|
||||||
nodesBitsTextOffset = %d
|
|
||||||
nodesBitsTextLength = %d
|
|
||||||
|
|
||||||
childrenBitsWildcard = %d
|
|
||||||
childrenBitsNodeType = %d
|
|
||||||
childrenBitsHi = %d
|
|
||||||
childrenBitsLo = %d
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
nodeTypeNormal = %d
|
|
||||||
nodeTypeException = %d
|
|
||||||
nodeTypeParentOnly = %d
|
|
||||||
)
|
|
||||||
|
|
||||||
// numTLD is the number of top level domains.
|
|
||||||
const numTLD = %d
|
|
||||||
|
|
||||||
`
|
|
||||||
fmt.Fprintf(w, header, *version,
|
|
||||||
nodesBitsChildren, nodesBitsICANN, nodesBitsTextOffset, nodesBitsTextLength,
|
|
||||||
childrenBitsWildcard, childrenBitsNodeType, childrenBitsHi, childrenBitsLo,
|
|
||||||
nodeTypeNormal, nodeTypeException, nodeTypeParentOnly, len(n.children))
|
|
||||||
|
|
||||||
text := combineText(labelsList)
|
|
||||||
if text == "" {
|
|
||||||
return fmt.Errorf("internal error: makeText returned no text")
|
|
||||||
}
|
|
||||||
for _, label := range labelsList {
|
|
||||||
offset, length := strings.Index(text, label), len(label)
|
|
||||||
if offset < 0 {
|
|
||||||
return fmt.Errorf("internal error: could not find %q in text %q", label, text)
|
|
||||||
}
|
|
||||||
maxTextOffset, maxTextLength = max(maxTextOffset, offset), max(maxTextLength, length)
|
|
||||||
if offset >= 1<<nodesBitsTextOffset {
|
|
||||||
return fmt.Errorf("text offset %d is too large, or nodeBitsTextOffset is too small", offset)
|
|
||||||
}
|
|
||||||
if length >= 1<<nodesBitsTextLength {
|
|
||||||
return fmt.Errorf("text length %d is too large, or nodeBitsTextLength is too small", length)
|
|
||||||
}
|
|
||||||
labelEncoding[label] = uint32(offset)<<nodesBitsTextLength | uint32(length)
|
|
||||||
}
|
|
||||||
fmt.Fprintf(w, "// Text is the combined text of all labels.\nconst text = ")
|
|
||||||
for len(text) > 0 {
|
|
||||||
n, plus := len(text), ""
|
|
||||||
if n > 64 {
|
|
||||||
n, plus = 64, " +"
|
|
||||||
}
|
|
||||||
fmt.Fprintf(w, "%q%s\n", text[:n], plus)
|
|
||||||
text = text[n:]
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := n.walk(w, assignIndexes); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
fmt.Fprintf(w, `
|
|
||||||
|
|
||||||
// nodes is the list of nodes. Each node is represented as a uint32, which
|
|
||||||
// encodes the node's children, wildcard bit and node type (as an index into
|
|
||||||
// the children array), ICANN bit and text.
|
|
||||||
//
|
|
||||||
// In the //-comment after each node's data, the nodes indexes of the children
|
|
||||||
// are formatted as (n0x1234-n0x1256), with * denoting the wildcard bit. The
|
|
||||||
// nodeType is printed as + for normal, ! for exception, and o for parent-only
|
|
||||||
// nodes that have children but don't match a domain label in their own right.
|
|
||||||
// An I denotes an ICANN domain.
|
|
||||||
//
|
|
||||||
// The layout within the uint32, from MSB to LSB, is:
|
|
||||||
// [%2d bits] unused
|
|
||||||
// [%2d bits] children index
|
|
||||||
// [%2d bits] ICANN bit
|
|
||||||
// [%2d bits] text index
|
|
||||||
// [%2d bits] text length
|
|
||||||
var nodes = [...]uint32{
|
|
||||||
`,
|
|
||||||
32-nodesBitsChildren-nodesBitsICANN-nodesBitsTextOffset-nodesBitsTextLength,
|
|
||||||
nodesBitsChildren, nodesBitsICANN, nodesBitsTextOffset, nodesBitsTextLength)
|
|
||||||
if err := n.walk(w, printNode); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
fmt.Fprintf(w, `}
|
|
||||||
|
|
||||||
// children is the list of nodes' children, the parent's wildcard bit and the
|
|
||||||
// parent's node type. If a node has no children then their children index
|
|
||||||
// will be in the range [0, 6), depending on the wildcard bit and node type.
|
|
||||||
//
|
|
||||||
// The layout within the uint32, from MSB to LSB, is:
|
|
||||||
// [%2d bits] unused
|
|
||||||
// [%2d bits] wildcard bit
|
|
||||||
// [%2d bits] node type
|
|
||||||
// [%2d bits] high nodes index (exclusive) of children
|
|
||||||
// [%2d bits] low nodes index (inclusive) of children
|
|
||||||
var children=[...]uint32{
|
|
||||||
`,
|
|
||||||
32-childrenBitsWildcard-childrenBitsNodeType-childrenBitsHi-childrenBitsLo,
|
|
||||||
childrenBitsWildcard, childrenBitsNodeType, childrenBitsHi, childrenBitsLo)
|
|
||||||
for i, c := range childrenEncoding {
|
|
||||||
s := "---------------"
|
|
||||||
lo := c & (1<<childrenBitsLo - 1)
|
|
||||||
hi := (c >> childrenBitsLo) & (1<<childrenBitsHi - 1)
|
|
||||||
if lo != hi {
|
|
||||||
s = fmt.Sprintf("n0x%04x-n0x%04x", lo, hi)
|
|
||||||
}
|
|
||||||
nodeType := int(c>>(childrenBitsLo+childrenBitsHi)) & (1<<childrenBitsNodeType - 1)
|
|
||||||
wildcard := c>>(childrenBitsLo+childrenBitsHi+childrenBitsNodeType) != 0
|
|
||||||
fmt.Fprintf(w, "0x%08x, // c0x%04x (%s)%s %s\n",
|
|
||||||
c, i, s, wildcardStr(wildcard), nodeTypeStr(nodeType))
|
|
||||||
}
|
|
||||||
fmt.Fprintf(w, "}\n\n")
|
|
||||||
fmt.Fprintf(w, "// max children %d (capacity %d)\n", maxChildren, 1<<nodesBitsChildren-1)
|
|
||||||
fmt.Fprintf(w, "// max text offset %d (capacity %d)\n", maxTextOffset, 1<<nodesBitsTextOffset-1)
|
|
||||||
fmt.Fprintf(w, "// max text length %d (capacity %d)\n", maxTextLength, 1<<nodesBitsTextLength-1)
|
|
||||||
fmt.Fprintf(w, "// max hi %d (capacity %d)\n", maxHi, 1<<childrenBitsHi-1)
|
|
||||||
fmt.Fprintf(w, "// max lo %d (capacity %d)\n", maxLo, 1<<childrenBitsLo-1)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type node struct {
|
|
||||||
label string
|
|
||||||
nodeType int
|
|
||||||
icann bool
|
|
||||||
wildcard bool
|
|
||||||
// nodesIndex and childrenIndex are the index of this node in the nodes
|
|
||||||
// and the index of its children offset/length in the children arrays.
|
|
||||||
nodesIndex, childrenIndex int
|
|
||||||
// firstChild is the index of this node's first child, or zero if this
|
|
||||||
// node has no children.
|
|
||||||
firstChild int
|
|
||||||
// children are the node's children, in strictly increasing node label order.
|
|
||||||
children []*node
|
|
||||||
}
|
|
||||||
|
|
||||||
func (n *node) walk(w io.Writer, f func(w1 io.Writer, n1 *node) error) error {
|
|
||||||
if err := f(w, n); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
for _, c := range n.children {
|
|
||||||
if err := c.walk(w, f); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// child returns the child of n with the given label. The child is created if
|
|
||||||
// it did not exist beforehand.
|
|
||||||
func (n *node) child(label string) *node {
|
|
||||||
for _, c := range n.children {
|
|
||||||
if c.label == label {
|
|
||||||
return c
|
|
||||||
}
|
|
||||||
}
|
|
||||||
c := &node{
|
|
||||||
label: label,
|
|
||||||
nodeType: nodeTypeParentOnly,
|
|
||||||
icann: true,
|
|
||||||
}
|
|
||||||
n.children = append(n.children, c)
|
|
||||||
sort.Sort(byLabel(n.children))
|
|
||||||
return c
|
|
||||||
}
|
|
||||||
|
|
||||||
type byLabel []*node
|
|
||||||
|
|
||||||
func (b byLabel) Len() int { return len(b) }
|
|
||||||
func (b byLabel) Swap(i, j int) { b[i], b[j] = b[j], b[i] }
|
|
||||||
func (b byLabel) Less(i, j int) bool { return b[i].label < b[j].label }
|
|
||||||
|
|
||||||
var nextNodesIndex int
|
|
||||||
|
|
||||||
// childrenEncoding are the encoded entries in the generated children array.
|
|
||||||
// All these pre-defined entries have no children.
|
|
||||||
var childrenEncoding = []uint32{
|
|
||||||
0 << (childrenBitsLo + childrenBitsHi), // Without wildcard bit, nodeTypeNormal.
|
|
||||||
1 << (childrenBitsLo + childrenBitsHi), // Without wildcard bit, nodeTypeException.
|
|
||||||
2 << (childrenBitsLo + childrenBitsHi), // Without wildcard bit, nodeTypeParentOnly.
|
|
||||||
4 << (childrenBitsLo + childrenBitsHi), // With wildcard bit, nodeTypeNormal.
|
|
||||||
5 << (childrenBitsLo + childrenBitsHi), // With wildcard bit, nodeTypeException.
|
|
||||||
6 << (childrenBitsLo + childrenBitsHi), // With wildcard bit, nodeTypeParentOnly.
|
|
||||||
}
|
|
||||||
|
|
||||||
var firstCallToAssignIndexes = true
|
|
||||||
|
|
||||||
func assignIndexes(w io.Writer, n *node) error {
|
|
||||||
if len(n.children) != 0 {
|
|
||||||
// Assign nodesIndex.
|
|
||||||
n.firstChild = nextNodesIndex
|
|
||||||
for _, c := range n.children {
|
|
||||||
c.nodesIndex = nextNodesIndex
|
|
||||||
nextNodesIndex++
|
|
||||||
}
|
|
||||||
|
|
||||||
// The root node's children is implicit.
|
|
||||||
if firstCallToAssignIndexes {
|
|
||||||
firstCallToAssignIndexes = false
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Assign childrenIndex.
|
|
||||||
maxChildren = max(maxChildren, len(childrenEncoding))
|
|
||||||
if len(childrenEncoding) >= 1<<nodesBitsChildren {
|
|
||||||
return fmt.Errorf("children table size %d is too large, or nodeBitsChildren is too small", len(childrenEncoding))
|
|
||||||
}
|
|
||||||
n.childrenIndex = len(childrenEncoding)
|
|
||||||
lo := uint32(n.firstChild)
|
|
||||||
hi := lo + uint32(len(n.children))
|
|
||||||
maxLo, maxHi = u32max(maxLo, lo), u32max(maxHi, hi)
|
|
||||||
if lo >= 1<<childrenBitsLo {
|
|
||||||
return fmt.Errorf("children lo %d is too large, or childrenBitsLo is too small", lo)
|
|
||||||
}
|
|
||||||
if hi >= 1<<childrenBitsHi {
|
|
||||||
return fmt.Errorf("children hi %d is too large, or childrenBitsHi is too small", hi)
|
|
||||||
}
|
|
||||||
enc := hi<<childrenBitsLo | lo
|
|
||||||
enc |= uint32(n.nodeType) << (childrenBitsLo + childrenBitsHi)
|
|
||||||
if n.wildcard {
|
|
||||||
enc |= 1 << (childrenBitsLo + childrenBitsHi + childrenBitsNodeType)
|
|
||||||
}
|
|
||||||
childrenEncoding = append(childrenEncoding, enc)
|
|
||||||
} else {
|
|
||||||
n.childrenIndex = n.nodeType
|
|
||||||
if n.wildcard {
|
|
||||||
n.childrenIndex += numNodeType
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func printNode(w io.Writer, n *node) error {
|
|
||||||
for _, c := range n.children {
|
|
||||||
s := "---------------"
|
|
||||||
if len(c.children) != 0 {
|
|
||||||
s = fmt.Sprintf("n0x%04x-n0x%04x", c.firstChild, c.firstChild+len(c.children))
|
|
||||||
}
|
|
||||||
encoding := labelEncoding[c.label]
|
|
||||||
if c.icann {
|
|
||||||
encoding |= 1 << (nodesBitsTextLength + nodesBitsTextOffset)
|
|
||||||
}
|
|
||||||
encoding |= uint32(c.childrenIndex) << (nodesBitsTextLength + nodesBitsTextOffset + nodesBitsICANN)
|
|
||||||
fmt.Fprintf(w, "0x%08x, // n0x%04x c0x%04x (%s)%s %s %s %s\n",
|
|
||||||
encoding, c.nodesIndex, c.childrenIndex, s, wildcardStr(c.wildcard),
|
|
||||||
nodeTypeStr(c.nodeType), icannStr(c.icann), c.label,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func printNodeLabel(w io.Writer, n *node) error {
|
|
||||||
for _, c := range n.children {
|
|
||||||
fmt.Fprintf(w, "%q,\n", c.label)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func icannStr(icann bool) string {
|
|
||||||
if icann {
|
|
||||||
return "I"
|
|
||||||
}
|
|
||||||
return " "
|
|
||||||
}
|
|
||||||
|
|
||||||
func wildcardStr(wildcard bool) string {
|
|
||||||
if wildcard {
|
|
||||||
return "*"
|
|
||||||
}
|
|
||||||
return " "
|
|
||||||
}
|
|
||||||
|
|
||||||
// combineText combines all the strings in labelsList to form one giant string.
|
|
||||||
// Overlapping strings will be merged: "arpa" and "parliament" could yield
|
|
||||||
// "arparliament".
|
|
||||||
func combineText(labelsList []string) string {
|
|
||||||
beforeLength := 0
|
|
||||||
for _, s := range labelsList {
|
|
||||||
beforeLength += len(s)
|
|
||||||
}
|
|
||||||
|
|
||||||
text := crush(removeSubstrings(labelsList))
|
|
||||||
if *v {
|
|
||||||
fmt.Fprintf(os.Stderr, "crushed %d bytes to become %d bytes\n", beforeLength, len(text))
|
|
||||||
}
|
|
||||||
return text
|
|
||||||
}
|
|
||||||
|
|
||||||
type byLength []string
|
|
||||||
|
|
||||||
func (s byLength) Len() int { return len(s) }
|
|
||||||
func (s byLength) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
|
|
||||||
func (s byLength) Less(i, j int) bool { return len(s[i]) < len(s[j]) }
|
|
||||||
|
|
||||||
// removeSubstrings returns a copy of its input with any strings removed
|
|
||||||
// that are substrings of other provided strings.
|
|
||||||
func removeSubstrings(input []string) []string {
|
|
||||||
// Make a copy of input.
|
|
||||||
ss := append(make([]string, 0, len(input)), input...)
|
|
||||||
sort.Sort(byLength(ss))
|
|
||||||
|
|
||||||
for i, shortString := range ss {
|
|
||||||
// For each string, only consider strings higher than it in sort order, i.e.
|
|
||||||
// of equal length or greater.
|
|
||||||
for _, longString := range ss[i+1:] {
|
|
||||||
if strings.Contains(longString, shortString) {
|
|
||||||
ss[i] = ""
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Remove the empty strings.
|
|
||||||
sort.Strings(ss)
|
|
||||||
for len(ss) > 0 && ss[0] == "" {
|
|
||||||
ss = ss[1:]
|
|
||||||
}
|
|
||||||
return ss
|
|
||||||
}
|
|
||||||
|
|
||||||
// crush combines a list of strings, taking advantage of overlaps. It returns a
|
|
||||||
// single string that contains each input string as a substring.
|
|
||||||
func crush(ss []string) string {
|
|
||||||
maxLabelLen := 0
|
|
||||||
for _, s := range ss {
|
|
||||||
if maxLabelLen < len(s) {
|
|
||||||
maxLabelLen = len(s)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for prefixLen := maxLabelLen; prefixLen > 0; prefixLen-- {
|
|
||||||
prefixes := makePrefixMap(ss, prefixLen)
|
|
||||||
for i, s := range ss {
|
|
||||||
if len(s) <= prefixLen {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
mergeLabel(ss, i, prefixLen, prefixes)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return strings.Join(ss, "")
|
|
||||||
}
|
|
||||||
|
|
||||||
// mergeLabel merges the label at ss[i] with the first available matching label
|
|
||||||
// in prefixMap, where the last "prefixLen" characters in ss[i] match the first
|
|
||||||
// "prefixLen" characters in the matching label.
|
|
||||||
// It will merge ss[i] repeatedly until no more matches are available.
|
|
||||||
// All matching labels merged into ss[i] are replaced by "".
|
|
||||||
func mergeLabel(ss []string, i, prefixLen int, prefixes prefixMap) {
|
|
||||||
s := ss[i]
|
|
||||||
suffix := s[len(s)-prefixLen:]
|
|
||||||
for _, j := range prefixes[suffix] {
|
|
||||||
// Empty strings mean "already used." Also avoid merging with self.
|
|
||||||
if ss[j] == "" || i == j {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if *v {
|
|
||||||
fmt.Fprintf(os.Stderr, "%d-length overlap at (%4d,%4d): %q and %q share %q\n",
|
|
||||||
prefixLen, i, j, ss[i], ss[j], suffix)
|
|
||||||
}
|
|
||||||
ss[i] += ss[j][prefixLen:]
|
|
||||||
ss[j] = ""
|
|
||||||
// ss[i] has a new suffix, so merge again if possible.
|
|
||||||
// Note: we only have to merge again at the same prefix length. Shorter
|
|
||||||
// prefix lengths will be handled in the next iteration of crush's for loop.
|
|
||||||
// Can there be matches for longer prefix lengths, introduced by the merge?
|
|
||||||
// I believe that any such matches would by necessity have been eliminated
|
|
||||||
// during substring removal or merged at a higher prefix length. For
|
|
||||||
// instance, in crush("abc", "cde", "bcdef"), combining "abc" and "cde"
|
|
||||||
// would yield "abcde", which could be merged with "bcdef." However, in
|
|
||||||
// practice "cde" would already have been elimintated by removeSubstrings.
|
|
||||||
mergeLabel(ss, i, prefixLen, prefixes)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// prefixMap maps from a prefix to a list of strings containing that prefix. The
|
|
||||||
// list of strings is represented as indexes into a slice of strings stored
|
|
||||||
// elsewhere.
|
|
||||||
type prefixMap map[string][]int
|
|
||||||
|
|
||||||
// makePrefixMap constructs a prefixMap from a slice of strings.
|
|
||||||
func makePrefixMap(ss []string, prefixLen int) prefixMap {
|
|
||||||
prefixes := make(prefixMap)
|
|
||||||
for i, s := range ss {
|
|
||||||
// We use < rather than <= because if a label matches on a prefix equal to
|
|
||||||
// its full length, that's actually a substring match handled by
|
|
||||||
// removeSubstrings.
|
|
||||||
if prefixLen < len(s) {
|
|
||||||
prefix := s[:prefixLen]
|
|
||||||
prefixes[prefix] = append(prefixes[prefix], i)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return prefixes
|
|
||||||
}
|
|
Loading…
Reference in New Issue
Block a user