diff --git a/google/internal/externalaccount/aws.go b/google/internal/externalaccount/aws.go index 6318a23..2bf3202 100644 --- a/google/internal/externalaccount/aws.go +++ b/google/internal/externalaccount/aws.go @@ -62,6 +62,13 @@ const ( // The AWS authorization header name for the auto-generated date. awsDateHeader = "x-amz-date" + // Supported AWS configuration environment variables. + awsAccessKeyId = "AWS_ACCESS_KEY_ID" + awsDefaultRegion = "AWS_DEFAULT_REGION" + awsRegion = "AWS_REGION" + awsSecretAccessKey = "AWS_SECRET_ACCESS_KEY" + awsSessionToken = "AWS_SESSION_TOKEN" + awsTimeFormatLong = "20060102T150405Z" awsTimeFormatShort = "20060102" ) @@ -317,16 +324,33 @@ func (cs awsCredentialSource) doRequest(req *http.Request) (*http.Response, erro return cs.client.Do(req.WithContext(cs.ctx)) } +func canRetrieveRegionFromEnvironment() bool { + // The AWS region can be provided through AWS_REGION or AWS_DEFAULT_REGION. Only one is + // required. + return getenv(awsRegion) != "" || getenv(awsDefaultRegion) != "" +} + +func canRetrieveSecurityCredentialFromEnvironment() bool { + // Check if both AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY are available. + return getenv(awsAccessKeyId) != "" && getenv(awsSecretAccessKey) != "" +} + +func shouldUseMetadataServer() bool { + return !canRetrieveRegionFromEnvironment() || !canRetrieveSecurityCredentialFromEnvironment() +} + func (cs awsCredentialSource) subjectToken() (string, error) { if cs.requestSigner == nil { - awsSessionToken, err := cs.getAWSSessionToken() - if err != nil { - return "", err - } - headers := make(map[string]string) - if awsSessionToken != "" { - headers[awsIMDSv2SessionTokenHeader] = awsSessionToken + if shouldUseMetadataServer() { + awsSessionToken, err := cs.getAWSSessionToken() + if err != nil { + return "", err + } + + if awsSessionToken != "" { + headers[awsIMDSv2SessionTokenHeader] = awsSessionToken + } } awsSecurityCredentials, err := cs.getSecurityCredentials(headers) @@ -432,11 +456,11 @@ func (cs *awsCredentialSource) getAWSSessionToken() (string, error) { } func (cs *awsCredentialSource) getRegion(headers map[string]string) (string, error) { - if envAwsRegion := getenv("AWS_REGION"); envAwsRegion != "" { - return envAwsRegion, nil - } - if envAwsRegion := getenv("AWS_DEFAULT_REGION"); envAwsRegion != "" { - return envAwsRegion, nil + if canRetrieveRegionFromEnvironment() { + if envAwsRegion := getenv(awsRegion); envAwsRegion != "" { + return envAwsRegion, nil + } + return getenv("AWS_DEFAULT_REGION"), nil } if cs.RegionURL == "" { @@ -477,14 +501,12 @@ func (cs *awsCredentialSource) getRegion(headers map[string]string) (string, err } func (cs *awsCredentialSource) getSecurityCredentials(headers map[string]string) (result awsSecurityCredentials, err error) { - if accessKeyID := getenv("AWS_ACCESS_KEY_ID"); accessKeyID != "" { - if secretAccessKey := getenv("AWS_SECRET_ACCESS_KEY"); secretAccessKey != "" { - return awsSecurityCredentials{ - AccessKeyID: accessKeyID, - SecretAccessKey: secretAccessKey, - SecurityToken: getenv("AWS_SESSION_TOKEN"), - }, nil - } + if canRetrieveSecurityCredentialFromEnvironment() { + return awsSecurityCredentials{ + AccessKeyID: getenv(awsAccessKeyId), + SecretAccessKey: getenv(awsSecretAccessKey), + SecurityToken: getenv(awsSessionToken), + }, nil } roleName, err := cs.getMetadataRoleName(headers) diff --git a/google/internal/externalaccount/aws_test.go b/google/internal/externalaccount/aws_test.go index 30a003a..058b004 100644 --- a/google/internal/externalaccount/aws_test.go +++ b/google/internal/externalaccount/aws_test.go @@ -474,6 +474,38 @@ func createDefaultAwsTestServer() *testAwsServer { ) } +func createDefaultAwsTestServerWithImdsv2(t *testing.T) *testAwsServer { + validateSessionTokenHeaders := func(r *http.Request) { + if r.URL.Path == "/latest/api/token" { + headerValue := r.Header.Get(awsIMDSv2SessionTtlHeader) + if headerValue != awsIMDSv2SessionTtl { + t.Errorf("%q = \n%q\n want \n%q", awsIMDSv2SessionTtlHeader, headerValue, awsIMDSv2SessionTtl) + } + } else { + headerValue := r.Header.Get(awsIMDSv2SessionTokenHeader) + if headerValue != "sessiontoken" { + t.Errorf("%q = \n%q\n want \n%q", awsIMDSv2SessionTokenHeader, headerValue, "sessiontoken") + } + } + } + + return createAwsTestServer( + "/latest/meta-data/iam/security-credentials", + "/latest/meta-data/placement/availability-zone", + "https://sts.{region}.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15", + "/latest/api/token", + "gcp-aws-role", + "us-east-2b", + map[string]string{ + "SecretAccessKey": secretAccessKey, + "AccessKeyId": accessKeyID, + "Token": securityToken, + }, + "sessiontoken", + validateSessionTokenHeaders, + ) +} + func (server *testAwsServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { switch p := r.URL.Path; p { case server.url: @@ -597,35 +629,7 @@ func TestAWSCredential_BasicRequest(t *testing.T) { } func TestAWSCredential_IMDSv2(t *testing.T) { - validateSessionTokenHeaders := func(r *http.Request) { - if r.URL.Path == "/latest/api/token" { - headerValue := r.Header.Get(awsIMDSv2SessionTtlHeader) - if headerValue != awsIMDSv2SessionTtl { - t.Errorf("%q = \n%q\n want \n%q", awsIMDSv2SessionTtlHeader, headerValue, awsIMDSv2SessionTtl) - } - } else { - headerValue := r.Header.Get(awsIMDSv2SessionTokenHeader) - if headerValue != "sessiontoken" { - t.Errorf("%q = \n%q\n want \n%q", awsIMDSv2SessionTokenHeader, headerValue, "sessiontoken") - } - } - } - - server := createAwsTestServer( - "/latest/meta-data/iam/security-credentials", - "/latest/meta-data/placement/availability-zone", - "https://sts.{region}.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15", - "/latest/api/token", - "gcp-aws-role", - "us-east-2b", - map[string]string{ - "SecretAccessKey": secretAccessKey, - "AccessKeyId": accessKeyID, - "Token": securityToken, - }, - "sessiontoken", - validateSessionTokenHeaders, - ) + server := createDefaultAwsTestServerWithImdsv2(t) ts := httptest.NewServer(server) tsURL, err := neturl.Parse(ts.URL) if err != nil { @@ -1152,6 +1156,208 @@ func TestAWSCredential_RequestWithBadFinalCredentialURL(t *testing.T) { } } +func TestAWSCredential_ShouldNotCallMetadataEndpointWhenCredsAreInEnv(t *testing.T) { + server := createDefaultAwsTestServer() + ts := httptest.NewServer(server) + tsURL, err := neturl.Parse(ts.URL) + if err != nil { + t.Fatalf("couldn't parse httptest servername") + } + + metadataTs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t.Error("Metadata server should not have been called.") + })) + + tfc := testFileConfig + tfc.CredentialSource = server.getCredentialSource(ts.URL) + tfc.CredentialSource.IMDSv2SessionTokenURL = metadataTs.URL + + oldGetenv := getenv + oldNow := now + oldValidHostnames := validHostnames + defer func() { + getenv = oldGetenv + now = oldNow + validHostnames = oldValidHostnames + }() + getenv = setEnvironment(map[string]string{ + "AWS_ACCESS_KEY_ID": "AKIDEXAMPLE", + "AWS_SECRET_ACCESS_KEY": "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY", + "AWS_REGION": "us-west-1", + }) + now = setTime(defaultTime) + validHostnames = []string{tsURL.Hostname()} + + base, err := tfc.parse(context.Background()) + if err != nil { + t.Fatalf("parse() failed %v", err) + } + + out, err := base.subjectToken() + if err != nil { + t.Fatalf("retrieveSubjectToken() failed: %v", err) + } + + expected := getExpectedSubjectToken( + "https://sts.us-west-1.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15", + "us-west-1", + "AKIDEXAMPLE", + "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY", + "", + ) + + if got, want := out, expected; !reflect.DeepEqual(got, want) { + t.Errorf("subjectToken = \n%q\n want \n%q", got, want) + } +} + +func TestAWSCredential_ShouldCallMetadataEndpointWhenNoRegion(t *testing.T) { + server := createDefaultAwsTestServerWithImdsv2(t) + ts := httptest.NewServer(server) + tsURL, err := neturl.Parse(ts.URL) + if err != nil { + t.Fatalf("couldn't parse httptest servername") + } + + tfc := testFileConfig + tfc.CredentialSource = server.getCredentialSource(ts.URL) + + oldGetenv := getenv + oldNow := now + oldValidHostnames := validHostnames + defer func() { + getenv = oldGetenv + now = oldNow + validHostnames = oldValidHostnames + }() + getenv = setEnvironment(map[string]string{ + "AWS_ACCESS_KEY_ID": accessKeyID, + "AWS_SECRET_ACCESS_KEY": secretAccessKey, + }) + now = setTime(defaultTime) + validHostnames = []string{tsURL.Hostname()} + + base, err := tfc.parse(context.Background()) + if err != nil { + t.Fatalf("parse() failed %v", err) + } + + out, err := base.subjectToken() + if err != nil { + t.Fatalf("retrieveSubjectToken() failed: %v", err) + } + + expected := getExpectedSubjectToken( + "https://sts.us-east-2.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15", + "us-east-2", + accessKeyID, + secretAccessKey, + "", + ) + + if got, want := out, expected; !reflect.DeepEqual(got, want) { + t.Errorf("subjectToken = \n%q\n want \n%q", got, want) + } +} + +func TestAWSCredential_ShouldCallMetadataEndpointWhenNoAccessKey(t *testing.T) { + server := createDefaultAwsTestServerWithImdsv2(t) + ts := httptest.NewServer(server) + tsURL, err := neturl.Parse(ts.URL) + if err != nil { + t.Fatalf("couldn't parse httptest servername") + } + + tfc := testFileConfig + tfc.CredentialSource = server.getCredentialSource(ts.URL) + + oldGetenv := getenv + oldNow := now + oldValidHostnames := validHostnames + defer func() { + getenv = oldGetenv + now = oldNow + validHostnames = oldValidHostnames + }() + getenv = setEnvironment(map[string]string{ + "AWS_SECRET_ACCESS_KEY": "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY", + "AWS_REGION": "us-west-1", + }) + now = setTime(defaultTime) + validHostnames = []string{tsURL.Hostname()} + + base, err := tfc.parse(context.Background()) + if err != nil { + t.Fatalf("parse() failed %v", err) + } + + out, err := base.subjectToken() + if err != nil { + t.Fatalf("retrieveSubjectToken() failed: %v", err) + } + + expected := getExpectedSubjectToken( + "https://sts.us-west-1.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15", + "us-west-1", + accessKeyID, + secretAccessKey, + securityToken, + ) + + if got, want := out, expected; !reflect.DeepEqual(got, want) { + t.Errorf("subjectToken = \n%q\n want \n%q", got, want) + } +} + +func TestAWSCredential_ShouldCallMetadataEndpointWhenNoSecretAccessKey(t *testing.T) { + server := createDefaultAwsTestServerWithImdsv2(t) + ts := httptest.NewServer(server) + tsURL, err := neturl.Parse(ts.URL) + if err != nil { + t.Fatalf("couldn't parse httptest servername") + } + + tfc := testFileConfig + tfc.CredentialSource = server.getCredentialSource(ts.URL) + + oldGetenv := getenv + oldNow := now + oldValidHostnames := validHostnames + defer func() { + getenv = oldGetenv + now = oldNow + validHostnames = oldValidHostnames + }() + getenv = setEnvironment(map[string]string{ + "AWS_ACCESS_KEY_ID": "AKIDEXAMPLE", + "AWS_REGION": "us-west-1", + }) + now = setTime(defaultTime) + validHostnames = []string{tsURL.Hostname()} + + base, err := tfc.parse(context.Background()) + if err != nil { + t.Fatalf("parse() failed %v", err) + } + + out, err := base.subjectToken() + if err != nil { + t.Fatalf("retrieveSubjectToken() failed: %v", err) + } + + expected := getExpectedSubjectToken( + "https://sts.us-west-1.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15", + "us-west-1", + accessKeyID, + secretAccessKey, + securityToken, + ) + + if got, want := out, expected; !reflect.DeepEqual(got, want) { + t.Errorf("subjectToken = \n%q\n want \n%q", got, want) + } +} + func TestAWSCredential_Validations(t *testing.T) { var metadataServerValidityTests = []struct { name string