|
|
@@ -21,6 +21,8 @@ import (
|
|
|
"crypto/x509"
|
|
|
"fmt"
|
|
|
"os"
|
|
|
+ "path/filepath"
|
|
|
+ "strings"
|
|
|
)
|
|
|
|
|
|
const (
|
|
|
@@ -57,8 +59,14 @@ func DefaultTLSConfig() *TLSConfig {
|
|
|
// This enables mTLS, requiring and verifying client certificates.
|
|
|
func LoadTLSConfig(config *TLSConfig) (*tls.Config, error) {
|
|
|
// Load server certificate and key
|
|
|
- certPath := fmt.Sprintf("%s/%s", config.CertDir, config.CertFile)
|
|
|
- keyPath := fmt.Sprintf("%s/%s", config.CertDir, config.KeyFile)
|
|
|
+ certPath, err := resolveCertPath(config.CertDir, config.CertFile)
|
|
|
+ if err != nil {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+ keyPath, err := resolveCertPath(config.CertDir, config.KeyFile)
|
|
|
+ if err != nil {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
|
|
|
cert, err := tls.LoadX509KeyPair(certPath, keyPath)
|
|
|
if err != nil {
|
|
|
@@ -66,8 +74,11 @@ func LoadTLSConfig(config *TLSConfig) (*tls.Config, error) {
|
|
|
}
|
|
|
|
|
|
// Load CA certificate for client verification
|
|
|
- caPath := fmt.Sprintf("%s/%s", config.CertDir, config.CACertFile)
|
|
|
- // #nosec G304 -- TLSConfig paths are explicit operator-provided mount points.
|
|
|
+ caPath, err := resolveCertPath(config.CertDir, config.CACertFile)
|
|
|
+ if err != nil {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+ // #nosec G304 -- resolveCertPath constrains file names to direct children of CertDir.
|
|
|
caCert, err := os.ReadFile(caPath)
|
|
|
if err != nil {
|
|
|
return nil, fmt.Errorf("failed to load CA certificate: %w", err)
|
|
|
@@ -108,3 +119,23 @@ func getEnvOrDefault(key, defaultValue string) string {
|
|
|
}
|
|
|
return defaultValue
|
|
|
}
|
|
|
+
|
|
|
+func resolveCertPath(certDir, fileName string) (string, error) {
|
|
|
+ cleanDir := filepath.Clean(certDir)
|
|
|
+ cleanFile := filepath.Clean(fileName)
|
|
|
+
|
|
|
+ if filepath.IsAbs(cleanFile) || cleanFile == "." || cleanFile == ".." || cleanFile != filepath.Base(cleanFile) {
|
|
|
+ return "", fmt.Errorf("invalid TLS file name %q", fileName)
|
|
|
+ }
|
|
|
+
|
|
|
+ fullPath := filepath.Join(cleanDir, cleanFile)
|
|
|
+ relPath, err := filepath.Rel(cleanDir, fullPath)
|
|
|
+ if err != nil {
|
|
|
+ return "", fmt.Errorf("resolve TLS file path: %w", err)
|
|
|
+ }
|
|
|
+ if relPath == ".." || strings.HasPrefix(relPath, ".."+string(filepath.Separator)) {
|
|
|
+ return "", fmt.Errorf("invalid TLS file name %q", fileName)
|
|
|
+ }
|
|
|
+
|
|
|
+ return fullPath, nil
|
|
|
+}
|