package middleware import ( "context" "encoding/json" "fmt" "io/ioutil" "net" "net/http" "strings" "sync" "time" dcontext "github.com/distribution/distribution/v3/context" ) const ( // ipRangesURL is the URL to get definition of AWS IPs defaultIPRangesURL = "https://ip-ranges.amazonaws.com/ip-ranges.json" // updateFrequency tells how frequently AWS IPs need to be updated defaultUpdateFrequency = time.Hour * 12 ) // newAWSIPs returns a New awsIP object. // If awsRegion is `nil`, it accepts any region. Otherwise, it only allow the regions specified func newAWSIPs(host string, updateFrequency time.Duration, awsRegion []string) *awsIPs { ips := &awsIPs{ host: host, updateFrequency: updateFrequency, awsRegion: awsRegion, updaterStopChan: make(chan bool), } if err := ips.tryUpdate(); err != nil { dcontext.GetLogger(context.Background()).WithError(err).Warn("failed to update AWS IP") } go ips.updater() return ips } // awsIPs tracks a list of AWS ips, filtered by awsRegion type awsIPs struct { host string updateFrequency time.Duration ipv4 []net.IPNet ipv6 []net.IPNet mutex sync.RWMutex awsRegion []string updaterStopChan chan bool initialized bool } type awsIPResponse struct { Prefixes []prefixEntry `json:"prefixes"` V6Prefixes []prefixEntry `json:"ipv6_prefixes"` } type prefixEntry struct { IPV4Prefix string `json:"ip_prefix"` IPV6Prefix string `json:"ipv6_prefix"` Region string `json:"region"` Service string `json:"service"` } func fetchAWSIPs(url string) (awsIPResponse, error) { var response awsIPResponse resp, err := http.Get(url) if err != nil { return response, err } if resp.StatusCode != 200 { body, _ := ioutil.ReadAll(resp.Body) return response, fmt.Errorf("failed to fetch network data. response = %s", body) } decoder := json.NewDecoder(resp.Body) err = decoder.Decode(&response) if err != nil { return response, err } return response, nil } // tryUpdate attempts to download the new set of ip addresses. // tryUpdate must be thread safe with contains func (s *awsIPs) tryUpdate() error { response, err := fetchAWSIPs(s.host) if err != nil { return err } var ipv4 []net.IPNet var ipv6 []net.IPNet processAddress := func(output *[]net.IPNet, prefix string, region string) { regionAllowed := false if len(s.awsRegion) > 0 { for _, ar := range s.awsRegion { if strings.ToLower(region) == ar { regionAllowed = true break } } } else { regionAllowed = true } _, network, err := net.ParseCIDR(prefix) if err != nil { dcontext.GetLoggerWithFields(dcontext.Background(), map[interface{}]interface{}{ "cidr": prefix, }).Error("unparseable cidr") return } if regionAllowed { *output = append(*output, *network) } } for _, prefix := range response.Prefixes { processAddress(&ipv4, prefix.IPV4Prefix, prefix.Region) } for _, prefix := range response.V6Prefixes { processAddress(&ipv6, prefix.IPV6Prefix, prefix.Region) } s.mutex.Lock() defer s.mutex.Unlock() // Update each attr of awsips atomically. s.ipv4 = ipv4 s.ipv6 = ipv6 s.initialized = true return nil } // This function is meant to be run in a background goroutine. // It will periodically update the ips from aws. func (s *awsIPs) updater() { defer close(s.updaterStopChan) for { time.Sleep(s.updateFrequency) select { case <-s.updaterStopChan: dcontext.GetLogger(context.Background()).Info("aws ip updater received stop signal") return default: err := s.tryUpdate() if err != nil { dcontext.GetLogger(context.Background()).WithError(err).Error("git AWS IP") } } } } // getCandidateNetworks returns either the ipv4 or ipv6 networks // that were last read from aws. The networks returned // have the same type as the ip address provided. func (s *awsIPs) getCandidateNetworks(ip net.IP) []net.IPNet { s.mutex.RLock() defer s.mutex.RUnlock() if ip.To4() != nil { return s.ipv4 } else if ip.To16() != nil { return s.ipv6 } else { dcontext.GetLoggerWithFields(dcontext.Background(), map[interface{}]interface{}{ "ip": ip, }).Error("unknown ip address format") // assume mismatch, pass through cloudfront return nil } } // Contains determines whether the host is within aws. func (s *awsIPs) contains(ip net.IP) bool { networks := s.getCandidateNetworks(ip) for _, network := range networks { if network.Contains(ip) { return true } } return false } // parseIPFromRequest attempts to extract the ip address of the // client that made the request func parseIPFromRequest(ctx context.Context) (net.IP, error) { request, err := dcontext.GetRequest(ctx) if err != nil { return nil, err } ipStr := dcontext.RemoteIP(request) ip := net.ParseIP(ipStr) if ip == nil { return nil, fmt.Errorf("invalid ip address from requester: %s", ipStr) } return ip, nil } // eligibleForS3 checks if a request is eligible for using S3 directly // Return true only when the IP belongs to a specific aws region and user-agent is docker func eligibleForS3(ctx context.Context, awsIPs *awsIPs) bool { if awsIPs != nil && awsIPs.initialized { if addr, err := parseIPFromRequest(ctx); err == nil { request, err := dcontext.GetRequest(ctx) if err != nil { dcontext.GetLogger(ctx).Warnf("the CloudFront middleware cannot parse the request: %s", err) } else { loggerField := map[interface{}]interface{}{ "user-client": request.UserAgent(), "ip": dcontext.RemoteIP(request), } if awsIPs.contains(addr) { dcontext.GetLoggerWithFields(ctx, loggerField).Info("request from the allowed AWS region, skipping CloudFront") return true } dcontext.GetLoggerWithFields(ctx, loggerField).Warn("request not from the allowed AWS region, fallback to CloudFront") } } else { dcontext.GetLogger(ctx).WithError(err).Warn("failed to parse ip address from context, fallback to CloudFront") } } return false }