token_getter_test.go 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299
  1. /*
  2. Copyright © The ESO Authors
  3. Licensed under the Apache License, Version 2.0 (the "License");
  4. you may not use this file except in compliance with the License.
  5. You may obtain a copy of the License at
  6. https://www.apache.org/licenses/LICENSE-2.0
  7. Unless required by applicable law or agreed to in writing, software
  8. distributed under the License is distributed on an "AS IS" BASIS,
  9. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. See the License for the specific language governing permissions and
  11. limitations under the License.
  12. */
  13. package mysterybox
  14. import (
  15. "context"
  16. "encoding/json"
  17. "strconv"
  18. "sync"
  19. "testing"
  20. "time"
  21. "github.com/nebius/gosdk/auth"
  22. tassert "github.com/stretchr/testify/assert"
  23. trequire "github.com/stretchr/testify/require"
  24. clocktesting "k8s.io/utils/clock/testing"
  25. "github.com/external-secrets/external-secrets/providers/v1/nebius/common/sdk/iam"
  26. )
  27. type tokenTestEnv struct {
  28. ctx context.Context
  29. clk *clocktesting.FakeClock
  30. fakeTokenExchanger *iam.FakeTokenExchanger
  31. cachedTokenGetter *CachedTokenGetter
  32. }
  33. func newTokenTestEnv(t *testing.T) *tokenTestEnv {
  34. t.Helper()
  35. clk := clocktesting.NewFakeClock(time.Unix(0, 0))
  36. ex := &iam.FakeTokenExchanger{}
  37. svc, err := NewCachedTokenGetter(10, ex, clk)
  38. trequire.NoError(t, err)
  39. return &tokenTestEnv{ctx: context.Background(), clk: clk, fakeTokenExchanger: ex, cachedTokenGetter: svc}
  40. }
  41. func buildSubjectCredsJSON(t *testing.T, privateKey, keyID, subject string) string {
  42. t.Helper()
  43. b, err := json.Marshal(&auth.ServiceAccountCredentials{
  44. SubjectCredentials: auth.SubjectCredentials{
  45. PrivateKey: privateKey,
  46. KeyID: keyID,
  47. Subject: subject,
  48. Issuer: subject,
  49. },
  50. })
  51. trequire.NoError(t, err)
  52. return string(b)
  53. }
  54. func TestGetToken_CachesUntilTenPercentLeft(t *testing.T) {
  55. t.Parallel()
  56. env := newTokenTestEnv(t)
  57. ctx := env.ctx
  58. creds := buildSubjectCredsJSON(t, "priv-A", "kid-A", "sa-A")
  59. token1, err := env.cachedTokenGetter.GetToken(ctx, "api.example", creds, nil)
  60. tassert.NoError(t, err)
  61. tassert.Equal(t, "token-1", token1)
  62. tassert.Equal(t, int64(1), env.fakeTokenExchanger.Calls.Load())
  63. // add 5 seconds (remaining > 10%)
  64. addSecondsToClock(env.clk, 5)
  65. token2, err := env.cachedTokenGetter.GetToken(ctx, "api.example", creds, nil)
  66. tassert.NoError(t, err)
  67. tassert.Equal(t, token1, token2)
  68. tassert.Equal(t, int64(1), env.fakeTokenExchanger.Calls.Load())
  69. // after >90% elapsed -> should refresh
  70. addSecondsToClock(env.clk, 91) // total 96s
  71. token3, err := env.cachedTokenGetter.GetToken(ctx, "api.example", creds, nil)
  72. tassert.NoError(t, err)
  73. tassert.NotEqual(t, token1, token3)
  74. tassert.Equal(t, int64(2), env.fakeTokenExchanger.Calls.Load())
  75. }
  76. func TestGetToken_SeparateCacheEntriesPerSubjectCreds(t *testing.T) {
  77. t.Parallel()
  78. env := newTokenTestEnv(t)
  79. ctx := env.ctx
  80. credsA := buildSubjectCredsJSON(t, "priv-A", "kid-A", "sa-A")
  81. credsB := buildSubjectCredsJSON(t, "priv-B", "kid-B", "sa-B")
  82. tokenA1, err := env.cachedTokenGetter.GetToken(ctx, "api.example", credsA, nil)
  83. tassert.NoError(t, err)
  84. tassert.Equal(t, "token-1", tokenA1)
  85. tokenB1, err := env.cachedTokenGetter.GetToken(ctx, "api.example", credsB, nil)
  86. tassert.NoError(t, err)
  87. tassert.Equal(t, "token-2", tokenB1)
  88. tassert.Equal(t, int64(2), env.fakeTokenExchanger.Calls.Load())
  89. // check token cached
  90. addSecondsToClock(env.clk, 1)
  91. tokA2, err := env.cachedTokenGetter.GetToken(ctx, "api.example", credsA, nil)
  92. tassert.NoError(t, err)
  93. tassert.Equal(t, tokenA1, tokA2)
  94. tassert.Equal(t, int64(2), env.fakeTokenExchanger.Calls.Load())
  95. }
  96. func TestGetToken_InvalidSubjectCreds_ReturnsError(t *testing.T) {
  97. t.Parallel()
  98. env := newTokenTestEnv(t)
  99. _, err := env.cachedTokenGetter.GetToken(env.ctx, "api.example", "not a json", nil)
  100. tassert.Error(t, err)
  101. }
  102. func addSecondsToClock(clk *clocktesting.FakeClock, second time.Duration) {
  103. clk.SetTime(clk.Now().Add(second * time.Second))
  104. }
  105. func TestGetToken_SeparateCacheEntriesPerApiDomain(t *testing.T) {
  106. t.Parallel()
  107. env := newTokenTestEnv(t)
  108. ctx := env.ctx
  109. creds := buildSubjectCredsJSON(t, "priv-A", "kid-A", "sa-A")
  110. tokA1, err := env.cachedTokenGetter.GetToken(ctx, "api.one", creds, nil)
  111. tassert.NoError(t, err)
  112. tassert.Equal(t, "token-1", tokA1)
  113. tokB1, err := env.cachedTokenGetter.GetToken(ctx, "api.two", creds, nil)
  114. tassert.NoError(t, err)
  115. tassert.Equal(t, "token-2", tokB1)
  116. tassert.NotEqual(t, tokA1, tokB1)
  117. tassert.Equal(t, int64(2), env.fakeTokenExchanger.Calls.Load())
  118. tokA2, err := env.cachedTokenGetter.GetToken(ctx, "api.one", creds, nil)
  119. tassert.NoError(t, err)
  120. tassert.Equal(t, tokA1, tokA2)
  121. tassert.Equal(t, int64(2), env.fakeTokenExchanger.Calls.Load())
  122. }
  123. func TestGetToken_LRUEviction(t *testing.T) {
  124. t.Parallel()
  125. clk := clocktesting.NewFakeClock(time.Unix(0, 0))
  126. ex := &iam.FakeTokenExchanger{}
  127. svc, err := NewCachedTokenGetter(2, ex, clk)
  128. tassert.NoError(t, err)
  129. ctx := context.Background()
  130. creds := buildSubjectCredsJSON(t, "priv-A", "kid-A", "sa-A")
  131. tok1, err := svc.GetToken(ctx, "api.first", creds, nil)
  132. tassert.NoError(t, err)
  133. tassert.Equal(t, "token-1", tok1)
  134. tok2, err := svc.GetToken(ctx, "api.second", creds, nil)
  135. tassert.NoError(t, err)
  136. tassert.Equal(t, "token-2", tok2)
  137. tassert.Equal(t, int64(2), ex.Calls.Load())
  138. tok1again, err := svc.GetToken(ctx, "api.first", creds, nil)
  139. tassert.NoError(t, err)
  140. tassert.Equal(t, tok1, tok1again)
  141. tassert.Equal(t, int64(2), ex.Calls.Load())
  142. tok3, err := svc.GetToken(ctx, "api.third", creds, nil)
  143. tassert.NoError(t, err)
  144. tassert.Equal(t, "token-3", tok3)
  145. tassert.Equal(t, int64(3), ex.Calls.Load())
  146. secondAgain, err := svc.GetToken(ctx, "api.second", creds, nil)
  147. tassert.NoError(t, err)
  148. tassert.Equal(t, "token-4", secondAgain)
  149. tassert.Equal(t, int64(4), ex.Calls.Load())
  150. }
  151. func TestGetToken_AfterExpiration_Refreshes(t *testing.T) {
  152. t.Parallel()
  153. env := newTokenTestEnv(t)
  154. ctx := env.ctx
  155. creds := buildSubjectCredsJSON(t, "priv-A", "kid-A", "sa-A")
  156. _, err := env.cachedTokenGetter.GetToken(ctx, "api.example", creds, nil)
  157. tassert.NoError(t, err)
  158. addSecondsToClock(env.clk, 101)
  159. tok2, err := env.cachedTokenGetter.GetToken(ctx, "api.example", creds, nil)
  160. tassert.NoError(t, err)
  161. tassert.Equal(t, int64(2), env.fakeTokenExchanger.Calls.Load())
  162. tassert.Equal(t, "token-2", tok2)
  163. }
  164. func TestGetToken_CacheKeyChangesOnKeyRotation(t *testing.T) {
  165. t.Parallel()
  166. env := newTokenTestEnv(t)
  167. ctx := env.ctx
  168. base := buildSubjectCredsJSON(t, "priv-A", "kid-A", "sa-A")
  169. rotatedKeyID := buildSubjectCredsJSON(t, "priv-A", "kid-B", "sa-A")
  170. rotatedPriv := buildSubjectCredsJSON(t, "priv-B", "kid-A", "sa-A")
  171. rotatedSubject := buildSubjectCredsJSON(t, "priv-A", "kid-A", "sa-B")
  172. t1, _ := env.cachedTokenGetter.GetToken(ctx, "api", base, nil)
  173. t2, _ := env.cachedTokenGetter.GetToken(ctx, "api", rotatedKeyID, nil)
  174. t3, _ := env.cachedTokenGetter.GetToken(ctx, "api", rotatedPriv, nil)
  175. t4, _ := env.cachedTokenGetter.GetToken(ctx, "api", rotatedSubject, nil)
  176. tassert.NotEqual(t, t1, t2)
  177. tassert.NotEqual(t, t1, t3)
  178. tassert.NotEqual(t, t1, t4)
  179. tassert.Equal(t, int64(4), env.fakeTokenExchanger.Calls.Load())
  180. }
  181. func TestGetToken_ExchangerErrorIsWrapped(t *testing.T) {
  182. t.Parallel()
  183. clk := clocktesting.NewFakeClock(time.Unix(0, 0))
  184. svc, err := NewCachedTokenGetter(10, &iam.FakeTokenExchanger{ReturnError: true}, clk)
  185. trequire.NoError(t, err)
  186. _, err = svc.GetToken(context.Background(), "api", buildSubjectCredsJSON(t, "p", "k", "s"), nil)
  187. tassert.Error(t, err)
  188. tassert.Contains(t, err.Error(), "could not exchange creds to iam token:")
  189. }
  190. func TestGetToken_Singleflight_DedupesConcurrentSameKey(t *testing.T) {
  191. t.Parallel()
  192. clk := clocktesting.NewFakeClock(time.Unix(0, 0))
  193. ex := &iam.FakeTokenExchanger{}
  194. svc, err := NewCachedTokenGetter(10, ex, clk)
  195. trequire.NoError(t, err)
  196. creds := buildSubjectCredsJSON(t, "priv-A", "kid-A", "sa-A")
  197. const n = 50
  198. start := make(chan struct{})
  199. var wg sync.WaitGroup
  200. wg.Add(n)
  201. tokens := make([]string, n)
  202. errs := make([]error, n)
  203. for i := range n {
  204. go func() {
  205. defer wg.Done()
  206. <-start
  207. tok, err := svc.GetToken(context.Background(), "api.example", creds, nil)
  208. tokens[i] = tok
  209. errs[i] = err
  210. }()
  211. }
  212. close(start)
  213. wg.Wait()
  214. for i := range n {
  215. tassert.NoError(t, errs[i])
  216. tassert.Equal(t, tokens[0], tokens[i])
  217. }
  218. tassert.Equal(t, int64(1), ex.Calls.Load())
  219. }
  220. func TestGetToken_ConcurrentDifferentKeys_NoRaceAndWorks(t *testing.T) {
  221. t.Parallel()
  222. clk := clocktesting.NewFakeClock(time.Unix(0, 0))
  223. ex := &iam.FakeTokenExchanger{}
  224. svc, err := NewCachedTokenGetter(2, ex, clk)
  225. trequire.NoError(t, err)
  226. creds := buildSubjectCredsJSON(t, "priv-A", "kid-A", "sa-A")
  227. const n = 50
  228. start := make(chan struct{})
  229. var wg sync.WaitGroup
  230. wg.Add(n)
  231. for i := range n {
  232. go func() {
  233. defer wg.Done()
  234. <-start
  235. domain := "api." + strconv.Itoa(i%5)
  236. _, err := svc.GetToken(context.Background(), domain, creds, nil)
  237. tassert.NoError(t, err)
  238. }()
  239. }
  240. close(start)
  241. wg.Wait()
  242. tassert.GreaterOrEqual(t, ex.Calls.Load(), int64(1)) // lru cache is small, no guarantees
  243. }