mirror of
https://github.com/golang/oauth2.git
synced 2025-07-21 00:00:09 +08:00
internal: include clientID in auth style cache key
Fixes golang/oauth2#654 Change-Id: I735891f2a77c3797662b2eadab7e7828ff14bf5f Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/666915 LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com> Reviewed-by: Junyang Shao <shaojunyang@google.com> Reviewed-by: Matt Hickford <matt.hickford@gmail.com> Reviewed-by: Dmitri Shuralyov <dmitshur@google.com>
This commit is contained in:
parent
2d34e3091b
commit
32d34ef364
@ -105,14 +105,6 @@ func (e *expirationTime) UnmarshalJSON(b []byte) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// RegisterBrokenAuthHeaderProvider previously did something. It is now a no-op.
|
|
||||||
//
|
|
||||||
// Deprecated: this function no longer does anything. Caller code that
|
|
||||||
// wants to avoid potential extra HTTP requests made during
|
|
||||||
// auto-probing of the provider's auth style should set
|
|
||||||
// Endpoint.AuthStyle.
|
|
||||||
func RegisterBrokenAuthHeaderProvider(tokenURL string) {}
|
|
||||||
|
|
||||||
// AuthStyle is a copy of the golang.org/x/oauth2 package's AuthStyle type.
|
// AuthStyle is a copy of the golang.org/x/oauth2 package's AuthStyle type.
|
||||||
type AuthStyle int
|
type AuthStyle int
|
||||||
|
|
||||||
@ -149,6 +141,11 @@ func (lc *LazyAuthStyleCache) Get() *AuthStyleCache {
|
|||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type authStyleCacheKey struct {
|
||||||
|
url string
|
||||||
|
clientID string
|
||||||
|
}
|
||||||
|
|
||||||
// AuthStyleCache is the set of tokenURLs we've successfully used via
|
// AuthStyleCache is the set of tokenURLs we've successfully used via
|
||||||
// RetrieveToken and which style auth we ended up using.
|
// RetrieveToken and which style auth we ended up using.
|
||||||
// It's called a cache, but it doesn't (yet?) shrink. It's expected that
|
// It's called a cache, but it doesn't (yet?) shrink. It's expected that
|
||||||
@ -156,26 +153,26 @@ func (lc *LazyAuthStyleCache) Get() *AuthStyleCache {
|
|||||||
// small.
|
// small.
|
||||||
type AuthStyleCache struct {
|
type AuthStyleCache struct {
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
m map[string]AuthStyle // keyed by tokenURL
|
m map[authStyleCacheKey]AuthStyle
|
||||||
}
|
}
|
||||||
|
|
||||||
// lookupAuthStyle reports which auth style we last used with tokenURL
|
// lookupAuthStyle reports which auth style we last used with tokenURL
|
||||||
// when calling RetrieveToken and whether we have ever done so.
|
// when calling RetrieveToken and whether we have ever done so.
|
||||||
func (c *AuthStyleCache) lookupAuthStyle(tokenURL string) (style AuthStyle, ok bool) {
|
func (c *AuthStyleCache) lookupAuthStyle(tokenURL, clientID string) (style AuthStyle, ok bool) {
|
||||||
c.mu.Lock()
|
c.mu.Lock()
|
||||||
defer c.mu.Unlock()
|
defer c.mu.Unlock()
|
||||||
style, ok = c.m[tokenURL]
|
style, ok = c.m[authStyleCacheKey{tokenURL, clientID}]
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// setAuthStyle adds an entry to authStyleCache, documented above.
|
// setAuthStyle adds an entry to authStyleCache, documented above.
|
||||||
func (c *AuthStyleCache) setAuthStyle(tokenURL string, v AuthStyle) {
|
func (c *AuthStyleCache) setAuthStyle(tokenURL, clientID string, v AuthStyle) {
|
||||||
c.mu.Lock()
|
c.mu.Lock()
|
||||||
defer c.mu.Unlock()
|
defer c.mu.Unlock()
|
||||||
if c.m == nil {
|
if c.m == nil {
|
||||||
c.m = make(map[string]AuthStyle)
|
c.m = make(map[authStyleCacheKey]AuthStyle)
|
||||||
}
|
}
|
||||||
c.m[tokenURL] = v
|
c.m[authStyleCacheKey{tokenURL, clientID}] = v
|
||||||
}
|
}
|
||||||
|
|
||||||
// newTokenRequest returns a new *http.Request to retrieve a new token
|
// newTokenRequest returns a new *http.Request to retrieve a new token
|
||||||
@ -218,7 +215,7 @@ func cloneURLValues(v url.Values) url.Values {
|
|||||||
func RetrieveToken(ctx context.Context, clientID, clientSecret, tokenURL string, v url.Values, authStyle AuthStyle, styleCache *AuthStyleCache) (*Token, error) {
|
func RetrieveToken(ctx context.Context, clientID, clientSecret, tokenURL string, v url.Values, authStyle AuthStyle, styleCache *AuthStyleCache) (*Token, error) {
|
||||||
needsAuthStyleProbe := authStyle == AuthStyleUnknown
|
needsAuthStyleProbe := authStyle == AuthStyleUnknown
|
||||||
if needsAuthStyleProbe {
|
if needsAuthStyleProbe {
|
||||||
if style, ok := styleCache.lookupAuthStyle(tokenURL); ok {
|
if style, ok := styleCache.lookupAuthStyle(tokenURL, clientID); ok {
|
||||||
authStyle = style
|
authStyle = style
|
||||||
needsAuthStyleProbe = false
|
needsAuthStyleProbe = false
|
||||||
} else {
|
} else {
|
||||||
@ -248,7 +245,7 @@ func RetrieveToken(ctx context.Context, clientID, clientSecret, tokenURL string,
|
|||||||
token, err = doTokenRoundTrip(ctx, req)
|
token, err = doTokenRoundTrip(ctx, req)
|
||||||
}
|
}
|
||||||
if needsAuthStyleProbe && err == nil {
|
if needsAuthStyleProbe && err == nil {
|
||||||
styleCache.setAuthStyle(tokenURL, authStyle)
|
styleCache.setAuthStyle(tokenURL, clientID, authStyle)
|
||||||
}
|
}
|
||||||
// Don't overwrite `RefreshToken` with an empty value
|
// Don't overwrite `RefreshToken` with an empty value
|
||||||
// if this was a token refreshing request.
|
// if this was a token refreshing request.
|
||||||
|
@ -75,3 +75,48 @@ func TestExpiresInUpperBound(t *testing.T) {
|
|||||||
t.Errorf("expiration time = %v; want %v", e, want)
|
t.Errorf("expiration time = %v; want %v", e, want)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAuthStyleCache(t *testing.T) {
|
||||||
|
var c LazyAuthStyleCache
|
||||||
|
|
||||||
|
cases := []struct {
|
||||||
|
url string
|
||||||
|
clientID string
|
||||||
|
style AuthStyle
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
"https://host1.example.com/token",
|
||||||
|
"client_1",
|
||||||
|
AuthStyleInHeader,
|
||||||
|
}, {
|
||||||
|
"https://host2.example.com/token",
|
||||||
|
"client_2",
|
||||||
|
AuthStyleInParams,
|
||||||
|
}, {
|
||||||
|
"https://host1.example.com/token",
|
||||||
|
"client_3",
|
||||||
|
AuthStyleInParams,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range cases {
|
||||||
|
t.Run(tt.clientID, func(t *testing.T) {
|
||||||
|
cc := c.Get()
|
||||||
|
got, ok := cc.lookupAuthStyle(tt.url, tt.clientID)
|
||||||
|
if ok {
|
||||||
|
t.Fatalf("unexpected auth style found on first request: %v", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
cc.setAuthStyle(tt.url, tt.clientID, tt.style)
|
||||||
|
|
||||||
|
got, ok = cc.lookupAuthStyle(tt.url, tt.clientID)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("auth style not found in cache")
|
||||||
|
}
|
||||||
|
|
||||||
|
if got != tt.style {
|
||||||
|
t.Fatalf("auth style mismatch, got=%v, want=%v", got, tt.style)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user