Skip to content

Commit

Permalink
webhook: migrate to url endpoint
Browse files Browse the repository at this point in the history
Signed-off-by: l1b0k <[email protected]>
  • Loading branch information
l1b0k committed Dec 25, 2024
1 parent 1c70388 commit 9c6817e
Show file tree
Hide file tree
Showing 4 changed files with 191 additions and 57 deletions.
2 changes: 1 addition & 1 deletion cmd/terway-controlplane/terway-controlplane.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ func main() {
options := newOption(cfg)

if !cfg.DisableWebhook {
err = cert.SyncCert(ctx, directClient, cfg.ControllerNamespace, cfg.ControllerName, cfg.ClusterDomain, cfg.CertDir)
err = cert.SyncCert(ctx, directClient, cfg.ControllerNamespace, cfg.ControllerName, cfg.ClusterDomain, cfg.CertDir, cfg.ClusterID, cfg.WebhookURLMode)
if err != nil {
panic(err)
}
Expand Down
160 changes: 104 additions & 56 deletions pkg/cert/webhook.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@ import (
"math/big"
"os"
"path/filepath"
"reflect"
"time"

"k8s.io/apimachinery/pkg/util/sets"
ctrl "sigs.k8s.io/controller-runtime"
"sigs.k8s.io/controller-runtime/pkg/client"

Expand All @@ -36,20 +38,47 @@ const (
)

// SyncCert sync cert for webhook
func SyncCert(ctx context.Context, c client.Client, ns, name, domain, certDir string) error {
secretName := fmt.Sprintf("%s-webhook-cert", name)
func SyncCert(ctx context.Context, c client.Client, serviceNamespace, serviceName, clusterDomain, certDir, clusterID string, urlMode bool) error {
secretName := fmt.Sprintf("%s-webhook-cert", serviceName)
// check secret
var serverCertBytes, serverKeyBytes, caCertBytes []byte

// get cert from secret or generate it
existSecret := &corev1.Secret{}
err := c.Get(ctx, types.NamespacedName{Namespace: ns, Name: secretName}, existSecret)
err := c.Get(ctx, types.NamespacedName{Namespace: serviceNamespace, Name: secretName}, existSecret)
if err != nil {
if !errors.IsNotFound(err) {
return fmt.Errorf("error get cert from secret, %w", err)
}
}

serverCertBytes = existSecret.Data[serverCertKey]
serverKeyBytes = existSecret.Data[serverKeyKey]
caCertBytes = existSecret.Data[caCertKey]
urlEndpoint := ""

cn := fmt.Sprintf("%s.%s.svc", serviceName, serviceNamespace)
dnsNames := []string{
serviceName,
fmt.Sprintf("%s.%s", serviceName, serviceNamespace),
fmt.Sprintf("%s.%s.svc", serviceName, serviceNamespace),
}
if urlMode {
if clusterID == "" {
return fmt.Errorf("clusterID is required in urlMode")
}
dnsNames = append(dnsNames,
fmt.Sprintf("%s.%s", serviceName, clusterID),
fmt.Sprintf("%s.%s.svc", serviceName, clusterID),
fmt.Sprintf("%s.%s.svc.cluster.local", serviceName, clusterID),
)

urlEndpoint = fmt.Sprintf("https://%s.%s.svc.cluster.local/mutating", serviceName, clusterID)
}

if needRegenerate(serverCertBytes, dnsNames) {
// create certs
s, err := GenerateCerts(ns, name, domain)
s, err := GenerateCerts(cn, clusterDomain, dnsNames)
if err != nil {
return fmt.Errorf("error generate cert, %w", err)
}
Expand All @@ -58,7 +87,7 @@ func SyncCert(ctx context.Context, c client.Client, ns, name, domain, certDir st
serverKeyBytes = s.Data[serverKeyKey]
caCertBytes = s.Data[caCertKey]
s.Name = secretName
s.Namespace = ns
s.Namespace = serviceNamespace
// create secret this make sure one is the leader
err = c.Create(ctx, s)
if err != nil {
Expand All @@ -67,19 +96,16 @@ func SyncCert(ctx context.Context, c client.Client, ns, name, domain, certDir st
}

secret := &corev1.Secret{}
err = c.Get(ctx, types.NamespacedName{Namespace: ns, Name: secretName}, secret)
err = c.Get(ctx, types.NamespacedName{Namespace: serviceNamespace, Name: secretName}, secret)
if err != nil {
return fmt.Errorf("error get cert from secret, %w", err)
}
serverCertBytes = secret.Data[serverCertKey]
serverKeyBytes = secret.Data[serverKeyKey]
caCertBytes = secret.Data[caCertKey]
}
} else {
serverCertBytes = existSecret.Data[serverCertKey]
serverKeyBytes = existSecret.Data[serverKeyKey]
caCertBytes = existSecret.Data[caCertKey]
}

if len(serverCertBytes) == 0 || len(serverKeyBytes) == 0 || len(caCertBytes) == 0 {
return fmt.Errorf("invalid cert")
}
Expand All @@ -105,76 +131,100 @@ func SyncCert(ctx context.Context, c client.Client, ns, name, domain, certDir st

// update webhook
mutatingWebhook := &admissionregistrationv1.MutatingWebhookConfiguration{}
err = c.Get(ctx, types.NamespacedName{Namespace: ns, Name: name}, mutatingWebhook)
err = c.Get(ctx, types.NamespacedName{Namespace: serviceNamespace, Name: serviceName}, mutatingWebhook)
if err != nil {
return err
}
if len(mutatingWebhook.Webhooks) == 0 {
return fmt.Errorf("no webhook config found")
}
oldMutatingWebhook := mutatingWebhook.DeepCopy()
changed := false
for i, hook := range mutatingWebhook.Webhooks {
if len(hook.ClientConfig.CABundle) != 0 {
continue
}
changed = true
// patch ca
mutatingWebhook.Webhooks[i].ClientConfig.CABundle = caCertBytes
}
if changed {
err = wait.ExponentialBackoffWithContext(ctx, utils.DefaultPatchBackoff, func(ctx context.Context) (done bool, err error) {
innerErr := c.Patch(ctx, mutatingWebhook, client.StrategicMergeFrom(oldMutatingWebhook))
if innerErr != nil {
log.Error(innerErr, "error patch ca")
return false, nil
}
return true, nil
})

if err != nil {
return err
}
log.Info("update MutatingWebhook ca bundle success")
}

validateWebhook := &admissionregistrationv1.ValidatingWebhookConfiguration{}
err = c.Get(ctx, types.NamespacedName{Namespace: ns, Name: name}, validateWebhook)
err = c.Get(ctx, types.NamespacedName{Namespace: serviceNamespace, Name: serviceName}, validateWebhook)
if err != nil {
return err
}
if len(validateWebhook.Webhooks) == 0 {
return fmt.Errorf("no webhook config found")
}
changed = false

oldMutatingWebhook := mutatingWebhook.DeepCopy()
oldValidateWebhook := validateWebhook.DeepCopy()
for i, hook := range validateWebhook.Webhooks {
if len(hook.ClientConfig.CABundle) != 0 {
continue

for i, _ := range mutatingWebhook.Webhooks {
if string(mutatingWebhook.Webhooks[i].ClientConfig.CABundle) != string(caCertBytes) {
mutatingWebhook.Webhooks[i].ClientConfig.CABundle = caCertBytes
}
if urlMode {
mutatingWebhook.Webhooks[i].ClientConfig.Service = nil
mutatingWebhook.Webhooks[i].ClientConfig.URL = &urlEndpoint
}
// patch ca
validateWebhook.Webhooks[i].ClientConfig.CABundle = caCertBytes
changed = true
}
if changed {
err = wait.ExponentialBackoffWithContext(ctx, utils.DefaultPatchBackoff, func(ctx context.Context) (done bool, err error) {
innerErr := c.Patch(ctx, validateWebhook, client.StrategicMergeFrom(oldValidateWebhook))

for i, _ := range validateWebhook.Webhooks {
if string(validateWebhook.Webhooks[i].ClientConfig.CABundle) != string(caCertBytes) {
validateWebhook.Webhooks[i].ClientConfig.CABundle = caCertBytes
}
if urlMode {
validateWebhook.Webhooks[i].ClientConfig.Service = nil
validateWebhook.Webhooks[i].ClientConfig.URL = &urlEndpoint
}
}

err = updateWebhook(ctx, c, mutatingWebhook, oldMutatingWebhook)
if err != nil {
return err
}

err = updateWebhook(ctx, c, validateWebhook, oldValidateWebhook)
if err != nil {
return err
}

return nil
}

func updateWebhook(ctx context.Context, c client.Client, oldObj, newObj client.Object) error {
if !reflect.DeepEqual(newObj, oldObj) {
err := wait.ExponentialBackoffWithContext(ctx, utils.DefaultPatchBackoff, func(ctx context.Context) (done bool, err error) {
innerErr := c.Patch(ctx, newObj, client.StrategicMergeFrom(oldObj))
if innerErr != nil {
log.Error(innerErr, "error patch ca")
return false, nil
}
return true, nil
})

if err != nil {
return err
}
log.Info("update ValidatingWebhook ca bundle success")
log.Info("update webhook success")
}

return nil
}

func GenerateCerts(serviceNamespace, serviceName, clusterDomain string) (*corev1.Secret, error) {
// needRegenerate check exist cert with dnsNames,
func needRegenerate(certPEM []byte, dnsNames []string) bool {
if certPEM == nil {
return true
}
block, _ := pem.Decode(certPEM)
if block == nil {
log.Info("failed to decode PEM block containing the certificate")
return true
}
// Parse the certificate
cert, err := x509.ParseCertificate(block.Bytes)
if err != nil {
log.Error(err, "error parse cert")
return true
}
oldSet := sets.New[string](cert.DNSNames...)
newSet := sets.New[string](dnsNames...)
return !oldSet.Equal(newSet)
}

func GenerateCerts(cn, clusterDomain string, dnsNames []string) (*corev1.Secret, error) {
var caPEM, serverCertPEM, serverPrivateKeyPEM *bytes.Buffer
ca := &x509.Certificate{
SerialNumber: big.NewInt(2021),
Expand All @@ -199,6 +249,9 @@ func GenerateCerts(serviceNamespace, serviceName, clusterDomain string) (*corev1
return nil, err
}

if len(dnsNames) == 0 {
return nil, fmt.Errorf("dnsNames is empty")
}
caPEM = new(bytes.Buffer)
err = pem.Encode(caPEM, &pem.Block{
Type: "CERTIFICATE",
Expand All @@ -208,16 +261,11 @@ func GenerateCerts(serviceNamespace, serviceName, clusterDomain string) (*corev1
return nil, err
}

commonName := fmt.Sprintf("%s.%s.svc", serviceName, serviceNamespace)
dnsNames := []string{serviceName,
fmt.Sprintf("%s.%s", serviceName, serviceNamespace),
commonName}

cert := &x509.Certificate{
SerialNumber: big.NewInt(2021),
DNSNames: dnsNames,
Subject: pkix.Name{
CommonName: commonName,
CommonName: cn,
Organization: []string{clusterDomain},
},
NotBefore: time.Now().Add(-time.Hour),
Expand Down
85 changes: 85 additions & 0 deletions pkg/cert/webhook_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
package cert

import (
"bytes"
"context"
"crypto/rand"
"testing"

"github.com/stretchr/testify/assert"
admissionregistrationv1 "k8s.io/api/admissionregistration/v1"
corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"sigs.k8s.io/controller-runtime/pkg/client/fake"
)

func TestGenerateCerts_ValidInput(t *testing.T) {
cn := "test-service"
clusterDomain := "cluster.local"
dnsNames := []string{"test-service", "test-service.namespace", "test-service.namespace.svc"}

secret, err := GenerateCerts(cn, clusterDomain, dnsNames)

assert.NoError(t, err)
assert.NotNil(t, secret)
assert.Equal(t, corev1.SecretTypeTLS, secret.Type)
assert.Contains(t, secret.Data, "ca.crt")
assert.Contains(t, secret.Data, "tls.crt")
assert.Contains(t, secret.Data, "tls.key")

need := needRegenerate(secret.Data["tls.crt"], dnsNames)
assert.False(t, need)
}

func TestGenerateCerts_EmptyDNSNames(t *testing.T) {
cn := "test-service"
clusterDomain := "cluster.local"
dnsNames := []string{}

secret, err := GenerateCerts(cn, clusterDomain, dnsNames)

assert.Error(t, err)
assert.Nil(t, secret)
}

func TestGenerateCerts_InvalidKeyGeneration(t *testing.T) {
cn := "test-service"
clusterDomain := "cluster.local"
dnsNames := []string{"test-service", "test-service.namespace", "test-service.namespace.svc"}

originalRandReader := rand.Reader
defer func() { rand.Reader = originalRandReader }()
rand.Reader = bytes.NewReader([]byte{})

secret, err := GenerateCerts(cn, clusterDomain, dnsNames)

assert.Error(t, err)
assert.Nil(t, secret)
}

func TestWebhookUpdatedSuccessfully(t *testing.T) {
ctx := context.TODO()
oldObj := &admissionregistrationv1.MutatingWebhookConfiguration{
ObjectMeta: metav1.ObjectMeta{
Name: "test-webhook",
},
}
c := fake.NewClientBuilder().WithObjects(oldObj).Build()
newObj := oldObj.DeepCopy()
newObj.Webhooks = append(newObj.Webhooks, admissionregistrationv1.MutatingWebhook{})

err := updateWebhook(ctx, c, oldObj, newObj)

assert.NoError(t, err)
}

func TestWebhookNotUpdatedWhenEqual(t *testing.T) {
ctx := context.TODO()
c := fake.NewClientBuilder().Build()
oldObj := &admissionregistrationv1.MutatingWebhookConfiguration{}
newObj := oldObj.DeepCopy()

err := updateWebhook(ctx, c, oldObj, newObj)

assert.NoError(t, err)
}
1 change: 1 addition & 0 deletions types/controlplane/config_default.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ type Config struct {
DisableWebhook bool `json:"disableWebhook"`
WebhookPort int `json:"webhookPort" validate:"gt=0,lte=65535" mod:"default=4443"`
CertDir string `json:"certDir" validate:"required" mod:"default=/var/run/webhook-cert"`
WebhookURLMode bool `json:"webhookURLMode"`
LeaderElection bool `json:"leaderElection"`
RegisterEndpoint bool `json:"registerEndpoint"` // deprecated
EnableTrace bool `json:"enableTrace"`
Expand Down

0 comments on commit 9c6817e

Please sign in to comment.