Browse Source

fix: allow using UUID as vault and item name (#4490)

Signed-off-by: Gergely Brautigam <182850+Skarlso@users.noreply.github.com>
Gergely Brautigam 1 year ago
parent
commit
11f059a9a2

+ 35 - 15
pkg/provider/onepassword/fake/fake.go

@@ -48,13 +48,13 @@ func (mockClient *OnePasswordMockClient) GetVaults() ([]onepassword.Vault, error
 }
 
 // GetVault unused fake.
-func (mockClient *OnePasswordMockClient) GetVault(_ string) (*onepassword.Vault, error) {
-	return &onepassword.Vault{}, nil
+func (mockClient *OnePasswordMockClient) GetVault(uuid string) (*onepassword.Vault, error) {
+	return mockClient.GetVaultByTitle(uuid)
 }
 
 // GetVaultByUUID unused fake.
-func (mockClient *OnePasswordMockClient) GetVaultByUUID(_ string) (*onepassword.Vault, error) {
-	return &onepassword.Vault{}, nil
+func (mockClient *OnePasswordMockClient) GetVaultByUUID(uuid string) (*onepassword.Vault, error) {
+	return &mockClient.MockVaults[uuid][0], nil
 }
 
 // GetVaultByTitle returns a vault, you must preload, only one.
@@ -77,16 +77,7 @@ func (mockClient *OnePasswordMockClient) GetItems(vaultUUID string) ([]onepasswo
 
 // GetItem returns a *onepassword.Item, you must preload.
 func (mockClient *OnePasswordMockClient) GetItem(itemUUID, vaultUUID string) (*onepassword.Item, error) {
-	for _, item := range mockClient.MockItems[vaultUUID] {
-		if item.ID == itemUUID {
-			// load the fields that GetItemsByTitle does not
-			item.Fields = mockClient.MockItemFields[vaultUUID][itemUUID]
-
-			return &item, nil
-		}
-	}
-
-	return &onepassword.Item{}, errors.New("status 400: Invalid Item UUID")
+	return mockClient.GetItemByUUID(itemUUID, vaultUUID)
 }
 
 // GetItemByUUID returns a *onepassword.Item, you must preload.
@@ -100,7 +91,7 @@ func (mockClient *OnePasswordMockClient) GetItemByUUID(itemUUID, vaultUUID strin
 		}
 	}
 
-	return &onepassword.Item{}, errors.New("status 400: Invalid Item UUID")
+	return &onepassword.Item{}, errors.New("status 400: Invalid GetItemByUUID")
 }
 
 // GetItemByTitle unused fake.
@@ -233,6 +224,16 @@ func (mockClient *OnePasswordMockClient) AddPredictableVault(name string) *OnePa
 	return mockClient
 }
 
+// AddPredictableVaultUUID adds vaults to the mock client in a predictable way.
+func (mockClient *OnePasswordMockClient) AddPredictableVaultUUID(name string) *OnePasswordMockClient {
+	mockClient.MockVaults[name] = append(mockClient.MockVaults[name], onepassword.Vault{
+		ID:   name,
+		Name: name,
+	})
+
+	return mockClient
+}
+
 // AddPredictableItemWithField adds an item and it's fields to the mock client in a predictable way.
 func (mockClient *OnePasswordMockClient) AddPredictableItemWithField(vaultName, title, label, value string) *OnePasswordMockClient {
 	itemID := fmt.Sprintf("%s-id", title)
@@ -255,6 +256,25 @@ func (mockClient *OnePasswordMockClient) AddPredictableItemWithField(vaultName,
 	return mockClient
 }
 
+// AddPredictableItemWithFieldUUID adds an item and it's fields to the mock client in a predictable way.
+func (mockClient *OnePasswordMockClient) AddPredictableItemWithFieldUUID(vaultName, title, label, value string) *OnePasswordMockClient {
+	mockClient.MockItems[vaultName] = append(mockClient.MockItems[vaultName], onepassword.Item{
+		ID:    title,
+		Title: title,
+		Vault: onepassword.ItemVault{ID: vaultName},
+	})
+
+	if mockClient.MockItemFields[vaultName] == nil {
+		mockClient.MockItemFields[vaultName] = make(map[string][]*onepassword.ItemField)
+	}
+	mockClient.MockItemFields[vaultName][title] = append(mockClient.MockItemFields[vaultName][title], &onepassword.ItemField{
+		Label: label,
+		Value: value,
+	})
+
+	return mockClient
+}
+
 // AppendVault appends a onepassword.Vault to the mock client.
 func (mockClient *OnePasswordMockClient) AppendVault(name string, vault onepassword.Vault) *OnePasswordMockClient {
 	mockClient.MockVaults[name] = append(mockClient.MockVaults[name], vault)

+ 9 - 4
pkg/provider/onepassword/onepassword.go

@@ -27,6 +27,7 @@ import (
 	"github.com/1Password/connect-sdk-go/connect"
 	"github.com/1Password/connect-sdk-go/onepassword"
 	corev1 "k8s.io/api/core/v1"
+	"k8s.io/kube-openapi/pkg/validation/strfmt"
 	kclient "sigs.k8s.io/controller-runtime/pkg/client"
 	"sigs.k8s.io/controller-runtime/pkg/webhook/admission"
 
@@ -410,7 +411,7 @@ func (provider *ProviderOnePassword) GetSecret(_ context.Context, ref esv1beta1.
 // to be able to retrieve secrets from the provider.
 func (provider *ProviderOnePassword) Validate() (esv1beta1.ValidationResult, error) {
 	for vaultName := range provider.vaults {
-		_, err := provider.client.GetVaultByTitle(vaultName)
+		_, err := provider.client.GetVault(vaultName)
 		if err != nil {
 			return esv1beta1.ValidationResultError, err
 		}
@@ -442,10 +443,11 @@ func (provider *ProviderOnePassword) GetAllSecrets(_ context.Context, ref esv1be
 	secretData := make(map[string][]byte)
 	sortedVaults := sortVaults(provider.vaults)
 	for _, vaultName := range sortedVaults {
-		vault, err := provider.client.GetVaultByTitle(vaultName)
+		vault, err := provider.client.GetVault(vaultName)
 		if err != nil {
 			return nil, fmt.Errorf(errGetVault, err)
 		}
+
 		if ref.Tags != nil {
 			err = provider.getAllByTags(vault.ID, ref, secretData)
 			if err != nil {
@@ -470,12 +472,15 @@ func (provider *ProviderOnePassword) Close(_ context.Context) error {
 func (provider *ProviderOnePassword) findItem(name string) (*onepassword.Item, error) {
 	sortedVaults := sortVaults(provider.vaults)
 	for _, vaultName := range sortedVaults {
-		vault, err := provider.client.GetVaultByTitle(vaultName)
+		vault, err := provider.client.GetVault(vaultName)
 		if err != nil {
 			return nil, fmt.Errorf(errGetVault, err)
 		}
 
-		// use GetItemsByTitle instead of GetItemByTitle in order to handle length cases
+		if strfmt.IsUUID(name) {
+			return provider.client.GetItem(name, vault.ID)
+		}
+
 		items, err := provider.client.GetItemsByTitle(name, vault.ID)
 		if err != nil {
 			return nil, fmt.Errorf(errGetItem, err)

+ 52 - 23
pkg/provider/onepassword/onepassword_test.go

@@ -36,8 +36,8 @@ import (
 
 const (
 	// vaults and items.
-	myVault, myVaultID                       = "my-vault", "my-vault-id"
-	myItem, myItemID                         = "my-item", "my-item-id"
+	myVault, myVaultID, myVaultUUID          = "my-vault", "my-vault-id", "39c31136-d086-47e9-a52c-8fe330d2669a"
+	myItem, myItemID, myItemUUID             = "my-item", "my-item-id", "687adbe7-e6d2-4059-9a62-dbb95d291143"
 	mySharedVault, mySharedVaultID           = "my-shared-vault", "my-shared-vault-id"
 	mySharedItem, mySharedItemID             = "my-shared-item", "my-shared-item-id"
 	myOtherVault, myOtherVaultID             = "my-other-vault", "my-other-vault-id"
@@ -118,6 +118,33 @@ func TestFindItem(t *testing.T) {
 			},
 		},
 		{
+			setupNote: "uuid: valid basic: one vault, one item, one field",
+			provider: &ProviderOnePassword{
+				vaults: map[string]int{myVaultUUID: 1},
+				client: fake.NewMockClient().
+					AddPredictableVaultUUID(myVaultUUID).
+					AddPredictableItemWithFieldUUID(myVaultUUID, myItemUUID, key1, value1),
+			},
+			checks: []check{
+				{
+					checkNote:    "pass",
+					findItemName: myItemUUID,
+					expectedErr:  nil,
+					expectedItem: &onepassword.Item{
+						ID:    myItemUUID,
+						Title: myItemUUID,
+						Vault: onepassword.ItemVault{ID: myVaultUUID},
+						Fields: []*onepassword.ItemField{
+							{
+								Label: key1,
+								Value: value1,
+							},
+						},
+					},
+				},
+			},
+		},
+		{
 			setupNote: "multiple vaults, multiple items",
 			provider: &ProviderOnePassword{
 				vaults: map[string]int{myVault: 1, mySharedVault: 2},
@@ -328,29 +355,31 @@ func TestFindItem(t *testing.T) {
 	}
 
 	// run the tests
-	for _, tc := range testCases {
-		for _, check := range tc.checks {
-			got, err := tc.provider.findItem(check.findItemName)
-			notes := fmt.Sprintf(setupCheckFormat, tc.setupNote, check.checkNote)
-			if check.expectedErr == nil && err != nil {
-				// expected no error, got one
-				t.Errorf(findItemErrFormat, notes, nil, err)
-			}
-			if check.expectedErr != nil && err == nil {
-				// expected an error, didn't get one
-				t.Errorf(findItemErrFormat, notes, check.expectedErr.Error(), nil)
-			}
-			if check.expectedErr != nil && err != nil && err.Error() != check.expectedErr.Error() {
-				// expected an error, got the wrong one
-				t.Errorf(findItemErrFormat, notes, check.expectedErr.Error(), err.Error())
-			}
-			if check.expectedItem != nil {
-				if !reflect.DeepEqual(check.expectedItem, got) {
-					// expected a predefined item, got something else
-					t.Errorf(findItemErrFormat, notes, check.expectedItem, got)
+	for num, tc := range testCases {
+		t.Run(fmt.Sprintf("test-%d", num), func(t *testing.T) {
+			for _, check := range tc.checks {
+				got, err := tc.provider.findItem(check.findItemName)
+				notes := fmt.Sprintf(setupCheckFormat, tc.setupNote, check.checkNote)
+				if check.expectedErr == nil && err != nil {
+					// expected no error, got one
+					t.Errorf(findItemErrFormat, notes, nil, err)
+				}
+				if check.expectedErr != nil && err == nil {
+					// expected an error, didn't get one
+					t.Errorf(findItemErrFormat, notes, check.expectedErr.Error(), nil)
+				}
+				if check.expectedErr != nil && err != nil && err.Error() != check.expectedErr.Error() {
+					// expected an error, got the wrong one
+					t.Errorf(findItemErrFormat, notes, check.expectedErr.Error(), err.Error())
+				}
+				if check.expectedItem != nil {
+					if !reflect.DeepEqual(check.expectedItem, got) {
+						// expected a predefined item, got something else
+						t.Errorf(findItemErrFormat, notes, check.expectedItem, got)
+					}
 				}
 			}
-		}
+		})
 	}
 }