From c048af9da2ff4db86f1ad989c73c6ec7a62c57a4 Mon Sep 17 00:00:00 2001 From: Burcu Dogan Date: Thu, 13 Nov 2014 15:41:14 +1100 Subject: [PATCH] Add Cacher interface. --- oauth2.go | 69 +++++++++++++++++++++++++++++++++++++---------- oauth2_test.go | 33 +++++++++++++++++++++-- transport.go | 17 ++++++++---- transport_test.go | 6 +---- 4 files changed, 99 insertions(+), 26 deletions(-) diff --git a/oauth2.go b/oauth2.go index ce166b1..38f4a19 100644 --- a/oauth2.go +++ b/oauth2.go @@ -22,6 +22,18 @@ import ( "strings" ) +// Cacher implementations read and write OAuth 2.0 tokens from a cache. +type Cacher interface { + // Read reads the token from the cache. + // If the read is successful, it should return the token and a nil error. + // The returned tokens may be expired tokens. + // If there is no token in the cache, it should return a nil token and a nil error. + // It should return a non-nil error when an unrecoverable failure occurs. + Read() (*Token, error) + // Write writes the token to the cache. + Write(*Token) +} + // Option represents a function that applies some state to // an Options object. type Option func(*Options) error @@ -91,6 +103,16 @@ func RoundTripper(tr http.RoundTripper) Option { } } +// Cache requires a Cacher implementation. It will initially read +// the token if the transport is initialized with NewTransportFromCache +// and will write the refreshed tokens back to the cache. +func Cache(c Cacher) Option { + return func(o *Options) error { + o.Cache = c + return nil + } +} + type Flow struct { opts Options } @@ -109,11 +131,11 @@ func New(options ...Option) (*Flow, error) { case f.opts.TokenFetcherFunc != nil: return f, nil case f.opts.AUD != nil: - // TODO(jbd): Assert required JWT params. + // TODO(jbd): Assert the required JWT params. f.opts.TokenFetcherFunc = makeTwoLeggedFetcher(&f.opts) return f, nil case f.opts.AuthURL != nil && f.opts.TokenURL != nil: - // TODO(jbd): Assert required OAuth2 params. + // TODO(jbd): Assert the required OAuth2 params. f.opts.TokenFetcherFunc = makeThreeLeggedFetcher(&f.opts) return f, nil default: @@ -175,6 +197,23 @@ func (f *Flow) exchange(code string) (*Token, error) { }) } +// NewTransportFromCache reads the token from the cache and returns +// a Transport that is authorized and the authenticated +// by the returned token. +func (f *Flow) NewTransportFromCache() (*Transport, error) { + if f.opts.Cache == nil { + return nil, errors.New("oauth2: no cache is set") + } + tok, err := f.opts.Cache.Read() + if err != nil { + return nil, err + } + if tok == nil { + return nil, nil + } + return f.newTransportFromToken(tok), nil +} + // NewTransportFromCode exchanges the code to retrieve a new access token // and returns an authorized and authenticated Transport. func (f *Flow) NewTransportFromCode(code string) (*Transport, error) { @@ -182,22 +221,22 @@ func (f *Flow) NewTransportFromCode(code string) (*Transport, error) { if err != nil { return nil, err } - return f.NewTransportFromToken(token), nil -} - -// NewTransportFromToken returns a new Transport that is authorized -// and authenticated with the provided token. -func (f *Flow) NewTransportFromToken(t *Token) *Transport { - tr := f.opts.Transport - if tr == nil { - tr = http.DefaultTransport - } - return newTransport(tr, f.opts.TokenFetcherFunc, t) + return f.newTransportFromToken(token), nil } // NewTransport returns a Transport. func (f *Flow) NewTransport() *Transport { - return f.NewTransportFromToken(nil) + return f.newTransportFromToken(nil) +} + +// newTransportFromToken returns a new Transport that is authorized +// and authenticated with the provided token. +func (f *Flow) newTransportFromToken(t *Token) *Transport { + tr := f.opts.Transport + if tr == nil { + tr = http.DefaultTransport + } + return newTransport(tr, &f.opts, t) } func makeThreeLeggedFetcher(o *Options) func(t *Token) (*Token, error) { @@ -255,6 +294,8 @@ type Options struct { // AUD represents the token endpoint required to complete the 2-legged JWT flow. AUD *url.URL + Cache Cacher + TokenFetcherFunc func(t *Token) (*Token, error) Transport http.RoundTripper diff --git a/oauth2_test.go b/oauth2_test.go index 92419c6..d356f62 100644 --- a/oauth2_test.go +++ b/oauth2_test.go @@ -20,6 +20,19 @@ func (t *mockTransport) RoundTrip(req *http.Request) (resp *http.Response, err e return t.rt(req) } +type mockCache struct { + token *Token + readErr error +} + +func (c *mockCache) Read() (*Token, error) { + return c.token, c.readErr +} + +func (c *mockCache) Write(*Token) { + // do nothing +} + func newTestFlow(url string) *Flow { f, _ := New( Client("CLIENT_ID", "CLIENT_SECRET"), @@ -211,7 +224,8 @@ func TestTokenRefreshRequest(t *testing.T) { })) defer ts.Close() f := newTestFlow(ts.URL) - tr := f.NewTransportFromToken(&Token{RefreshToken: "REFRESH_TOKEN"}) + tr := f.NewTransport() + tr.SetToken(&Token{RefreshToken: "REFRESH_TOKEN"}) c := http.Client{Transport: tr} c.Get(ts.URL + "/somethingelse") } @@ -235,10 +249,25 @@ func TestFetchWithNoRefreshToken(t *testing.T) { })) defer ts.Close() f := newTestFlow(ts.URL) - tr := f.NewTransportFromToken(&Token{}) + tr := f.NewTransport() c := http.Client{Transport: tr} _, err := c.Get(ts.URL + "/somethingelse") if err == nil { t.Errorf("Fetch should return an error if no refresh token is set") } } + +func TestCacheNoToken(t *testing.T) { + f, _ := New( + Client("CLIENT_ID", "CLIENT_SECRET"), + Endpoint("/auth", "/token"), + Cache(&mockCache{token: nil, readErr: nil}), + ) + tr, err := f.NewTransportFromCache() + if err != nil { + t.Errorf("No error expected, %v is found", err) + } + if tr != nil { + t.Errorf("No transport should have been initiated, tr is found to be %v", tr) + } +} diff --git a/transport.go b/transport.go index e1a35b0..9df11d8 100644 --- a/transport.go +++ b/transport.go @@ -66,8 +66,8 @@ func (t *Token) Expired() bool { // Transport is an http.RoundTripper that makes OAuth 2.0 HTTP requests. type Transport struct { - fetcher func(t *Token) (*Token, error) - base http.RoundTripper + opts *Options + base http.RoundTripper mu sync.RWMutex token *Token @@ -76,8 +76,12 @@ type Transport struct { // NewTransport creates a new Transport that uses the provided // token fetcher as token retrieving strategy. It authenticates // the requests and delegates origTransport to make the actual requests. -func newTransport(base http.RoundTripper, fn func(t *Token) (*Token, error), token *Token) *Transport { - return &Transport{base: base, fetcher: fn, token: token} +func newTransport(base http.RoundTripper, opts *Options, token *Token) *Transport { + return &Transport{ + base: base, + opts: opts, + token: token, + } } // RoundTrip authorizes and authenticates the request with an @@ -94,6 +98,9 @@ func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error return nil, err } token = t.Token() + if t.opts.Cache != nil { + t.opts.Cache.Write(token) + } } // To set the Authorization header, we must make a copy of the Request @@ -129,7 +136,7 @@ func (t *Transport) SetToken(v *Token) { func (t *Transport) RefreshToken() error { t.mu.Lock() defer t.mu.Unlock() - token, err := t.fetcher(t.token) + token, err := t.opts.TokenFetcherFunc(t.token) if err != nil { return err } diff --git a/transport_test.go b/transport_test.go index f7cdbc4..5fbccf6 100644 --- a/transport_test.go +++ b/transport_test.go @@ -15,10 +15,6 @@ func (f *mockTokenFetcher) Fn() func(*Token) (*Token, error) { } } -func (f *mockTokenFetcher) FetchToken(existing *Token) (*Token, error) { - return f.token, nil -} - func TestInitialTokenRead(t *testing.T) { tr := newTransport(http.DefaultTransport, nil, &Token{AccessToken: "abc"}) server := newMockServer(func(w http.ResponseWriter, r *http.Request) { @@ -37,7 +33,7 @@ func TestTokenFetch(t *testing.T) { AccessToken: "abc", }, } - tr := newTransport(http.DefaultTransport, fetcher.Fn(), nil) + tr := newTransport(http.DefaultTransport, &Options{TokenFetcherFunc: fetcher.Fn()}, nil) server := newMockServer(func(w http.ResponseWriter, r *http.Request) { if r.Header.Get("Authorization") != "Bearer abc" { t.Errorf("Transport doesn't set the Authorization header from the fetched token")