mirror of
https://github.com/golang/oauth2.git
synced 2025-07-21 00:00:09 +08:00
Add Cacher interface.
This commit is contained in:
parent
2d3ce25e9a
commit
c048af9da2
69
oauth2.go
69
oauth2.go
@ -22,6 +22,18 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Cacher implementations read and write OAuth 2.0 tokens from a cache.
|
||||||
|
type Cacher interface {
|
||||||
|
// Read reads the token from the cache.
|
||||||
|
// If the read is successful, it should return the token and a nil error.
|
||||||
|
// The returned tokens may be expired tokens.
|
||||||
|
// If there is no token in the cache, it should return a nil token and a nil error.
|
||||||
|
// It should return a non-nil error when an unrecoverable failure occurs.
|
||||||
|
Read() (*Token, error)
|
||||||
|
// Write writes the token to the cache.
|
||||||
|
Write(*Token)
|
||||||
|
}
|
||||||
|
|
||||||
// Option represents a function that applies some state to
|
// Option represents a function that applies some state to
|
||||||
// an Options object.
|
// an Options object.
|
||||||
type Option func(*Options) error
|
type Option func(*Options) error
|
||||||
@ -91,6 +103,16 @@ func RoundTripper(tr http.RoundTripper) Option {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Cache requires a Cacher implementation. It will initially read
|
||||||
|
// the token if the transport is initialized with NewTransportFromCache
|
||||||
|
// and will write the refreshed tokens back to the cache.
|
||||||
|
func Cache(c Cacher) Option {
|
||||||
|
return func(o *Options) error {
|
||||||
|
o.Cache = c
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
type Flow struct {
|
type Flow struct {
|
||||||
opts Options
|
opts Options
|
||||||
}
|
}
|
||||||
@ -109,11 +131,11 @@ func New(options ...Option) (*Flow, error) {
|
|||||||
case f.opts.TokenFetcherFunc != nil:
|
case f.opts.TokenFetcherFunc != nil:
|
||||||
return f, nil
|
return f, nil
|
||||||
case f.opts.AUD != nil:
|
case f.opts.AUD != nil:
|
||||||
// TODO(jbd): Assert required JWT params.
|
// TODO(jbd): Assert the required JWT params.
|
||||||
f.opts.TokenFetcherFunc = makeTwoLeggedFetcher(&f.opts)
|
f.opts.TokenFetcherFunc = makeTwoLeggedFetcher(&f.opts)
|
||||||
return f, nil
|
return f, nil
|
||||||
case f.opts.AuthURL != nil && f.opts.TokenURL != nil:
|
case f.opts.AuthURL != nil && f.opts.TokenURL != nil:
|
||||||
// TODO(jbd): Assert required OAuth2 params.
|
// TODO(jbd): Assert the required OAuth2 params.
|
||||||
f.opts.TokenFetcherFunc = makeThreeLeggedFetcher(&f.opts)
|
f.opts.TokenFetcherFunc = makeThreeLeggedFetcher(&f.opts)
|
||||||
return f, nil
|
return f, nil
|
||||||
default:
|
default:
|
||||||
@ -175,6 +197,23 @@ func (f *Flow) exchange(code string) (*Token, error) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NewTransportFromCache reads the token from the cache and returns
|
||||||
|
// a Transport that is authorized and the authenticated
|
||||||
|
// by the returned token.
|
||||||
|
func (f *Flow) NewTransportFromCache() (*Transport, error) {
|
||||||
|
if f.opts.Cache == nil {
|
||||||
|
return nil, errors.New("oauth2: no cache is set")
|
||||||
|
}
|
||||||
|
tok, err := f.opts.Cache.Read()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if tok == nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return f.newTransportFromToken(tok), nil
|
||||||
|
}
|
||||||
|
|
||||||
// NewTransportFromCode exchanges the code to retrieve a new access token
|
// NewTransportFromCode exchanges the code to retrieve a new access token
|
||||||
// and returns an authorized and authenticated Transport.
|
// and returns an authorized and authenticated Transport.
|
||||||
func (f *Flow) NewTransportFromCode(code string) (*Transport, error) {
|
func (f *Flow) NewTransportFromCode(code string) (*Transport, error) {
|
||||||
@ -182,22 +221,22 @@ func (f *Flow) NewTransportFromCode(code string) (*Transport, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return f.NewTransportFromToken(token), nil
|
return f.newTransportFromToken(token), nil
|
||||||
}
|
|
||||||
|
|
||||||
// NewTransportFromToken returns a new Transport that is authorized
|
|
||||||
// and authenticated with the provided token.
|
|
||||||
func (f *Flow) NewTransportFromToken(t *Token) *Transport {
|
|
||||||
tr := f.opts.Transport
|
|
||||||
if tr == nil {
|
|
||||||
tr = http.DefaultTransport
|
|
||||||
}
|
|
||||||
return newTransport(tr, f.opts.TokenFetcherFunc, t)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewTransport returns a Transport.
|
// NewTransport returns a Transport.
|
||||||
func (f *Flow) NewTransport() *Transport {
|
func (f *Flow) NewTransport() *Transport {
|
||||||
return f.NewTransportFromToken(nil)
|
return f.newTransportFromToken(nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// newTransportFromToken returns a new Transport that is authorized
|
||||||
|
// and authenticated with the provided token.
|
||||||
|
func (f *Flow) newTransportFromToken(t *Token) *Transport {
|
||||||
|
tr := f.opts.Transport
|
||||||
|
if tr == nil {
|
||||||
|
tr = http.DefaultTransport
|
||||||
|
}
|
||||||
|
return newTransport(tr, &f.opts, t)
|
||||||
}
|
}
|
||||||
|
|
||||||
func makeThreeLeggedFetcher(o *Options) func(t *Token) (*Token, error) {
|
func makeThreeLeggedFetcher(o *Options) func(t *Token) (*Token, error) {
|
||||||
@ -255,6 +294,8 @@ type Options struct {
|
|||||||
// AUD represents the token endpoint required to complete the 2-legged JWT flow.
|
// AUD represents the token endpoint required to complete the 2-legged JWT flow.
|
||||||
AUD *url.URL
|
AUD *url.URL
|
||||||
|
|
||||||
|
Cache Cacher
|
||||||
|
|
||||||
TokenFetcherFunc func(t *Token) (*Token, error)
|
TokenFetcherFunc func(t *Token) (*Token, error)
|
||||||
|
|
||||||
Transport http.RoundTripper
|
Transport http.RoundTripper
|
||||||
|
@ -20,6 +20,19 @@ func (t *mockTransport) RoundTrip(req *http.Request) (resp *http.Response, err e
|
|||||||
return t.rt(req)
|
return t.rt(req)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type mockCache struct {
|
||||||
|
token *Token
|
||||||
|
readErr error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *mockCache) Read() (*Token, error) {
|
||||||
|
return c.token, c.readErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *mockCache) Write(*Token) {
|
||||||
|
// do nothing
|
||||||
|
}
|
||||||
|
|
||||||
func newTestFlow(url string) *Flow {
|
func newTestFlow(url string) *Flow {
|
||||||
f, _ := New(
|
f, _ := New(
|
||||||
Client("CLIENT_ID", "CLIENT_SECRET"),
|
Client("CLIENT_ID", "CLIENT_SECRET"),
|
||||||
@ -211,7 +224,8 @@ func TestTokenRefreshRequest(t *testing.T) {
|
|||||||
}))
|
}))
|
||||||
defer ts.Close()
|
defer ts.Close()
|
||||||
f := newTestFlow(ts.URL)
|
f := newTestFlow(ts.URL)
|
||||||
tr := f.NewTransportFromToken(&Token{RefreshToken: "REFRESH_TOKEN"})
|
tr := f.NewTransport()
|
||||||
|
tr.SetToken(&Token{RefreshToken: "REFRESH_TOKEN"})
|
||||||
c := http.Client{Transport: tr}
|
c := http.Client{Transport: tr}
|
||||||
c.Get(ts.URL + "/somethingelse")
|
c.Get(ts.URL + "/somethingelse")
|
||||||
}
|
}
|
||||||
@ -235,10 +249,25 @@ func TestFetchWithNoRefreshToken(t *testing.T) {
|
|||||||
}))
|
}))
|
||||||
defer ts.Close()
|
defer ts.Close()
|
||||||
f := newTestFlow(ts.URL)
|
f := newTestFlow(ts.URL)
|
||||||
tr := f.NewTransportFromToken(&Token{})
|
tr := f.NewTransport()
|
||||||
c := http.Client{Transport: tr}
|
c := http.Client{Transport: tr}
|
||||||
_, err := c.Get(ts.URL + "/somethingelse")
|
_, err := c.Get(ts.URL + "/somethingelse")
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Errorf("Fetch should return an error if no refresh token is set")
|
t.Errorf("Fetch should return an error if no refresh token is set")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCacheNoToken(t *testing.T) {
|
||||||
|
f, _ := New(
|
||||||
|
Client("CLIENT_ID", "CLIENT_SECRET"),
|
||||||
|
Endpoint("/auth", "/token"),
|
||||||
|
Cache(&mockCache{token: nil, readErr: nil}),
|
||||||
|
)
|
||||||
|
tr, err := f.NewTransportFromCache()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("No error expected, %v is found", err)
|
||||||
|
}
|
||||||
|
if tr != nil {
|
||||||
|
t.Errorf("No transport should have been initiated, tr is found to be %v", tr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
17
transport.go
17
transport.go
@ -66,8 +66,8 @@ func (t *Token) Expired() bool {
|
|||||||
|
|
||||||
// Transport is an http.RoundTripper that makes OAuth 2.0 HTTP requests.
|
// Transport is an http.RoundTripper that makes OAuth 2.0 HTTP requests.
|
||||||
type Transport struct {
|
type Transport struct {
|
||||||
fetcher func(t *Token) (*Token, error)
|
opts *Options
|
||||||
base http.RoundTripper
|
base http.RoundTripper
|
||||||
|
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
token *Token
|
token *Token
|
||||||
@ -76,8 +76,12 @@ type Transport struct {
|
|||||||
// NewTransport creates a new Transport that uses the provided
|
// NewTransport creates a new Transport that uses the provided
|
||||||
// token fetcher as token retrieving strategy. It authenticates
|
// token fetcher as token retrieving strategy. It authenticates
|
||||||
// the requests and delegates origTransport to make the actual requests.
|
// the requests and delegates origTransport to make the actual requests.
|
||||||
func newTransport(base http.RoundTripper, fn func(t *Token) (*Token, error), token *Token) *Transport {
|
func newTransport(base http.RoundTripper, opts *Options, token *Token) *Transport {
|
||||||
return &Transport{base: base, fetcher: fn, token: token}
|
return &Transport{
|
||||||
|
base: base,
|
||||||
|
opts: opts,
|
||||||
|
token: token,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// RoundTrip authorizes and authenticates the request with an
|
// RoundTrip authorizes and authenticates the request with an
|
||||||
@ -94,6 +98,9 @@ func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
token = t.Token()
|
token = t.Token()
|
||||||
|
if t.opts.Cache != nil {
|
||||||
|
t.opts.Cache.Write(token)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// To set the Authorization header, we must make a copy of the Request
|
// To set the Authorization header, we must make a copy of the Request
|
||||||
@ -129,7 +136,7 @@ func (t *Transport) SetToken(v *Token) {
|
|||||||
func (t *Transport) RefreshToken() error {
|
func (t *Transport) RefreshToken() error {
|
||||||
t.mu.Lock()
|
t.mu.Lock()
|
||||||
defer t.mu.Unlock()
|
defer t.mu.Unlock()
|
||||||
token, err := t.fetcher(t.token)
|
token, err := t.opts.TokenFetcherFunc(t.token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -15,10 +15,6 @@ func (f *mockTokenFetcher) Fn() func(*Token) (*Token, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *mockTokenFetcher) FetchToken(existing *Token) (*Token, error) {
|
|
||||||
return f.token, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestInitialTokenRead(t *testing.T) {
|
func TestInitialTokenRead(t *testing.T) {
|
||||||
tr := newTransport(http.DefaultTransport, nil, &Token{AccessToken: "abc"})
|
tr := newTransport(http.DefaultTransport, nil, &Token{AccessToken: "abc"})
|
||||||
server := newMockServer(func(w http.ResponseWriter, r *http.Request) {
|
server := newMockServer(func(w http.ResponseWriter, r *http.Request) {
|
||||||
@ -37,7 +33,7 @@ func TestTokenFetch(t *testing.T) {
|
|||||||
AccessToken: "abc",
|
AccessToken: "abc",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
tr := newTransport(http.DefaultTransport, fetcher.Fn(), nil)
|
tr := newTransport(http.DefaultTransport, &Options{TokenFetcherFunc: fetcher.Fn()}, nil)
|
||||||
server := newMockServer(func(w http.ResponseWriter, r *http.Request) {
|
server := newMockServer(func(w http.ResponseWriter, r *http.Request) {
|
||||||
if r.Header.Get("Authorization") != "Bearer abc" {
|
if r.Header.Get("Authorization") != "Bearer abc" {
|
||||||
t.Errorf("Transport doesn't set the Authorization header from the fetched token")
|
t.Errorf("Transport doesn't set the Authorization header from the fetched token")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user