Browse Source

test: use `T.Setenv` to set env vars in tests (#1611)

This commit replaces `os.Setenv` with `t.Setenv` in tests. The
environment variable is automatically restored to its original value
when the test and all its subtests complete.

Reference: https://pkg.go.dev/testing#T.Setenv
Signed-off-by: Eng Zer Jun <engzerjun@gmail.com>
Eng Zer Jun 3 years ago
parent
commit
0c9efa67b0

+ 5 - 15
pkg/provider/aws/auth/auth_test.go

@@ -16,7 +16,6 @@ package auth
 
 import (
 	"context"
-	"os"
 	"strings"
 	"testing"
 	"time"
@@ -423,7 +422,7 @@ func testRow(t *testing.T, row TestSessionRow) {
 		assert.Nil(t, err)
 	}
 	for k, v := range row.env {
-		os.Setenv(k, v)
+		t.Setenv(k, v)
 	}
 	if row.sa != nil {
 		err := kc.Create(context.Background(), row.sa)
@@ -436,11 +435,6 @@ func testRow(t *testing.T, row TestSessionRow) {
 		},
 	})
 	assert.Nil(t, err)
-	defer func() {
-		for k := range row.env {
-			os.Unsetenv(k)
-		}
-	}()
 	s, err := New(context.Background(), row.store, kc, row.namespace, row.stsProvider, row.jwtProvider)
 	if !ErrorContains(err, row.expectErr) {
 		t.Errorf("expected error %s but found %s", row.expectErr, err.Error())
@@ -460,10 +454,8 @@ func testRow(t *testing.T, row TestSessionRow) {
 
 func TestSMEnvCredentials(t *testing.T) {
 	k8sClient := clientfake.NewClientBuilder().Build()
-	os.Setenv("AWS_SECRET_ACCESS_KEY", "1111")
-	os.Setenv("AWS_ACCESS_KEY_ID", "2222")
-	defer os.Unsetenv("AWS_SECRET_ACCESS_KEY")
-	defer os.Unsetenv("AWS_ACCESS_KEY_ID")
+	t.Setenv("AWS_SECRET_ACCESS_KEY", "1111")
+	t.Setenv("AWS_ACCESS_KEY_ID", "2222")
 	s, err := New(context.Background(), &esv1beta1.SecretStore{
 		Spec: esv1beta1.SecretStoreSpec{
 			Provider: &esv1beta1.SecretStoreProvider{
@@ -500,10 +492,8 @@ func TestSMAssumeRole(t *testing.T) {
 			}, nil
 		},
 	}
-	os.Setenv("AWS_SECRET_ACCESS_KEY", "1111")
-	os.Setenv("AWS_ACCESS_KEY_ID", "2222")
-	defer os.Unsetenv("AWS_SECRET_ACCESS_KEY")
-	defer os.Unsetenv("AWS_ACCESS_KEY_ID")
+	t.Setenv("AWS_SECRET_ACCESS_KEY", "1111")
+	t.Setenv("AWS_ACCESS_KEY_ID", "2222")
 	s, err := New(context.Background(), &esv1beta1.SecretStore{
 		Spec: esv1beta1.SecretStoreSpec{
 			Provider: &esv1beta1.SecretStoreProvider{

+ 1 - 6
pkg/provider/aws/auth/resolver_test.go

@@ -14,7 +14,6 @@ limitations under the License.
 package auth
 
 import (
-	"os"
 	"testing"
 
 	"github.com/stretchr/testify/assert"
@@ -44,7 +43,7 @@ func TestResolver(t *testing.T) {
 	}
 
 	for _, item := range tbl {
-		os.Setenv(item.env, item.url)
+		t.Setenv(item.env, item.url)
 	}
 
 	f := ResolveEndpoint()
@@ -54,8 +53,4 @@ func TestResolver(t *testing.T) {
 		assert.Nil(t, err)
 		assert.Equal(t, item.url, ep.URL)
 	}
-
-	for _, item := range tbl {
-		os.Unsetenv(item.env)
-	}
 }

+ 2 - 5
pkg/provider/aws/provider_test.go

@@ -17,7 +17,6 @@ package aws
 import (
 	"context"
 	"fmt"
-	"os"
 	"strings"
 	"testing"
 
@@ -45,10 +44,8 @@ func TestProvider(t *testing.T) {
 	// inject fake static credentials because we test
 	// if we are able to get credentials when constructing the client
 	// see #415
-	os.Setenv("AWS_ACCESS_KEY_ID", "1234")
-	os.Setenv("AWS_SECRET_ACCESS_KEY", "1234")
-	defer os.Unsetenv("AWS_ACCESS_KEY_ID")
-	defer os.Unsetenv("AWS_SECRET_ACCESS_KEY")
+	t.Setenv("AWS_ACCESS_KEY_ID", "1234")
+	t.Setenv("AWS_SECRET_ACCESS_KEY", "1234")
 
 	tbl := []struct {
 		test    string

+ 10 - 24
pkg/provider/azure/keyvault/keyvault_auth_test.go

@@ -101,8 +101,7 @@ func TestGetAuthorizorForWorkloadIdentity(t *testing.T) {
 		name       string
 		provider   *esv1beta1.AzureKVProvider
 		k8sObjects []client.Object
-		prep       func()
-		cleanup    func()
+		prep       func(*testing.T)
 		expErr     string
 	}
 
@@ -120,30 +119,20 @@ func TestGetAuthorizorForWorkloadIdentity(t *testing.T) {
 		{
 			name:     "missing workload identity token file",
 			provider: &esv1beta1.AzureKVProvider{},
-			prep: func() {
-				os.Setenv("AZURE_CLIENT_ID", clientID)
-				os.Setenv("AZURE_TENANT_ID", tenantID)
-				os.Setenv("AZURE_FEDERATED_TOKEN_FILE", "invalid file")
-			},
-			cleanup: func() {
-				os.Unsetenv("AZURE_CLIENT_ID")
-				os.Unsetenv("AZURE_TENANT_ID")
-				os.Unsetenv("AZURE_FEDERATED_TOKEN_FILE")
+			prep: func(t *testing.T) {
+				t.Setenv("AZURE_CLIENT_ID", clientID)
+				t.Setenv("AZURE_TENANT_ID", tenantID)
+				t.Setenv("AZURE_FEDERATED_TOKEN_FILE", "invalid file")
 			},
 			expErr: "unable to read token file invalid file: open invalid file: no such file or directory",
 		},
 		{
 			name:     "correct workload identity",
 			provider: &esv1beta1.AzureKVProvider{},
-			prep: func() {
-				os.Setenv("AZURE_CLIENT_ID", clientID)
-				os.Setenv("AZURE_TENANT_ID", tenantID)
-				os.Setenv("AZURE_FEDERATED_TOKEN_FILE", tokenFile)
-			},
-			cleanup: func() {
-				os.Unsetenv("AZURE_CLIENT_ID")
-				os.Unsetenv("AZURE_TENANT_ID")
-				os.Unsetenv("AZURE_FEDERATED_TOKEN_FILE")
+			prep: func(t *testing.T) {
+				t.Setenv("AZURE_CLIENT_ID", clientID)
+				t.Setenv("AZURE_TENANT_ID", tenantID)
+				t.Setenv("AZURE_FEDERATED_TOKEN_FILE", tokenFile)
 			},
 		},
 		{
@@ -200,10 +189,7 @@ func TestGetAuthorizorForWorkloadIdentity(t *testing.T) {
 				return &tokenProvider{accessToken: azAccessToken}, nil
 			}
 			if row.prep != nil {
-				row.prep()
-			}
-			if row.cleanup != nil {
-				defer row.cleanup()
+				row.prep(t)
 			}
 			authorizer, err := az.authorizerForWorkloadIdentity(context.Background(), tokenProvider)
 			if row.expErr == "" {