Browse Source

feat: update to use Batch value get instead of List and Fetch all secrets for AWS provider (#4181)

* feat: update to use Batch value get instead of List and Fetch all secrets for AWS provider

Signed-off-by: Gergely Brautigam <182850+Skarlso@users.noreply.github.com>

* update the documentation and point to the right blog post

Signed-off-by: Gergely Brautigam <182850+Skarlso@users.noreply.github.com>

* update to fall back to ListSecrets in case path is not defined

Signed-off-by: Gergely Brautigam <182850+Skarlso@users.noreply.github.com>

---------

Signed-off-by: Gergely Brautigam <182850+Skarlso@users.noreply.github.com>
Gergely Brautigam 1 year ago
parent
commit
c684e8c360

+ 8 - 0
docs/provider/aws-secrets-manager.md

@@ -12,10 +12,17 @@ way users of the `SecretStore` can only access the secrets necessary.
 {% include 'aws-sm-store.yaml' %}
 ```
 **NOTE:** In case of a `ClusterSecretStore`, Be sure to provide `namespace` in `accessKeyIDSecretRef` and `secretAccessKeySecretRef`  with the namespaces where the secrets reside.
+
+**NOTE:** When using `dataFrom` without a `path` defined, the provider will fall back to using `ListSecrets`. `ListSecrets`
+then proceeds to fetch each individual secret in turn. To use `BatchGetSecretValue` and avoid excessive API calls define
+a `path` prefix or use `Tags` filter.
+
 ### IAM Policy
 
 Create a IAM Policy to pin down access to secrets matching `dev-*`.
 
+For Batch permissions read the following post https://aws.amazon.com/about-aws/whats-new/2023/11/aws-secrets-manager-batch-retrieval-secrets/.
+
 ``` json
 {
   "Version": "2012-10-17",
@@ -27,6 +34,7 @@ Create a IAM Policy to pin down access to secrets matching `dev-*`.
         "secretsmanager:GetSecretValue",
         "secretsmanager:DescribeSecret",
         "secretsmanager:ListSecretVersionIds"
+        "secretsmanager:BatchGetSecretValue"
       ],
       "Resource": [
         "arn:aws:secretsmanager:us-west-2:111122223333:secret:dev-*"

+ 3 - 4
e2e/suites/provider/cases/aws/secretsmanager/secretsmanager.go

@@ -26,10 +26,9 @@ import (
 )
 
 const (
-	withStaticAuth         = "with static auth"
-	withExtID              = "with externalID"
-	withSessionTags        = "with session tags"
-	withReferentStaticAuth = "with static referent auth"
+	withStaticAuth  = "with static auth"
+	withExtID       = "with externalID"
+	withSessionTags = "with session tags"
 )
 
 var _ = Describe("[aws] ", Label("aws", "secretsmanager"), func() {

+ 1 - 0
pkg/constants/constants.go

@@ -23,6 +23,7 @@ const (
 	CallAWSSMCreateSecret        = "CreateSecret"
 	CallAWSSMPutSecretValue      = "PutSecretValue"
 	CallAWSSMListSecrets         = "ListSecrets"
+	CallAWSSMBatchGetSecretValue = "BatchGetSecretValue"
 
 	ProviderAWSPS                = "AWS/ParameterStore"
 	CallAWSPSGetParameter        = "GetParameter"

+ 1 - 3
pkg/controllers/externalsecret/externalsecret_controller_secret.go

@@ -150,7 +150,6 @@ func (r *Reconciler) handleGenerateSecrets(ctx context.Context, namespace string
 	return secretMap, err
 }
 
-//nolint:dupl
 func (r *Reconciler) handleExtractSecrets(ctx context.Context, externalSecret *esv1beta1.ExternalSecret, remoteRef esv1beta1.ExternalSecretDataFromRemoteRef, cmgr *secretstore.Manager) (map[string][]byte, error) {
 	client, err := cmgr.Get(ctx, externalSecret.Spec.SecretStoreRef, externalSecret.Namespace, remoteRef.SourceRef)
 	if err != nil {
@@ -190,7 +189,6 @@ func (r *Reconciler) handleExtractSecrets(ctx context.Context, externalSecret *e
 	return secretMap, err
 }
 
-//nolint:dupl
 func (r *Reconciler) handleFindAllSecrets(ctx context.Context, externalSecret *esv1beta1.ExternalSecret, remoteRef esv1beta1.ExternalSecretDataFromRemoteRef, cmgr *secretstore.Manager) (map[string][]byte, error) {
 	client, err := cmgr.Get(ctx, externalSecret.Spec.SecretStoreRef, externalSecret.Namespace, remoteRef.SourceRef)
 	if err != nil {
@@ -200,7 +198,7 @@ func (r *Reconciler) handleFindAllSecrets(ctx context.Context, externalSecret *e
 	// get all secrets from the store that match the selector
 	secretMap, err := client.GetAllSecrets(ctx, *remoteRef.Find)
 	if err != nil {
-		return nil, err
+		return nil, fmt.Errorf("error getting all secrets: %w", err)
 	}
 
 	// rewrite the keys if needed

+ 14 - 8
pkg/provider/aws/secretsmanager/fake/fake.go

@@ -28,14 +28,15 @@ import (
 
 // Client implements the aws secretsmanager interface.
 type Client struct {
-	ExecutionCounter            int
-	valFn                       map[string]func(*awssm.GetSecretValueInput) (*awssm.GetSecretValueOutput, error)
-	CreateSecretWithContextFn   CreateSecretWithContextFn
-	GetSecretValueWithContextFn GetSecretValueWithContextFn
-	PutSecretValueWithContextFn PutSecretValueWithContextFn
-	DescribeSecretWithContextFn DescribeSecretWithContextFn
-	DeleteSecretWithContextFn   DeleteSecretWithContextFn
-	ListSecretsFn               ListSecretsFn
+	ExecutionCounter                 int
+	valFn                            map[string]func(*awssm.GetSecretValueInput) (*awssm.GetSecretValueOutput, error)
+	CreateSecretWithContextFn        CreateSecretWithContextFn
+	GetSecretValueWithContextFn      GetSecretValueWithContextFn
+	PutSecretValueWithContextFn      PutSecretValueWithContextFn
+	DescribeSecretWithContextFn      DescribeSecretWithContextFn
+	DeleteSecretWithContextFn        DeleteSecretWithContextFn
+	ListSecretsFn                    ListSecretsFn
+	BatchGetSecretValueWithContextFn BatchGetSecretValueWithContextFn
 }
 
 type CreateSecretWithContextFn func(aws.Context, *awssm.CreateSecretInput, ...request.Option) (*awssm.CreateSecretOutput, error)
@@ -44,6 +45,7 @@ type PutSecretValueWithContextFn func(aws.Context, *awssm.PutSecretValueInput, .
 type DescribeSecretWithContextFn func(aws.Context, *awssm.DescribeSecretInput, ...request.Option) (*awssm.DescribeSecretOutput, error)
 type DeleteSecretWithContextFn func(ctx aws.Context, input *awssm.DeleteSecretInput, opts ...request.Option) (*awssm.DeleteSecretOutput, error)
 type ListSecretsFn func(ctx aws.Context, input *awssm.ListSecretsInput, opts ...request.Option) (*awssm.ListSecretsOutput, error)
+type BatchGetSecretValueWithContextFn func(aws.Context, *awssm.BatchGetSecretValueInput, ...request.Option) (*awssm.BatchGetSecretValueOutput, error)
 
 func (sm Client) CreateSecretWithContext(ctx aws.Context, input *awssm.CreateSecretInput, options ...request.Option) (*awssm.CreateSecretOutput, error) {
 	return sm.CreateSecretWithContextFn(ctx, input, options...)
@@ -164,6 +166,10 @@ func (sm *Client) ListSecrets(input *awssm.ListSecretsInput) (*awssm.ListSecrets
 	return sm.ListSecretsFn(nil, input)
 }
 
+func (sm *Client) BatchGetSecretValueWithContext(_ aws.Context, in *awssm.BatchGetSecretValueInput, _ ...request.Option) (*awssm.BatchGetSecretValueOutput, error) {
+	return sm.BatchGetSecretValueWithContextFn(nil, in)
+}
+
 func (sm *Client) cacheKeyForInput(in *awssm.GetSecretValueInput) string {
 	var secretID, versionID string
 	if in.SecretId != nil {

+ 47 - 27
pkg/provider/aws/secretsmanager/secretsmanager.go

@@ -65,6 +65,7 @@ type SecretsManager struct {
 // SMInterface is a subset of the smiface api.
 // see: https://docs.aws.amazon.com/sdk-for-go/api/service/secretsmanager/secretsmanageriface/
 type SMInterface interface {
+	BatchGetSecretValueWithContext(aws.Context, *awssm.BatchGetSecretValueInput, ...request.Option) (*awssm.BatchGetSecretValueOutput, error)
 	ListSecrets(*awssm.ListSecretsInput) (*awssm.ListSecretsOutput, error)
 	GetSecretValue(*awssm.GetSecretValueInput) (*awssm.GetSecretValueOutput, error)
 	CreateSecretWithContext(aws.Context, *awssm.CreateSecretInput, ...request.Option) (*awssm.CreateSecretOutput, error)
@@ -348,12 +349,16 @@ func (sm *SecretsManager) findByName(ctx context.Context, ref esv1beta1.External
 				ref.Path,
 			},
 		})
+
+		return sm.fetchWithBatch(ctx, filters, matcher)
 	}
 
 	data := make(map[string][]byte)
 	var nextToken *string
 
 	for {
+		// I put this into the for loop on purpose.
+		log.V(0).Info("using ListSecret to fetch all secrets; this is a costly operations, please use batching by defining a _path_")
 		it, err := sm.client.ListSecrets(&awssm.ListSecretsInput{
 			Filters:   filters,
 			NextToken: nextToken,
@@ -368,8 +373,7 @@ func (sm *SecretsManager) findByName(ctx context.Context, ref esv1beta1.External
 				continue
 			}
 			log.V(1).Info("aws sm findByName matches", "name", *secret.Name)
-			err = sm.fetchAndSet(ctx, data, *secret.Name)
-			if err != nil {
+			if err := sm.fetchAndSet(ctx, data, *secret.Name); err != nil {
 				return nil, err
 			}
 		}
@@ -406,31 +410,7 @@ func (sm *SecretsManager) findByTags(ctx context.Context, ref esv1beta1.External
 		})
 	}
 
-	data := make(map[string][]byte)
-	var nextToken *string
-	for {
-		log.V(1).Info("aws sm findByTag", "nextToken", nextToken)
-		it, err := sm.client.ListSecrets(&awssm.ListSecretsInput{
-			Filters:   filters,
-			NextToken: nextToken,
-		})
-		metrics.ObserveAPICall(constants.ProviderAWSSM, constants.CallAWSSMListSecrets, err)
-		if err != nil {
-			return nil, err
-		}
-		log.V(1).Info("aws sm findByTag found", "secrets", len(it.SecretList))
-		for _, secret := range it.SecretList {
-			err = sm.fetchAndSet(ctx, data, *secret.Name)
-			if err != nil {
-				return nil, err
-			}
-		}
-		nextToken = it.NextToken
-		if nextToken == nil {
-			break
-		}
-	}
-	return data, nil
+	return sm.fetchWithBatch(ctx, filters, nil)
 }
 
 func (sm *SecretsManager) fetchAndSet(ctx context.Context, data map[string][]byte, name string) error {
@@ -614,3 +594,43 @@ func (sm *SecretsManager) putSecretValueWithContext(ctx context.Context, secretI
 
 	return err
 }
+
+func (sm *SecretsManager) fetchWithBatch(ctx context.Context, filters []*awssm.Filter, matcher *find.Matcher) (map[string][]byte, error) {
+	data := make(map[string][]byte)
+	var nextToken *string
+
+	for {
+		it, err := sm.client.BatchGetSecretValueWithContext(ctx, &awssm.BatchGetSecretValueInput{
+			Filters:   filters,
+			NextToken: nextToken,
+		})
+		metrics.ObserveAPICall(constants.ProviderAWSSM, constants.CallAWSSMBatchGetSecretValue, err)
+		if err != nil {
+			return nil, err
+		}
+		log.V(1).Info("aws sm findByName found", "secrets", len(it.SecretValues))
+		for _, secret := range it.SecretValues {
+			if matcher != nil && !matcher.MatchName(*secret.Name) {
+				continue
+			}
+			log.V(1).Info("aws sm findByName matches", "name", *secret.Name)
+
+			sm.setSecretValues(secret, data)
+		}
+		nextToken = it.NextToken
+		if nextToken == nil {
+			break
+		}
+	}
+
+	return data, nil
+}
+
+func (sm *SecretsManager) setSecretValues(secret *awssm.SecretValueEntry, data map[string][]byte) {
+	if secret.SecretString != nil {
+		data[*secret.Name] = []byte(*secret.SecretString)
+	}
+	if secret.SecretBinary != nil {
+		data[*secret.Name] = secret.SecretBinary
+	}
+}

+ 43 - 39
pkg/provider/aws/secretsmanager/secretsmanager_test.go

@@ -1064,17 +1064,15 @@ func TestSecretsManagerGetAllSecrets(t *testing.T) {
 	}
 	// Test cases
 	testCases := []struct {
-		name string
-		ref  esv1beta1.ExternalSecretFind
-
-		secretName    string
-		secretVersion string
-		secretValue   string
-		fetchError    error
-		listSecretsFn func(ctx context.Context, input *awssm.ListSecretsInput, opts ...request.Option) (*awssm.ListSecretsOutput, error)
-
-		expectedData  map[string][]byte
-		expectedError string
+		name                             string
+		ref                              esv1beta1.ExternalSecretFind
+		secretName                       string
+		secretVersion                    string
+		secretValue                      string
+		batchGetSecretValueWithContextFn func(aws.Context, *awssm.BatchGetSecretValueInput, ...request.Option) (*awssm.BatchGetSecretValueOutput, error)
+		listSecretsFn                    func(ctx context.Context, input *awssm.ListSecretsInput, opts ...request.Option) (*awssm.ListSecretsOutput, error)
+		expectedData                     map[string][]byte
+		expectedError                    string
 	}{
 		{
 			name: "Matching secrets found",
@@ -1087,14 +1085,16 @@ func TestSecretsManagerGetAllSecrets(t *testing.T) {
 			secretName:    secretName,
 			secretVersion: secretVersion,
 			secretValue:   secretValue,
-			listSecretsFn: func(ctx context.Context, input *awssm.ListSecretsInput, opts ...request.Option) (*awssm.ListSecretsOutput, error) {
+			batchGetSecretValueWithContextFn: func(_ aws.Context, input *awssm.BatchGetSecretValueInput, _ ...request.Option) (*awssm.BatchGetSecretValueOutput, error) {
 				assert.Len(t, input.Filters, 1)
 				assert.Equal(t, "name", *input.Filters[0].Key)
 				assert.Equal(t, secretPath, *input.Filters[0].Values[0])
-				return &awssm.ListSecretsOutput{
-					SecretList: []*awssm.SecretListEntry{
+				return &awssm.BatchGetSecretValueOutput{
+					SecretValues: []*awssm.SecretValueEntry{
 						{
-							Name: ptr.To(secretName),
+							Name:          ptr.To(secretName),
+							VersionStages: []*string{ptr.To(secretVersion)},
+							SecretBinary:  []byte(secretValue),
 						},
 					},
 				}, nil
@@ -1115,15 +1115,14 @@ func TestSecretsManagerGetAllSecrets(t *testing.T) {
 			secretName:    secretName,
 			secretVersion: secretVersion,
 			secretValue:   secretValue,
-			fetchError:    errBoom,
-			listSecretsFn: func(ctx context.Context, input *awssm.ListSecretsInput, opts ...request.Option) (*awssm.ListSecretsOutput, error) {
-				return &awssm.ListSecretsOutput{
-					SecretList: []*awssm.SecretListEntry{
+			batchGetSecretValueWithContextFn: func(aws.Context, *awssm.BatchGetSecretValueInput, ...request.Option) (*awssm.BatchGetSecretValueOutput, error) {
+				return &awssm.BatchGetSecretValueOutput{
+					SecretValues: []*awssm.SecretValueEntry{
 						{
 							Name: ptr.To(secretName),
 						},
 					},
-				}, nil
+				}, errBoom
 			},
 			expectedData:  nil,
 			expectedError: errBoom.Error(),
@@ -1157,6 +1156,15 @@ func TestSecretsManagerGetAllSecrets(t *testing.T) {
 					},
 				}, nil
 			},
+			batchGetSecretValueWithContextFn: func(aws.Context, *awssm.BatchGetSecretValueInput, ...request.Option) (*awssm.BatchGetSecretValueOutput, error) {
+				return &awssm.BatchGetSecretValueOutput{
+					SecretValues: []*awssm.SecretValueEntry{
+						{
+							Name: ptr.To("other-secret"),
+						},
+					},
+				}, nil
+			},
 			expectedData:  make(map[string][]byte),
 			expectedError: "",
 		},
@@ -1179,16 +1187,18 @@ func TestSecretsManagerGetAllSecrets(t *testing.T) {
 			secretName:    secretName,
 			secretVersion: secretVersion,
 			secretValue:   secretValue,
-			listSecretsFn: func(ctx context.Context, input *awssm.ListSecretsInput, opts ...request.Option) (*awssm.ListSecretsOutput, error) {
+			batchGetSecretValueWithContextFn: func(_ aws.Context, input *awssm.BatchGetSecretValueInput, _ ...request.Option) (*awssm.BatchGetSecretValueOutput, error) {
 				assert.Len(t, input.Filters, 2)
 				assert.Equal(t, "tag-key", *input.Filters[0].Key)
 				assert.Equal(t, "foo", *input.Filters[0].Values[0])
 				assert.Equal(t, "tag-value", *input.Filters[1].Key)
 				assert.Equal(t, "bar", *input.Filters[1].Values[0])
-				return &awssm.ListSecretsOutput{
-					SecretList: []*awssm.SecretListEntry{
+				return &awssm.BatchGetSecretValueOutput{
+					SecretValues: []*awssm.SecretValueEntry{
 						{
-							Name: ptr.To(secretName),
+							Name:          ptr.To(secretName),
+							VersionStages: []*string{ptr.To(secretVersion)},
+							SecretBinary:  []byte(secretValue),
 						},
 					},
 				}, nil
@@ -1206,15 +1216,16 @@ func TestSecretsManagerGetAllSecrets(t *testing.T) {
 			secretName:    secretName,
 			secretVersion: secretVersion,
 			secretValue:   secretValue,
-			fetchError:    errBoom,
-			listSecretsFn: func(ctx context.Context, input *awssm.ListSecretsInput, opts ...request.Option) (*awssm.ListSecretsOutput, error) {
-				return &awssm.ListSecretsOutput{
-					SecretList: []*awssm.SecretListEntry{
+			batchGetSecretValueWithContextFn: func(aws.Context, *awssm.BatchGetSecretValueInput, ...request.Option) (*awssm.BatchGetSecretValueOutput, error) {
+				return &awssm.BatchGetSecretValueOutput{
+					SecretValues: []*awssm.SecretValueEntry{
 						{
-							Name: ptr.To(secretName),
+							Name:          ptr.To(secretName),
+							VersionStages: []*string{ptr.To(secretVersion)},
+							SecretBinary:  []byte(secretValue),
 						},
 					},
-				}, nil
+				}, errBoom
 			},
 			expectedData:  nil,
 			expectedError: errBoom.Error(),
@@ -1224,7 +1235,7 @@ func TestSecretsManagerGetAllSecrets(t *testing.T) {
 			ref: esv1beta1.ExternalSecretFind{
 				Tags: secretTags,
 			},
-			listSecretsFn: func(ctx context.Context, input *awssm.ListSecretsInput, opts ...request.Option) (*awssm.ListSecretsOutput, error) {
+			batchGetSecretValueWithContextFn: func(aws.Context, *awssm.BatchGetSecretValueInput, ...request.Option) (*awssm.BatchGetSecretValueOutput, error) {
 				return nil, errBoom
 			},
 			expectedData:  nil,
@@ -1235,15 +1246,8 @@ func TestSecretsManagerGetAllSecrets(t *testing.T) {
 	for _, tc := range testCases {
 		t.Run(tc.name, func(t *testing.T) {
 			fc := fakesm.NewClient()
+			fc.BatchGetSecretValueWithContextFn = tc.batchGetSecretValueWithContextFn
 			fc.ListSecretsFn = tc.listSecretsFn
-			fc.WithValue(&awssm.GetSecretValueInput{
-				SecretId:     ptr.To(tc.secretName),
-				VersionStage: ptr.To(tc.secretVersion),
-			}, &awssm.GetSecretValueOutput{
-				Name:          ptr.To(tc.secretName),
-				VersionStages: []*string{ptr.To(tc.secretVersion)},
-				SecretBinary:  []byte(tc.secretValue),
-			}, tc.fetchError)
 			sm := SecretsManager{
 				client: fc,
 				cache:  make(map[string]*awssm.GetSecretValueOutput),