diff --git a/cmd/terway-controlplane/terway-controlplane.go b/cmd/terway-controlplane/terway-controlplane.go index 573c258f..dbef03ff 100644 --- a/cmd/terway-controlplane/terway-controlplane.go +++ b/cmd/terway-controlplane/terway-controlplane.go @@ -141,7 +141,7 @@ func main() { options := newOption(cfg) if !cfg.DisableWebhook { - err = cert.SyncCert(ctx, directClient, cfg.ControllerNamespace, cfg.ControllerName, cfg.ClusterDomain, cfg.CertDir) + err = cert.SyncCert(ctx, directClient, cfg.ControllerNamespace, cfg.ControllerName, cfg.ClusterDomain, cfg.CertDir, cfg.ClusterID, cfg.WebhookURLMode) if err != nil { panic(err) } diff --git a/pkg/cert/webhook.go b/pkg/cert/webhook.go index 6e85f204..113e66f0 100644 --- a/pkg/cert/webhook.go +++ b/pkg/cert/webhook.go @@ -12,18 +12,19 @@ import ( "math/big" "os" "path/filepath" + "reflect" "time" + "k8s.io/apimachinery/pkg/util/sets" + "k8s.io/client-go/util/retry" ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/controller/controllerutil" admissionregistrationv1 "k8s.io/api/admissionregistration/v1" corev1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/api/errors" "k8s.io/apimachinery/pkg/types" - "k8s.io/apimachinery/pkg/util/wait" - - "github.com/AliyunContainerService/terway/pkg/utils" ) var log = ctrl.Log.WithName("webhook-cert") @@ -36,52 +37,31 @@ const ( ) // SyncCert sync cert for webhook -func SyncCert(ctx context.Context, c client.Client, ns, name, domain, certDir string) error { - secretName := fmt.Sprintf("%s-webhook-cert", name) +func SyncCert(ctx context.Context, c client.Client, serviceNamespace, serviceName, clusterDomain, certDir, clusterID string, urlMode bool) error { // check secret - var serverCertBytes, serverKeyBytes, caCertBytes []byte + var caCertBytes []byte - // get cert from secret or generate it - existSecret := &corev1.Secret{} - err := c.Get(ctx, types.NamespacedName{Namespace: ns, Name: secretName}, existSecret) + var secret *corev1.Secret + err := retry.RetryOnConflict(retry.DefaultRetry, func() error { + var err error + secret, err = createOrUpdateCert(ctx, c, serviceNamespace, serviceName, clusterDomain, clusterID, urlMode) + return err + }) if err != nil { - if !errors.IsNotFound(err) { - return fmt.Errorf("error get cert from secret, %w", err) - } - // create certs - s, err := GenerateCerts(ns, name, domain) - if err != nil { - return fmt.Errorf("error generate cert, %w", err) - } + return err + } + caCertBytes = secret.Data[caCertKey] - serverCertBytes = s.Data[serverCertKey] - serverKeyBytes = s.Data[serverKeyKey] - caCertBytes = s.Data[caCertKey] - s.Name = secretName - s.Namespace = ns - // create secret this make sure one is the leader - err = c.Create(ctx, s) - if err != nil { - if !errors.IsAlreadyExists(err) { - return fmt.Errorf("error create cert to secret, %w", err) - } - - secret := &corev1.Secret{} - err = c.Get(ctx, types.NamespacedName{Namespace: ns, Name: secretName}, secret) - if err != nil { - return fmt.Errorf("error get cert from secret, %w", err) - } - serverCertBytes = secret.Data[serverCertKey] - serverKeyBytes = secret.Data[serverKeyKey] - caCertBytes = secret.Data[caCertKey] - } - } else { - serverCertBytes = existSecret.Data[serverCertKey] - serverKeyBytes = existSecret.Data[serverKeyKey] - caCertBytes = existSecret.Data[caCertKey] + urlEndpoint := "" + if urlMode { + urlEndpoint = fmt.Sprintf("https://%s.%s.svc.cluster.local./mutating", serviceName, clusterID) } - if len(serverCertBytes) == 0 || len(serverKeyBytes) == 0 || len(caCertBytes) == 0 { - return fmt.Errorf("invalid cert") + + err = retry.RetryOnConflict(retry.DefaultRetry, func() error { + return createOrUpdateWebhook(ctx, c, serviceName, urlEndpoint, caCertBytes) + }) + if err != nil { + return err } err = os.MkdirAll(certDir, os.ModeDir) @@ -90,11 +70,11 @@ func SyncCert(ctx context.Context, c client.Client, ns, name, domain, certDir st } // write cert to file - err = os.WriteFile(filepath.Join(certDir, serverCertKey), serverCertBytes, os.ModePerm) + err = os.WriteFile(filepath.Join(certDir, serverCertKey), secret.Data[serverCertKey], os.ModePerm) if err != nil { return fmt.Errorf("error create secret file, %w", err) } - err = os.WriteFile(filepath.Join(certDir, serverKeyKey), serverKeyBytes, os.ModePerm) + err = os.WriteFile(filepath.Join(certDir, serverKeyKey), secret.Data[serverKeyKey], os.ModePerm) if err != nil { return fmt.Errorf("error create secret file, %w", err) } @@ -102,79 +82,161 @@ func SyncCert(ctx context.Context, c client.Client, ns, name, domain, certDir st if err != nil { return fmt.Errorf("error create secret file, %w", err) } + return nil +} + +func updateWebhook(ctx context.Context, c client.Client, oldObj, newObj client.Object) error { + if !reflect.DeepEqual(oldObj, newObj) { + err := c.Update(ctx, newObj) + if err != nil { + return err + } + log.Info("update webhook success") + } + return nil +} + +func createOrUpdateCert(ctx context.Context, c client.Client, serviceNamespace, serviceName, clusterDomain, clusterID string, urlMode bool) (*corev1.Secret, error) { + secretName := fmt.Sprintf("%s-webhook-cert", serviceName) + cn := fmt.Sprintf("%s.%s.svc", serviceName, serviceNamespace) + dnsNames := []string{ + serviceName, + fmt.Sprintf("%s.%s", serviceName, serviceNamespace), + fmt.Sprintf("%s.%s.svc", serviceName, serviceNamespace), + } + if urlMode { + if clusterID == "" { + return nil, fmt.Errorf("clusterID is required in urlMode") + } + dnsNames = append(dnsNames, + fmt.Sprintf("%s.%s", serviceName, clusterID), + fmt.Sprintf("%s.%s.svc", serviceName, clusterID), + fmt.Sprintf("%s.%s.svc.cluster.local", serviceName, clusterID), + ) + } + + // check secret + var serverCertBytes, serverKeyBytes, caCertBytes []byte + + // get cert from secret or generate it + existSecret := &corev1.Secret{} + err := c.Get(ctx, types.NamespacedName{Namespace: serviceNamespace, Name: secretName}, existSecret) + if err != nil { + if !errors.IsNotFound(err) { + return nil, fmt.Errorf("error get cert from secret, %w", err) + } + } + + if needRegenerate(existSecret.Data[serverCertKey], dnsNames) { + // create certs + s, err := GenerateCerts(cn, clusterDomain, dnsNames) + if err != nil { + return nil, fmt.Errorf("error generate cert, %w", err) + } + + s.Name = secretName + s.Namespace = serviceNamespace + + update := s.DeepCopy() + result, err := controllerutil.CreateOrUpdate(ctx, c, update, func() error { + update.Data = s.Data + return nil + }) + if err != nil { + return nil, err + } + existSecret = update + log.Info("update secret", "result", result) + } + + serverCertBytes = existSecret.Data[serverCertKey] + serverKeyBytes = existSecret.Data[serverKeyKey] + caCertBytes = existSecret.Data[caCertKey] + + if len(serverCertBytes) == 0 || len(serverKeyBytes) == 0 || len(caCertBytes) == 0 { + return nil, fmt.Errorf("invalid cert") + } + return existSecret, nil +} + +func createOrUpdateWebhook(ctx context.Context, c client.Client, hookName string, urlEndpoint string, caCertBytes []byte) error { // update webhook mutatingWebhook := &admissionregistrationv1.MutatingWebhookConfiguration{} - err = c.Get(ctx, types.NamespacedName{Namespace: ns, Name: name}, mutatingWebhook) + err := c.Get(ctx, types.NamespacedName{Name: hookName}, mutatingWebhook) if err != nil { return err } if len(mutatingWebhook.Webhooks) == 0 { return fmt.Errorf("no webhook config found") } - oldMutatingWebhook := mutatingWebhook.DeepCopy() - changed := false - for i, hook := range mutatingWebhook.Webhooks { - if len(hook.ClientConfig.CABundle) != 0 { - continue - } - changed = true - // patch ca - mutatingWebhook.Webhooks[i].ClientConfig.CABundle = caCertBytes - } - if changed { - err = wait.ExponentialBackoffWithContext(ctx, utils.DefaultPatchBackoff, func(ctx context.Context) (done bool, err error) { - innerErr := c.Patch(ctx, mutatingWebhook, client.StrategicMergeFrom(oldMutatingWebhook)) - if innerErr != nil { - log.Error(innerErr, "error patch ca") - return false, nil - } - return true, nil - }) - - if err != nil { - return err - } - log.Info("update MutatingWebhook ca bundle success") - } validateWebhook := &admissionregistrationv1.ValidatingWebhookConfiguration{} - err = c.Get(ctx, types.NamespacedName{Namespace: ns, Name: name}, validateWebhook) + err = c.Get(ctx, types.NamespacedName{Name: hookName}, validateWebhook) if err != nil { return err } if len(validateWebhook.Webhooks) == 0 { return fmt.Errorf("no webhook config found") } - changed = false + + oldMutatingWebhook := mutatingWebhook.DeepCopy() oldValidateWebhook := validateWebhook.DeepCopy() - for i, hook := range validateWebhook.Webhooks { - if len(hook.ClientConfig.CABundle) != 0 { - continue + + for i, _ := range mutatingWebhook.Webhooks { + if string(mutatingWebhook.Webhooks[i].ClientConfig.CABundle) != string(caCertBytes) { + mutatingWebhook.Webhooks[i].ClientConfig.CABundle = caCertBytes } - // patch ca - validateWebhook.Webhooks[i].ClientConfig.CABundle = caCertBytes - changed = true - } - if changed { - err = wait.ExponentialBackoffWithContext(ctx, utils.DefaultPatchBackoff, func(ctx context.Context) (done bool, err error) { - innerErr := c.Patch(ctx, validateWebhook, client.StrategicMergeFrom(oldValidateWebhook)) - if innerErr != nil { - log.Error(innerErr, "error patch ca") - return false, nil - } - return true, nil - }) - if err != nil { - return err + if urlEndpoint != "" { + mutatingWebhook.Webhooks[i].ClientConfig.Service = nil + mutatingWebhook.Webhooks[i].ClientConfig.URL = &urlEndpoint + } + } + + for i, _ := range validateWebhook.Webhooks { + if string(validateWebhook.Webhooks[i].ClientConfig.CABundle) != string(caCertBytes) { + validateWebhook.Webhooks[i].ClientConfig.CABundle = caCertBytes + } + if urlEndpoint != "" { + validateWebhook.Webhooks[i].ClientConfig.Service = nil + validateWebhook.Webhooks[i].ClientConfig.URL = &urlEndpoint } - log.Info("update ValidatingWebhook ca bundle success") } + err = updateWebhook(ctx, c, oldMutatingWebhook, mutatingWebhook) + if err != nil { + return err + } + + err = updateWebhook(ctx, c, oldValidateWebhook, validateWebhook) + if err != nil { + return err + } return nil } -func GenerateCerts(serviceNamespace, serviceName, clusterDomain string) (*corev1.Secret, error) { +// needRegenerate check exist cert with dnsNames, +func needRegenerate(certPEM []byte, dnsNames []string) bool { + if certPEM == nil { + return true + } + block, _ := pem.Decode(certPEM) + if block == nil { + log.Info("failed to decode PEM block containing the certificate") + return true + } + // Parse the certificate + cert, err := x509.ParseCertificate(block.Bytes) + if err != nil { + log.Error(err, "error parse cert") + return true + } + oldSet := sets.New[string](cert.DNSNames...) + newSet := sets.New[string](dnsNames...) + return !oldSet.Equal(newSet) +} + +func GenerateCerts(cn, clusterDomain string, dnsNames []string) (*corev1.Secret, error) { var caPEM, serverCertPEM, serverPrivateKeyPEM *bytes.Buffer ca := &x509.Certificate{ SerialNumber: big.NewInt(2021), @@ -199,6 +261,9 @@ func GenerateCerts(serviceNamespace, serviceName, clusterDomain string) (*corev1 return nil, err } + if len(dnsNames) == 0 { + return nil, fmt.Errorf("dnsNames is empty") + } caPEM = new(bytes.Buffer) err = pem.Encode(caPEM, &pem.Block{ Type: "CERTIFICATE", @@ -208,16 +273,11 @@ func GenerateCerts(serviceNamespace, serviceName, clusterDomain string) (*corev1 return nil, err } - commonName := fmt.Sprintf("%s.%s.svc", serviceName, serviceNamespace) - dnsNames := []string{serviceName, - fmt.Sprintf("%s.%s", serviceName, serviceNamespace), - commonName} - cert := &x509.Certificate{ SerialNumber: big.NewInt(2021), DNSNames: dnsNames, Subject: pkix.Name{ - CommonName: commonName, + CommonName: cn, Organization: []string{clusterDomain}, }, NotBefore: time.Now().Add(-time.Hour), diff --git a/pkg/cert/webhook_test.go b/pkg/cert/webhook_test.go new file mode 100644 index 00000000..04f6e134 --- /dev/null +++ b/pkg/cert/webhook_test.go @@ -0,0 +1,247 @@ +package cert + +import ( + "bytes" + "context" + "crypto/rand" + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + admissionregistrationv1 "k8s.io/api/admissionregistration/v1" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/types" + "sigs.k8s.io/controller-runtime/pkg/client/fake" +) + +func TestGenerateCerts_ValidInput(t *testing.T) { + cn := "test-service" + clusterDomain := "cluster.local" + dnsNames := []string{"test-service", "test-service.namespace", "test-service.namespace.svc"} + + secret, err := GenerateCerts(cn, clusterDomain, dnsNames) + + assert.NoError(t, err) + assert.NotNil(t, secret) + assert.Equal(t, corev1.SecretTypeTLS, secret.Type) + assert.Contains(t, secret.Data, "ca.crt") + assert.Contains(t, secret.Data, "tls.crt") + assert.Contains(t, secret.Data, "tls.key") + + need := needRegenerate(secret.Data["tls.crt"], dnsNames) + assert.False(t, need) +} + +func TestGenerateCerts_EmptyDNSNames(t *testing.T) { + cn := "test-service" + clusterDomain := "cluster.local" + dnsNames := []string{} + + secret, err := GenerateCerts(cn, clusterDomain, dnsNames) + + assert.Error(t, err) + assert.Nil(t, secret) +} + +func TestGenerateCerts_InvalidKeyGeneration(t *testing.T) { + cn := "test-service" + clusterDomain := "cluster.local" + dnsNames := []string{"test-service", "test-service.namespace", "test-service.namespace.svc"} + + originalRandReader := rand.Reader + defer func() { rand.Reader = originalRandReader }() + rand.Reader = bytes.NewReader([]byte{}) + + secret, err := GenerateCerts(cn, clusterDomain, dnsNames) + + assert.Error(t, err) + assert.Nil(t, secret) +} + +func TestWebhookUpdatedSuccessfully(t *testing.T) { + ctx := context.TODO() + oldObj := &admissionregistrationv1.MutatingWebhookConfiguration{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-webhook", + }, + } + c := fake.NewClientBuilder().WithObjects(oldObj).Build() + newObj := oldObj.DeepCopy() + newObj.Webhooks = append(newObj.Webhooks, admissionregistrationv1.MutatingWebhook{}) + + err := updateWebhook(ctx, c, oldObj, newObj) + + assert.NoError(t, err) +} + +func TestWebhookNotUpdatedWhenEqual(t *testing.T) { + ctx := context.TODO() + c := fake.NewClientBuilder().Build() + oldObj := &admissionregistrationv1.MutatingWebhookConfiguration{} + newObj := oldObj.DeepCopy() + + err := updateWebhook(ctx, c, oldObj, newObj) + + assert.NoError(t, err) +} + +func TestWebhookUpdatedWithCABundle(t *testing.T) { + ctx := context.TODO() + c := fake.NewClientBuilder().WithObjects(&admissionregistrationv1.MutatingWebhookConfiguration{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-webhook", + }, + Webhooks: []admissionregistrationv1.MutatingWebhook{ + { + ClientConfig: admissionregistrationv1.WebhookClientConfig{ + CABundle: []byte("old-ca-bundle"), + }, + }, + }, + }, &admissionregistrationv1.ValidatingWebhookConfiguration{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-webhook", + }, + Webhooks: []admissionregistrationv1.ValidatingWebhook{ + { + ClientConfig: admissionregistrationv1.WebhookClientConfig{ + CABundle: []byte("old-ca-bundle"), + }, + }, + }, + }).Build() + caCertBytes := []byte("new-ca-bundle") + + err := createOrUpdateWebhook(ctx, c, "test-webhook", "", caCertBytes) + + assert.NoError(t, err) + + updated := &admissionregistrationv1.MutatingWebhookConfiguration{} + err = c.Get(ctx, types.NamespacedName{Name: "test-webhook"}, updated) + assert.NoError(t, err) + + assert.Equal(t, caCertBytes, updated.Webhooks[0].ClientConfig.CABundle) +} + +func TestWebhookUpdatedWithURLEndpoint(t *testing.T) { + ctx := context.TODO() + c := fake.NewClientBuilder().WithObjects(&admissionregistrationv1.MutatingWebhookConfiguration{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-webhook", + }, + Webhooks: []admissionregistrationv1.MutatingWebhook{ + { + ClientConfig: admissionregistrationv1.WebhookClientConfig{ + CABundle: []byte("old-ca-bundle"), + }, + }, + }, + }, &admissionregistrationv1.ValidatingWebhookConfiguration{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-webhook", + }, + Webhooks: []admissionregistrationv1.ValidatingWebhook{ + { + ClientConfig: admissionregistrationv1.WebhookClientConfig{ + CABundle: []byte("old-ca-bundle"), + }, + }, + }, + }).Build() + caCertBytes := []byte("new-ca-bundle") + urlEndpoint := "https://example.com/webhook" + + err := createOrUpdateWebhook(ctx, c, "test-webhook", urlEndpoint, caCertBytes) + + assert.NoError(t, err) + + updated := &admissionregistrationv1.MutatingWebhookConfiguration{} + err = c.Get(ctx, types.NamespacedName{Name: "test-webhook"}, updated) + assert.NoError(t, err) + + assert.Equal(t, urlEndpoint, *updated.Webhooks[0].ClientConfig.URL) +} + +func TestWebhookUpdateFailsWhenNoWebhookConfig(t *testing.T) { + ctx := context.TODO() + c := fake.NewClientBuilder().Build() + caCertBytes := []byte("new-ca-bundle") + + err := createOrUpdateWebhook(ctx, c, "test-webhook", "", caCertBytes) + + assert.Error(t, err) +} + +func TestWebhookUpdateFailsWhenNoWebhooks(t *testing.T) { + ctx := context.TODO() + c := fake.NewClientBuilder().WithObjects(&admissionregistrationv1.MutatingWebhookConfiguration{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-webhook", + }, + }, &admissionregistrationv1.ValidatingWebhookConfiguration{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-webhook", + }, + }).Build() + caCertBytes := []byte("new-ca-bundle") + + err := createOrUpdateWebhook(ctx, c, "test-webhook", "", caCertBytes) + + assert.Error(t, err) +} + +func TestCreateOrUpdateCertSuccessfully(t *testing.T) { + ctx := context.TODO() + c := fake.NewClientBuilder().Build() + serviceNamespace := "default" + serviceName := "test-service" + clusterDomain := "cluster.local" + clusterID := "test-cluster" + urlMode := false + + secret, err := createOrUpdateCert(ctx, c, serviceNamespace, serviceName, clusterDomain, clusterID, urlMode) + + assert.NoError(t, err) + assert.NotNil(t, secret) + assert.Equal(t, fmt.Sprintf("%s-webhook-cert", serviceName), secret.Name) + assert.Equal(t, serviceNamespace, secret.Namespace) + assert.Contains(t, secret.Data, serverCertKey) + assert.Contains(t, secret.Data, serverKeyKey) + assert.Contains(t, secret.Data, caCertKey) +} + +func TestCreateOrUpdateCertWithURLMode(t *testing.T) { + ctx := context.TODO() + c := fake.NewClientBuilder().Build() + serviceNamespace := "default" + serviceName := "test-service" + clusterDomain := "cluster.local" + clusterID := "test-cluster" + urlMode := true + + secret, err := createOrUpdateCert(ctx, c, serviceNamespace, serviceName, clusterDomain, clusterID, urlMode) + + assert.NoError(t, err) + assert.NotNil(t, secret) + assert.Equal(t, fmt.Sprintf("%s-webhook-cert", serviceName), secret.Name) + assert.Equal(t, serviceNamespace, secret.Namespace) + assert.Contains(t, secret.Data, serverCertKey) + assert.Contains(t, secret.Data, serverKeyKey) + assert.Contains(t, secret.Data, caCertKey) +} + +func TestCreateOrUpdateCertFailsWhenClusterIDMissing(t *testing.T) { + ctx := context.TODO() + c := fake.NewClientBuilder().Build() + serviceNamespace := "default" + serviceName := "test-service" + clusterDomain := "cluster.local" + clusterID := "" + urlMode := true + + secret, err := createOrUpdateCert(ctx, c, serviceNamespace, serviceName, clusterDomain, clusterID, urlMode) + + assert.Error(t, err) + assert.Nil(t, secret) +} diff --git a/types/controlplane/config_default.go b/types/controlplane/config_default.go index c10c43de..4533e4b1 100644 --- a/types/controlplane/config_default.go +++ b/types/controlplane/config_default.go @@ -41,6 +41,7 @@ type Config struct { DisableWebhook bool `json:"disableWebhook"` WebhookPort int `json:"webhookPort" validate:"gt=0,lte=65535" mod:"default=4443"` CertDir string `json:"certDir" validate:"required" mod:"default=/var/run/webhook-cert"` + WebhookURLMode bool `json:"webhookURLMode"` LeaderElection bool `json:"leaderElection"` RegisterEndpoint bool `json:"registerEndpoint"` // deprecated EnableTrace bool `json:"enableTrace"`