fake.go 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204
  1. /*
  2. Licensed under the Apache License, Version 2.0 (the "License");
  3. you may not use this file except in compliance with the License.
  4. You may obtain a copy of the License at
  5. http://www.apache.org/licenses/LICENSE-2.0
  6. Unless required by applicable law or agreed to in writing, software
  7. distributed under the License is distributed on an "AS IS" BASIS,
  8. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. See the License for the specific language governing permissions and
  10. limitations under the License.
  11. */
  12. package fake
  13. import (
  14. "context"
  15. "github.com/Azure/azure-sdk-for-go/profiles/latest/keyvault/keyvault"
  16. "github.com/google/uuid"
  17. mock "github.com/stretchr/testify/mock"
  18. )
  19. type secretData struct {
  20. item keyvault.SecretItem
  21. secretVersions map[string]keyvault.SecretBundle
  22. lastVersion string
  23. }
  24. type keyData struct {
  25. item keyvault.KeyItem
  26. keyVersions map[string]keyvault.KeyBundle
  27. lastVersion string
  28. }
  29. type AzureMock struct {
  30. mock.Mock
  31. knownSecrets map[string]map[string]*secretData
  32. knownKeys map[string]map[string]*keyData
  33. }
  34. func (m *AzureMock) AddSecret(vaultBaseURL, secretName, secretContent string, enabled bool) string {
  35. uid := uuid.NewString()
  36. m.AddSecretWithVersion(vaultBaseURL, secretName, uid, secretContent, enabled)
  37. return uid
  38. }
  39. func (m *AzureMock) AddSecretWithVersion(vaultBaseURL, secretName, secretVersion, secretContent string, enabled bool) {
  40. if m.knownSecrets == nil {
  41. m.knownSecrets = make(map[string]map[string]*secretData)
  42. }
  43. if m.knownSecrets[vaultBaseURL] == nil {
  44. m.knownSecrets[vaultBaseURL] = make(map[string]*secretData)
  45. }
  46. secretItemID := vaultBaseURL + secretName
  47. secretBundleID := secretItemID + "/" + secretVersion
  48. if m.knownSecrets[vaultBaseURL][secretName] == nil {
  49. m.knownSecrets[vaultBaseURL][secretName] = &secretData{
  50. item: newValidSecretItem(secretItemID, enabled),
  51. secretVersions: make(map[string]keyvault.SecretBundle),
  52. }
  53. } else {
  54. m.knownSecrets[vaultBaseURL][secretName].item.Attributes.Enabled = &enabled
  55. }
  56. m.knownSecrets[vaultBaseURL][secretName].secretVersions[secretVersion] = newValidSecretBundle(secretBundleID, secretContent)
  57. m.knownSecrets[vaultBaseURL][secretName].lastVersion = secretVersion
  58. }
  59. func newValidSecretBundle(secretBundleID, secretContent string) keyvault.SecretBundle {
  60. return keyvault.SecretBundle{
  61. Value: &secretContent,
  62. ID: &secretBundleID,
  63. }
  64. }
  65. func newValidSecretItem(secretItemID string, enabled bool) keyvault.SecretItem {
  66. return keyvault.SecretItem{
  67. ID: &secretItemID,
  68. Attributes: &keyvault.SecretAttributes{Enabled: &enabled},
  69. }
  70. }
  71. func (m *AzureMock) ExpectsGetSecret(ctx context.Context, vaultBaseURL, secretName, secretVersion string) {
  72. data := m.knownSecrets[vaultBaseURL][secretName]
  73. version := secretVersion
  74. if version == "" {
  75. version = data.lastVersion
  76. }
  77. returnValue := data.secretVersions[version]
  78. m.On("GetSecret", ctx, vaultBaseURL, secretName, secretVersion).Return(returnValue, nil)
  79. }
  80. func (m *AzureMock) ExpectsGetSecretsComplete(ctx context.Context, vaultBaseURL string, maxresults *int32) {
  81. secretMap := m.knownSecrets[vaultBaseURL]
  82. secretItems := make([]keyvault.SecretItem, len(secretMap))
  83. i := 0
  84. for _, value := range secretMap {
  85. secretItems[i] = value.item
  86. i++
  87. }
  88. firstPage := keyvault.SecretListResult{
  89. Value: &secretItems,
  90. NextLink: nil,
  91. }
  92. returnValue := keyvault.NewSecretListResultIterator(keyvault.NewSecretListResultPage(firstPage, func(context.Context, keyvault.SecretListResult) (keyvault.SecretListResult, error) {
  93. return keyvault.SecretListResult{}, nil
  94. }))
  95. m.On("GetSecretsComplete", ctx, vaultBaseURL, maxresults).Return(returnValue, nil)
  96. }
  97. func (m *AzureMock) AddKey(vaultBaseURL, keyName string, key *keyvault.JSONWebKey, enabled bool) string {
  98. uid := uuid.NewString()
  99. m.AddKeyWithVersion(vaultBaseURL, keyName, uid, key, enabled)
  100. return uid
  101. }
  102. func (m *AzureMock) AddKeyWithVersion(vaultBaseURL, keyName, keyVersion string, key *keyvault.JSONWebKey, enabled bool) {
  103. if m.knownKeys == nil {
  104. m.knownKeys = make(map[string]map[string]*keyData)
  105. }
  106. if m.knownKeys[vaultBaseURL] == nil {
  107. m.knownKeys[vaultBaseURL] = make(map[string]*keyData)
  108. }
  109. keyItemID := vaultBaseURL + keyName
  110. if m.knownKeys[vaultBaseURL][keyName] == nil {
  111. m.knownKeys[vaultBaseURL][keyName] = &keyData{
  112. item: newValidKeyItem(keyItemID, enabled),
  113. keyVersions: make(map[string]keyvault.KeyBundle),
  114. }
  115. } else {
  116. m.knownKeys[vaultBaseURL][keyName].item.Attributes.Enabled = &enabled
  117. }
  118. m.knownKeys[vaultBaseURL][keyName].keyVersions[keyVersion] = newValidKeyBundle(key)
  119. m.knownKeys[vaultBaseURL][keyName].lastVersion = keyVersion
  120. }
  121. func newValidKeyBundle(key *keyvault.JSONWebKey) keyvault.KeyBundle {
  122. return keyvault.KeyBundle{
  123. Key: key,
  124. }
  125. }
  126. func newValidKeyItem(keyItemID string, enabled bool) keyvault.KeyItem {
  127. return keyvault.KeyItem{
  128. Kid: &keyItemID,
  129. Attributes: &keyvault.KeyAttributes{Enabled: &enabled},
  130. }
  131. }
  132. func (m *AzureMock) ExpectsGetKey(ctx context.Context, vaultBaseURL, keyName, keyVersion string) {
  133. data := m.knownKeys[vaultBaseURL][keyName]
  134. version := keyVersion
  135. if version == "" {
  136. version = data.lastVersion
  137. }
  138. returnValue := data.keyVersions[version]
  139. m.On("GetKey", ctx, vaultBaseURL, keyName, keyVersion).Return(returnValue, nil)
  140. }
  141. func (m *AzureMock) ExpectsGetKeysComplete(ctx context.Context, vaultBaseURL string, maxresults *int32) {
  142. keyMap := m.knownKeys[vaultBaseURL]
  143. keyItems := make([]keyvault.KeyItem, len(keyMap))
  144. i := 0
  145. for _, value := range keyMap {
  146. keyItems[i] = value.item
  147. i++
  148. }
  149. firstPage := keyvault.KeyListResult{
  150. Value: &keyItems,
  151. NextLink: nil,
  152. }
  153. returnValue := keyvault.NewKeyListResultIterator(keyvault.NewKeyListResultPage(firstPage, func(context.Context, keyvault.KeyListResult) (keyvault.KeyListResult, error) {
  154. return keyvault.KeyListResult{}, nil
  155. }))
  156. m.On("GetKeysComplete", ctx, vaultBaseURL, maxresults).Return(returnValue, nil)
  157. }
  158. func (m *AzureMock) GetKey(ctx context.Context, vaultBaseURL, keyName, keyVersion string) (result keyvault.KeyBundle, err error) {
  159. args := m.Called(ctx, vaultBaseURL, keyName, keyVersion)
  160. return args.Get(0).(keyvault.KeyBundle), args.Error(1)
  161. }
  162. func (m *AzureMock) GetSecret(ctx context.Context, vaultBaseURL, secretName, secretVersion string) (result keyvault.SecretBundle, err error) {
  163. args := m.Called(ctx, vaultBaseURL, secretName, secretVersion)
  164. return args.Get(0).(keyvault.SecretBundle), args.Error(1)
  165. }
  166. func (m *AzureMock) GetCertificate(ctx context.Context, vaultBaseURL, certificateName, certificateVersion string) (result keyvault.CertificateBundle, err error) {
  167. args := m.Called(ctx, vaultBaseURL, certificateName, certificateVersion)
  168. return args.Get(0).(keyvault.CertificateBundle), args.Error(1)
  169. }
  170. func (m *AzureMock) GetSecretsComplete(ctx context.Context, vaultBaseURL string, maxresults *int32) (result keyvault.SecretListResultIterator, err error) {
  171. args := m.Called(ctx, vaultBaseURL, maxresults)
  172. return args.Get(0).(keyvault.SecretListResultIterator), args.Error(1)
  173. }
  174. func (m *AzureMock) GetKeysComplete(ctx context.Context, vaultBaseURL string, maxresults *int32) (result keyvault.KeyListResultIterator, err error) {
  175. args := m.Called(ctx, vaultBaseURL, maxresults)
  176. return args.Get(0).(keyvault.KeyListResultIterator), args.Error(1)
  177. }