| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152 |
- /*
- Copyright © The ESO Authors
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
- https://www.apache.org/licenses/LICENSE-2.0
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
- */
- package template
- import (
- "bytes"
- "crypto/x509"
- "encoding/pem"
- "errors"
- "fmt"
- "strings"
- )
- const (
- errJunk = "error filtering pem: found junk"
- certTypeLeaf = "leaf"
- certTypeIntermediate = "intermediate"
- certTypeRoot = "root"
- )
- func filterPEM(pemType, input string) (string, error) {
- input = trimJunk(input)
- data := []byte(input)
- var blocks []byte
- var block *pem.Block
- var rest []byte
- for {
- block, rest = pem.Decode(data)
- data = rest
- if block == nil {
- break
- }
- if !strings.EqualFold(block.Type, pemType) {
- continue
- }
- var buf bytes.Buffer
- err := pem.Encode(&buf, block)
- if err != nil {
- return "", err
- }
- blocks = append(blocks, buf.Bytes()...)
- }
- if len(blocks) == 0 && len(rest) != 0 {
- return "", errors.New(errJunk)
- }
- return string(blocks), nil
- }
- // trimJunk performs the same operation pem.Decode did until:
- // https://github.com/golang/go/issues/76124
- func trimJunk(input string) string {
- index := strings.Index(input, "-----BEGIN")
- if index == -1 {
- return input
- }
- return input[index:]
- }
- func filterCertChain(certType, input string) (string, error) {
- ordered, err := fetchX509CertChains([]byte(input))
- if err != nil {
- return "", err
- }
- switch certType {
- case certTypeLeaf:
- cert := ordered[0]
- if cert.AuthorityKeyId != nil && !bytes.Equal(cert.AuthorityKeyId, cert.SubjectKeyId) {
- return pemEncode(ordered[0].Raw, pemTypeCertificate)
- }
- case certTypeIntermediate:
- if len(ordered) < 2 {
- return "", nil
- }
- var pemData []byte
- for _, cert := range ordered[1:] {
- if isRootCertificate(cert) {
- break
- }
- b := &pem.Block{
- Type: pemTypeCertificate,
- Bytes: cert.Raw,
- }
- pemData = append(pemData, pem.EncodeToMemory(b)...)
- }
- return string(pemData), nil
- case certTypeRoot:
- cert := ordered[len(ordered)-1]
- if isRootCertificate(cert) {
- return pemEncode(cert.Raw, pemTypeCertificate)
- }
- }
- return "", nil
- }
- func isRootCertificate(cert *x509.Certificate) bool {
- return cert.AuthorityKeyId == nil || bytes.Equal(cert.AuthorityKeyId, cert.SubjectKeyId)
- }
- // certSANs extracts Subject Alternative Names (SANs) from a PEM-encoded certificate.
- // It returns a list of all SANs including DNS names, IP addresses, email addresses, and URIs.
- func certSANs(input string) ([]string, error) {
- input = trimJunk(input)
- block, _ := pem.Decode([]byte(input))
- if block == nil {
- return nil, fmt.Errorf("failed to decode PEM block")
- }
- cert, err := x509.ParseCertificate(block.Bytes)
- if err != nil {
- return nil, fmt.Errorf("failed to parse certificate: %w", err)
- }
- sans := make([]string, 0, len(cert.DNSNames)+len(cert.IPAddresses)+len(cert.EmailAddresses)+len(cert.URIs))
- sans = append(sans, cert.DNSNames...)
- for _, ip := range cert.IPAddresses {
- sans = append(sans, ip.String())
- }
- sans = append(sans, cert.EmailAddresses...)
- for _, uri := range cert.URIs {
- sans = append(sans, uri.String())
- }
- return sans, nil
- }
- func pemEncode(thing []byte, kind string) (string, error) {
- buf := bytes.NewBuffer(nil)
- err := pem.Encode(buf, &pem.Block{Type: kind, Bytes: thing})
- return buf.String(), err
- }
|