|
|
@@ -11,6 +11,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
See the License for the specific language governing permissions and
|
|
|
limitations under the License.
|
|
|
*/
|
|
|
+
|
|
|
package secretsmanager
|
|
|
|
|
|
import (
|
|
|
@@ -37,11 +38,13 @@ type secretsManagerTestCase struct {
|
|
|
expectedSecret string
|
|
|
// for testing secretmap
|
|
|
expectedData map[string][]byte
|
|
|
+ // for testing caching
|
|
|
+ expectedCounter *int
|
|
|
}
|
|
|
|
|
|
func makeValidSecretsManagerTestCase() *secretsManagerTestCase {
|
|
|
smtc := secretsManagerTestCase{
|
|
|
- fakeClient: &fakesm.Client{},
|
|
|
+ fakeClient: fakesm.NewClient(),
|
|
|
apiInput: makeValidAPIInput(),
|
|
|
remoteRef: makeValidRemoteRef(),
|
|
|
apiOutput: makeValidAPIOutput(),
|
|
|
@@ -164,8 +167,59 @@ func TestSecretsManagerGetSecret(t *testing.T) {
|
|
|
makeValidSecretsManagerTestCaseCustom(setAPIErr),
|
|
|
}
|
|
|
|
|
|
- sm := SecretsManager{}
|
|
|
for k, v := range successCases {
|
|
|
+ sm := SecretsManager{
|
|
|
+ cache: make(map[string]*awssm.GetSecretValueOutput),
|
|
|
+ client: v.fakeClient,
|
|
|
+ }
|
|
|
+ out, err := sm.GetSecret(context.Background(), *v.remoteRef)
|
|
|
+ if !ErrorContains(err, v.expectError) {
|
|
|
+ t.Errorf("[%d] unexpected error: %s, expected: '%s'", k, err.Error(), v.expectError)
|
|
|
+ }
|
|
|
+ if err == nil && string(out) != v.expectedSecret {
|
|
|
+ t.Errorf("[%d] unexpected secret: expected %s, got %s", k, v.expectedSecret, string(out))
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+func TestCaching(t *testing.T) {
|
|
|
+ fakeClient := fakesm.NewClient()
|
|
|
+
|
|
|
+ // good case: first call, since we are using the same key, results should be cached and the counter should not go
|
|
|
+ // over 1
|
|
|
+ firstCall := func(smtc *secretsManagerTestCase) {
|
|
|
+ smtc.apiOutput.SecretString = aws.String(`{"foo":"bar", "bar":"vodka"}`)
|
|
|
+ smtc.remoteRef.Property = "foo"
|
|
|
+ smtc.expectedSecret = "bar"
|
|
|
+ smtc.expectedCounter = aws.Int(1)
|
|
|
+ smtc.fakeClient = fakeClient
|
|
|
+ }
|
|
|
+ secondCall := func(smtc *secretsManagerTestCase) {
|
|
|
+ smtc.apiOutput.SecretString = aws.String(`{"foo":"bar", "bar":"vodka"}`)
|
|
|
+ smtc.remoteRef.Property = "bar"
|
|
|
+ smtc.expectedSecret = "vodka"
|
|
|
+ smtc.expectedCounter = aws.Int(1)
|
|
|
+ smtc.fakeClient = fakeClient
|
|
|
+ }
|
|
|
+ notCachedCall := func(smtc *secretsManagerTestCase) {
|
|
|
+ smtc.apiOutput.SecretString = aws.String(`{"sheldon":"bazinga", "bar":"foo"}`)
|
|
|
+ smtc.remoteRef.Property = "sheldon"
|
|
|
+ smtc.expectedSecret = "bazinga"
|
|
|
+ smtc.expectedCounter = aws.Int(2)
|
|
|
+ smtc.fakeClient = fakeClient
|
|
|
+ smtc.apiInput.SecretId = aws.String("xyz")
|
|
|
+ smtc.remoteRef.Key = "xyz" // it should reset the cache since the key is different
|
|
|
+ }
|
|
|
+
|
|
|
+ cachedCases := []*secretsManagerTestCase{
|
|
|
+ makeValidSecretsManagerTestCaseCustom(firstCall),
|
|
|
+ makeValidSecretsManagerTestCaseCustom(firstCall),
|
|
|
+ makeValidSecretsManagerTestCaseCustom(secondCall),
|
|
|
+ makeValidSecretsManagerTestCaseCustom(notCachedCall),
|
|
|
+ }
|
|
|
+ sm := SecretsManager{
|
|
|
+ cache: make(map[string]*awssm.GetSecretValueOutput),
|
|
|
+ }
|
|
|
+ for k, v := range cachedCases {
|
|
|
sm.client = v.fakeClient
|
|
|
out, err := sm.GetSecret(context.Background(), *v.remoteRef)
|
|
|
if !ErrorContains(err, v.expectError) {
|
|
|
@@ -174,6 +228,9 @@ func TestSecretsManagerGetSecret(t *testing.T) {
|
|
|
if err == nil && string(out) != v.expectedSecret {
|
|
|
t.Errorf("[%d] unexpected secret: expected %s, got %s", k, v.expectedSecret, string(out))
|
|
|
}
|
|
|
+ if v.expectedCounter != nil && v.fakeClient.ExecutionCounter != *v.expectedCounter {
|
|
|
+ t.Errorf("[%d] unexpected counter value: expected %d, got %d", k, v.expectedCounter, v.fakeClient.ExecutionCounter)
|
|
|
+ }
|
|
|
}
|
|
|
}
|
|
|
|
|
|
@@ -184,6 +241,14 @@ func TestGetSecretMap(t *testing.T) {
|
|
|
smtc.expectedData["foo"] = []byte("bar")
|
|
|
}
|
|
|
|
|
|
+ // good case: caching
|
|
|
+ cachedMap := func(smtc *secretsManagerTestCase) {
|
|
|
+ smtc.apiOutput.SecretString = aws.String(`{"foo":"bar", "plus": "one"}`)
|
|
|
+ smtc.expectedData["foo"] = []byte("bar")
|
|
|
+ smtc.expectedData["plus"] = []byte("one")
|
|
|
+ smtc.expectedCounter = aws.Int(1)
|
|
|
+ }
|
|
|
+
|
|
|
// bad case: invalid json
|
|
|
setInvalidJSON := func(smtc *secretsManagerTestCase) {
|
|
|
smtc.apiOutput.SecretString = aws.String(`-----------------`)
|
|
|
@@ -194,11 +259,14 @@ func TestGetSecretMap(t *testing.T) {
|
|
|
makeValidSecretsManagerTestCaseCustom(setDeserialization),
|
|
|
makeValidSecretsManagerTestCaseCustom(setAPIErr),
|
|
|
makeValidSecretsManagerTestCaseCustom(setInvalidJSON),
|
|
|
+ makeValidSecretsManagerTestCaseCustom(cachedMap),
|
|
|
}
|
|
|
|
|
|
- sm := SecretsManager{}
|
|
|
for k, v := range successCases {
|
|
|
- sm.client = v.fakeClient
|
|
|
+ sm := SecretsManager{
|
|
|
+ cache: make(map[string]*awssm.GetSecretValueOutput),
|
|
|
+ client: v.fakeClient,
|
|
|
+ }
|
|
|
out, err := sm.GetSecretMap(context.Background(), *v.remoteRef)
|
|
|
if !ErrorContains(err, v.expectError) {
|
|
|
t.Errorf("[%d] unexpected error: %s, expected: '%s'", k, err.Error(), v.expectError)
|
|
|
@@ -206,6 +274,9 @@ func TestGetSecretMap(t *testing.T) {
|
|
|
if err == nil && !cmp.Equal(out, v.expectedData) {
|
|
|
t.Errorf("[%d] unexpected secret data: expected %#v, got %#v", k, v.expectedData, out)
|
|
|
}
|
|
|
+ if v.expectedCounter != nil && v.fakeClient.ExecutionCounter != *v.expectedCounter {
|
|
|
+ t.Errorf("[%d] unexpected counter value: expected %d, got %d", k, v.expectedCounter, v.fakeClient.ExecutionCounter)
|
|
|
+ }
|
|
|
}
|
|
|
}
|
|
|
|