Przeglądaj źródła

oracle: Get secret by name from a specific vault

Elad Gabay 4 lat temu
rodzic
commit
cab49e57f7

+ 4 - 4
pkg/provider/oracle/fake/fake.go

@@ -20,16 +20,16 @@ import (
 )
 
 type OracleMockClient struct {
-	getSecret func(ctx context.Context, request secrets.GetSecretBundleRequest) (response secrets.GetSecretBundleResponse, err error)
+	getSecret func(ctx context.Context, request secrets.GetSecretBundleByNameRequest) (response secrets.GetSecretBundleByNameResponse, err error)
 }
 
-func (mc *OracleMockClient) GetSecretBundle(ctx context.Context, request secrets.GetSecretBundleRequest) (response secrets.GetSecretBundleResponse, err error) {
+func (mc *OracleMockClient) GetSecretBundleByName(ctx context.Context, request secrets.GetSecretBundleByNameRequest) (response secrets.GetSecretBundleByNameResponse, err error) {
 	return mc.getSecret(ctx, request)
 }
 
-func (mc *OracleMockClient) WithValue(input secrets.GetSecretBundleRequest, output secrets.GetSecretBundleResponse, err error) {
+func (mc *OracleMockClient) WithValue(input secrets.GetSecretBundleByNameRequest, output secrets.GetSecretBundleByNameResponse, err error) {
 	if mc != nil {
-		mc.getSecret = func(ctx context.Context, paramReq secrets.GetSecretBundleRequest) (secrets.GetSecretBundleResponse, error) {
+		mc.getSecret = func(ctx context.Context, paramReq secrets.GetSecretBundleByNameRequest) (secrets.GetSecretBundleByNameResponse, error) {
 			return output, err
 		}
 	}

+ 19 - 10
pkg/provider/oracle/oracle.go

@@ -18,7 +18,7 @@ import (
 	"fmt"
 
 	"github.com/oracle/oci-go-sdk/v45/common"
-	secrets "github.com/oracle/oci-go-sdk/v45/secrets"
+	"github.com/oracle/oci-go-sdk/v45/secrets"
 	"github.com/tidwall/gjson"
 	corev1 "k8s.io/api/core/v1"
 	"k8s.io/apimachinery/pkg/types"
@@ -46,6 +46,7 @@ const (
 	errMissingTenancy                        = "missing Tenancy ID"
 	errMissingRegion                         = "missing Region"
 	errMissingFingerprint                    = "missing Fingerprint"
+	errMissingVault                          = "missing Vault"
 	errJSONSecretUnmarshal                   = "unable to unmarshal secret: %w"
 	errMissingKey                            = "missing Key in secret: %s"
 	errUnexpectedContent                     = "unexpected secret bundle content"
@@ -65,10 +66,11 @@ type client struct {
 
 type VaultManagementService struct {
 	Client VMInterface
+	vault  string
 }
 
 type VMInterface interface {
-	GetSecretBundle(ctx context.Context, request secrets.GetSecretBundleRequest) (response secrets.GetSecretBundleResponse, err error)
+	GetSecretBundleByName(ctx context.Context, request secrets.GetSecretBundleByNameRequest) (secrets.GetSecretBundleByNameResponse, error)
 }
 
 func (c *client) setAuth(ctx context.Context) error {
@@ -127,22 +129,22 @@ func (vms *VaultManagementService) GetSecret(ctx context.Context, ref esv1alpha1
 	if utils.IsNil(vms.Client) {
 		return nil, fmt.Errorf(errUninitalizedOracleProvider)
 	}
-	sec, err := vms.Client.GetSecretBundle(ctx, secrets.GetSecretBundleRequest{
-		SecretId: &ref.Key,
-		Stage:    secrets.GetSecretBundleStageEnum(ref.Version),
-	})
 
+	sec, err := vms.Client.GetSecretBundleByName(ctx, secrets.GetSecretBundleByNameRequest{
+		VaultId:    &vms.vault,
+		SecretName: &ref.Key,
+		Stage:      secrets.GetSecretBundleByNameStageEnum(ref.Version),
+	})
 	if err != nil {
 		return nil, util.SanitizeErr(err)
 	}
-	// TODO: should bt.Content be base64 decoded??
+
 	bt, ok := sec.SecretBundleContent.(secrets.Base64SecretBundleContentDetails)
 	if !ok {
 		return nil, fmt.Errorf(errUnexpectedContent)
 	}
 
 	payload, err := base64.StdEncoding.DecodeString(*bt.Content)
-
 	if err != nil {
 		return nil, err
 	}
@@ -182,6 +184,10 @@ func (vms *VaultManagementService) NewClient(ctx context.Context, store esv1alph
 	storeSpec := store.GetSpec()
 	oracleSpec := storeSpec.Provider.Oracle
 
+	if oracleSpec.Vault == "" {
+		return nil, fmt.Errorf(errMissingVault)
+	}
+
 	oracleStore := &client{
 		kube:      kube,
 		store:     oracleSpec,
@@ -204,8 +210,11 @@ func (vms *VaultManagementService) NewClient(ctx context.Context, store esv1alph
 	if err != nil {
 		return nil, fmt.Errorf(errOracleClient, err)
 	}
-	vms.Client = secretManagementService
-	return vms, nil
+
+	return &VaultManagementService{
+		Client: secretManagementService,
+		vault:  oracleSpec.Vault,
+	}, nil
 }
 
 func (vms *VaultManagementService) Close(ctx context.Context) error {

+ 9 - 10
pkg/provider/oracle/oracle_test.go

@@ -28,8 +28,8 @@ import (
 
 type vaultTestCase struct {
 	mockClient     *fakeoracle.OracleMockClient
-	apiInput       *secrets.GetSecretBundleRequest
-	apiOutput      *secrets.GetSecretBundleResponse
+	apiInput       *secrets.GetSecretBundleByNameRequest
+	apiOutput      *secrets.GetSecretBundleByNameResponse
 	ref            *esv1alpha1.ExternalSecretDataRemoteRef
 	apiErr         error
 	expectError    string
@@ -60,15 +60,15 @@ func makeValidRef() *esv1alpha1.ExternalSecretDataRemoteRef {
 	}
 }
 
-func makeValidAPIInput() *secrets.GetSecretBundleRequest {
-	return &secrets.GetSecretBundleRequest{
-		SecretId: utilpointer.StringPtr("test-secret"),
+func makeValidAPIInput() *secrets.GetSecretBundleByNameRequest {
+	return &secrets.GetSecretBundleByNameRequest{
+		SecretName: utilpointer.StringPtr("test-secret"),
+		VaultId: utilpointer.StringPtr("test-vault"),
 	}
 }
 
-func makeValidAPIOutput() *secrets.GetSecretBundleResponse {
-	return &secrets.GetSecretBundleResponse{
-		Etag:         utilpointer.StringPtr("test-name"),
+func makeValidAPIOutput() *secrets.GetSecretBundleByNameResponse {
+	return &secrets.GetSecretBundleByNameResponse{
 		SecretBundle: secrets.SecretBundle{},
 	}
 }
@@ -99,8 +99,7 @@ func TestOracleVaultGetSecret(t *testing.T) {
 	// good case: default version is set
 	// key is passed in, output is sent back
 	setSecretString := func(smtc *vaultTestCase) {
-		smtc.apiOutput = &secrets.GetSecretBundleResponse{
-			Etag: utilpointer.StringPtr("test-name"),
+		smtc.apiOutput = &secrets.GetSecretBundleByNameResponse{
 			SecretBundle: secrets.SecretBundle{
 				SecretId:      utilpointer.StringPtr("test-id"),
 				VersionNumber: utilpointer.Int64(1),