Browse Source

fix(vault): Treat tokens expiring in <60s as expired (#3637)

* fix(vault): Treat tokens expiring in <60s as expired

Without this, it's possible to hit a TOCTOU issue where checkToken()
sees a valid token, but it expires before the actual operation is
performed. This condition is only reachable when the experimental
caching feature is enabled.

60 seconds was chosen as a sane (but arbitrary) value. It should be more
than enough to cover the amount of time between checkToken() and the
actual operation.

Signed-off-by: Andrew Gunnerson <andrew.gunnerson@elastic.co>

* ADOPTERS.md: Add Elastic

Signed-off-by: Andrew Gunnerson <andrew.gunnerson@elastic.co>

---------

Signed-off-by: Andrew Gunnerson <andrew.gunnerson@elastic.co>
Andrew Gunnerson 1 year ago
parent
commit
2053df7b7c
3 changed files with 89 additions and 0 deletions
  1. 1 0
      ADOPTERS.md
  2. 19 0
      pkg/provider/vault/auth.go
  3. 69 0
      pkg/provider/vault/auth_test.go

+ 1 - 0
ADOPTERS.md

@@ -7,6 +7,7 @@
 - [Container Solutions](http://container-solutions.com/)
 - [DaangnPay](https://www.daangnpay.com/)
 - [Epidemic Sound](https://www.epidemicsound.com/)
+- [Elastic](https://www.elastic.co/)
 - [Fivetran](https://www.fivetran.com)
 - [Form3](https://www.form3.tech/)
 - [GoTo](https://www.goto.com/)

+ 19 - 0
pkg/provider/vault/auth.go

@@ -16,6 +16,7 @@ package vault
 
 import (
 	"context"
+	"encoding/json"
 	"errors"
 	"fmt"
 
@@ -160,6 +161,24 @@ func checkToken(ctx context.Context, token util.Token) (bool, error) {
 	if tokenType == "batch" {
 		return false, nil
 	}
+	ttl, ok := resp.Data["ttl"]
+	if !ok {
+		return false, fmt.Errorf("no TTL found in response")
+	}
+	ttlInt, err := ttl.(json.Number).Int64()
+	if err != nil {
+		return false, fmt.Errorf("invalid token TTL: %v: %w", ttl, err)
+	}
+	expireTime, ok := resp.Data["expire_time"]
+	if !ok {
+		return false, fmt.Errorf("no expiration time found in response")
+	}
+	if ttlInt < 60 && expireTime != nil {
+		// Treat expirable tokens that are about to expire as already expired.
+		// This ensures that the token won't expire in between this check and
+		// performing the actual operation.
+		return false, nil
+	}
 	return true, nil
 }
 

+ 69 - 0
pkg/provider/vault/auth_test.go

@@ -16,6 +16,7 @@ package vault
 
 import (
 	"context"
+	"encoding/json"
 	"errors"
 	"testing"
 
@@ -208,3 +209,71 @@ func TestCheckTokenErrors(t *testing.T) {
 		})
 	}
 }
+
+func TestCheckTokenTtl(t *testing.T) {
+	cases := map[string]struct {
+		message string
+		secret  *vault.Secret
+		cache   bool
+	}{
+		"LongTTLExpirable": {
+			message: "should cache if expirable token expires far into the future",
+			secret: &vault.Secret{
+				Data: map[string]interface{}{
+					"expire_time": "2024-01-01T00:00:00.000000000Z",
+					"ttl":         json.Number("3600"),
+					"type":        "service",
+				},
+			},
+			cache: true,
+		},
+		"ShortTTLExpirable": {
+			message: "should not cache if expirable token is about to expire",
+			secret: &vault.Secret{
+				Data: map[string]interface{}{
+					"expire_time": "2024-01-01T00:00:00.000000000Z",
+					"ttl":         json.Number("5"),
+					"type":        "service",
+				},
+			},
+			cache: false,
+		},
+		"ZeroTTLExpirable": {
+			message: "should not cache if expirable token has TTL of 0",
+			secret: &vault.Secret{
+				Data: map[string]interface{}{
+					"expire_time": "2024-01-01T00:00:00.000000000Z",
+					"ttl":         json.Number("0"),
+					"type":        "service",
+				},
+			},
+			cache: false,
+		},
+		"NonExpirable": {
+			message: "should cache if token is non-expirable",
+			secret: &vault.Secret{
+				Data: map[string]interface{}{
+					"expire_time": nil,
+					"ttl":         json.Number("0"),
+					"type":        "service",
+				},
+			},
+			cache: true,
+		},
+	}
+
+	for name, tc := range cases {
+		t.Run(name, func(t *testing.T) {
+			token := fake.Token{
+				LookupSelfWithContextFn: func(ctx context.Context) (*vault.Secret, error) {
+					return tc.secret, nil
+				},
+			}
+
+			cached, err := checkToken(context.Background(), token)
+			if cached != tc.cache || err != nil {
+				t.Errorf("%v: err = %v", tc.message, err)
+			}
+		})
+	}
+}