From 32d34ef364e670a650fe59267b92301ff7ed08f1 Mon Sep 17 00:00:00 2001 From: Sean Liao Date: Sun, 20 Apr 2025 15:32:12 +0100 Subject: [PATCH] internal: include clientID in auth style cache key Fixes golang/oauth2#654 Change-Id: I735891f2a77c3797662b2eadab7e7828ff14bf5f Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/666915 LUCI-TryBot-Result: Go LUCI Reviewed-by: Junyang Shao Reviewed-by: Matt Hickford Reviewed-by: Dmitri Shuralyov --- internal/token.go | 29 ++++++++++++--------------- internal/token_test.go | 45 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 58 insertions(+), 16 deletions(-) diff --git a/internal/token.go b/internal/token.go index b417456..8389f24 100644 --- a/internal/token.go +++ b/internal/token.go @@ -105,14 +105,6 @@ func (e *expirationTime) UnmarshalJSON(b []byte) error { return nil } -// RegisterBrokenAuthHeaderProvider previously did something. It is now a no-op. -// -// Deprecated: this function no longer does anything. Caller code that -// wants to avoid potential extra HTTP requests made during -// auto-probing of the provider's auth style should set -// Endpoint.AuthStyle. -func RegisterBrokenAuthHeaderProvider(tokenURL string) {} - // AuthStyle is a copy of the golang.org/x/oauth2 package's AuthStyle type. type AuthStyle int @@ -149,6 +141,11 @@ func (lc *LazyAuthStyleCache) Get() *AuthStyleCache { return c } +type authStyleCacheKey struct { + url string + clientID string +} + // AuthStyleCache is the set of tokenURLs we've successfully used via // RetrieveToken and which style auth we ended up using. // It's called a cache, but it doesn't (yet?) shrink. It's expected that @@ -156,26 +153,26 @@ func (lc *LazyAuthStyleCache) Get() *AuthStyleCache { // small. type AuthStyleCache struct { mu sync.Mutex - m map[string]AuthStyle // keyed by tokenURL + m map[authStyleCacheKey]AuthStyle } // lookupAuthStyle reports which auth style we last used with tokenURL // when calling RetrieveToken and whether we have ever done so. -func (c *AuthStyleCache) lookupAuthStyle(tokenURL string) (style AuthStyle, ok bool) { +func (c *AuthStyleCache) lookupAuthStyle(tokenURL, clientID string) (style AuthStyle, ok bool) { c.mu.Lock() defer c.mu.Unlock() - style, ok = c.m[tokenURL] + style, ok = c.m[authStyleCacheKey{tokenURL, clientID}] return } // setAuthStyle adds an entry to authStyleCache, documented above. -func (c *AuthStyleCache) setAuthStyle(tokenURL string, v AuthStyle) { +func (c *AuthStyleCache) setAuthStyle(tokenURL, clientID string, v AuthStyle) { c.mu.Lock() defer c.mu.Unlock() if c.m == nil { - c.m = make(map[string]AuthStyle) + c.m = make(map[authStyleCacheKey]AuthStyle) } - c.m[tokenURL] = v + c.m[authStyleCacheKey{tokenURL, clientID}] = v } // newTokenRequest returns a new *http.Request to retrieve a new token @@ -218,7 +215,7 @@ func cloneURLValues(v url.Values) url.Values { func RetrieveToken(ctx context.Context, clientID, clientSecret, tokenURL string, v url.Values, authStyle AuthStyle, styleCache *AuthStyleCache) (*Token, error) { needsAuthStyleProbe := authStyle == AuthStyleUnknown if needsAuthStyleProbe { - if style, ok := styleCache.lookupAuthStyle(tokenURL); ok { + if style, ok := styleCache.lookupAuthStyle(tokenURL, clientID); ok { authStyle = style needsAuthStyleProbe = false } else { @@ -248,7 +245,7 @@ func RetrieveToken(ctx context.Context, clientID, clientSecret, tokenURL string, token, err = doTokenRoundTrip(ctx, req) } if needsAuthStyleProbe && err == nil { - styleCache.setAuthStyle(tokenURL, authStyle) + styleCache.setAuthStyle(tokenURL, clientID, authStyle) } // Don't overwrite `RefreshToken` with an empty value // if this was a token refreshing request. diff --git a/internal/token_test.go b/internal/token_test.go index c08862a..ef28c11 100644 --- a/internal/token_test.go +++ b/internal/token_test.go @@ -75,3 +75,48 @@ func TestExpiresInUpperBound(t *testing.T) { t.Errorf("expiration time = %v; want %v", e, want) } } + +func TestAuthStyleCache(t *testing.T) { + var c LazyAuthStyleCache + + cases := []struct { + url string + clientID string + style AuthStyle + }{ + { + "https://host1.example.com/token", + "client_1", + AuthStyleInHeader, + }, { + "https://host2.example.com/token", + "client_2", + AuthStyleInParams, + }, { + "https://host1.example.com/token", + "client_3", + AuthStyleInParams, + }, + } + + for _, tt := range cases { + t.Run(tt.clientID, func(t *testing.T) { + cc := c.Get() + got, ok := cc.lookupAuthStyle(tt.url, tt.clientID) + if ok { + t.Fatalf("unexpected auth style found on first request: %v", got) + } + + cc.setAuthStyle(tt.url, tt.clientID, tt.style) + + got, ok = cc.lookupAuthStyle(tt.url, tt.clientID) + if !ok { + t.Fatalf("auth style not found in cache") + } + + if got != tt.style { + t.Fatalf("auth style mismatch, got=%v, want=%v", got, tt.style) + } + }) + } +}