diff --git a/cmd/terway-controlplane/terway-controlplane.go b/cmd/terway-controlplane/terway-controlplane.go index 573c258f..dbef03ff 100644 --- a/cmd/terway-controlplane/terway-controlplane.go +++ b/cmd/terway-controlplane/terway-controlplane.go @@ -141,7 +141,7 @@ func main() { options := newOption(cfg) if !cfg.DisableWebhook { - err = cert.SyncCert(ctx, directClient, cfg.ControllerNamespace, cfg.ControllerName, cfg.ClusterDomain, cfg.CertDir) + err = cert.SyncCert(ctx, directClient, cfg.ControllerNamespace, cfg.ControllerName, cfg.ClusterDomain, cfg.CertDir, cfg.ClusterID, cfg.WebhookURLMode) if err != nil { panic(err) } diff --git a/pkg/cert/webhook.go b/pkg/cert/webhook.go index 6e85f204..6b7d9efd 100644 --- a/pkg/cert/webhook.go +++ b/pkg/cert/webhook.go @@ -12,8 +12,10 @@ import ( "math/big" "os" "path/filepath" + "reflect" "time" + "k8s.io/apimachinery/pkg/util/sets" ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/controller-runtime/pkg/client" @@ -36,20 +38,47 @@ const ( ) // SyncCert sync cert for webhook -func SyncCert(ctx context.Context, c client.Client, ns, name, domain, certDir string) error { - secretName := fmt.Sprintf("%s-webhook-cert", name) +func SyncCert(ctx context.Context, c client.Client, serviceNamespace, serviceName, clusterDomain, certDir, clusterID string, urlMode bool) error { + secretName := fmt.Sprintf("%s-webhook-cert", serviceName) // check secret var serverCertBytes, serverKeyBytes, caCertBytes []byte // get cert from secret or generate it existSecret := &corev1.Secret{} - err := c.Get(ctx, types.NamespacedName{Namespace: ns, Name: secretName}, existSecret) + err := c.Get(ctx, types.NamespacedName{Namespace: serviceNamespace, Name: secretName}, existSecret) if err != nil { if !errors.IsNotFound(err) { return fmt.Errorf("error get cert from secret, %w", err) } + } + + serverCertBytes = existSecret.Data[serverCertKey] + serverKeyBytes = existSecret.Data[serverKeyKey] + caCertBytes = existSecret.Data[caCertKey] + urlEndpoint := "" + + cn := fmt.Sprintf("%s.%s.svc", serviceName, serviceNamespace) + dnsNames := []string{ + serviceName, + fmt.Sprintf("%s.%s", serviceName, serviceNamespace), + fmt.Sprintf("%s.%s.svc", serviceName, serviceNamespace), + } + if urlMode { + if clusterID == "" { + return fmt.Errorf("clusterID is required in urlMode") + } + dnsNames = append(dnsNames, + fmt.Sprintf("%s.%s", serviceName, clusterID), + fmt.Sprintf("%s.%s.svc", serviceName, clusterID), + fmt.Sprintf("%s.%s.svc.cluster.local", serviceName, clusterID), + ) + + urlEndpoint = fmt.Sprintf("https://%s.%s.svc.cluster.local/mutating", serviceName, clusterID) + } + + if needRegenerate(serverCertBytes, dnsNames) { // create certs - s, err := GenerateCerts(ns, name, domain) + s, err := GenerateCerts(cn, clusterDomain, dnsNames) if err != nil { return fmt.Errorf("error generate cert, %w", err) } @@ -58,7 +87,7 @@ func SyncCert(ctx context.Context, c client.Client, ns, name, domain, certDir st serverKeyBytes = s.Data[serverKeyKey] caCertBytes = s.Data[caCertKey] s.Name = secretName - s.Namespace = ns + s.Namespace = serviceNamespace // create secret this make sure one is the leader err = c.Create(ctx, s) if err != nil { @@ -67,7 +96,7 @@ func SyncCert(ctx context.Context, c client.Client, ns, name, domain, certDir st } secret := &corev1.Secret{} - err = c.Get(ctx, types.NamespacedName{Namespace: ns, Name: secretName}, secret) + err = c.Get(ctx, types.NamespacedName{Namespace: serviceNamespace, Name: secretName}, secret) if err != nil { return fmt.Errorf("error get cert from secret, %w", err) } @@ -75,11 +104,8 @@ func SyncCert(ctx context.Context, c client.Client, ns, name, domain, certDir st serverKeyBytes = secret.Data[serverKeyKey] caCertBytes = secret.Data[caCertKey] } - } else { - serverCertBytes = existSecret.Data[serverCertKey] - serverKeyBytes = existSecret.Data[serverKeyKey] - caCertBytes = existSecret.Data[caCertKey] } + if len(serverCertBytes) == 0 || len(serverKeyBytes) == 0 || len(caCertBytes) == 0 { return fmt.Errorf("invalid cert") } @@ -105,76 +131,100 @@ func SyncCert(ctx context.Context, c client.Client, ns, name, domain, certDir st // update webhook mutatingWebhook := &admissionregistrationv1.MutatingWebhookConfiguration{} - err = c.Get(ctx, types.NamespacedName{Namespace: ns, Name: name}, mutatingWebhook) + err = c.Get(ctx, types.NamespacedName{Namespace: serviceNamespace, Name: serviceName}, mutatingWebhook) if err != nil { return err } if len(mutatingWebhook.Webhooks) == 0 { return fmt.Errorf("no webhook config found") } - oldMutatingWebhook := mutatingWebhook.DeepCopy() - changed := false - for i, hook := range mutatingWebhook.Webhooks { - if len(hook.ClientConfig.CABundle) != 0 { - continue - } - changed = true - // patch ca - mutatingWebhook.Webhooks[i].ClientConfig.CABundle = caCertBytes - } - if changed { - err = wait.ExponentialBackoffWithContext(ctx, utils.DefaultPatchBackoff, func(ctx context.Context) (done bool, err error) { - innerErr := c.Patch(ctx, mutatingWebhook, client.StrategicMergeFrom(oldMutatingWebhook)) - if innerErr != nil { - log.Error(innerErr, "error patch ca") - return false, nil - } - return true, nil - }) - - if err != nil { - return err - } - log.Info("update MutatingWebhook ca bundle success") - } validateWebhook := &admissionregistrationv1.ValidatingWebhookConfiguration{} - err = c.Get(ctx, types.NamespacedName{Namespace: ns, Name: name}, validateWebhook) + err = c.Get(ctx, types.NamespacedName{Namespace: serviceNamespace, Name: serviceName}, validateWebhook) if err != nil { return err } if len(validateWebhook.Webhooks) == 0 { return fmt.Errorf("no webhook config found") } - changed = false + + oldMutatingWebhook := mutatingWebhook.DeepCopy() oldValidateWebhook := validateWebhook.DeepCopy() - for i, hook := range validateWebhook.Webhooks { - if len(hook.ClientConfig.CABundle) != 0 { - continue + + for i, _ := range mutatingWebhook.Webhooks { + if string(mutatingWebhook.Webhooks[i].ClientConfig.CABundle) != string(caCertBytes) { + mutatingWebhook.Webhooks[i].ClientConfig.CABundle = caCertBytes + } + if urlMode { + mutatingWebhook.Webhooks[i].ClientConfig.Service = nil + mutatingWebhook.Webhooks[i].ClientConfig.URL = &urlEndpoint } - // patch ca - validateWebhook.Webhooks[i].ClientConfig.CABundle = caCertBytes - changed = true } - if changed { - err = wait.ExponentialBackoffWithContext(ctx, utils.DefaultPatchBackoff, func(ctx context.Context) (done bool, err error) { - innerErr := c.Patch(ctx, validateWebhook, client.StrategicMergeFrom(oldValidateWebhook)) + + for i, _ := range validateWebhook.Webhooks { + if string(validateWebhook.Webhooks[i].ClientConfig.CABundle) != string(caCertBytes) { + validateWebhook.Webhooks[i].ClientConfig.CABundle = caCertBytes + } + if urlMode { + validateWebhook.Webhooks[i].ClientConfig.Service = nil + validateWebhook.Webhooks[i].ClientConfig.URL = &urlEndpoint + } + } + + err = updateWebhook(ctx, c, mutatingWebhook, oldMutatingWebhook) + if err != nil { + return err + } + + err = updateWebhook(ctx, c, validateWebhook, oldValidateWebhook) + if err != nil { + return err + } + + return nil +} + +func updateWebhook(ctx context.Context, c client.Client, oldObj, newObj client.Object) error { + if !reflect.DeepEqual(newObj, oldObj) { + err := wait.ExponentialBackoffWithContext(ctx, utils.DefaultPatchBackoff, func(ctx context.Context) (done bool, err error) { + innerErr := c.Patch(ctx, newObj, client.StrategicMergeFrom(oldObj)) if innerErr != nil { log.Error(innerErr, "error patch ca") return false, nil } return true, nil }) + if err != nil { return err } - log.Info("update ValidatingWebhook ca bundle success") + log.Info("update webhook success") } - return nil } -func GenerateCerts(serviceNamespace, serviceName, clusterDomain string) (*corev1.Secret, error) { +// needRegenerate check exist cert with dnsNames, +func needRegenerate(certPEM []byte, dnsNames []string) bool { + if certPEM == nil { + return true + } + block, _ := pem.Decode(certPEM) + if block == nil { + log.Info("failed to decode PEM block containing the certificate") + return true + } + // Parse the certificate + cert, err := x509.ParseCertificate(block.Bytes) + if err != nil { + log.Error(err, "error parse cert") + return true + } + oldSet := sets.New[string](cert.DNSNames...) + newSet := sets.New[string](dnsNames...) + return !oldSet.Equal(newSet) +} + +func GenerateCerts(cn, clusterDomain string, dnsNames []string) (*corev1.Secret, error) { var caPEM, serverCertPEM, serverPrivateKeyPEM *bytes.Buffer ca := &x509.Certificate{ SerialNumber: big.NewInt(2021), @@ -199,6 +249,9 @@ func GenerateCerts(serviceNamespace, serviceName, clusterDomain string) (*corev1 return nil, err } + if len(dnsNames) == 0 { + return nil, fmt.Errorf("dnsNames is empty") + } caPEM = new(bytes.Buffer) err = pem.Encode(caPEM, &pem.Block{ Type: "CERTIFICATE", @@ -208,16 +261,11 @@ func GenerateCerts(serviceNamespace, serviceName, clusterDomain string) (*corev1 return nil, err } - commonName := fmt.Sprintf("%s.%s.svc", serviceName, serviceNamespace) - dnsNames := []string{serviceName, - fmt.Sprintf("%s.%s", serviceName, serviceNamespace), - commonName} - cert := &x509.Certificate{ SerialNumber: big.NewInt(2021), DNSNames: dnsNames, Subject: pkix.Name{ - CommonName: commonName, + CommonName: cn, Organization: []string{clusterDomain}, }, NotBefore: time.Now().Add(-time.Hour), diff --git a/pkg/cert/webhook_test.go b/pkg/cert/webhook_test.go new file mode 100644 index 00000000..2d616fbe --- /dev/null +++ b/pkg/cert/webhook_test.go @@ -0,0 +1,85 @@ +package cert + +import ( + "bytes" + "context" + "crypto/rand" + "testing" + + "github.com/stretchr/testify/assert" + admissionregistrationv1 "k8s.io/api/admissionregistration/v1" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "sigs.k8s.io/controller-runtime/pkg/client/fake" +) + +func TestGenerateCerts_ValidInput(t *testing.T) { + cn := "test-service" + clusterDomain := "cluster.local" + dnsNames := []string{"test-service", "test-service.namespace", "test-service.namespace.svc"} + + secret, err := GenerateCerts(cn, clusterDomain, dnsNames) + + assert.NoError(t, err) + assert.NotNil(t, secret) + assert.Equal(t, corev1.SecretTypeTLS, secret.Type) + assert.Contains(t, secret.Data, "ca.crt") + assert.Contains(t, secret.Data, "tls.crt") + assert.Contains(t, secret.Data, "tls.key") + + need := needRegenerate(secret.Data["tls.crt"], dnsNames) + assert.False(t, need) +} + +func TestGenerateCerts_EmptyDNSNames(t *testing.T) { + cn := "test-service" + clusterDomain := "cluster.local" + dnsNames := []string{} + + secret, err := GenerateCerts(cn, clusterDomain, dnsNames) + + assert.Error(t, err) + assert.Nil(t, secret) +} + +func TestGenerateCerts_InvalidKeyGeneration(t *testing.T) { + cn := "test-service" + clusterDomain := "cluster.local" + dnsNames := []string{"test-service", "test-service.namespace", "test-service.namespace.svc"} + + originalRandReader := rand.Reader + defer func() { rand.Reader = originalRandReader }() + rand.Reader = bytes.NewReader([]byte{}) + + secret, err := GenerateCerts(cn, clusterDomain, dnsNames) + + assert.Error(t, err) + assert.Nil(t, secret) +} + +func TestWebhookUpdatedSuccessfully(t *testing.T) { + ctx := context.TODO() + oldObj := &admissionregistrationv1.MutatingWebhookConfiguration{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-webhook", + }, + } + c := fake.NewClientBuilder().WithObjects(oldObj).Build() + newObj := oldObj.DeepCopy() + newObj.Webhooks = append(newObj.Webhooks, admissionregistrationv1.MutatingWebhook{}) + + err := updateWebhook(ctx, c, oldObj, newObj) + + assert.NoError(t, err) +} + +func TestWebhookNotUpdatedWhenEqual(t *testing.T) { + ctx := context.TODO() + c := fake.NewClientBuilder().Build() + oldObj := &admissionregistrationv1.MutatingWebhookConfiguration{} + newObj := oldObj.DeepCopy() + + err := updateWebhook(ctx, c, oldObj, newObj) + + assert.NoError(t, err) +} diff --git a/types/controlplane/config_default.go b/types/controlplane/config_default.go index c10c43de..4533e4b1 100644 --- a/types/controlplane/config_default.go +++ b/types/controlplane/config_default.go @@ -41,6 +41,7 @@ type Config struct { DisableWebhook bool `json:"disableWebhook"` WebhookPort int `json:"webhookPort" validate:"gt=0,lte=65535" mod:"default=4443"` CertDir string `json:"certDir" validate:"required" mod:"default=/var/run/webhook-cert"` + WebhookURLMode bool `json:"webhookURLMode"` LeaderElection bool `json:"leaderElection"` RegisterEndpoint bool `json:"registerEndpoint"` // deprecated EnableTrace bool `json:"enableTrace"`