| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299 |
- /*
- Copyright © The ESO Authors
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
- https://www.apache.org/licenses/LICENSE-2.0
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
- */
- package mysterybox
- import (
- "context"
- "encoding/json"
- "strconv"
- "sync"
- "testing"
- "time"
- "github.com/nebius/gosdk/auth"
- tassert "github.com/stretchr/testify/assert"
- trequire "github.com/stretchr/testify/require"
- clocktesting "k8s.io/utils/clock/testing"
- "github.com/external-secrets/external-secrets/providers/v1/nebius/common/sdk/iam"
- )
- type tokenTestEnv struct {
- ctx context.Context
- clk *clocktesting.FakeClock
- fakeTokenExchanger *iam.FakeTokenExchanger
- cachedTokenGetter *CachedTokenGetter
- }
- func newTokenTestEnv(t *testing.T) *tokenTestEnv {
- t.Helper()
- clk := clocktesting.NewFakeClock(time.Unix(0, 0))
- ex := &iam.FakeTokenExchanger{}
- svc, err := NewCachedTokenGetter(10, ex, clk)
- trequire.NoError(t, err)
- return &tokenTestEnv{ctx: context.Background(), clk: clk, fakeTokenExchanger: ex, cachedTokenGetter: svc}
- }
- func buildSubjectCredsJSON(t *testing.T, privateKey, keyID, subject string) string {
- t.Helper()
- b, err := json.Marshal(&auth.ServiceAccountCredentials{
- SubjectCredentials: auth.SubjectCredentials{
- PrivateKey: privateKey,
- KeyID: keyID,
- Subject: subject,
- Issuer: subject,
- },
- })
- trequire.NoError(t, err)
- return string(b)
- }
- func TestGetToken_CachesUntilTenPercentLeft(t *testing.T) {
- t.Parallel()
- env := newTokenTestEnv(t)
- ctx := env.ctx
- creds := buildSubjectCredsJSON(t, "priv-A", "kid-A", "sa-A")
- token1, err := env.cachedTokenGetter.GetToken(ctx, "api.example", creds, nil)
- tassert.NoError(t, err)
- tassert.Equal(t, "token-1", token1)
- tassert.Equal(t, int64(1), env.fakeTokenExchanger.Calls.Load())
- // add 5 seconds (remaining > 10%)
- addSecondsToClock(env.clk, 5)
- token2, err := env.cachedTokenGetter.GetToken(ctx, "api.example", creds, nil)
- tassert.NoError(t, err)
- tassert.Equal(t, token1, token2)
- tassert.Equal(t, int64(1), env.fakeTokenExchanger.Calls.Load())
- // after >90% elapsed -> should refresh
- addSecondsToClock(env.clk, 91) // total 96s
- token3, err := env.cachedTokenGetter.GetToken(ctx, "api.example", creds, nil)
- tassert.NoError(t, err)
- tassert.NotEqual(t, token1, token3)
- tassert.Equal(t, int64(2), env.fakeTokenExchanger.Calls.Load())
- }
- func TestGetToken_SeparateCacheEntriesPerSubjectCreds(t *testing.T) {
- t.Parallel()
- env := newTokenTestEnv(t)
- ctx := env.ctx
- credsA := buildSubjectCredsJSON(t, "priv-A", "kid-A", "sa-A")
- credsB := buildSubjectCredsJSON(t, "priv-B", "kid-B", "sa-B")
- tokenA1, err := env.cachedTokenGetter.GetToken(ctx, "api.example", credsA, nil)
- tassert.NoError(t, err)
- tassert.Equal(t, "token-1", tokenA1)
- tokenB1, err := env.cachedTokenGetter.GetToken(ctx, "api.example", credsB, nil)
- tassert.NoError(t, err)
- tassert.Equal(t, "token-2", tokenB1)
- tassert.Equal(t, int64(2), env.fakeTokenExchanger.Calls.Load())
- // check token cached
- addSecondsToClock(env.clk, 1)
- tokA2, err := env.cachedTokenGetter.GetToken(ctx, "api.example", credsA, nil)
- tassert.NoError(t, err)
- tassert.Equal(t, tokenA1, tokA2)
- tassert.Equal(t, int64(2), env.fakeTokenExchanger.Calls.Load())
- }
- func TestGetToken_InvalidSubjectCreds_ReturnsError(t *testing.T) {
- t.Parallel()
- env := newTokenTestEnv(t)
- _, err := env.cachedTokenGetter.GetToken(env.ctx, "api.example", "not a json", nil)
- tassert.Error(t, err)
- }
- func addSecondsToClock(clk *clocktesting.FakeClock, second time.Duration) {
- clk.SetTime(clk.Now().Add(second * time.Second))
- }
- func TestGetToken_SeparateCacheEntriesPerApiDomain(t *testing.T) {
- t.Parallel()
- env := newTokenTestEnv(t)
- ctx := env.ctx
- creds := buildSubjectCredsJSON(t, "priv-A", "kid-A", "sa-A")
- tokA1, err := env.cachedTokenGetter.GetToken(ctx, "api.one", creds, nil)
- tassert.NoError(t, err)
- tassert.Equal(t, "token-1", tokA1)
- tokB1, err := env.cachedTokenGetter.GetToken(ctx, "api.two", creds, nil)
- tassert.NoError(t, err)
- tassert.Equal(t, "token-2", tokB1)
- tassert.NotEqual(t, tokA1, tokB1)
- tassert.Equal(t, int64(2), env.fakeTokenExchanger.Calls.Load())
- tokA2, err := env.cachedTokenGetter.GetToken(ctx, "api.one", creds, nil)
- tassert.NoError(t, err)
- tassert.Equal(t, tokA1, tokA2)
- tassert.Equal(t, int64(2), env.fakeTokenExchanger.Calls.Load())
- }
- func TestGetToken_LRUEviction(t *testing.T) {
- t.Parallel()
- clk := clocktesting.NewFakeClock(time.Unix(0, 0))
- ex := &iam.FakeTokenExchanger{}
- svc, err := NewCachedTokenGetter(2, ex, clk)
- tassert.NoError(t, err)
- ctx := context.Background()
- creds := buildSubjectCredsJSON(t, "priv-A", "kid-A", "sa-A")
- tok1, err := svc.GetToken(ctx, "api.first", creds, nil)
- tassert.NoError(t, err)
- tassert.Equal(t, "token-1", tok1)
- tok2, err := svc.GetToken(ctx, "api.second", creds, nil)
- tassert.NoError(t, err)
- tassert.Equal(t, "token-2", tok2)
- tassert.Equal(t, int64(2), ex.Calls.Load())
- tok1again, err := svc.GetToken(ctx, "api.first", creds, nil)
- tassert.NoError(t, err)
- tassert.Equal(t, tok1, tok1again)
- tassert.Equal(t, int64(2), ex.Calls.Load())
- tok3, err := svc.GetToken(ctx, "api.third", creds, nil)
- tassert.NoError(t, err)
- tassert.Equal(t, "token-3", tok3)
- tassert.Equal(t, int64(3), ex.Calls.Load())
- secondAgain, err := svc.GetToken(ctx, "api.second", creds, nil)
- tassert.NoError(t, err)
- tassert.Equal(t, "token-4", secondAgain)
- tassert.Equal(t, int64(4), ex.Calls.Load())
- }
- func TestGetToken_AfterExpiration_Refreshes(t *testing.T) {
- t.Parallel()
- env := newTokenTestEnv(t)
- ctx := env.ctx
- creds := buildSubjectCredsJSON(t, "priv-A", "kid-A", "sa-A")
- _, err := env.cachedTokenGetter.GetToken(ctx, "api.example", creds, nil)
- tassert.NoError(t, err)
- addSecondsToClock(env.clk, 101)
- tok2, err := env.cachedTokenGetter.GetToken(ctx, "api.example", creds, nil)
- tassert.NoError(t, err)
- tassert.Equal(t, int64(2), env.fakeTokenExchanger.Calls.Load())
- tassert.Equal(t, "token-2", tok2)
- }
- func TestGetToken_CacheKeyChangesOnKeyRotation(t *testing.T) {
- t.Parallel()
- env := newTokenTestEnv(t)
- ctx := env.ctx
- base := buildSubjectCredsJSON(t, "priv-A", "kid-A", "sa-A")
- rotatedKeyID := buildSubjectCredsJSON(t, "priv-A", "kid-B", "sa-A")
- rotatedPriv := buildSubjectCredsJSON(t, "priv-B", "kid-A", "sa-A")
- rotatedSubject := buildSubjectCredsJSON(t, "priv-A", "kid-A", "sa-B")
- t1, _ := env.cachedTokenGetter.GetToken(ctx, "api", base, nil)
- t2, _ := env.cachedTokenGetter.GetToken(ctx, "api", rotatedKeyID, nil)
- t3, _ := env.cachedTokenGetter.GetToken(ctx, "api", rotatedPriv, nil)
- t4, _ := env.cachedTokenGetter.GetToken(ctx, "api", rotatedSubject, nil)
- tassert.NotEqual(t, t1, t2)
- tassert.NotEqual(t, t1, t3)
- tassert.NotEqual(t, t1, t4)
- tassert.Equal(t, int64(4), env.fakeTokenExchanger.Calls.Load())
- }
- func TestGetToken_ExchangerErrorIsWrapped(t *testing.T) {
- t.Parallel()
- clk := clocktesting.NewFakeClock(time.Unix(0, 0))
- svc, err := NewCachedTokenGetter(10, &iam.FakeTokenExchanger{ReturnError: true}, clk)
- trequire.NoError(t, err)
- _, err = svc.GetToken(context.Background(), "api", buildSubjectCredsJSON(t, "p", "k", "s"), nil)
- tassert.Error(t, err)
- tassert.Contains(t, err.Error(), "could not exchange creds to iam token:")
- }
- func TestGetToken_Singleflight_DedupesConcurrentSameKey(t *testing.T) {
- t.Parallel()
- clk := clocktesting.NewFakeClock(time.Unix(0, 0))
- ex := &iam.FakeTokenExchanger{}
- svc, err := NewCachedTokenGetter(10, ex, clk)
- trequire.NoError(t, err)
- creds := buildSubjectCredsJSON(t, "priv-A", "kid-A", "sa-A")
- const n = 50
- start := make(chan struct{})
- var wg sync.WaitGroup
- wg.Add(n)
- tokens := make([]string, n)
- errs := make([]error, n)
- for i := range n {
- go func() {
- defer wg.Done()
- <-start
- tok, err := svc.GetToken(context.Background(), "api.example", creds, nil)
- tokens[i] = tok
- errs[i] = err
- }()
- }
- close(start)
- wg.Wait()
- for i := range n {
- tassert.NoError(t, errs[i])
- tassert.Equal(t, tokens[0], tokens[i])
- }
- tassert.Equal(t, int64(1), ex.Calls.Load())
- }
- func TestGetToken_ConcurrentDifferentKeys_NoRaceAndWorks(t *testing.T) {
- t.Parallel()
- clk := clocktesting.NewFakeClock(time.Unix(0, 0))
- ex := &iam.FakeTokenExchanger{}
- svc, err := NewCachedTokenGetter(2, ex, clk)
- trequire.NoError(t, err)
- creds := buildSubjectCredsJSON(t, "priv-A", "kid-A", "sa-A")
- const n = 50
- start := make(chan struct{})
- var wg sync.WaitGroup
- wg.Add(n)
- for i := range n {
- go func() {
- defer wg.Done()
- <-start
- domain := "api." + strconv.Itoa(i%5)
- _, err := svc.GetToken(context.Background(), domain, creds, nil)
- tassert.NoError(t, err)
- }()
- }
- close(start)
- wg.Wait()
- tassert.GreaterOrEqual(t, ex.Calls.Load(), int64(1)) // lru cache is small, no guarantees
- }
|