pem.go 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. /*
  2. Copyright © The ESO Authors
  3. Licensed under the Apache License, Version 2.0 (the "License");
  4. you may not use this file except in compliance with the License.
  5. You may obtain a copy of the License at
  6. https://www.apache.org/licenses/LICENSE-2.0
  7. Unless required by applicable law or agreed to in writing, software
  8. distributed under the License is distributed on an "AS IS" BASIS,
  9. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. See the License for the specific language governing permissions and
  11. limitations under the License.
  12. */
  13. package template
  14. import (
  15. "bytes"
  16. "crypto/x509"
  17. "encoding/pem"
  18. "errors"
  19. "fmt"
  20. "strings"
  21. )
  22. const (
  23. errJunk = "error filtering pem: found junk"
  24. certTypeLeaf = "leaf"
  25. certTypeIntermediate = "intermediate"
  26. certTypeRoot = "root"
  27. )
  28. func filterPEM(pemType, input string) (string, error) {
  29. input = trimJunk(input)
  30. data := []byte(input)
  31. var blocks []byte
  32. var block *pem.Block
  33. var rest []byte
  34. for {
  35. block, rest = pem.Decode(data)
  36. data = rest
  37. if block == nil {
  38. break
  39. }
  40. if !strings.EqualFold(block.Type, pemType) {
  41. continue
  42. }
  43. var buf bytes.Buffer
  44. err := pem.Encode(&buf, block)
  45. if err != nil {
  46. return "", err
  47. }
  48. blocks = append(blocks, buf.Bytes()...)
  49. }
  50. if len(blocks) == 0 && len(rest) != 0 {
  51. return "", errors.New(errJunk)
  52. }
  53. return string(blocks), nil
  54. }
  55. // trimJunk performs the same operation pem.Decode did until:
  56. // https://github.com/golang/go/issues/76124
  57. func trimJunk(input string) string {
  58. index := strings.Index(input, "-----BEGIN")
  59. if index == -1 {
  60. return input
  61. }
  62. return input[index:]
  63. }
  64. func filterCertChain(certType, input string) (string, error) {
  65. ordered, err := fetchX509CertChains([]byte(input))
  66. if err != nil {
  67. return "", err
  68. }
  69. switch certType {
  70. case certTypeLeaf:
  71. cert := ordered[0]
  72. if cert.AuthorityKeyId != nil && !bytes.Equal(cert.AuthorityKeyId, cert.SubjectKeyId) {
  73. return pemEncode(ordered[0].Raw, pemTypeCertificate)
  74. }
  75. case certTypeIntermediate:
  76. if len(ordered) < 2 {
  77. return "", nil
  78. }
  79. var pemData []byte
  80. for _, cert := range ordered[1:] {
  81. if isRootCertificate(cert) {
  82. break
  83. }
  84. b := &pem.Block{
  85. Type: pemTypeCertificate,
  86. Bytes: cert.Raw,
  87. }
  88. pemData = append(pemData, pem.EncodeToMemory(b)...)
  89. }
  90. return string(pemData), nil
  91. case certTypeRoot:
  92. cert := ordered[len(ordered)-1]
  93. if isRootCertificate(cert) {
  94. return pemEncode(cert.Raw, pemTypeCertificate)
  95. }
  96. }
  97. return "", nil
  98. }
  99. func isRootCertificate(cert *x509.Certificate) bool {
  100. return cert.AuthorityKeyId == nil || bytes.Equal(cert.AuthorityKeyId, cert.SubjectKeyId)
  101. }
  102. // certSANs extracts Subject Alternative Names (SANs) from a PEM-encoded certificate.
  103. // It returns a list of all SANs including DNS names, IP addresses, email addresses, and URIs.
  104. func certSANs(input string) ([]string, error) {
  105. input = trimJunk(input)
  106. block, _ := pem.Decode([]byte(input))
  107. if block == nil {
  108. return nil, fmt.Errorf("failed to decode PEM block")
  109. }
  110. cert, err := x509.ParseCertificate(block.Bytes)
  111. if err != nil {
  112. return nil, fmt.Errorf("failed to parse certificate: %w", err)
  113. }
  114. sans := make([]string, 0, len(cert.DNSNames)+len(cert.IPAddresses)+len(cert.EmailAddresses)+len(cert.URIs))
  115. sans = append(sans, cert.DNSNames...)
  116. for _, ip := range cert.IPAddresses {
  117. sans = append(sans, ip.String())
  118. }
  119. sans = append(sans, cert.EmailAddresses...)
  120. for _, uri := range cert.URIs {
  121. sans = append(sans, uri.String())
  122. }
  123. return sans, nil
  124. }
  125. func pemEncode(thing []byte, kind string) (string, error) {
  126. buf := bytes.NewBuffer(nil)
  127. err := pem.Encode(buf, &pem.Block{Type: kind, Bytes: thing})
  128. return buf.String(), err
  129. }