google: add Credentials.UniverseDomainProvider

* move MDS universe retrieval within Compute credentials

Change-Id: I847d2075ca11bde998a06220307626e902230c23
Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/575936
Reviewed-by: Cody Oss <codyoss@google.com>
Auto-Submit: Cody Oss <codyoss@google.com>
Run-TryBot: Cody Oss <codyoss@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
This commit is contained in:
Chris Smith 2024-04-02 16:20:57 -06:00 committed by Gopher Robot
parent 3c9c1f6d00
commit d0e617c58c
2 changed files with 58 additions and 41 deletions

View File

@ -42,6 +42,17 @@ type Credentials struct {
// running on Google Cloud Platform. // running on Google Cloud Platform.
JSON []byte JSON []byte
// UniverseDomainProvider returns the default service domain for a given
// Cloud universe. Optional.
//
// On GCE, UniverseDomainProvider should return the universe domain value
// from Google Compute Engine (GCE)'s metadata server. See also [The attached service
// account](https://cloud.google.com/docs/authentication/application-default-credentials#attached-sa).
// If the GCE metadata server returns a 404 error, the default universe
// domain value should be returned. If the GCE metadata server returns an
// error other than 404, the error should be returned.
UniverseDomainProvider func() (string, error)
udMu sync.Mutex // guards universeDomain udMu sync.Mutex // guards universeDomain
// universeDomain is the default service domain for a given Cloud universe. // universeDomain is the default service domain for a given Cloud universe.
universeDomain string universeDomain string
@ -64,54 +75,32 @@ func (c *Credentials) UniverseDomain() string {
} }
// GetUniverseDomain returns the default service domain for a given Cloud // GetUniverseDomain returns the default service domain for a given Cloud
// universe. // universe. If present, UniverseDomainProvider will be invoked and its return
// value will be cached.
// //
// The default value is "googleapis.com". // The default value is "googleapis.com".
//
// It obtains the universe domain from the attached service account on GCE when
// authenticating via the GCE metadata server. See also [The attached service
// account](https://cloud.google.com/docs/authentication/application-default-credentials#attached-sa).
// If the GCE metadata server returns a 404 error, the default value is
// returned. If the GCE metadata server returns an error other than 404, the
// error is returned.
func (c *Credentials) GetUniverseDomain() (string, error) { func (c *Credentials) GetUniverseDomain() (string, error) {
c.udMu.Lock() c.udMu.Lock()
defer c.udMu.Unlock() defer c.udMu.Unlock()
if c.universeDomain == "" && metadata.OnGCE() { if c.universeDomain == "" && c.UniverseDomainProvider != nil {
// If we're on Google Compute Engine, an App Engine standard second // On Google Compute Engine, an App Engine standard second generation
// generation runtime, or App Engine flexible, use the metadata server. // runtime, or App Engine flexible, use an externally provided function
err := c.computeUniverseDomain() // to request the universe domain from the metadata server.
ud, err := c.UniverseDomainProvider()
if err != nil { if err != nil {
return "", err return "", err
} }
c.universeDomain = ud
} }
// If not on Google Compute Engine, or in case of any non-error path in // If no UniverseDomainProvider (meaning not on Google Compute Engine), or
// computeUniverseDomain that did not set universeDomain, set the default // in case of any (non-error) empty return value from
// universe domain. // UniverseDomainProvider, set the default universe domain.
if c.universeDomain == "" { if c.universeDomain == "" {
c.universeDomain = defaultUniverseDomain c.universeDomain = defaultUniverseDomain
} }
return c.universeDomain, nil return c.universeDomain, nil
} }
// computeUniverseDomain fetches the default service domain for a given Cloud
// universe from Google Compute Engine (GCE)'s metadata server. It's only valid
// to use this method if your program is running on a GCE instance.
func (c *Credentials) computeUniverseDomain() error {
var err error
c.universeDomain, err = metadata.Get("universe/universe_domain")
if err != nil {
if _, ok := err.(metadata.NotDefinedError); ok {
// http.StatusNotFound (404)
c.universeDomain = defaultUniverseDomain
return nil
} else {
return err
}
}
return nil
}
// DefaultCredentials is the old name of Credentials. // DefaultCredentials is the old name of Credentials.
// //
// Deprecated: use Credentials instead. // Deprecated: use Credentials instead.
@ -226,10 +215,23 @@ func FindDefaultCredentialsWithParams(ctx context.Context, params CredentialsPar
// or App Engine flexible, use the metadata server. // or App Engine flexible, use the metadata server.
if metadata.OnGCE() { if metadata.OnGCE() {
id, _ := metadata.ProjectID() id, _ := metadata.ProjectID()
universeDomainProvider := func() (string, error) {
universeDomain, err := metadata.Get("universe/universe_domain")
if err != nil {
if _, ok := err.(metadata.NotDefinedError); ok {
// http.StatusNotFound (404)
return defaultUniverseDomain, nil
} else {
return "", err
}
}
return universeDomain, nil
}
return &Credentials{ return &Credentials{
ProjectID: id, ProjectID: id,
TokenSource: computeTokenSource("", params.EarlyTokenRefresh, params.Scopes...), TokenSource: computeTokenSource("", params.EarlyTokenRefresh, params.Scopes...),
universeDomain: params.UniverseDomain, UniverseDomainProvider: universeDomainProvider,
universeDomain: params.UniverseDomain,
}, nil }, nil
} }

View File

@ -10,6 +10,8 @@ import (
"net/http/httptest" "net/http/httptest"
"strings" "strings"
"testing" "testing"
"cloud.google.com/go/compute/metadata"
) )
var saJSONJWT = []byte(`{ var saJSONJWT = []byte(`{
@ -255,9 +257,14 @@ func TestCredentialsFromJSONWithParams_User_UniverseDomain_Params_UniverseDomain
func TestComputeUniverseDomain(t *testing.T) { func TestComputeUniverseDomain(t *testing.T) {
universeDomainPath := "/computeMetadata/v1/universe/universe_domain" universeDomainPath := "/computeMetadata/v1/universe/universe_domain"
universeDomainResponseBody := "example.com" universeDomainResponseBody := "example.com"
var requests int
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
requests++
if r.URL.Path != universeDomainPath { if r.URL.Path != universeDomainPath {
t.Errorf("got %s, want %s", r.URL.Path, universeDomainPath) t.Errorf("bad path, got %s, want %s", r.URL.Path, universeDomainPath)
}
if requests > 1 {
t.Errorf("too many requests, got %d, want 1", requests)
} }
w.Write([]byte(universeDomainResponseBody)) w.Write([]byte(universeDomainResponseBody))
})) }))
@ -268,11 +275,19 @@ func TestComputeUniverseDomain(t *testing.T) {
params := CredentialsParams{ params := CredentialsParams{
Scopes: []string{scope}, Scopes: []string{scope},
} }
universeDomainProvider := func() (string, error) {
universeDomain, err := metadata.Get("universe/universe_domain")
if err != nil {
return "", err
}
return universeDomain, nil
}
// Copied from FindDefaultCredentialsWithParams, metadata.OnGCE() = true block // Copied from FindDefaultCredentialsWithParams, metadata.OnGCE() = true block
creds := &Credentials{ creds := &Credentials{
ProjectID: "fake_project", ProjectID: "fake_project",
TokenSource: computeTokenSource("", params.EarlyTokenRefresh, params.Scopes...), TokenSource: computeTokenSource("", params.EarlyTokenRefresh, params.Scopes...),
universeDomain: params.UniverseDomain, // empty UniverseDomainProvider: universeDomainProvider,
universeDomain: params.UniverseDomain, // empty
} }
c := make(chan bool) c := make(chan bool)
go func() { go func() {
@ -285,7 +300,7 @@ func TestComputeUniverseDomain(t *testing.T) {
} }
c <- true c <- true
}() }()
got, err := creds.GetUniverseDomain() // Second conflicting access. got, err := creds.GetUniverseDomain() // Second conflicting (and potentially uncached) access.
<-c <-c
if err != nil { if err != nil {
t.Error(err) t.Error(err)