Browse Source

feat(aws): allow custom endpoints

Moritz Johner 5 years ago
parent
commit
b8ecff54c0
2 changed files with 68 additions and 13 deletions
  1. 1 0
      .gitignore
  2. 67 13
      pkg/provider/aws/provider.go

+ 1 - 0
.gitignore

@@ -13,3 +13,4 @@ cover.out
 deploy/charts/external-secrets/templates/crds/*.yaml
 deploy/charts/external-secrets/templates/crds/*.yaml
 
 
 site/
 site/
+e2e/k8s/deploy

+ 67 - 13
pkg/provider/aws/provider.go

@@ -17,7 +17,9 @@ package aws
 import (
 import (
 	"context"
 	"context"
 	"fmt"
 	"fmt"
+	"os"
 
 
+	"github.com/aws/aws-sdk-go/aws/endpoints"
 	"github.com/aws/aws-sdk-go/aws/session"
 	"github.com/aws/aws-sdk-go/aws/session"
 	v1 "k8s.io/api/core/v1"
 	v1 "k8s.io/api/core/v1"
 	ctrl "sigs.k8s.io/controller-runtime"
 	ctrl "sigs.k8s.io/controller-runtime"
@@ -36,6 +38,25 @@ type Provider struct{}
 
 
 var log = ctrl.Log.WithName("provider").WithName("aws")
 var log = ctrl.Log.WithName("provider").WithName("aws")
 
 
+const (
+	SecretsManagerEndpointEnv = "AWS_SECRETSMANAGER_ENDPOINT"
+	STSEndpointEnv            = "AWS_STS_ENDPOINT"
+	SSMEndpointEnv            = "AWS_SSM_ENDPOINT"
+
+	errUnableCreateSession                     = "unable to create session: %w"
+	errUnknownProviderService                  = "unknown AWS Provider Service: %s"
+	errInvalidClusterStoreMissingAKIDNamespace = "invalid ClusterSecretStore: missing AWS AccessKeyID Namespace"
+	errInvalidClusterStoreMissingSAKNamespace  = "invalid ClusterSecretStore: missing AWS SecretAccessKey Namespace"
+	errFetchAKIDSecret                         = "could not fetch accessKeyID secret: %w"
+	errFetchSAKSecret                          = "could not fetch SecretAccessKey secret: %w"
+	errMissingSAK                              = "missing SecretAccessKey"
+	errMissingAKID                             = "missing AccessKeyID"
+	errNilStore                                = "found nil store"
+	errMissingStoreSpec                        = "store is missing spec"
+	errMissingProvider                         = "storeSpec is missing provider"
+	errInvalidProvider                         = "invalid provider spec. Missing AWS field in store %s"
+)
+
 // NewClient constructs a new secrets client based on the provided store.
 // NewClient constructs a new secrets client based on the provided store.
 func (p *Provider) NewClient(ctx context.Context, store esv1alpha1.GenericStore, kube client.Client, namespace string) (provider.SecretsClient, error) {
 func (p *Provider) NewClient(ctx context.Context, store esv1alpha1.GenericStore, kube client.Client, namespace string) (provider.SecretsClient, error) {
 	return newClient(ctx, store, kube, namespace, awssess.DefaultSTSProvider)
 	return newClient(ctx, store, kube, namespace, awssess.DefaultSTSProvider)
@@ -48,7 +69,7 @@ func newClient(ctx context.Context, store esv1alpha1.GenericStore, kube client.C
 	}
 	}
 	sess, err := newSession(ctx, store, kube, namespace, assumeRoler)
 	sess, err := newSession(ctx, store, kube, namespace, assumeRoler)
 	if err != nil {
 	if err != nil {
-		return nil, fmt.Errorf("unable to create session: %w", err)
+		return nil, fmt.Errorf(errUnableCreateSession, err)
 	}
 	}
 	switch prov.Service {
 	switch prov.Service {
 	case esv1alpha1.AWSServiceSecretsManager:
 	case esv1alpha1.AWSServiceSecretsManager:
@@ -56,7 +77,7 @@ func newClient(ctx context.Context, store esv1alpha1.GenericStore, kube client.C
 	case esv1alpha1.AWSServiceParameterStore:
 	case esv1alpha1.AWSServiceParameterStore:
 		return parameterstore.New(sess)
 		return parameterstore.New(sess)
 	}
 	}
-	return nil, fmt.Errorf("unknown AWS Provider Service: %s", prov.Service)
+	return nil, fmt.Errorf(errUnknownProviderService, prov.Service)
 }
 }
 
 
 // newSession creates a new aws session based on a store
 // newSession creates a new aws session based on a store
@@ -77,14 +98,14 @@ func newSession(ctx context.Context, store esv1alpha1.GenericStore, kube client.
 		// only ClusterStore is allowed to set namespace (and then it's required)
 		// only ClusterStore is allowed to set namespace (and then it's required)
 		if store.GetObjectKind().GroupVersionKind().Kind == esv1alpha1.ClusterSecretStoreKind {
 		if store.GetObjectKind().GroupVersionKind().Kind == esv1alpha1.ClusterSecretStoreKind {
 			if prov.Auth.SecretRef.AccessKeyID.Namespace == nil {
 			if prov.Auth.SecretRef.AccessKeyID.Namespace == nil {
-				return nil, fmt.Errorf("invalid ClusterSecretStore: missing AWS AccessKeyID Namespace")
+				return nil, fmt.Errorf(errInvalidClusterStoreMissingAKIDNamespace)
 			}
 			}
 			ke.Namespace = *prov.Auth.SecretRef.AccessKeyID.Namespace
 			ke.Namespace = *prov.Auth.SecretRef.AccessKeyID.Namespace
 		}
 		}
 		akSecret := v1.Secret{}
 		akSecret := v1.Secret{}
 		err := kube.Get(ctx, ke, &akSecret)
 		err := kube.Get(ctx, ke, &akSecret)
 		if err != nil {
 		if err != nil {
-			return nil, fmt.Errorf("could not fetch accessKeyID secret: %w", err)
+			return nil, fmt.Errorf(errFetchAKIDSecret, err)
 		}
 		}
 		ke = client.ObjectKey{
 		ke = client.ObjectKey{
 			Name:      prov.Auth.SecretRef.SecretAccessKey.Name,
 			Name:      prov.Auth.SecretRef.SecretAccessKey.Name,
@@ -93,47 +114,80 @@ func newSession(ctx context.Context, store esv1alpha1.GenericStore, kube client.
 		// only ClusterStore is allowed to set namespace (and then it's required)
 		// only ClusterStore is allowed to set namespace (and then it's required)
 		if store.GetObjectKind().GroupVersionKind().Kind == esv1alpha1.ClusterSecretStoreKind {
 		if store.GetObjectKind().GroupVersionKind().Kind == esv1alpha1.ClusterSecretStoreKind {
 			if prov.Auth.SecretRef.SecretAccessKey.Namespace == nil {
 			if prov.Auth.SecretRef.SecretAccessKey.Namespace == nil {
-				return nil, fmt.Errorf("invalid ClusterSecretStore: missing AWS SecretAccessKey Namespace")
+				return nil, fmt.Errorf(errInvalidClusterStoreMissingSAKNamespace)
 			}
 			}
 			ke.Namespace = *prov.Auth.SecretRef.SecretAccessKey.Namespace
 			ke.Namespace = *prov.Auth.SecretRef.SecretAccessKey.Namespace
 		}
 		}
 		sakSecret := v1.Secret{}
 		sakSecret := v1.Secret{}
 		err = kube.Get(ctx, ke, &sakSecret)
 		err = kube.Get(ctx, ke, &sakSecret)
 		if err != nil {
 		if err != nil {
-			return nil, fmt.Errorf("could not fetch SecretAccessKey secret: %w", err)
+			return nil, fmt.Errorf(errFetchSAKSecret, err)
 		}
 		}
 		sak = string(sakSecret.Data[prov.Auth.SecretRef.SecretAccessKey.Key])
 		sak = string(sakSecret.Data[prov.Auth.SecretRef.SecretAccessKey.Key])
 		aks = string(akSecret.Data[prov.Auth.SecretRef.AccessKeyID.Key])
 		aks = string(akSecret.Data[prov.Auth.SecretRef.AccessKeyID.Key])
 		if sak == "" {
 		if sak == "" {
-			return nil, fmt.Errorf("missing SecretAccessKey")
+			return nil, fmt.Errorf(errMissingSAK)
 		}
 		}
 		if aks == "" {
 		if aks == "" {
-			return nil, fmt.Errorf("missing AccessKeyID")
+			return nil, fmt.Errorf(errMissingAKID)
 		}
 		}
 	}
 	}
-	return awssess.New(sak, aks, prov.Region, prov.Role, assumeRoler)
+	session, err := awssess.New(sak, aks, prov.Region, prov.Role, assumeRoler)
+	if err != nil {
+		return nil, err
+	}
+	session.Config.EndpointResolver = ResolveEndpoint()
+	return session, nil
 }
 }
 
 
 // getAWSProvider does the necessary nil checks on the generic store
 // getAWSProvider does the necessary nil checks on the generic store
 // it returns the aws provider or an error.
 // it returns the aws provider or an error.
 func getAWSProvider(store esv1alpha1.GenericStore) (*esv1alpha1.AWSProvider, error) {
 func getAWSProvider(store esv1alpha1.GenericStore) (*esv1alpha1.AWSProvider, error) {
 	if store == nil {
 	if store == nil {
-		return nil, fmt.Errorf("found nil store")
+		return nil, fmt.Errorf(errNilStore)
 	}
 	}
 	spc := store.GetSpec()
 	spc := store.GetSpec()
 	if spc == nil {
 	if spc == nil {
-		return nil, fmt.Errorf("store is missing spec")
+		return nil, fmt.Errorf(errMissingStoreSpec)
 	}
 	}
 	if spc.Provider == nil {
 	if spc.Provider == nil {
-		return nil, fmt.Errorf("storeSpec is missing provider")
+		return nil, fmt.Errorf(errMissingProvider)
 	}
 	}
 	prov := spc.Provider.AWS
 	prov := spc.Provider.AWS
 	if prov == nil {
 	if prov == nil {
-		return nil, fmt.Errorf("invalid provider spec. Missing AWS field in store %s", store.GetObjectMeta().String())
+		return nil, fmt.Errorf(errInvalidProvider, store.GetObjectMeta().String())
 	}
 	}
 	return prov, nil
 	return prov, nil
 }
 }
 
 
+// ResolveEndpoint returns a ResolverFunc with
+// customizable endpoints
+func ResolveEndpoint() endpoints.ResolverFunc {
+	customEndpoints := make(map[string]string)
+	if v := os.Getenv(SecretsManagerEndpointEnv); v != "" {
+		customEndpoints["secretsmanager"] = v
+	}
+	if v := os.Getenv(SSMEndpointEnv); v != "" {
+		customEndpoints["ssm"] = v
+	}
+	if v := os.Getenv(STSEndpointEnv); v != "" {
+		customEndpoints["sts"] = v
+	}
+	return ResolveEndpointWithServiceMap(customEndpoints)
+}
+
+func ResolveEndpointWithServiceMap(customEndpoints map[string]string) endpoints.ResolverFunc {
+	defaultResolver := endpoints.DefaultResolver()
+	return func(service, region string, opts ...func(*endpoints.Options)) (endpoints.ResolvedEndpoint, error) {
+		if ep, ok := customEndpoints[service]; ok {
+			return endpoints.ResolvedEndpoint{
+				URL: ep,
+			}, nil
+		}
+		return defaultResolver.EndpointFor(service, region, opts...)
+	}
+}
+
 func init() {
 func init() {
 	schema.Register(&Provider{}, &esv1alpha1.SecretStoreProvider{
 	schema.Register(&Provider{}, &esv1alpha1.SecretStoreProvider{
 		AWS: &esv1alpha1.AWSProvider{},
 		AWS: &esv1alpha1.AWSProvider{},