Browse Source

feat: added cache in aws secret manager provider

Alexander Chernov 4 years ago
parent
commit
dae7237953

+ 27 - 3
pkg/provider/aws/secretsmanager/fake/fake.go

@@ -11,6 +11,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 See the License for the specific language governing permissions and
 limitations under the License.
 */
+
 package fake
 
 import (
@@ -22,15 +23,38 @@ import (
 
 // Client implements the aws secretsmanager interface.
 type Client struct {
-	valFn func(*awssm.GetSecretValueInput) (*awssm.GetSecretValueOutput, error)
+	ExecutionCounter int
+	valFn            map[string]func(*awssm.GetSecretValueInput) (*awssm.GetSecretValueOutput, error)
+}
+
+// NewClient init a new fake client.
+func NewClient() *Client {
+	return &Client{
+		valFn: make(map[string]func(*awssm.GetSecretValueInput) (*awssm.GetSecretValueOutput, error)),
+	}
 }
 
 func (sm *Client) GetSecretValue(in *awssm.GetSecretValueInput) (*awssm.GetSecretValueOutput, error) {
-	return sm.valFn(in)
+	sm.ExecutionCounter++
+	if entry, found := sm.valFn[sm.cacheKeyForInput(in)]; found {
+		return entry(in)
+	}
+	return nil, fmt.Errorf("test case not found")
+}
+
+func (sm *Client) cacheKeyForInput(in *awssm.GetSecretValueInput) string {
+	var secretID, versionID string
+	if in.SecretId != nil {
+		secretID = *in.SecretId
+	}
+	if in.VersionId != nil {
+		versionID = *in.VersionId
+	}
+	return fmt.Sprintf("%s#%s", secretID, versionID)
 }
 
 func (sm *Client) WithValue(in *awssm.GetSecretValueInput, val *awssm.GetSecretValueOutput, err error) {
-	sm.valFn = func(paramIn *awssm.GetSecretValueInput) (*awssm.GetSecretValueOutput, error) {
+	sm.valFn[sm.cacheKeyForInput(in)] = func(paramIn *awssm.GetSecretValueInput) (*awssm.GetSecretValueOutput, error) {
 		if !cmp.Equal(paramIn, in) {
 			return nil, fmt.Errorf("unexpected test argument")
 		}

+ 22 - 2
pkg/provider/aws/secretsmanager/secretsmanager.go

@@ -11,6 +11,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 See the License for the specific language governing permissions and
 limitations under the License.
 */
+
 package secretsmanager
 
 import (
@@ -30,6 +31,7 @@ import (
 // SecretsManager is a provider for AWS SecretsManager.
 type SecretsManager struct {
 	client SMInterface
+	cache  map[string]*awssm.GetSecretValueOutput
 }
 
 // SMInterface is a subset of the smiface api.
@@ -44,21 +46,38 @@ var log = ctrl.Log.WithName("provider").WithName("aws").WithName("secretsmanager
 func New(sess client.ConfigProvider) (*SecretsManager, error) {
 	return &SecretsManager{
 		client: awssm.New(sess),
+		cache:  make(map[string]*awssm.GetSecretValueOutput),
 	}, nil
 }
 
-// GetSecret returns a single secret from the provider.
-func (sm *SecretsManager) GetSecret(ctx context.Context, ref esv1alpha1.ExternalSecretDataRemoteRef) ([]byte, error) {
+func (sm *SecretsManager) fetch(_ context.Context, ref esv1alpha1.ExternalSecretDataRemoteRef) (*awssm.GetSecretValueOutput, error) {
 	ver := "AWSCURRENT"
 	if ref.Version != "" {
 		ver = ref.Version
 	}
 	log.Info("fetching secret value", "key", ref.Key, "version", ver)
+
+	cacheKey := fmt.Sprintf("%s#%s", ref.Key, ver)
+	if secretOut, found := sm.cache[cacheKey]; found {
+		log.Info("found secret in cache", "key", ref.Key, "version", ver)
+		return secretOut, nil
+	}
 	secretOut, err := sm.client.GetSecretValue(&awssm.GetSecretValueInput{
 		SecretId:     &ref.Key,
 		VersionStage: &ver,
 	})
 	if err != nil {
+		return nil, err
+	}
+	sm.cache[cacheKey] = secretOut
+
+	return secretOut, nil
+}
+
+// GetSecret returns a single secret from the provider.
+func (sm *SecretsManager) GetSecret(ctx context.Context, ref esv1alpha1.ExternalSecretDataRemoteRef) ([]byte, error) {
+	secretOut, err := sm.fetch(ctx, ref)
+	if err != nil {
 		return nil, util.SanitizeErr(err)
 	}
 	if ref.Property == "" {
@@ -77,6 +96,7 @@ func (sm *SecretsManager) GetSecret(ctx context.Context, ref esv1alpha1.External
 	if secretOut.SecretBinary != nil {
 		payload = string(secretOut.SecretBinary)
 	}
+
 	val := gjson.Get(payload, ref.Property)
 	if !val.Exists() {
 		return nil, fmt.Errorf("key %s does not exist in secret %s", ref.Property, ref.Key)

+ 75 - 4
pkg/provider/aws/secretsmanager/secretsmanager_test.go

@@ -11,6 +11,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 See the License for the specific language governing permissions and
 limitations under the License.
 */
+
 package secretsmanager
 
 import (
@@ -37,11 +38,13 @@ type secretsManagerTestCase struct {
 	expectedSecret string
 	// for testing secretmap
 	expectedData map[string][]byte
+	// for testing caching
+	expectedCounter *int
 }
 
 func makeValidSecretsManagerTestCase() *secretsManagerTestCase {
 	smtc := secretsManagerTestCase{
-		fakeClient:     &fakesm.Client{},
+		fakeClient:     fakesm.NewClient(),
 		apiInput:       makeValidAPIInput(),
 		remoteRef:      makeValidRemoteRef(),
 		apiOutput:      makeValidAPIOutput(),
@@ -164,8 +167,59 @@ func TestSecretsManagerGetSecret(t *testing.T) {
 		makeValidSecretsManagerTestCaseCustom(setAPIErr),
 	}
 
-	sm := SecretsManager{}
 	for k, v := range successCases {
+		sm := SecretsManager{
+			cache:  make(map[string]*awssm.GetSecretValueOutput),
+			client: v.fakeClient,
+		}
+		out, err := sm.GetSecret(context.Background(), *v.remoteRef)
+		if !ErrorContains(err, v.expectError) {
+			t.Errorf("[%d] unexpected error: %s, expected: '%s'", k, err.Error(), v.expectError)
+		}
+		if err == nil && string(out) != v.expectedSecret {
+			t.Errorf("[%d] unexpected secret: expected %s, got %s", k, v.expectedSecret, string(out))
+		}
+	}
+}
+func TestCaching(t *testing.T) {
+	fakeClient := fakesm.NewClient()
+
+	// good case: first call, since we are using the same key, results should be cached and the counter should not go
+	// over 1
+	firstCall := func(smtc *secretsManagerTestCase) {
+		smtc.apiOutput.SecretString = aws.String(`{"foo":"bar", "bar":"vodka"}`)
+		smtc.remoteRef.Property = "foo"
+		smtc.expectedSecret = "bar"
+		smtc.expectedCounter = aws.Int(1)
+		smtc.fakeClient = fakeClient
+	}
+	secondCall := func(smtc *secretsManagerTestCase) {
+		smtc.apiOutput.SecretString = aws.String(`{"foo":"bar", "bar":"vodka"}`)
+		smtc.remoteRef.Property = "bar"
+		smtc.expectedSecret = "vodka"
+		smtc.expectedCounter = aws.Int(1)
+		smtc.fakeClient = fakeClient
+	}
+	notCachedCall := func(smtc *secretsManagerTestCase) {
+		smtc.apiOutput.SecretString = aws.String(`{"sheldon":"bazinga", "bar":"foo"}`)
+		smtc.remoteRef.Property = "sheldon"
+		smtc.expectedSecret = "bazinga"
+		smtc.expectedCounter = aws.Int(2)
+		smtc.fakeClient = fakeClient
+		smtc.apiInput.SecretId = aws.String("xyz")
+		smtc.remoteRef.Key = "xyz" // it should reset the cache since the key is different
+	}
+
+	cachedCases := []*secretsManagerTestCase{
+		makeValidSecretsManagerTestCaseCustom(firstCall),
+		makeValidSecretsManagerTestCaseCustom(firstCall),
+		makeValidSecretsManagerTestCaseCustom(secondCall),
+		makeValidSecretsManagerTestCaseCustom(notCachedCall),
+	}
+	sm := SecretsManager{
+		cache: make(map[string]*awssm.GetSecretValueOutput),
+	}
+	for k, v := range cachedCases {
 		sm.client = v.fakeClient
 		out, err := sm.GetSecret(context.Background(), *v.remoteRef)
 		if !ErrorContains(err, v.expectError) {
@@ -174,6 +228,9 @@ func TestSecretsManagerGetSecret(t *testing.T) {
 		if err == nil && string(out) != v.expectedSecret {
 			t.Errorf("[%d] unexpected secret: expected %s, got %s", k, v.expectedSecret, string(out))
 		}
+		if v.expectedCounter != nil && v.fakeClient.ExecutionCounter != *v.expectedCounter {
+			t.Errorf("[%d] unexpected counter value: expected %d, got %d", k, v.expectedCounter, v.fakeClient.ExecutionCounter)
+		}
 	}
 }
 
@@ -184,6 +241,14 @@ func TestGetSecretMap(t *testing.T) {
 		smtc.expectedData["foo"] = []byte("bar")
 	}
 
+	// good case: caching
+	cachedMap := func(smtc *secretsManagerTestCase) {
+		smtc.apiOutput.SecretString = aws.String(`{"foo":"bar", "plus": "one"}`)
+		smtc.expectedData["foo"] = []byte("bar")
+		smtc.expectedData["plus"] = []byte("one")
+		smtc.expectedCounter = aws.Int(1)
+	}
+
 	// bad case: invalid json
 	setInvalidJSON := func(smtc *secretsManagerTestCase) {
 		smtc.apiOutput.SecretString = aws.String(`-----------------`)
@@ -194,11 +259,14 @@ func TestGetSecretMap(t *testing.T) {
 		makeValidSecretsManagerTestCaseCustom(setDeserialization),
 		makeValidSecretsManagerTestCaseCustom(setAPIErr),
 		makeValidSecretsManagerTestCaseCustom(setInvalidJSON),
+		makeValidSecretsManagerTestCaseCustom(cachedMap),
 	}
 
-	sm := SecretsManager{}
 	for k, v := range successCases {
-		sm.client = v.fakeClient
+		sm := SecretsManager{
+			cache:  make(map[string]*awssm.GetSecretValueOutput),
+			client: v.fakeClient,
+		}
 		out, err := sm.GetSecretMap(context.Background(), *v.remoteRef)
 		if !ErrorContains(err, v.expectError) {
 			t.Errorf("[%d] unexpected error: %s, expected: '%s'", k, err.Error(), v.expectError)
@@ -206,6 +274,9 @@ func TestGetSecretMap(t *testing.T) {
 		if err == nil && !cmp.Equal(out, v.expectedData) {
 			t.Errorf("[%d] unexpected secret data: expected %#v, got %#v", k, v.expectedData, out)
 		}
+		if v.expectedCounter != nil && v.fakeClient.ExecutionCounter != *v.expectedCounter {
+			t.Errorf("[%d] unexpected counter value: expected %d, got %d", k, v.expectedCounter, v.fakeClient.ExecutionCounter)
+		}
 	}
 }