internal: include clientID in auth style cache key

Fixes golang/oauth2#654

Change-Id: I735891f2a77c3797662b2eadab7e7828ff14bf5f
Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/666915
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Junyang Shao <shaojunyang@google.com>
Reviewed-by: Matt Hickford <matt.hickford@gmail.com>
Reviewed-by: Dmitri Shuralyov <dmitshur@google.com>
This commit is contained in:
Sean Liao 2025-04-20 15:32:12 +01:00
parent 2d34e3091b
commit 32d34ef364
2 changed files with 58 additions and 16 deletions

View File

@ -105,14 +105,6 @@ func (e *expirationTime) UnmarshalJSON(b []byte) error {
return nil
}
// RegisterBrokenAuthHeaderProvider previously did something. It is now a no-op.
//
// Deprecated: this function no longer does anything. Caller code that
// wants to avoid potential extra HTTP requests made during
// auto-probing of the provider's auth style should set
// Endpoint.AuthStyle.
func RegisterBrokenAuthHeaderProvider(tokenURL string) {}
// AuthStyle is a copy of the golang.org/x/oauth2 package's AuthStyle type.
type AuthStyle int
@ -149,6 +141,11 @@ func (lc *LazyAuthStyleCache) Get() *AuthStyleCache {
return c
}
type authStyleCacheKey struct {
url string
clientID string
}
// AuthStyleCache is the set of tokenURLs we've successfully used via
// RetrieveToken and which style auth we ended up using.
// It's called a cache, but it doesn't (yet?) shrink. It's expected that
@ -156,26 +153,26 @@ func (lc *LazyAuthStyleCache) Get() *AuthStyleCache {
// small.
type AuthStyleCache struct {
mu sync.Mutex
m map[string]AuthStyle // keyed by tokenURL
m map[authStyleCacheKey]AuthStyle
}
// lookupAuthStyle reports which auth style we last used with tokenURL
// when calling RetrieveToken and whether we have ever done so.
func (c *AuthStyleCache) lookupAuthStyle(tokenURL string) (style AuthStyle, ok bool) {
func (c *AuthStyleCache) lookupAuthStyle(tokenURL, clientID string) (style AuthStyle, ok bool) {
c.mu.Lock()
defer c.mu.Unlock()
style, ok = c.m[tokenURL]
style, ok = c.m[authStyleCacheKey{tokenURL, clientID}]
return
}
// setAuthStyle adds an entry to authStyleCache, documented above.
func (c *AuthStyleCache) setAuthStyle(tokenURL string, v AuthStyle) {
func (c *AuthStyleCache) setAuthStyle(tokenURL, clientID string, v AuthStyle) {
c.mu.Lock()
defer c.mu.Unlock()
if c.m == nil {
c.m = make(map[string]AuthStyle)
c.m = make(map[authStyleCacheKey]AuthStyle)
}
c.m[tokenURL] = v
c.m[authStyleCacheKey{tokenURL, clientID}] = v
}
// newTokenRequest returns a new *http.Request to retrieve a new token
@ -218,7 +215,7 @@ func cloneURLValues(v url.Values) url.Values {
func RetrieveToken(ctx context.Context, clientID, clientSecret, tokenURL string, v url.Values, authStyle AuthStyle, styleCache *AuthStyleCache) (*Token, error) {
needsAuthStyleProbe := authStyle == AuthStyleUnknown
if needsAuthStyleProbe {
if style, ok := styleCache.lookupAuthStyle(tokenURL); ok {
if style, ok := styleCache.lookupAuthStyle(tokenURL, clientID); ok {
authStyle = style
needsAuthStyleProbe = false
} else {
@ -248,7 +245,7 @@ func RetrieveToken(ctx context.Context, clientID, clientSecret, tokenURL string,
token, err = doTokenRoundTrip(ctx, req)
}
if needsAuthStyleProbe && err == nil {
styleCache.setAuthStyle(tokenURL, authStyle)
styleCache.setAuthStyle(tokenURL, clientID, authStyle)
}
// Don't overwrite `RefreshToken` with an empty value
// if this was a token refreshing request.

View File

@ -75,3 +75,48 @@ func TestExpiresInUpperBound(t *testing.T) {
t.Errorf("expiration time = %v; want %v", e, want)
}
}
func TestAuthStyleCache(t *testing.T) {
var c LazyAuthStyleCache
cases := []struct {
url string
clientID string
style AuthStyle
}{
{
"https://host1.example.com/token",
"client_1",
AuthStyleInHeader,
}, {
"https://host2.example.com/token",
"client_2",
AuthStyleInParams,
}, {
"https://host1.example.com/token",
"client_3",
AuthStyleInParams,
},
}
for _, tt := range cases {
t.Run(tt.clientID, func(t *testing.T) {
cc := c.Get()
got, ok := cc.lookupAuthStyle(tt.url, tt.clientID)
if ok {
t.Fatalf("unexpected auth style found on first request: %v", got)
}
cc.setAuthStyle(tt.url, tt.clientID, tt.style)
got, ok = cc.lookupAuthStyle(tt.url, tt.clientID)
if !ok {
t.Fatalf("auth style not found in cache")
}
if got != tt.style {
t.Fatalf("auth style mismatch, got=%v, want=%v", got, tt.style)
}
})
}
}