From 510acbce1f1678162c5fae5bde59f0d03e14cb6d Mon Sep 17 00:00:00 2001 From: aeitzman Date: Wed, 30 Nov 2022 16:37:52 +0000 Subject: [PATCH] google/internal/externalaccount: Added check for aws region and security credential environment variables before aws metadata call Adds check for aws values in environment variables before the metadata server is called to prevent unnecessary off box calls. See https://github.com/googleapis/google-auth-library-java/pull/1100 for same change in java library. Change-Id: Ie86a899be88c38d3fcbbe377f9bf30a7a66530c0 GitHub-Last-Rev: bcab69572cb0dca4c7c6426203d4232e6e89d8db GitHub-Pull-Request: golang/oauth2#612 Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/453715 Reviewed-by: Leo Siracusa TryBot-Result: Gopher Robot Run-TryBot: Cody Oss Auto-Submit: Cody Oss Reviewed-by: Cody Oss --- google/internal/externalaccount/aws.go | 62 +++-- google/internal/externalaccount/aws_test.go | 264 +++++++++++++++++--- 2 files changed, 277 insertions(+), 49 deletions(-) diff --git a/google/internal/externalaccount/aws.go b/google/internal/externalaccount/aws.go index 6318a23..2bf3202 100644 --- a/google/internal/externalaccount/aws.go +++ b/google/internal/externalaccount/aws.go @@ -62,6 +62,13 @@ const ( // The AWS authorization header name for the auto-generated date. awsDateHeader = "x-amz-date" + // Supported AWS configuration environment variables. + awsAccessKeyId = "AWS_ACCESS_KEY_ID" + awsDefaultRegion = "AWS_DEFAULT_REGION" + awsRegion = "AWS_REGION" + awsSecretAccessKey = "AWS_SECRET_ACCESS_KEY" + awsSessionToken = "AWS_SESSION_TOKEN" + awsTimeFormatLong = "20060102T150405Z" awsTimeFormatShort = "20060102" ) @@ -317,16 +324,33 @@ func (cs awsCredentialSource) doRequest(req *http.Request) (*http.Response, erro return cs.client.Do(req.WithContext(cs.ctx)) } +func canRetrieveRegionFromEnvironment() bool { + // The AWS region can be provided through AWS_REGION or AWS_DEFAULT_REGION. Only one is + // required. + return getenv(awsRegion) != "" || getenv(awsDefaultRegion) != "" +} + +func canRetrieveSecurityCredentialFromEnvironment() bool { + // Check if both AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY are available. + return getenv(awsAccessKeyId) != "" && getenv(awsSecretAccessKey) != "" +} + +func shouldUseMetadataServer() bool { + return !canRetrieveRegionFromEnvironment() || !canRetrieveSecurityCredentialFromEnvironment() +} + func (cs awsCredentialSource) subjectToken() (string, error) { if cs.requestSigner == nil { - awsSessionToken, err := cs.getAWSSessionToken() - if err != nil { - return "", err - } - headers := make(map[string]string) - if awsSessionToken != "" { - headers[awsIMDSv2SessionTokenHeader] = awsSessionToken + if shouldUseMetadataServer() { + awsSessionToken, err := cs.getAWSSessionToken() + if err != nil { + return "", err + } + + if awsSessionToken != "" { + headers[awsIMDSv2SessionTokenHeader] = awsSessionToken + } } awsSecurityCredentials, err := cs.getSecurityCredentials(headers) @@ -432,11 +456,11 @@ func (cs *awsCredentialSource) getAWSSessionToken() (string, error) { } func (cs *awsCredentialSource) getRegion(headers map[string]string) (string, error) { - if envAwsRegion := getenv("AWS_REGION"); envAwsRegion != "" { - return envAwsRegion, nil - } - if envAwsRegion := getenv("AWS_DEFAULT_REGION"); envAwsRegion != "" { - return envAwsRegion, nil + if canRetrieveRegionFromEnvironment() { + if envAwsRegion := getenv(awsRegion); envAwsRegion != "" { + return envAwsRegion, nil + } + return getenv("AWS_DEFAULT_REGION"), nil } if cs.RegionURL == "" { @@ -477,14 +501,12 @@ func (cs *awsCredentialSource) getRegion(headers map[string]string) (string, err } func (cs *awsCredentialSource) getSecurityCredentials(headers map[string]string) (result awsSecurityCredentials, err error) { - if accessKeyID := getenv("AWS_ACCESS_KEY_ID"); accessKeyID != "" { - if secretAccessKey := getenv("AWS_SECRET_ACCESS_KEY"); secretAccessKey != "" { - return awsSecurityCredentials{ - AccessKeyID: accessKeyID, - SecretAccessKey: secretAccessKey, - SecurityToken: getenv("AWS_SESSION_TOKEN"), - }, nil - } + if canRetrieveSecurityCredentialFromEnvironment() { + return awsSecurityCredentials{ + AccessKeyID: getenv(awsAccessKeyId), + SecretAccessKey: getenv(awsSecretAccessKey), + SecurityToken: getenv(awsSessionToken), + }, nil } roleName, err := cs.getMetadataRoleName(headers) diff --git a/google/internal/externalaccount/aws_test.go b/google/internal/externalaccount/aws_test.go index 30a003a..058b004 100644 --- a/google/internal/externalaccount/aws_test.go +++ b/google/internal/externalaccount/aws_test.go @@ -474,6 +474,38 @@ func createDefaultAwsTestServer() *testAwsServer { ) } +func createDefaultAwsTestServerWithImdsv2(t *testing.T) *testAwsServer { + validateSessionTokenHeaders := func(r *http.Request) { + if r.URL.Path == "/latest/api/token" { + headerValue := r.Header.Get(awsIMDSv2SessionTtlHeader) + if headerValue != awsIMDSv2SessionTtl { + t.Errorf("%q = \n%q\n want \n%q", awsIMDSv2SessionTtlHeader, headerValue, awsIMDSv2SessionTtl) + } + } else { + headerValue := r.Header.Get(awsIMDSv2SessionTokenHeader) + if headerValue != "sessiontoken" { + t.Errorf("%q = \n%q\n want \n%q", awsIMDSv2SessionTokenHeader, headerValue, "sessiontoken") + } + } + } + + return createAwsTestServer( + "/latest/meta-data/iam/security-credentials", + "/latest/meta-data/placement/availability-zone", + "https://sts.{region}.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15", + "/latest/api/token", + "gcp-aws-role", + "us-east-2b", + map[string]string{ + "SecretAccessKey": secretAccessKey, + "AccessKeyId": accessKeyID, + "Token": securityToken, + }, + "sessiontoken", + validateSessionTokenHeaders, + ) +} + func (server *testAwsServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { switch p := r.URL.Path; p { case server.url: @@ -597,35 +629,7 @@ func TestAWSCredential_BasicRequest(t *testing.T) { } func TestAWSCredential_IMDSv2(t *testing.T) { - validateSessionTokenHeaders := func(r *http.Request) { - if r.URL.Path == "/latest/api/token" { - headerValue := r.Header.Get(awsIMDSv2SessionTtlHeader) - if headerValue != awsIMDSv2SessionTtl { - t.Errorf("%q = \n%q\n want \n%q", awsIMDSv2SessionTtlHeader, headerValue, awsIMDSv2SessionTtl) - } - } else { - headerValue := r.Header.Get(awsIMDSv2SessionTokenHeader) - if headerValue != "sessiontoken" { - t.Errorf("%q = \n%q\n want \n%q", awsIMDSv2SessionTokenHeader, headerValue, "sessiontoken") - } - } - } - - server := createAwsTestServer( - "/latest/meta-data/iam/security-credentials", - "/latest/meta-data/placement/availability-zone", - "https://sts.{region}.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15", - "/latest/api/token", - "gcp-aws-role", - "us-east-2b", - map[string]string{ - "SecretAccessKey": secretAccessKey, - "AccessKeyId": accessKeyID, - "Token": securityToken, - }, - "sessiontoken", - validateSessionTokenHeaders, - ) + server := createDefaultAwsTestServerWithImdsv2(t) ts := httptest.NewServer(server) tsURL, err := neturl.Parse(ts.URL) if err != nil { @@ -1152,6 +1156,208 @@ func TestAWSCredential_RequestWithBadFinalCredentialURL(t *testing.T) { } } +func TestAWSCredential_ShouldNotCallMetadataEndpointWhenCredsAreInEnv(t *testing.T) { + server := createDefaultAwsTestServer() + ts := httptest.NewServer(server) + tsURL, err := neturl.Parse(ts.URL) + if err != nil { + t.Fatalf("couldn't parse httptest servername") + } + + metadataTs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t.Error("Metadata server should not have been called.") + })) + + tfc := testFileConfig + tfc.CredentialSource = server.getCredentialSource(ts.URL) + tfc.CredentialSource.IMDSv2SessionTokenURL = metadataTs.URL + + oldGetenv := getenv + oldNow := now + oldValidHostnames := validHostnames + defer func() { + getenv = oldGetenv + now = oldNow + validHostnames = oldValidHostnames + }() + getenv = setEnvironment(map[string]string{ + "AWS_ACCESS_KEY_ID": "AKIDEXAMPLE", + "AWS_SECRET_ACCESS_KEY": "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY", + "AWS_REGION": "us-west-1", + }) + now = setTime(defaultTime) + validHostnames = []string{tsURL.Hostname()} + + base, err := tfc.parse(context.Background()) + if err != nil { + t.Fatalf("parse() failed %v", err) + } + + out, err := base.subjectToken() + if err != nil { + t.Fatalf("retrieveSubjectToken() failed: %v", err) + } + + expected := getExpectedSubjectToken( + "https://sts.us-west-1.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15", + "us-west-1", + "AKIDEXAMPLE", + "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY", + "", + ) + + if got, want := out, expected; !reflect.DeepEqual(got, want) { + t.Errorf("subjectToken = \n%q\n want \n%q", got, want) + } +} + +func TestAWSCredential_ShouldCallMetadataEndpointWhenNoRegion(t *testing.T) { + server := createDefaultAwsTestServerWithImdsv2(t) + ts := httptest.NewServer(server) + tsURL, err := neturl.Parse(ts.URL) + if err != nil { + t.Fatalf("couldn't parse httptest servername") + } + + tfc := testFileConfig + tfc.CredentialSource = server.getCredentialSource(ts.URL) + + oldGetenv := getenv + oldNow := now + oldValidHostnames := validHostnames + defer func() { + getenv = oldGetenv + now = oldNow + validHostnames = oldValidHostnames + }() + getenv = setEnvironment(map[string]string{ + "AWS_ACCESS_KEY_ID": accessKeyID, + "AWS_SECRET_ACCESS_KEY": secretAccessKey, + }) + now = setTime(defaultTime) + validHostnames = []string{tsURL.Hostname()} + + base, err := tfc.parse(context.Background()) + if err != nil { + t.Fatalf("parse() failed %v", err) + } + + out, err := base.subjectToken() + if err != nil { + t.Fatalf("retrieveSubjectToken() failed: %v", err) + } + + expected := getExpectedSubjectToken( + "https://sts.us-east-2.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15", + "us-east-2", + accessKeyID, + secretAccessKey, + "", + ) + + if got, want := out, expected; !reflect.DeepEqual(got, want) { + t.Errorf("subjectToken = \n%q\n want \n%q", got, want) + } +} + +func TestAWSCredential_ShouldCallMetadataEndpointWhenNoAccessKey(t *testing.T) { + server := createDefaultAwsTestServerWithImdsv2(t) + ts := httptest.NewServer(server) + tsURL, err := neturl.Parse(ts.URL) + if err != nil { + t.Fatalf("couldn't parse httptest servername") + } + + tfc := testFileConfig + tfc.CredentialSource = server.getCredentialSource(ts.URL) + + oldGetenv := getenv + oldNow := now + oldValidHostnames := validHostnames + defer func() { + getenv = oldGetenv + now = oldNow + validHostnames = oldValidHostnames + }() + getenv = setEnvironment(map[string]string{ + "AWS_SECRET_ACCESS_KEY": "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY", + "AWS_REGION": "us-west-1", + }) + now = setTime(defaultTime) + validHostnames = []string{tsURL.Hostname()} + + base, err := tfc.parse(context.Background()) + if err != nil { + t.Fatalf("parse() failed %v", err) + } + + out, err := base.subjectToken() + if err != nil { + t.Fatalf("retrieveSubjectToken() failed: %v", err) + } + + expected := getExpectedSubjectToken( + "https://sts.us-west-1.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15", + "us-west-1", + accessKeyID, + secretAccessKey, + securityToken, + ) + + if got, want := out, expected; !reflect.DeepEqual(got, want) { + t.Errorf("subjectToken = \n%q\n want \n%q", got, want) + } +} + +func TestAWSCredential_ShouldCallMetadataEndpointWhenNoSecretAccessKey(t *testing.T) { + server := createDefaultAwsTestServerWithImdsv2(t) + ts := httptest.NewServer(server) + tsURL, err := neturl.Parse(ts.URL) + if err != nil { + t.Fatalf("couldn't parse httptest servername") + } + + tfc := testFileConfig + tfc.CredentialSource = server.getCredentialSource(ts.URL) + + oldGetenv := getenv + oldNow := now + oldValidHostnames := validHostnames + defer func() { + getenv = oldGetenv + now = oldNow + validHostnames = oldValidHostnames + }() + getenv = setEnvironment(map[string]string{ + "AWS_ACCESS_KEY_ID": "AKIDEXAMPLE", + "AWS_REGION": "us-west-1", + }) + now = setTime(defaultTime) + validHostnames = []string{tsURL.Hostname()} + + base, err := tfc.parse(context.Background()) + if err != nil { + t.Fatalf("parse() failed %v", err) + } + + out, err := base.subjectToken() + if err != nil { + t.Fatalf("retrieveSubjectToken() failed: %v", err) + } + + expected := getExpectedSubjectToken( + "https://sts.us-west-1.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15", + "us-west-1", + accessKeyID, + secretAccessKey, + securityToken, + ) + + if got, want := out, expected; !reflect.DeepEqual(got, want) { + t.Errorf("subjectToken = \n%q\n want \n%q", got, want) + } +} + func TestAWSCredential_Validations(t *testing.T) { var metadataServerValidityTests = []struct { name string