diff --git a/example_test.go b/example_test.go index cb4726f..d26a3dc 100644 --- a/example_test.go +++ b/example_test.go @@ -50,7 +50,6 @@ func ExampleConfig() { } func ExampleJWTConfig() { - var initialToken *oauth2.Token // nil means no initial token conf := &oauth2.JWTConfig{ Email: "xxx@developer.com", // The contents of your RSA private key or your PEM file @@ -67,6 +66,6 @@ func ExampleJWTConfig() { } // Initiate an http.Client, the following GET request will be // authorized and authenticated on the behalf of user@example.com. - client := conf.Client(oauth2.NoContext, initialToken) + client := conf.Client(oauth2.NoContext) client.Get("...") } diff --git a/google/example_test.go b/google/example_test.go index 6d21d5e..a59cfe9 100644 --- a/google/example_test.go +++ b/google/example_test.go @@ -69,7 +69,7 @@ func ExampleJWTConfigFromJSON() { // Initiate an http.Client. The following GET request will be // authorized and authenticated on the behalf of // your service account. - client := conf.Client(oauth2.NoContext, nil) + client := conf.Client(oauth2.NoContext) client.Get("...") } @@ -101,7 +101,7 @@ func Example_serviceAccount() { } // Initiate an http.Client, the following GET request will be // authorized and authenticated on the behalf of user@example.com. - client := conf.Client(oauth2.NoContext, nil) + client := conf.Client(oauth2.NoContext) client.Get("...") } diff --git a/google/google.go b/google/google.go index 4890776..eb6c92a 100644 --- a/google/google.go +++ b/google/google.go @@ -15,7 +15,6 @@ package google // import "golang.org/x/oauth2/google" import ( "encoding/json" - "fmt" "net" "net/http" @@ -24,6 +23,9 @@ import ( "golang.org/x/oauth2" ) +// TODO(bradfitz,jbd): import "google.golang.org/cloud/compute/metadata" instead of +// the metaClient and metadata.google.internal stuff below. + // Endpoint is Google's OAuth 2.0 endpoint. var Endpoint = oauth2.Endpoint{ AuthURL: "https://accounts.google.com/o/oauth2/auth", @@ -66,7 +68,7 @@ type metaTokenRespBody struct { // Further information about retrieving access tokens from the GCE metadata // server can be found at https://cloud.google.com/compute/docs/authentication. func ComputeTokenSource(account string) oauth2.TokenSource { - return &computeSource{account: account} + return oauth2.ReuseTokenSource(nil, &computeSource{account: account}) } type computeSource struct { diff --git a/google/source_appengine.go b/google/source_appengine.go index 9b8aa97..d0eb3da 100644 --- a/google/source_appengine.go +++ b/google/source_appengine.go @@ -29,13 +29,16 @@ type tokenLock struct { } type appEngineTokenSource struct { - ctx oauth2.Context - scopes []string - key string // guarded by package-level mutex, aeTokensMu + ctx oauth2.Context - // fetcherFunc makes the actual RPC to fetch a new access token with an expiry time. - // Provider of this function is responsible to assert that the given context is valid. - fetcherFunc func(ctx oauth2.Context, scope ...string) (string, time.Time, error) + // fetcherFunc makes the actual RPC to fetch a new access + // token with an expiry time. Provider of this function is + // responsible to assert that the given context is valid. + fetcherFunc func(ctx oauth2.Context, scope ...string) (accessToken string, expiry time.Time, err error) + + // scopes and key are guarded by the package-level mutex aeTokensMu + scopes []string + key string } func (ts *appEngineTokenSource) Token() (*oauth2.Token, error) { @@ -53,7 +56,7 @@ func (ts *appEngineTokenSource) Token() (*oauth2.Token, error) { tok.mu.Lock() defer tok.mu.Unlock() - if tok.t != nil && !tok.t.Expired() { + if tok.t.Valid() { return tok.t, nil } access, exp, err := ts.fetcherFunc(ts.ctx, ts.scopes...) diff --git a/jwt.go b/jwt.go index d7ec0d3..9507671 100644 --- a/jwt.go +++ b/jwt.go @@ -52,33 +52,21 @@ type JWTConfig struct { // TokenSource returns a JWT TokenSource using the configuration // in c and the HTTP client from the provided context. -// -// The returned TokenSource only does JWT requests when necessary but -// otherwise returns the same token repeatedly until it expires. -// -// The provided initialToken may be nil, in which case the first -// call to TokenSource will do a new JWT request. -func (c *JWTConfig) TokenSource(ctx Context, initialToken *Token) TokenSource { - return &newWhenNeededSource{ - t: initialToken, - new: jwtSource{ctx, c}, - } +func (c *JWTConfig) TokenSource(ctx Context) TokenSource { + return ReuseTokenSource(nil, jwtSource{ctx, c}) } // Client returns an HTTP client wrapping the context's // HTTP transport and adding Authorization headers with tokens // obtained from c. // -// The provided initialToken may be nil, in which case the first -// call to TokenSource will do a new JWT request. -// // The returned client and its Transport should not be modified. -func (c *JWTConfig) Client(ctx Context, initialToken *Token) *http.Client { - return NewClient(ctx, c.TokenSource(ctx, initialToken)) +func (c *JWTConfig) Client(ctx Context) *http.Client { + return NewClient(ctx, c.TokenSource(ctx)) } // jwtSource is a source that always does a signed JWT request for a token. -// It should typically be wrapped with a newWhenNeededSource. +// It should typically be wrapped with a reuseTokenSource. type jwtSource struct { ctx Context conf *JWTConfig diff --git a/jwt_test.go b/jwt_test.go index 8c2e62e..e9a732c 100644 --- a/jwt_test.go +++ b/jwt_test.go @@ -55,12 +55,12 @@ func TestJWTFetch_JSONResponse(t *testing.T) { PrivateKey: dummyPrivateKey, TokenURL: ts.URL, } - tok, err := conf.TokenSource(NoContext, nil).Token() + tok, err := conf.TokenSource(NoContext).Token() if err != nil { t.Fatal(err) } - if tok.Expired() { - t.Errorf("Token shouldn't be expired") + if !tok.Valid() { + t.Errorf("Token invalid") } if tok.AccessToken != "90d64460d14870c08c81352a05dedd3465940a7c" { t.Errorf("Unexpected access token, %#v", tok.AccessToken) @@ -89,19 +89,25 @@ func TestJWTFetch_BadResponse(t *testing.T) { PrivateKey: dummyPrivateKey, TokenURL: ts.URL, } - tok, err := conf.TokenSource(NoContext, nil).Token() + tok, err := conf.TokenSource(NoContext).Token() if err != nil { t.Fatal(err) } - if tok.AccessToken != "" { - t.Errorf("Unexpected access token, %#v.", tok.AccessToken) + if tok == nil { + t.Fatalf("token is nil") } - if tok.TokenType != "bearer" { - t.Errorf("Unexpected token type, %#v.", tok.TokenType) + if tok.Valid() { + t.Errorf("token is valid. want invalid.") + } + if tok.AccessToken != "" { + t.Errorf("Unexpected non-empty access token %q.", tok.AccessToken) + } + if want := "bearer"; tok.TokenType != want { + t.Errorf("TokenType = %q; want %q", tok.TokenType, want) } scope := tok.Extra("scope") - if scope != "user" { - t.Errorf("Unexpected value for scope: %v", scope) + if want := "user"; scope != want { + t.Errorf("token scope = %q; want %q", scope, want) } } @@ -116,7 +122,7 @@ func TestJWTFetch_BadResponseType(t *testing.T) { PrivateKey: dummyPrivateKey, TokenURL: ts.URL, } - tok, err := conf.TokenSource(NoContext, nil).Token() + tok, err := conf.TokenSource(NoContext).Token() if err == nil { t.Error("got a token; expected error") if tok.AccessToken != "" { diff --git a/oauth2.go b/oauth2.go index 3644683..5f2b145 100644 --- a/oauth2.go +++ b/oauth2.go @@ -26,7 +26,6 @@ import ( ) // Context can be an golang.org/x/net.Context, or an App Engine Context. -// In the future these will be unified. // If you don't care and aren't running on App Engine, you may use NoContext. type Context interface{} @@ -36,7 +35,7 @@ type Context interface{} var NoContext Context = nil // Config describes a typical 3-legged OAuth2 flow, with both the -// client application information and the server's URLs. +// client application information and the server's endpoint URLs. type Config struct { // ClientID is the application's ID. ClientID string @@ -45,9 +44,9 @@ type Config struct { ClientSecret string // Endpoint contains the resource server's token endpoint - // URLs. These are supplied by the server and are often - // available via site-specific packages (for example, - // google.Endpoint or github.Endpoint) + // URLs. These are constants specific to each server and are + // often available via site-specific packages, such as + // google.Endpoint or github.Endpoint. Endpoint Endpoint // RedirectURL is the URL to redirect users going through @@ -61,6 +60,7 @@ type Config struct { // A TokenSource is anything that can return a token. type TokenSource interface { // Token returns a token or an error. + // Token must be safe for concurrent use by multiple goroutines. Token() (*Token, error) } @@ -208,7 +208,7 @@ func (c *Config) Client(ctx Context, t *Token) *http.Client { // // Most users will use Config.Client instead. func (c *Config) TokenSource(ctx Context, t *Token) TokenSource { - nwn := &newWhenNeededSource{t: t} + nwn := &reuseTokenSource{t: t} nwn.new = tokenRefresher{ ctx: ctx, conf: c, @@ -239,13 +239,13 @@ func (tf tokenRefresher) Token() (*Token, error) { }) } -// newWhenNeededSource is a TokenSource that holds a single token in memory +// reuseTokenSource is a TokenSource that holds a single token in memory // and validates its expiry before each call to retrieve it with // Token. If it's expired, it will be auto-refreshed using the // new TokenSource. // // The first call to TokenRefresher must be SetToken. -type newWhenNeededSource struct { +type reuseTokenSource struct { new TokenSource // called when t is expired. mu sync.Mutex // guards t @@ -255,10 +255,10 @@ type newWhenNeededSource struct { // Token returns the current token if it's still valid, else will // refresh the current token (using r.Context for HTTP client // information) and return the new one. -func (s *newWhenNeededSource) Token() (*Token, error) { +func (s *reuseTokenSource) Token() (*Token, error) { s.mu.Lock() defer s.mu.Unlock() - if s.t != nil && !s.t.Expired() { + if s.t.Valid() { return s.t, nil } t, err := s.new.Token() @@ -410,12 +410,41 @@ var HTTPClient contextKey type contextKey struct{} // NewClient creates an *http.Client from a Context and TokenSource. -// The client's lifetime does not extend beyond the lifetime of the context. +// The returned client is not valid beyond the lifetime of the context. func NewClient(ctx Context, src TokenSource) *http.Client { return &http.Client{ Transport: &Transport{ Base: contextTransport(ctx), - Source: src, + Source: ReuseTokenSource(nil, src), }, } } + +// ReuseTokenSource returns a TokenSource which repeatedly returns the +// same token as long as it's valid, starting with t. +// When its cached token is invalid, a new token is obtained from src. +// +// ReuseTokenSource is typically used to reuse tokens from a cache +// (such as a file on disk) between runs of a program, rather than +// obtaining new tokens unnecessarily. +// +// The initial token t may be nil, in which case the TokenSource is +// wrapped in a caching version if it isn't one already. This also +// means it's always safe to wrap ReuseTokenSource around any other +// TokenSource without adverse effects. +func ReuseTokenSource(t *Token, src TokenSource) TokenSource { + // Don't wrap a reuseTokenSource in itself. That would work, + // but cause an unnecessary number of mutex operations. + // Just build the equivalent one. + if rt, ok := src.(*reuseTokenSource); ok { + if t == nil { + // Just use it directly. + return rt + } + src = rt.new + } + return &reuseTokenSource{ + t: t, + new: src, + } +} diff --git a/oauth2_test.go b/oauth2_test.go index c567c3a..804098a 100644 --- a/oauth2_test.go +++ b/oauth2_test.go @@ -99,8 +99,8 @@ func TestExchangeRequest(t *testing.T) { if err != nil { t.Error(err) } - if tok.Expired() { - t.Errorf("Token shouldn't be expired.") + if !tok.Valid() { + t.Fatalf("Token invalid. Got: %#v", tok) } if tok.AccessToken != "90d64460d14870c08c81352a05dedd3465940a7c" { t.Errorf("Unexpected access token, %#v.", tok.AccessToken) @@ -143,8 +143,8 @@ func TestExchangeRequest_JSONResponse(t *testing.T) { if err != nil { t.Error(err) } - if tok.Expired() { - t.Errorf("Token shouldn't be expired.") + if !tok.Valid() { + t.Fatalf("Token invalid. Got: %#v", tok) } if tok.AccessToken != "90d64460d14870c08c81352a05dedd3465940a7c" { t.Errorf("Unexpected access token, %#v.", tok.AccessToken) diff --git a/token.go b/token.go index 0c52888..6aa0b41 100644 --- a/token.go +++ b/token.go @@ -74,14 +74,16 @@ func (t *Token) Extra(key string) string { return "" } -// Expired returns true if there is no access token or the -// access token is expired. -func (t *Token) Expired() bool { - if t.AccessToken == "" { - return true - } +// expired reports whether the token is expired. +// t must be non-nil. +func (t *Token) expired() bool { if t.Expiry.IsZero() { return false } return t.Expiry.Before(time.Now()) } + +// Valid reports whether t is non-nil, has an AccessToken, and is not expired. +func (t *Token) Valid() bool { + return t != nil && t.AccessToken != "" && !t.expired() +} diff --git a/transport_test.go b/transport_test.go index b3414e3..efb8232 100644 --- a/transport_test.go +++ b/transport_test.go @@ -32,10 +32,10 @@ func TestTransportTokenSource(t *testing.T) { client.Get(server.URL) } -func TestExpiredWithNoAccessToken(t *testing.T) { +func TestTokenValidNoAccessToken(t *testing.T) { token := &Token{} - if !token.Expired() { - t.Errorf("Token should be expired if no access token is provided") + if token.Valid() { + t.Errorf("Token should not be valid with no access token") } } @@ -43,8 +43,8 @@ func TestExpiredWithExpiry(t *testing.T) { token := &Token{ Expiry: time.Now().Add(-5 * time.Hour), } - if !token.Expired() { - t.Errorf("Token should be expired if no access token is provided") + if token.Valid() { + t.Errorf("Token should not be valid if it expired in the past") } }