Compare commits

..

No commits in common. "master" and "v0.28.0" have entirely different histories.

39 changed files with 243 additions and 512 deletions

View File

@ -34,7 +34,7 @@ type PKCEParams struct {
// and returns an auth code and state upon approval. // and returns an auth code and state upon approval.
type AuthorizationHandler func(authCodeURL string) (code string, state string, err error) type AuthorizationHandler func(authCodeURL string) (code string, state string, err error)
// TokenSourceWithPKCE is an enhanced version of [oauth2.TokenSource] with PKCE support. // TokenSourceWithPKCE is an enhanced version of TokenSource with PKCE support.
// //
// The pkce parameter supports PKCE flow, which uses code challenge and code verifier // The pkce parameter supports PKCE flow, which uses code challenge and code verifier
// to prevent CSRF attacks. A unique code challenge and code verifier should be generated // to prevent CSRF attacks. A unique code challenge and code verifier should be generated
@ -43,12 +43,12 @@ func TokenSourceWithPKCE(ctx context.Context, config *oauth2.Config, state strin
return oauth2.ReuseTokenSource(nil, authHandlerSource{config: config, ctx: ctx, authHandler: authHandler, state: state, pkce: pkce}) return oauth2.ReuseTokenSource(nil, authHandlerSource{config: config, ctx: ctx, authHandler: authHandler, state: state, pkce: pkce})
} }
// TokenSource returns an [oauth2.TokenSource] that fetches access tokens // TokenSource returns an oauth2.TokenSource that fetches access tokens
// using 3-legged-OAuth flow. // using 3-legged-OAuth flow.
// //
// The provided [context.Context] is used for oauth2 Exchange operation. // The provided context.Context is used for oauth2 Exchange operation.
// //
// The provided [oauth2.Config] should be a full configuration containing AuthURL, // The provided oauth2.Config should be a full configuration containing AuthURL,
// TokenURL, and Scope. // TokenURL, and Scope.
// //
// An environment-specific AuthorizationHandler is used to obtain user consent. // An environment-specific AuthorizationHandler is used to obtain user consent.

View File

@ -55,7 +55,7 @@ type Config struct {
// Token uses client credentials to retrieve a token. // Token uses client credentials to retrieve a token.
// //
// The provided context optionally controls which HTTP client is used. See the [oauth2.HTTPClient] variable. // The provided context optionally controls which HTTP client is used. See the oauth2.HTTPClient variable.
func (c *Config) Token(ctx context.Context) (*oauth2.Token, error) { func (c *Config) Token(ctx context.Context) (*oauth2.Token, error) {
return c.TokenSource(ctx).Token() return c.TokenSource(ctx).Token()
} }
@ -64,18 +64,18 @@ func (c *Config) Token(ctx context.Context) (*oauth2.Token, error) {
// The token will auto-refresh as necessary. // The token will auto-refresh as necessary.
// //
// The provided context optionally controls which HTTP client // The provided context optionally controls which HTTP client
// is returned. See the [oauth2.HTTPClient] variable. // is returned. See the oauth2.HTTPClient variable.
// //
// The returned [http.Client] and its Transport should not be modified. // The returned Client and its Transport should not be modified.
func (c *Config) Client(ctx context.Context) *http.Client { func (c *Config) Client(ctx context.Context) *http.Client {
return oauth2.NewClient(ctx, c.TokenSource(ctx)) return oauth2.NewClient(ctx, c.TokenSource(ctx))
} }
// TokenSource returns a [oauth2.TokenSource] that returns t until t expires, // TokenSource returns a TokenSource that returns t until t expires,
// automatically refreshing it as necessary using the provided context and the // automatically refreshing it as necessary using the provided context and the
// client ID and client secret. // client ID and client secret.
// //
// Most users will use [Config.Client] instead. // Most users will use Config.Client instead.
func (c *Config) TokenSource(ctx context.Context) oauth2.TokenSource { func (c *Config) TokenSource(ctx context.Context) oauth2.TokenSource {
source := &tokenSource{ source := &tokenSource{
ctx: ctx, ctx: ctx,

View File

@ -7,6 +7,7 @@ package clientcredentials
import ( import (
"context" "context"
"io" "io"
"io/ioutil"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url" "net/url"
@ -35,9 +36,9 @@ func TestTokenSourceGrantTypeOverride(t *testing.T) {
wantGrantType := "password" wantGrantType := "password"
var gotGrantType string var gotGrantType string
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, err := io.ReadAll(r.Body) body, err := ioutil.ReadAll(r.Body)
if err != nil { if err != nil {
t.Errorf("io.ReadAll(r.Body) == %v, %v, want _, <nil>", body, err) t.Errorf("ioutil.ReadAll(r.Body) == %v, %v, want _, <nil>", body, err)
} }
if err := r.Body.Close(); err != nil { if err := r.Body.Close(); err != nil {
t.Errorf("r.Body.Close() == %v, want <nil>", err) t.Errorf("r.Body.Close() == %v, want <nil>", err)
@ -80,7 +81,7 @@ func TestTokenRequest(t *testing.T) {
if got, want := r.Header.Get("Content-Type"), "application/x-www-form-urlencoded"; got != want { if got, want := r.Header.Get("Content-Type"), "application/x-www-form-urlencoded"; got != want {
t.Errorf("Content-Type header = %q; want %q", got, want) t.Errorf("Content-Type header = %q; want %q", got, want)
} }
body, err := io.ReadAll(r.Body) body, err := ioutil.ReadAll(r.Body)
if err != nil { if err != nil {
r.Body.Close() r.Body.Close()
} }
@ -122,7 +123,7 @@ func TestTokenRefreshRequest(t *testing.T) {
if got, want := headerContentType, "application/x-www-form-urlencoded"; got != want { if got, want := headerContentType, "application/x-www-form-urlencoded"; got != want {
t.Errorf("Content-Type = %q; want %q", got, want) t.Errorf("Content-Type = %q; want %q", got, want)
} }
body, _ := io.ReadAll(r.Body) body, _ := ioutil.ReadAll(r.Body)
const want = "audience=audience1&grant_type=client_credentials&scope=scope1+scope2" const want = "audience=audience1&grant_type=client_credentials&scope=scope1+scope2"
if string(body) != want { if string(body) != want {
t.Errorf("Unexpected refresh token payload.\n got: %s\nwant: %s\n", body, want) t.Errorf("Unexpected refresh token payload.\n got: %s\nwant: %s\n", body, want)

View File

@ -7,6 +7,9 @@ import (
"strings" "strings"
"testing" "testing"
"time" "time"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
) )
func TestDeviceAuthResponseMarshalJson(t *testing.T) { func TestDeviceAuthResponseMarshalJson(t *testing.T) {
@ -71,16 +74,7 @@ func TestDeviceAuthResponseUnmarshalJson(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
margin := time.Second + time.Since(begin) if !cmp.Equal(got, tc.want, cmpopts.IgnoreUnexported(DeviceAuthResponse{}), cmpopts.EquateApproxTime(time.Second+time.Since(begin))) {
timeDiff := got.Expiry.Sub(tc.want.Expiry)
if timeDiff < 0 {
timeDiff *= -1
}
if timeDiff > margin {
t.Errorf("expiry time difference too large, got=%v, want=%v margin=%v", got.Expiry, tc.want.Expiry, margin)
}
got.Expiry, tc.want.Expiry = time.Time{}, time.Time{}
if got != tc.want {
t.Errorf("want=%#v, got=%#v", tc.want, got) t.Errorf("want=%#v, got=%#v", tc.want, got)
} }
}) })

View File

@ -6,7 +6,7 @@
package endpoints package endpoints
import ( import (
"net/url" "strings"
"golang.org/x/oauth2" "golang.org/x/oauth2"
) )
@ -17,30 +17,6 @@ var Amazon = oauth2.Endpoint{
TokenURL: "https://api.amazon.com/auth/o2/token", TokenURL: "https://api.amazon.com/auth/o2/token",
} }
// Apple is the endpoint for "Sign in with Apple".
//
// Documentation: https://developer.apple.com/documentation/signinwithapplerestapi
var Apple = oauth2.Endpoint{
AuthURL: "https://appleid.apple.com/auth/authorize",
TokenURL: "https://appleid.apple.com/auth/token",
}
// Asana is the endpoint for Asana.
//
// Documentation: https://developers.asana.com/docs/oauth
var Asana = oauth2.Endpoint{
AuthURL: "https://app.asana.com/-/oauth_authorize",
TokenURL: "https://app.asana.com/-/oauth_token",
}
// Badgr is the endpoint for Canvas Badges.
//
// Documentation: https://community.canvaslms.com/t5/Canvas-Badges-Credentials/Developers-Build-an-app-that-integrates-with-the-Canvas-Badges/ta-p/528727
var Badgr = oauth2.Endpoint{
AuthURL: "https://badgr.com/auth/oauth2/authorize",
TokenURL: "https://api.badgr.io/o/token",
}
// Battlenet is the endpoint for Battlenet. // Battlenet is the endpoint for Battlenet.
var Battlenet = oauth2.Endpoint{ var Battlenet = oauth2.Endpoint{
AuthURL: "https://battle.net/oauth/authorize", AuthURL: "https://battle.net/oauth/authorize",
@ -59,44 +35,16 @@ var Cern = oauth2.Endpoint{
TokenURL: "https://oauth.web.cern.ch/OAuth/Token", TokenURL: "https://oauth.web.cern.ch/OAuth/Token",
} }
// Coinbase is the endpoint for Coinbase.
//
// Documentation: https://docs.cdp.coinbase.com/coinbase-app/docs/coinbase-app-reference
var Coinbase = oauth2.Endpoint{
AuthURL: "https://login.coinbase.com/oauth2/auth",
TokenURL: "https://login.coinbase.com/oauth2/token",
}
// Discord is the endpoint for Discord. // Discord is the endpoint for Discord.
//
// Documentation: https://discord.com/developers/docs/topics/oauth2#shared-resources-oauth2-urls
var Discord = oauth2.Endpoint{ var Discord = oauth2.Endpoint{
AuthURL: "https://discord.com/oauth2/authorize", AuthURL: "https://discord.com/oauth2/authorize",
TokenURL: "https://discord.com/api/oauth2/token", TokenURL: "https://discord.com/api/oauth2/token",
} }
// Dropbox is the endpoint for Dropbox.
//
// Documentation: https://developers.dropbox.com/oauth-guide
var Dropbox = oauth2.Endpoint{
AuthURL: "https://www.dropbox.com/oauth2/authorize",
TokenURL: "https://api.dropboxapi.com/oauth2/token",
}
// Endpoint is Ebay's OAuth 2.0 endpoint.
//
// Documentation: https://developer.ebay.com/api-docs/static/authorization_guide_landing.html
var Endpoint = oauth2.Endpoint{
AuthURL: "https://auth.ebay.com/oauth2/authorize",
TokenURL: "https://api.ebay.com/identity/v1/oauth2/token",
}
// Facebook is the endpoint for Facebook. // Facebook is the endpoint for Facebook.
//
// Documentation: https://developers.facebook.com/docs/facebook-login/guides/advanced/manual-flow
var Facebook = oauth2.Endpoint{ var Facebook = oauth2.Endpoint{
AuthURL: "https://www.facebook.com/v22.0/dialog/oauth", AuthURL: "https://www.facebook.com/v3.2/dialog/oauth",
TokenURL: "https://graph.facebook.com/v22.0/oauth/access_token", TokenURL: "https://graph.facebook.com/v3.2/oauth/access_token",
} }
// Foursquare is the endpoint for Foursquare. // Foursquare is the endpoint for Foursquare.
@ -156,14 +104,6 @@ var KaKao = oauth2.Endpoint{
TokenURL: "https://kauth.kakao.com/oauth/token", TokenURL: "https://kauth.kakao.com/oauth/token",
} }
// Line is the endpoint for Line.
//
// Documentation: https://developers.line.biz/en/docs/line-login/integrate-line-login/
var Line = oauth2.Endpoint{
AuthURL: "https://access.line.me/oauth2/v2.1/authorize",
TokenURL: "https://api.line.me/oauth2/v2.1/token",
}
// LinkedIn is the endpoint for LinkedIn. // LinkedIn is the endpoint for LinkedIn.
var LinkedIn = oauth2.Endpoint{ var LinkedIn = oauth2.Endpoint{
AuthURL: "https://www.linkedin.com/oauth/v2/authorization", AuthURL: "https://www.linkedin.com/oauth/v2/authorization",
@ -200,17 +140,7 @@ var Microsoft = oauth2.Endpoint{
TokenURL: "https://login.live.com/oauth20_token.srf", TokenURL: "https://login.live.com/oauth20_token.srf",
} }
// Naver is the endpoint for Naver.
//
// Documentation: https://developers.naver.com/docs/login/devguide/devguide.md
var Naver = oauth2.Endpoint{
AuthURL: "https://nid.naver.com/oauth2/authorize",
TokenURL: "https://nid.naver.com/oauth2/token",
}
// NokiaHealth is the endpoint for Nokia Health. // NokiaHealth is the endpoint for Nokia Health.
//
// Deprecated: Nokia Health is now Withings.
var NokiaHealth = oauth2.Endpoint{ var NokiaHealth = oauth2.Endpoint{
AuthURL: "https://account.health.nokia.com/oauth2_user/authorize2", AuthURL: "https://account.health.nokia.com/oauth2_user/authorize2",
TokenURL: "https://account.health.nokia.com/oauth2/token", TokenURL: "https://account.health.nokia.com/oauth2/token",
@ -222,14 +152,6 @@ var Odnoklassniki = oauth2.Endpoint{
TokenURL: "https://api.odnoklassniki.ru/oauth/token.do", TokenURL: "https://api.odnoklassniki.ru/oauth/token.do",
} }
// OpenStreetMap is the endpoint for OpenStreetMap.org.
//
// Documentation: https://wiki.openstreetmap.org/wiki/OAuth
var OpenStreetMap = oauth2.Endpoint{
AuthURL: "https://www.openstreetmap.org/oauth2/authorize",
TokenURL: "https://www.openstreetmap.org/oauth2/token",
}
// Patreon is the endpoint for Patreon. // Patreon is the endpoint for Patreon.
var Patreon = oauth2.Endpoint{ var Patreon = oauth2.Endpoint{
AuthURL: "https://www.patreon.com/oauth2/authorize", AuthURL: "https://www.patreon.com/oauth2/authorize",
@ -248,52 +170,10 @@ var PayPalSandbox = oauth2.Endpoint{
TokenURL: "https://api.sandbox.paypal.com/v1/identity/openidconnect/tokenservice", TokenURL: "https://api.sandbox.paypal.com/v1/identity/openidconnect/tokenservice",
} }
// Pinterest is the endpoint for Pinterest.
//
// Documentation: https://developers.pinterest.com/docs/getting-started/set-up-authentication-and-authorization/
var Pinterest = oauth2.Endpoint{
AuthURL: "https://www.pinterest.com/oauth",
TokenURL: "https://api.pinterest.com/v5/oauth/token",
}
// Pipedrive is the endpoint for Pipedrive.
//
// Documentation: https://developers.pipedrive.com/docs/api/v1/Oauth
var Pipedrive = oauth2.Endpoint{
AuthURL: "https://oauth.pipedrive.com/oauth/authorize",
TokenURL: "https://oauth.pipedrive.com/oauth/token",
}
// QQ is the endpoint for QQ.
//
// Documentation: https://wiki.connect.qq.com/%e5%bc%80%e5%8f%91%e6%94%bb%e7%95%a5_server-side
var QQ = oauth2.Endpoint{
AuthURL: "https://graph.qq.com/oauth2.0/authorize",
TokenURL: "https://graph.qq.com/oauth2.0/token",
}
// Rakuten is the endpoint for Rakuten.
//
// Documentation: https://webservice.rakuten.co.jp/documentation
var Rakuten = oauth2.Endpoint{
AuthURL: "https://app.rakuten.co.jp/services/authorize",
TokenURL: "https://app.rakuten.co.jp/services/token",
}
// Slack is the endpoint for Slack. // Slack is the endpoint for Slack.
//
// Documentation: https://api.slack.com/authentication/oauth-v2
var Slack = oauth2.Endpoint{ var Slack = oauth2.Endpoint{
AuthURL: "https://slack.com/oauth/v2/authorize", AuthURL: "https://slack.com/oauth/authorize",
TokenURL: "https://slack.com/api/oauth.v2.access", TokenURL: "https://slack.com/api/oauth.access",
}
// Splitwise is the endpoint for Splitwise.
//
// Documentation: https://dev.splitwise.com/
var Splitwise = oauth2.Endpoint{
AuthURL: "https://www.splitwise.com/oauth/authorize",
TokenURL: "https://www.splitwise.com/oauth/token",
} }
// Spotify is the endpoint for Spotify. // Spotify is the endpoint for Spotify.
@ -332,22 +212,6 @@ var Vk = oauth2.Endpoint{
TokenURL: "https://oauth.vk.com/access_token", TokenURL: "https://oauth.vk.com/access_token",
} }
// Withings is the endpoint for Withings.
//
// Documentation: https://account.withings.com/oauth2_user/authorize2
var Withings = oauth2.Endpoint{
AuthURL: "https://account.withings.com/oauth2_user/authorize2",
TokenURL: "https://account.withings.com/oauth2/token",
}
// X is the endpoint for X (Twitter).
//
// Documentation: https://docs.x.com/resources/fundamentals/authentication/oauth-2-0/user-access-token
var X = oauth2.Endpoint{
AuthURL: "https://x.com/i/oauth2/authorize",
TokenURL: "https://api.x.com/2/oauth2/token",
}
// Yahoo is the endpoint for Yahoo. // Yahoo is the endpoint for Yahoo.
var Yahoo = oauth2.Endpoint{ var Yahoo = oauth2.Endpoint{
AuthURL: "https://api.login.yahoo.com/oauth2/request_auth", AuthURL: "https://api.login.yahoo.com/oauth2/request_auth",
@ -366,20 +230,6 @@ var Zoom = oauth2.Endpoint{
TokenURL: "https://zoom.us/oauth/token", TokenURL: "https://zoom.us/oauth/token",
} }
// Asgardeo returns a new oauth2.Endpoint for the given tenant.
//
// Documentation: https://wso2.com/asgardeo/docs/guides/authentication/oidc/discover-oidc-configs/
func AsgardeoEndpoint(tenant string) oauth2.Endpoint {
u := url.URL{
Scheme: "https",
Host: "api.asgardeo.io",
}
return oauth2.Endpoint{
AuthURL: u.JoinPath("t", tenant, "/oauth2/authorize").String(),
TokenURL: u.JoinPath("t", tenant, "/oauth2/token").String(),
}
}
// AzureAD returns a new oauth2.Endpoint for the given tenant at Azure Active Directory. // AzureAD returns a new oauth2.Endpoint for the given tenant at Azure Active Directory.
// If tenant is empty, it uses the tenant called `common`. // If tenant is empty, it uses the tenant called `common`.
// //
@ -389,29 +239,19 @@ func AzureAD(tenant string) oauth2.Endpoint {
if tenant == "" { if tenant == "" {
tenant = "common" tenant = "common"
} }
u := url.URL{
Scheme: "https",
Host: "login.microsoftonline.com",
}
return oauth2.Endpoint{ return oauth2.Endpoint{
AuthURL: u.JoinPath(tenant, "/oauth2/v2.0/authorize").String(), AuthURL: "https://login.microsoftonline.com/" + tenant + "/oauth2/v2.0/authorize",
TokenURL: u.JoinPath(tenant, "/oauth2/v2.0/token").String(), TokenURL: "https://login.microsoftonline.com/" + tenant + "/oauth2/v2.0/token",
DeviceAuthURL: u.JoinPath(tenant, "/oauth2/v2.0/devicecode").String(), DeviceAuthURL: "https://login.microsoftonline.com/" + tenant + "/oauth2/v2.0/devicecode",
} }
} }
// AzureADB2CEndpoint returns a new oauth2.Endpoint for the given tenant and policy at Azure Active Directory B2C. // HipChatServer returns a new oauth2.Endpoint for a HipChat Server instance
// policy is the Azure B2C User flow name Example: `B2C_1_SignUpSignIn`. // running on the given domain or host.
// func HipChatServer(host string) oauth2.Endpoint {
// Documentation: https://docs.microsoft.com/en-us/azure/active-directory-b2c/tokens-overview#endpoints
func AzureADB2CEndpoint(tenant string, policy string) oauth2.Endpoint {
u := url.URL{
Scheme: "https",
Host: tenant + ".b2clogin.com",
}
return oauth2.Endpoint{ return oauth2.Endpoint{
AuthURL: u.JoinPath(tenant+".onmicrosoft.com", policy, "/oauth2/v2.0/authorize").String(), AuthURL: "https://" + host + "/users/authorize",
TokenURL: u.JoinPath(tenant+".onmicrosoft.com", policy, "/oauth2/v2.0/token").String(), TokenURL: "https://" + host + "/v2/oauth/token",
} }
} }
@ -424,42 +264,9 @@ func AzureADB2CEndpoint(tenant string, policy string) oauth2.Endpoint {
// https://docs.aws.amazon.com/cognito/latest/developerguide/cognito-user-pools-assign-domain.html // https://docs.aws.amazon.com/cognito/latest/developerguide/cognito-user-pools-assign-domain.html
// https://docs.aws.amazon.com/cognito/latest/developerguide/cognito-userpools-server-contract-reference.html // https://docs.aws.amazon.com/cognito/latest/developerguide/cognito-userpools-server-contract-reference.html
func AWSCognito(domain string) oauth2.Endpoint { func AWSCognito(domain string) oauth2.Endpoint {
u, err := url.Parse(domain) domain = strings.TrimRight(domain, "/")
if err != nil || u.Scheme == "" || u.Host == "" {
panic("endpoints: invalid domain" + domain)
}
return oauth2.Endpoint{ return oauth2.Endpoint{
AuthURL: u.JoinPath("/oauth2/authorize").String(), AuthURL: domain + "/oauth2/authorize",
TokenURL: u.JoinPath("/oauth2/token").String(), TokenURL: domain + "/oauth2/token",
}
}
// HipChatServer returns a new oauth2.Endpoint for a HipChat Server instance.
// host should be a hostname, without any scheme prefix.
//
// Documentation: https://developer.atlassian.com/server/hipchat/hipchat-rest-api-access-tokens/
func HipChatServer(host string) oauth2.Endpoint {
u := url.URL{
Scheme: "https",
Host: host,
}
return oauth2.Endpoint{
AuthURL: u.JoinPath("/users/authorize").String(),
TokenURL: u.JoinPath("/v2/oauth/token").String(),
}
}
// Shopify returns a new oauth2.Endpoint for the supplied shop domain name.
// host should be a hostname, without any scheme prefix.
//
// Documentation: https://shopify.dev/docs/apps/auth/oauth
func Shopify(host string) oauth2.Endpoint {
u := url.URL{
Scheme: "https",
Host: host,
}
return oauth2.Endpoint{
AuthURL: u.JoinPath("/admin/oauth/authorize").String(),
TokenURL: u.JoinPath("/admin/oauth/access_token").String(),
} }
} }

5
go.mod
View File

@ -2,4 +2,7 @@ module golang.org/x/oauth2
go 1.23.0 go 1.23.0
require cloud.google.com/go/compute/metadata v0.3.0 require (
cloud.google.com/go/compute/metadata v0.3.0
github.com/google/go-cmp v0.5.9
)

2
go.sum
View File

@ -1,2 +1,4 @@
cloud.google.com/go/compute/metadata v0.3.0 h1:Tz+eQXMEqDIKRsmY3cHTL6FVaynIjX2QxYC4trgAKZc= cloud.google.com/go/compute/metadata v0.3.0 h1:Tz+eQXMEqDIKRsmY3cHTL6FVaynIjX2QxYC4trgAKZc=
cloud.google.com/go/compute/metadata v0.3.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k= cloud.google.com/go/compute/metadata v0.3.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k=
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=

View File

@ -39,7 +39,7 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io/ioutil"
"net/http" "net/http"
"net/url" "net/url"
"strings" "strings"
@ -198,7 +198,7 @@ func (dts downscopingTokenSource) Token() (*oauth2.Token, error) {
return nil, fmt.Errorf("unable to generate POST Request %v", err) return nil, fmt.Errorf("unable to generate POST Request %v", err)
} }
defer resp.Body.Close() defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body) respBody, err := ioutil.ReadAll(resp.Body)
if err != nil { if err != nil {
return nil, fmt.Errorf("downscope: unable to read response body: %v", err) return nil, fmt.Errorf("downscope: unable to read response body: %v", err)
} }

View File

@ -6,7 +6,7 @@ package downscope
import ( import (
"context" "context"
"io" "io/ioutil"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"testing" "testing"
@ -27,7 +27,7 @@ func Test_DownscopedTokenSource(t *testing.T) {
if r.URL.String() != "/" { if r.URL.String() != "/" {
t.Errorf("Unexpected request URL, %v is found", r.URL) t.Errorf("Unexpected request URL, %v is found", r.URL)
} }
body, err := io.ReadAll(r.Body) body, err := ioutil.ReadAll(r.Body)
if err != nil { if err != nil {
t.Fatalf("Failed to read request body: %v", err) t.Fatalf("Failed to read request body: %v", err)
} }

View File

@ -7,9 +7,9 @@ package google_test
import ( import (
"context" "context"
"fmt" "fmt"
"io/ioutil"
"log" "log"
"net/http" "net/http"
"os"
"golang.org/x/oauth2" "golang.org/x/oauth2"
"golang.org/x/oauth2/google" "golang.org/x/oauth2/google"
@ -60,7 +60,7 @@ func ExampleJWTConfigFromJSON() {
// To create a service account client, click "Create new Client ID", // To create a service account client, click "Create new Client ID",
// select "Service Account", and click "Create Client ID". A JSON // select "Service Account", and click "Create Client ID". A JSON
// key file will then be downloaded to your computer. // key file will then be downloaded to your computer.
data, err := os.ReadFile("/path/to/your-project-key.json") data, err := ioutil.ReadFile("/path/to/your-project-key.json")
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
@ -136,7 +136,7 @@ func ExampleComputeTokenSource() {
func ExampleCredentialsFromJSON() { func ExampleCredentialsFromJSON() {
ctx := context.Background() ctx := context.Background()
data, err := os.ReadFile("/path/to/key-file.json") data, err := ioutil.ReadFile("/path/to/key-file.json")
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }

View File

@ -14,6 +14,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"net/http" "net/http"
"net/url" "net/url"
"os" "os"
@ -169,7 +170,7 @@ func requestDataHash(req *http.Request) (string, error) {
} }
defer requestBody.Close() defer requestBody.Close()
requestData, err = io.ReadAll(io.LimitReader(requestBody, 1<<20)) requestData, err = ioutil.ReadAll(io.LimitReader(requestBody, 1<<20))
if err != nil { if err != nil {
return "", err return "", err
} }
@ -418,7 +419,7 @@ func (cs *awsCredentialSource) getAWSSessionToken() (string, error) {
} }
defer resp.Body.Close() defer resp.Body.Close()
respBody, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) respBody, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20))
if err != nil { if err != nil {
return "", err return "", err
} }
@ -461,7 +462,7 @@ func (cs *awsCredentialSource) getRegion(headers map[string]string) (string, err
} }
defer resp.Body.Close() defer resp.Body.Close()
respBody, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) respBody, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20))
if err != nil { if err != nil {
return "", err return "", err
} }
@ -530,7 +531,7 @@ func (cs *awsCredentialSource) getMetadataSecurityCredentials(roleName string, h
} }
defer resp.Body.Close() defer resp.Body.Close()
respBody, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) respBody, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20))
if err != nil { if err != nil {
return result, err return result, err
} }
@ -563,7 +564,7 @@ func (cs *awsCredentialSource) getMetadataRoleName(headers map[string]string) (s
} }
defer resp.Body.Close() defer resp.Body.Close()
respBody, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) respBody, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20))
if err != nil { if err != nil {
return "", err return "", err
} }

View File

@ -486,11 +486,11 @@ func (ts tokenSource) Token() (*oauth2.Token, error) {
ClientID: conf.ClientID, ClientID: conf.ClientID,
ClientSecret: conf.ClientSecret, ClientSecret: conf.ClientSecret,
} }
var options map[string]any var options map[string]interface{}
// Do not pass workforce_pool_user_project when client authentication is used. // Do not pass workforce_pool_user_project when client authentication is used.
// The client ID is sufficient for determining the user project. // The client ID is sufficient for determining the user project.
if conf.WorkforcePoolUserProject != "" && conf.ClientID == "" { if conf.WorkforcePoolUserProject != "" && conf.ClientID == "" {
options = map[string]any{ options = map[string]interface{}{
"userProject": conf.WorkforcePoolUserProject, "userProject": conf.WorkforcePoolUserProject,
} }
} }

View File

@ -8,7 +8,7 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io/ioutil"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"testing" "testing"
@ -77,7 +77,7 @@ func run(t *testing.T, config *Config, tets *testExchangeTokenServer) (*oauth2.T
if got, want := headerMetrics, tets.metricsHeader; got != want { if got, want := headerMetrics, tets.metricsHeader; got != want {
t.Errorf("got %v but want %v", got, want) t.Errorf("got %v but want %v", got, want)
} }
body, err := io.ReadAll(r.Body) body, err := ioutil.ReadAll(r.Body)
if err != nil { if err != nil {
t.Fatalf("Failed reading request body: %s.", err) t.Fatalf("Failed reading request body: %s.", err)
} }
@ -131,7 +131,7 @@ func createImpersonationServer(urlWanted, authWanted, bodyWanted, response strin
if got, want := headerContentType, "application/json"; got != want { if got, want := headerContentType, "application/json"; got != want {
t.Errorf("got %v but want %v", got, want) t.Errorf("got %v but want %v", got, want)
} }
body, err := io.ReadAll(r.Body) body, err := ioutil.ReadAll(r.Body)
if err != nil { if err != nil {
t.Fatalf("Failed reading request body: %v.", err) t.Fatalf("Failed reading request body: %v.", err)
} }
@ -160,7 +160,7 @@ func createTargetServer(metricsHeaderWanted string, t *testing.T) *httptest.Serv
if got, want := headerMetrics, metricsHeaderWanted; got != want { if got, want := headerMetrics, metricsHeaderWanted; got != want {
t.Errorf("got %v but want %v", got, want) t.Errorf("got %v but want %v", got, want)
} }
body, err := io.ReadAll(r.Body) body, err := ioutil.ReadAll(r.Body)
if err != nil { if err != nil {
t.Fatalf("Failed reading request body: %v.", err) t.Fatalf("Failed reading request body: %v.", err)
} }

View File

@ -11,6 +11,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"os" "os"
"os/exec" "os/exec"
"regexp" "regexp"
@ -257,7 +258,7 @@ func (cs executableCredentialSource) getTokenFromOutputFile() (token string, err
} }
defer file.Close() defer file.Close()
data, err := io.ReadAll(io.LimitReader(file, 1<<20)) data, err := ioutil.ReadAll(io.LimitReader(file, 1<<20))
if err != nil || len(data) == 0 { if err != nil || len(data) == 0 {
// Cachefile exists, but no data found. Get new credential. // Cachefile exists, but no data found. Get new credential.
return "", nil return "", nil

View File

@ -8,10 +8,13 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io/ioutil"
"os" "os"
"slices" "sort"
"testing" "testing"
"time" "time"
"github.com/google/go-cmp/cmp"
) )
type testEnvironment struct { type testEnvironment struct {
@ -251,12 +254,14 @@ func TestExecutableCredentialGetEnvironment(t *testing.T) {
ecs.env = &tt.environment ecs.env = &tt.environment
got := ecs.executableEnvironment() // This Transformer sorts a []string.
slices.Sort(got) sorter := cmp.Transformer("Sort", func(in []string) []string {
want := tt.expectedEnvironment out := append([]string(nil), in...) // Copy input to avoid mutating it
slices.Sort(want) sort.Strings(out)
return out
})
if !slices.Equal(got, want) { if got, want := ecs.executableEnvironment(), tt.expectedEnvironment; !cmp.Equal(got, want, sorter) {
t.Errorf("Incorrect environment received.\nReceived: %s\nExpected: %s", got, want) t.Errorf("Incorrect environment received.\nReceived: %s\nExpected: %s", got, want)
} }
}) })
@ -609,7 +614,7 @@ func TestRetrieveExecutableSubjectTokenSuccesses(t *testing.T) {
} }
func TestRetrieveOutputFileSubjectTokenNotJSON(t *testing.T) { func TestRetrieveOutputFileSubjectTokenNotJSON(t *testing.T) {
outputFile, err := os.CreateTemp("testdata", "result.*.json") outputFile, err := ioutil.TempFile("testdata", "result.*.json")
if err != nil { if err != nil {
t.Fatalf("Tempfile failed: %v", err) t.Fatalf("Tempfile failed: %v", err)
} }
@ -758,7 +763,7 @@ var cacheFailureTests = []struct {
func TestRetrieveOutputFileSubjectTokenFailureTests(t *testing.T) { func TestRetrieveOutputFileSubjectTokenFailureTests(t *testing.T) {
for _, tt := range cacheFailureTests { for _, tt := range cacheFailureTests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
outputFile, err := os.CreateTemp("testdata", "result.*.json") outputFile, err := ioutil.TempFile("testdata", "result.*.json")
if err != nil { if err != nil {
t.Fatalf("Tempfile failed: %v", err) t.Fatalf("Tempfile failed: %v", err)
} }
@ -861,7 +866,7 @@ var invalidCacheTests = []struct {
func TestRetrieveOutputFileSubjectTokenInvalidCache(t *testing.T) { func TestRetrieveOutputFileSubjectTokenInvalidCache(t *testing.T) {
for _, tt := range invalidCacheTests { for _, tt := range invalidCacheTests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
outputFile, err := os.CreateTemp("testdata", "result.*.json") outputFile, err := ioutil.TempFile("testdata", "result.*.json")
if err != nil { if err != nil {
t.Fatalf("Tempfile failed: %v", err) t.Fatalf("Tempfile failed: %v", err)
} }
@ -965,7 +970,8 @@ var cacheSuccessTests = []struct {
func TestRetrieveOutputFileSubjectTokenJwt(t *testing.T) { func TestRetrieveOutputFileSubjectTokenJwt(t *testing.T) {
for _, tt := range cacheSuccessTests { for _, tt := range cacheSuccessTests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
outputFile, err := os.CreateTemp("testdata", "result.*.json")
outputFile, err := ioutil.TempFile("testdata", "result.*.json")
if err != nil { if err != nil {
t.Fatalf("Tempfile failed: %v", err) t.Fatalf("Tempfile failed: %v", err)
} }

View File

@ -10,6 +10,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"os" "os"
) )
@ -28,14 +29,14 @@ func (cs fileCredentialSource) subjectToken() (string, error) {
return "", fmt.Errorf("oauth2/google/externalaccount: failed to open credential file %q", cs.File) return "", fmt.Errorf("oauth2/google/externalaccount: failed to open credential file %q", cs.File)
} }
defer tokenFile.Close() defer tokenFile.Close()
tokenBytes, err := io.ReadAll(io.LimitReader(tokenFile, 1<<20)) tokenBytes, err := ioutil.ReadAll(io.LimitReader(tokenFile, 1<<20))
if err != nil { if err != nil {
return "", fmt.Errorf("oauth2/google/externalaccount: failed to read credential file: %v", err) return "", fmt.Errorf("oauth2/google/externalaccount: failed to read credential file: %v", err)
} }
tokenBytes = bytes.TrimSpace(tokenBytes) tokenBytes = bytes.TrimSpace(tokenBytes)
switch cs.Format.Type { switch cs.Format.Type {
case "json": case "json":
jsonData := make(map[string]any) jsonData := make(map[string]interface{})
err = json.Unmarshal(tokenBytes, &jsonData) err = json.Unmarshal(tokenBytes, &jsonData)
if err != nil { if err != nil {
return "", fmt.Errorf("oauth2/google/externalaccount: failed to unmarshal subject token file: %v", err) return "", fmt.Errorf("oauth2/google/externalaccount: failed to unmarshal subject token file: %v", err)

View File

@ -7,6 +7,8 @@ package externalaccount
import ( import (
"runtime" "runtime"
"testing" "testing"
"github.com/google/go-cmp/cmp"
) )
func TestGoVersion(t *testing.T) { func TestGoVersion(t *testing.T) {
@ -38,8 +40,8 @@ func TestGoVersion(t *testing.T) {
} { } {
version = tst.v version = tst.v
got := goVersion() got := goVersion()
if got != tst.want { if diff := cmp.Diff(got, tst.want); diff != "" {
t.Errorf("go version = %q, want = %q", got, tst.want) t.Errorf("got(-),want(+):\n%s", diff)
} }
} }
version = runtime.Version version = runtime.Version

View File

@ -10,6 +10,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"net/http" "net/http"
"golang.org/x/oauth2" "golang.org/x/oauth2"
@ -43,7 +44,7 @@ func (cs urlCredentialSource) subjectToken() (string, error) {
} }
defer resp.Body.Close() defer resp.Body.Close()
respBody, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) respBody, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20))
if err != nil { if err != nil {
return "", fmt.Errorf("oauth2/google/externalaccount: invalid body in subject token URL query: %v", err) return "", fmt.Errorf("oauth2/google/externalaccount: invalid body in subject token URL query: %v", err)
} }
@ -53,7 +54,7 @@ func (cs urlCredentialSource) subjectToken() (string, error) {
switch cs.Format.Type { switch cs.Format.Type {
case "json": case "json":
jsonData := make(map[string]any) jsonData := make(map[string]interface{})
err = json.Unmarshal(respBody, &jsonData) err = json.Unmarshal(respBody, &jsonData)
if err != nil { if err != nil {
return "", fmt.Errorf("oauth2/google/externalaccount: failed to unmarshal subject token file: %v", err) return "", fmt.Errorf("oauth2/google/externalaccount: failed to unmarshal subject token file: %v", err)

View File

@ -285,23 +285,27 @@ func (cs computeSource) Token() (*oauth2.Token, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
var res oauth2.Token var res struct {
AccessToken string `json:"access_token"`
ExpiresInSec int `json:"expires_in"`
TokenType string `json:"token_type"`
}
err = json.NewDecoder(strings.NewReader(tokenJSON)).Decode(&res) err = json.NewDecoder(strings.NewReader(tokenJSON)).Decode(&res)
if err != nil { if err != nil {
return nil, fmt.Errorf("oauth2/google: invalid token JSON from metadata: %v", err) return nil, fmt.Errorf("oauth2/google: invalid token JSON from metadata: %v", err)
} }
if res.ExpiresIn == 0 || res.AccessToken == "" { if res.ExpiresInSec == 0 || res.AccessToken == "" {
return nil, fmt.Errorf("oauth2/google: incomplete token received from metadata") return nil, fmt.Errorf("oauth2/google: incomplete token received from metadata")
} }
tok := &oauth2.Token{ tok := &oauth2.Token{
AccessToken: res.AccessToken, AccessToken: res.AccessToken,
TokenType: res.TokenType, TokenType: res.TokenType,
Expiry: time.Now().Add(time.Duration(res.ExpiresIn) * time.Second), Expiry: time.Now().Add(time.Duration(res.ExpiresInSec) * time.Second),
} }
// NOTE(cbro): add hidden metadata about where the token is from. // NOTE(cbro): add hidden metadata about where the token is from.
// This is needed for detection by client libraries to know that credentials come from the metadata server. // This is needed for detection by client libraries to know that credentials come from the metadata server.
// This may be removed in a future version of this library. // This may be removed in a future version of this library.
return tok.WithExtra(map[string]any{ return tok.WithExtra(map[string]interface{}{
"oauth2.google.tokenSource": "compute-metadata", "oauth2.google.tokenSource": "compute-metadata",
"oauth2.google.serviceAccount": acct, "oauth2.google.serviceAccount": acct,
}), nil }), nil

View File

@ -8,7 +8,7 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"errors" "errors"
"io" "io/ioutil"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"testing" "testing"
@ -227,7 +227,7 @@ func (trts *testRefreshTokenServer) run(t *testing.T) (string, error) {
if got, want := headerContentType, trts.ContentType; got != want { if got, want := headerContentType, trts.ContentType; got != want {
t.Errorf("got %v but want %v", got, want) t.Errorf("got %v but want %v", got, want)
} }
body, err := io.ReadAll(r.Body) body, err := ioutil.ReadAll(r.Body)
if err != nil { if err != nil {
t.Fatalf("Failed reading request body: %s.", err) t.Fatalf("Failed reading request body: %s.", err)
} }

View File

@ -10,6 +10,7 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"net/http" "net/http"
"time" "time"
@ -80,7 +81,7 @@ func (its ImpersonateTokenSource) Token() (*oauth2.Token, error) {
return nil, fmt.Errorf("oauth2/google: unable to generate access token: %v", err) return nil, fmt.Errorf("oauth2/google: unable to generate access token: %v", err)
} }
defer resp.Body.Close() defer resp.Body.Close()
body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) body, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20))
if err != nil { if err != nil {
return nil, fmt.Errorf("oauth2/google: unable to read body: %v", err) return nil, fmt.Errorf("oauth2/google: unable to read body: %v", err)
} }

View File

@ -9,6 +9,7 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"net/http" "net/http"
"net/url" "net/url"
"strconv" "strconv"
@ -27,7 +28,7 @@ func defaultHeader() http.Header {
// The first 4 fields are all mandatory. headers can be used to pass additional // The first 4 fields are all mandatory. headers can be used to pass additional
// headers beyond the bare minimum required by the token exchange. options can // headers beyond the bare minimum required by the token exchange. options can
// be used to pass additional JSON-structured options to the remote server. // be used to pass additional JSON-structured options to the remote server.
func ExchangeToken(ctx context.Context, endpoint string, request *TokenExchangeRequest, authentication ClientAuthentication, headers http.Header, options map[string]any) (*Response, error) { func ExchangeToken(ctx context.Context, endpoint string, request *TokenExchangeRequest, authentication ClientAuthentication, headers http.Header, options map[string]interface{}) (*Response, error) {
data := url.Values{} data := url.Values{}
data.Set("audience", request.Audience) data.Set("audience", request.Audience)
data.Set("grant_type", "urn:ietf:params:oauth:grant-type:token-exchange") data.Set("grant_type", "urn:ietf:params:oauth:grant-type:token-exchange")
@ -81,7 +82,7 @@ func makeRequest(ctx context.Context, endpoint string, data url.Values, authenti
} }
defer resp.Body.Close() defer resp.Body.Close()
body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) body, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20))
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -7,7 +7,7 @@ package stsexchange
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"io" "io/ioutil"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url" "net/url"
@ -73,7 +73,7 @@ func TestExchangeToken(t *testing.T) {
if got, want := r.Header.Get("Content-Type"), "application/x-www-form-urlencoded"; got != want { if got, want := r.Header.Get("Content-Type"), "application/x-www-form-urlencoded"; got != want {
t.Errorf("Unexpected Content-Type header, got %v, want %v", got, want) t.Errorf("Unexpected Content-Type header, got %v, want %v", got, want)
} }
body, err := io.ReadAll(r.Body) body, err := ioutil.ReadAll(r.Body)
if err != nil { if err != nil {
t.Errorf("Failed reading request body: %v.", err) t.Errorf("Failed reading request body: %v.", err)
} }
@ -132,7 +132,7 @@ var optsValues = [][]string{{"foo", "bar"}, {"cat", "pan"}}
func TestExchangeToken_Opts(t *testing.T) { func TestExchangeToken_Opts(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, err := io.ReadAll(r.Body) body, err := ioutil.ReadAll(r.Body)
if err != nil { if err != nil {
t.Fatalf("Failed reading request body: %v.", err) t.Fatalf("Failed reading request body: %v.", err)
} }
@ -146,7 +146,7 @@ func TestExchangeToken_Opts(t *testing.T) {
} else if len(strOpts) < 1 { } else if len(strOpts) < 1 {
t.Errorf("\"options\" field has length 0.") t.Errorf("\"options\" field has length 0.")
} }
var opts map[string]any var opts map[string]interface{}
err = json.Unmarshal([]byte(strOpts[0]), &opts) err = json.Unmarshal([]byte(strOpts[0]), &opts)
if err != nil { if err != nil {
t.Fatalf("Couldn't parse received \"options\" field.") t.Fatalf("Couldn't parse received \"options\" field.")
@ -159,7 +159,7 @@ func TestExchangeToken_Opts(t *testing.T) {
if !ok { if !ok {
t.Errorf("Couldn't find first option parameter.") t.Errorf("Couldn't find first option parameter.")
} else { } else {
tOpts1, ok := val.(map[string]any) tOpts1, ok := val.(map[string]interface{})
if !ok { if !ok {
t.Errorf("Failed to assert the first option parameter as type testOpts.") t.Errorf("Failed to assert the first option parameter as type testOpts.")
} else { } else {
@ -176,7 +176,7 @@ func TestExchangeToken_Opts(t *testing.T) {
if !ok { if !ok {
t.Errorf("Couldn't find second option parameter.") t.Errorf("Couldn't find second option parameter.")
} else { } else {
tOpts2, ok := val2.(map[string]any) tOpts2, ok := val2.(map[string]interface{})
if !ok { if !ok {
t.Errorf("Failed to assert the second option parameter as type testOpts.") t.Errorf("Failed to assert the second option parameter as type testOpts.")
} else { } else {
@ -200,7 +200,7 @@ func TestExchangeToken_Opts(t *testing.T) {
firstOption := testOpts{optsValues[0][0], optsValues[0][1]} firstOption := testOpts{optsValues[0][0], optsValues[0][1]}
secondOption := testOpts{optsValues[1][0], optsValues[1][1]} secondOption := testOpts{optsValues[1][0], optsValues[1][1]}
inputOpts := make(map[string]any) inputOpts := make(map[string]interface{})
inputOpts["one"] = firstOption inputOpts["one"] = firstOption
inputOpts["two"] = secondOption inputOpts["two"] = secondOption
ExchangeToken(context.Background(), ts.URL, &exchangeTokenRequest, auth, headers, inputOpts) ExchangeToken(context.Background(), ts.URL, &exchangeTokenRequest, auth, headers, inputOpts)
@ -220,7 +220,7 @@ func TestRefreshToken(t *testing.T) {
if got, want := r.Header.Get("Content-Type"), "application/x-www-form-urlencoded"; got != want { if got, want := r.Header.Get("Content-Type"), "application/x-www-form-urlencoded"; got != want {
t.Errorf("Unexpected Content-Type header, got %v, want %v", got, want) t.Errorf("Unexpected Content-Type header, got %v, want %v", got, want)
} }
body, err := io.ReadAll(r.Body) body, err := ioutil.ReadAll(r.Body)
if err != nil { if err != nil {
t.Errorf("Failed reading request body: %v.", err) t.Errorf("Failed reading request body: %v.", err)
} }

View File

@ -2,5 +2,5 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// Package internal contains support packages for [golang.org/x/oauth2]. // Package internal contains support packages for oauth2 package.
package internal package internal

View File

@ -13,7 +13,7 @@ import (
) )
// ParseKey converts the binary contents of a private key file // ParseKey converts the binary contents of a private key file
// to an [*rsa.PrivateKey]. It detects whether the private key is in a // to an *rsa.PrivateKey. It detects whether the private key is in a
// PEM container or not. If so, it extracts the private key // PEM container or not. If so, it extracts the private key
// from PEM container before conversion. It only supports PEM // from PEM container before conversion. It only supports PEM
// containers with no passphrase. // containers with no passphrase.

View File

@ -10,6 +10,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"math" "math"
"mime" "mime"
"net/http" "net/http"
@ -25,9 +26,9 @@ import (
// the requests to access protected resources on the OAuth 2.0 // the requests to access protected resources on the OAuth 2.0
// provider's backend. // provider's backend.
// //
// This type is a mirror of [golang.org/x/oauth2.Token] and exists to break // This type is a mirror of oauth2.Token and exists to break
// an otherwise-circular dependency. Other internal packages // an otherwise-circular dependency. Other internal packages
// should convert this Token into an [golang.org/x/oauth2.Token] before use. // should convert this Token into an oauth2.Token before use.
type Token struct { type Token struct {
// AccessToken is the token that authorizes and authenticates // AccessToken is the token that authorizes and authenticates
// the requests. // the requests.
@ -49,16 +50,9 @@ type Token struct {
// mechanisms for that TokenSource will not be used. // mechanisms for that TokenSource will not be used.
Expiry time.Time Expiry time.Time
// ExpiresIn is the OAuth2 wire format "expires_in" field,
// which specifies how many seconds later the token expires,
// relative to an unknown time base approximately around "now".
// It is the application's responsibility to populate
// `Expiry` from `ExpiresIn` when required.
ExpiresIn int64 `json:"expires_in,omitempty"`
// Raw optionally contains extra metadata from the server // Raw optionally contains extra metadata from the server
// when updating a token. // when updating a token.
Raw any Raw interface{}
} }
// tokenJSON is the struct representing the HTTP response from OAuth2 // tokenJSON is the struct representing the HTTP response from OAuth2
@ -105,6 +99,14 @@ func (e *expirationTime) UnmarshalJSON(b []byte) error {
return nil return nil
} }
// RegisterBrokenAuthHeaderProvider previously did something. It is now a no-op.
//
// Deprecated: this function no longer does anything. Caller code that
// wants to avoid potential extra HTTP requests made during
// auto-probing of the provider's auth style should set
// Endpoint.AuthStyle.
func RegisterBrokenAuthHeaderProvider(tokenURL string) {}
// AuthStyle is a copy of the golang.org/x/oauth2 package's AuthStyle type. // AuthStyle is a copy of the golang.org/x/oauth2 package's AuthStyle type.
type AuthStyle int type AuthStyle int
@ -141,11 +143,6 @@ func (lc *LazyAuthStyleCache) Get() *AuthStyleCache {
return c return c
} }
type authStyleCacheKey struct {
url string
clientID string
}
// AuthStyleCache is the set of tokenURLs we've successfully used via // AuthStyleCache is the set of tokenURLs we've successfully used via
// RetrieveToken and which style auth we ended up using. // RetrieveToken and which style auth we ended up using.
// It's called a cache, but it doesn't (yet?) shrink. It's expected that // It's called a cache, but it doesn't (yet?) shrink. It's expected that
@ -153,26 +150,26 @@ type authStyleCacheKey struct {
// small. // small.
type AuthStyleCache struct { type AuthStyleCache struct {
mu sync.Mutex mu sync.Mutex
m map[authStyleCacheKey]AuthStyle m map[string]AuthStyle // keyed by tokenURL
} }
// lookupAuthStyle reports which auth style we last used with tokenURL // lookupAuthStyle reports which auth style we last used with tokenURL
// when calling RetrieveToken and whether we have ever done so. // when calling RetrieveToken and whether we have ever done so.
func (c *AuthStyleCache) lookupAuthStyle(tokenURL, clientID string) (style AuthStyle, ok bool) { func (c *AuthStyleCache) lookupAuthStyle(tokenURL string) (style AuthStyle, ok bool) {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
style, ok = c.m[authStyleCacheKey{tokenURL, clientID}] style, ok = c.m[tokenURL]
return return
} }
// setAuthStyle adds an entry to authStyleCache, documented above. // setAuthStyle adds an entry to authStyleCache, documented above.
func (c *AuthStyleCache) setAuthStyle(tokenURL, clientID string, v AuthStyle) { func (c *AuthStyleCache) setAuthStyle(tokenURL string, v AuthStyle) {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
if c.m == nil { if c.m == nil {
c.m = make(map[authStyleCacheKey]AuthStyle) c.m = make(map[string]AuthStyle)
} }
c.m[authStyleCacheKey{tokenURL, clientID}] = v c.m[tokenURL] = v
} }
// newTokenRequest returns a new *http.Request to retrieve a new token // newTokenRequest returns a new *http.Request to retrieve a new token
@ -213,9 +210,9 @@ func cloneURLValues(v url.Values) url.Values {
} }
func RetrieveToken(ctx context.Context, clientID, clientSecret, tokenURL string, v url.Values, authStyle AuthStyle, styleCache *AuthStyleCache) (*Token, error) { func RetrieveToken(ctx context.Context, clientID, clientSecret, tokenURL string, v url.Values, authStyle AuthStyle, styleCache *AuthStyleCache) (*Token, error) {
needsAuthStyleProbe := authStyle == AuthStyleUnknown needsAuthStyleProbe := authStyle == 0
if needsAuthStyleProbe { if needsAuthStyleProbe {
if style, ok := styleCache.lookupAuthStyle(tokenURL, clientID); ok { if style, ok := styleCache.lookupAuthStyle(tokenURL); ok {
authStyle = style authStyle = style
needsAuthStyleProbe = false needsAuthStyleProbe = false
} else { } else {
@ -245,7 +242,7 @@ func RetrieveToken(ctx context.Context, clientID, clientSecret, tokenURL string,
token, err = doTokenRoundTrip(ctx, req) token, err = doTokenRoundTrip(ctx, req)
} }
if needsAuthStyleProbe && err == nil { if needsAuthStyleProbe && err == nil {
styleCache.setAuthStyle(tokenURL, clientID, authStyle) styleCache.setAuthStyle(tokenURL, authStyle)
} }
// Don't overwrite `RefreshToken` with an empty value // Don't overwrite `RefreshToken` with an empty value
// if this was a token refreshing request. // if this was a token refreshing request.
@ -260,7 +257,7 @@ func doTokenRoundTrip(ctx context.Context, req *http.Request) (*Token, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
body, err := io.ReadAll(io.LimitReader(r.Body, 1<<20)) body, err := ioutil.ReadAll(io.LimitReader(r.Body, 1<<20))
r.Body.Close() r.Body.Close()
if err != nil { if err != nil {
return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err) return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err)
@ -315,8 +312,7 @@ func doTokenRoundTrip(ctx context.Context, req *http.Request) (*Token, error) {
TokenType: tj.TokenType, TokenType: tj.TokenType,
RefreshToken: tj.RefreshToken, RefreshToken: tj.RefreshToken,
Expiry: tj.expiry(), Expiry: tj.expiry(),
ExpiresIn: int64(tj.ExpiresIn), Raw: make(map[string]interface{}),
Raw: make(map[string]any),
} }
json.Unmarshal(body, &token.Raw) // no error checks for optional fields json.Unmarshal(body, &token.Raw) // no error checks for optional fields
} }

View File

@ -75,48 +75,3 @@ func TestExpiresInUpperBound(t *testing.T) {
t.Errorf("expiration time = %v; want %v", e, want) t.Errorf("expiration time = %v; want %v", e, want)
} }
} }
func TestAuthStyleCache(t *testing.T) {
var c LazyAuthStyleCache
cases := []struct {
url string
clientID string
style AuthStyle
}{
{
"https://host1.example.com/token",
"client_1",
AuthStyleInHeader,
}, {
"https://host2.example.com/token",
"client_2",
AuthStyleInParams,
}, {
"https://host1.example.com/token",
"client_3",
AuthStyleInParams,
},
}
for _, tt := range cases {
t.Run(tt.clientID, func(t *testing.T) {
cc := c.Get()
got, ok := cc.lookupAuthStyle(tt.url, tt.clientID)
if ok {
t.Fatalf("unexpected auth style found on first request: %v", got)
}
cc.setAuthStyle(tt.url, tt.clientID, tt.style)
got, ok = cc.lookupAuthStyle(tt.url, tt.clientID)
if !ok {
t.Fatalf("auth style not found in cache")
}
if got != tt.style {
t.Fatalf("auth style mismatch, got=%v, want=%v", got, tt.style)
}
})
}
}

View File

@ -9,8 +9,8 @@ import (
"net/http" "net/http"
) )
// HTTPClient is the context key to use with [context.WithValue] // HTTPClient is the context key to use with golang.org/x/net/context's
// to associate an [*http.Client] value with a context. // WithValue function to associate an *http.Client value with a context.
var HTTPClient ContextKey var HTTPClient ContextKey
// ContextKey is just an empty struct. It exists so HTTPClient can be // ContextKey is just an empty struct. It exists so HTTPClient can be

View File

@ -13,6 +13,7 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"net/http" "net/http"
"net/url" "net/url"
"strings" "strings"
@ -113,7 +114,7 @@ func (js jwtSource) Token() (*oauth2.Token, error) {
return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err) return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err)
} }
defer resp.Body.Close() defer resp.Body.Close()
body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) body, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20))
if err != nil { if err != nil {
return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err) return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err)
} }
@ -122,7 +123,11 @@ func (js jwtSource) Token() (*oauth2.Token, error) {
} }
// tokenRes is the JSON response body. // tokenRes is the JSON response body.
var tokenRes oauth2.Token var tokenRes struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
ExpiresIn int64 `json:"expires_in"` // relative seconds from now
}
if err := json.Unmarshal(body, &tokenRes); err != nil { if err := json.Unmarshal(body, &tokenRes); err != nil {
return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err) return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err)
} }

View File

@ -4,7 +4,7 @@
// Package jws provides a partial implementation // Package jws provides a partial implementation
// of JSON Web Signature encoding and decoding. // of JSON Web Signature encoding and decoding.
// It exists to support the [golang.org/x/oauth2] package. // It exists to support the golang.org/x/oauth2 package.
// //
// See RFC 7515. // See RFC 7515.
// //
@ -48,7 +48,7 @@ type ClaimSet struct {
// See http://tools.ietf.org/html/draft-jones-json-web-token-10#section-4.3 // See http://tools.ietf.org/html/draft-jones-json-web-token-10#section-4.3
// This array is marshalled using custom code (see (c *ClaimSet) encode()). // This array is marshalled using custom code (see (c *ClaimSet) encode()).
PrivateClaims map[string]any `json:"-"` PrivateClaims map[string]interface{} `json:"-"`
} }
func (c *ClaimSet) encode() (string, error) { func (c *ClaimSet) encode() (string, error) {
@ -116,12 +116,12 @@ func (h *Header) encode() (string, error) {
// Decode decodes a claim set from a JWS payload. // Decode decodes a claim set from a JWS payload.
func Decode(payload string) (*ClaimSet, error) { func Decode(payload string) (*ClaimSet, error) {
// decode returned id token to get expiry // decode returned id token to get expiry
_, claims, _, ok := parseToken(payload) s := strings.Split(payload, ".")
if !ok { if len(s) < 2 {
// TODO(jbd): Provide more context about the error. // TODO(jbd): Provide more context about the error.
return nil, errors.New("jws: invalid token received") return nil, errors.New("jws: invalid token received")
} }
decoded, err := base64.RawURLEncoding.DecodeString(claims) decoded, err := base64.RawURLEncoding.DecodeString(s[1])
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -152,7 +152,7 @@ func EncodeWithSigner(header *Header, c *ClaimSet, sg Signer) (string, error) {
} }
// Encode encodes a signed JWS with provided header and claim set. // Encode encodes a signed JWS with provided header and claim set.
// This invokes [EncodeWithSigner] using [crypto/rsa.SignPKCS1v15] with the given RSA private key. // This invokes EncodeWithSigner using crypto/rsa.SignPKCS1v15 with the given RSA private key.
func Encode(header *Header, c *ClaimSet, key *rsa.PrivateKey) (string, error) { func Encode(header *Header, c *ClaimSet, key *rsa.PrivateKey) (string, error) {
sg := func(data []byte) (sig []byte, err error) { sg := func(data []byte) (sig []byte, err error) {
h := sha256.New() h := sha256.New()
@ -165,34 +165,18 @@ func Encode(header *Header, c *ClaimSet, key *rsa.PrivateKey) (string, error) {
// Verify tests whether the provided JWT token's signature was produced by the private key // Verify tests whether the provided JWT token's signature was produced by the private key
// associated with the supplied public key. // associated with the supplied public key.
func Verify(token string, key *rsa.PublicKey) error { func Verify(token string, key *rsa.PublicKey) error {
header, claims, sig, ok := parseToken(token) if strings.Count(token, ".") != 2 {
if !ok {
return errors.New("jws: invalid token received, token must have 3 parts") return errors.New("jws: invalid token received, token must have 3 parts")
} }
signatureString, err := base64.RawURLEncoding.DecodeString(sig)
parts := strings.SplitN(token, ".", 3)
signedContent := parts[0] + "." + parts[1]
signatureString, err := base64.RawURLEncoding.DecodeString(parts[2])
if err != nil { if err != nil {
return err return err
} }
h := sha256.New() h := sha256.New()
h.Write([]byte(header + tokenDelim + claims)) h.Write([]byte(signedContent))
return rsa.VerifyPKCS1v15(key, crypto.SHA256, h.Sum(nil), signatureString) return rsa.VerifyPKCS1v15(key, crypto.SHA256, h.Sum(nil), signatureString)
} }
func parseToken(s string) (header, claims, sig string, ok bool) {
header, s, ok = strings.Cut(s, tokenDelim)
if !ok { // no period found
return "", "", "", false
}
claims, s, ok = strings.Cut(s, tokenDelim)
if !ok { // only one period found
return "", "", "", false
}
sig, _, ok = strings.Cut(s, tokenDelim)
if ok { // three periods found
return "", "", "", false
}
return header, claims, sig, true
}
const tokenDelim = "."

View File

@ -7,8 +7,6 @@ package jws
import ( import (
"crypto/rand" "crypto/rand"
"crypto/rsa" "crypto/rsa"
"net/http"
"strings"
"testing" "testing"
) )
@ -41,57 +39,8 @@ func TestSignAndVerify(t *testing.T) {
} }
func TestVerifyFailsOnMalformedClaim(t *testing.T) { func TestVerifyFailsOnMalformedClaim(t *testing.T) {
cases := []struct { err := Verify("abc.def", nil)
desc string
token string
}{
{
desc: "no periods",
token: "aa",
}, {
desc: "only one period",
token: "a.a",
}, {
desc: "more than two periods",
token: "a.a.a.a",
},
}
for _, tc := range cases {
f := func(t *testing.T) {
err := Verify(tc.token, nil)
if err == nil { if err == nil {
t.Error("got no errors; want improperly formed JWT not to be verified") t.Error("got no errors; want improperly formed JWT not to be verified")
} }
}
t.Run(tc.desc, f)
}
}
func BenchmarkVerify(b *testing.B) {
cases := []struct {
desc string
token string
}{
{
desc: "full of periods",
token: strings.Repeat(".", http.DefaultMaxHeaderBytes),
}, {
desc: "two trailing periods",
token: strings.Repeat("a", http.DefaultMaxHeaderBytes-2) + "..",
},
}
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
b.Fatal(err)
}
for _, bc := range cases {
f := func(b *testing.B) {
b.ReportAllocs()
b.ResetTimer()
for range b.N {
Verify(bc.token, &privateKey.PublicKey)
}
}
b.Run(bc.desc, f)
}
} }

View File

@ -13,6 +13,7 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"net/http" "net/http"
"net/url" "net/url"
"strings" "strings"
@ -68,7 +69,7 @@ type Config struct {
// PrivateClaims optionally specifies custom private claims in the JWT. // PrivateClaims optionally specifies custom private claims in the JWT.
// See http://tools.ietf.org/html/draft-jones-json-web-token-10#section-4.3 // See http://tools.ietf.org/html/draft-jones-json-web-token-10#section-4.3
PrivateClaims map[string]any PrivateClaims map[string]interface{}
// UseIDToken optionally specifies whether ID token should be used instead // UseIDToken optionally specifies whether ID token should be used instead
// of access token when the server returns both. // of access token when the server returns both.
@ -135,7 +136,7 @@ func (js jwtSource) Token() (*oauth2.Token, error) {
return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err) return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err)
} }
defer resp.Body.Close() defer resp.Body.Close()
body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) body, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20))
if err != nil { if err != nil {
return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err) return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err)
} }
@ -147,8 +148,10 @@ func (js jwtSource) Token() (*oauth2.Token, error) {
} }
// tokenRes is the JSON response body. // tokenRes is the JSON response body.
var tokenRes struct { var tokenRes struct {
oauth2.Token AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
IDToken string `json:"id_token"` IDToken string `json:"id_token"`
ExpiresIn int64 `json:"expires_in"` // relative seconds from now
} }
if err := json.Unmarshal(body, &tokenRes); err != nil { if err := json.Unmarshal(body, &tokenRes); err != nil {
return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err) return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err)
@ -157,7 +160,7 @@ func (js jwtSource) Token() (*oauth2.Token, error) {
AccessToken: tokenRes.AccessToken, AccessToken: tokenRes.AccessToken,
TokenType: tokenRes.TokenType, TokenType: tokenRes.TokenType,
} }
raw := make(map[string]any) raw := make(map[string]interface{})
json.Unmarshal(body, &raw) // no error checks for optional fields json.Unmarshal(body, &raw) // no error checks for optional fields
token = token.WithExtra(raw) token = token.WithExtra(raw)

View File

@ -227,7 +227,7 @@ func TestJWTFetch_AssertionPayload(t *testing.T) {
PrivateKey: dummyPrivateKey, PrivateKey: dummyPrivateKey,
PrivateKeyID: "ABCDEFGHIJKLMNOPQRSTUVWXYZ", PrivateKeyID: "ABCDEFGHIJKLMNOPQRSTUVWXYZ",
TokenURL: ts.URL, TokenURL: ts.URL,
PrivateClaims: map[string]any{ PrivateClaims: map[string]interface{}{
"private0": "claim0", "private0": "claim0",
"private1": "claim1", "private1": "claim1",
}, },
@ -273,11 +273,11 @@ func TestJWTFetch_AssertionPayload(t *testing.T) {
t.Errorf("payload prn = %q; want %q", got, want) t.Errorf("payload prn = %q; want %q", got, want)
} }
if len(conf.PrivateClaims) > 0 { if len(conf.PrivateClaims) > 0 {
var got any var got interface{}
if err := json.Unmarshal(gotjson, &got); err != nil { if err := json.Unmarshal(gotjson, &got); err != nil {
t.Errorf("failed to parse payload; err = %q", err) t.Errorf("failed to parse payload; err = %q", err)
} }
m := got.(map[string]any) m := got.(map[string]interface{})
for v, k := range conf.PrivateClaims { for v, k := range conf.PrivateClaims {
if !reflect.DeepEqual(m[v], k) { if !reflect.DeepEqual(m[v], k) {
t.Errorf("payload private claims key = %q: got %#v; want %#v", v, m[v], k) t.Errorf("payload private claims key = %q: got %#v; want %#v", v, m[v], k)

View File

@ -22,9 +22,9 @@ import (
) )
// NoContext is the default context you should supply if not using // NoContext is the default context you should supply if not using
// your own [context.Context]. // your own context.Context (see https://golang.org/x/net/context).
// //
// Deprecated: Use [context.Background] or [context.TODO] instead. // Deprecated: Use context.Background() or context.TODO() instead.
var NoContext = context.TODO() var NoContext = context.TODO()
// RegisterBrokenAuthHeaderProvider previously did something. It is now a no-op. // RegisterBrokenAuthHeaderProvider previously did something. It is now a no-op.
@ -37,8 +37,8 @@ func RegisterBrokenAuthHeaderProvider(tokenURL string) {}
// Config describes a typical 3-legged OAuth2 flow, with both the // Config describes a typical 3-legged OAuth2 flow, with both the
// client application information and the server's endpoint URLs. // client application information and the server's endpoint URLs.
// For the client credentials 2-legged OAuth2 flow, see the // For the client credentials 2-legged OAuth2 flow, see the clientcredentials
// [golang.org/x/oauth2/clientcredentials] package. // package (https://golang.org/x/oauth2/clientcredentials).
type Config struct { type Config struct {
// ClientID is the application's ID. // ClientID is the application's ID.
ClientID string ClientID string
@ -46,7 +46,7 @@ type Config struct {
// ClientSecret is the application's secret. // ClientSecret is the application's secret.
ClientSecret string ClientSecret string
// Endpoint contains the authorization server's token endpoint // Endpoint contains the resource server's token endpoint
// URLs. These are constants specific to each server and are // URLs. These are constants specific to each server and are
// often available via site-specific packages, such as // often available via site-specific packages, such as
// google.Endpoint or github.Endpoint. // google.Endpoint or github.Endpoint.
@ -135,7 +135,7 @@ type setParam struct{ k, v string }
func (p setParam) setValue(m url.Values) { m.Set(p.k, p.v) } func (p setParam) setValue(m url.Values) { m.Set(p.k, p.v) }
// SetAuthURLParam builds an [AuthCodeOption] which passes key/value parameters // SetAuthURLParam builds an AuthCodeOption which passes key/value parameters
// to a provider's authorization endpoint. // to a provider's authorization endpoint.
func SetAuthURLParam(key, value string) AuthCodeOption { func SetAuthURLParam(key, value string) AuthCodeOption {
return setParam{key, value} return setParam{key, value}
@ -148,8 +148,8 @@ func SetAuthURLParam(key, value string) AuthCodeOption {
// request and callback. The authorization server includes this value when // request and callback. The authorization server includes this value when
// redirecting the user agent back to the client. // redirecting the user agent back to the client.
// //
// Opts may include [AccessTypeOnline] or [AccessTypeOffline], as well // Opts may include AccessTypeOnline or AccessTypeOffline, as well
// as [ApprovalForce]. // as ApprovalForce.
// //
// To protect against CSRF attacks, opts should include a PKCE challenge // To protect against CSRF attacks, opts should include a PKCE challenge
// (S256ChallengeOption). Not all servers support PKCE. An alternative is to // (S256ChallengeOption). Not all servers support PKCE. An alternative is to
@ -194,7 +194,7 @@ func (c *Config) AuthCodeURL(state string, opts ...AuthCodeOption) string {
// and when other authorization grant types are not available." // and when other authorization grant types are not available."
// See https://tools.ietf.org/html/rfc6749#section-4.3 for more info. // See https://tools.ietf.org/html/rfc6749#section-4.3 for more info.
// //
// The provided context optionally controls which HTTP client is used. See the [HTTPClient] variable. // The provided context optionally controls which HTTP client is used. See the HTTPClient variable.
func (c *Config) PasswordCredentialsToken(ctx context.Context, username, password string) (*Token, error) { func (c *Config) PasswordCredentialsToken(ctx context.Context, username, password string) (*Token, error) {
v := url.Values{ v := url.Values{
"grant_type": {"password"}, "grant_type": {"password"},
@ -212,10 +212,10 @@ func (c *Config) PasswordCredentialsToken(ctx context.Context, username, passwor
// It is used after a resource provider redirects the user back // It is used after a resource provider redirects the user back
// to the Redirect URI (the URL obtained from AuthCodeURL). // to the Redirect URI (the URL obtained from AuthCodeURL).
// //
// The provided context optionally controls which HTTP client is used. See the [HTTPClient] variable. // The provided context optionally controls which HTTP client is used. See the HTTPClient variable.
// //
// The code will be in the [http.Request.FormValue]("code"). Before // The code will be in the *http.Request.FormValue("code"). Before
// calling Exchange, be sure to validate [http.Request.FormValue]("state") if you are // calling Exchange, be sure to validate FormValue("state") if you are
// using it to protect against CSRF attacks. // using it to protect against CSRF attacks.
// //
// If using PKCE to protect against CSRF attacks, opts should include a // If using PKCE to protect against CSRF attacks, opts should include a
@ -242,10 +242,10 @@ func (c *Config) Client(ctx context.Context, t *Token) *http.Client {
return NewClient(ctx, c.TokenSource(ctx, t)) return NewClient(ctx, c.TokenSource(ctx, t))
} }
// TokenSource returns a [TokenSource] that returns t until t expires, // TokenSource returns a TokenSource that returns t until t expires,
// automatically refreshing it as necessary using the provided context. // automatically refreshing it as necessary using the provided context.
// //
// Most users will use [Config.Client] instead. // Most users will use Config.Client instead.
func (c *Config) TokenSource(ctx context.Context, t *Token) TokenSource { func (c *Config) TokenSource(ctx context.Context, t *Token) TokenSource {
tkr := &tokenRefresher{ tkr := &tokenRefresher{
ctx: ctx, ctx: ctx,
@ -260,7 +260,7 @@ func (c *Config) TokenSource(ctx context.Context, t *Token) TokenSource {
} }
} }
// tokenRefresher is a TokenSource that makes "grant_type=refresh_token" // tokenRefresher is a TokenSource that makes "grant_type"=="refresh_token"
// HTTP requests to renew a token using a RefreshToken. // HTTP requests to renew a token using a RefreshToken.
type tokenRefresher struct { type tokenRefresher struct {
ctx context.Context // used to get HTTP requests ctx context.Context // used to get HTTP requests
@ -305,7 +305,8 @@ type reuseTokenSource struct {
} }
// Token returns the current token if it's still valid, else will // Token returns the current token if it's still valid, else will
// refresh the current token and return the new one. // refresh the current token (using r.Context for HTTP client
// information) and return the new one.
func (s *reuseTokenSource) Token() (*Token, error) { func (s *reuseTokenSource) Token() (*Token, error) {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
@ -321,7 +322,7 @@ func (s *reuseTokenSource) Token() (*Token, error) {
return t, nil return t, nil
} }
// StaticTokenSource returns a [TokenSource] that always returns the same token. // StaticTokenSource returns a TokenSource that always returns the same token.
// Because the provided token t is never refreshed, StaticTokenSource is only // Because the provided token t is never refreshed, StaticTokenSource is only
// useful for tokens that never expire. // useful for tokens that never expire.
func StaticTokenSource(t *Token) TokenSource { func StaticTokenSource(t *Token) TokenSource {
@ -337,16 +338,16 @@ func (s staticTokenSource) Token() (*Token, error) {
return s.t, nil return s.t, nil
} }
// HTTPClient is the context key to use with [context.WithValue] // HTTPClient is the context key to use with golang.org/x/net/context's
// to associate a [*http.Client] value with a context. // WithValue function to associate an *http.Client value with a context.
var HTTPClient internal.ContextKey var HTTPClient internal.ContextKey
// NewClient creates an [*http.Client] from a [context.Context] and [TokenSource]. // NewClient creates an *http.Client from a Context and TokenSource.
// The returned client is not valid beyond the lifetime of the context. // The returned client is not valid beyond the lifetime of the context.
// //
// Note that if a custom [*http.Client] is provided via the [context.Context] it // Note that if a custom *http.Client is provided via the Context it
// is used only for token acquisition and is not used to configure the // is used only for token acquisition and is not used to configure the
// [*http.Client] returned from NewClient. // *http.Client returned from NewClient.
// //
// As a special case, if src is nil, a non-OAuth2 client is returned // As a special case, if src is nil, a non-OAuth2 client is returned
// using the provided context. This exists to support related OAuth2 // using the provided context. This exists to support related OAuth2
@ -367,7 +368,7 @@ func NewClient(ctx context.Context, src TokenSource) *http.Client {
} }
} }
// ReuseTokenSource returns a [TokenSource] which repeatedly returns the // ReuseTokenSource returns a TokenSource which repeatedly returns the
// same token as long as it's valid, starting with t. // same token as long as it's valid, starting with t.
// When its cached token is invalid, a new token is obtained from src. // When its cached token is invalid, a new token is obtained from src.
// //
@ -375,10 +376,10 @@ func NewClient(ctx context.Context, src TokenSource) *http.Client {
// (such as a file on disk) between runs of a program, rather than // (such as a file on disk) between runs of a program, rather than
// obtaining new tokens unnecessarily. // obtaining new tokens unnecessarily.
// //
// The initial token t may be nil, in which case the [TokenSource] is // The initial token t may be nil, in which case the TokenSource is
// wrapped in a caching version if it isn't one already. This also // wrapped in a caching version if it isn't one already. This also
// means it's always safe to wrap ReuseTokenSource around any other // means it's always safe to wrap ReuseTokenSource around any other
// [TokenSource] without adverse effects. // TokenSource without adverse effects.
func ReuseTokenSource(t *Token, src TokenSource) TokenSource { func ReuseTokenSource(t *Token, src TokenSource) TokenSource {
// Don't wrap a reuseTokenSource in itself. That would work, // Don't wrap a reuseTokenSource in itself. That would work,
// but cause an unnecessary number of mutex operations. // but cause an unnecessary number of mutex operations.
@ -396,8 +397,8 @@ func ReuseTokenSource(t *Token, src TokenSource) TokenSource {
} }
} }
// ReuseTokenSourceWithExpiry returns a [TokenSource] that acts in the same manner as the // ReuseTokenSourceWithExpiry returns a TokenSource that acts in the same manner as the
// [TokenSource] returned by [ReuseTokenSource], except the expiry buffer is // TokenSource returned by ReuseTokenSource, except the expiry buffer is
// configurable. The expiration time of a token is calculated as // configurable. The expiration time of a token is calculated as
// t.Expiry.Add(-earlyExpiry). // t.Expiry.Add(-earlyExpiry).
func ReuseTokenSourceWithExpiry(t *Token, src TokenSource, earlyExpiry time.Duration) TokenSource { func ReuseTokenSourceWithExpiry(t *Token, src TokenSource, earlyExpiry time.Duration) TokenSource {

View File

@ -9,6 +9,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url" "net/url"
@ -103,7 +104,7 @@ func TestExchangeRequest(t *testing.T) {
if headerContentType != "application/x-www-form-urlencoded" { if headerContentType != "application/x-www-form-urlencoded" {
t.Errorf("Unexpected Content-Type header %q", headerContentType) t.Errorf("Unexpected Content-Type header %q", headerContentType)
} }
body, err := io.ReadAll(r.Body) body, err := ioutil.ReadAll(r.Body)
if err != nil { if err != nil {
t.Errorf("Failed reading request body: %s.", err) t.Errorf("Failed reading request body: %s.", err)
} }
@ -147,7 +148,7 @@ func TestExchangeRequest_CustomParam(t *testing.T) {
if headerContentType != "application/x-www-form-urlencoded" { if headerContentType != "application/x-www-form-urlencoded" {
t.Errorf("Unexpected Content-Type header, %v is found.", headerContentType) t.Errorf("Unexpected Content-Type header, %v is found.", headerContentType)
} }
body, err := io.ReadAll(r.Body) body, err := ioutil.ReadAll(r.Body)
if err != nil { if err != nil {
t.Errorf("Failed reading request body: %s.", err) t.Errorf("Failed reading request body: %s.", err)
} }
@ -193,7 +194,7 @@ func TestExchangeRequest_JSONResponse(t *testing.T) {
if headerContentType != "application/x-www-form-urlencoded" { if headerContentType != "application/x-www-form-urlencoded" {
t.Errorf("Unexpected Content-Type header, %v is found.", headerContentType) t.Errorf("Unexpected Content-Type header, %v is found.", headerContentType)
} }
body, err := io.ReadAll(r.Body) body, err := ioutil.ReadAll(r.Body)
if err != nil { if err != nil {
t.Errorf("Failed reading request body: %s.", err) t.Errorf("Failed reading request body: %s.", err)
} }
@ -300,7 +301,7 @@ func testExchangeRequest_JSONResponse_expiry(t *testing.T, exp string, want, nul
conf := newConf(ts.URL) conf := newConf(ts.URL)
t1 := time.Now().Add(day) t1 := time.Now().Add(day)
tok, err := conf.Exchange(context.Background(), "exchange-code") tok, err := conf.Exchange(context.Background(), "exchange-code")
t2 := time.Now().Add(day) t2 := t1.Add(day)
if got := (err == nil); got != want { if got := (err == nil); got != want {
if want { if want {
@ -392,7 +393,7 @@ func TestPasswordCredentialsTokenRequest(t *testing.T) {
if headerContentType != expected { if headerContentType != expected {
t.Errorf("Content-Type header = %q; want %q", headerContentType, expected) t.Errorf("Content-Type header = %q; want %q", headerContentType, expected)
} }
body, err := io.ReadAll(r.Body) body, err := ioutil.ReadAll(r.Body)
if err != nil { if err != nil {
t.Errorf("Failed reading request body: %s.", err) t.Errorf("Failed reading request body: %s.", err)
} }
@ -434,7 +435,7 @@ func TestTokenRefreshRequest(t *testing.T) {
if headerContentType != "application/x-www-form-urlencoded" { if headerContentType != "application/x-www-form-urlencoded" {
t.Errorf("Unexpected Content-Type header %q", headerContentType) t.Errorf("Unexpected Content-Type header %q", headerContentType)
} }
body, _ := io.ReadAll(r.Body) body, _ := ioutil.ReadAll(r.Body)
if string(body) != "grant_type=refresh_token&refresh_token=REFRESH_TOKEN" { if string(body) != "grant_type=refresh_token&refresh_token=REFRESH_TOKEN" {
t.Errorf("Unexpected refresh token payload %q", body) t.Errorf("Unexpected refresh token payload %q", body)
} }
@ -459,7 +460,7 @@ func TestFetchWithNoRefreshToken(t *testing.T) {
if headerContentType != "application/x-www-form-urlencoded" { if headerContentType != "application/x-www-form-urlencoded" {
t.Errorf("Unexpected Content-Type header, %v is found.", headerContentType) t.Errorf("Unexpected Content-Type header, %v is found.", headerContentType)
} }
body, _ := io.ReadAll(r.Body) body, _ := ioutil.ReadAll(r.Body)
if string(body) != "client_id=CLIENT_ID&grant_type=refresh_token&refresh_token=REFRESH_TOKEN" { if string(body) != "client_id=CLIENT_ID&grant_type=refresh_token&refresh_token=REFRESH_TOKEN" {
t.Errorf("Unexpected refresh token payload, %v is found.", string(body)) t.Errorf("Unexpected refresh token payload, %v is found.", string(body))
} }

15
pkce.go
View File

@ -1,7 +1,6 @@
// Copyright 2023 The Go Authors. All rights reserved. // Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
package oauth2 package oauth2
import ( import (
@ -21,9 +20,9 @@ const (
// This follows recommendations in RFC 7636. // This follows recommendations in RFC 7636.
// //
// A fresh verifier should be generated for each authorization. // A fresh verifier should be generated for each authorization.
// The resulting verifier should be passed to [Config.AuthCodeURL] or [Config.DeviceAuth] // S256ChallengeOption(verifier) should then be passed to Config.AuthCodeURL
// with [S256ChallengeOption], and to [Config.Exchange] or [Config.DeviceAccessToken] // (or Config.DeviceAuth) and VerifierOption(verifier) to Config.Exchange
// with [VerifierOption]. // (or Config.DeviceAccessToken).
func GenerateVerifier() string { func GenerateVerifier() string {
// "RECOMMENDED that the output of a suitable random number generator be // "RECOMMENDED that the output of a suitable random number generator be
// used to create a 32-octet sequence. The octet sequence is then // used to create a 32-octet sequence. The octet sequence is then
@ -37,22 +36,22 @@ func GenerateVerifier() string {
return base64.RawURLEncoding.EncodeToString(data) return base64.RawURLEncoding.EncodeToString(data)
} }
// VerifierOption returns a PKCE code verifier [AuthCodeOption]. It should only be // VerifierOption returns a PKCE code verifier AuthCodeOption. It should be
// passed to [Config.Exchange] or [Config.DeviceAccessToken]. // passed to Config.Exchange or Config.DeviceAccessToken only.
func VerifierOption(verifier string) AuthCodeOption { func VerifierOption(verifier string) AuthCodeOption {
return setParam{k: codeVerifierKey, v: verifier} return setParam{k: codeVerifierKey, v: verifier}
} }
// S256ChallengeFromVerifier returns a PKCE code challenge derived from verifier with method S256. // S256ChallengeFromVerifier returns a PKCE code challenge derived from verifier with method S256.
// //
// Prefer to use [S256ChallengeOption] where possible. // Prefer to use S256ChallengeOption where possible.
func S256ChallengeFromVerifier(verifier string) string { func S256ChallengeFromVerifier(verifier string) string {
sha := sha256.Sum256([]byte(verifier)) sha := sha256.Sum256([]byte(verifier))
return base64.RawURLEncoding.EncodeToString(sha[:]) return base64.RawURLEncoding.EncodeToString(sha[:])
} }
// S256ChallengeOption derives a PKCE code challenge derived from verifier with // S256ChallengeOption derives a PKCE code challenge derived from verifier with
// method S256. It should be passed to [Config.AuthCodeURL] or [Config.DeviceAuth] // method S256. It should be passed to Config.AuthCodeURL or Config.DeviceAuth
// only. // only.
func S256ChallengeOption(verifier string) AuthCodeOption { func S256ChallengeOption(verifier string) AuthCodeOption {
return challengeOption{ return challengeOption{

View File

@ -44,7 +44,7 @@ type Token struct {
// Expiry is the optional expiration time of the access token. // Expiry is the optional expiration time of the access token.
// //
// If zero, [TokenSource] implementations will reuse the same // If zero, TokenSource implementations will reuse the same
// token forever and RefreshToken or equivalent // token forever and RefreshToken or equivalent
// mechanisms for that TokenSource will not be used. // mechanisms for that TokenSource will not be used.
Expiry time.Time `json:"expiry,omitempty"` Expiry time.Time `json:"expiry,omitempty"`
@ -58,7 +58,7 @@ type Token struct {
// raw optionally contains extra metadata from the server // raw optionally contains extra metadata from the server
// when updating a token. // when updating a token.
raw any raw interface{}
// expiryDelta is used to calculate when a token is considered // expiryDelta is used to calculate when a token is considered
// expired, by subtracting from Expiry. If zero, defaultExpiryDelta // expired, by subtracting from Expiry. If zero, defaultExpiryDelta
@ -86,16 +86,16 @@ func (t *Token) Type() string {
// SetAuthHeader sets the Authorization header to r using the access // SetAuthHeader sets the Authorization header to r using the access
// token in t. // token in t.
// //
// This method is unnecessary when using [Transport] or an HTTP Client // This method is unnecessary when using Transport or an HTTP Client
// returned by this package. // returned by this package.
func (t *Token) SetAuthHeader(r *http.Request) { func (t *Token) SetAuthHeader(r *http.Request) {
r.Header.Set("Authorization", t.Type()+" "+t.AccessToken) r.Header.Set("Authorization", t.Type()+" "+t.AccessToken)
} }
// WithExtra returns a new [Token] that's a clone of t, but using the // WithExtra returns a new Token that's a clone of t, but using the
// provided raw extra map. This is only intended for use by packages // provided raw extra map. This is only intended for use by packages
// implementing derivative OAuth2 flows. // implementing derivative OAuth2 flows.
func (t *Token) WithExtra(extra any) *Token { func (t *Token) WithExtra(extra interface{}) *Token {
t2 := new(Token) t2 := new(Token)
*t2 = *t *t2 = *t
t2.raw = extra t2.raw = extra
@ -105,8 +105,8 @@ func (t *Token) WithExtra(extra any) *Token {
// Extra returns an extra field. // Extra returns an extra field.
// Extra fields are key-value pairs returned by the server as a // Extra fields are key-value pairs returned by the server as a
// part of the token retrieval response. // part of the token retrieval response.
func (t *Token) Extra(key string) any { func (t *Token) Extra(key string) interface{} {
if raw, ok := t.raw.(map[string]any); ok { if raw, ok := t.raw.(map[string]interface{}); ok {
return raw[key] return raw[key]
} }
@ -163,14 +163,13 @@ func tokenFromInternal(t *internal.Token) *Token {
TokenType: t.TokenType, TokenType: t.TokenType,
RefreshToken: t.RefreshToken, RefreshToken: t.RefreshToken,
Expiry: t.Expiry, Expiry: t.Expiry,
ExpiresIn: t.ExpiresIn,
raw: t.Raw, raw: t.Raw,
} }
} }
// retrieveToken takes a *Config and uses that to retrieve an *internal.Token. // retrieveToken takes a *Config and uses that to retrieve an *internal.Token.
// This token is then mapped from *internal.Token into an *oauth2.Token which is returned along // This token is then mapped from *internal.Token into an *oauth2.Token which is returned along
// with an error. // with an error..
func retrieveToken(ctx context.Context, c *Config, v url.Values) (*Token, error) { func retrieveToken(ctx context.Context, c *Config, v url.Values) (*Token, error) {
tk, err := internal.RetrieveToken(ctx, c.ClientID, c.ClientSecret, c.Endpoint.TokenURL, v, internal.AuthStyle(c.Endpoint.AuthStyle), c.authStyleCache.Get()) tk, err := internal.RetrieveToken(ctx, c.ClientID, c.ClientSecret, c.Endpoint.TokenURL, v, internal.AuthStyle(c.Endpoint.AuthStyle), c.authStyleCache.Get())
if err != nil { if err != nil {

View File

@ -12,8 +12,8 @@ import (
func TestTokenExtra(t *testing.T) { func TestTokenExtra(t *testing.T) {
type testCase struct { type testCase struct {
key string key string
val any val interface{}
want any want interface{}
} }
const key = "extra-key" const key = "extra-key"
cases := []testCase{ cases := []testCase{
@ -23,7 +23,7 @@ func TestTokenExtra(t *testing.T) {
{key: "other-key", val: "def", want: nil}, {key: "other-key", val: "def", want: nil},
} }
for _, tc := range cases { for _, tc := range cases {
extra := make(map[string]any) extra := make(map[string]interface{})
extra[tc.key] = tc.val extra[tc.key] = tc.val
tok := &Token{raw: extra} tok := &Token{raw: extra}
if got, want := tok.Extra(key), tc.want; got != want { if got, want := tok.Extra(key), tc.want; got != want {

View File

@ -11,12 +11,12 @@ import (
"sync" "sync"
) )
// Transport is an [http.RoundTripper] that makes OAuth 2.0 HTTP requests, // Transport is an http.RoundTripper that makes OAuth 2.0 HTTP requests,
// wrapping a base [http.RoundTripper] and adding an Authorization header // wrapping a base RoundTripper and adding an Authorization header
// with a token from the supplied [TokenSource]. // with a token from the supplied Sources.
// //
// Transport is a low-level mechanism. Most code will use the // Transport is a low-level mechanism. Most code will use the
// higher-level [Config.Client] method instead. // higher-level Config.Client method instead.
type Transport struct { type Transport struct {
// Source supplies the token to add to outgoing requests' // Source supplies the token to add to outgoing requests'
// Authorization headers. // Authorization headers.
@ -47,7 +47,7 @@ func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
return nil, err return nil, err
} }
req2 := req.Clone(req.Context()) req2 := cloneRequest(req) // per RoundTripper contract
token.SetAuthHeader(req2) token.SetAuthHeader(req2)
// req.Body is assumed to be closed by the base RoundTripper. // req.Body is assumed to be closed by the base RoundTripper.
@ -73,3 +73,17 @@ func (t *Transport) base() http.RoundTripper {
} }
return http.DefaultTransport return http.DefaultTransport
} }
// cloneRequest returns a clone of the provided *http.Request.
// The clone is a shallow copy of the struct and its Header map.
func cloneRequest(r *http.Request) *http.Request {
// shallow copy of the struct
r2 := new(http.Request)
*r2 = *r
// deep copy of the Header
r2.Header = make(http.Header, len(r.Header))
for k, s := range r.Header {
r2.Header[k] = append([]string(nil), s...)
}
return r2
}