Skip to content

Commit

Permalink
webhook: migrate to url endpoint
Browse files Browse the repository at this point in the history
Signed-off-by: l1b0k <[email protected]>
  • Loading branch information
l1b0k committed Dec 25, 2024
1 parent 1c70388 commit 1c4aa0b
Show file tree
Hide file tree
Showing 4 changed files with 409 additions and 101 deletions.
2 changes: 1 addition & 1 deletion cmd/terway-controlplane/terway-controlplane.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ func main() {
options := newOption(cfg)

if !cfg.DisableWebhook {
err = cert.SyncCert(ctx, directClient, cfg.ControllerNamespace, cfg.ControllerName, cfg.ClusterDomain, cfg.CertDir)
err = cert.SyncCert(ctx, directClient, cfg.ControllerNamespace, cfg.ControllerName, cfg.ClusterDomain, cfg.CertDir, cfg.ClusterID, cfg.WebhookURLMode)
if err != nil {
panic(err)
}
Expand Down
260 changes: 160 additions & 100 deletions pkg/cert/webhook.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,19 @@ import (
"math/big"
"os"
"path/filepath"
"reflect"
"time"

"k8s.io/apimachinery/pkg/util/sets"
"k8s.io/client-go/util/retry"
ctrl "sigs.k8s.io/controller-runtime"
"sigs.k8s.io/controller-runtime/pkg/client"
"sigs.k8s.io/controller-runtime/pkg/controller/controllerutil"

admissionregistrationv1 "k8s.io/api/admissionregistration/v1"
corev1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/errors"
"k8s.io/apimachinery/pkg/types"
"k8s.io/apimachinery/pkg/util/wait"

"github.com/AliyunContainerService/terway/pkg/utils"
)

var log = ctrl.Log.WithName("webhook-cert")
Expand All @@ -36,52 +37,31 @@ const (
)

// SyncCert sync cert for webhook
func SyncCert(ctx context.Context, c client.Client, ns, name, domain, certDir string) error {
secretName := fmt.Sprintf("%s-webhook-cert", name)
func SyncCert(ctx context.Context, c client.Client, serviceNamespace, serviceName, clusterDomain, certDir, clusterID string, urlMode bool) error {
// check secret
var serverCertBytes, serverKeyBytes, caCertBytes []byte
var caCertBytes []byte

// get cert from secret or generate it
existSecret := &corev1.Secret{}
err := c.Get(ctx, types.NamespacedName{Namespace: ns, Name: secretName}, existSecret)
var secret *corev1.Secret
err := retry.RetryOnConflict(retry.DefaultRetry, func() error {
var err error
secret, err = createOrUpdateCert(ctx, c, serviceNamespace, serviceName, clusterDomain, clusterID, urlMode)
return err
})
if err != nil {
if !errors.IsNotFound(err) {
return fmt.Errorf("error get cert from secret, %w", err)
}
// create certs
s, err := GenerateCerts(ns, name, domain)
if err != nil {
return fmt.Errorf("error generate cert, %w", err)
}
return err
}
caCertBytes = secret.Data[caCertKey]

serverCertBytes = s.Data[serverCertKey]
serverKeyBytes = s.Data[serverKeyKey]
caCertBytes = s.Data[caCertKey]
s.Name = secretName
s.Namespace = ns
// create secret this make sure one is the leader
err = c.Create(ctx, s)
if err != nil {
if !errors.IsAlreadyExists(err) {
return fmt.Errorf("error create cert to secret, %w", err)
}

secret := &corev1.Secret{}
err = c.Get(ctx, types.NamespacedName{Namespace: ns, Name: secretName}, secret)
if err != nil {
return fmt.Errorf("error get cert from secret, %w", err)
}
serverCertBytes = secret.Data[serverCertKey]
serverKeyBytes = secret.Data[serverKeyKey]
caCertBytes = secret.Data[caCertKey]
}
} else {
serverCertBytes = existSecret.Data[serverCertKey]
serverKeyBytes = existSecret.Data[serverKeyKey]
caCertBytes = existSecret.Data[caCertKey]
urlEndpoint := ""
if urlMode {
urlEndpoint = fmt.Sprintf("https://%s.%s.svc.cluster.local/mutating", serviceName, clusterID)
}
if len(serverCertBytes) == 0 || len(serverKeyBytes) == 0 || len(caCertBytes) == 0 {
return fmt.Errorf("invalid cert")

err = retry.RetryOnConflict(retry.DefaultRetry, func() error {
return createOrUpdateWebhook(ctx, c, serviceName, urlEndpoint, caCertBytes)
})
if err != nil {
return err
}

err = os.MkdirAll(certDir, os.ModeDir)
Expand All @@ -90,91 +70,173 @@ func SyncCert(ctx context.Context, c client.Client, ns, name, domain, certDir st
}

// write cert to file
err = os.WriteFile(filepath.Join(certDir, serverCertKey), serverCertBytes, os.ModePerm)
err = os.WriteFile(filepath.Join(certDir, serverCertKey), secret.Data[serverCertKey], os.ModePerm)
if err != nil {
return fmt.Errorf("error create secret file, %w", err)
}
err = os.WriteFile(filepath.Join(certDir, serverKeyKey), serverKeyBytes, os.ModePerm)
err = os.WriteFile(filepath.Join(certDir, serverKeyKey), secret.Data[serverKeyKey], os.ModePerm)
if err != nil {
return fmt.Errorf("error create secret file, %w", err)
}
err = os.WriteFile(filepath.Join(certDir, caCertKey), caCertBytes, os.ModePerm)
if err != nil {
return fmt.Errorf("error create secret file, %w", err)
}
return nil
}

func updateWebhook(ctx context.Context, c client.Client, oldObj, newObj client.Object) error {
if !reflect.DeepEqual(oldObj, newObj) {
err := c.Update(ctx, newObj)
if err != nil {
return err
}
log.Info("update webhook success")
}
return nil
}

func createOrUpdateCert(ctx context.Context, c client.Client, serviceNamespace, serviceName, clusterDomain, clusterID string, urlMode bool) (*corev1.Secret, error) {
secretName := fmt.Sprintf("%s-webhook-cert", serviceName)
cn := fmt.Sprintf("%s.%s.svc", serviceName, serviceNamespace)
dnsNames := []string{
serviceName,
fmt.Sprintf("%s.%s", serviceName, serviceNamespace),
fmt.Sprintf("%s.%s.svc", serviceName, serviceNamespace),
}
if urlMode {
if clusterID == "" {
return nil, fmt.Errorf("clusterID is required in urlMode")
}
dnsNames = append(dnsNames,
fmt.Sprintf("%s.%s", serviceName, clusterID),
fmt.Sprintf("%s.%s.svc", serviceName, clusterID),
fmt.Sprintf("%s.%s.svc.cluster.local", serviceName, clusterID),
)
}

// check secret
var serverCertBytes, serverKeyBytes, caCertBytes []byte

// get cert from secret or generate it
existSecret := &corev1.Secret{}
err := c.Get(ctx, types.NamespacedName{Namespace: serviceNamespace, Name: secretName}, existSecret)
if err != nil {
if !errors.IsNotFound(err) {
return nil, fmt.Errorf("error get cert from secret, %w", err)
}
}

if needRegenerate(existSecret.Data[serverCertKey], dnsNames) {
// create certs
s, err := GenerateCerts(cn, clusterDomain, dnsNames)
if err != nil {
return nil, fmt.Errorf("error generate cert, %w", err)
}

s.Name = secretName
s.Namespace = serviceNamespace

update := s.DeepCopy()
result, err := controllerutil.CreateOrUpdate(ctx, c, update, func() error {
update.Data = s.Data
return nil
})
if err != nil {
return nil, err
}
existSecret = update
log.Info("update secret", "result", result)
}

serverCertBytes = existSecret.Data[serverCertKey]
serverKeyBytes = existSecret.Data[serverKeyKey]
caCertBytes = existSecret.Data[caCertKey]

if len(serverCertBytes) == 0 || len(serverKeyBytes) == 0 || len(caCertBytes) == 0 {
return nil, fmt.Errorf("invalid cert")
}

return existSecret, nil
}

func createOrUpdateWebhook(ctx context.Context, c client.Client, hookName string, urlEndpoint string, caCertBytes []byte) error {
// update webhook
mutatingWebhook := &admissionregistrationv1.MutatingWebhookConfiguration{}
err = c.Get(ctx, types.NamespacedName{Namespace: ns, Name: name}, mutatingWebhook)
err := c.Get(ctx, types.NamespacedName{Name: hookName}, mutatingWebhook)
if err != nil {
return err
}
if len(mutatingWebhook.Webhooks) == 0 {
return fmt.Errorf("no webhook config found")
}
oldMutatingWebhook := mutatingWebhook.DeepCopy()
changed := false
for i, hook := range mutatingWebhook.Webhooks {
if len(hook.ClientConfig.CABundle) != 0 {
continue
}
changed = true
// patch ca
mutatingWebhook.Webhooks[i].ClientConfig.CABundle = caCertBytes
}
if changed {
err = wait.ExponentialBackoffWithContext(ctx, utils.DefaultPatchBackoff, func(ctx context.Context) (done bool, err error) {
innerErr := c.Patch(ctx, mutatingWebhook, client.StrategicMergeFrom(oldMutatingWebhook))
if innerErr != nil {
log.Error(innerErr, "error patch ca")
return false, nil
}
return true, nil
})

if err != nil {
return err
}
log.Info("update MutatingWebhook ca bundle success")
}

validateWebhook := &admissionregistrationv1.ValidatingWebhookConfiguration{}
err = c.Get(ctx, types.NamespacedName{Namespace: ns, Name: name}, validateWebhook)
err = c.Get(ctx, types.NamespacedName{Name: hookName}, validateWebhook)
if err != nil {
return err
}
if len(validateWebhook.Webhooks) == 0 {
return fmt.Errorf("no webhook config found")
}
changed = false

oldMutatingWebhook := mutatingWebhook.DeepCopy()
oldValidateWebhook := validateWebhook.DeepCopy()
for i, hook := range validateWebhook.Webhooks {
if len(hook.ClientConfig.CABundle) != 0 {
continue

for i, _ := range mutatingWebhook.Webhooks {
if string(mutatingWebhook.Webhooks[i].ClientConfig.CABundle) != string(caCertBytes) {
mutatingWebhook.Webhooks[i].ClientConfig.CABundle = caCertBytes
}
// patch ca
validateWebhook.Webhooks[i].ClientConfig.CABundle = caCertBytes
changed = true
}
if changed {
err = wait.ExponentialBackoffWithContext(ctx, utils.DefaultPatchBackoff, func(ctx context.Context) (done bool, err error) {
innerErr := c.Patch(ctx, validateWebhook, client.StrategicMergeFrom(oldValidateWebhook))
if innerErr != nil {
log.Error(innerErr, "error patch ca")
return false, nil
}
return true, nil
})
if err != nil {
return err
if urlEndpoint != "" {
mutatingWebhook.Webhooks[i].ClientConfig.Service = nil
mutatingWebhook.Webhooks[i].ClientConfig.URL = &urlEndpoint
}
}

for i, _ := range validateWebhook.Webhooks {
if string(validateWebhook.Webhooks[i].ClientConfig.CABundle) != string(caCertBytes) {
validateWebhook.Webhooks[i].ClientConfig.CABundle = caCertBytes
}
if urlEndpoint != "" {
validateWebhook.Webhooks[i].ClientConfig.Service = nil
validateWebhook.Webhooks[i].ClientConfig.URL = &urlEndpoint
}
log.Info("update ValidatingWebhook ca bundle success")
}

err = updateWebhook(ctx, c, oldMutatingWebhook, mutatingWebhook)
if err != nil {
return err
}

err = updateWebhook(ctx, c, oldValidateWebhook, validateWebhook)
if err != nil {
return err
}
return nil
}

func GenerateCerts(serviceNamespace, serviceName, clusterDomain string) (*corev1.Secret, error) {
// needRegenerate check exist cert with dnsNames,
func needRegenerate(certPEM []byte, dnsNames []string) bool {
if certPEM == nil {
return true
}
block, _ := pem.Decode(certPEM)
if block == nil {
log.Info("failed to decode PEM block containing the certificate")
return true
}
// Parse the certificate
cert, err := x509.ParseCertificate(block.Bytes)
if err != nil {
log.Error(err, "error parse cert")
return true
}
oldSet := sets.New[string](cert.DNSNames...)
newSet := sets.New[string](dnsNames...)
return !oldSet.Equal(newSet)
}

func GenerateCerts(cn, clusterDomain string, dnsNames []string) (*corev1.Secret, error) {
var caPEM, serverCertPEM, serverPrivateKeyPEM *bytes.Buffer
ca := &x509.Certificate{
SerialNumber: big.NewInt(2021),
Expand All @@ -199,6 +261,9 @@ func GenerateCerts(serviceNamespace, serviceName, clusterDomain string) (*corev1
return nil, err
}

if len(dnsNames) == 0 {
return nil, fmt.Errorf("dnsNames is empty")
}
caPEM = new(bytes.Buffer)
err = pem.Encode(caPEM, &pem.Block{
Type: "CERTIFICATE",
Expand All @@ -208,16 +273,11 @@ func GenerateCerts(serviceNamespace, serviceName, clusterDomain string) (*corev1
return nil, err
}

commonName := fmt.Sprintf("%s.%s.svc", serviceName, serviceNamespace)
dnsNames := []string{serviceName,
fmt.Sprintf("%s.%s", serviceName, serviceNamespace),
commonName}

cert := &x509.Certificate{
SerialNumber: big.NewInt(2021),
DNSNames: dnsNames,
Subject: pkix.Name{
CommonName: commonName,
CommonName: cn,
Organization: []string{clusterDomain},
},
NotBefore: time.Now().Add(-time.Hour),
Expand Down
Loading

0 comments on commit 1c4aa0b

Please sign in to comment.