mirror of
https://github.com/golang/oauth2.git
synced 2025-07-21 00:00:09 +08:00
With https://go.dev/issue/61417 implemented, we can use the token type directly to unmarshal the JSON fields for the wire format. While here, remove all uses of the deprecated ioutil package as suggested by gopls while making these changes. Change-Id: I79d82374643007a21b5b3d9a8117bed81273eca5 Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/614415 Reviewed-by: Sean Liao <sean@liao.dev> LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com> Reviewed-by: Dmitri Shuralyov <dmitshur@google.com> Reviewed-by: Michael Pratt <mpratt@google.com>
260 lines
6.7 KiB
Go
260 lines
6.7 KiB
Go
// Copyright 2023 The Go Authors. All rights reserved.
|
|
// Use of this source code is governed by a BSD-style
|
|
// license that can be found in the LICENSE file.
|
|
|
|
package externalaccountauthorizeduser
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"io"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"testing"
|
|
"time"
|
|
|
|
"golang.org/x/oauth2"
|
|
"golang.org/x/oauth2/google/internal/stsexchange"
|
|
)
|
|
|
|
const expiryDelta = 10 * time.Second
|
|
|
|
var (
|
|
expiry = time.Unix(234852, 0)
|
|
testNow = func() time.Time { return expiry }
|
|
testValid = func(t oauth2.Token) bool {
|
|
return t.AccessToken != "" && !t.Expiry.Round(0).Add(-expiryDelta).Before(testNow())
|
|
}
|
|
)
|
|
|
|
type testRefreshTokenServer struct {
|
|
URL string
|
|
Authorization string
|
|
ContentType string
|
|
Body string
|
|
ResponsePayload *stsexchange.Response
|
|
Response string
|
|
server *httptest.Server
|
|
}
|
|
|
|
func TestExternalAccountAuthorizedUser_JustToken(t *testing.T) {
|
|
config := &Config{
|
|
Token: "AAAAAAA",
|
|
Expiry: now().Add(time.Hour),
|
|
}
|
|
ts, err := config.TokenSource(context.Background())
|
|
if err != nil {
|
|
t.Fatalf("Error getting token source: %v", err)
|
|
}
|
|
|
|
token, err := ts.Token()
|
|
if err != nil {
|
|
t.Fatalf("Error retrieving Token: %v", err)
|
|
}
|
|
if got, want := token.AccessToken, "AAAAAAA"; got != want {
|
|
t.Fatalf("Unexpected access token, got %v, want %v", got, want)
|
|
}
|
|
}
|
|
|
|
func TestExternalAccountAuthorizedUser_TokenRefreshWithRefreshTokenInResponse(t *testing.T) {
|
|
server := &testRefreshTokenServer{
|
|
URL: "/",
|
|
Authorization: "Basic Q0xJRU5UX0lEOkNMSUVOVF9TRUNSRVQ=",
|
|
ContentType: "application/x-www-form-urlencoded",
|
|
Body: "grant_type=refresh_token&refresh_token=BBBBBBBBB",
|
|
ResponsePayload: &stsexchange.Response{
|
|
ExpiresIn: 3600,
|
|
AccessToken: "AAAAAAA",
|
|
RefreshToken: "CCCCCCC",
|
|
},
|
|
}
|
|
|
|
url, err := server.run(t)
|
|
if err != nil {
|
|
t.Fatalf("Error starting server")
|
|
}
|
|
defer server.close(t)
|
|
|
|
config := &Config{
|
|
RefreshToken: "BBBBBBBBB",
|
|
TokenURL: url,
|
|
ClientID: "CLIENT_ID",
|
|
ClientSecret: "CLIENT_SECRET",
|
|
}
|
|
ts, err := config.TokenSource(context.Background())
|
|
if err != nil {
|
|
t.Fatalf("Error getting token source: %v", err)
|
|
}
|
|
|
|
token, err := ts.Token()
|
|
if err != nil {
|
|
t.Fatalf("Error retrieving Token: %v", err)
|
|
}
|
|
if got, want := token.AccessToken, "AAAAAAA"; got != want {
|
|
t.Fatalf("Unexpected access token, got %v, want %v", got, want)
|
|
}
|
|
if config.RefreshToken != "CCCCCCC" {
|
|
t.Fatalf("Refresh token not updated")
|
|
}
|
|
}
|
|
|
|
func TestExternalAccountAuthorizedUser_MinimumFieldsRequiredForRefresh(t *testing.T) {
|
|
server := &testRefreshTokenServer{
|
|
URL: "/",
|
|
Authorization: "Basic Q0xJRU5UX0lEOkNMSUVOVF9TRUNSRVQ=",
|
|
ContentType: "application/x-www-form-urlencoded",
|
|
Body: "grant_type=refresh_token&refresh_token=BBBBBBBBB",
|
|
ResponsePayload: &stsexchange.Response{
|
|
ExpiresIn: 3600,
|
|
AccessToken: "AAAAAAA",
|
|
},
|
|
}
|
|
|
|
url, err := server.run(t)
|
|
if err != nil {
|
|
t.Fatalf("Error starting server")
|
|
}
|
|
defer server.close(t)
|
|
|
|
config := &Config{
|
|
RefreshToken: "BBBBBBBBB",
|
|
TokenURL: url,
|
|
ClientID: "CLIENT_ID",
|
|
ClientSecret: "CLIENT_SECRET",
|
|
}
|
|
ts, err := config.TokenSource(context.Background())
|
|
if err != nil {
|
|
t.Fatalf("Error getting token source: %v", err)
|
|
}
|
|
|
|
token, err := ts.Token()
|
|
if err != nil {
|
|
t.Fatalf("Error retrieving Token: %v", err)
|
|
}
|
|
if got, want := token.AccessToken, "AAAAAAA"; got != want {
|
|
t.Fatalf("Unexpected access token, got %v, want %v", got, want)
|
|
}
|
|
}
|
|
|
|
func TestExternalAccountAuthorizedUser_MissingRefreshFields(t *testing.T) {
|
|
server := &testRefreshTokenServer{
|
|
URL: "/",
|
|
Authorization: "Basic Q0xJRU5UX0lEOkNMSUVOVF9TRUNSRVQ=",
|
|
ContentType: "application/x-www-form-urlencoded",
|
|
Body: "grant_type=refresh_token&refresh_token=BBBBBBBBB",
|
|
ResponsePayload: &stsexchange.Response{
|
|
ExpiresIn: 3600,
|
|
AccessToken: "AAAAAAA",
|
|
},
|
|
}
|
|
|
|
url, err := server.run(t)
|
|
if err != nil {
|
|
t.Fatalf("Error starting server")
|
|
}
|
|
defer server.close(t)
|
|
testCases := []struct {
|
|
name string
|
|
config Config
|
|
}{
|
|
{
|
|
name: "empty config",
|
|
config: Config{},
|
|
},
|
|
{
|
|
name: "missing refresh token",
|
|
config: Config{
|
|
TokenURL: url,
|
|
ClientID: "CLIENT_ID",
|
|
ClientSecret: "CLIENT_SECRET",
|
|
},
|
|
},
|
|
{
|
|
name: "missing token url",
|
|
config: Config{
|
|
RefreshToken: "BBBBBBBBB",
|
|
ClientID: "CLIENT_ID",
|
|
ClientSecret: "CLIENT_SECRET",
|
|
},
|
|
},
|
|
{
|
|
name: "missing client id",
|
|
config: Config{
|
|
RefreshToken: "BBBBBBBBB",
|
|
TokenURL: url,
|
|
ClientSecret: "CLIENT_SECRET",
|
|
},
|
|
},
|
|
{
|
|
name: "missing client secret",
|
|
config: Config{
|
|
RefreshToken: "BBBBBBBBB",
|
|
TokenURL: url,
|
|
ClientID: "CLIENT_ID",
|
|
},
|
|
},
|
|
}
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
|
|
expectErrMsg := "oauth2/google: Token should be created with fields to make it valid (`token` and `expiry`), or fields to allow it to refresh (`refresh_token`, `token_url`, `client_id`, `client_secret`)."
|
|
_, err := tc.config.TokenSource((context.Background()))
|
|
if err == nil {
|
|
t.Fatalf("Expected error, but received none")
|
|
}
|
|
if got := err.Error(); got != expectErrMsg {
|
|
t.Fatalf("Unexpected error, got %v, want %v", got, expectErrMsg)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func (trts *testRefreshTokenServer) run(t *testing.T) (string, error) {
|
|
t.Helper()
|
|
if trts.server != nil {
|
|
return "", errors.New("Server is already running")
|
|
}
|
|
trts.server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if got, want := r.URL.String(), trts.URL; got != want {
|
|
t.Errorf("URL.String(): got %v but want %v", got, want)
|
|
}
|
|
headerAuth := r.Header.Get("Authorization")
|
|
if got, want := headerAuth, trts.Authorization; got != want {
|
|
t.Errorf("got %v but want %v", got, want)
|
|
}
|
|
headerContentType := r.Header.Get("Content-Type")
|
|
if got, want := headerContentType, trts.ContentType; got != want {
|
|
t.Errorf("got %v but want %v", got, want)
|
|
}
|
|
body, err := io.ReadAll(r.Body)
|
|
if err != nil {
|
|
t.Fatalf("Failed reading request body: %s.", err)
|
|
}
|
|
if got, want := string(body), trts.Body; got != want {
|
|
t.Errorf("Unexpected exchange payload: got %v but want %v", got, want)
|
|
}
|
|
w.Header().Set("Content-Type", "application/json")
|
|
if trts.ResponsePayload != nil {
|
|
content, err := json.Marshal(trts.ResponsePayload)
|
|
if err != nil {
|
|
t.Fatalf("unable to marshall response JSON")
|
|
}
|
|
w.Write(content)
|
|
} else {
|
|
w.Write([]byte(trts.Response))
|
|
}
|
|
}))
|
|
return trts.server.URL, nil
|
|
}
|
|
|
|
func (trts *testRefreshTokenServer) close(t *testing.T) error {
|
|
t.Helper()
|
|
if trts.server == nil {
|
|
return errors.New("No server is running")
|
|
}
|
|
trts.server.Close()
|
|
trts.server = nil
|
|
return nil
|
|
}
|