docker-distribution/vendor/github.com/AzureAD/microsoft-authentication-li.../apps/internal/oauth/ops/authority/authority.go

553 lines
19 KiB
Go

// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
package authority
import (
"context"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"os"
"path"
"strings"
"time"
"github.com/google/uuid"
)
const (
authorizationEndpoint = "https://%v/%v/oauth2/v2.0/authorize"
instanceDiscoveryEndpoint = "https://%v/common/discovery/instance"
tenantDiscoveryEndpointWithRegion = "https://%s.%s/%s/v2.0/.well-known/openid-configuration"
regionName = "REGION_NAME"
defaultAPIVersion = "2021-10-01"
imdsEndpoint = "http://169.254.169.254/metadata/instance/compute/location?format=text&api-version=" + defaultAPIVersion
autoDetectRegion = "TryAutoDetect"
)
// These are various hosts that host AAD Instance discovery endpoints.
const (
defaultHost = "login.microsoftonline.com"
loginMicrosoft = "login.microsoft.com"
loginWindows = "login.windows.net"
loginSTSWindows = "sts.windows.net"
loginMicrosoftOnline = defaultHost
)
// jsonCaller is an interface that allows us to mock the JSONCall method.
type jsonCaller interface {
JSONCall(ctx context.Context, endpoint string, headers http.Header, qv url.Values, body, resp interface{}) error
}
var aadTrustedHostList = map[string]bool{
"login.windows.net": true, // Microsoft Azure Worldwide - Used in validation scenarios where host is not this list
"login.chinacloudapi.cn": true, // Microsoft Azure China
"login.microsoftonline.de": true, // Microsoft Azure Blackforest
"login-us.microsoftonline.com": true, // Microsoft Azure US Government - Legacy
"login.microsoftonline.us": true, // Microsoft Azure US Government
"login.microsoftonline.com": true, // Microsoft Azure Worldwide
"login.cloudgovapi.us": true, // Microsoft Azure US Government
}
// TrustedHost checks if an AAD host is trusted/valid.
func TrustedHost(host string) bool {
if _, ok := aadTrustedHostList[host]; ok {
return true
}
return false
}
// OAuthResponseBase is the base JSON return message for an OAuth call.
// This is embedded in other calls to get the base fields from every response.
type OAuthResponseBase struct {
Error string `json:"error"`
SubError string `json:"suberror"`
ErrorDescription string `json:"error_description"`
ErrorCodes []int `json:"error_codes"`
CorrelationID string `json:"correlation_id"`
Claims string `json:"claims"`
}
// TenantDiscoveryResponse is the tenant endpoints from the OpenID configuration endpoint.
type TenantDiscoveryResponse struct {
OAuthResponseBase
AuthorizationEndpoint string `json:"authorization_endpoint"`
TokenEndpoint string `json:"token_endpoint"`
Issuer string `json:"issuer"`
AdditionalFields map[string]interface{}
}
// Validate validates that the response had the correct values required.
func (r *TenantDiscoveryResponse) Validate() error {
switch "" {
case r.AuthorizationEndpoint:
return errors.New("TenantDiscoveryResponse: authorize endpoint was not found in the openid configuration")
case r.TokenEndpoint:
return errors.New("TenantDiscoveryResponse: token endpoint was not found in the openid configuration")
case r.Issuer:
return errors.New("TenantDiscoveryResponse: issuer was not found in the openid configuration")
}
return nil
}
type InstanceDiscoveryMetadata struct {
PreferredNetwork string `json:"preferred_network"`
PreferredCache string `json:"preferred_cache"`
Aliases []string `json:"aliases"`
AdditionalFields map[string]interface{}
}
type InstanceDiscoveryResponse struct {
TenantDiscoveryEndpoint string `json:"tenant_discovery_endpoint"`
Metadata []InstanceDiscoveryMetadata `json:"metadata"`
AdditionalFields map[string]interface{}
}
//go:generate stringer -type=AuthorizeType
// AuthorizeType represents the type of token flow.
type AuthorizeType int
// These are all the types of token flows.
const (
ATUnknown AuthorizeType = iota
ATUsernamePassword
ATWindowsIntegrated
ATAuthCode
ATInteractive
ATClientCredentials
ATDeviceCode
ATRefreshToken
AccountByID
ATOnBehalfOf
)
// These are all authority types
const (
AAD = "MSSTS"
ADFS = "ADFS"
)
// AuthParams represents the parameters used for authorization for token acquisition.
type AuthParams struct {
AuthorityInfo Info
CorrelationID string
Endpoints Endpoints
ClientID string
// Redirecturi is used for auth flows that specify a redirect URI (e.g. local server for interactive auth flow).
Redirecturi string
HomeAccountID string
// Username is the user-name portion for username/password auth flow.
Username string
// Password is the password portion for username/password auth flow.
Password string
// Scopes is the list of scopes the user consents to.
Scopes []string
// AuthorizationType specifies the auth flow being used.
AuthorizationType AuthorizeType
// State is a random value used to prevent cross-site request forgery attacks.
State string
// CodeChallenge is derived from a code verifier and is sent in the auth request.
CodeChallenge string
// CodeChallengeMethod describes the method used to create the CodeChallenge.
CodeChallengeMethod string
// Prompt specifies the user prompt type during interactive auth.
Prompt string
// IsConfidentialClient specifies if it is a confidential client.
IsConfidentialClient bool
// SendX5C specifies if x5c claim(public key of the certificate) should be sent to STS.
SendX5C bool
// UserAssertion is the access token used to acquire token on behalf of user
UserAssertion string
// Capabilities the client will include with each token request, for example "CP1".
// Call [NewClientCapabilities] to construct a value for this field.
Capabilities ClientCapabilities
// Claims required for an access token to satisfy a conditional access policy
Claims string
// KnownAuthorityHosts don't require metadata discovery because they're known to the user
KnownAuthorityHosts []string
// LoginHint is a username with which to pre-populate account selection during interactive auth
LoginHint string
// DomainHint is a directive that can be used to accelerate the user to their federated IdP sign-in page
DomainHint string
}
// NewAuthParams creates an authorization parameters object.
func NewAuthParams(clientID string, authorityInfo Info) AuthParams {
return AuthParams{
ClientID: clientID,
AuthorityInfo: authorityInfo,
CorrelationID: uuid.New().String(),
}
}
// WithTenant returns a copy of the AuthParams having the specified tenant ID. If the given
// ID is empty, the copy is identical to the original. This function returns an error in
// several cases:
// - ID isn't specific (for example, it's "common")
// - ID is non-empty and the authority doesn't support tenants (for example, it's an ADFS authority)
// - the client is configured to authenticate only Microsoft accounts via the "consumers" endpoint
// - the resulting authority URL is invalid
func (p AuthParams) WithTenant(ID string) (AuthParams, error) {
switch ID {
case "", p.AuthorityInfo.Tenant:
// keep the default tenant because the caller didn't override it
return p, nil
case "common", "consumers", "organizations":
if p.AuthorityInfo.AuthorityType == AAD {
return p, fmt.Errorf(`tenant ID must be a specific tenant, not "%s"`, ID)
}
// else we'll return a better error below
}
if p.AuthorityInfo.AuthorityType != AAD {
return p, errors.New("the authority doesn't support tenants")
}
if p.AuthorityInfo.Tenant == "consumers" {
return p, errors.New(`client is configured to authenticate only personal Microsoft accounts, via the "consumers" endpoint`)
}
authority := "https://" + path.Join(p.AuthorityInfo.Host, ID)
info, err := NewInfoFromAuthorityURI(authority, p.AuthorityInfo.ValidateAuthority, p.AuthorityInfo.InstanceDiscoveryDisabled)
if err == nil {
info.Region = p.AuthorityInfo.Region
p.AuthorityInfo = info
}
return p, err
}
// MergeCapabilitiesAndClaims combines client capabilities and challenge claims into a value suitable for an authentication request's "claims" parameter.
func (p AuthParams) MergeCapabilitiesAndClaims() (string, error) {
claims := p.Claims
if len(p.Capabilities.asMap) > 0 {
if claims == "" {
// without claims the result is simply the capabilities
return p.Capabilities.asJSON, nil
}
// Otherwise, merge claims and capabilties into a single JSON object.
// We handle the claims challenge as a map because we don't know its structure.
var challenge map[string]any
if err := json.Unmarshal([]byte(claims), &challenge); err != nil {
return "", fmt.Errorf(`claims must be JSON. Are they base64 encoded? json.Unmarshal returned "%v"`, err)
}
if err := merge(p.Capabilities.asMap, challenge); err != nil {
return "", err
}
b, err := json.Marshal(challenge)
if err != nil {
return "", err
}
claims = string(b)
}
return claims, nil
}
// merges a into b without overwriting b's values. Returns an error when a and b share a key for which either has a non-object value.
func merge(a, b map[string]any) error {
for k, av := range a {
if bv, ok := b[k]; !ok {
// b doesn't contain this key => simply set it to a's value
b[k] = av
} else {
// b does contain this key => recursively merge a[k] into b[k], provided both are maps. If a[k] or b[k] isn't
// a map, return an error because merging would overwrite some value in b. Errors shouldn't occur in practice
// because the challenge will be from AAD, which knows the capabilities format.
if A, ok := av.(map[string]any); ok {
if B, ok := bv.(map[string]any); ok {
return merge(A, B)
} else {
// b[k] isn't a map
return errors.New("challenge claims conflict with client capabilities")
}
} else {
// a[k] isn't a map
return errors.New("challenge claims conflict with client capabilities")
}
}
}
return nil
}
// ClientCapabilities stores capabilities in the formats used by AuthParams.MergeCapabilitiesAndClaims.
// [NewClientCapabilities] precomputes these representations because capabilities are static for the
// lifetime of a client and are included with every authentication request i.e., these computations
// always have the same result and would otherwise have to be repeated for every request.
type ClientCapabilities struct {
// asJSON is for the common case: adding the capabilities to an auth request with no challenge claims
asJSON string
// asMap is for merging the capabilities with challenge claims
asMap map[string]any
}
func NewClientCapabilities(capabilities []string) (ClientCapabilities, error) {
c := ClientCapabilities{}
var err error
if len(capabilities) > 0 {
cpbs := make([]string, len(capabilities))
for i := 0; i < len(cpbs); i++ {
cpbs[i] = fmt.Sprintf(`"%s"`, capabilities[i])
}
c.asJSON = fmt.Sprintf(`{"access_token":{"xms_cc":{"values":[%s]}}}`, strings.Join(cpbs, ","))
// note our JSON is valid but we can't stop users breaking it with garbage like "}"
err = json.Unmarshal([]byte(c.asJSON), &c.asMap)
}
return c, err
}
// Info consists of information about the authority.
type Info struct {
Host string
CanonicalAuthorityURI string
AuthorityType string
UserRealmURIPrefix string
ValidateAuthority bool
Tenant string
Region string
InstanceDiscoveryDisabled bool
}
func firstPathSegment(u *url.URL) (string, error) {
pathParts := strings.Split(u.EscapedPath(), "/")
if len(pathParts) >= 2 {
return pathParts[1], nil
}
return "", errors.New(`authority must be an https URL such as "https://login.microsoftonline.com/<your tenant>"`)
}
// NewInfoFromAuthorityURI creates an AuthorityInfo instance from the authority URL provided.
func NewInfoFromAuthorityURI(authority string, validateAuthority bool, instanceDiscoveryDisabled bool) (Info, error) {
u, err := url.Parse(strings.ToLower(authority))
if err != nil || u.Scheme != "https" {
return Info{}, errors.New(`authority must be an https URL such as "https://login.microsoftonline.com/<your tenant>"`)
}
tenant, err := firstPathSegment(u)
if err != nil {
return Info{}, err
}
authorityType := AAD
if tenant == "adfs" {
authorityType = ADFS
}
// u.Host includes the port, if any, which is required for private cloud deployments
return Info{
Host: u.Host,
CanonicalAuthorityURI: fmt.Sprintf("https://%v/%v/", u.Host, tenant),
AuthorityType: authorityType,
UserRealmURIPrefix: fmt.Sprintf("https://%v/common/userrealm/", u.Hostname()),
ValidateAuthority: validateAuthority,
Tenant: tenant,
InstanceDiscoveryDisabled: instanceDiscoveryDisabled,
}, nil
}
// Endpoints consists of the endpoints from the tenant discovery response.
type Endpoints struct {
AuthorizationEndpoint string
TokenEndpoint string
selfSignedJwtAudience string
authorityHost string
}
// NewEndpoints creates an Endpoints object.
func NewEndpoints(authorizationEndpoint string, tokenEndpoint string, selfSignedJwtAudience string, authorityHost string) Endpoints {
return Endpoints{authorizationEndpoint, tokenEndpoint, selfSignedJwtAudience, authorityHost}
}
// UserRealmAccountType refers to the type of user realm.
type UserRealmAccountType string
// These are the different types of user realms.
const (
Unknown UserRealmAccountType = ""
Federated UserRealmAccountType = "Federated"
Managed UserRealmAccountType = "Managed"
)
// UserRealm is used for the username password request to determine user type
type UserRealm struct {
AccountType UserRealmAccountType `json:"account_type"`
DomainName string `json:"domain_name"`
CloudInstanceName string `json:"cloud_instance_name"`
CloudAudienceURN string `json:"cloud_audience_urn"`
// required if accountType is Federated
FederationProtocol string `json:"federation_protocol"`
FederationMetadataURL string `json:"federation_metadata_url"`
AdditionalFields map[string]interface{}
}
func (u UserRealm) validate() error {
switch "" {
case string(u.AccountType):
return errors.New("the account type (Federated or Managed) is missing")
case u.DomainName:
return errors.New("domain name of user realm is missing")
case u.CloudInstanceName:
return errors.New("cloud instance name of user realm is missing")
case u.CloudAudienceURN:
return errors.New("cloud Instance URN is missing")
}
if u.AccountType == Federated {
switch "" {
case u.FederationProtocol:
return errors.New("federation protocol of user realm is missing")
case u.FederationMetadataURL:
return errors.New("federation metadata URL of user realm is missing")
}
}
return nil
}
// Client represents the REST calls to authority backends.
type Client struct {
// Comm provides the HTTP transport client.
Comm jsonCaller // *comm.Client
}
func (c Client) UserRealm(ctx context.Context, authParams AuthParams) (UserRealm, error) {
endpoint := fmt.Sprintf("https://%s/common/UserRealm/%s", authParams.Endpoints.authorityHost, url.PathEscape(authParams.Username))
qv := url.Values{
"api-version": []string{"1.0"},
}
resp := UserRealm{}
err := c.Comm.JSONCall(
ctx,
endpoint,
http.Header{"client-request-id": []string{authParams.CorrelationID}},
qv,
nil,
&resp,
)
if err != nil {
return resp, err
}
return resp, resp.validate()
}
func (c Client) GetTenantDiscoveryResponse(ctx context.Context, openIDConfigurationEndpoint string) (TenantDiscoveryResponse, error) {
resp := TenantDiscoveryResponse{}
err := c.Comm.JSONCall(
ctx,
openIDConfigurationEndpoint,
http.Header{},
nil,
nil,
&resp,
)
return resp, err
}
// AADInstanceDiscovery attempts to discover a tenant endpoint (used in OIDC auth with an authorization endpoint).
// This is done by AAD which allows for aliasing of tenants (windows.sts.net is the same as login.windows.com).
func (c Client) AADInstanceDiscovery(ctx context.Context, authorityInfo Info) (InstanceDiscoveryResponse, error) {
region := ""
var err error
resp := InstanceDiscoveryResponse{}
if authorityInfo.Region != "" && authorityInfo.Region != autoDetectRegion {
region = authorityInfo.Region
} else if authorityInfo.Region == autoDetectRegion {
region = detectRegion(ctx)
}
if region != "" {
environment := authorityInfo.Host
switch environment {
case loginMicrosoft, loginWindows, loginSTSWindows, defaultHost:
environment = loginMicrosoft
}
resp.TenantDiscoveryEndpoint = fmt.Sprintf(tenantDiscoveryEndpointWithRegion, region, environment, authorityInfo.Tenant)
metadata := InstanceDiscoveryMetadata{
PreferredNetwork: fmt.Sprintf("%v.%v", region, authorityInfo.Host),
PreferredCache: authorityInfo.Host,
Aliases: []string{fmt.Sprintf("%v.%v", region, authorityInfo.Host), authorityInfo.Host},
}
resp.Metadata = []InstanceDiscoveryMetadata{metadata}
} else {
qv := url.Values{}
qv.Set("api-version", "1.1")
qv.Set("authorization_endpoint", fmt.Sprintf(authorizationEndpoint, authorityInfo.Host, authorityInfo.Tenant))
discoveryHost := defaultHost
if TrustedHost(authorityInfo.Host) {
discoveryHost = authorityInfo.Host
}
endpoint := fmt.Sprintf(instanceDiscoveryEndpoint, discoveryHost)
err = c.Comm.JSONCall(ctx, endpoint, http.Header{}, qv, nil, &resp)
}
return resp, err
}
func detectRegion(ctx context.Context) string {
region := os.Getenv(regionName)
if region != "" {
region = strings.ReplaceAll(region, " ", "")
return strings.ToLower(region)
}
// HTTP call to IMDS endpoint to get region
// Refer : https://identitydivision.visualstudio.com/DevEx/_git/AuthLibrariesApiReview?path=%2FPinAuthToRegion%2FAAD%20SDK%20Proposal%20to%20Pin%20Auth%20to%20region.md&_a=preview&version=GBdev
// Set a 2 second timeout for this http client which only does calls to IMDS endpoint
client := http.Client{
Timeout: time.Duration(2 * time.Second),
}
req, _ := http.NewRequest("GET", imdsEndpoint, nil)
req.Header.Set("Metadata", "true")
resp, err := client.Do(req)
// If the request times out or there is an error, it is retried once
if err != nil || resp.StatusCode != 200 {
resp, err = client.Do(req)
if err != nil || resp.StatusCode != 200 {
return ""
}
}
defer resp.Body.Close()
response, err := io.ReadAll(resp.Body)
if err != nil {
return ""
}
return string(response)
}
func (a *AuthParams) CacheKey(isAppCache bool) string {
if a.AuthorizationType == ATOnBehalfOf {
return a.AssertionHash()
}
if a.AuthorizationType == ATClientCredentials || isAppCache {
return a.AppKey()
}
if a.AuthorizationType == ATRefreshToken || a.AuthorizationType == AccountByID {
return a.HomeAccountID
}
return ""
}
func (a *AuthParams) AssertionHash() string {
hasher := sha256.New()
// Per documentation this never returns an error : https://pkg.go.dev/hash#pkg-types
_, _ = hasher.Write([]byte(a.UserAssertion))
sha := base64.URLEncoding.EncodeToString(hasher.Sum(nil))
return sha
}
func (a *AuthParams) AppKey() string {
if a.AuthorityInfo.Tenant != "" {
return fmt.Sprintf("%s_%s_AppTokenCache", a.ClientID, a.AuthorityInfo.Tenant)
}
return fmt.Sprintf("%s__AppTokenCache", a.ClientID)
}