diff --git a/contrib/token-server/token.go b/contrib/token-server/token.go index a66c767d..dc0956d4 100644 --- a/contrib/token-server/token.go +++ b/contrib/token-server/token.go @@ -180,7 +180,7 @@ func (issuer *TokenIssuer) CreateJWT(subject string, audience string, grantedAcc claimSet := token.ClaimSet{ Issuer: issuer.Issuer, Subject: subject, - Audience: audience, + Audience: []string{audience}, Expiration: now.Add(exp).Unix(), NotBefore: now.Unix(), IssuedAt: now.Unix(), diff --git a/registry/auth/token/token.go b/registry/auth/token/token.go index da016700..dfd1569b 100644 --- a/registry/auth/token/token.go +++ b/registry/auth/token/token.go @@ -42,13 +42,13 @@ type ResourceActions struct { // ClaimSet describes the main section of a JSON Web Token. type ClaimSet struct { // Public claims - Issuer string `json:"iss"` - Subject string `json:"sub"` - Audience string `json:"aud"` - Expiration int64 `json:"exp"` - NotBefore int64 `json:"nbf"` - IssuedAt int64 `json:"iat"` - JWTID string `json:"jti"` + Issuer string `json:"iss"` + Subject string `json:"sub"` + Audience AudienceList `json:"aud"` + Expiration int64 `json:"exp"` + NotBefore int64 `json:"nbf"` + IssuedAt int64 `json:"iat"` + JWTID string `json:"jti"` // Private claims Access []*ResourceActions `json:"access"` @@ -143,8 +143,8 @@ func (t *Token) Verify(verifyOpts VerifyOptions) error { } // Verify that the Audience claim is allowed. - if !contains(verifyOpts.AcceptedAudiences, t.Claims.Audience) { - log.Infof("token intended for another audience: %q", t.Claims.Audience) + if !containsAny(verifyOpts.AcceptedAudiences, t.Claims.Audience) { + log.Infof("token intended for another audience: %v", t.Claims.Audience) return ErrInvalidToken } diff --git a/registry/auth/token/token_test.go b/registry/auth/token/token_test.go index 7fa62be5..2837ad33 100644 --- a/registry/auth/token/token_test.go +++ b/registry/auth/token/token_test.go @@ -116,7 +116,7 @@ func makeTestToken(issuer, audience string, access []*ResourceActions, rootKey l claimSet := &ClaimSet{ Issuer: issuer, Subject: "foo", - Audience: audience, + Audience: []string{audience}, Expiration: exp.Unix(), NotBefore: now.Unix(), IssuedAt: now.Unix(), diff --git a/registry/auth/token/types.go b/registry/auth/token/types.go new file mode 100644 index 00000000..2aa5c9ba --- /dev/null +++ b/registry/auth/token/types.go @@ -0,0 +1,55 @@ +package token + +import ( + "encoding/json" + "reflect" +) + +// AudienceList is a slice of strings that can be deserialized from either a single string value or a list of strings. +type AudienceList []string + +func (s *AudienceList) UnmarshalJSON(data []byte) (err error) { + var value interface{} + + if err = json.Unmarshal(data, &value); err != nil { + return err + } + + switch v := value.(type) { + case string: + *s = []string{v} + + case []string: + *s = v + + case []interface{}: + var ss []string + + for _, vv := range v { + vs, ok := vv.(string) + if !ok { + return &json.UnsupportedTypeError{ + Type: reflect.TypeOf(vv), + } + } + + ss = append(ss, vs) + } + + *s = ss + + case nil: + return nil + + default: + return &json.UnsupportedTypeError{ + Type: reflect.TypeOf(v), + } + } + + return +} + +func (s AudienceList) MarshalJSON() (b []byte, err error) { + return json.Marshal([]string(s)) +} diff --git a/registry/auth/token/types_test.go b/registry/auth/token/types_test.go new file mode 100644 index 00000000..5e547761 --- /dev/null +++ b/registry/auth/token/types_test.go @@ -0,0 +1,85 @@ +package token + +import ( + "encoding/json" + "testing" +) + +func TestAudienceList_Unmarshal(t *testing.T) { + t.Run("OK", func(t *testing.T) { + testCases := []struct { + value string + expected AudienceList + }{ + { + value: `"audience"`, + expected: AudienceList{"audience"}, + }, + { + value: `["audience1", "audience2"]`, + expected: AudienceList{"audience1", "audience2"}, + }, + { + value: `null`, + expected: nil, + }, + } + + for _, testCase := range testCases { + testCase := testCase + + t.Run("", func(t *testing.T) { + var actual AudienceList + + err := json.Unmarshal([]byte(testCase.value), &actual) + if err != nil { + t.Fatal(err) + } + + assertStringListEqual(t, testCase.expected, actual) + }) + } + }) + + t.Run("Error", func(t *testing.T) { + var actual AudienceList + + err := json.Unmarshal([]byte("1234"), &actual) + if err == nil { + t.Fatal("expected unmarshal to fail") + } + }) +} + +func TestAudienceList_Marshal(t *testing.T) { + value := AudienceList{"audience"} + + expected := `["audience"]` + + actual, err := json.Marshal(value) + if err != nil { + t.Fatal(err) + } + + if expected != string(actual) { + t.Errorf("expected marshaled list to be %v, got %v", expected, actual) + } +} + +func assertStringListEqual(t *testing.T, expected []string, actual []string) { + t.Helper() + + if len(expected) != len(actual) { + t.Errorf("length mismatch: expected %d long slice, got %d", len(expected), len(actual)) + + return + } + + for i, v := range expected { + if v != actual[i] { + t.Errorf("expected %d. item to be %q, got %q", i, v, actual[i]) + } + + return + } +} diff --git a/registry/auth/token/util.go b/registry/auth/token/util.go index d7f95be4..a219df86 100644 --- a/registry/auth/token/util.go +++ b/registry/auth/token/util.go @@ -56,3 +56,14 @@ func contains(ss []string, q string) bool { return false } + +// containsAny returns true if any of q is found in ss. +func containsAny(ss []string, q []string) bool { + for _, s := range ss { + if contains(q, s) { + return true + } + } + + return false +}