mirror of
https://github.com/golang/oauth2.git
synced 2025-07-21 00:00:09 +08:00
internal: include clientID in auth style cache key
Fixes golang/oauth2#654 Change-Id: I735891f2a77c3797662b2eadab7e7828ff14bf5f Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/666915 LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com> Reviewed-by: Junyang Shao <shaojunyang@google.com> Reviewed-by: Matt Hickford <matt.hickford@gmail.com> Reviewed-by: Dmitri Shuralyov <dmitshur@google.com>
This commit is contained in:
parent
2d34e3091b
commit
32d34ef364
@ -105,14 +105,6 @@ func (e *expirationTime) UnmarshalJSON(b []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// RegisterBrokenAuthHeaderProvider previously did something. It is now a no-op.
|
||||
//
|
||||
// Deprecated: this function no longer does anything. Caller code that
|
||||
// wants to avoid potential extra HTTP requests made during
|
||||
// auto-probing of the provider's auth style should set
|
||||
// Endpoint.AuthStyle.
|
||||
func RegisterBrokenAuthHeaderProvider(tokenURL string) {}
|
||||
|
||||
// AuthStyle is a copy of the golang.org/x/oauth2 package's AuthStyle type.
|
||||
type AuthStyle int
|
||||
|
||||
@ -149,6 +141,11 @@ func (lc *LazyAuthStyleCache) Get() *AuthStyleCache {
|
||||
return c
|
||||
}
|
||||
|
||||
type authStyleCacheKey struct {
|
||||
url string
|
||||
clientID string
|
||||
}
|
||||
|
||||
// AuthStyleCache is the set of tokenURLs we've successfully used via
|
||||
// RetrieveToken and which style auth we ended up using.
|
||||
// It's called a cache, but it doesn't (yet?) shrink. It's expected that
|
||||
@ -156,26 +153,26 @@ func (lc *LazyAuthStyleCache) Get() *AuthStyleCache {
|
||||
// small.
|
||||
type AuthStyleCache struct {
|
||||
mu sync.Mutex
|
||||
m map[string]AuthStyle // keyed by tokenURL
|
||||
m map[authStyleCacheKey]AuthStyle
|
||||
}
|
||||
|
||||
// lookupAuthStyle reports which auth style we last used with tokenURL
|
||||
// when calling RetrieveToken and whether we have ever done so.
|
||||
func (c *AuthStyleCache) lookupAuthStyle(tokenURL string) (style AuthStyle, ok bool) {
|
||||
func (c *AuthStyleCache) lookupAuthStyle(tokenURL, clientID string) (style AuthStyle, ok bool) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
style, ok = c.m[tokenURL]
|
||||
style, ok = c.m[authStyleCacheKey{tokenURL, clientID}]
|
||||
return
|
||||
}
|
||||
|
||||
// setAuthStyle adds an entry to authStyleCache, documented above.
|
||||
func (c *AuthStyleCache) setAuthStyle(tokenURL string, v AuthStyle) {
|
||||
func (c *AuthStyleCache) setAuthStyle(tokenURL, clientID string, v AuthStyle) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if c.m == nil {
|
||||
c.m = make(map[string]AuthStyle)
|
||||
c.m = make(map[authStyleCacheKey]AuthStyle)
|
||||
}
|
||||
c.m[tokenURL] = v
|
||||
c.m[authStyleCacheKey{tokenURL, clientID}] = v
|
||||
}
|
||||
|
||||
// newTokenRequest returns a new *http.Request to retrieve a new token
|
||||
@ -218,7 +215,7 @@ func cloneURLValues(v url.Values) url.Values {
|
||||
func RetrieveToken(ctx context.Context, clientID, clientSecret, tokenURL string, v url.Values, authStyle AuthStyle, styleCache *AuthStyleCache) (*Token, error) {
|
||||
needsAuthStyleProbe := authStyle == AuthStyleUnknown
|
||||
if needsAuthStyleProbe {
|
||||
if style, ok := styleCache.lookupAuthStyle(tokenURL); ok {
|
||||
if style, ok := styleCache.lookupAuthStyle(tokenURL, clientID); ok {
|
||||
authStyle = style
|
||||
needsAuthStyleProbe = false
|
||||
} else {
|
||||
@ -248,7 +245,7 @@ func RetrieveToken(ctx context.Context, clientID, clientSecret, tokenURL string,
|
||||
token, err = doTokenRoundTrip(ctx, req)
|
||||
}
|
||||
if needsAuthStyleProbe && err == nil {
|
||||
styleCache.setAuthStyle(tokenURL, authStyle)
|
||||
styleCache.setAuthStyle(tokenURL, clientID, authStyle)
|
||||
}
|
||||
// Don't overwrite `RefreshToken` with an empty value
|
||||
// if this was a token refreshing request.
|
||||
|
@ -75,3 +75,48 @@ func TestExpiresInUpperBound(t *testing.T) {
|
||||
t.Errorf("expiration time = %v; want %v", e, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthStyleCache(t *testing.T) {
|
||||
var c LazyAuthStyleCache
|
||||
|
||||
cases := []struct {
|
||||
url string
|
||||
clientID string
|
||||
style AuthStyle
|
||||
}{
|
||||
{
|
||||
"https://host1.example.com/token",
|
||||
"client_1",
|
||||
AuthStyleInHeader,
|
||||
}, {
|
||||
"https://host2.example.com/token",
|
||||
"client_2",
|
||||
AuthStyleInParams,
|
||||
}, {
|
||||
"https://host1.example.com/token",
|
||||
"client_3",
|
||||
AuthStyleInParams,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.clientID, func(t *testing.T) {
|
||||
cc := c.Get()
|
||||
got, ok := cc.lookupAuthStyle(tt.url, tt.clientID)
|
||||
if ok {
|
||||
t.Fatalf("unexpected auth style found on first request: %v", got)
|
||||
}
|
||||
|
||||
cc.setAuthStyle(tt.url, tt.clientID, tt.style)
|
||||
|
||||
got, ok = cc.lookupAuthStyle(tt.url, tt.clientID)
|
||||
if !ok {
|
||||
t.Fatalf("auth style not found in cache")
|
||||
}
|
||||
|
||||
if got != tt.style {
|
||||
t.Fatalf("auth style mismatch, got=%v, want=%v", got, tt.style)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user