mirror of
https://github.com/golang/oauth2.git
synced 2025-07-21 00:00:09 +08:00
Compare commits
20 Commits
Author | SHA1 | Date | |
---|---|---|---|
|
cf14319341 | ||
|
32d34ef364 | ||
|
2d34e3091b | ||
|
696f7b3128 | ||
|
471209bbe2 | ||
|
6968da209b | ||
|
d2c4e0a625 | ||
|
883dc3c9d8 | ||
|
1c06e8705e | ||
|
65c15a3514 | ||
|
ce56909505 | ||
|
0042180b24 | ||
|
ce350bff61 | ||
|
44967abe90 | ||
|
9c82a8cf7a | ||
|
681b4d8edc | ||
|
3f78298bee | ||
|
109dabf901 | ||
|
ac571fa341 | ||
|
314ee5b92b |
@ -34,7 +34,7 @@ type PKCEParams struct {
|
|||||||
// and returns an auth code and state upon approval.
|
// and returns an auth code and state upon approval.
|
||||||
type AuthorizationHandler func(authCodeURL string) (code string, state string, err error)
|
type AuthorizationHandler func(authCodeURL string) (code string, state string, err error)
|
||||||
|
|
||||||
// TokenSourceWithPKCE is an enhanced version of TokenSource with PKCE support.
|
// TokenSourceWithPKCE is an enhanced version of [oauth2.TokenSource] with PKCE support.
|
||||||
//
|
//
|
||||||
// The pkce parameter supports PKCE flow, which uses code challenge and code verifier
|
// The pkce parameter supports PKCE flow, which uses code challenge and code verifier
|
||||||
// to prevent CSRF attacks. A unique code challenge and code verifier should be generated
|
// to prevent CSRF attacks. A unique code challenge and code verifier should be generated
|
||||||
@ -43,12 +43,12 @@ func TokenSourceWithPKCE(ctx context.Context, config *oauth2.Config, state strin
|
|||||||
return oauth2.ReuseTokenSource(nil, authHandlerSource{config: config, ctx: ctx, authHandler: authHandler, state: state, pkce: pkce})
|
return oauth2.ReuseTokenSource(nil, authHandlerSource{config: config, ctx: ctx, authHandler: authHandler, state: state, pkce: pkce})
|
||||||
}
|
}
|
||||||
|
|
||||||
// TokenSource returns an oauth2.TokenSource that fetches access tokens
|
// TokenSource returns an [oauth2.TokenSource] that fetches access tokens
|
||||||
// using 3-legged-OAuth flow.
|
// using 3-legged-OAuth flow.
|
||||||
//
|
//
|
||||||
// The provided context.Context is used for oauth2 Exchange operation.
|
// The provided [context.Context] is used for oauth2 Exchange operation.
|
||||||
//
|
//
|
||||||
// The provided oauth2.Config should be a full configuration containing AuthURL,
|
// The provided [oauth2.Config] should be a full configuration containing AuthURL,
|
||||||
// TokenURL, and Scope.
|
// TokenURL, and Scope.
|
||||||
//
|
//
|
||||||
// An environment-specific AuthorizationHandler is used to obtain user consent.
|
// An environment-specific AuthorizationHandler is used to obtain user consent.
|
||||||
|
@ -55,7 +55,7 @@ type Config struct {
|
|||||||
|
|
||||||
// Token uses client credentials to retrieve a token.
|
// Token uses client credentials to retrieve a token.
|
||||||
//
|
//
|
||||||
// The provided context optionally controls which HTTP client is used. See the oauth2.HTTPClient variable.
|
// The provided context optionally controls which HTTP client is used. See the [oauth2.HTTPClient] variable.
|
||||||
func (c *Config) Token(ctx context.Context) (*oauth2.Token, error) {
|
func (c *Config) Token(ctx context.Context) (*oauth2.Token, error) {
|
||||||
return c.TokenSource(ctx).Token()
|
return c.TokenSource(ctx).Token()
|
||||||
}
|
}
|
||||||
@ -64,18 +64,18 @@ func (c *Config) Token(ctx context.Context) (*oauth2.Token, error) {
|
|||||||
// The token will auto-refresh as necessary.
|
// The token will auto-refresh as necessary.
|
||||||
//
|
//
|
||||||
// The provided context optionally controls which HTTP client
|
// The provided context optionally controls which HTTP client
|
||||||
// is returned. See the oauth2.HTTPClient variable.
|
// is returned. See the [oauth2.HTTPClient] variable.
|
||||||
//
|
//
|
||||||
// The returned Client and its Transport should not be modified.
|
// The returned [http.Client] and its Transport should not be modified.
|
||||||
func (c *Config) Client(ctx context.Context) *http.Client {
|
func (c *Config) Client(ctx context.Context) *http.Client {
|
||||||
return oauth2.NewClient(ctx, c.TokenSource(ctx))
|
return oauth2.NewClient(ctx, c.TokenSource(ctx))
|
||||||
}
|
}
|
||||||
|
|
||||||
// TokenSource returns a TokenSource that returns t until t expires,
|
// TokenSource returns a [oauth2.TokenSource] that returns t until t expires,
|
||||||
// automatically refreshing it as necessary using the provided context and the
|
// automatically refreshing it as necessary using the provided context and the
|
||||||
// client ID and client secret.
|
// client ID and client secret.
|
||||||
//
|
//
|
||||||
// Most users will use Config.Client instead.
|
// Most users will use [Config.Client] instead.
|
||||||
func (c *Config) TokenSource(ctx context.Context) oauth2.TokenSource {
|
func (c *Config) TokenSource(ctx context.Context) oauth2.TokenSource {
|
||||||
source := &tokenSource{
|
source := &tokenSource{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
|
@ -7,7 +7,6 @@ package clientcredentials
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"io"
|
"io"
|
||||||
"io/ioutil"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"net/url"
|
"net/url"
|
||||||
@ -36,9 +35,9 @@ func TestTokenSourceGrantTypeOverride(t *testing.T) {
|
|||||||
wantGrantType := "password"
|
wantGrantType := "password"
|
||||||
var gotGrantType string
|
var gotGrantType string
|
||||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
body, err := ioutil.ReadAll(r.Body)
|
body, err := io.ReadAll(r.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("ioutil.ReadAll(r.Body) == %v, %v, want _, <nil>", body, err)
|
t.Errorf("io.ReadAll(r.Body) == %v, %v, want _, <nil>", body, err)
|
||||||
}
|
}
|
||||||
if err := r.Body.Close(); err != nil {
|
if err := r.Body.Close(); err != nil {
|
||||||
t.Errorf("r.Body.Close() == %v, want <nil>", err)
|
t.Errorf("r.Body.Close() == %v, want <nil>", err)
|
||||||
@ -81,7 +80,7 @@ func TestTokenRequest(t *testing.T) {
|
|||||||
if got, want := r.Header.Get("Content-Type"), "application/x-www-form-urlencoded"; got != want {
|
if got, want := r.Header.Get("Content-Type"), "application/x-www-form-urlencoded"; got != want {
|
||||||
t.Errorf("Content-Type header = %q; want %q", got, want)
|
t.Errorf("Content-Type header = %q; want %q", got, want)
|
||||||
}
|
}
|
||||||
body, err := ioutil.ReadAll(r.Body)
|
body, err := io.ReadAll(r.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
r.Body.Close()
|
r.Body.Close()
|
||||||
}
|
}
|
||||||
@ -123,7 +122,7 @@ func TestTokenRefreshRequest(t *testing.T) {
|
|||||||
if got, want := headerContentType, "application/x-www-form-urlencoded"; got != want {
|
if got, want := headerContentType, "application/x-www-form-urlencoded"; got != want {
|
||||||
t.Errorf("Content-Type = %q; want %q", got, want)
|
t.Errorf("Content-Type = %q; want %q", got, want)
|
||||||
}
|
}
|
||||||
body, _ := ioutil.ReadAll(r.Body)
|
body, _ := io.ReadAll(r.Body)
|
||||||
const want = "audience=audience1&grant_type=client_credentials&scope=scope1+scope2"
|
const want = "audience=audience1&grant_type=client_credentials&scope=scope1+scope2"
|
||||||
if string(body) != want {
|
if string(body) != want {
|
||||||
t.Errorf("Unexpected refresh token payload.\n got: %s\nwant: %s\n", body, want)
|
t.Errorf("Unexpected refresh token payload.\n got: %s\nwant: %s\n", body, want)
|
||||||
|
@ -7,9 +7,6 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/google/go-cmp/cmp"
|
|
||||||
"github.com/google/go-cmp/cmp/cmpopts"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestDeviceAuthResponseMarshalJson(t *testing.T) {
|
func TestDeviceAuthResponseMarshalJson(t *testing.T) {
|
||||||
@ -74,7 +71,16 @@ func TestDeviceAuthResponseUnmarshalJson(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
if !cmp.Equal(got, tc.want, cmpopts.IgnoreUnexported(DeviceAuthResponse{}), cmpopts.EquateApproxTime(time.Second+time.Since(begin))) {
|
margin := time.Second + time.Since(begin)
|
||||||
|
timeDiff := got.Expiry.Sub(tc.want.Expiry)
|
||||||
|
if timeDiff < 0 {
|
||||||
|
timeDiff *= -1
|
||||||
|
}
|
||||||
|
if timeDiff > margin {
|
||||||
|
t.Errorf("expiry time difference too large, got=%v, want=%v margin=%v", got.Expiry, tc.want.Expiry, margin)
|
||||||
|
}
|
||||||
|
got.Expiry, tc.want.Expiry = time.Time{}, time.Time{}
|
||||||
|
if got != tc.want {
|
||||||
t.Errorf("want=%#v, got=%#v", tc.want, got)
|
t.Errorf("want=%#v, got=%#v", tc.want, got)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
@ -6,7 +6,7 @@
|
|||||||
package endpoints
|
package endpoints
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"strings"
|
"net/url"
|
||||||
|
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
)
|
)
|
||||||
@ -17,6 +17,30 @@ var Amazon = oauth2.Endpoint{
|
|||||||
TokenURL: "https://api.amazon.com/auth/o2/token",
|
TokenURL: "https://api.amazon.com/auth/o2/token",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Apple is the endpoint for "Sign in with Apple".
|
||||||
|
//
|
||||||
|
// Documentation: https://developer.apple.com/documentation/signinwithapplerestapi
|
||||||
|
var Apple = oauth2.Endpoint{
|
||||||
|
AuthURL: "https://appleid.apple.com/auth/authorize",
|
||||||
|
TokenURL: "https://appleid.apple.com/auth/token",
|
||||||
|
}
|
||||||
|
|
||||||
|
// Asana is the endpoint for Asana.
|
||||||
|
//
|
||||||
|
// Documentation: https://developers.asana.com/docs/oauth
|
||||||
|
var Asana = oauth2.Endpoint{
|
||||||
|
AuthURL: "https://app.asana.com/-/oauth_authorize",
|
||||||
|
TokenURL: "https://app.asana.com/-/oauth_token",
|
||||||
|
}
|
||||||
|
|
||||||
|
// Badgr is the endpoint for Canvas Badges.
|
||||||
|
//
|
||||||
|
// Documentation: https://community.canvaslms.com/t5/Canvas-Badges-Credentials/Developers-Build-an-app-that-integrates-with-the-Canvas-Badges/ta-p/528727
|
||||||
|
var Badgr = oauth2.Endpoint{
|
||||||
|
AuthURL: "https://badgr.com/auth/oauth2/authorize",
|
||||||
|
TokenURL: "https://api.badgr.io/o/token",
|
||||||
|
}
|
||||||
|
|
||||||
// Battlenet is the endpoint for Battlenet.
|
// Battlenet is the endpoint for Battlenet.
|
||||||
var Battlenet = oauth2.Endpoint{
|
var Battlenet = oauth2.Endpoint{
|
||||||
AuthURL: "https://battle.net/oauth/authorize",
|
AuthURL: "https://battle.net/oauth/authorize",
|
||||||
@ -35,10 +59,44 @@ var Cern = oauth2.Endpoint{
|
|||||||
TokenURL: "https://oauth.web.cern.ch/OAuth/Token",
|
TokenURL: "https://oauth.web.cern.ch/OAuth/Token",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Coinbase is the endpoint for Coinbase.
|
||||||
|
//
|
||||||
|
// Documentation: https://docs.cdp.coinbase.com/coinbase-app/docs/coinbase-app-reference
|
||||||
|
var Coinbase = oauth2.Endpoint{
|
||||||
|
AuthURL: "https://login.coinbase.com/oauth2/auth",
|
||||||
|
TokenURL: "https://login.coinbase.com/oauth2/token",
|
||||||
|
}
|
||||||
|
|
||||||
|
// Discord is the endpoint for Discord.
|
||||||
|
//
|
||||||
|
// Documentation: https://discord.com/developers/docs/topics/oauth2#shared-resources-oauth2-urls
|
||||||
|
var Discord = oauth2.Endpoint{
|
||||||
|
AuthURL: "https://discord.com/oauth2/authorize",
|
||||||
|
TokenURL: "https://discord.com/api/oauth2/token",
|
||||||
|
}
|
||||||
|
|
||||||
|
// Dropbox is the endpoint for Dropbox.
|
||||||
|
//
|
||||||
|
// Documentation: https://developers.dropbox.com/oauth-guide
|
||||||
|
var Dropbox = oauth2.Endpoint{
|
||||||
|
AuthURL: "https://www.dropbox.com/oauth2/authorize",
|
||||||
|
TokenURL: "https://api.dropboxapi.com/oauth2/token",
|
||||||
|
}
|
||||||
|
|
||||||
|
// Endpoint is Ebay's OAuth 2.0 endpoint.
|
||||||
|
//
|
||||||
|
// Documentation: https://developer.ebay.com/api-docs/static/authorization_guide_landing.html
|
||||||
|
var Endpoint = oauth2.Endpoint{
|
||||||
|
AuthURL: "https://auth.ebay.com/oauth2/authorize",
|
||||||
|
TokenURL: "https://api.ebay.com/identity/v1/oauth2/token",
|
||||||
|
}
|
||||||
|
|
||||||
// Facebook is the endpoint for Facebook.
|
// Facebook is the endpoint for Facebook.
|
||||||
|
//
|
||||||
|
// Documentation: https://developers.facebook.com/docs/facebook-login/guides/advanced/manual-flow
|
||||||
var Facebook = oauth2.Endpoint{
|
var Facebook = oauth2.Endpoint{
|
||||||
AuthURL: "https://www.facebook.com/v3.2/dialog/oauth",
|
AuthURL: "https://www.facebook.com/v22.0/dialog/oauth",
|
||||||
TokenURL: "https://graph.facebook.com/v3.2/oauth/access_token",
|
TokenURL: "https://graph.facebook.com/v22.0/oauth/access_token",
|
||||||
}
|
}
|
||||||
|
|
||||||
// Foursquare is the endpoint for Foursquare.
|
// Foursquare is the endpoint for Foursquare.
|
||||||
@ -98,6 +156,14 @@ var KaKao = oauth2.Endpoint{
|
|||||||
TokenURL: "https://kauth.kakao.com/oauth/token",
|
TokenURL: "https://kauth.kakao.com/oauth/token",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Line is the endpoint for Line.
|
||||||
|
//
|
||||||
|
// Documentation: https://developers.line.biz/en/docs/line-login/integrate-line-login/
|
||||||
|
var Line = oauth2.Endpoint{
|
||||||
|
AuthURL: "https://access.line.me/oauth2/v2.1/authorize",
|
||||||
|
TokenURL: "https://api.line.me/oauth2/v2.1/token",
|
||||||
|
}
|
||||||
|
|
||||||
// LinkedIn is the endpoint for LinkedIn.
|
// LinkedIn is the endpoint for LinkedIn.
|
||||||
var LinkedIn = oauth2.Endpoint{
|
var LinkedIn = oauth2.Endpoint{
|
||||||
AuthURL: "https://www.linkedin.com/oauth/v2/authorization",
|
AuthURL: "https://www.linkedin.com/oauth/v2/authorization",
|
||||||
@ -134,7 +200,17 @@ var Microsoft = oauth2.Endpoint{
|
|||||||
TokenURL: "https://login.live.com/oauth20_token.srf",
|
TokenURL: "https://login.live.com/oauth20_token.srf",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Naver is the endpoint for Naver.
|
||||||
|
//
|
||||||
|
// Documentation: https://developers.naver.com/docs/login/devguide/devguide.md
|
||||||
|
var Naver = oauth2.Endpoint{
|
||||||
|
AuthURL: "https://nid.naver.com/oauth2/authorize",
|
||||||
|
TokenURL: "https://nid.naver.com/oauth2/token",
|
||||||
|
}
|
||||||
|
|
||||||
// NokiaHealth is the endpoint for Nokia Health.
|
// NokiaHealth is the endpoint for Nokia Health.
|
||||||
|
//
|
||||||
|
// Deprecated: Nokia Health is now Withings.
|
||||||
var NokiaHealth = oauth2.Endpoint{
|
var NokiaHealth = oauth2.Endpoint{
|
||||||
AuthURL: "https://account.health.nokia.com/oauth2_user/authorize2",
|
AuthURL: "https://account.health.nokia.com/oauth2_user/authorize2",
|
||||||
TokenURL: "https://account.health.nokia.com/oauth2/token",
|
TokenURL: "https://account.health.nokia.com/oauth2/token",
|
||||||
@ -146,6 +222,20 @@ var Odnoklassniki = oauth2.Endpoint{
|
|||||||
TokenURL: "https://api.odnoklassniki.ru/oauth/token.do",
|
TokenURL: "https://api.odnoklassniki.ru/oauth/token.do",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// OpenStreetMap is the endpoint for OpenStreetMap.org.
|
||||||
|
//
|
||||||
|
// Documentation: https://wiki.openstreetmap.org/wiki/OAuth
|
||||||
|
var OpenStreetMap = oauth2.Endpoint{
|
||||||
|
AuthURL: "https://www.openstreetmap.org/oauth2/authorize",
|
||||||
|
TokenURL: "https://www.openstreetmap.org/oauth2/token",
|
||||||
|
}
|
||||||
|
|
||||||
|
// Patreon is the endpoint for Patreon.
|
||||||
|
var Patreon = oauth2.Endpoint{
|
||||||
|
AuthURL: "https://www.patreon.com/oauth2/authorize",
|
||||||
|
TokenURL: "https://www.patreon.com/api/oauth2/token",
|
||||||
|
}
|
||||||
|
|
||||||
// PayPal is the endpoint for PayPal.
|
// PayPal is the endpoint for PayPal.
|
||||||
var PayPal = oauth2.Endpoint{
|
var PayPal = oauth2.Endpoint{
|
||||||
AuthURL: "https://www.paypal.com/webapps/auth/protocol/openidconnect/v1/authorize",
|
AuthURL: "https://www.paypal.com/webapps/auth/protocol/openidconnect/v1/authorize",
|
||||||
@ -158,10 +248,52 @@ var PayPalSandbox = oauth2.Endpoint{
|
|||||||
TokenURL: "https://api.sandbox.paypal.com/v1/identity/openidconnect/tokenservice",
|
TokenURL: "https://api.sandbox.paypal.com/v1/identity/openidconnect/tokenservice",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Pinterest is the endpoint for Pinterest.
|
||||||
|
//
|
||||||
|
// Documentation: https://developers.pinterest.com/docs/getting-started/set-up-authentication-and-authorization/
|
||||||
|
var Pinterest = oauth2.Endpoint{
|
||||||
|
AuthURL: "https://www.pinterest.com/oauth",
|
||||||
|
TokenURL: "https://api.pinterest.com/v5/oauth/token",
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pipedrive is the endpoint for Pipedrive.
|
||||||
|
//
|
||||||
|
// Documentation: https://developers.pipedrive.com/docs/api/v1/Oauth
|
||||||
|
var Pipedrive = oauth2.Endpoint{
|
||||||
|
AuthURL: "https://oauth.pipedrive.com/oauth/authorize",
|
||||||
|
TokenURL: "https://oauth.pipedrive.com/oauth/token",
|
||||||
|
}
|
||||||
|
|
||||||
|
// QQ is the endpoint for QQ.
|
||||||
|
//
|
||||||
|
// Documentation: https://wiki.connect.qq.com/%e5%bc%80%e5%8f%91%e6%94%bb%e7%95%a5_server-side
|
||||||
|
var QQ = oauth2.Endpoint{
|
||||||
|
AuthURL: "https://graph.qq.com/oauth2.0/authorize",
|
||||||
|
TokenURL: "https://graph.qq.com/oauth2.0/token",
|
||||||
|
}
|
||||||
|
|
||||||
|
// Rakuten is the endpoint for Rakuten.
|
||||||
|
//
|
||||||
|
// Documentation: https://webservice.rakuten.co.jp/documentation
|
||||||
|
var Rakuten = oauth2.Endpoint{
|
||||||
|
AuthURL: "https://app.rakuten.co.jp/services/authorize",
|
||||||
|
TokenURL: "https://app.rakuten.co.jp/services/token",
|
||||||
|
}
|
||||||
|
|
||||||
// Slack is the endpoint for Slack.
|
// Slack is the endpoint for Slack.
|
||||||
|
//
|
||||||
|
// Documentation: https://api.slack.com/authentication/oauth-v2
|
||||||
var Slack = oauth2.Endpoint{
|
var Slack = oauth2.Endpoint{
|
||||||
AuthURL: "https://slack.com/oauth/authorize",
|
AuthURL: "https://slack.com/oauth/v2/authorize",
|
||||||
TokenURL: "https://slack.com/api/oauth.access",
|
TokenURL: "https://slack.com/api/oauth.v2.access",
|
||||||
|
}
|
||||||
|
|
||||||
|
// Splitwise is the endpoint for Splitwise.
|
||||||
|
//
|
||||||
|
// Documentation: https://dev.splitwise.com/
|
||||||
|
var Splitwise = oauth2.Endpoint{
|
||||||
|
AuthURL: "https://www.splitwise.com/oauth/authorize",
|
||||||
|
TokenURL: "https://www.splitwise.com/oauth/token",
|
||||||
}
|
}
|
||||||
|
|
||||||
// Spotify is the endpoint for Spotify.
|
// Spotify is the endpoint for Spotify.
|
||||||
@ -200,6 +332,22 @@ var Vk = oauth2.Endpoint{
|
|||||||
TokenURL: "https://oauth.vk.com/access_token",
|
TokenURL: "https://oauth.vk.com/access_token",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Withings is the endpoint for Withings.
|
||||||
|
//
|
||||||
|
// Documentation: https://account.withings.com/oauth2_user/authorize2
|
||||||
|
var Withings = oauth2.Endpoint{
|
||||||
|
AuthURL: "https://account.withings.com/oauth2_user/authorize2",
|
||||||
|
TokenURL: "https://account.withings.com/oauth2/token",
|
||||||
|
}
|
||||||
|
|
||||||
|
// X is the endpoint for X (Twitter).
|
||||||
|
//
|
||||||
|
// Documentation: https://docs.x.com/resources/fundamentals/authentication/oauth-2-0/user-access-token
|
||||||
|
var X = oauth2.Endpoint{
|
||||||
|
AuthURL: "https://x.com/i/oauth2/authorize",
|
||||||
|
TokenURL: "https://api.x.com/2/oauth2/token",
|
||||||
|
}
|
||||||
|
|
||||||
// Yahoo is the endpoint for Yahoo.
|
// Yahoo is the endpoint for Yahoo.
|
||||||
var Yahoo = oauth2.Endpoint{
|
var Yahoo = oauth2.Endpoint{
|
||||||
AuthURL: "https://api.login.yahoo.com/oauth2/request_auth",
|
AuthURL: "https://api.login.yahoo.com/oauth2/request_auth",
|
||||||
@ -218,6 +366,20 @@ var Zoom = oauth2.Endpoint{
|
|||||||
TokenURL: "https://zoom.us/oauth/token",
|
TokenURL: "https://zoom.us/oauth/token",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Asgardeo returns a new oauth2.Endpoint for the given tenant.
|
||||||
|
//
|
||||||
|
// Documentation: https://wso2.com/asgardeo/docs/guides/authentication/oidc/discover-oidc-configs/
|
||||||
|
func AsgardeoEndpoint(tenant string) oauth2.Endpoint {
|
||||||
|
u := url.URL{
|
||||||
|
Scheme: "https",
|
||||||
|
Host: "api.asgardeo.io",
|
||||||
|
}
|
||||||
|
return oauth2.Endpoint{
|
||||||
|
AuthURL: u.JoinPath("t", tenant, "/oauth2/authorize").String(),
|
||||||
|
TokenURL: u.JoinPath("t", tenant, "/oauth2/token").String(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// AzureAD returns a new oauth2.Endpoint for the given tenant at Azure Active Directory.
|
// AzureAD returns a new oauth2.Endpoint for the given tenant at Azure Active Directory.
|
||||||
// If tenant is empty, it uses the tenant called `common`.
|
// If tenant is empty, it uses the tenant called `common`.
|
||||||
//
|
//
|
||||||
@ -227,19 +389,29 @@ func AzureAD(tenant string) oauth2.Endpoint {
|
|||||||
if tenant == "" {
|
if tenant == "" {
|
||||||
tenant = "common"
|
tenant = "common"
|
||||||
}
|
}
|
||||||
|
u := url.URL{
|
||||||
|
Scheme: "https",
|
||||||
|
Host: "login.microsoftonline.com",
|
||||||
|
}
|
||||||
return oauth2.Endpoint{
|
return oauth2.Endpoint{
|
||||||
AuthURL: "https://login.microsoftonline.com/" + tenant + "/oauth2/v2.0/authorize",
|
AuthURL: u.JoinPath(tenant, "/oauth2/v2.0/authorize").String(),
|
||||||
TokenURL: "https://login.microsoftonline.com/" + tenant + "/oauth2/v2.0/token",
|
TokenURL: u.JoinPath(tenant, "/oauth2/v2.0/token").String(),
|
||||||
DeviceAuthURL: "https://login.microsoftonline.com/" + tenant + "/oauth2/v2.0/devicecode",
|
DeviceAuthURL: u.JoinPath(tenant, "/oauth2/v2.0/devicecode").String(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// HipChatServer returns a new oauth2.Endpoint for a HipChat Server instance
|
// AzureADB2CEndpoint returns a new oauth2.Endpoint for the given tenant and policy at Azure Active Directory B2C.
|
||||||
// running on the given domain or host.
|
// policy is the Azure B2C User flow name Example: `B2C_1_SignUpSignIn`.
|
||||||
func HipChatServer(host string) oauth2.Endpoint {
|
//
|
||||||
|
// Documentation: https://docs.microsoft.com/en-us/azure/active-directory-b2c/tokens-overview#endpoints
|
||||||
|
func AzureADB2CEndpoint(tenant string, policy string) oauth2.Endpoint {
|
||||||
|
u := url.URL{
|
||||||
|
Scheme: "https",
|
||||||
|
Host: tenant + ".b2clogin.com",
|
||||||
|
}
|
||||||
return oauth2.Endpoint{
|
return oauth2.Endpoint{
|
||||||
AuthURL: "https://" + host + "/users/authorize",
|
AuthURL: u.JoinPath(tenant+".onmicrosoft.com", policy, "/oauth2/v2.0/authorize").String(),
|
||||||
TokenURL: "https://" + host + "/v2/oauth/token",
|
TokenURL: u.JoinPath(tenant+".onmicrosoft.com", policy, "/oauth2/v2.0/token").String(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -252,9 +424,42 @@ func HipChatServer(host string) oauth2.Endpoint {
|
|||||||
// https://docs.aws.amazon.com/cognito/latest/developerguide/cognito-user-pools-assign-domain.html
|
// https://docs.aws.amazon.com/cognito/latest/developerguide/cognito-user-pools-assign-domain.html
|
||||||
// https://docs.aws.amazon.com/cognito/latest/developerguide/cognito-userpools-server-contract-reference.html
|
// https://docs.aws.amazon.com/cognito/latest/developerguide/cognito-userpools-server-contract-reference.html
|
||||||
func AWSCognito(domain string) oauth2.Endpoint {
|
func AWSCognito(domain string) oauth2.Endpoint {
|
||||||
domain = strings.TrimRight(domain, "/")
|
u, err := url.Parse(domain)
|
||||||
|
if err != nil || u.Scheme == "" || u.Host == "" {
|
||||||
|
panic("endpoints: invalid domain" + domain)
|
||||||
|
}
|
||||||
return oauth2.Endpoint{
|
return oauth2.Endpoint{
|
||||||
AuthURL: domain + "/oauth2/authorize",
|
AuthURL: u.JoinPath("/oauth2/authorize").String(),
|
||||||
TokenURL: domain + "/oauth2/token",
|
TokenURL: u.JoinPath("/oauth2/token").String(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// HipChatServer returns a new oauth2.Endpoint for a HipChat Server instance.
|
||||||
|
// host should be a hostname, without any scheme prefix.
|
||||||
|
//
|
||||||
|
// Documentation: https://developer.atlassian.com/server/hipchat/hipchat-rest-api-access-tokens/
|
||||||
|
func HipChatServer(host string) oauth2.Endpoint {
|
||||||
|
u := url.URL{
|
||||||
|
Scheme: "https",
|
||||||
|
Host: host,
|
||||||
|
}
|
||||||
|
return oauth2.Endpoint{
|
||||||
|
AuthURL: u.JoinPath("/users/authorize").String(),
|
||||||
|
TokenURL: u.JoinPath("/v2/oauth/token").String(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Shopify returns a new oauth2.Endpoint for the supplied shop domain name.
|
||||||
|
// host should be a hostname, without any scheme prefix.
|
||||||
|
//
|
||||||
|
// Documentation: https://shopify.dev/docs/apps/auth/oauth
|
||||||
|
func Shopify(host string) oauth2.Endpoint {
|
||||||
|
u := url.URL{
|
||||||
|
Scheme: "https",
|
||||||
|
Host: host,
|
||||||
|
}
|
||||||
|
return oauth2.Endpoint{
|
||||||
|
AuthURL: u.JoinPath("/admin/oauth/authorize").String(),
|
||||||
|
TokenURL: u.JoinPath("/admin/oauth/access_token").String(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
7
go.mod
7
go.mod
@ -1,8 +1,5 @@
|
|||||||
module golang.org/x/oauth2
|
module golang.org/x/oauth2
|
||||||
|
|
||||||
go 1.18
|
go 1.23.0
|
||||||
|
|
||||||
require (
|
require cloud.google.com/go/compute/metadata v0.3.0
|
||||||
cloud.google.com/go/compute/metadata v0.3.0
|
|
||||||
github.com/google/go-cmp v0.5.9
|
|
||||||
)
|
|
||||||
|
2
go.sum
2
go.sum
@ -1,4 +1,2 @@
|
|||||||
cloud.google.com/go/compute/metadata v0.3.0 h1:Tz+eQXMEqDIKRsmY3cHTL6FVaynIjX2QxYC4trgAKZc=
|
cloud.google.com/go/compute/metadata v0.3.0 h1:Tz+eQXMEqDIKRsmY3cHTL6FVaynIjX2QxYC4trgAKZc=
|
||||||
cloud.google.com/go/compute/metadata v0.3.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k=
|
cloud.google.com/go/compute/metadata v0.3.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k=
|
||||||
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
|
|
||||||
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
|
||||||
|
@ -39,7 +39,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io/ioutil"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
@ -198,7 +198,7 @@ func (dts downscopingTokenSource) Token() (*oauth2.Token, error) {
|
|||||||
return nil, fmt.Errorf("unable to generate POST Request %v", err)
|
return nil, fmt.Errorf("unable to generate POST Request %v", err)
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
respBody, err := ioutil.ReadAll(resp.Body)
|
respBody, err := io.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("downscope: unable to read response body: %v", err)
|
return nil, fmt.Errorf("downscope: unable to read response body: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -6,7 +6,7 @@ package downscope
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"io/ioutil"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"testing"
|
"testing"
|
||||||
@ -27,7 +27,7 @@ func Test_DownscopedTokenSource(t *testing.T) {
|
|||||||
if r.URL.String() != "/" {
|
if r.URL.String() != "/" {
|
||||||
t.Errorf("Unexpected request URL, %v is found", r.URL)
|
t.Errorf("Unexpected request URL, %v is found", r.URL)
|
||||||
}
|
}
|
||||||
body, err := ioutil.ReadAll(r.Body)
|
body, err := io.ReadAll(r.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to read request body: %v", err)
|
t.Fatalf("Failed to read request body: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -7,9 +7,9 @@ package google_test
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io/ioutil"
|
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"os"
|
||||||
|
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
"golang.org/x/oauth2/google"
|
"golang.org/x/oauth2/google"
|
||||||
@ -60,7 +60,7 @@ func ExampleJWTConfigFromJSON() {
|
|||||||
// To create a service account client, click "Create new Client ID",
|
// To create a service account client, click "Create new Client ID",
|
||||||
// select "Service Account", and click "Create Client ID". A JSON
|
// select "Service Account", and click "Create Client ID". A JSON
|
||||||
// key file will then be downloaded to your computer.
|
// key file will then be downloaded to your computer.
|
||||||
data, err := ioutil.ReadFile("/path/to/your-project-key.json")
|
data, err := os.ReadFile("/path/to/your-project-key.json")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
@ -136,7 +136,7 @@ func ExampleComputeTokenSource() {
|
|||||||
|
|
||||||
func ExampleCredentialsFromJSON() {
|
func ExampleCredentialsFromJSON() {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
data, err := ioutil.ReadFile("/path/to/key-file.json")
|
data, err := os.ReadFile("/path/to/key-file.json")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -14,7 +14,6 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"io/ioutil"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
@ -170,7 +169,7 @@ func requestDataHash(req *http.Request) (string, error) {
|
|||||||
}
|
}
|
||||||
defer requestBody.Close()
|
defer requestBody.Close()
|
||||||
|
|
||||||
requestData, err = ioutil.ReadAll(io.LimitReader(requestBody, 1<<20))
|
requestData, err = io.ReadAll(io.LimitReader(requestBody, 1<<20))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
@ -419,7 +418,7 @@ func (cs *awsCredentialSource) getAWSSessionToken() (string, error) {
|
|||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
respBody, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20))
|
respBody, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
@ -462,7 +461,7 @@ func (cs *awsCredentialSource) getRegion(headers map[string]string) (string, err
|
|||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
respBody, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20))
|
respBody, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
@ -531,7 +530,7 @@ func (cs *awsCredentialSource) getMetadataSecurityCredentials(roleName string, h
|
|||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
respBody, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20))
|
respBody, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return result, err
|
return result, err
|
||||||
}
|
}
|
||||||
@ -564,7 +563,7 @@ func (cs *awsCredentialSource) getMetadataRoleName(headers map[string]string) (s
|
|||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
respBody, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20))
|
respBody, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
@ -263,7 +263,7 @@ const (
|
|||||||
fileTypeJSON = "json"
|
fileTypeJSON = "json"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Format contains information needed to retireve a subject token for URL or File sourced credentials.
|
// Format contains information needed to retrieve a subject token for URL or File sourced credentials.
|
||||||
type Format struct {
|
type Format struct {
|
||||||
// Type should be either "text" or "json". This determines whether the file or URL sourced credentials
|
// Type should be either "text" or "json". This determines whether the file or URL sourced credentials
|
||||||
// expect a simple text subject token or if the subject token will be contained in a JSON object.
|
// expect a simple text subject token or if the subject token will be contained in a JSON object.
|
||||||
@ -486,11 +486,11 @@ func (ts tokenSource) Token() (*oauth2.Token, error) {
|
|||||||
ClientID: conf.ClientID,
|
ClientID: conf.ClientID,
|
||||||
ClientSecret: conf.ClientSecret,
|
ClientSecret: conf.ClientSecret,
|
||||||
}
|
}
|
||||||
var options map[string]interface{}
|
var options map[string]any
|
||||||
// Do not pass workforce_pool_user_project when client authentication is used.
|
// Do not pass workforce_pool_user_project when client authentication is used.
|
||||||
// The client ID is sufficient for determining the user project.
|
// The client ID is sufficient for determining the user project.
|
||||||
if conf.WorkforcePoolUserProject != "" && conf.ClientID == "" {
|
if conf.WorkforcePoolUserProject != "" && conf.ClientID == "" {
|
||||||
options = map[string]interface{}{
|
options = map[string]any{
|
||||||
"userProject": conf.WorkforcePoolUserProject,
|
"userProject": conf.WorkforcePoolUserProject,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -8,7 +8,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io/ioutil"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"testing"
|
"testing"
|
||||||
@ -77,7 +77,7 @@ func run(t *testing.T, config *Config, tets *testExchangeTokenServer) (*oauth2.T
|
|||||||
if got, want := headerMetrics, tets.metricsHeader; got != want {
|
if got, want := headerMetrics, tets.metricsHeader; got != want {
|
||||||
t.Errorf("got %v but want %v", got, want)
|
t.Errorf("got %v but want %v", got, want)
|
||||||
}
|
}
|
||||||
body, err := ioutil.ReadAll(r.Body)
|
body, err := io.ReadAll(r.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed reading request body: %s.", err)
|
t.Fatalf("Failed reading request body: %s.", err)
|
||||||
}
|
}
|
||||||
@ -131,7 +131,7 @@ func createImpersonationServer(urlWanted, authWanted, bodyWanted, response strin
|
|||||||
if got, want := headerContentType, "application/json"; got != want {
|
if got, want := headerContentType, "application/json"; got != want {
|
||||||
t.Errorf("got %v but want %v", got, want)
|
t.Errorf("got %v but want %v", got, want)
|
||||||
}
|
}
|
||||||
body, err := ioutil.ReadAll(r.Body)
|
body, err := io.ReadAll(r.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed reading request body: %v.", err)
|
t.Fatalf("Failed reading request body: %v.", err)
|
||||||
}
|
}
|
||||||
@ -160,7 +160,7 @@ func createTargetServer(metricsHeaderWanted string, t *testing.T) *httptest.Serv
|
|||||||
if got, want := headerMetrics, metricsHeaderWanted; got != want {
|
if got, want := headerMetrics, metricsHeaderWanted; got != want {
|
||||||
t.Errorf("got %v but want %v", got, want)
|
t.Errorf("got %v but want %v", got, want)
|
||||||
}
|
}
|
||||||
body, err := ioutil.ReadAll(r.Body)
|
body, err := io.ReadAll(r.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed reading request body: %v.", err)
|
t.Fatalf("Failed reading request body: %v.", err)
|
||||||
}
|
}
|
||||||
@ -347,12 +347,12 @@ func TestNonworkforceWithWorkforcePoolUserProject(t *testing.T) {
|
|||||||
t.Fatalf("Expected error but found none")
|
t.Fatalf("Expected error but found none")
|
||||||
}
|
}
|
||||||
if got, want := err.Error(), "oauth2/google/externalaccount: Workforce pool user project should not be set for non-workforce pool credentials"; got != want {
|
if got, want := err.Error(), "oauth2/google/externalaccount: Workforce pool user project should not be set for non-workforce pool credentials"; got != want {
|
||||||
t.Errorf("Incorrect error received.\nExpected: %s\nRecieved: %s", want, got)
|
t.Errorf("Incorrect error received.\nExpected: %s\nReceived: %s", want, got)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestWorkforcePoolCreation(t *testing.T) {
|
func TestWorkforcePoolCreation(t *testing.T) {
|
||||||
var audienceValidatyTests = []struct {
|
var audienceValidityTests = []struct {
|
||||||
audience string
|
audience string
|
||||||
expectSuccess bool
|
expectSuccess bool
|
||||||
}{
|
}{
|
||||||
@ -371,7 +371,7 @@ func TestWorkforcePoolCreation(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
for _, tt := range audienceValidatyTests {
|
for _, tt := range audienceValidityTests {
|
||||||
t.Run(" "+tt.audience, func(t *testing.T) { // We prepend a space ahead of the test input when outputting for sake of readability.
|
t.Run(" "+tt.audience, func(t *testing.T) { // We prepend a space ahead of the test input when outputting for sake of readability.
|
||||||
config := testConfig
|
config := testConfig
|
||||||
config.TokenURL = "https://sts.googleapis.com" // Setting the most basic acceptable tokenURL
|
config.TokenURL = "https://sts.googleapis.com" // Setting the most basic acceptable tokenURL
|
||||||
|
@ -11,7 +11,6 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"io/ioutil"
|
|
||||||
"os"
|
"os"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"regexp"
|
"regexp"
|
||||||
@ -258,7 +257,7 @@ func (cs executableCredentialSource) getTokenFromOutputFile() (token string, err
|
|||||||
}
|
}
|
||||||
defer file.Close()
|
defer file.Close()
|
||||||
|
|
||||||
data, err := ioutil.ReadAll(io.LimitReader(file, 1<<20))
|
data, err := io.ReadAll(io.LimitReader(file, 1<<20))
|
||||||
if err != nil || len(data) == 0 {
|
if err != nil || len(data) == 0 {
|
||||||
// Cachefile exists, but no data found. Get new credential.
|
// Cachefile exists, but no data found. Get new credential.
|
||||||
return "", nil
|
return "", nil
|
||||||
|
@ -8,13 +8,10 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io/ioutil"
|
|
||||||
"os"
|
"os"
|
||||||
"sort"
|
"slices"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/google/go-cmp/cmp"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type testEnvironment struct {
|
type testEnvironment struct {
|
||||||
@ -254,14 +251,12 @@ func TestExecutableCredentialGetEnvironment(t *testing.T) {
|
|||||||
|
|
||||||
ecs.env = &tt.environment
|
ecs.env = &tt.environment
|
||||||
|
|
||||||
// This Transformer sorts a []string.
|
got := ecs.executableEnvironment()
|
||||||
sorter := cmp.Transformer("Sort", func(in []string) []string {
|
slices.Sort(got)
|
||||||
out := append([]string(nil), in...) // Copy input to avoid mutating it
|
want := tt.expectedEnvironment
|
||||||
sort.Strings(out)
|
slices.Sort(want)
|
||||||
return out
|
|
||||||
})
|
|
||||||
|
|
||||||
if got, want := ecs.executableEnvironment(), tt.expectedEnvironment; !cmp.Equal(got, want, sorter) {
|
if !slices.Equal(got, want) {
|
||||||
t.Errorf("Incorrect environment received.\nReceived: %s\nExpected: %s", got, want)
|
t.Errorf("Incorrect environment received.\nReceived: %s\nExpected: %s", got, want)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
@ -614,7 +609,7 @@ func TestRetrieveExecutableSubjectTokenSuccesses(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestRetrieveOutputFileSubjectTokenNotJSON(t *testing.T) {
|
func TestRetrieveOutputFileSubjectTokenNotJSON(t *testing.T) {
|
||||||
outputFile, err := ioutil.TempFile("testdata", "result.*.json")
|
outputFile, err := os.CreateTemp("testdata", "result.*.json")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Tempfile failed: %v", err)
|
t.Fatalf("Tempfile failed: %v", err)
|
||||||
}
|
}
|
||||||
@ -654,7 +649,7 @@ func TestRetrieveOutputFileSubjectTokenNotJSON(t *testing.T) {
|
|||||||
if _, err = base.subjectToken(); err == nil {
|
if _, err = base.subjectToken(); err == nil {
|
||||||
t.Fatalf("Expected error but found none")
|
t.Fatalf("Expected error but found none")
|
||||||
} else if got, want := err.Error(), jsonParsingError(outputFileSource, "tokentokentoken").Error(); got != want {
|
} else if got, want := err.Error(), jsonParsingError(outputFileSource, "tokentokentoken").Error(); got != want {
|
||||||
t.Errorf("Incorrect error received.\nExpected: %s\nRecieved: %s", want, got)
|
t.Errorf("Incorrect error received.\nExpected: %s\nReceived: %s", want, got)
|
||||||
}
|
}
|
||||||
|
|
||||||
_, deadlineSet := te.getDeadline()
|
_, deadlineSet := te.getDeadline()
|
||||||
@ -763,7 +758,7 @@ var cacheFailureTests = []struct {
|
|||||||
func TestRetrieveOutputFileSubjectTokenFailureTests(t *testing.T) {
|
func TestRetrieveOutputFileSubjectTokenFailureTests(t *testing.T) {
|
||||||
for _, tt := range cacheFailureTests {
|
for _, tt := range cacheFailureTests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
outputFile, err := ioutil.TempFile("testdata", "result.*.json")
|
outputFile, err := os.CreateTemp("testdata", "result.*.json")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Tempfile failed: %v", err)
|
t.Fatalf("Tempfile failed: %v", err)
|
||||||
}
|
}
|
||||||
@ -801,7 +796,7 @@ func TestRetrieveOutputFileSubjectTokenFailureTests(t *testing.T) {
|
|||||||
if _, err = ecs.subjectToken(); err == nil {
|
if _, err = ecs.subjectToken(); err == nil {
|
||||||
t.Errorf("Expected error but found none")
|
t.Errorf("Expected error but found none")
|
||||||
} else if got, want := err.Error(), tt.expectedErr.Error(); got != want {
|
} else if got, want := err.Error(), tt.expectedErr.Error(); got != want {
|
||||||
t.Errorf("Incorrect error received.\nExpected: %s\nRecieved: %s", want, got)
|
t.Errorf("Incorrect error received.\nExpected: %s\nReceived: %s", want, got)
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, deadlineSet := te.getDeadline(); deadlineSet {
|
if _, deadlineSet := te.getDeadline(); deadlineSet {
|
||||||
@ -866,7 +861,7 @@ var invalidCacheTests = []struct {
|
|||||||
func TestRetrieveOutputFileSubjectTokenInvalidCache(t *testing.T) {
|
func TestRetrieveOutputFileSubjectTokenInvalidCache(t *testing.T) {
|
||||||
for _, tt := range invalidCacheTests {
|
for _, tt := range invalidCacheTests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
outputFile, err := ioutil.TempFile("testdata", "result.*.json")
|
outputFile, err := os.CreateTemp("testdata", "result.*.json")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Tempfile failed: %v", err)
|
t.Fatalf("Tempfile failed: %v", err)
|
||||||
}
|
}
|
||||||
@ -923,7 +918,7 @@ func TestRetrieveOutputFileSubjectTokenInvalidCache(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if got, want := out, "tokentokentoken"; got != want {
|
if got, want := out, "tokentokentoken"; got != want {
|
||||||
t.Errorf("Incorrect token received.\nExpected: %s\nRecieved: %s", want, got)
|
t.Errorf("Incorrect token received.\nExpected: %s\nReceived: %s", want, got)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -970,8 +965,7 @@ var cacheSuccessTests = []struct {
|
|||||||
func TestRetrieveOutputFileSubjectTokenJwt(t *testing.T) {
|
func TestRetrieveOutputFileSubjectTokenJwt(t *testing.T) {
|
||||||
for _, tt := range cacheSuccessTests {
|
for _, tt := range cacheSuccessTests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
outputFile, err := os.CreateTemp("testdata", "result.*.json")
|
||||||
outputFile, err := ioutil.TempFile("testdata", "result.*.json")
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Tempfile failed: %v", err)
|
t.Fatalf("Tempfile failed: %v", err)
|
||||||
}
|
}
|
||||||
@ -1012,7 +1006,7 @@ func TestRetrieveOutputFileSubjectTokenJwt(t *testing.T) {
|
|||||||
if out, err := ecs.subjectToken(); err != nil {
|
if out, err := ecs.subjectToken(); err != nil {
|
||||||
t.Errorf("retrieveSubjectToken() failed: %v", err)
|
t.Errorf("retrieveSubjectToken() failed: %v", err)
|
||||||
} else if got, want := out, "tokentokentoken"; got != want {
|
} else if got, want := out, "tokentokentoken"; got != want {
|
||||||
t.Errorf("Incorrect token received.\nExpected: %s\nRecieved: %s", want, got)
|
t.Errorf("Incorrect token received.\nExpected: %s\nReceived: %s", want, got)
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, deadlineSet := te.getDeadline(); deadlineSet {
|
if _, deadlineSet := te.getDeadline(); deadlineSet {
|
||||||
|
@ -10,7 +10,6 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"io/ioutil"
|
|
||||||
"os"
|
"os"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -29,14 +28,14 @@ func (cs fileCredentialSource) subjectToken() (string, error) {
|
|||||||
return "", fmt.Errorf("oauth2/google/externalaccount: failed to open credential file %q", cs.File)
|
return "", fmt.Errorf("oauth2/google/externalaccount: failed to open credential file %q", cs.File)
|
||||||
}
|
}
|
||||||
defer tokenFile.Close()
|
defer tokenFile.Close()
|
||||||
tokenBytes, err := ioutil.ReadAll(io.LimitReader(tokenFile, 1<<20))
|
tokenBytes, err := io.ReadAll(io.LimitReader(tokenFile, 1<<20))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("oauth2/google/externalaccount: failed to read credential file: %v", err)
|
return "", fmt.Errorf("oauth2/google/externalaccount: failed to read credential file: %v", err)
|
||||||
}
|
}
|
||||||
tokenBytes = bytes.TrimSpace(tokenBytes)
|
tokenBytes = bytes.TrimSpace(tokenBytes)
|
||||||
switch cs.Format.Type {
|
switch cs.Format.Type {
|
||||||
case "json":
|
case "json":
|
||||||
jsonData := make(map[string]interface{})
|
jsonData := make(map[string]any)
|
||||||
err = json.Unmarshal(tokenBytes, &jsonData)
|
err = json.Unmarshal(tokenBytes, &jsonData)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("oauth2/google/externalaccount: failed to unmarshal subject token file: %v", err)
|
return "", fmt.Errorf("oauth2/google/externalaccount: failed to unmarshal subject token file: %v", err)
|
||||||
|
@ -7,8 +7,6 @@ package externalaccount
|
|||||||
import (
|
import (
|
||||||
"runtime"
|
"runtime"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/google/go-cmp/cmp"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestGoVersion(t *testing.T) {
|
func TestGoVersion(t *testing.T) {
|
||||||
@ -40,8 +38,8 @@ func TestGoVersion(t *testing.T) {
|
|||||||
} {
|
} {
|
||||||
version = tst.v
|
version = tst.v
|
||||||
got := goVersion()
|
got := goVersion()
|
||||||
if diff := cmp.Diff(got, tst.want); diff != "" {
|
if got != tst.want {
|
||||||
t.Errorf("got(-),want(+):\n%s", diff)
|
t.Errorf("go version = %q, want = %q", got, tst.want)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
version = runtime.Version
|
version = runtime.Version
|
||||||
|
@ -10,7 +10,6 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"io/ioutil"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
@ -44,7 +43,7 @@ func (cs urlCredentialSource) subjectToken() (string, error) {
|
|||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
respBody, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20))
|
respBody, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("oauth2/google/externalaccount: invalid body in subject token URL query: %v", err)
|
return "", fmt.Errorf("oauth2/google/externalaccount: invalid body in subject token URL query: %v", err)
|
||||||
}
|
}
|
||||||
@ -54,7 +53,7 @@ func (cs urlCredentialSource) subjectToken() (string, error) {
|
|||||||
|
|
||||||
switch cs.Format.Type {
|
switch cs.Format.Type {
|
||||||
case "json":
|
case "json":
|
||||||
jsonData := make(map[string]interface{})
|
jsonData := make(map[string]any)
|
||||||
err = json.Unmarshal(respBody, &jsonData)
|
err = json.Unmarshal(respBody, &jsonData)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("oauth2/google/externalaccount: failed to unmarshal subject token file: %v", err)
|
return "", fmt.Errorf("oauth2/google/externalaccount: failed to unmarshal subject token file: %v", err)
|
||||||
|
@ -285,27 +285,23 @@ func (cs computeSource) Token() (*oauth2.Token, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
var res struct {
|
var res oauth2.Token
|
||||||
AccessToken string `json:"access_token"`
|
|
||||||
ExpiresInSec int `json:"expires_in"`
|
|
||||||
TokenType string `json:"token_type"`
|
|
||||||
}
|
|
||||||
err = json.NewDecoder(strings.NewReader(tokenJSON)).Decode(&res)
|
err = json.NewDecoder(strings.NewReader(tokenJSON)).Decode(&res)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("oauth2/google: invalid token JSON from metadata: %v", err)
|
return nil, fmt.Errorf("oauth2/google: invalid token JSON from metadata: %v", err)
|
||||||
}
|
}
|
||||||
if res.ExpiresInSec == 0 || res.AccessToken == "" {
|
if res.ExpiresIn == 0 || res.AccessToken == "" {
|
||||||
return nil, fmt.Errorf("oauth2/google: incomplete token received from metadata")
|
return nil, fmt.Errorf("oauth2/google: incomplete token received from metadata")
|
||||||
}
|
}
|
||||||
tok := &oauth2.Token{
|
tok := &oauth2.Token{
|
||||||
AccessToken: res.AccessToken,
|
AccessToken: res.AccessToken,
|
||||||
TokenType: res.TokenType,
|
TokenType: res.TokenType,
|
||||||
Expiry: time.Now().Add(time.Duration(res.ExpiresInSec) * time.Second),
|
Expiry: time.Now().Add(time.Duration(res.ExpiresIn) * time.Second),
|
||||||
}
|
}
|
||||||
// NOTE(cbro): add hidden metadata about where the token is from.
|
// NOTE(cbro): add hidden metadata about where the token is from.
|
||||||
// This is needed for detection by client libraries to know that credentials come from the metadata server.
|
// This is needed for detection by client libraries to know that credentials come from the metadata server.
|
||||||
// This may be removed in a future version of this library.
|
// This may be removed in a future version of this library.
|
||||||
return tok.WithExtra(map[string]interface{}{
|
return tok.WithExtra(map[string]any{
|
||||||
"oauth2.google.tokenSource": "compute-metadata",
|
"oauth2.google.tokenSource": "compute-metadata",
|
||||||
"oauth2.google.serviceAccount": acct,
|
"oauth2.google.serviceAccount": acct,
|
||||||
}), nil
|
}), nil
|
||||||
|
@ -72,7 +72,7 @@ func TestConfigFromJSON(t *testing.T) {
|
|||||||
t.Errorf("ClientSecret = %q; want %q", got, want)
|
t.Errorf("ClientSecret = %q; want %q", got, want)
|
||||||
}
|
}
|
||||||
if got, want := conf.RedirectURL, "https://www.example.com/oauth2callback"; got != want {
|
if got, want := conf.RedirectURL, "https://www.example.com/oauth2callback"; got != want {
|
||||||
t.Errorf("RedictURL = %q; want %q", got, want)
|
t.Errorf("RedirectURL = %q; want %q", got, want)
|
||||||
}
|
}
|
||||||
if got, want := strings.Join(conf.Scopes, ","), "scope1,scope2"; got != want {
|
if got, want := strings.Join(conf.Scopes, ","), "scope1,scope2"; got != want {
|
||||||
t.Errorf("Scopes = %q; want %q", got, want)
|
t.Errorf("Scopes = %q; want %q", got, want)
|
||||||
|
@ -8,7 +8,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"io/ioutil"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"testing"
|
"testing"
|
||||||
@ -38,7 +38,7 @@ type testRefreshTokenServer struct {
|
|||||||
server *httptest.Server
|
server *httptest.Server
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestExernalAccountAuthorizedUser_JustToken(t *testing.T) {
|
func TestExternalAccountAuthorizedUser_JustToken(t *testing.T) {
|
||||||
config := &Config{
|
config := &Config{
|
||||||
Token: "AAAAAAA",
|
Token: "AAAAAAA",
|
||||||
Expiry: now().Add(time.Hour),
|
Expiry: now().Add(time.Hour),
|
||||||
@ -57,7 +57,7 @@ func TestExernalAccountAuthorizedUser_JustToken(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestExernalAccountAuthorizedUser_TokenRefreshWithRefreshTokenInRespondse(t *testing.T) {
|
func TestExternalAccountAuthorizedUser_TokenRefreshWithRefreshTokenInResponse(t *testing.T) {
|
||||||
server := &testRefreshTokenServer{
|
server := &testRefreshTokenServer{
|
||||||
URL: "/",
|
URL: "/",
|
||||||
Authorization: "Basic Q0xJRU5UX0lEOkNMSUVOVF9TRUNSRVQ=",
|
Authorization: "Basic Q0xJRU5UX0lEOkNMSUVOVF9TRUNSRVQ=",
|
||||||
@ -99,7 +99,7 @@ func TestExernalAccountAuthorizedUser_TokenRefreshWithRefreshTokenInRespondse(t
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestExernalAccountAuthorizedUser_MinimumFieldsRequiredForRefresh(t *testing.T) {
|
func TestExternalAccountAuthorizedUser_MinimumFieldsRequiredForRefresh(t *testing.T) {
|
||||||
server := &testRefreshTokenServer{
|
server := &testRefreshTokenServer{
|
||||||
URL: "/",
|
URL: "/",
|
||||||
Authorization: "Basic Q0xJRU5UX0lEOkNMSUVOVF9TRUNSRVQ=",
|
Authorization: "Basic Q0xJRU5UX0lEOkNMSUVOVF9TRUNSRVQ=",
|
||||||
@ -187,7 +187,7 @@ func TestExternalAccountAuthorizedUser_MissingRefreshFields(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "missing client secrect",
|
name: "missing client secret",
|
||||||
config: Config{
|
config: Config{
|
||||||
RefreshToken: "BBBBBBBBB",
|
RefreshToken: "BBBBBBBBB",
|
||||||
TokenURL: url,
|
TokenURL: url,
|
||||||
@ -227,7 +227,7 @@ func (trts *testRefreshTokenServer) run(t *testing.T) (string, error) {
|
|||||||
if got, want := headerContentType, trts.ContentType; got != want {
|
if got, want := headerContentType, trts.ContentType; got != want {
|
||||||
t.Errorf("got %v but want %v", got, want)
|
t.Errorf("got %v but want %v", got, want)
|
||||||
}
|
}
|
||||||
body, err := ioutil.ReadAll(r.Body)
|
body, err := io.ReadAll(r.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed reading request body: %s.", err)
|
t.Fatalf("Failed reading request body: %s.", err)
|
||||||
}
|
}
|
||||||
|
@ -10,7 +10,6 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"io/ioutil"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -81,7 +80,7 @@ func (its ImpersonateTokenSource) Token() (*oauth2.Token, error) {
|
|||||||
return nil, fmt.Errorf("oauth2/google: unable to generate access token: %v", err)
|
return nil, fmt.Errorf("oauth2/google: unable to generate access token: %v", err)
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
body, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20))
|
body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("oauth2/google: unable to read body: %v", err)
|
return nil, fmt.Errorf("oauth2/google: unable to read body: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -9,7 +9,6 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"io/ioutil"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strconv"
|
"strconv"
|
||||||
@ -28,7 +27,7 @@ func defaultHeader() http.Header {
|
|||||||
// The first 4 fields are all mandatory. headers can be used to pass additional
|
// The first 4 fields are all mandatory. headers can be used to pass additional
|
||||||
// headers beyond the bare minimum required by the token exchange. options can
|
// headers beyond the bare minimum required by the token exchange. options can
|
||||||
// be used to pass additional JSON-structured options to the remote server.
|
// be used to pass additional JSON-structured options to the remote server.
|
||||||
func ExchangeToken(ctx context.Context, endpoint string, request *TokenExchangeRequest, authentication ClientAuthentication, headers http.Header, options map[string]interface{}) (*Response, error) {
|
func ExchangeToken(ctx context.Context, endpoint string, request *TokenExchangeRequest, authentication ClientAuthentication, headers http.Header, options map[string]any) (*Response, error) {
|
||||||
data := url.Values{}
|
data := url.Values{}
|
||||||
data.Set("audience", request.Audience)
|
data.Set("audience", request.Audience)
|
||||||
data.Set("grant_type", "urn:ietf:params:oauth:grant-type:token-exchange")
|
data.Set("grant_type", "urn:ietf:params:oauth:grant-type:token-exchange")
|
||||||
@ -82,7 +81,7 @@ func makeRequest(ctx context.Context, endpoint string, data url.Values, authenti
|
|||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
body, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20))
|
body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -7,7 +7,7 @@ package stsexchange
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"io/ioutil"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"net/url"
|
"net/url"
|
||||||
@ -73,7 +73,7 @@ func TestExchangeToken(t *testing.T) {
|
|||||||
if got, want := r.Header.Get("Content-Type"), "application/x-www-form-urlencoded"; got != want {
|
if got, want := r.Header.Get("Content-Type"), "application/x-www-form-urlencoded"; got != want {
|
||||||
t.Errorf("Unexpected Content-Type header, got %v, want %v", got, want)
|
t.Errorf("Unexpected Content-Type header, got %v, want %v", got, want)
|
||||||
}
|
}
|
||||||
body, err := ioutil.ReadAll(r.Body)
|
body, err := io.ReadAll(r.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("Failed reading request body: %v.", err)
|
t.Errorf("Failed reading request body: %v.", err)
|
||||||
}
|
}
|
||||||
@ -132,7 +132,7 @@ var optsValues = [][]string{{"foo", "bar"}, {"cat", "pan"}}
|
|||||||
|
|
||||||
func TestExchangeToken_Opts(t *testing.T) {
|
func TestExchangeToken_Opts(t *testing.T) {
|
||||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
body, err := ioutil.ReadAll(r.Body)
|
body, err := io.ReadAll(r.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed reading request body: %v.", err)
|
t.Fatalf("Failed reading request body: %v.", err)
|
||||||
}
|
}
|
||||||
@ -142,11 +142,11 @@ func TestExchangeToken_Opts(t *testing.T) {
|
|||||||
}
|
}
|
||||||
strOpts, ok := data["options"]
|
strOpts, ok := data["options"]
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Errorf("Server didn't recieve an \"options\" field.")
|
t.Errorf("Server didn't receive an \"options\" field.")
|
||||||
} else if len(strOpts) < 1 {
|
} else if len(strOpts) < 1 {
|
||||||
t.Errorf("\"options\" field has length 0.")
|
t.Errorf("\"options\" field has length 0.")
|
||||||
}
|
}
|
||||||
var opts map[string]interface{}
|
var opts map[string]any
|
||||||
err = json.Unmarshal([]byte(strOpts[0]), &opts)
|
err = json.Unmarshal([]byte(strOpts[0]), &opts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Couldn't parse received \"options\" field.")
|
t.Fatalf("Couldn't parse received \"options\" field.")
|
||||||
@ -159,7 +159,7 @@ func TestExchangeToken_Opts(t *testing.T) {
|
|||||||
if !ok {
|
if !ok {
|
||||||
t.Errorf("Couldn't find first option parameter.")
|
t.Errorf("Couldn't find first option parameter.")
|
||||||
} else {
|
} else {
|
||||||
tOpts1, ok := val.(map[string]interface{})
|
tOpts1, ok := val.(map[string]any)
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Errorf("Failed to assert the first option parameter as type testOpts.")
|
t.Errorf("Failed to assert the first option parameter as type testOpts.")
|
||||||
} else {
|
} else {
|
||||||
@ -176,7 +176,7 @@ func TestExchangeToken_Opts(t *testing.T) {
|
|||||||
if !ok {
|
if !ok {
|
||||||
t.Errorf("Couldn't find second option parameter.")
|
t.Errorf("Couldn't find second option parameter.")
|
||||||
} else {
|
} else {
|
||||||
tOpts2, ok := val2.(map[string]interface{})
|
tOpts2, ok := val2.(map[string]any)
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Errorf("Failed to assert the second option parameter as type testOpts.")
|
t.Errorf("Failed to assert the second option parameter as type testOpts.")
|
||||||
} else {
|
} else {
|
||||||
@ -200,7 +200,7 @@ func TestExchangeToken_Opts(t *testing.T) {
|
|||||||
|
|
||||||
firstOption := testOpts{optsValues[0][0], optsValues[0][1]}
|
firstOption := testOpts{optsValues[0][0], optsValues[0][1]}
|
||||||
secondOption := testOpts{optsValues[1][0], optsValues[1][1]}
|
secondOption := testOpts{optsValues[1][0], optsValues[1][1]}
|
||||||
inputOpts := make(map[string]interface{})
|
inputOpts := make(map[string]any)
|
||||||
inputOpts["one"] = firstOption
|
inputOpts["one"] = firstOption
|
||||||
inputOpts["two"] = secondOption
|
inputOpts["two"] = secondOption
|
||||||
ExchangeToken(context.Background(), ts.URL, &exchangeTokenRequest, auth, headers, inputOpts)
|
ExchangeToken(context.Background(), ts.URL, &exchangeTokenRequest, auth, headers, inputOpts)
|
||||||
@ -220,7 +220,7 @@ func TestRefreshToken(t *testing.T) {
|
|||||||
if got, want := r.Header.Get("Content-Type"), "application/x-www-form-urlencoded"; got != want {
|
if got, want := r.Header.Get("Content-Type"), "application/x-www-form-urlencoded"; got != want {
|
||||||
t.Errorf("Unexpected Content-Type header, got %v, want %v", got, want)
|
t.Errorf("Unexpected Content-Type header, got %v, want %v", got, want)
|
||||||
}
|
}
|
||||||
body, err := ioutil.ReadAll(r.Body)
|
body, err := io.ReadAll(r.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("Failed reading request body: %v.", err)
|
t.Errorf("Failed reading request body: %v.", err)
|
||||||
}
|
}
|
||||||
|
@ -2,5 +2,5 @@
|
|||||||
// Use of this source code is governed by a BSD-style
|
// Use of this source code is governed by a BSD-style
|
||||||
// license that can be found in the LICENSE file.
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
// Package internal contains support packages for oauth2 package.
|
// Package internal contains support packages for [golang.org/x/oauth2].
|
||||||
package internal
|
package internal
|
||||||
|
@ -13,7 +13,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// ParseKey converts the binary contents of a private key file
|
// ParseKey converts the binary contents of a private key file
|
||||||
// to an *rsa.PrivateKey. It detects whether the private key is in a
|
// to an [*rsa.PrivateKey]. It detects whether the private key is in a
|
||||||
// PEM container or not. If so, it extracts the private key
|
// PEM container or not. If so, it extracts the private key
|
||||||
// from PEM container before conversion. It only supports PEM
|
// from PEM container before conversion. It only supports PEM
|
||||||
// containers with no passphrase.
|
// containers with no passphrase.
|
||||||
|
@ -10,7 +10,6 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"io/ioutil"
|
|
||||||
"math"
|
"math"
|
||||||
"mime"
|
"mime"
|
||||||
"net/http"
|
"net/http"
|
||||||
@ -26,9 +25,9 @@ import (
|
|||||||
// the requests to access protected resources on the OAuth 2.0
|
// the requests to access protected resources on the OAuth 2.0
|
||||||
// provider's backend.
|
// provider's backend.
|
||||||
//
|
//
|
||||||
// This type is a mirror of oauth2.Token and exists to break
|
// This type is a mirror of [golang.org/x/oauth2.Token] and exists to break
|
||||||
// an otherwise-circular dependency. Other internal packages
|
// an otherwise-circular dependency. Other internal packages
|
||||||
// should convert this Token into an oauth2.Token before use.
|
// should convert this Token into an [golang.org/x/oauth2.Token] before use.
|
||||||
type Token struct {
|
type Token struct {
|
||||||
// AccessToken is the token that authorizes and authenticates
|
// AccessToken is the token that authorizes and authenticates
|
||||||
// the requests.
|
// the requests.
|
||||||
@ -50,9 +49,16 @@ type Token struct {
|
|||||||
// mechanisms for that TokenSource will not be used.
|
// mechanisms for that TokenSource will not be used.
|
||||||
Expiry time.Time
|
Expiry time.Time
|
||||||
|
|
||||||
|
// ExpiresIn is the OAuth2 wire format "expires_in" field,
|
||||||
|
// which specifies how many seconds later the token expires,
|
||||||
|
// relative to an unknown time base approximately around "now".
|
||||||
|
// It is the application's responsibility to populate
|
||||||
|
// `Expiry` from `ExpiresIn` when required.
|
||||||
|
ExpiresIn int64 `json:"expires_in,omitempty"`
|
||||||
|
|
||||||
// Raw optionally contains extra metadata from the server
|
// Raw optionally contains extra metadata from the server
|
||||||
// when updating a token.
|
// when updating a token.
|
||||||
Raw interface{}
|
Raw any
|
||||||
}
|
}
|
||||||
|
|
||||||
// tokenJSON is the struct representing the HTTP response from OAuth2
|
// tokenJSON is the struct representing the HTTP response from OAuth2
|
||||||
@ -99,14 +105,6 @@ func (e *expirationTime) UnmarshalJSON(b []byte) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// RegisterBrokenAuthHeaderProvider previously did something. It is now a no-op.
|
|
||||||
//
|
|
||||||
// Deprecated: this function no longer does anything. Caller code that
|
|
||||||
// wants to avoid potential extra HTTP requests made during
|
|
||||||
// auto-probing of the provider's auth style should set
|
|
||||||
// Endpoint.AuthStyle.
|
|
||||||
func RegisterBrokenAuthHeaderProvider(tokenURL string) {}
|
|
||||||
|
|
||||||
// AuthStyle is a copy of the golang.org/x/oauth2 package's AuthStyle type.
|
// AuthStyle is a copy of the golang.org/x/oauth2 package's AuthStyle type.
|
||||||
type AuthStyle int
|
type AuthStyle int
|
||||||
|
|
||||||
@ -143,6 +141,11 @@ func (lc *LazyAuthStyleCache) Get() *AuthStyleCache {
|
|||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type authStyleCacheKey struct {
|
||||||
|
url string
|
||||||
|
clientID string
|
||||||
|
}
|
||||||
|
|
||||||
// AuthStyleCache is the set of tokenURLs we've successfully used via
|
// AuthStyleCache is the set of tokenURLs we've successfully used via
|
||||||
// RetrieveToken and which style auth we ended up using.
|
// RetrieveToken and which style auth we ended up using.
|
||||||
// It's called a cache, but it doesn't (yet?) shrink. It's expected that
|
// It's called a cache, but it doesn't (yet?) shrink. It's expected that
|
||||||
@ -150,26 +153,26 @@ func (lc *LazyAuthStyleCache) Get() *AuthStyleCache {
|
|||||||
// small.
|
// small.
|
||||||
type AuthStyleCache struct {
|
type AuthStyleCache struct {
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
m map[string]AuthStyle // keyed by tokenURL
|
m map[authStyleCacheKey]AuthStyle
|
||||||
}
|
}
|
||||||
|
|
||||||
// lookupAuthStyle reports which auth style we last used with tokenURL
|
// lookupAuthStyle reports which auth style we last used with tokenURL
|
||||||
// when calling RetrieveToken and whether we have ever done so.
|
// when calling RetrieveToken and whether we have ever done so.
|
||||||
func (c *AuthStyleCache) lookupAuthStyle(tokenURL string) (style AuthStyle, ok bool) {
|
func (c *AuthStyleCache) lookupAuthStyle(tokenURL, clientID string) (style AuthStyle, ok bool) {
|
||||||
c.mu.Lock()
|
c.mu.Lock()
|
||||||
defer c.mu.Unlock()
|
defer c.mu.Unlock()
|
||||||
style, ok = c.m[tokenURL]
|
style, ok = c.m[authStyleCacheKey{tokenURL, clientID}]
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// setAuthStyle adds an entry to authStyleCache, documented above.
|
// setAuthStyle adds an entry to authStyleCache, documented above.
|
||||||
func (c *AuthStyleCache) setAuthStyle(tokenURL string, v AuthStyle) {
|
func (c *AuthStyleCache) setAuthStyle(tokenURL, clientID string, v AuthStyle) {
|
||||||
c.mu.Lock()
|
c.mu.Lock()
|
||||||
defer c.mu.Unlock()
|
defer c.mu.Unlock()
|
||||||
if c.m == nil {
|
if c.m == nil {
|
||||||
c.m = make(map[string]AuthStyle)
|
c.m = make(map[authStyleCacheKey]AuthStyle)
|
||||||
}
|
}
|
||||||
c.m[tokenURL] = v
|
c.m[authStyleCacheKey{tokenURL, clientID}] = v
|
||||||
}
|
}
|
||||||
|
|
||||||
// newTokenRequest returns a new *http.Request to retrieve a new token
|
// newTokenRequest returns a new *http.Request to retrieve a new token
|
||||||
@ -210,9 +213,9 @@ func cloneURLValues(v url.Values) url.Values {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func RetrieveToken(ctx context.Context, clientID, clientSecret, tokenURL string, v url.Values, authStyle AuthStyle, styleCache *AuthStyleCache) (*Token, error) {
|
func RetrieveToken(ctx context.Context, clientID, clientSecret, tokenURL string, v url.Values, authStyle AuthStyle, styleCache *AuthStyleCache) (*Token, error) {
|
||||||
needsAuthStyleProbe := authStyle == 0
|
needsAuthStyleProbe := authStyle == AuthStyleUnknown
|
||||||
if needsAuthStyleProbe {
|
if needsAuthStyleProbe {
|
||||||
if style, ok := styleCache.lookupAuthStyle(tokenURL); ok {
|
if style, ok := styleCache.lookupAuthStyle(tokenURL, clientID); ok {
|
||||||
authStyle = style
|
authStyle = style
|
||||||
needsAuthStyleProbe = false
|
needsAuthStyleProbe = false
|
||||||
} else {
|
} else {
|
||||||
@ -242,7 +245,7 @@ func RetrieveToken(ctx context.Context, clientID, clientSecret, tokenURL string,
|
|||||||
token, err = doTokenRoundTrip(ctx, req)
|
token, err = doTokenRoundTrip(ctx, req)
|
||||||
}
|
}
|
||||||
if needsAuthStyleProbe && err == nil {
|
if needsAuthStyleProbe && err == nil {
|
||||||
styleCache.setAuthStyle(tokenURL, authStyle)
|
styleCache.setAuthStyle(tokenURL, clientID, authStyle)
|
||||||
}
|
}
|
||||||
// Don't overwrite `RefreshToken` with an empty value
|
// Don't overwrite `RefreshToken` with an empty value
|
||||||
// if this was a token refreshing request.
|
// if this was a token refreshing request.
|
||||||
@ -257,7 +260,7 @@ func doTokenRoundTrip(ctx context.Context, req *http.Request) (*Token, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
body, err := ioutil.ReadAll(io.LimitReader(r.Body, 1<<20))
|
body, err := io.ReadAll(io.LimitReader(r.Body, 1<<20))
|
||||||
r.Body.Close()
|
r.Body.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err)
|
return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err)
|
||||||
@ -312,7 +315,8 @@ func doTokenRoundTrip(ctx context.Context, req *http.Request) (*Token, error) {
|
|||||||
TokenType: tj.TokenType,
|
TokenType: tj.TokenType,
|
||||||
RefreshToken: tj.RefreshToken,
|
RefreshToken: tj.RefreshToken,
|
||||||
Expiry: tj.expiry(),
|
Expiry: tj.expiry(),
|
||||||
Raw: make(map[string]interface{}),
|
ExpiresIn: int64(tj.ExpiresIn),
|
||||||
|
Raw: make(map[string]any),
|
||||||
}
|
}
|
||||||
json.Unmarshal(body, &token.Raw) // no error checks for optional fields
|
json.Unmarshal(body, &token.Raw) // no error checks for optional fields
|
||||||
}
|
}
|
||||||
|
@ -75,3 +75,48 @@ func TestExpiresInUpperBound(t *testing.T) {
|
|||||||
t.Errorf("expiration time = %v; want %v", e, want)
|
t.Errorf("expiration time = %v; want %v", e, want)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAuthStyleCache(t *testing.T) {
|
||||||
|
var c LazyAuthStyleCache
|
||||||
|
|
||||||
|
cases := []struct {
|
||||||
|
url string
|
||||||
|
clientID string
|
||||||
|
style AuthStyle
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
"https://host1.example.com/token",
|
||||||
|
"client_1",
|
||||||
|
AuthStyleInHeader,
|
||||||
|
}, {
|
||||||
|
"https://host2.example.com/token",
|
||||||
|
"client_2",
|
||||||
|
AuthStyleInParams,
|
||||||
|
}, {
|
||||||
|
"https://host1.example.com/token",
|
||||||
|
"client_3",
|
||||||
|
AuthStyleInParams,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range cases {
|
||||||
|
t.Run(tt.clientID, func(t *testing.T) {
|
||||||
|
cc := c.Get()
|
||||||
|
got, ok := cc.lookupAuthStyle(tt.url, tt.clientID)
|
||||||
|
if ok {
|
||||||
|
t.Fatalf("unexpected auth style found on first request: %v", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
cc.setAuthStyle(tt.url, tt.clientID, tt.style)
|
||||||
|
|
||||||
|
got, ok = cc.lookupAuthStyle(tt.url, tt.clientID)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("auth style not found in cache")
|
||||||
|
}
|
||||||
|
|
||||||
|
if got != tt.style {
|
||||||
|
t.Fatalf("auth style mismatch, got=%v, want=%v", got, tt.style)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -9,8 +9,8 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
)
|
)
|
||||||
|
|
||||||
// HTTPClient is the context key to use with golang.org/x/net/context's
|
// HTTPClient is the context key to use with [context.WithValue]
|
||||||
// WithValue function to associate an *http.Client value with a context.
|
// to associate an [*http.Client] value with a context.
|
||||||
var HTTPClient ContextKey
|
var HTTPClient ContextKey
|
||||||
|
|
||||||
// ContextKey is just an empty struct. It exists so HTTPClient can be
|
// ContextKey is just an empty struct. It exists so HTTPClient can be
|
||||||
|
@ -13,7 +13,6 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"io/ioutil"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
@ -114,7 +113,7 @@ func (js jwtSource) Token() (*oauth2.Token, error) {
|
|||||||
return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err)
|
return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err)
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
body, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20))
|
body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err)
|
return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err)
|
||||||
}
|
}
|
||||||
@ -123,11 +122,7 @@ func (js jwtSource) Token() (*oauth2.Token, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// tokenRes is the JSON response body.
|
// tokenRes is the JSON response body.
|
||||||
var tokenRes struct {
|
var tokenRes oauth2.Token
|
||||||
AccessToken string `json:"access_token"`
|
|
||||||
TokenType string `json:"token_type"`
|
|
||||||
ExpiresIn int64 `json:"expires_in"` // relative seconds from now
|
|
||||||
}
|
|
||||||
if err := json.Unmarshal(body, &tokenRes); err != nil {
|
if err := json.Unmarshal(body, &tokenRes); err != nil {
|
||||||
return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err)
|
return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err)
|
||||||
}
|
}
|
||||||
|
40
jws/jws.go
40
jws/jws.go
@ -4,7 +4,7 @@
|
|||||||
|
|
||||||
// Package jws provides a partial implementation
|
// Package jws provides a partial implementation
|
||||||
// of JSON Web Signature encoding and decoding.
|
// of JSON Web Signature encoding and decoding.
|
||||||
// It exists to support the golang.org/x/oauth2 package.
|
// It exists to support the [golang.org/x/oauth2] package.
|
||||||
//
|
//
|
||||||
// See RFC 7515.
|
// See RFC 7515.
|
||||||
//
|
//
|
||||||
@ -48,7 +48,7 @@ type ClaimSet struct {
|
|||||||
|
|
||||||
// See http://tools.ietf.org/html/draft-jones-json-web-token-10#section-4.3
|
// See http://tools.ietf.org/html/draft-jones-json-web-token-10#section-4.3
|
||||||
// This array is marshalled using custom code (see (c *ClaimSet) encode()).
|
// This array is marshalled using custom code (see (c *ClaimSet) encode()).
|
||||||
PrivateClaims map[string]interface{} `json:"-"`
|
PrivateClaims map[string]any `json:"-"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *ClaimSet) encode() (string, error) {
|
func (c *ClaimSet) encode() (string, error) {
|
||||||
@ -116,12 +116,12 @@ func (h *Header) encode() (string, error) {
|
|||||||
// Decode decodes a claim set from a JWS payload.
|
// Decode decodes a claim set from a JWS payload.
|
||||||
func Decode(payload string) (*ClaimSet, error) {
|
func Decode(payload string) (*ClaimSet, error) {
|
||||||
// decode returned id token to get expiry
|
// decode returned id token to get expiry
|
||||||
s := strings.Split(payload, ".")
|
_, claims, _, ok := parseToken(payload)
|
||||||
if len(s) < 2 {
|
if !ok {
|
||||||
// TODO(jbd): Provide more context about the error.
|
// TODO(jbd): Provide more context about the error.
|
||||||
return nil, errors.New("jws: invalid token received")
|
return nil, errors.New("jws: invalid token received")
|
||||||
}
|
}
|
||||||
decoded, err := base64.RawURLEncoding.DecodeString(s[1])
|
decoded, err := base64.RawURLEncoding.DecodeString(claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -152,7 +152,7 @@ func EncodeWithSigner(header *Header, c *ClaimSet, sg Signer) (string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Encode encodes a signed JWS with provided header and claim set.
|
// Encode encodes a signed JWS with provided header and claim set.
|
||||||
// This invokes EncodeWithSigner using crypto/rsa.SignPKCS1v15 with the given RSA private key.
|
// This invokes [EncodeWithSigner] using [crypto/rsa.SignPKCS1v15] with the given RSA private key.
|
||||||
func Encode(header *Header, c *ClaimSet, key *rsa.PrivateKey) (string, error) {
|
func Encode(header *Header, c *ClaimSet, key *rsa.PrivateKey) (string, error) {
|
||||||
sg := func(data []byte) (sig []byte, err error) {
|
sg := func(data []byte) (sig []byte, err error) {
|
||||||
h := sha256.New()
|
h := sha256.New()
|
||||||
@ -165,18 +165,34 @@ func Encode(header *Header, c *ClaimSet, key *rsa.PrivateKey) (string, error) {
|
|||||||
// Verify tests whether the provided JWT token's signature was produced by the private key
|
// Verify tests whether the provided JWT token's signature was produced by the private key
|
||||||
// associated with the supplied public key.
|
// associated with the supplied public key.
|
||||||
func Verify(token string, key *rsa.PublicKey) error {
|
func Verify(token string, key *rsa.PublicKey) error {
|
||||||
parts := strings.Split(token, ".")
|
header, claims, sig, ok := parseToken(token)
|
||||||
if len(parts) != 3 {
|
if !ok {
|
||||||
return errors.New("jws: invalid token received, token must have 3 parts")
|
return errors.New("jws: invalid token received, token must have 3 parts")
|
||||||
}
|
}
|
||||||
|
signatureString, err := base64.RawURLEncoding.DecodeString(sig)
|
||||||
signedContent := parts[0] + "." + parts[1]
|
|
||||||
signatureString, err := base64.RawURLEncoding.DecodeString(parts[2])
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
h := sha256.New()
|
h := sha256.New()
|
||||||
h.Write([]byte(signedContent))
|
h.Write([]byte(header + tokenDelim + claims))
|
||||||
return rsa.VerifyPKCS1v15(key, crypto.SHA256, h.Sum(nil), signatureString)
|
return rsa.VerifyPKCS1v15(key, crypto.SHA256, h.Sum(nil), signatureString)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func parseToken(s string) (header, claims, sig string, ok bool) {
|
||||||
|
header, s, ok = strings.Cut(s, tokenDelim)
|
||||||
|
if !ok { // no period found
|
||||||
|
return "", "", "", false
|
||||||
|
}
|
||||||
|
claims, s, ok = strings.Cut(s, tokenDelim)
|
||||||
|
if !ok { // only one period found
|
||||||
|
return "", "", "", false
|
||||||
|
}
|
||||||
|
sig, _, ok = strings.Cut(s, tokenDelim)
|
||||||
|
if ok { // three periods found
|
||||||
|
return "", "", "", false
|
||||||
|
}
|
||||||
|
return header, claims, sig, true
|
||||||
|
}
|
||||||
|
|
||||||
|
const tokenDelim = "."
|
||||||
|
@ -7,6 +7,8 @@ package jws
|
|||||||
import (
|
import (
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"crypto/rsa"
|
"crypto/rsa"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -39,8 +41,57 @@ func TestSignAndVerify(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestVerifyFailsOnMalformedClaim(t *testing.T) {
|
func TestVerifyFailsOnMalformedClaim(t *testing.T) {
|
||||||
err := Verify("abc.def", nil)
|
cases := []struct {
|
||||||
|
desc string
|
||||||
|
token string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
desc: "no periods",
|
||||||
|
token: "aa",
|
||||||
|
}, {
|
||||||
|
desc: "only one period",
|
||||||
|
token: "a.a",
|
||||||
|
}, {
|
||||||
|
desc: "more than two periods",
|
||||||
|
token: "a.a.a.a",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tc := range cases {
|
||||||
|
f := func(t *testing.T) {
|
||||||
|
err := Verify(tc.token, nil)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Error("got no errors; want improperly formed JWT not to be verified")
|
t.Error("got no errors; want improperly formed JWT not to be verified")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
t.Run(tc.desc, f)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkVerify(b *testing.B) {
|
||||||
|
cases := []struct {
|
||||||
|
desc string
|
||||||
|
token string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
desc: "full of periods",
|
||||||
|
token: strings.Repeat(".", http.DefaultMaxHeaderBytes),
|
||||||
|
}, {
|
||||||
|
desc: "two trailing periods",
|
||||||
|
token: strings.Repeat("a", http.DefaultMaxHeaderBytes-2) + "..",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
for _, bc := range cases {
|
||||||
|
f := func(b *testing.B) {
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.ResetTimer()
|
||||||
|
for range b.N {
|
||||||
|
Verify(bc.token, &privateKey.PublicKey)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
b.Run(bc.desc, f)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
11
jwt/jwt.go
11
jwt/jwt.go
@ -13,7 +13,6 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"io/ioutil"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
@ -69,7 +68,7 @@ type Config struct {
|
|||||||
|
|
||||||
// PrivateClaims optionally specifies custom private claims in the JWT.
|
// PrivateClaims optionally specifies custom private claims in the JWT.
|
||||||
// See http://tools.ietf.org/html/draft-jones-json-web-token-10#section-4.3
|
// See http://tools.ietf.org/html/draft-jones-json-web-token-10#section-4.3
|
||||||
PrivateClaims map[string]interface{}
|
PrivateClaims map[string]any
|
||||||
|
|
||||||
// UseIDToken optionally specifies whether ID token should be used instead
|
// UseIDToken optionally specifies whether ID token should be used instead
|
||||||
// of access token when the server returns both.
|
// of access token when the server returns both.
|
||||||
@ -136,7 +135,7 @@ func (js jwtSource) Token() (*oauth2.Token, error) {
|
|||||||
return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err)
|
return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err)
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
body, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20))
|
body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err)
|
return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err)
|
||||||
}
|
}
|
||||||
@ -148,10 +147,8 @@ func (js jwtSource) Token() (*oauth2.Token, error) {
|
|||||||
}
|
}
|
||||||
// tokenRes is the JSON response body.
|
// tokenRes is the JSON response body.
|
||||||
var tokenRes struct {
|
var tokenRes struct {
|
||||||
AccessToken string `json:"access_token"`
|
oauth2.Token
|
||||||
TokenType string `json:"token_type"`
|
|
||||||
IDToken string `json:"id_token"`
|
IDToken string `json:"id_token"`
|
||||||
ExpiresIn int64 `json:"expires_in"` // relative seconds from now
|
|
||||||
}
|
}
|
||||||
if err := json.Unmarshal(body, &tokenRes); err != nil {
|
if err := json.Unmarshal(body, &tokenRes); err != nil {
|
||||||
return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err)
|
return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err)
|
||||||
@ -160,7 +157,7 @@ func (js jwtSource) Token() (*oauth2.Token, error) {
|
|||||||
AccessToken: tokenRes.AccessToken,
|
AccessToken: tokenRes.AccessToken,
|
||||||
TokenType: tokenRes.TokenType,
|
TokenType: tokenRes.TokenType,
|
||||||
}
|
}
|
||||||
raw := make(map[string]interface{})
|
raw := make(map[string]any)
|
||||||
json.Unmarshal(body, &raw) // no error checks for optional fields
|
json.Unmarshal(body, &raw) // no error checks for optional fields
|
||||||
token = token.WithExtra(raw)
|
token = token.WithExtra(raw)
|
||||||
|
|
||||||
|
@ -227,7 +227,7 @@ func TestJWTFetch_AssertionPayload(t *testing.T) {
|
|||||||
PrivateKey: dummyPrivateKey,
|
PrivateKey: dummyPrivateKey,
|
||||||
PrivateKeyID: "ABCDEFGHIJKLMNOPQRSTUVWXYZ",
|
PrivateKeyID: "ABCDEFGHIJKLMNOPQRSTUVWXYZ",
|
||||||
TokenURL: ts.URL,
|
TokenURL: ts.URL,
|
||||||
PrivateClaims: map[string]interface{}{
|
PrivateClaims: map[string]any{
|
||||||
"private0": "claim0",
|
"private0": "claim0",
|
||||||
"private1": "claim1",
|
"private1": "claim1",
|
||||||
},
|
},
|
||||||
@ -273,11 +273,11 @@ func TestJWTFetch_AssertionPayload(t *testing.T) {
|
|||||||
t.Errorf("payload prn = %q; want %q", got, want)
|
t.Errorf("payload prn = %q; want %q", got, want)
|
||||||
}
|
}
|
||||||
if len(conf.PrivateClaims) > 0 {
|
if len(conf.PrivateClaims) > 0 {
|
||||||
var got interface{}
|
var got any
|
||||||
if err := json.Unmarshal(gotjson, &got); err != nil {
|
if err := json.Unmarshal(gotjson, &got); err != nil {
|
||||||
t.Errorf("failed to parse payload; err = %q", err)
|
t.Errorf("failed to parse payload; err = %q", err)
|
||||||
}
|
}
|
||||||
m := got.(map[string]interface{})
|
m := got.(map[string]any)
|
||||||
for v, k := range conf.PrivateClaims {
|
for v, k := range conf.PrivateClaims {
|
||||||
if !reflect.DeepEqual(m[v], k) {
|
if !reflect.DeepEqual(m[v], k) {
|
||||||
t.Errorf("payload private claims key = %q: got %#v; want %#v", v, m[v], k)
|
t.Errorf("payload private claims key = %q: got %#v; want %#v", v, m[v], k)
|
||||||
|
63
oauth2.go
63
oauth2.go
@ -22,9 +22,9 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// NoContext is the default context you should supply if not using
|
// NoContext is the default context you should supply if not using
|
||||||
// your own context.Context (see https://golang.org/x/net/context).
|
// your own [context.Context].
|
||||||
//
|
//
|
||||||
// Deprecated: Use context.Background() or context.TODO() instead.
|
// Deprecated: Use [context.Background] or [context.TODO] instead.
|
||||||
var NoContext = context.TODO()
|
var NoContext = context.TODO()
|
||||||
|
|
||||||
// RegisterBrokenAuthHeaderProvider previously did something. It is now a no-op.
|
// RegisterBrokenAuthHeaderProvider previously did something. It is now a no-op.
|
||||||
@ -37,8 +37,8 @@ func RegisterBrokenAuthHeaderProvider(tokenURL string) {}
|
|||||||
|
|
||||||
// Config describes a typical 3-legged OAuth2 flow, with both the
|
// Config describes a typical 3-legged OAuth2 flow, with both the
|
||||||
// client application information and the server's endpoint URLs.
|
// client application information and the server's endpoint URLs.
|
||||||
// For the client credentials 2-legged OAuth2 flow, see the clientcredentials
|
// For the client credentials 2-legged OAuth2 flow, see the
|
||||||
// package (https://golang.org/x/oauth2/clientcredentials).
|
// [golang.org/x/oauth2/clientcredentials] package.
|
||||||
type Config struct {
|
type Config struct {
|
||||||
// ClientID is the application's ID.
|
// ClientID is the application's ID.
|
||||||
ClientID string
|
ClientID string
|
||||||
@ -46,7 +46,7 @@ type Config struct {
|
|||||||
// ClientSecret is the application's secret.
|
// ClientSecret is the application's secret.
|
||||||
ClientSecret string
|
ClientSecret string
|
||||||
|
|
||||||
// Endpoint contains the resource server's token endpoint
|
// Endpoint contains the authorization server's token endpoint
|
||||||
// URLs. These are constants specific to each server and are
|
// URLs. These are constants specific to each server and are
|
||||||
// often available via site-specific packages, such as
|
// often available via site-specific packages, such as
|
||||||
// google.Endpoint or github.Endpoint.
|
// google.Endpoint or github.Endpoint.
|
||||||
@ -135,7 +135,7 @@ type setParam struct{ k, v string }
|
|||||||
|
|
||||||
func (p setParam) setValue(m url.Values) { m.Set(p.k, p.v) }
|
func (p setParam) setValue(m url.Values) { m.Set(p.k, p.v) }
|
||||||
|
|
||||||
// SetAuthURLParam builds an AuthCodeOption which passes key/value parameters
|
// SetAuthURLParam builds an [AuthCodeOption] which passes key/value parameters
|
||||||
// to a provider's authorization endpoint.
|
// to a provider's authorization endpoint.
|
||||||
func SetAuthURLParam(key, value string) AuthCodeOption {
|
func SetAuthURLParam(key, value string) AuthCodeOption {
|
||||||
return setParam{key, value}
|
return setParam{key, value}
|
||||||
@ -148,8 +148,8 @@ func SetAuthURLParam(key, value string) AuthCodeOption {
|
|||||||
// request and callback. The authorization server includes this value when
|
// request and callback. The authorization server includes this value when
|
||||||
// redirecting the user agent back to the client.
|
// redirecting the user agent back to the client.
|
||||||
//
|
//
|
||||||
// Opts may include AccessTypeOnline or AccessTypeOffline, as well
|
// Opts may include [AccessTypeOnline] or [AccessTypeOffline], as well
|
||||||
// as ApprovalForce.
|
// as [ApprovalForce].
|
||||||
//
|
//
|
||||||
// To protect against CSRF attacks, opts should include a PKCE challenge
|
// To protect against CSRF attacks, opts should include a PKCE challenge
|
||||||
// (S256ChallengeOption). Not all servers support PKCE. An alternative is to
|
// (S256ChallengeOption). Not all servers support PKCE. An alternative is to
|
||||||
@ -194,7 +194,7 @@ func (c *Config) AuthCodeURL(state string, opts ...AuthCodeOption) string {
|
|||||||
// and when other authorization grant types are not available."
|
// and when other authorization grant types are not available."
|
||||||
// See https://tools.ietf.org/html/rfc6749#section-4.3 for more info.
|
// See https://tools.ietf.org/html/rfc6749#section-4.3 for more info.
|
||||||
//
|
//
|
||||||
// The provided context optionally controls which HTTP client is used. See the HTTPClient variable.
|
// The provided context optionally controls which HTTP client is used. See the [HTTPClient] variable.
|
||||||
func (c *Config) PasswordCredentialsToken(ctx context.Context, username, password string) (*Token, error) {
|
func (c *Config) PasswordCredentialsToken(ctx context.Context, username, password string) (*Token, error) {
|
||||||
v := url.Values{
|
v := url.Values{
|
||||||
"grant_type": {"password"},
|
"grant_type": {"password"},
|
||||||
@ -212,10 +212,10 @@ func (c *Config) PasswordCredentialsToken(ctx context.Context, username, passwor
|
|||||||
// It is used after a resource provider redirects the user back
|
// It is used after a resource provider redirects the user back
|
||||||
// to the Redirect URI (the URL obtained from AuthCodeURL).
|
// to the Redirect URI (the URL obtained from AuthCodeURL).
|
||||||
//
|
//
|
||||||
// The provided context optionally controls which HTTP client is used. See the HTTPClient variable.
|
// The provided context optionally controls which HTTP client is used. See the [HTTPClient] variable.
|
||||||
//
|
//
|
||||||
// The code will be in the *http.Request.FormValue("code"). Before
|
// The code will be in the [http.Request.FormValue]("code"). Before
|
||||||
// calling Exchange, be sure to validate FormValue("state") if you are
|
// calling Exchange, be sure to validate [http.Request.FormValue]("state") if you are
|
||||||
// using it to protect against CSRF attacks.
|
// using it to protect against CSRF attacks.
|
||||||
//
|
//
|
||||||
// If using PKCE to protect against CSRF attacks, opts should include a
|
// If using PKCE to protect against CSRF attacks, opts should include a
|
||||||
@ -242,10 +242,10 @@ func (c *Config) Client(ctx context.Context, t *Token) *http.Client {
|
|||||||
return NewClient(ctx, c.TokenSource(ctx, t))
|
return NewClient(ctx, c.TokenSource(ctx, t))
|
||||||
}
|
}
|
||||||
|
|
||||||
// TokenSource returns a TokenSource that returns t until t expires,
|
// TokenSource returns a [TokenSource] that returns t until t expires,
|
||||||
// automatically refreshing it as necessary using the provided context.
|
// automatically refreshing it as necessary using the provided context.
|
||||||
//
|
//
|
||||||
// Most users will use Config.Client instead.
|
// Most users will use [Config.Client] instead.
|
||||||
func (c *Config) TokenSource(ctx context.Context, t *Token) TokenSource {
|
func (c *Config) TokenSource(ctx context.Context, t *Token) TokenSource {
|
||||||
tkr := &tokenRefresher{
|
tkr := &tokenRefresher{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
@ -260,7 +260,7 @@ func (c *Config) TokenSource(ctx context.Context, t *Token) TokenSource {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// tokenRefresher is a TokenSource that makes "grant_type"=="refresh_token"
|
// tokenRefresher is a TokenSource that makes "grant_type=refresh_token"
|
||||||
// HTTP requests to renew a token using a RefreshToken.
|
// HTTP requests to renew a token using a RefreshToken.
|
||||||
type tokenRefresher struct {
|
type tokenRefresher struct {
|
||||||
ctx context.Context // used to get HTTP requests
|
ctx context.Context // used to get HTTP requests
|
||||||
@ -288,7 +288,7 @@ func (tf *tokenRefresher) Token() (*Token, error) {
|
|||||||
if tf.refreshToken != tk.RefreshToken {
|
if tf.refreshToken != tk.RefreshToken {
|
||||||
tf.refreshToken = tk.RefreshToken
|
tf.refreshToken = tk.RefreshToken
|
||||||
}
|
}
|
||||||
return tk, err
|
return tk, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// reuseTokenSource is a TokenSource that holds a single token in memory
|
// reuseTokenSource is a TokenSource that holds a single token in memory
|
||||||
@ -305,8 +305,7 @@ type reuseTokenSource struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Token returns the current token if it's still valid, else will
|
// Token returns the current token if it's still valid, else will
|
||||||
// refresh the current token (using r.Context for HTTP client
|
// refresh the current token and return the new one.
|
||||||
// information) and return the new one.
|
|
||||||
func (s *reuseTokenSource) Token() (*Token, error) {
|
func (s *reuseTokenSource) Token() (*Token, error) {
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
@ -322,7 +321,7 @@ func (s *reuseTokenSource) Token() (*Token, error) {
|
|||||||
return t, nil
|
return t, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// StaticTokenSource returns a TokenSource that always returns the same token.
|
// StaticTokenSource returns a [TokenSource] that always returns the same token.
|
||||||
// Because the provided token t is never refreshed, StaticTokenSource is only
|
// Because the provided token t is never refreshed, StaticTokenSource is only
|
||||||
// useful for tokens that never expire.
|
// useful for tokens that never expire.
|
||||||
func StaticTokenSource(t *Token) TokenSource {
|
func StaticTokenSource(t *Token) TokenSource {
|
||||||
@ -338,16 +337,16 @@ func (s staticTokenSource) Token() (*Token, error) {
|
|||||||
return s.t, nil
|
return s.t, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// HTTPClient is the context key to use with golang.org/x/net/context's
|
// HTTPClient is the context key to use with [context.WithValue]
|
||||||
// WithValue function to associate an *http.Client value with a context.
|
// to associate a [*http.Client] value with a context.
|
||||||
var HTTPClient internal.ContextKey
|
var HTTPClient internal.ContextKey
|
||||||
|
|
||||||
// NewClient creates an *http.Client from a Context and TokenSource.
|
// NewClient creates an [*http.Client] from a [context.Context] and [TokenSource].
|
||||||
// The returned client is not valid beyond the lifetime of the context.
|
// The returned client is not valid beyond the lifetime of the context.
|
||||||
//
|
//
|
||||||
// Note that if a custom *http.Client is provided via the Context it
|
// Note that if a custom [*http.Client] is provided via the [context.Context] it
|
||||||
// is used only for token acquisition and is not used to configure the
|
// is used only for token acquisition and is not used to configure the
|
||||||
// *http.Client returned from NewClient.
|
// [*http.Client] returned from NewClient.
|
||||||
//
|
//
|
||||||
// As a special case, if src is nil, a non-OAuth2 client is returned
|
// As a special case, if src is nil, a non-OAuth2 client is returned
|
||||||
// using the provided context. This exists to support related OAuth2
|
// using the provided context. This exists to support related OAuth2
|
||||||
@ -356,15 +355,19 @@ func NewClient(ctx context.Context, src TokenSource) *http.Client {
|
|||||||
if src == nil {
|
if src == nil {
|
||||||
return internal.ContextClient(ctx)
|
return internal.ContextClient(ctx)
|
||||||
}
|
}
|
||||||
|
cc := internal.ContextClient(ctx)
|
||||||
return &http.Client{
|
return &http.Client{
|
||||||
Transport: &Transport{
|
Transport: &Transport{
|
||||||
Base: internal.ContextClient(ctx).Transport,
|
Base: cc.Transport,
|
||||||
Source: ReuseTokenSource(nil, src),
|
Source: ReuseTokenSource(nil, src),
|
||||||
},
|
},
|
||||||
|
CheckRedirect: cc.CheckRedirect,
|
||||||
|
Jar: cc.Jar,
|
||||||
|
Timeout: cc.Timeout,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ReuseTokenSource returns a TokenSource which repeatedly returns the
|
// ReuseTokenSource returns a [TokenSource] which repeatedly returns the
|
||||||
// same token as long as it's valid, starting with t.
|
// same token as long as it's valid, starting with t.
|
||||||
// When its cached token is invalid, a new token is obtained from src.
|
// When its cached token is invalid, a new token is obtained from src.
|
||||||
//
|
//
|
||||||
@ -372,10 +375,10 @@ func NewClient(ctx context.Context, src TokenSource) *http.Client {
|
|||||||
// (such as a file on disk) between runs of a program, rather than
|
// (such as a file on disk) between runs of a program, rather than
|
||||||
// obtaining new tokens unnecessarily.
|
// obtaining new tokens unnecessarily.
|
||||||
//
|
//
|
||||||
// The initial token t may be nil, in which case the TokenSource is
|
// The initial token t may be nil, in which case the [TokenSource] is
|
||||||
// wrapped in a caching version if it isn't one already. This also
|
// wrapped in a caching version if it isn't one already. This also
|
||||||
// means it's always safe to wrap ReuseTokenSource around any other
|
// means it's always safe to wrap ReuseTokenSource around any other
|
||||||
// TokenSource without adverse effects.
|
// [TokenSource] without adverse effects.
|
||||||
func ReuseTokenSource(t *Token, src TokenSource) TokenSource {
|
func ReuseTokenSource(t *Token, src TokenSource) TokenSource {
|
||||||
// Don't wrap a reuseTokenSource in itself. That would work,
|
// Don't wrap a reuseTokenSource in itself. That would work,
|
||||||
// but cause an unnecessary number of mutex operations.
|
// but cause an unnecessary number of mutex operations.
|
||||||
@ -393,8 +396,8 @@ func ReuseTokenSource(t *Token, src TokenSource) TokenSource {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ReuseTokenSourceWithExpiry returns a TokenSource that acts in the same manner as the
|
// ReuseTokenSourceWithExpiry returns a [TokenSource] that acts in the same manner as the
|
||||||
// TokenSource returned by ReuseTokenSource, except the expiry buffer is
|
// [TokenSource] returned by [ReuseTokenSource], except the expiry buffer is
|
||||||
// configurable. The expiration time of a token is calculated as
|
// configurable. The expiration time of a token is calculated as
|
||||||
// t.Expiry.Add(-earlyExpiry).
|
// t.Expiry.Add(-earlyExpiry).
|
||||||
func ReuseTokenSourceWithExpiry(t *Token, src TokenSource, earlyExpiry time.Duration) TokenSource {
|
func ReuseTokenSourceWithExpiry(t *Token, src TokenSource, earlyExpiry time.Duration) TokenSource {
|
||||||
|
@ -9,7 +9,6 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"io/ioutil"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"net/url"
|
"net/url"
|
||||||
@ -104,7 +103,7 @@ func TestExchangeRequest(t *testing.T) {
|
|||||||
if headerContentType != "application/x-www-form-urlencoded" {
|
if headerContentType != "application/x-www-form-urlencoded" {
|
||||||
t.Errorf("Unexpected Content-Type header %q", headerContentType)
|
t.Errorf("Unexpected Content-Type header %q", headerContentType)
|
||||||
}
|
}
|
||||||
body, err := ioutil.ReadAll(r.Body)
|
body, err := io.ReadAll(r.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("Failed reading request body: %s.", err)
|
t.Errorf("Failed reading request body: %s.", err)
|
||||||
}
|
}
|
||||||
@ -148,7 +147,7 @@ func TestExchangeRequest_CustomParam(t *testing.T) {
|
|||||||
if headerContentType != "application/x-www-form-urlencoded" {
|
if headerContentType != "application/x-www-form-urlencoded" {
|
||||||
t.Errorf("Unexpected Content-Type header, %v is found.", headerContentType)
|
t.Errorf("Unexpected Content-Type header, %v is found.", headerContentType)
|
||||||
}
|
}
|
||||||
body, err := ioutil.ReadAll(r.Body)
|
body, err := io.ReadAll(r.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("Failed reading request body: %s.", err)
|
t.Errorf("Failed reading request body: %s.", err)
|
||||||
}
|
}
|
||||||
@ -194,7 +193,7 @@ func TestExchangeRequest_JSONResponse(t *testing.T) {
|
|||||||
if headerContentType != "application/x-www-form-urlencoded" {
|
if headerContentType != "application/x-www-form-urlencoded" {
|
||||||
t.Errorf("Unexpected Content-Type header, %v is found.", headerContentType)
|
t.Errorf("Unexpected Content-Type header, %v is found.", headerContentType)
|
||||||
}
|
}
|
||||||
body, err := ioutil.ReadAll(r.Body)
|
body, err := io.ReadAll(r.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("Failed reading request body: %s.", err)
|
t.Errorf("Failed reading request body: %s.", err)
|
||||||
}
|
}
|
||||||
@ -301,7 +300,7 @@ func testExchangeRequest_JSONResponse_expiry(t *testing.T, exp string, want, nul
|
|||||||
conf := newConf(ts.URL)
|
conf := newConf(ts.URL)
|
||||||
t1 := time.Now().Add(day)
|
t1 := time.Now().Add(day)
|
||||||
tok, err := conf.Exchange(context.Background(), "exchange-code")
|
tok, err := conf.Exchange(context.Background(), "exchange-code")
|
||||||
t2 := t1.Add(day)
|
t2 := time.Now().Add(day)
|
||||||
|
|
||||||
if got := (err == nil); got != want {
|
if got := (err == nil); got != want {
|
||||||
if want {
|
if want {
|
||||||
@ -393,7 +392,7 @@ func TestPasswordCredentialsTokenRequest(t *testing.T) {
|
|||||||
if headerContentType != expected {
|
if headerContentType != expected {
|
||||||
t.Errorf("Content-Type header = %q; want %q", headerContentType, expected)
|
t.Errorf("Content-Type header = %q; want %q", headerContentType, expected)
|
||||||
}
|
}
|
||||||
body, err := ioutil.ReadAll(r.Body)
|
body, err := io.ReadAll(r.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("Failed reading request body: %s.", err)
|
t.Errorf("Failed reading request body: %s.", err)
|
||||||
}
|
}
|
||||||
@ -435,7 +434,7 @@ func TestTokenRefreshRequest(t *testing.T) {
|
|||||||
if headerContentType != "application/x-www-form-urlencoded" {
|
if headerContentType != "application/x-www-form-urlencoded" {
|
||||||
t.Errorf("Unexpected Content-Type header %q", headerContentType)
|
t.Errorf("Unexpected Content-Type header %q", headerContentType)
|
||||||
}
|
}
|
||||||
body, _ := ioutil.ReadAll(r.Body)
|
body, _ := io.ReadAll(r.Body)
|
||||||
if string(body) != "grant_type=refresh_token&refresh_token=REFRESH_TOKEN" {
|
if string(body) != "grant_type=refresh_token&refresh_token=REFRESH_TOKEN" {
|
||||||
t.Errorf("Unexpected refresh token payload %q", body)
|
t.Errorf("Unexpected refresh token payload %q", body)
|
||||||
}
|
}
|
||||||
@ -460,7 +459,7 @@ func TestFetchWithNoRefreshToken(t *testing.T) {
|
|||||||
if headerContentType != "application/x-www-form-urlencoded" {
|
if headerContentType != "application/x-www-form-urlencoded" {
|
||||||
t.Errorf("Unexpected Content-Type header, %v is found.", headerContentType)
|
t.Errorf("Unexpected Content-Type header, %v is found.", headerContentType)
|
||||||
}
|
}
|
||||||
body, _ := ioutil.ReadAll(r.Body)
|
body, _ := io.ReadAll(r.Body)
|
||||||
if string(body) != "client_id=CLIENT_ID&grant_type=refresh_token&refresh_token=REFRESH_TOKEN" {
|
if string(body) != "client_id=CLIENT_ID&grant_type=refresh_token&refresh_token=REFRESH_TOKEN" {
|
||||||
t.Errorf("Unexpected refresh token payload, %v is found.", string(body))
|
t.Errorf("Unexpected refresh token payload, %v is found.", string(body))
|
||||||
}
|
}
|
||||||
|
15
pkce.go
15
pkce.go
@ -1,6 +1,7 @@
|
|||||||
// Copyright 2023 The Go Authors. All rights reserved.
|
// Copyright 2023 The Go Authors. All rights reserved.
|
||||||
// Use of this source code is governed by a BSD-style
|
// Use of this source code is governed by a BSD-style
|
||||||
// license that can be found in the LICENSE file.
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
package oauth2
|
package oauth2
|
||||||
|
|
||||||
import (
|
import (
|
||||||
@ -20,9 +21,9 @@ const (
|
|||||||
// This follows recommendations in RFC 7636.
|
// This follows recommendations in RFC 7636.
|
||||||
//
|
//
|
||||||
// A fresh verifier should be generated for each authorization.
|
// A fresh verifier should be generated for each authorization.
|
||||||
// S256ChallengeOption(verifier) should then be passed to Config.AuthCodeURL
|
// The resulting verifier should be passed to [Config.AuthCodeURL] or [Config.DeviceAuth]
|
||||||
// (or Config.DeviceAccess) and VerifierOption(verifier) to Config.Exchange
|
// with [S256ChallengeOption], and to [Config.Exchange] or [Config.DeviceAccessToken]
|
||||||
// (or Config.DeviceAccessToken).
|
// with [VerifierOption].
|
||||||
func GenerateVerifier() string {
|
func GenerateVerifier() string {
|
||||||
// "RECOMMENDED that the output of a suitable random number generator be
|
// "RECOMMENDED that the output of a suitable random number generator be
|
||||||
// used to create a 32-octet sequence. The octet sequence is then
|
// used to create a 32-octet sequence. The octet sequence is then
|
||||||
@ -36,22 +37,22 @@ func GenerateVerifier() string {
|
|||||||
return base64.RawURLEncoding.EncodeToString(data)
|
return base64.RawURLEncoding.EncodeToString(data)
|
||||||
}
|
}
|
||||||
|
|
||||||
// VerifierOption returns a PKCE code verifier AuthCodeOption. It should be
|
// VerifierOption returns a PKCE code verifier [AuthCodeOption]. It should only be
|
||||||
// passed to Config.Exchange or Config.DeviceAccessToken only.
|
// passed to [Config.Exchange] or [Config.DeviceAccessToken].
|
||||||
func VerifierOption(verifier string) AuthCodeOption {
|
func VerifierOption(verifier string) AuthCodeOption {
|
||||||
return setParam{k: codeVerifierKey, v: verifier}
|
return setParam{k: codeVerifierKey, v: verifier}
|
||||||
}
|
}
|
||||||
|
|
||||||
// S256ChallengeFromVerifier returns a PKCE code challenge derived from verifier with method S256.
|
// S256ChallengeFromVerifier returns a PKCE code challenge derived from verifier with method S256.
|
||||||
//
|
//
|
||||||
// Prefer to use S256ChallengeOption where possible.
|
// Prefer to use [S256ChallengeOption] where possible.
|
||||||
func S256ChallengeFromVerifier(verifier string) string {
|
func S256ChallengeFromVerifier(verifier string) string {
|
||||||
sha := sha256.Sum256([]byte(verifier))
|
sha := sha256.Sum256([]byte(verifier))
|
||||||
return base64.RawURLEncoding.EncodeToString(sha[:])
|
return base64.RawURLEncoding.EncodeToString(sha[:])
|
||||||
}
|
}
|
||||||
|
|
||||||
// S256ChallengeOption derives a PKCE code challenge derived from verifier with
|
// S256ChallengeOption derives a PKCE code challenge derived from verifier with
|
||||||
// method S256. It should be passed to Config.AuthCodeURL or Config.DeviceAccess
|
// method S256. It should be passed to [Config.AuthCodeURL] or [Config.DeviceAuth]
|
||||||
// only.
|
// only.
|
||||||
func S256ChallengeOption(verifier string) AuthCodeOption {
|
func S256ChallengeOption(verifier string) AuthCodeOption {
|
||||||
return challengeOption{
|
return challengeOption{
|
||||||
|
17
token.go
17
token.go
@ -44,7 +44,7 @@ type Token struct {
|
|||||||
|
|
||||||
// Expiry is the optional expiration time of the access token.
|
// Expiry is the optional expiration time of the access token.
|
||||||
//
|
//
|
||||||
// If zero, TokenSource implementations will reuse the same
|
// If zero, [TokenSource] implementations will reuse the same
|
||||||
// token forever and RefreshToken or equivalent
|
// token forever and RefreshToken or equivalent
|
||||||
// mechanisms for that TokenSource will not be used.
|
// mechanisms for that TokenSource will not be used.
|
||||||
Expiry time.Time `json:"expiry,omitempty"`
|
Expiry time.Time `json:"expiry,omitempty"`
|
||||||
@ -58,7 +58,7 @@ type Token struct {
|
|||||||
|
|
||||||
// raw optionally contains extra metadata from the server
|
// raw optionally contains extra metadata from the server
|
||||||
// when updating a token.
|
// when updating a token.
|
||||||
raw interface{}
|
raw any
|
||||||
|
|
||||||
// expiryDelta is used to calculate when a token is considered
|
// expiryDelta is used to calculate when a token is considered
|
||||||
// expired, by subtracting from Expiry. If zero, defaultExpiryDelta
|
// expired, by subtracting from Expiry. If zero, defaultExpiryDelta
|
||||||
@ -86,16 +86,16 @@ func (t *Token) Type() string {
|
|||||||
// SetAuthHeader sets the Authorization header to r using the access
|
// SetAuthHeader sets the Authorization header to r using the access
|
||||||
// token in t.
|
// token in t.
|
||||||
//
|
//
|
||||||
// This method is unnecessary when using Transport or an HTTP Client
|
// This method is unnecessary when using [Transport] or an HTTP Client
|
||||||
// returned by this package.
|
// returned by this package.
|
||||||
func (t *Token) SetAuthHeader(r *http.Request) {
|
func (t *Token) SetAuthHeader(r *http.Request) {
|
||||||
r.Header.Set("Authorization", t.Type()+" "+t.AccessToken)
|
r.Header.Set("Authorization", t.Type()+" "+t.AccessToken)
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithExtra returns a new Token that's a clone of t, but using the
|
// WithExtra returns a new [Token] that's a clone of t, but using the
|
||||||
// provided raw extra map. This is only intended for use by packages
|
// provided raw extra map. This is only intended for use by packages
|
||||||
// implementing derivative OAuth2 flows.
|
// implementing derivative OAuth2 flows.
|
||||||
func (t *Token) WithExtra(extra interface{}) *Token {
|
func (t *Token) WithExtra(extra any) *Token {
|
||||||
t2 := new(Token)
|
t2 := new(Token)
|
||||||
*t2 = *t
|
*t2 = *t
|
||||||
t2.raw = extra
|
t2.raw = extra
|
||||||
@ -105,8 +105,8 @@ func (t *Token) WithExtra(extra interface{}) *Token {
|
|||||||
// Extra returns an extra field.
|
// Extra returns an extra field.
|
||||||
// Extra fields are key-value pairs returned by the server as a
|
// Extra fields are key-value pairs returned by the server as a
|
||||||
// part of the token retrieval response.
|
// part of the token retrieval response.
|
||||||
func (t *Token) Extra(key string) interface{} {
|
func (t *Token) Extra(key string) any {
|
||||||
if raw, ok := t.raw.(map[string]interface{}); ok {
|
if raw, ok := t.raw.(map[string]any); ok {
|
||||||
return raw[key]
|
return raw[key]
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -163,13 +163,14 @@ func tokenFromInternal(t *internal.Token) *Token {
|
|||||||
TokenType: t.TokenType,
|
TokenType: t.TokenType,
|
||||||
RefreshToken: t.RefreshToken,
|
RefreshToken: t.RefreshToken,
|
||||||
Expiry: t.Expiry,
|
Expiry: t.Expiry,
|
||||||
|
ExpiresIn: t.ExpiresIn,
|
||||||
raw: t.Raw,
|
raw: t.Raw,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// retrieveToken takes a *Config and uses that to retrieve an *internal.Token.
|
// retrieveToken takes a *Config and uses that to retrieve an *internal.Token.
|
||||||
// This token is then mapped from *internal.Token into an *oauth2.Token which is returned along
|
// This token is then mapped from *internal.Token into an *oauth2.Token which is returned along
|
||||||
// with an error..
|
// with an error.
|
||||||
func retrieveToken(ctx context.Context, c *Config, v url.Values) (*Token, error) {
|
func retrieveToken(ctx context.Context, c *Config, v url.Values) (*Token, error) {
|
||||||
tk, err := internal.RetrieveToken(ctx, c.ClientID, c.ClientSecret, c.Endpoint.TokenURL, v, internal.AuthStyle(c.Endpoint.AuthStyle), c.authStyleCache.Get())
|
tk, err := internal.RetrieveToken(ctx, c.ClientID, c.ClientSecret, c.Endpoint.TokenURL, v, internal.AuthStyle(c.Endpoint.AuthStyle), c.authStyleCache.Get())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -12,8 +12,8 @@ import (
|
|||||||
func TestTokenExtra(t *testing.T) {
|
func TestTokenExtra(t *testing.T) {
|
||||||
type testCase struct {
|
type testCase struct {
|
||||||
key string
|
key string
|
||||||
val interface{}
|
val any
|
||||||
want interface{}
|
want any
|
||||||
}
|
}
|
||||||
const key = "extra-key"
|
const key = "extra-key"
|
||||||
cases := []testCase{
|
cases := []testCase{
|
||||||
@ -23,7 +23,7 @@ func TestTokenExtra(t *testing.T) {
|
|||||||
{key: "other-key", val: "def", want: nil},
|
{key: "other-key", val: "def", want: nil},
|
||||||
}
|
}
|
||||||
for _, tc := range cases {
|
for _, tc := range cases {
|
||||||
extra := make(map[string]interface{})
|
extra := make(map[string]any)
|
||||||
extra[tc.key] = tc.val
|
extra[tc.key] = tc.val
|
||||||
tok := &Token{raw: extra}
|
tok := &Token{raw: extra}
|
||||||
if got, want := tok.Extra(key), tc.want; got != want {
|
if got, want := tok.Extra(key), tc.want; got != want {
|
||||||
|
24
transport.go
24
transport.go
@ -11,12 +11,12 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Transport is an http.RoundTripper that makes OAuth 2.0 HTTP requests,
|
// Transport is an [http.RoundTripper] that makes OAuth 2.0 HTTP requests,
|
||||||
// wrapping a base RoundTripper and adding an Authorization header
|
// wrapping a base [http.RoundTripper] and adding an Authorization header
|
||||||
// with a token from the supplied Sources.
|
// with a token from the supplied [TokenSource].
|
||||||
//
|
//
|
||||||
// Transport is a low-level mechanism. Most code will use the
|
// Transport is a low-level mechanism. Most code will use the
|
||||||
// higher-level Config.Client method instead.
|
// higher-level [Config.Client] method instead.
|
||||||
type Transport struct {
|
type Transport struct {
|
||||||
// Source supplies the token to add to outgoing requests'
|
// Source supplies the token to add to outgoing requests'
|
||||||
// Authorization headers.
|
// Authorization headers.
|
||||||
@ -47,7 +47,7 @@ func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
req2 := cloneRequest(req) // per RoundTripper contract
|
req2 := req.Clone(req.Context())
|
||||||
token.SetAuthHeader(req2)
|
token.SetAuthHeader(req2)
|
||||||
|
|
||||||
// req.Body is assumed to be closed by the base RoundTripper.
|
// req.Body is assumed to be closed by the base RoundTripper.
|
||||||
@ -73,17 +73,3 @@ func (t *Transport) base() http.RoundTripper {
|
|||||||
}
|
}
|
||||||
return http.DefaultTransport
|
return http.DefaultTransport
|
||||||
}
|
}
|
||||||
|
|
||||||
// cloneRequest returns a clone of the provided *http.Request.
|
|
||||||
// The clone is a shallow copy of the struct and its Header map.
|
|
||||||
func cloneRequest(r *http.Request) *http.Request {
|
|
||||||
// shallow copy of the struct
|
|
||||||
r2 := new(http.Request)
|
|
||||||
*r2 = *r
|
|
||||||
// deep copy of the Header
|
|
||||||
r2.Header = make(http.Header, len(r.Header))
|
|
||||||
for k, s := range r.Header {
|
|
||||||
r2.Header[k] = append([]string(nil), s...)
|
|
||||||
}
|
|
||||||
return r2
|
|
||||||
}
|
|
||||||
|
@ -9,12 +9,6 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
type tokenSource struct{ token *Token }
|
|
||||||
|
|
||||||
func (t *tokenSource) Token() (*Token, error) {
|
|
||||||
return t.token, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestTransportNilTokenSource(t *testing.T) {
|
func TestTransportNilTokenSource(t *testing.T) {
|
||||||
tr := &Transport{}
|
tr := &Transport{}
|
||||||
server := newMockServer(func(w http.ResponseWriter, r *http.Request) {})
|
server := newMockServer(func(w http.ResponseWriter, r *http.Request) {})
|
||||||
@ -88,13 +82,10 @@ func TestTransportCloseRequestBodySuccess(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestTransportTokenSource(t *testing.T) {
|
func TestTransportTokenSource(t *testing.T) {
|
||||||
ts := &tokenSource{
|
|
||||||
token: &Token{
|
|
||||||
AccessToken: "abc",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
tr := &Transport{
|
tr := &Transport{
|
||||||
Source: ts,
|
Source: StaticTokenSource(&Token{
|
||||||
|
AccessToken: "abc",
|
||||||
|
}),
|
||||||
}
|
}
|
||||||
server := newMockServer(func(w http.ResponseWriter, r *http.Request) {
|
server := newMockServer(func(w http.ResponseWriter, r *http.Request) {
|
||||||
if got, want := r.Header.Get("Authorization"), "Bearer abc"; got != want {
|
if got, want := r.Header.Get("Authorization"), "Bearer abc"; got != want {
|
||||||
@ -123,14 +114,11 @@ func TestTransportTokenSourceTypes(t *testing.T) {
|
|||||||
{key: "basic", val: val, want: "Basic abc"},
|
{key: "basic", val: val, want: "Basic abc"},
|
||||||
}
|
}
|
||||||
for _, tc := range tests {
|
for _, tc := range tests {
|
||||||
ts := &tokenSource{
|
tr := &Transport{
|
||||||
token: &Token{
|
Source: StaticTokenSource(&Token{
|
||||||
AccessToken: tc.val,
|
AccessToken: tc.val,
|
||||||
TokenType: tc.key,
|
TokenType: tc.key,
|
||||||
},
|
}),
|
||||||
}
|
|
||||||
tr := &Transport{
|
|
||||||
Source: ts,
|
|
||||||
}
|
}
|
||||||
server := newMockServer(func(w http.ResponseWriter, r *http.Request) {
|
server := newMockServer(func(w http.ResponseWriter, r *http.Request) {
|
||||||
if got, want := r.Header.Get("Authorization"), tc.want; got != want {
|
if got, want := r.Header.Get("Authorization"), tc.want; got != want {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user