Compare commits

..

No commits in common. "master" and "v0.21.0" have entirely different histories.

46 changed files with 302 additions and 619 deletions

View File

@ -1,4 +1,4 @@
Copyright 2009 The Go Authors.
Copyright (c) 2009 The Go Authors. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
@ -10,7 +10,7 @@ notice, this list of conditions and the following disclaimer.
copyright notice, this list of conditions and the following disclaimer
in the documentation and/or other materials provided with the
distribution.
* Neither the name of Google LLC nor the names of its
* Neither the name of Google Inc. nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.

View File

@ -5,6 +5,15 @@
oauth2 package contains a client implementation for OAuth 2.0 spec.
## Installation
~~~~
go get golang.org/x/oauth2
~~~~
Or you can manually git clone the repository to
`$(go env GOPATH)/src/golang.org/x/oauth2`.
See pkg.go.dev for further documentation and examples.
* [pkg.go.dev/golang.org/x/oauth2](https://pkg.go.dev/golang.org/x/oauth2)
@ -24,11 +33,7 @@ The main issue tracker for the oauth2 repository is located at
https://github.com/golang/oauth2/issues.
This repository uses Gerrit for code changes. To learn how to submit changes to
this repository, see https://go.dev/doc/contribute.
The git repository is https://go.googlesource.com/oauth2.
Note:
this repository, see https://golang.org/doc/contribute.html. In particular:
* Excluding trivial changes, all contributions should be connected to an existing issue.
* API changes must go through the [change proposal process](https://go.dev/s/proposal-process) before they can be accepted.

View File

@ -34,7 +34,7 @@ type PKCEParams struct {
// and returns an auth code and state upon approval.
type AuthorizationHandler func(authCodeURL string) (code string, state string, err error)
// TokenSourceWithPKCE is an enhanced version of [oauth2.TokenSource] with PKCE support.
// TokenSourceWithPKCE is an enhanced version of TokenSource with PKCE support.
//
// The pkce parameter supports PKCE flow, which uses code challenge and code verifier
// to prevent CSRF attacks. A unique code challenge and code verifier should be generated
@ -43,12 +43,12 @@ func TokenSourceWithPKCE(ctx context.Context, config *oauth2.Config, state strin
return oauth2.ReuseTokenSource(nil, authHandlerSource{config: config, ctx: ctx, authHandler: authHandler, state: state, pkce: pkce})
}
// TokenSource returns an [oauth2.TokenSource] that fetches access tokens
// TokenSource returns an oauth2.TokenSource that fetches access tokens
// using 3-legged-OAuth flow.
//
// The provided [context.Context] is used for oauth2 Exchange operation.
// The provided context.Context is used for oauth2 Exchange operation.
//
// The provided [oauth2.Config] should be a full configuration containing AuthURL,
// The provided oauth2.Config should be a full configuration containing AuthURL,
// TokenURL, and Scope.
//
// An environment-specific AuthorizationHandler is used to obtain user consent.

View File

@ -37,7 +37,7 @@ type Config struct {
// URL. This is a constant specific to each server.
TokenURL string
// Scopes specifies optional requested permissions.
// Scope specifies optional requested permissions.
Scopes []string
// EndpointParams specifies additional parameters for requests to the token endpoint.
@ -55,7 +55,7 @@ type Config struct {
// Token uses client credentials to retrieve a token.
//
// The provided context optionally controls which HTTP client is used. See the [oauth2.HTTPClient] variable.
// The provided context optionally controls which HTTP client is used. See the oauth2.HTTPClient variable.
func (c *Config) Token(ctx context.Context) (*oauth2.Token, error) {
return c.TokenSource(ctx).Token()
}
@ -64,18 +64,18 @@ func (c *Config) Token(ctx context.Context) (*oauth2.Token, error) {
// The token will auto-refresh as necessary.
//
// The provided context optionally controls which HTTP client
// is returned. See the [oauth2.HTTPClient] variable.
// is returned. See the oauth2.HTTPClient variable.
//
// The returned [http.Client] and its Transport should not be modified.
// The returned Client and its Transport should not be modified.
func (c *Config) Client(ctx context.Context) *http.Client {
return oauth2.NewClient(ctx, c.TokenSource(ctx))
}
// TokenSource returns a [oauth2.TokenSource] that returns t until t expires,
// TokenSource returns a TokenSource that returns t until t expires,
// automatically refreshing it as necessary using the provided context and the
// client ID and client secret.
//
// Most users will use [Config.Client] instead.
// Most users will use Config.Client instead.
func (c *Config) TokenSource(ctx context.Context) oauth2.TokenSource {
source := &tokenSource{
ctx: ctx,

View File

@ -7,6 +7,7 @@ package clientcredentials
import (
"context"
"io"
"io/ioutil"
"net/http"
"net/http/httptest"
"net/url"
@ -35,9 +36,9 @@ func TestTokenSourceGrantTypeOverride(t *testing.T) {
wantGrantType := "password"
var gotGrantType string
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, err := io.ReadAll(r.Body)
body, err := ioutil.ReadAll(r.Body)
if err != nil {
t.Errorf("io.ReadAll(r.Body) == %v, %v, want _, <nil>", body, err)
t.Errorf("ioutil.ReadAll(r.Body) == %v, %v, want _, <nil>", body, err)
}
if err := r.Body.Close(); err != nil {
t.Errorf("r.Body.Close() == %v, want <nil>", err)
@ -80,7 +81,7 @@ func TestTokenRequest(t *testing.T) {
if got, want := r.Header.Get("Content-Type"), "application/x-www-form-urlencoded"; got != want {
t.Errorf("Content-Type header = %q; want %q", got, want)
}
body, err := io.ReadAll(r.Body)
body, err := ioutil.ReadAll(r.Body)
if err != nil {
r.Body.Close()
}
@ -122,7 +123,7 @@ func TestTokenRefreshRequest(t *testing.T) {
if got, want := headerContentType, "application/x-www-form-urlencoded"; got != want {
t.Errorf("Content-Type = %q; want %q", got, want)
}
body, _ := io.ReadAll(r.Body)
body, _ := ioutil.ReadAll(r.Body)
const want = "audience=audience1&grant_type=client_credentials&scope=scope1+scope2"
if string(body) != want {
t.Errorf("Unexpected refresh token payload.\n got: %s\nwant: %s\n", body, want)

View File

@ -7,6 +7,9 @@ import (
"strings"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
)
func TestDeviceAuthResponseMarshalJson(t *testing.T) {
@ -71,16 +74,7 @@ func TestDeviceAuthResponseUnmarshalJson(t *testing.T) {
if err != nil {
t.Fatal(err)
}
margin := time.Second + time.Since(begin)
timeDiff := got.Expiry.Sub(tc.want.Expiry)
if timeDiff < 0 {
timeDiff *= -1
}
if timeDiff > margin {
t.Errorf("expiry time difference too large, got=%v, want=%v margin=%v", got.Expiry, tc.want.Expiry, margin)
}
got.Expiry, tc.want.Expiry = time.Time{}, time.Time{}
if got != tc.want {
if !cmp.Equal(got, tc.want, cmpopts.IgnoreUnexported(DeviceAuthResponse{}), cmpopts.EquateApproxTime(time.Second+time.Since(begin))) {
t.Errorf("want=%#v, got=%#v", tc.want, got)
}
})

View File

@ -6,7 +6,7 @@
package endpoints
import (
"net/url"
"strings"
"golang.org/x/oauth2"
)
@ -17,30 +17,6 @@ var Amazon = oauth2.Endpoint{
TokenURL: "https://api.amazon.com/auth/o2/token",
}
// Apple is the endpoint for "Sign in with Apple".
//
// Documentation: https://developer.apple.com/documentation/signinwithapplerestapi
var Apple = oauth2.Endpoint{
AuthURL: "https://appleid.apple.com/auth/authorize",
TokenURL: "https://appleid.apple.com/auth/token",
}
// Asana is the endpoint for Asana.
//
// Documentation: https://developers.asana.com/docs/oauth
var Asana = oauth2.Endpoint{
AuthURL: "https://app.asana.com/-/oauth_authorize",
TokenURL: "https://app.asana.com/-/oauth_token",
}
// Badgr is the endpoint for Canvas Badges.
//
// Documentation: https://community.canvaslms.com/t5/Canvas-Badges-Credentials/Developers-Build-an-app-that-integrates-with-the-Canvas-Badges/ta-p/528727
var Badgr = oauth2.Endpoint{
AuthURL: "https://badgr.com/auth/oauth2/authorize",
TokenURL: "https://api.badgr.io/o/token",
}
// Battlenet is the endpoint for Battlenet.
var Battlenet = oauth2.Endpoint{
AuthURL: "https://battle.net/oauth/authorize",
@ -59,44 +35,10 @@ var Cern = oauth2.Endpoint{
TokenURL: "https://oauth.web.cern.ch/OAuth/Token",
}
// Coinbase is the endpoint for Coinbase.
//
// Documentation: https://docs.cdp.coinbase.com/coinbase-app/docs/coinbase-app-reference
var Coinbase = oauth2.Endpoint{
AuthURL: "https://login.coinbase.com/oauth2/auth",
TokenURL: "https://login.coinbase.com/oauth2/token",
}
// Discord is the endpoint for Discord.
//
// Documentation: https://discord.com/developers/docs/topics/oauth2#shared-resources-oauth2-urls
var Discord = oauth2.Endpoint{
AuthURL: "https://discord.com/oauth2/authorize",
TokenURL: "https://discord.com/api/oauth2/token",
}
// Dropbox is the endpoint for Dropbox.
//
// Documentation: https://developers.dropbox.com/oauth-guide
var Dropbox = oauth2.Endpoint{
AuthURL: "https://www.dropbox.com/oauth2/authorize",
TokenURL: "https://api.dropboxapi.com/oauth2/token",
}
// Endpoint is Ebay's OAuth 2.0 endpoint.
//
// Documentation: https://developer.ebay.com/api-docs/static/authorization_guide_landing.html
var Endpoint = oauth2.Endpoint{
AuthURL: "https://auth.ebay.com/oauth2/authorize",
TokenURL: "https://api.ebay.com/identity/v1/oauth2/token",
}
// Facebook is the endpoint for Facebook.
//
// Documentation: https://developers.facebook.com/docs/facebook-login/guides/advanced/manual-flow
var Facebook = oauth2.Endpoint{
AuthURL: "https://www.facebook.com/v22.0/dialog/oauth",
TokenURL: "https://graph.facebook.com/v22.0/oauth/access_token",
AuthURL: "https://www.facebook.com/v3.2/dialog/oauth",
TokenURL: "https://graph.facebook.com/v3.2/oauth/access_token",
}
// Foursquare is the endpoint for Foursquare.
@ -120,9 +62,8 @@ var GitHub = oauth2.Endpoint{
// GitLab is the endpoint for GitLab.
var GitLab = oauth2.Endpoint{
AuthURL: "https://gitlab.com/oauth/authorize",
TokenURL: "https://gitlab.com/oauth/token",
DeviceAuthURL: "https://gitlab.com/oauth/authorize_device",
AuthURL: "https://gitlab.com/oauth/authorize",
TokenURL: "https://gitlab.com/oauth/token",
}
// Google is the endpoint for Google.
@ -156,14 +97,6 @@ var KaKao = oauth2.Endpoint{
TokenURL: "https://kauth.kakao.com/oauth/token",
}
// Line is the endpoint for Line.
//
// Documentation: https://developers.line.biz/en/docs/line-login/integrate-line-login/
var Line = oauth2.Endpoint{
AuthURL: "https://access.line.me/oauth2/v2.1/authorize",
TokenURL: "https://api.line.me/oauth2/v2.1/token",
}
// LinkedIn is the endpoint for LinkedIn.
var LinkedIn = oauth2.Endpoint{
AuthURL: "https://www.linkedin.com/oauth/v2/authorization",
@ -200,17 +133,7 @@ var Microsoft = oauth2.Endpoint{
TokenURL: "https://login.live.com/oauth20_token.srf",
}
// Naver is the endpoint for Naver.
//
// Documentation: https://developers.naver.com/docs/login/devguide/devguide.md
var Naver = oauth2.Endpoint{
AuthURL: "https://nid.naver.com/oauth2/authorize",
TokenURL: "https://nid.naver.com/oauth2/token",
}
// NokiaHealth is the endpoint for Nokia Health.
//
// Deprecated: Nokia Health is now Withings.
var NokiaHealth = oauth2.Endpoint{
AuthURL: "https://account.health.nokia.com/oauth2_user/authorize2",
TokenURL: "https://account.health.nokia.com/oauth2/token",
@ -222,20 +145,6 @@ var Odnoklassniki = oauth2.Endpoint{
TokenURL: "https://api.odnoklassniki.ru/oauth/token.do",
}
// OpenStreetMap is the endpoint for OpenStreetMap.org.
//
// Documentation: https://wiki.openstreetmap.org/wiki/OAuth
var OpenStreetMap = oauth2.Endpoint{
AuthURL: "https://www.openstreetmap.org/oauth2/authorize",
TokenURL: "https://www.openstreetmap.org/oauth2/token",
}
// Patreon is the endpoint for Patreon.
var Patreon = oauth2.Endpoint{
AuthURL: "https://www.patreon.com/oauth2/authorize",
TokenURL: "https://www.patreon.com/api/oauth2/token",
}
// PayPal is the endpoint for PayPal.
var PayPal = oauth2.Endpoint{
AuthURL: "https://www.paypal.com/webapps/auth/protocol/openidconnect/v1/authorize",
@ -248,52 +157,10 @@ var PayPalSandbox = oauth2.Endpoint{
TokenURL: "https://api.sandbox.paypal.com/v1/identity/openidconnect/tokenservice",
}
// Pinterest is the endpoint for Pinterest.
//
// Documentation: https://developers.pinterest.com/docs/getting-started/set-up-authentication-and-authorization/
var Pinterest = oauth2.Endpoint{
AuthURL: "https://www.pinterest.com/oauth",
TokenURL: "https://api.pinterest.com/v5/oauth/token",
}
// Pipedrive is the endpoint for Pipedrive.
//
// Documentation: https://developers.pipedrive.com/docs/api/v1/Oauth
var Pipedrive = oauth2.Endpoint{
AuthURL: "https://oauth.pipedrive.com/oauth/authorize",
TokenURL: "https://oauth.pipedrive.com/oauth/token",
}
// QQ is the endpoint for QQ.
//
// Documentation: https://wiki.connect.qq.com/%e5%bc%80%e5%8f%91%e6%94%bb%e7%95%a5_server-side
var QQ = oauth2.Endpoint{
AuthURL: "https://graph.qq.com/oauth2.0/authorize",
TokenURL: "https://graph.qq.com/oauth2.0/token",
}
// Rakuten is the endpoint for Rakuten.
//
// Documentation: https://webservice.rakuten.co.jp/documentation
var Rakuten = oauth2.Endpoint{
AuthURL: "https://app.rakuten.co.jp/services/authorize",
TokenURL: "https://app.rakuten.co.jp/services/token",
}
// Slack is the endpoint for Slack.
//
// Documentation: https://api.slack.com/authentication/oauth-v2
var Slack = oauth2.Endpoint{
AuthURL: "https://slack.com/oauth/v2/authorize",
TokenURL: "https://slack.com/api/oauth.v2.access",
}
// Splitwise is the endpoint for Splitwise.
//
// Documentation: https://dev.splitwise.com/
var Splitwise = oauth2.Endpoint{
AuthURL: "https://www.splitwise.com/oauth/authorize",
TokenURL: "https://www.splitwise.com/oauth/token",
AuthURL: "https://slack.com/oauth/authorize",
TokenURL: "https://slack.com/api/oauth.access",
}
// Spotify is the endpoint for Spotify.
@ -332,22 +199,6 @@ var Vk = oauth2.Endpoint{
TokenURL: "https://oauth.vk.com/access_token",
}
// Withings is the endpoint for Withings.
//
// Documentation: https://account.withings.com/oauth2_user/authorize2
var Withings = oauth2.Endpoint{
AuthURL: "https://account.withings.com/oauth2_user/authorize2",
TokenURL: "https://account.withings.com/oauth2/token",
}
// X is the endpoint for X (Twitter).
//
// Documentation: https://docs.x.com/resources/fundamentals/authentication/oauth-2-0/user-access-token
var X = oauth2.Endpoint{
AuthURL: "https://x.com/i/oauth2/authorize",
TokenURL: "https://api.x.com/2/oauth2/token",
}
// Yahoo is the endpoint for Yahoo.
var Yahoo = oauth2.Endpoint{
AuthURL: "https://api.login.yahoo.com/oauth2/request_auth",
@ -366,20 +217,6 @@ var Zoom = oauth2.Endpoint{
TokenURL: "https://zoom.us/oauth/token",
}
// Asgardeo returns a new oauth2.Endpoint for the given tenant.
//
// Documentation: https://wso2.com/asgardeo/docs/guides/authentication/oidc/discover-oidc-configs/
func AsgardeoEndpoint(tenant string) oauth2.Endpoint {
u := url.URL{
Scheme: "https",
Host: "api.asgardeo.io",
}
return oauth2.Endpoint{
AuthURL: u.JoinPath("t", tenant, "/oauth2/authorize").String(),
TokenURL: u.JoinPath("t", tenant, "/oauth2/token").String(),
}
}
// AzureAD returns a new oauth2.Endpoint for the given tenant at Azure Active Directory.
// If tenant is empty, it uses the tenant called `common`.
//
@ -389,29 +226,19 @@ func AzureAD(tenant string) oauth2.Endpoint {
if tenant == "" {
tenant = "common"
}
u := url.URL{
Scheme: "https",
Host: "login.microsoftonline.com",
}
return oauth2.Endpoint{
AuthURL: u.JoinPath(tenant, "/oauth2/v2.0/authorize").String(),
TokenURL: u.JoinPath(tenant, "/oauth2/v2.0/token").String(),
DeviceAuthURL: u.JoinPath(tenant, "/oauth2/v2.0/devicecode").String(),
AuthURL: "https://login.microsoftonline.com/" + tenant + "/oauth2/v2.0/authorize",
TokenURL: "https://login.microsoftonline.com/" + tenant + "/oauth2/v2.0/token",
DeviceAuthURL: "https://login.microsoftonline.com/" + tenant + "/oauth2/v2.0/devicecode",
}
}
// AzureADB2CEndpoint returns a new oauth2.Endpoint for the given tenant and policy at Azure Active Directory B2C.
// policy is the Azure B2C User flow name Example: `B2C_1_SignUpSignIn`.
//
// Documentation: https://docs.microsoft.com/en-us/azure/active-directory-b2c/tokens-overview#endpoints
func AzureADB2CEndpoint(tenant string, policy string) oauth2.Endpoint {
u := url.URL{
Scheme: "https",
Host: tenant + ".b2clogin.com",
}
// HipChatServer returns a new oauth2.Endpoint for a HipChat Server instance
// running on the given domain or host.
func HipChatServer(host string) oauth2.Endpoint {
return oauth2.Endpoint{
AuthURL: u.JoinPath(tenant+".onmicrosoft.com", policy, "/oauth2/v2.0/authorize").String(),
TokenURL: u.JoinPath(tenant+".onmicrosoft.com", policy, "/oauth2/v2.0/token").String(),
AuthURL: "https://" + host + "/users/authorize",
TokenURL: "https://" + host + "/v2/oauth/token",
}
}
@ -424,42 +251,9 @@ func AzureADB2CEndpoint(tenant string, policy string) oauth2.Endpoint {
// https://docs.aws.amazon.com/cognito/latest/developerguide/cognito-user-pools-assign-domain.html
// https://docs.aws.amazon.com/cognito/latest/developerguide/cognito-userpools-server-contract-reference.html
func AWSCognito(domain string) oauth2.Endpoint {
u, err := url.Parse(domain)
if err != nil || u.Scheme == "" || u.Host == "" {
panic("endpoints: invalid domain" + domain)
}
domain = strings.TrimRight(domain, "/")
return oauth2.Endpoint{
AuthURL: u.JoinPath("/oauth2/authorize").String(),
TokenURL: u.JoinPath("/oauth2/token").String(),
}
}
// HipChatServer returns a new oauth2.Endpoint for a HipChat Server instance.
// host should be a hostname, without any scheme prefix.
//
// Documentation: https://developer.atlassian.com/server/hipchat/hipchat-rest-api-access-tokens/
func HipChatServer(host string) oauth2.Endpoint {
u := url.URL{
Scheme: "https",
Host: host,
}
return oauth2.Endpoint{
AuthURL: u.JoinPath("/users/authorize").String(),
TokenURL: u.JoinPath("/v2/oauth/token").String(),
}
}
// Shopify returns a new oauth2.Endpoint for the supplied shop domain name.
// host should be a hostname, without any scheme prefix.
//
// Documentation: https://shopify.dev/docs/apps/auth/oauth
func Shopify(host string) oauth2.Endpoint {
u := url.URL{
Scheme: "https",
Host: host,
}
return oauth2.Endpoint{
AuthURL: u.JoinPath("/admin/oauth/authorize").String(),
TokenURL: u.JoinPath("/admin/oauth/access_token").String(),
AuthURL: domain + "/oauth2/authorize",
TokenURL: domain + "/oauth2/token",
}
}

View File

@ -6,8 +6,11 @@
package gitlab // import "golang.org/x/oauth2/gitlab"
import (
"golang.org/x/oauth2/endpoints"
"golang.org/x/oauth2"
)
// Endpoint is GitLab's OAuth 2.0 endpoint.
var Endpoint = endpoints.GitLab
var Endpoint = oauth2.Endpoint{
AuthURL: "https://gitlab.com/oauth/authorize",
TokenURL: "https://gitlab.com/oauth/token",
}

7
go.mod
View File

@ -1,5 +1,8 @@
module golang.org/x/oauth2
go 1.23.0
go 1.18
require cloud.google.com/go/compute/metadata v0.3.0
require (
cloud.google.com/go/compute/metadata v0.3.0
github.com/google/go-cmp v0.5.9
)

2
go.sum
View File

@ -1,2 +1,4 @@
cloud.google.com/go/compute/metadata v0.3.0 h1:Tz+eQXMEqDIKRsmY3cHTL6FVaynIjX2QxYC4trgAKZc=
cloud.google.com/go/compute/metadata v0.3.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k=
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=

View File

@ -251,12 +251,6 @@ func FindDefaultCredentials(ctx context.Context, scopes ...string) (*Credentials
// a Google Developers service account key file, a gcloud user credentials file (a.k.a. refresh
// token JSON), or the JSON configuration file for workload identity federation in non-Google cloud
// platforms (see https://cloud.google.com/iam/docs/how-to#using-workload-identity-federation).
//
// Important: If you accept a credential configuration (credential JSON/File/Stream) from an
// external source for authentication to Google Cloud Platform, you must validate it before
// providing it to any Google API or library. Providing an unvalidated credential configuration to
// Google APIs can compromise the security of your systems and data. For more information, refer to
// [Validate credential configurations from external sources](https://cloud.google.com/docs/authentication/external/externally-sourced-credentials).
func CredentialsFromJSONWithParams(ctx context.Context, jsonData []byte, params CredentialsParams) (*Credentials, error) {
// Make defensive copy of the slices in params.
params = params.deepCopy()
@ -300,12 +294,6 @@ func CredentialsFromJSONWithParams(ctx context.Context, jsonData []byte, params
}
// CredentialsFromJSON invokes CredentialsFromJSONWithParams with the specified scopes.
//
// Important: If you accept a credential configuration (credential JSON/File/Stream) from an
// external source for authentication to Google Cloud Platform, you must validate it before
// providing it to any Google API or library. Providing an unvalidated credential configuration to
// Google APIs can compromise the security of your systems and data. For more information, refer to
// [Validate credential configurations from external sources](https://cloud.google.com/docs/authentication/external/externally-sourced-credentials).
func CredentialsFromJSON(ctx context.Context, jsonData []byte, scopes ...string) (*Credentials, error) {
var params CredentialsParams
params.Scopes = scopes

View File

@ -39,7 +39,7 @@ import (
"context"
"encoding/json"
"fmt"
"io"
"io/ioutil"
"net/http"
"net/url"
"strings"
@ -198,7 +198,7 @@ func (dts downscopingTokenSource) Token() (*oauth2.Token, error) {
return nil, fmt.Errorf("unable to generate POST Request %v", err)
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
respBody, err := ioutil.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("downscope: unable to read response body: %v", err)
}

View File

@ -6,7 +6,7 @@ package downscope
import (
"context"
"io"
"io/ioutil"
"net/http"
"net/http/httptest"
"testing"
@ -27,7 +27,7 @@ func Test_DownscopedTokenSource(t *testing.T) {
if r.URL.String() != "/" {
t.Errorf("Unexpected request URL, %v is found", r.URL)
}
body, err := io.ReadAll(r.Body)
body, err := ioutil.ReadAll(r.Body)
if err != nil {
t.Fatalf("Failed to read request body: %v", err)
}

View File

@ -7,9 +7,9 @@ package google_test
import (
"context"
"fmt"
"io/ioutil"
"log"
"net/http"
"os"
"golang.org/x/oauth2"
"golang.org/x/oauth2/google"
@ -60,7 +60,7 @@ func ExampleJWTConfigFromJSON() {
// To create a service account client, click "Create new Client ID",
// select "Service Account", and click "Create Client ID". A JSON
// key file will then be downloaded to your computer.
data, err := os.ReadFile("/path/to/your-project-key.json")
data, err := ioutil.ReadFile("/path/to/your-project-key.json")
if err != nil {
log.Fatal(err)
}
@ -136,7 +136,7 @@ func ExampleComputeTokenSource() {
func ExampleCredentialsFromJSON() {
ctx := context.Background()
data, err := os.ReadFile("/path/to/key-file.json")
data, err := ioutil.ReadFile("/path/to/key-file.json")
if err != nil {
log.Fatal(err)
}

View File

@ -14,6 +14,7 @@ import (
"errors"
"fmt"
"io"
"io/ioutil"
"net/http"
"net/url"
"os"
@ -27,7 +28,7 @@ import (
// AwsSecurityCredentials models AWS security credentials.
type AwsSecurityCredentials struct {
// AccessKeyID is the AWS Access Key ID - Required.
// AccessKeyId is the AWS Access Key ID - Required.
AccessKeyID string `json:"AccessKeyID"`
// SecretAccessKey is the AWS Secret Access Key - Required.
SecretAccessKey string `json:"SecretAccessKey"`
@ -169,7 +170,7 @@ func requestDataHash(req *http.Request) (string, error) {
}
defer requestBody.Close()
requestData, err = io.ReadAll(io.LimitReader(requestBody, 1<<20))
requestData, err = ioutil.ReadAll(io.LimitReader(requestBody, 1<<20))
if err != nil {
return "", err
}
@ -418,7 +419,7 @@ func (cs *awsCredentialSource) getAWSSessionToken() (string, error) {
}
defer resp.Body.Close()
respBody, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
respBody, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20))
if err != nil {
return "", err
}
@ -461,7 +462,7 @@ func (cs *awsCredentialSource) getRegion(headers map[string]string) (string, err
}
defer resp.Body.Close()
respBody, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
respBody, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20))
if err != nil {
return "", err
}
@ -530,7 +531,7 @@ func (cs *awsCredentialSource) getMetadataSecurityCredentials(roleName string, h
}
defer resp.Body.Close()
respBody, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
respBody, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20))
if err != nil {
return result, err
}
@ -563,7 +564,7 @@ func (cs *awsCredentialSource) getMetadataRoleName(headers map[string]string) (s
}
defer resp.Body.Close()
respBody, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
respBody, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20))
if err != nil {
return "", err
}

View File

@ -263,7 +263,7 @@ const (
fileTypeJSON = "json"
)
// Format contains information needed to retrieve a subject token for URL or File sourced credentials.
// Format contains information needed to retireve a subject token for URL or File sourced credentials.
type Format struct {
// Type should be either "text" or "json". This determines whether the file or URL sourced credentials
// expect a simple text subject token or if the subject token will be contained in a JSON object.
@ -278,52 +278,20 @@ type Format struct {
type CredentialSource struct {
// File is the location for file sourced credentials.
// One field amongst File, URL, Executable, or EnvironmentID should be provided, depending on the kind of credential in question.
//
// Important: If you accept a credential configuration (credential
// JSON/File/Stream) from an external source for authentication to Google
// Cloud Platform, you must validate it before providing it to any Google
// API or library. Providing an unvalidated credential configuration to
// Google APIs can compromise the security of your systems and data. For
// more information, refer to [Validate credential configurations from
// external sources](https://cloud.google.com/docs/authentication/external/externally-sourced-credentials).
File string `json:"file"`
// Url is the URL to call for URL sourced credentials.
// One field amongst File, URL, Executable, or EnvironmentID should be provided, depending on the kind of credential in question.
//
// Important: If you accept a credential configuration (credential
// JSON/File/Stream) from an external source for authentication to Google
// Cloud Platform, you must validate it before providing it to any Google
// API or library. Providing an unvalidated credential configuration to
// Google APIs can compromise the security of your systems and data. For
// more information, refer to [Validate credential configurations from
// external sources](https://cloud.google.com/docs/authentication/external/externally-sourced-credentials).
URL string `json:"url"`
// Headers are the headers to attach to the request for URL sourced credentials.
Headers map[string]string `json:"headers"`
// Executable is the configuration object for executable sourced credentials.
// One field amongst File, URL, Executable, or EnvironmentID should be provided, depending on the kind of credential in question.
//
// Important: If you accept a credential configuration (credential
// JSON/File/Stream) from an external source for authentication to Google
// Cloud Platform, you must validate it before providing it to any Google
// API or library. Providing an unvalidated credential configuration to
// Google APIs can compromise the security of your systems and data. For
// more information, refer to [Validate credential configurations from
// external sources](https://cloud.google.com/docs/authentication/external/externally-sourced-credentials).
Executable *ExecutableConfig `json:"executable"`
// EnvironmentID is the EnvironmentID used for AWS sourced credentials. This should start with "AWS".
// One field amongst File, URL, Executable, or EnvironmentID should be provided, depending on the kind of credential in question.
//
// Important: If you accept a credential configuration (credential
// JSON/File/Stream) from an external source for authentication to Google
// Cloud Platform, you must validate it before providing it to any Google
// API or library. Providing an unvalidated credential configuration to
// Google APIs can compromise the security of your systems and data. For
// more information, refer to [Validate credential configurations from
// external sources](https://cloud.google.com/docs/authentication/external/externally-sourced-credentials).
EnvironmentID string `json:"environment_id"`
// RegionURL is the metadata URL to retrieve the region from for EC2 AWS credentials.
RegionURL string `json:"region_url"`
@ -361,7 +329,7 @@ type SubjectTokenSupplier interface {
type AwsSecurityCredentialsSupplier interface {
// AwsRegion should return the AWS region or an error.
AwsRegion(ctx context.Context, options SupplierOptions) (string, error)
// AwsSecurityCredentials should return a valid set of AwsSecurityCredentials or an error.
// GetAwsSecurityCredentials should return a valid set of AwsSecurityCredentials or an error.
// The external account token source does not cache the returned security credentials, so caching
// logic should be implemented in the supplier to prevent multiple requests for the same security credentials.
AwsSecurityCredentials(ctx context.Context, options SupplierOptions) (*AwsSecurityCredentials, error)
@ -486,11 +454,11 @@ func (ts tokenSource) Token() (*oauth2.Token, error) {
ClientID: conf.ClientID,
ClientSecret: conf.ClientSecret,
}
var options map[string]any
var options map[string]interface{}
// Do not pass workforce_pool_user_project when client authentication is used.
// The client ID is sufficient for determining the user project.
if conf.WorkforcePoolUserProject != "" && conf.ClientID == "" {
options = map[string]any{
options = map[string]interface{}{
"userProject": conf.WorkforcePoolUserProject,
}
}

View File

@ -8,7 +8,7 @@ import (
"context"
"encoding/json"
"fmt"
"io"
"io/ioutil"
"net/http"
"net/http/httptest"
"testing"
@ -77,7 +77,7 @@ func run(t *testing.T, config *Config, tets *testExchangeTokenServer) (*oauth2.T
if got, want := headerMetrics, tets.metricsHeader; got != want {
t.Errorf("got %v but want %v", got, want)
}
body, err := io.ReadAll(r.Body)
body, err := ioutil.ReadAll(r.Body)
if err != nil {
t.Fatalf("Failed reading request body: %s.", err)
}
@ -131,7 +131,7 @@ func createImpersonationServer(urlWanted, authWanted, bodyWanted, response strin
if got, want := headerContentType, "application/json"; got != want {
t.Errorf("got %v but want %v", got, want)
}
body, err := io.ReadAll(r.Body)
body, err := ioutil.ReadAll(r.Body)
if err != nil {
t.Fatalf("Failed reading request body: %v.", err)
}
@ -160,7 +160,7 @@ func createTargetServer(metricsHeaderWanted string, t *testing.T) *httptest.Serv
if got, want := headerMetrics, metricsHeaderWanted; got != want {
t.Errorf("got %v but want %v", got, want)
}
body, err := io.ReadAll(r.Body)
body, err := ioutil.ReadAll(r.Body)
if err != nil {
t.Fatalf("Failed reading request body: %v.", err)
}
@ -347,12 +347,12 @@ func TestNonworkforceWithWorkforcePoolUserProject(t *testing.T) {
t.Fatalf("Expected error but found none")
}
if got, want := err.Error(), "oauth2/google/externalaccount: Workforce pool user project should not be set for non-workforce pool credentials"; got != want {
t.Errorf("Incorrect error received.\nExpected: %s\nReceived: %s", want, got)
t.Errorf("Incorrect error received.\nExpected: %s\nRecieved: %s", want, got)
}
}
func TestWorkforcePoolCreation(t *testing.T) {
var audienceValidityTests = []struct {
var audienceValidatyTests = []struct {
audience string
expectSuccess bool
}{
@ -371,7 +371,7 @@ func TestWorkforcePoolCreation(t *testing.T) {
}
ctx := context.Background()
for _, tt := range audienceValidityTests {
for _, tt := range audienceValidatyTests {
t.Run(" "+tt.audience, func(t *testing.T) { // We prepend a space ahead of the test input when outputting for sake of readability.
config := testConfig
config.TokenURL = "https://sts.googleapis.com" // Setting the most basic acceptable tokenURL

View File

@ -11,6 +11,7 @@ import (
"errors"
"fmt"
"io"
"io/ioutil"
"os"
"os/exec"
"regexp"
@ -257,7 +258,7 @@ func (cs executableCredentialSource) getTokenFromOutputFile() (token string, err
}
defer file.Close()
data, err := io.ReadAll(io.LimitReader(file, 1<<20))
data, err := ioutil.ReadAll(io.LimitReader(file, 1<<20))
if err != nil || len(data) == 0 {
// Cachefile exists, but no data found. Get new credential.
return "", nil

View File

@ -8,10 +8,13 @@ import (
"context"
"encoding/json"
"fmt"
"io/ioutil"
"os"
"slices"
"sort"
"testing"
"time"
"github.com/google/go-cmp/cmp"
)
type testEnvironment struct {
@ -251,12 +254,14 @@ func TestExecutableCredentialGetEnvironment(t *testing.T) {
ecs.env = &tt.environment
got := ecs.executableEnvironment()
slices.Sort(got)
want := tt.expectedEnvironment
slices.Sort(want)
// This Transformer sorts a []string.
sorter := cmp.Transformer("Sort", func(in []string) []string {
out := append([]string(nil), in...) // Copy input to avoid mutating it
sort.Strings(out)
return out
})
if !slices.Equal(got, want) {
if got, want := ecs.executableEnvironment(), tt.expectedEnvironment; !cmp.Equal(got, want, sorter) {
t.Errorf("Incorrect environment received.\nReceived: %s\nExpected: %s", got, want)
}
})
@ -609,7 +614,7 @@ func TestRetrieveExecutableSubjectTokenSuccesses(t *testing.T) {
}
func TestRetrieveOutputFileSubjectTokenNotJSON(t *testing.T) {
outputFile, err := os.CreateTemp("testdata", "result.*.json")
outputFile, err := ioutil.TempFile("testdata", "result.*.json")
if err != nil {
t.Fatalf("Tempfile failed: %v", err)
}
@ -649,7 +654,7 @@ func TestRetrieveOutputFileSubjectTokenNotJSON(t *testing.T) {
if _, err = base.subjectToken(); err == nil {
t.Fatalf("Expected error but found none")
} else if got, want := err.Error(), jsonParsingError(outputFileSource, "tokentokentoken").Error(); got != want {
t.Errorf("Incorrect error received.\nExpected: %s\nReceived: %s", want, got)
t.Errorf("Incorrect error received.\nExpected: %s\nRecieved: %s", want, got)
}
_, deadlineSet := te.getDeadline()
@ -758,7 +763,7 @@ var cacheFailureTests = []struct {
func TestRetrieveOutputFileSubjectTokenFailureTests(t *testing.T) {
for _, tt := range cacheFailureTests {
t.Run(tt.name, func(t *testing.T) {
outputFile, err := os.CreateTemp("testdata", "result.*.json")
outputFile, err := ioutil.TempFile("testdata", "result.*.json")
if err != nil {
t.Fatalf("Tempfile failed: %v", err)
}
@ -796,7 +801,7 @@ func TestRetrieveOutputFileSubjectTokenFailureTests(t *testing.T) {
if _, err = ecs.subjectToken(); err == nil {
t.Errorf("Expected error but found none")
} else if got, want := err.Error(), tt.expectedErr.Error(); got != want {
t.Errorf("Incorrect error received.\nExpected: %s\nReceived: %s", want, got)
t.Errorf("Incorrect error received.\nExpected: %s\nRecieved: %s", want, got)
}
if _, deadlineSet := te.getDeadline(); deadlineSet {
@ -861,7 +866,7 @@ var invalidCacheTests = []struct {
func TestRetrieveOutputFileSubjectTokenInvalidCache(t *testing.T) {
for _, tt := range invalidCacheTests {
t.Run(tt.name, func(t *testing.T) {
outputFile, err := os.CreateTemp("testdata", "result.*.json")
outputFile, err := ioutil.TempFile("testdata", "result.*.json")
if err != nil {
t.Fatalf("Tempfile failed: %v", err)
}
@ -918,7 +923,7 @@ func TestRetrieveOutputFileSubjectTokenInvalidCache(t *testing.T) {
}
if got, want := out, "tokentokentoken"; got != want {
t.Errorf("Incorrect token received.\nExpected: %s\nReceived: %s", want, got)
t.Errorf("Incorrect token received.\nExpected: %s\nRecieved: %s", want, got)
}
})
}
@ -965,7 +970,8 @@ var cacheSuccessTests = []struct {
func TestRetrieveOutputFileSubjectTokenJwt(t *testing.T) {
for _, tt := range cacheSuccessTests {
t.Run(tt.name, func(t *testing.T) {
outputFile, err := os.CreateTemp("testdata", "result.*.json")
outputFile, err := ioutil.TempFile("testdata", "result.*.json")
if err != nil {
t.Fatalf("Tempfile failed: %v", err)
}
@ -1006,7 +1012,7 @@ func TestRetrieveOutputFileSubjectTokenJwt(t *testing.T) {
if out, err := ecs.subjectToken(); err != nil {
t.Errorf("retrieveSubjectToken() failed: %v", err)
} else if got, want := out, "tokentokentoken"; got != want {
t.Errorf("Incorrect token received.\nExpected: %s\nReceived: %s", want, got)
t.Errorf("Incorrect token received.\nExpected: %s\nRecieved: %s", want, got)
}
if _, deadlineSet := te.getDeadline(); deadlineSet {

View File

@ -10,6 +10,7 @@ import (
"errors"
"fmt"
"io"
"io/ioutil"
"os"
)
@ -28,14 +29,14 @@ func (cs fileCredentialSource) subjectToken() (string, error) {
return "", fmt.Errorf("oauth2/google/externalaccount: failed to open credential file %q", cs.File)
}
defer tokenFile.Close()
tokenBytes, err := io.ReadAll(io.LimitReader(tokenFile, 1<<20))
tokenBytes, err := ioutil.ReadAll(io.LimitReader(tokenFile, 1<<20))
if err != nil {
return "", fmt.Errorf("oauth2/google/externalaccount: failed to read credential file: %v", err)
}
tokenBytes = bytes.TrimSpace(tokenBytes)
switch cs.Format.Type {
case "json":
jsonData := make(map[string]any)
jsonData := make(map[string]interface{})
err = json.Unmarshal(tokenBytes, &jsonData)
if err != nil {
return "", fmt.Errorf("oauth2/google/externalaccount: failed to unmarshal subject token file: %v", err)

View File

@ -7,6 +7,8 @@ package externalaccount
import (
"runtime"
"testing"
"github.com/google/go-cmp/cmp"
)
func TestGoVersion(t *testing.T) {
@ -38,8 +40,8 @@ func TestGoVersion(t *testing.T) {
} {
version = tst.v
got := goVersion()
if got != tst.want {
t.Errorf("go version = %q, want = %q", got, tst.want)
if diff := cmp.Diff(got, tst.want); diff != "" {
t.Errorf("got(-),want(+):\n%s", diff)
}
}
version = runtime.Version

View File

@ -10,6 +10,7 @@ import (
"errors"
"fmt"
"io"
"io/ioutil"
"net/http"
"golang.org/x/oauth2"
@ -43,7 +44,7 @@ func (cs urlCredentialSource) subjectToken() (string, error) {
}
defer resp.Body.Close()
respBody, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
respBody, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20))
if err != nil {
return "", fmt.Errorf("oauth2/google/externalaccount: invalid body in subject token URL query: %v", err)
}
@ -53,7 +54,7 @@ func (cs urlCredentialSource) subjectToken() (string, error) {
switch cs.Format.Type {
case "json":
jsonData := make(map[string]any)
jsonData := make(map[string]interface{})
err = json.Unmarshal(respBody, &jsonData)
if err != nil {
return "", fmt.Errorf("oauth2/google/externalaccount: failed to unmarshal subject token file: %v", err)

View File

@ -285,23 +285,27 @@ func (cs computeSource) Token() (*oauth2.Token, error) {
if err != nil {
return nil, err
}
var res oauth2.Token
var res struct {
AccessToken string `json:"access_token"`
ExpiresInSec int `json:"expires_in"`
TokenType string `json:"token_type"`
}
err = json.NewDecoder(strings.NewReader(tokenJSON)).Decode(&res)
if err != nil {
return nil, fmt.Errorf("oauth2/google: invalid token JSON from metadata: %v", err)
}
if res.ExpiresIn == 0 || res.AccessToken == "" {
if res.ExpiresInSec == 0 || res.AccessToken == "" {
return nil, fmt.Errorf("oauth2/google: incomplete token received from metadata")
}
tok := &oauth2.Token{
AccessToken: res.AccessToken,
TokenType: res.TokenType,
Expiry: time.Now().Add(time.Duration(res.ExpiresIn) * time.Second),
Expiry: time.Now().Add(time.Duration(res.ExpiresInSec) * time.Second),
}
// NOTE(cbro): add hidden metadata about where the token is from.
// This is needed for detection by client libraries to know that credentials come from the metadata server.
// This may be removed in a future version of this library.
return tok.WithExtra(map[string]any{
return tok.WithExtra(map[string]interface{}{
"oauth2.google.tokenSource": "compute-metadata",
"oauth2.google.serviceAccount": acct,
}), nil

View File

@ -72,7 +72,7 @@ func TestConfigFromJSON(t *testing.T) {
t.Errorf("ClientSecret = %q; want %q", got, want)
}
if got, want := conf.RedirectURL, "https://www.example.com/oauth2callback"; got != want {
t.Errorf("RedirectURL = %q; want %q", got, want)
t.Errorf("RedictURL = %q; want %q", got, want)
}
if got, want := strings.Join(conf.Scopes, ","), "scope1,scope2"; got != want {
t.Errorf("Scopes = %q; want %q", got, want)

View File

@ -8,7 +8,7 @@ import (
"context"
"encoding/json"
"errors"
"io"
"io/ioutil"
"net/http"
"net/http/httptest"
"testing"
@ -38,7 +38,7 @@ type testRefreshTokenServer struct {
server *httptest.Server
}
func TestExternalAccountAuthorizedUser_JustToken(t *testing.T) {
func TestExernalAccountAuthorizedUser_JustToken(t *testing.T) {
config := &Config{
Token: "AAAAAAA",
Expiry: now().Add(time.Hour),
@ -57,7 +57,7 @@ func TestExternalAccountAuthorizedUser_JustToken(t *testing.T) {
}
}
func TestExternalAccountAuthorizedUser_TokenRefreshWithRefreshTokenInResponse(t *testing.T) {
func TestExernalAccountAuthorizedUser_TokenRefreshWithRefreshTokenInRespondse(t *testing.T) {
server := &testRefreshTokenServer{
URL: "/",
Authorization: "Basic Q0xJRU5UX0lEOkNMSUVOVF9TRUNSRVQ=",
@ -99,7 +99,7 @@ func TestExternalAccountAuthorizedUser_TokenRefreshWithRefreshTokenInResponse(t
}
}
func TestExternalAccountAuthorizedUser_MinimumFieldsRequiredForRefresh(t *testing.T) {
func TestExernalAccountAuthorizedUser_MinimumFieldsRequiredForRefresh(t *testing.T) {
server := &testRefreshTokenServer{
URL: "/",
Authorization: "Basic Q0xJRU5UX0lEOkNMSUVOVF9TRUNSRVQ=",
@ -187,7 +187,7 @@ func TestExternalAccountAuthorizedUser_MissingRefreshFields(t *testing.T) {
},
},
{
name: "missing client secret",
name: "missing client secrect",
config: Config{
RefreshToken: "BBBBBBBBB",
TokenURL: url,
@ -227,7 +227,7 @@ func (trts *testRefreshTokenServer) run(t *testing.T) (string, error) {
if got, want := headerContentType, trts.ContentType; got != want {
t.Errorf("got %v but want %v", got, want)
}
body, err := io.ReadAll(r.Body)
body, err := ioutil.ReadAll(r.Body)
if err != nil {
t.Fatalf("Failed reading request body: %s.", err)
}

View File

@ -10,6 +10,7 @@ import (
"encoding/json"
"fmt"
"io"
"io/ioutil"
"net/http"
"time"
@ -80,7 +81,7 @@ func (its ImpersonateTokenSource) Token() (*oauth2.Token, error) {
return nil, fmt.Errorf("oauth2/google: unable to generate access token: %v", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
body, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20))
if err != nil {
return nil, fmt.Errorf("oauth2/google: unable to read body: %v", err)
}

View File

@ -9,6 +9,7 @@ import (
"encoding/json"
"fmt"
"io"
"io/ioutil"
"net/http"
"net/url"
"strconv"
@ -27,7 +28,7 @@ func defaultHeader() http.Header {
// The first 4 fields are all mandatory. headers can be used to pass additional
// headers beyond the bare minimum required by the token exchange. options can
// be used to pass additional JSON-structured options to the remote server.
func ExchangeToken(ctx context.Context, endpoint string, request *TokenExchangeRequest, authentication ClientAuthentication, headers http.Header, options map[string]any) (*Response, error) {
func ExchangeToken(ctx context.Context, endpoint string, request *TokenExchangeRequest, authentication ClientAuthentication, headers http.Header, options map[string]interface{}) (*Response, error) {
data := url.Values{}
data.Set("audience", request.Audience)
data.Set("grant_type", "urn:ietf:params:oauth:grant-type:token-exchange")
@ -81,7 +82,7 @@ func makeRequest(ctx context.Context, endpoint string, data url.Values, authenti
}
defer resp.Body.Close()
body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
body, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20))
if err != nil {
return nil, err
}

View File

@ -7,7 +7,7 @@ package stsexchange
import (
"context"
"encoding/json"
"io"
"io/ioutil"
"net/http"
"net/http/httptest"
"net/url"
@ -73,7 +73,7 @@ func TestExchangeToken(t *testing.T) {
if got, want := r.Header.Get("Content-Type"), "application/x-www-form-urlencoded"; got != want {
t.Errorf("Unexpected Content-Type header, got %v, want %v", got, want)
}
body, err := io.ReadAll(r.Body)
body, err := ioutil.ReadAll(r.Body)
if err != nil {
t.Errorf("Failed reading request body: %v.", err)
}
@ -132,7 +132,7 @@ var optsValues = [][]string{{"foo", "bar"}, {"cat", "pan"}}
func TestExchangeToken_Opts(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, err := io.ReadAll(r.Body)
body, err := ioutil.ReadAll(r.Body)
if err != nil {
t.Fatalf("Failed reading request body: %v.", err)
}
@ -142,11 +142,11 @@ func TestExchangeToken_Opts(t *testing.T) {
}
strOpts, ok := data["options"]
if !ok {
t.Errorf("Server didn't receive an \"options\" field.")
t.Errorf("Server didn't recieve an \"options\" field.")
} else if len(strOpts) < 1 {
t.Errorf("\"options\" field has length 0.")
}
var opts map[string]any
var opts map[string]interface{}
err = json.Unmarshal([]byte(strOpts[0]), &opts)
if err != nil {
t.Fatalf("Couldn't parse received \"options\" field.")
@ -159,7 +159,7 @@ func TestExchangeToken_Opts(t *testing.T) {
if !ok {
t.Errorf("Couldn't find first option parameter.")
} else {
tOpts1, ok := val.(map[string]any)
tOpts1, ok := val.(map[string]interface{})
if !ok {
t.Errorf("Failed to assert the first option parameter as type testOpts.")
} else {
@ -176,7 +176,7 @@ func TestExchangeToken_Opts(t *testing.T) {
if !ok {
t.Errorf("Couldn't find second option parameter.")
} else {
tOpts2, ok := val2.(map[string]any)
tOpts2, ok := val2.(map[string]interface{})
if !ok {
t.Errorf("Failed to assert the second option parameter as type testOpts.")
} else {
@ -200,7 +200,7 @@ func TestExchangeToken_Opts(t *testing.T) {
firstOption := testOpts{optsValues[0][0], optsValues[0][1]}
secondOption := testOpts{optsValues[1][0], optsValues[1][1]}
inputOpts := make(map[string]any)
inputOpts := make(map[string]interface{})
inputOpts["one"] = firstOption
inputOpts["two"] = secondOption
ExchangeToken(context.Background(), ts.URL, &exchangeTokenRequest, auth, headers, inputOpts)
@ -220,7 +220,7 @@ func TestRefreshToken(t *testing.T) {
if got, want := r.Header.Get("Content-Type"), "application/x-www-form-urlencoded"; got != want {
t.Errorf("Unexpected Content-Type header, got %v, want %v", got, want)
}
body, err := io.ReadAll(r.Body)
body, err := ioutil.ReadAll(r.Body)
if err != nil {
t.Errorf("Failed reading request body: %v.", err)
}

View File

@ -2,5 +2,5 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package internal contains support packages for [golang.org/x/oauth2].
// Package internal contains support packages for oauth2 package.
package internal

View File

@ -13,7 +13,7 @@ import (
)
// ParseKey converts the binary contents of a private key file
// to an [*rsa.PrivateKey]. It detects whether the private key is in a
// to an *rsa.PrivateKey. It detects whether the private key is in a
// PEM container or not. If so, it extracts the private key
// from PEM container before conversion. It only supports PEM
// containers with no passphrase.

View File

@ -10,6 +10,7 @@ import (
"errors"
"fmt"
"io"
"io/ioutil"
"math"
"mime"
"net/http"
@ -25,9 +26,9 @@ import (
// the requests to access protected resources on the OAuth 2.0
// provider's backend.
//
// This type is a mirror of [golang.org/x/oauth2.Token] and exists to break
// This type is a mirror of oauth2.Token and exists to break
// an otherwise-circular dependency. Other internal packages
// should convert this Token into an [golang.org/x/oauth2.Token] before use.
// should convert this Token into an oauth2.Token before use.
type Token struct {
// AccessToken is the token that authorizes and authenticates
// the requests.
@ -49,16 +50,9 @@ type Token struct {
// mechanisms for that TokenSource will not be used.
Expiry time.Time
// ExpiresIn is the OAuth2 wire format "expires_in" field,
// which specifies how many seconds later the token expires,
// relative to an unknown time base approximately around "now".
// It is the application's responsibility to populate
// `Expiry` from `ExpiresIn` when required.
ExpiresIn int64 `json:"expires_in,omitempty"`
// Raw optionally contains extra metadata from the server
// when updating a token.
Raw any
Raw interface{}
}
// tokenJSON is the struct representing the HTTP response from OAuth2
@ -105,6 +99,14 @@ func (e *expirationTime) UnmarshalJSON(b []byte) error {
return nil
}
// RegisterBrokenAuthHeaderProvider previously did something. It is now a no-op.
//
// Deprecated: this function no longer does anything. Caller code that
// wants to avoid potential extra HTTP requests made during
// auto-probing of the provider's auth style should set
// Endpoint.AuthStyle.
func RegisterBrokenAuthHeaderProvider(tokenURL string) {}
// AuthStyle is a copy of the golang.org/x/oauth2 package's AuthStyle type.
type AuthStyle int
@ -141,11 +143,6 @@ func (lc *LazyAuthStyleCache) Get() *AuthStyleCache {
return c
}
type authStyleCacheKey struct {
url string
clientID string
}
// AuthStyleCache is the set of tokenURLs we've successfully used via
// RetrieveToken and which style auth we ended up using.
// It's called a cache, but it doesn't (yet?) shrink. It's expected that
@ -153,26 +150,26 @@ type authStyleCacheKey struct {
// small.
type AuthStyleCache struct {
mu sync.Mutex
m map[authStyleCacheKey]AuthStyle
m map[string]AuthStyle // keyed by tokenURL
}
// lookupAuthStyle reports which auth style we last used with tokenURL
// when calling RetrieveToken and whether we have ever done so.
func (c *AuthStyleCache) lookupAuthStyle(tokenURL, clientID string) (style AuthStyle, ok bool) {
func (c *AuthStyleCache) lookupAuthStyle(tokenURL string) (style AuthStyle, ok bool) {
c.mu.Lock()
defer c.mu.Unlock()
style, ok = c.m[authStyleCacheKey{tokenURL, clientID}]
style, ok = c.m[tokenURL]
return
}
// setAuthStyle adds an entry to authStyleCache, documented above.
func (c *AuthStyleCache) setAuthStyle(tokenURL, clientID string, v AuthStyle) {
func (c *AuthStyleCache) setAuthStyle(tokenURL string, v AuthStyle) {
c.mu.Lock()
defer c.mu.Unlock()
if c.m == nil {
c.m = make(map[authStyleCacheKey]AuthStyle)
c.m = make(map[string]AuthStyle)
}
c.m[authStyleCacheKey{tokenURL, clientID}] = v
c.m[tokenURL] = v
}
// newTokenRequest returns a new *http.Request to retrieve a new token
@ -213,9 +210,9 @@ func cloneURLValues(v url.Values) url.Values {
}
func RetrieveToken(ctx context.Context, clientID, clientSecret, tokenURL string, v url.Values, authStyle AuthStyle, styleCache *AuthStyleCache) (*Token, error) {
needsAuthStyleProbe := authStyle == AuthStyleUnknown
needsAuthStyleProbe := authStyle == 0
if needsAuthStyleProbe {
if style, ok := styleCache.lookupAuthStyle(tokenURL, clientID); ok {
if style, ok := styleCache.lookupAuthStyle(tokenURL); ok {
authStyle = style
needsAuthStyleProbe = false
} else {
@ -245,7 +242,7 @@ func RetrieveToken(ctx context.Context, clientID, clientSecret, tokenURL string,
token, err = doTokenRoundTrip(ctx, req)
}
if needsAuthStyleProbe && err == nil {
styleCache.setAuthStyle(tokenURL, clientID, authStyle)
styleCache.setAuthStyle(tokenURL, authStyle)
}
// Don't overwrite `RefreshToken` with an empty value
// if this was a token refreshing request.
@ -260,7 +257,7 @@ func doTokenRoundTrip(ctx context.Context, req *http.Request) (*Token, error) {
if err != nil {
return nil, err
}
body, err := io.ReadAll(io.LimitReader(r.Body, 1<<20))
body, err := ioutil.ReadAll(io.LimitReader(r.Body, 1<<20))
r.Body.Close()
if err != nil {
return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err)
@ -315,8 +312,7 @@ func doTokenRoundTrip(ctx context.Context, req *http.Request) (*Token, error) {
TokenType: tj.TokenType,
RefreshToken: tj.RefreshToken,
Expiry: tj.expiry(),
ExpiresIn: int64(tj.ExpiresIn),
Raw: make(map[string]any),
Raw: make(map[string]interface{}),
}
json.Unmarshal(body, &token.Raw) // no error checks for optional fields
}

View File

@ -75,48 +75,3 @@ func TestExpiresInUpperBound(t *testing.T) {
t.Errorf("expiration time = %v; want %v", e, want)
}
}
func TestAuthStyleCache(t *testing.T) {
var c LazyAuthStyleCache
cases := []struct {
url string
clientID string
style AuthStyle
}{
{
"https://host1.example.com/token",
"client_1",
AuthStyleInHeader,
}, {
"https://host2.example.com/token",
"client_2",
AuthStyleInParams,
}, {
"https://host1.example.com/token",
"client_3",
AuthStyleInParams,
},
}
for _, tt := range cases {
t.Run(tt.clientID, func(t *testing.T) {
cc := c.Get()
got, ok := cc.lookupAuthStyle(tt.url, tt.clientID)
if ok {
t.Fatalf("unexpected auth style found on first request: %v", got)
}
cc.setAuthStyle(tt.url, tt.clientID, tt.style)
got, ok = cc.lookupAuthStyle(tt.url, tt.clientID)
if !ok {
t.Fatalf("auth style not found in cache")
}
if got != tt.style {
t.Fatalf("auth style mismatch, got=%v, want=%v", got, tt.style)
}
})
}
}

View File

@ -9,8 +9,8 @@ import (
"net/http"
)
// HTTPClient is the context key to use with [context.WithValue]
// to associate an [*http.Client] value with a context.
// HTTPClient is the context key to use with golang.org/x/net/context's
// WithValue function to associate an *http.Client value with a context.
var HTTPClient ContextKey
// ContextKey is just an empty struct. It exists so HTTPClient can be

View File

@ -13,6 +13,7 @@ import (
"encoding/json"
"fmt"
"io"
"io/ioutil"
"net/http"
"net/url"
"strings"
@ -113,7 +114,7 @@ func (js jwtSource) Token() (*oauth2.Token, error) {
return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
body, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20))
if err != nil {
return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err)
}
@ -122,7 +123,11 @@ func (js jwtSource) Token() (*oauth2.Token, error) {
}
// tokenRes is the JSON response body.
var tokenRes oauth2.Token
var tokenRes struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
ExpiresIn int64 `json:"expires_in"` // relative seconds from now
}
if err := json.Unmarshal(body, &tokenRes); err != nil {
return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err)
}

View File

@ -4,7 +4,7 @@
// Package jws provides a partial implementation
// of JSON Web Signature encoding and decoding.
// It exists to support the [golang.org/x/oauth2] package.
// It exists to support the golang.org/x/oauth2 package.
//
// See RFC 7515.
//
@ -48,7 +48,7 @@ type ClaimSet struct {
// See http://tools.ietf.org/html/draft-jones-json-web-token-10#section-4.3
// This array is marshalled using custom code (see (c *ClaimSet) encode()).
PrivateClaims map[string]any `json:"-"`
PrivateClaims map[string]interface{} `json:"-"`
}
func (c *ClaimSet) encode() (string, error) {
@ -116,12 +116,12 @@ func (h *Header) encode() (string, error) {
// Decode decodes a claim set from a JWS payload.
func Decode(payload string) (*ClaimSet, error) {
// decode returned id token to get expiry
_, claims, _, ok := parseToken(payload)
if !ok {
s := strings.Split(payload, ".")
if len(s) < 2 {
// TODO(jbd): Provide more context about the error.
return nil, errors.New("jws: invalid token received")
}
decoded, err := base64.RawURLEncoding.DecodeString(claims)
decoded, err := base64.RawURLEncoding.DecodeString(s[1])
if err != nil {
return nil, err
}
@ -152,7 +152,7 @@ func EncodeWithSigner(header *Header, c *ClaimSet, sg Signer) (string, error) {
}
// Encode encodes a signed JWS with provided header and claim set.
// This invokes [EncodeWithSigner] using [crypto/rsa.SignPKCS1v15] with the given RSA private key.
// This invokes EncodeWithSigner using crypto/rsa.SignPKCS1v15 with the given RSA private key.
func Encode(header *Header, c *ClaimSet, key *rsa.PrivateKey) (string, error) {
sg := func(data []byte) (sig []byte, err error) {
h := sha256.New()
@ -165,34 +165,18 @@ func Encode(header *Header, c *ClaimSet, key *rsa.PrivateKey) (string, error) {
// Verify tests whether the provided JWT token's signature was produced by the private key
// associated with the supplied public key.
func Verify(token string, key *rsa.PublicKey) error {
header, claims, sig, ok := parseToken(token)
if !ok {
parts := strings.Split(token, ".")
if len(parts) != 3 {
return errors.New("jws: invalid token received, token must have 3 parts")
}
signatureString, err := base64.RawURLEncoding.DecodeString(sig)
signedContent := parts[0] + "." + parts[1]
signatureString, err := base64.RawURLEncoding.DecodeString(parts[2])
if err != nil {
return err
}
h := sha256.New()
h.Write([]byte(header + tokenDelim + claims))
h.Write([]byte(signedContent))
return rsa.VerifyPKCS1v15(key, crypto.SHA256, h.Sum(nil), signatureString)
}
func parseToken(s string) (header, claims, sig string, ok bool) {
header, s, ok = strings.Cut(s, tokenDelim)
if !ok { // no period found
return "", "", "", false
}
claims, s, ok = strings.Cut(s, tokenDelim)
if !ok { // only one period found
return "", "", "", false
}
sig, _, ok = strings.Cut(s, tokenDelim)
if ok { // three periods found
return "", "", "", false
}
return header, claims, sig, true
}
const tokenDelim = "."

View File

@ -7,8 +7,6 @@ package jws
import (
"crypto/rand"
"crypto/rsa"
"net/http"
"strings"
"testing"
)
@ -41,57 +39,8 @@ func TestSignAndVerify(t *testing.T) {
}
func TestVerifyFailsOnMalformedClaim(t *testing.T) {
cases := []struct {
desc string
token string
}{
{
desc: "no periods",
token: "aa",
}, {
desc: "only one period",
token: "a.a",
}, {
desc: "more than two periods",
token: "a.a.a.a",
},
}
for _, tc := range cases {
f := func(t *testing.T) {
err := Verify(tc.token, nil)
if err == nil {
t.Error("got no errors; want improperly formed JWT not to be verified")
}
}
t.Run(tc.desc, f)
}
}
func BenchmarkVerify(b *testing.B) {
cases := []struct {
desc string
token string
}{
{
desc: "full of periods",
token: strings.Repeat(".", http.DefaultMaxHeaderBytes),
}, {
desc: "two trailing periods",
token: strings.Repeat("a", http.DefaultMaxHeaderBytes-2) + "..",
},
}
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
b.Fatal(err)
}
for _, bc := range cases {
f := func(b *testing.B) {
b.ReportAllocs()
b.ResetTimer()
for range b.N {
Verify(bc.token, &privateKey.PublicKey)
}
}
b.Run(bc.desc, f)
err := Verify("abc.def", nil)
if err == nil {
t.Error("got no errors; want improperly formed JWT not to be verified")
}
}

View File

@ -10,7 +10,7 @@ import (
"golang.org/x/oauth2/jwt"
)
func ExampleConfig() {
func ExampleJWTConfig() {
ctx := context.Background()
conf := &jwt.Config{
Email: "xxx@developer.com",

View File

@ -13,6 +13,7 @@ import (
"encoding/json"
"fmt"
"io"
"io/ioutil"
"net/http"
"net/url"
"strings"
@ -68,7 +69,7 @@ type Config struct {
// PrivateClaims optionally specifies custom private claims in the JWT.
// See http://tools.ietf.org/html/draft-jones-json-web-token-10#section-4.3
PrivateClaims map[string]any
PrivateClaims map[string]interface{}
// UseIDToken optionally specifies whether ID token should be used instead
// of access token when the server returns both.
@ -135,7 +136,7 @@ func (js jwtSource) Token() (*oauth2.Token, error) {
return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
body, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20))
if err != nil {
return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err)
}
@ -147,8 +148,10 @@ func (js jwtSource) Token() (*oauth2.Token, error) {
}
// tokenRes is the JSON response body.
var tokenRes struct {
oauth2.Token
IDToken string `json:"id_token"`
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
IDToken string `json:"id_token"`
ExpiresIn int64 `json:"expires_in"` // relative seconds from now
}
if err := json.Unmarshal(body, &tokenRes); err != nil {
return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err)
@ -157,7 +160,7 @@ func (js jwtSource) Token() (*oauth2.Token, error) {
AccessToken: tokenRes.AccessToken,
TokenType: tokenRes.TokenType,
}
raw := make(map[string]any)
raw := make(map[string]interface{})
json.Unmarshal(body, &raw) // no error checks for optional fields
token = token.WithExtra(raw)

View File

@ -227,7 +227,7 @@ func TestJWTFetch_AssertionPayload(t *testing.T) {
PrivateKey: dummyPrivateKey,
PrivateKeyID: "ABCDEFGHIJKLMNOPQRSTUVWXYZ",
TokenURL: ts.URL,
PrivateClaims: map[string]any{
PrivateClaims: map[string]interface{}{
"private0": "claim0",
"private1": "claim1",
},
@ -273,11 +273,11 @@ func TestJWTFetch_AssertionPayload(t *testing.T) {
t.Errorf("payload prn = %q; want %q", got, want)
}
if len(conf.PrivateClaims) > 0 {
var got any
var got interface{}
if err := json.Unmarshal(gotjson, &got); err != nil {
t.Errorf("failed to parse payload; err = %q", err)
}
m := got.(map[string]any)
m := got.(map[string]interface{})
for v, k := range conf.PrivateClaims {
if !reflect.DeepEqual(m[v], k) {
t.Errorf("payload private claims key = %q: got %#v; want %#v", v, m[v], k)

View File

@ -22,9 +22,9 @@ import (
)
// NoContext is the default context you should supply if not using
// your own [context.Context].
// your own context.Context (see https://golang.org/x/net/context).
//
// Deprecated: Use [context.Background] or [context.TODO] instead.
// Deprecated: Use context.Background() or context.TODO() instead.
var NoContext = context.TODO()
// RegisterBrokenAuthHeaderProvider previously did something. It is now a no-op.
@ -37,8 +37,8 @@ func RegisterBrokenAuthHeaderProvider(tokenURL string) {}
// Config describes a typical 3-legged OAuth2 flow, with both the
// client application information and the server's endpoint URLs.
// For the client credentials 2-legged OAuth2 flow, see the
// [golang.org/x/oauth2/clientcredentials] package.
// For the client credentials 2-legged OAuth2 flow, see the clientcredentials
// package (https://golang.org/x/oauth2/clientcredentials).
type Config struct {
// ClientID is the application's ID.
ClientID string
@ -46,7 +46,7 @@ type Config struct {
// ClientSecret is the application's secret.
ClientSecret string
// Endpoint contains the authorization server's token endpoint
// Endpoint contains the resource server's token endpoint
// URLs. These are constants specific to each server and are
// often available via site-specific packages, such as
// google.Endpoint or github.Endpoint.
@ -56,7 +56,7 @@ type Config struct {
// the OAuth flow, after the resource owner's URLs.
RedirectURL string
// Scopes specifies optional requested permissions.
// Scope specifies optional requested permissions.
Scopes []string
// authStyleCache caches which auth style to use when Endpoint.AuthStyle is
@ -135,7 +135,7 @@ type setParam struct{ k, v string }
func (p setParam) setValue(m url.Values) { m.Set(p.k, p.v) }
// SetAuthURLParam builds an [AuthCodeOption] which passes key/value parameters
// SetAuthURLParam builds an AuthCodeOption which passes key/value parameters
// to a provider's authorization endpoint.
func SetAuthURLParam(key, value string) AuthCodeOption {
return setParam{key, value}
@ -148,8 +148,8 @@ func SetAuthURLParam(key, value string) AuthCodeOption {
// request and callback. The authorization server includes this value when
// redirecting the user agent back to the client.
//
// Opts may include [AccessTypeOnline] or [AccessTypeOffline], as well
// as [ApprovalForce].
// Opts may include AccessTypeOnline or AccessTypeOffline, as well
// as ApprovalForce.
//
// To protect against CSRF attacks, opts should include a PKCE challenge
// (S256ChallengeOption). Not all servers support PKCE. An alternative is to
@ -194,7 +194,7 @@ func (c *Config) AuthCodeURL(state string, opts ...AuthCodeOption) string {
// and when other authorization grant types are not available."
// See https://tools.ietf.org/html/rfc6749#section-4.3 for more info.
//
// The provided context optionally controls which HTTP client is used. See the [HTTPClient] variable.
// The provided context optionally controls which HTTP client is used. See the HTTPClient variable.
func (c *Config) PasswordCredentialsToken(ctx context.Context, username, password string) (*Token, error) {
v := url.Values{
"grant_type": {"password"},
@ -212,10 +212,10 @@ func (c *Config) PasswordCredentialsToken(ctx context.Context, username, passwor
// It is used after a resource provider redirects the user back
// to the Redirect URI (the URL obtained from AuthCodeURL).
//
// The provided context optionally controls which HTTP client is used. See the [HTTPClient] variable.
// The provided context optionally controls which HTTP client is used. See the HTTPClient variable.
//
// The code will be in the [http.Request.FormValue]("code"). Before
// calling Exchange, be sure to validate [http.Request.FormValue]("state") if you are
// The code will be in the *http.Request.FormValue("code"). Before
// calling Exchange, be sure to validate FormValue("state") if you are
// using it to protect against CSRF attacks.
//
// If using PKCE to protect against CSRF attacks, opts should include a
@ -242,10 +242,10 @@ func (c *Config) Client(ctx context.Context, t *Token) *http.Client {
return NewClient(ctx, c.TokenSource(ctx, t))
}
// TokenSource returns a [TokenSource] that returns t until t expires,
// TokenSource returns a TokenSource that returns t until t expires,
// automatically refreshing it as necessary using the provided context.
//
// Most users will use [Config.Client] instead.
// Most users will use Config.Client instead.
func (c *Config) TokenSource(ctx context.Context, t *Token) TokenSource {
tkr := &tokenRefresher{
ctx: ctx,
@ -260,7 +260,7 @@ func (c *Config) TokenSource(ctx context.Context, t *Token) TokenSource {
}
}
// tokenRefresher is a TokenSource that makes "grant_type=refresh_token"
// tokenRefresher is a TokenSource that makes "grant_type"=="refresh_token"
// HTTP requests to renew a token using a RefreshToken.
type tokenRefresher struct {
ctx context.Context // used to get HTTP requests
@ -288,7 +288,7 @@ func (tf *tokenRefresher) Token() (*Token, error) {
if tf.refreshToken != tk.RefreshToken {
tf.refreshToken = tk.RefreshToken
}
return tk, nil
return tk, err
}
// reuseTokenSource is a TokenSource that holds a single token in memory
@ -305,7 +305,8 @@ type reuseTokenSource struct {
}
// Token returns the current token if it's still valid, else will
// refresh the current token and return the new one.
// refresh the current token (using r.Context for HTTP client
// information) and return the new one.
func (s *reuseTokenSource) Token() (*Token, error) {
s.mu.Lock()
defer s.mu.Unlock()
@ -321,7 +322,7 @@ func (s *reuseTokenSource) Token() (*Token, error) {
return t, nil
}
// StaticTokenSource returns a [TokenSource] that always returns the same token.
// StaticTokenSource returns a TokenSource that always returns the same token.
// Because the provided token t is never refreshed, StaticTokenSource is only
// useful for tokens that never expire.
func StaticTokenSource(t *Token) TokenSource {
@ -337,16 +338,16 @@ func (s staticTokenSource) Token() (*Token, error) {
return s.t, nil
}
// HTTPClient is the context key to use with [context.WithValue]
// to associate a [*http.Client] value with a context.
// HTTPClient is the context key to use with golang.org/x/net/context's
// WithValue function to associate an *http.Client value with a context.
var HTTPClient internal.ContextKey
// NewClient creates an [*http.Client] from a [context.Context] and [TokenSource].
// NewClient creates an *http.Client from a Context and TokenSource.
// The returned client is not valid beyond the lifetime of the context.
//
// Note that if a custom [*http.Client] is provided via the [context.Context] it
// Note that if a custom *http.Client is provided via the Context it
// is used only for token acquisition and is not used to configure the
// [*http.Client] returned from NewClient.
// *http.Client returned from NewClient.
//
// As a special case, if src is nil, a non-OAuth2 client is returned
// using the provided context. This exists to support related OAuth2
@ -355,19 +356,15 @@ func NewClient(ctx context.Context, src TokenSource) *http.Client {
if src == nil {
return internal.ContextClient(ctx)
}
cc := internal.ContextClient(ctx)
return &http.Client{
Transport: &Transport{
Base: cc.Transport,
Base: internal.ContextClient(ctx).Transport,
Source: ReuseTokenSource(nil, src),
},
CheckRedirect: cc.CheckRedirect,
Jar: cc.Jar,
Timeout: cc.Timeout,
}
}
// ReuseTokenSource returns a [TokenSource] which repeatedly returns the
// ReuseTokenSource returns a TokenSource which repeatedly returns the
// same token as long as it's valid, starting with t.
// When its cached token is invalid, a new token is obtained from src.
//
@ -375,10 +372,10 @@ func NewClient(ctx context.Context, src TokenSource) *http.Client {
// (such as a file on disk) between runs of a program, rather than
// obtaining new tokens unnecessarily.
//
// The initial token t may be nil, in which case the [TokenSource] is
// The initial token t may be nil, in which case the TokenSource is
// wrapped in a caching version if it isn't one already. This also
// means it's always safe to wrap ReuseTokenSource around any other
// [TokenSource] without adverse effects.
// TokenSource without adverse effects.
func ReuseTokenSource(t *Token, src TokenSource) TokenSource {
// Don't wrap a reuseTokenSource in itself. That would work,
// but cause an unnecessary number of mutex operations.
@ -396,8 +393,8 @@ func ReuseTokenSource(t *Token, src TokenSource) TokenSource {
}
}
// ReuseTokenSourceWithExpiry returns a [TokenSource] that acts in the same manner as the
// [TokenSource] returned by [ReuseTokenSource], except the expiry buffer is
// ReuseTokenSourceWithExpiry returns a TokenSource that acts in the same manner as the
// TokenSource returned by ReuseTokenSource, except the expiry buffer is
// configurable. The expiration time of a token is calculated as
// t.Expiry.Add(-earlyExpiry).
func ReuseTokenSourceWithExpiry(t *Token, src TokenSource, earlyExpiry time.Duration) TokenSource {

View File

@ -9,6 +9,7 @@ import (
"errors"
"fmt"
"io"
"io/ioutil"
"net/http"
"net/http/httptest"
"net/url"
@ -103,7 +104,7 @@ func TestExchangeRequest(t *testing.T) {
if headerContentType != "application/x-www-form-urlencoded" {
t.Errorf("Unexpected Content-Type header %q", headerContentType)
}
body, err := io.ReadAll(r.Body)
body, err := ioutil.ReadAll(r.Body)
if err != nil {
t.Errorf("Failed reading request body: %s.", err)
}
@ -147,7 +148,7 @@ func TestExchangeRequest_CustomParam(t *testing.T) {
if headerContentType != "application/x-www-form-urlencoded" {
t.Errorf("Unexpected Content-Type header, %v is found.", headerContentType)
}
body, err := io.ReadAll(r.Body)
body, err := ioutil.ReadAll(r.Body)
if err != nil {
t.Errorf("Failed reading request body: %s.", err)
}
@ -193,7 +194,7 @@ func TestExchangeRequest_JSONResponse(t *testing.T) {
if headerContentType != "application/x-www-form-urlencoded" {
t.Errorf("Unexpected Content-Type header, %v is found.", headerContentType)
}
body, err := io.ReadAll(r.Body)
body, err := ioutil.ReadAll(r.Body)
if err != nil {
t.Errorf("Failed reading request body: %s.", err)
}
@ -300,7 +301,7 @@ func testExchangeRequest_JSONResponse_expiry(t *testing.T, exp string, want, nul
conf := newConf(ts.URL)
t1 := time.Now().Add(day)
tok, err := conf.Exchange(context.Background(), "exchange-code")
t2 := time.Now().Add(day)
t2 := t1.Add(day)
if got := (err == nil); got != want {
if want {
@ -392,7 +393,7 @@ func TestPasswordCredentialsTokenRequest(t *testing.T) {
if headerContentType != expected {
t.Errorf("Content-Type header = %q; want %q", headerContentType, expected)
}
body, err := io.ReadAll(r.Body)
body, err := ioutil.ReadAll(r.Body)
if err != nil {
t.Errorf("Failed reading request body: %s.", err)
}
@ -434,7 +435,7 @@ func TestTokenRefreshRequest(t *testing.T) {
if headerContentType != "application/x-www-form-urlencoded" {
t.Errorf("Unexpected Content-Type header %q", headerContentType)
}
body, _ := io.ReadAll(r.Body)
body, _ := ioutil.ReadAll(r.Body)
if string(body) != "grant_type=refresh_token&refresh_token=REFRESH_TOKEN" {
t.Errorf("Unexpected refresh token payload %q", body)
}
@ -459,7 +460,7 @@ func TestFetchWithNoRefreshToken(t *testing.T) {
if headerContentType != "application/x-www-form-urlencoded" {
t.Errorf("Unexpected Content-Type header, %v is found.", headerContentType)
}
body, _ := io.ReadAll(r.Body)
body, _ := ioutil.ReadAll(r.Body)
if string(body) != "client_id=CLIENT_ID&grant_type=refresh_token&refresh_token=REFRESH_TOKEN" {
t.Errorf("Unexpected refresh token payload, %v is found.", string(body))
}

15
pkce.go
View File

@ -1,7 +1,6 @@
// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package oauth2
import (
@ -21,9 +20,9 @@ const (
// This follows recommendations in RFC 7636.
//
// A fresh verifier should be generated for each authorization.
// The resulting verifier should be passed to [Config.AuthCodeURL] or [Config.DeviceAuth]
// with [S256ChallengeOption], and to [Config.Exchange] or [Config.DeviceAccessToken]
// with [VerifierOption].
// S256ChallengeOption(verifier) should then be passed to Config.AuthCodeURL
// (or Config.DeviceAccess) and VerifierOption(verifier) to Config.Exchange
// (or Config.DeviceAccessToken).
func GenerateVerifier() string {
// "RECOMMENDED that the output of a suitable random number generator be
// used to create a 32-octet sequence. The octet sequence is then
@ -37,22 +36,22 @@ func GenerateVerifier() string {
return base64.RawURLEncoding.EncodeToString(data)
}
// VerifierOption returns a PKCE code verifier [AuthCodeOption]. It should only be
// passed to [Config.Exchange] or [Config.DeviceAccessToken].
// VerifierOption returns a PKCE code verifier AuthCodeOption. It should be
// passed to Config.Exchange or Config.DeviceAccessToken only.
func VerifierOption(verifier string) AuthCodeOption {
return setParam{k: codeVerifierKey, v: verifier}
}
// S256ChallengeFromVerifier returns a PKCE code challenge derived from verifier with method S256.
//
// Prefer to use [S256ChallengeOption] where possible.
// Prefer to use S256ChallengeOption where possible.
func S256ChallengeFromVerifier(verifier string) string {
sha := sha256.Sum256([]byte(verifier))
return base64.RawURLEncoding.EncodeToString(sha[:])
}
// S256ChallengeOption derives a PKCE code challenge derived from verifier with
// method S256. It should be passed to [Config.AuthCodeURL] or [Config.DeviceAuth]
// method S256. It should be passed to Config.AuthCodeURL or Config.DeviceAccess
// only.
func S256ChallengeOption(verifier string) AuthCodeOption {
return challengeOption{

View File

@ -44,21 +44,14 @@ type Token struct {
// Expiry is the optional expiration time of the access token.
//
// If zero, [TokenSource] implementations will reuse the same
// If zero, TokenSource implementations will reuse the same
// token forever and RefreshToken or equivalent
// mechanisms for that TokenSource will not be used.
Expiry time.Time `json:"expiry,omitempty"`
// ExpiresIn is the OAuth2 wire format "expires_in" field,
// which specifies how many seconds later the token expires,
// relative to an unknown time base approximately around "now".
// It is the application's responsibility to populate
// `Expiry` from `ExpiresIn` when required.
ExpiresIn int64 `json:"expires_in,omitempty"`
// raw optionally contains extra metadata from the server
// when updating a token.
raw any
raw interface{}
// expiryDelta is used to calculate when a token is considered
// expired, by subtracting from Expiry. If zero, defaultExpiryDelta
@ -86,16 +79,16 @@ func (t *Token) Type() string {
// SetAuthHeader sets the Authorization header to r using the access
// token in t.
//
// This method is unnecessary when using [Transport] or an HTTP Client
// This method is unnecessary when using Transport or an HTTP Client
// returned by this package.
func (t *Token) SetAuthHeader(r *http.Request) {
r.Header.Set("Authorization", t.Type()+" "+t.AccessToken)
}
// WithExtra returns a new [Token] that's a clone of t, but using the
// WithExtra returns a new Token that's a clone of t, but using the
// provided raw extra map. This is only intended for use by packages
// implementing derivative OAuth2 flows.
func (t *Token) WithExtra(extra any) *Token {
func (t *Token) WithExtra(extra interface{}) *Token {
t2 := new(Token)
*t2 = *t
t2.raw = extra
@ -105,8 +98,8 @@ func (t *Token) WithExtra(extra any) *Token {
// Extra returns an extra field.
// Extra fields are key-value pairs returned by the server as a
// part of the token retrieval response.
func (t *Token) Extra(key string) any {
if raw, ok := t.raw.(map[string]any); ok {
func (t *Token) Extra(key string) interface{} {
if raw, ok := t.raw.(map[string]interface{}); ok {
return raw[key]
}
@ -163,14 +156,13 @@ func tokenFromInternal(t *internal.Token) *Token {
TokenType: t.TokenType,
RefreshToken: t.RefreshToken,
Expiry: t.Expiry,
ExpiresIn: t.ExpiresIn,
raw: t.Raw,
}
}
// retrieveToken takes a *Config and uses that to retrieve an *internal.Token.
// This token is then mapped from *internal.Token into an *oauth2.Token which is returned along
// with an error.
// with an error..
func retrieveToken(ctx context.Context, c *Config, v url.Values) (*Token, error) {
tk, err := internal.RetrieveToken(ctx, c.ClientID, c.ClientSecret, c.Endpoint.TokenURL, v, internal.AuthStyle(c.Endpoint.AuthStyle), c.authStyleCache.Get())
if err != nil {

View File

@ -12,8 +12,8 @@ import (
func TestTokenExtra(t *testing.T) {
type testCase struct {
key string
val any
want any
val interface{}
want interface{}
}
const key = "extra-key"
cases := []testCase{
@ -23,7 +23,7 @@ func TestTokenExtra(t *testing.T) {
{key: "other-key", val: "def", want: nil},
}
for _, tc := range cases {
extra := make(map[string]any)
extra := make(map[string]interface{})
extra[tc.key] = tc.val
tok := &Token{raw: extra}
if got, want := tok.Extra(key), tc.want; got != want {

View File

@ -11,12 +11,12 @@ import (
"sync"
)
// Transport is an [http.RoundTripper] that makes OAuth 2.0 HTTP requests,
// wrapping a base [http.RoundTripper] and adding an Authorization header
// with a token from the supplied [TokenSource].
// Transport is an http.RoundTripper that makes OAuth 2.0 HTTP requests,
// wrapping a base RoundTripper and adding an Authorization header
// with a token from the supplied Sources.
//
// Transport is a low-level mechanism. Most code will use the
// higher-level [Config.Client] method instead.
// higher-level Config.Client method instead.
type Transport struct {
// Source supplies the token to add to outgoing requests'
// Authorization headers.
@ -47,7 +47,7 @@ func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
return nil, err
}
req2 := req.Clone(req.Context())
req2 := cloneRequest(req) // per RoundTripper contract
token.SetAuthHeader(req2)
// req.Body is assumed to be closed by the base RoundTripper.
@ -73,3 +73,17 @@ func (t *Transport) base() http.RoundTripper {
}
return http.DefaultTransport
}
// cloneRequest returns a clone of the provided *http.Request.
// The clone is a shallow copy of the struct and its Header map.
func cloneRequest(r *http.Request) *http.Request {
// shallow copy of the struct
r2 := new(http.Request)
*r2 = *r
// deep copy of the Header
r2.Header = make(http.Header, len(r.Header))
for k, s := range r.Header {
r2.Header[k] = append([]string(nil), s...)
}
return r2
}

View File

@ -9,6 +9,12 @@ import (
"time"
)
type tokenSource struct{ token *Token }
func (t *tokenSource) Token() (*Token, error) {
return t.token, nil
}
func TestTransportNilTokenSource(t *testing.T) {
tr := &Transport{}
server := newMockServer(func(w http.ResponseWriter, r *http.Request) {})
@ -82,10 +88,13 @@ func TestTransportCloseRequestBodySuccess(t *testing.T) {
}
func TestTransportTokenSource(t *testing.T) {
tr := &Transport{
Source: StaticTokenSource(&Token{
ts := &tokenSource{
token: &Token{
AccessToken: "abc",
}),
},
}
tr := &Transport{
Source: ts,
}
server := newMockServer(func(w http.ResponseWriter, r *http.Request) {
if got, want := r.Header.Get("Authorization"), "Bearer abc"; got != want {
@ -114,11 +123,14 @@ func TestTransportTokenSourceTypes(t *testing.T) {
{key: "basic", val: val, want: "Basic abc"},
}
for _, tc := range tests {
tr := &Transport{
Source: StaticTokenSource(&Token{
ts := &tokenSource{
token: &Token{
AccessToken: tc.val,
TokenType: tc.key,
}),
},
}
tr := &Transport{
Source: ts,
}
server := newMockServer(func(w http.ResponseWriter, r *http.Request) {
if got, want := r.Header.Get("Authorization"), tc.want; got != want {