diff --git a/internal/token.go b/internal/token.go index 8a10204..558ce95 100644 --- a/internal/token.go +++ b/internal/token.go @@ -6,6 +6,7 @@ package internal import ( "encoding/json" + "errors" "fmt" "io" "io/ioutil" @@ -250,6 +251,9 @@ func RetrieveToken(ctx context.Context, clientID, clientSecret, tokenURL string, if token.RefreshToken == "" { token.RefreshToken = v.Get("refresh_token") } + if token.AccessToken == "" { + return token, errors.New("oauth2: server response missing access_token") + } return token, nil } diff --git a/internal/token_test.go b/internal/token_test.go index 9118d82..7b52e51 100644 --- a/internal/token_test.go +++ b/internal/token_test.go @@ -33,7 +33,8 @@ func TestRetrieveTokenBustedNoSecret(t *testing.T) { if got, want := r.FormValue("client_secret"), ""; got != want { t.Errorf("client_secret = %q; want empty", got) } - io.WriteString(w, "{}") // something non-empty, required to set a Content-Type in Go 1.10 + w.Header().Set("Content-Type", "application/json") + io.WriteString(w, `{"access_token": "ACCESS_TOKEN", "token_type": "bearer"}`) })) defer ts.Close() @@ -85,7 +86,8 @@ func TestRetrieveTokenWithContexts(t *testing.T) { const clientID = "client-id" ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - io.WriteString(w, "{}") // something non-empty, required to set a Content-Type in Go 1.10 + w.Header().Set("Content-Type", "application/json") + io.WriteString(w, `{"access_token": "ACCESS_TOKEN", "token_type": "bearer"}`) })) defer ts.Close() diff --git a/oauth2_test.go b/oauth2_test.go index 4937901..847160f 100644 --- a/oauth2_test.go +++ b/oauth2_test.go @@ -278,12 +278,9 @@ func TestExchangeRequest_BadResponse(t *testing.T) { })) defer ts.Close() conf := newConf(ts.URL) - tok, err := conf.Exchange(context.Background(), "code") - if err != nil { - t.Fatal(err) - } - if tok.AccessToken != "" { - t.Errorf("Unexpected access token, %#v.", tok.AccessToken) + _, err := conf.Exchange(context.Background(), "code") + if err == nil { + t.Error("expected error from missing access_token") } } @@ -296,7 +293,7 @@ func TestExchangeRequest_BadResponseType(t *testing.T) { conf := newConf(ts.URL) _, err := conf.Exchange(context.Background(), "exchange-code") if err == nil { - t.Error("expected error from invalid access_token type") + t.Error("expected error from non-string access_token") } }