Browse Source

feat: added cache in aws secret manager provider

Alexander Chernov 4 years ago
parent
commit
dae7237953

+ 27 - 3
pkg/provider/aws/secretsmanager/fake/fake.go

@@ -11,6 +11,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 See the License for the specific language governing permissions and
 See the License for the specific language governing permissions and
 limitations under the License.
 limitations under the License.
 */
 */
+
 package fake
 package fake
 
 
 import (
 import (
@@ -22,15 +23,38 @@ import (
 
 
 // Client implements the aws secretsmanager interface.
 // Client implements the aws secretsmanager interface.
 type Client struct {
 type Client struct {
-	valFn func(*awssm.GetSecretValueInput) (*awssm.GetSecretValueOutput, error)
+	ExecutionCounter int
+	valFn            map[string]func(*awssm.GetSecretValueInput) (*awssm.GetSecretValueOutput, error)
+}
+
+// NewClient init a new fake client.
+func NewClient() *Client {
+	return &Client{
+		valFn: make(map[string]func(*awssm.GetSecretValueInput) (*awssm.GetSecretValueOutput, error)),
+	}
 }
 }
 
 
 func (sm *Client) GetSecretValue(in *awssm.GetSecretValueInput) (*awssm.GetSecretValueOutput, error) {
 func (sm *Client) GetSecretValue(in *awssm.GetSecretValueInput) (*awssm.GetSecretValueOutput, error) {
-	return sm.valFn(in)
+	sm.ExecutionCounter++
+	if entry, found := sm.valFn[sm.cacheKeyForInput(in)]; found {
+		return entry(in)
+	}
+	return nil, fmt.Errorf("test case not found")
+}
+
+func (sm *Client) cacheKeyForInput(in *awssm.GetSecretValueInput) string {
+	var secretID, versionID string
+	if in.SecretId != nil {
+		secretID = *in.SecretId
+	}
+	if in.VersionId != nil {
+		versionID = *in.VersionId
+	}
+	return fmt.Sprintf("%s#%s", secretID, versionID)
 }
 }
 
 
 func (sm *Client) WithValue(in *awssm.GetSecretValueInput, val *awssm.GetSecretValueOutput, err error) {
 func (sm *Client) WithValue(in *awssm.GetSecretValueInput, val *awssm.GetSecretValueOutput, err error) {
-	sm.valFn = func(paramIn *awssm.GetSecretValueInput) (*awssm.GetSecretValueOutput, error) {
+	sm.valFn[sm.cacheKeyForInput(in)] = func(paramIn *awssm.GetSecretValueInput) (*awssm.GetSecretValueOutput, error) {
 		if !cmp.Equal(paramIn, in) {
 		if !cmp.Equal(paramIn, in) {
 			return nil, fmt.Errorf("unexpected test argument")
 			return nil, fmt.Errorf("unexpected test argument")
 		}
 		}

+ 22 - 2
pkg/provider/aws/secretsmanager/secretsmanager.go

@@ -11,6 +11,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 See the License for the specific language governing permissions and
 See the License for the specific language governing permissions and
 limitations under the License.
 limitations under the License.
 */
 */
+
 package secretsmanager
 package secretsmanager
 
 
 import (
 import (
@@ -30,6 +31,7 @@ import (
 // SecretsManager is a provider for AWS SecretsManager.
 // SecretsManager is a provider for AWS SecretsManager.
 type SecretsManager struct {
 type SecretsManager struct {
 	client SMInterface
 	client SMInterface
+	cache  map[string]*awssm.GetSecretValueOutput
 }
 }
 
 
 // SMInterface is a subset of the smiface api.
 // SMInterface is a subset of the smiface api.
@@ -44,21 +46,38 @@ var log = ctrl.Log.WithName("provider").WithName("aws").WithName("secretsmanager
 func New(sess client.ConfigProvider) (*SecretsManager, error) {
 func New(sess client.ConfigProvider) (*SecretsManager, error) {
 	return &SecretsManager{
 	return &SecretsManager{
 		client: awssm.New(sess),
 		client: awssm.New(sess),
+		cache:  make(map[string]*awssm.GetSecretValueOutput),
 	}, nil
 	}, nil
 }
 }
 
 
-// GetSecret returns a single secret from the provider.
-func (sm *SecretsManager) GetSecret(ctx context.Context, ref esv1alpha1.ExternalSecretDataRemoteRef) ([]byte, error) {
+func (sm *SecretsManager) fetch(_ context.Context, ref esv1alpha1.ExternalSecretDataRemoteRef) (*awssm.GetSecretValueOutput, error) {
 	ver := "AWSCURRENT"
 	ver := "AWSCURRENT"
 	if ref.Version != "" {
 	if ref.Version != "" {
 		ver = ref.Version
 		ver = ref.Version
 	}
 	}
 	log.Info("fetching secret value", "key", ref.Key, "version", ver)
 	log.Info("fetching secret value", "key", ref.Key, "version", ver)
+
+	cacheKey := fmt.Sprintf("%s#%s", ref.Key, ver)
+	if secretOut, found := sm.cache[cacheKey]; found {
+		log.Info("found secret in cache", "key", ref.Key, "version", ver)
+		return secretOut, nil
+	}
 	secretOut, err := sm.client.GetSecretValue(&awssm.GetSecretValueInput{
 	secretOut, err := sm.client.GetSecretValue(&awssm.GetSecretValueInput{
 		SecretId:     &ref.Key,
 		SecretId:     &ref.Key,
 		VersionStage: &ver,
 		VersionStage: &ver,
 	})
 	})
 	if err != nil {
 	if err != nil {
+		return nil, err
+	}
+	sm.cache[cacheKey] = secretOut
+
+	return secretOut, nil
+}
+
+// GetSecret returns a single secret from the provider.
+func (sm *SecretsManager) GetSecret(ctx context.Context, ref esv1alpha1.ExternalSecretDataRemoteRef) ([]byte, error) {
+	secretOut, err := sm.fetch(ctx, ref)
+	if err != nil {
 		return nil, util.SanitizeErr(err)
 		return nil, util.SanitizeErr(err)
 	}
 	}
 	if ref.Property == "" {
 	if ref.Property == "" {
@@ -77,6 +96,7 @@ func (sm *SecretsManager) GetSecret(ctx context.Context, ref esv1alpha1.External
 	if secretOut.SecretBinary != nil {
 	if secretOut.SecretBinary != nil {
 		payload = string(secretOut.SecretBinary)
 		payload = string(secretOut.SecretBinary)
 	}
 	}
+
 	val := gjson.Get(payload, ref.Property)
 	val := gjson.Get(payload, ref.Property)
 	if !val.Exists() {
 	if !val.Exists() {
 		return nil, fmt.Errorf("key %s does not exist in secret %s", ref.Property, ref.Key)
 		return nil, fmt.Errorf("key %s does not exist in secret %s", ref.Property, ref.Key)

+ 75 - 4
pkg/provider/aws/secretsmanager/secretsmanager_test.go

@@ -11,6 +11,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 See the License for the specific language governing permissions and
 See the License for the specific language governing permissions and
 limitations under the License.
 limitations under the License.
 */
 */
+
 package secretsmanager
 package secretsmanager
 
 
 import (
 import (
@@ -37,11 +38,13 @@ type secretsManagerTestCase struct {
 	expectedSecret string
 	expectedSecret string
 	// for testing secretmap
 	// for testing secretmap
 	expectedData map[string][]byte
 	expectedData map[string][]byte
+	// for testing caching
+	expectedCounter *int
 }
 }
 
 
 func makeValidSecretsManagerTestCase() *secretsManagerTestCase {
 func makeValidSecretsManagerTestCase() *secretsManagerTestCase {
 	smtc := secretsManagerTestCase{
 	smtc := secretsManagerTestCase{
-		fakeClient:     &fakesm.Client{},
+		fakeClient:     fakesm.NewClient(),
 		apiInput:       makeValidAPIInput(),
 		apiInput:       makeValidAPIInput(),
 		remoteRef:      makeValidRemoteRef(),
 		remoteRef:      makeValidRemoteRef(),
 		apiOutput:      makeValidAPIOutput(),
 		apiOutput:      makeValidAPIOutput(),
@@ -164,8 +167,59 @@ func TestSecretsManagerGetSecret(t *testing.T) {
 		makeValidSecretsManagerTestCaseCustom(setAPIErr),
 		makeValidSecretsManagerTestCaseCustom(setAPIErr),
 	}
 	}
 
 
-	sm := SecretsManager{}
 	for k, v := range successCases {
 	for k, v := range successCases {
+		sm := SecretsManager{
+			cache:  make(map[string]*awssm.GetSecretValueOutput),
+			client: v.fakeClient,
+		}
+		out, err := sm.GetSecret(context.Background(), *v.remoteRef)
+		if !ErrorContains(err, v.expectError) {
+			t.Errorf("[%d] unexpected error: %s, expected: '%s'", k, err.Error(), v.expectError)
+		}
+		if err == nil && string(out) != v.expectedSecret {
+			t.Errorf("[%d] unexpected secret: expected %s, got %s", k, v.expectedSecret, string(out))
+		}
+	}
+}
+func TestCaching(t *testing.T) {
+	fakeClient := fakesm.NewClient()
+
+	// good case: first call, since we are using the same key, results should be cached and the counter should not go
+	// over 1
+	firstCall := func(smtc *secretsManagerTestCase) {
+		smtc.apiOutput.SecretString = aws.String(`{"foo":"bar", "bar":"vodka"}`)
+		smtc.remoteRef.Property = "foo"
+		smtc.expectedSecret = "bar"
+		smtc.expectedCounter = aws.Int(1)
+		smtc.fakeClient = fakeClient
+	}
+	secondCall := func(smtc *secretsManagerTestCase) {
+		smtc.apiOutput.SecretString = aws.String(`{"foo":"bar", "bar":"vodka"}`)
+		smtc.remoteRef.Property = "bar"
+		smtc.expectedSecret = "vodka"
+		smtc.expectedCounter = aws.Int(1)
+		smtc.fakeClient = fakeClient
+	}
+	notCachedCall := func(smtc *secretsManagerTestCase) {
+		smtc.apiOutput.SecretString = aws.String(`{"sheldon":"bazinga", "bar":"foo"}`)
+		smtc.remoteRef.Property = "sheldon"
+		smtc.expectedSecret = "bazinga"
+		smtc.expectedCounter = aws.Int(2)
+		smtc.fakeClient = fakeClient
+		smtc.apiInput.SecretId = aws.String("xyz")
+		smtc.remoteRef.Key = "xyz" // it should reset the cache since the key is different
+	}
+
+	cachedCases := []*secretsManagerTestCase{
+		makeValidSecretsManagerTestCaseCustom(firstCall),
+		makeValidSecretsManagerTestCaseCustom(firstCall),
+		makeValidSecretsManagerTestCaseCustom(secondCall),
+		makeValidSecretsManagerTestCaseCustom(notCachedCall),
+	}
+	sm := SecretsManager{
+		cache: make(map[string]*awssm.GetSecretValueOutput),
+	}
+	for k, v := range cachedCases {
 		sm.client = v.fakeClient
 		sm.client = v.fakeClient
 		out, err := sm.GetSecret(context.Background(), *v.remoteRef)
 		out, err := sm.GetSecret(context.Background(), *v.remoteRef)
 		if !ErrorContains(err, v.expectError) {
 		if !ErrorContains(err, v.expectError) {
@@ -174,6 +228,9 @@ func TestSecretsManagerGetSecret(t *testing.T) {
 		if err == nil && string(out) != v.expectedSecret {
 		if err == nil && string(out) != v.expectedSecret {
 			t.Errorf("[%d] unexpected secret: expected %s, got %s", k, v.expectedSecret, string(out))
 			t.Errorf("[%d] unexpected secret: expected %s, got %s", k, v.expectedSecret, string(out))
 		}
 		}
+		if v.expectedCounter != nil && v.fakeClient.ExecutionCounter != *v.expectedCounter {
+			t.Errorf("[%d] unexpected counter value: expected %d, got %d", k, v.expectedCounter, v.fakeClient.ExecutionCounter)
+		}
 	}
 	}
 }
 }
 
 
@@ -184,6 +241,14 @@ func TestGetSecretMap(t *testing.T) {
 		smtc.expectedData["foo"] = []byte("bar")
 		smtc.expectedData["foo"] = []byte("bar")
 	}
 	}
 
 
+	// good case: caching
+	cachedMap := func(smtc *secretsManagerTestCase) {
+		smtc.apiOutput.SecretString = aws.String(`{"foo":"bar", "plus": "one"}`)
+		smtc.expectedData["foo"] = []byte("bar")
+		smtc.expectedData["plus"] = []byte("one")
+		smtc.expectedCounter = aws.Int(1)
+	}
+
 	// bad case: invalid json
 	// bad case: invalid json
 	setInvalidJSON := func(smtc *secretsManagerTestCase) {
 	setInvalidJSON := func(smtc *secretsManagerTestCase) {
 		smtc.apiOutput.SecretString = aws.String(`-----------------`)
 		smtc.apiOutput.SecretString = aws.String(`-----------------`)
@@ -194,11 +259,14 @@ func TestGetSecretMap(t *testing.T) {
 		makeValidSecretsManagerTestCaseCustom(setDeserialization),
 		makeValidSecretsManagerTestCaseCustom(setDeserialization),
 		makeValidSecretsManagerTestCaseCustom(setAPIErr),
 		makeValidSecretsManagerTestCaseCustom(setAPIErr),
 		makeValidSecretsManagerTestCaseCustom(setInvalidJSON),
 		makeValidSecretsManagerTestCaseCustom(setInvalidJSON),
+		makeValidSecretsManagerTestCaseCustom(cachedMap),
 	}
 	}
 
 
-	sm := SecretsManager{}
 	for k, v := range successCases {
 	for k, v := range successCases {
-		sm.client = v.fakeClient
+		sm := SecretsManager{
+			cache:  make(map[string]*awssm.GetSecretValueOutput),
+			client: v.fakeClient,
+		}
 		out, err := sm.GetSecretMap(context.Background(), *v.remoteRef)
 		out, err := sm.GetSecretMap(context.Background(), *v.remoteRef)
 		if !ErrorContains(err, v.expectError) {
 		if !ErrorContains(err, v.expectError) {
 			t.Errorf("[%d] unexpected error: %s, expected: '%s'", k, err.Error(), v.expectError)
 			t.Errorf("[%d] unexpected error: %s, expected: '%s'", k, err.Error(), v.expectError)
@@ -206,6 +274,9 @@ func TestGetSecretMap(t *testing.T) {
 		if err == nil && !cmp.Equal(out, v.expectedData) {
 		if err == nil && !cmp.Equal(out, v.expectedData) {
 			t.Errorf("[%d] unexpected secret data: expected %#v, got %#v", k, v.expectedData, out)
 			t.Errorf("[%d] unexpected secret data: expected %#v, got %#v", k, v.expectedData, out)
 		}
 		}
+		if v.expectedCounter != nil && v.fakeClient.ExecutionCounter != *v.expectedCounter {
+			t.Errorf("[%d] unexpected counter value: expected %d, got %d", k, v.expectedCounter, v.fakeClient.ExecutionCounter)
+		}
 	}
 	}
 }
 }