oauth2/transport.go
Aaron Torres a8c019d04a oauth2: add support for client credential grant type
Creates a new package called clientcredentials and
adds transport and token information to the internal
package. Also modifies the oauth2 package to make
use of the newly added files in the internal package.

The clientcredentials package allows for token requests
using a "client credentials" grant type.

Fixes https://github.com/golang/oauth2/issues/7

Change-Id: Iec649d1029870c27a2d1023baa9d52db42ff45e8
Reviewed-on: https://go-review.googlesource.com/2983
Reviewed-by: Burcu Dogan <jbd@google.com>
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
2015-04-18 00:13:27 +00:00

133 lines
3.0 KiB
Go

// Copyright 2014 The oauth2 Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package oauth2
import (
"errors"
"io"
"net/http"
"sync"
)
// Transport is an http.RoundTripper that makes OAuth 2.0 HTTP requests,
// wrapping a base RoundTripper and adding an Authorization header
// with a token from the supplied Sources.
//
// Transport is a low-level mechanism. Most code will use the
// higher-level Config.Client method instead.
type Transport struct {
// Source supplies the token to add to outgoing requests'
// Authorization headers.
Source TokenSource
// Base is the base RoundTripper used to make HTTP requests.
// If nil, http.DefaultTransport is used.
Base http.RoundTripper
mu sync.Mutex // guards modReq
modReq map[*http.Request]*http.Request // original -> modified
}
// RoundTrip authorizes and authenticates the request with an
// access token. If no token exists or token is expired,
// tries to refresh/fetch a new token.
func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
if t.Source == nil {
return nil, errors.New("oauth2: Transport's Source is nil")
}
token, err := t.Source.Token()
if err != nil {
return nil, err
}
req2 := cloneRequest(req) // per RoundTripper contract
token.SetAuthHeader(req2)
t.setModReq(req, req2)
res, err := t.base().RoundTrip(req2)
if err != nil {
t.setModReq(req, nil)
return nil, err
}
res.Body = &onEOFReader{
rc: res.Body,
fn: func() { t.setModReq(req, nil) },
}
return res, nil
}
// CancelRequest cancels an in-flight request by closing its connection.
func (t *Transport) CancelRequest(req *http.Request) {
type canceler interface {
CancelRequest(*http.Request)
}
if cr, ok := t.base().(canceler); ok {
t.mu.Lock()
modReq := t.modReq[req]
delete(t.modReq, req)
t.mu.Unlock()
cr.CancelRequest(modReq)
}
}
func (t *Transport) base() http.RoundTripper {
if t.Base != nil {
return t.Base
}
return http.DefaultTransport
}
func (t *Transport) setModReq(orig, mod *http.Request) {
t.mu.Lock()
defer t.mu.Unlock()
if t.modReq == nil {
t.modReq = make(map[*http.Request]*http.Request)
}
if mod == nil {
delete(t.modReq, orig)
} else {
t.modReq[orig] = mod
}
}
// cloneRequest returns a clone of the provided *http.Request.
// The clone is a shallow copy of the struct and its Header map.
func cloneRequest(r *http.Request) *http.Request {
// shallow copy of the struct
r2 := new(http.Request)
*r2 = *r
// deep copy of the Header
r2.Header = make(http.Header, len(r.Header))
for k, s := range r.Header {
r2.Header[k] = append([]string(nil), s...)
}
return r2
}
type onEOFReader struct {
rc io.ReadCloser
fn func()
}
func (r *onEOFReader) Read(p []byte) (n int, err error) {
n, err = r.rc.Read(p)
if err == io.EOF {
r.runFunc()
}
return
}
func (r *onEOFReader) Close() error {
err := r.rc.Close()
r.runFunc()
return err
}
func (r *onEOFReader) runFunc() {
if fn := r.fn; fn != nil {
fn()
r.fn = nil
}
}