Skip to content

Commit

Permalink
feat: Support key versionless
Browse files Browse the repository at this point in the history
Retrieve latest key version from akv and
put key version into annotation for decryption.

Signed-off-by: Zhecheng Li <[email protected]>
  • Loading branch information
lzhecheng committed Dec 17, 2024
1 parent 2b68d2f commit e6d4654
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 35 deletions.
7 changes: 4 additions & 3 deletions cmd/server/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,10 @@ import (
)

var (
listenAddr = flag.String("listen-addr", "unix:///opt/azurekms.socket", "gRPC listen address")
keyvaultName = flag.String("keyvault-name", "", "Azure Key Vault name")
keyName = flag.String("key-name", "", "Azure Key Vault KMS key name")
listenAddr = flag.String("listen-addr", "unix:///opt/azurekms.socket", "gRPC listen address")
keyvaultName = flag.String("keyvault-name", "", "Azure Key Vault name")
keyName = flag.String("key-name", "", "Azure Key Vault KMS key name")
// If --key-version not set or is empty, the plugin will use the latest version of the key.
keyVersion = flag.String("key-version", "", "Azure Key Vault KMS key version")
managedHSM = flag.Bool("managed-hsm", false, "Azure Key Vault Managed HSM. Refer to https://docs.microsoft.com/en-us/azure/key-vault/managed-hsm/overview for more details.")
logFormatJSON = flag.Bool("log-format-json", false, "set log formatter to json")
Expand Down
93 changes: 71 additions & 22 deletions pkg/plugin/keyvault.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ const (
keyvaultRegionAnnotationKey = "x-ms-keyvault-region.azure.akv.io"
versionAnnotationKey = "version.azure.akv.io"
algorithmAnnotationKey = "algorithm.azure.akv.io"
keyVersionAnnotationKey = "keyversion.azure.akv.io"
dateAnnotationValue = "Date"
requestIDAnnotationValue = "X-Ms-Request-Id"
keyvaultRegionAnnotationValue = "X-Ms-Keyvault-Region"
Expand Down Expand Up @@ -70,7 +71,7 @@ type KeyVaultClient struct {
keyName string
keyVersion string
vaultURL string
keyIDHash string
keyIDHash string // keyIDHash is used when key version-less is disabled
azureEnvironment *azure.Environment
}

Expand All @@ -90,9 +91,10 @@ func NewKeyVaultClient(

// this should be the case for bring your own key, clusters bootstrapped with
// aks-engine or aks and standalone kms plugin deployments
if len(vaultName) == 0 || len(keyName) == 0 || len(keyVersion) == 0 {
return nil, fmt.Errorf("key vault name, key name and key version are required")
if len(vaultName) == 0 || len(keyName) == 0 {
return nil, fmt.Errorf("key vault name and key name are required")
}

kvClient := kv.New()
err := kvClient.AddToUserAgent(version.GetUserAgent())
if err != nil {
Expand Down Expand Up @@ -121,9 +123,12 @@ func NewKeyVaultClient(
return nil, fmt.Errorf("failed to get vault url, error: %+v", err)
}

keyIDHash, err := getKeyIDHash(*vaultURL, keyName, keyVersion)
if err != nil {
return nil, fmt.Errorf("failed to get key id hash, error: %w", err)
keyIDHash := ""
if len(keyVersion) != 0 {
keyIDHash, err = getKeyIDHash(*vaultURL, keyName, keyVersion)
if err != nil {
return nil, fmt.Errorf("failed to get key id hash, error: %w", err)
}
}

if proxyMode {
Expand Down Expand Up @@ -158,17 +163,39 @@ func (kvc *KeyVaultClient) Encrypt(
Algorithm: encryptionAlgorithm,
Value: &value,
}
result, err := kvc.baseClient.Encrypt(ctx, kvc.vaultURL, kvc.keyName, kvc.keyVersion, params)

keyVersion := kvc.keyVersion
result, err := kvc.baseClient.Encrypt(ctx, kvc.vaultURL, kvc.keyName, keyVersion, params)
if err != nil {
return nil, fmt.Errorf("failed to encrypt, error: %+v", err)
}

if kvc.keyIDHash != fmt.Sprintf("%x", sha256.Sum256([]byte(*result.Kid))) {
return nil, fmt.Errorf(
"key id initialized does not match with the key id from encryption result, expected: %s, got: %s",
kvc.keyIDHash,
*result.Kid,
)
keyIDHash := ""
if result.Kid == nil {
return nil, fmt.Errorf("key id is nil in encryption result")
}
if len(keyVersion) == 0 {
keyVersion = path.Base(strings.TrimSuffix(*result.Kid, "/"))
keyIDHash, err = getKeyIDHash(kvc.vaultURL, kvc.keyName, keyVersion)
if err != nil {
return nil, fmt.Errorf("failed to get key id hash, error: %w", err)
}
if keyIDHash != fmt.Sprintf("%x", sha256.Sum256([]byte(*result.Kid))) {
return nil, fmt.Errorf(
"key id initialized does not match with the key id from encryption result, expected: %s, got: %s",
keyIDHash,
*result.Kid,
)
}
} else {
if kvc.keyIDHash != fmt.Sprintf("%x", sha256.Sum256([]byte(*result.Kid))) {
return nil, fmt.Errorf(
"key id initialized does not match with the key id from encryption result, expected: %s, got: %s",
kvc.keyIDHash,
*result.Kid,
)
}
keyIDHash = kvc.keyIDHash
}

annotations := map[string][]byte{
Expand All @@ -177,11 +204,13 @@ func (kvc *KeyVaultClient) Encrypt(
keyvaultRegionAnnotationKey: []byte(result.Header.Get(keyvaultRegionAnnotationValue)),
versionAnnotationKey: []byte(encryptionResponseVersion),
algorithmAnnotationKey: []byte(encryptionAlgorithm),
keyVersionAnnotationKey: []byte(keyVersion),
}

mlog.Info("Encryption succeeded", "vaultName", kvc.vaultName, "keyName", kvc.keyName, "keyVersion", keyVersion)
return &service.EncryptResponse{
Ciphertext: []byte(*result.Result),
KeyID: kvc.keyIDHash,
KeyID: keyIDHash,
Annotations: annotations,
}, nil
}
Expand All @@ -208,7 +237,12 @@ func (kvc *KeyVaultClient) Decrypt(
Value: &value,
}

result, err := kvc.baseClient.Decrypt(ctx, kvc.vaultURL, kvc.keyName, kvc.keyVersion, params)
keyVersion := kvc.keyVersion
if len(annotations[keyVersionAnnotationKey]) != 0 {
keyVersion = string(annotations[keyVersionAnnotationKey])
}

result, err := kvc.baseClient.Decrypt(ctx, kvc.vaultURL, kvc.keyName, keyVersion, params)
if err != nil {
return nil, fmt.Errorf("failed to decrypt, error: %+v", err)
}
Expand All @@ -217,6 +251,7 @@ func (kvc *KeyVaultClient) Decrypt(
return nil, fmt.Errorf("failed to base64 decode result, error: %+v", err)
}

mlog.Info("Decryption succeeded", "vaultName", kvc.vaultName, "keyName", kvc.keyName, "keyVersion", keyVersion)
return bytes, nil
}

Expand All @@ -234,19 +269,33 @@ func (kvc *KeyVaultClient) GetVaultURL() string {
// It also validates keyID that the API server checks.
func (kvc *KeyVaultClient) validateAnnotations(
annotations map[string][]byte,
keyID string,
keyIDHash string,
encryptionAlgorithm kv.JSONWebKeyEncryptionAlgorithm,
) error {
if len(annotations) == 0 {
return fmt.Errorf("invalid annotations, annotations cannot be empty")
}

if keyID != kvc.keyIDHash {
return fmt.Errorf(
"key id %s does not match expected key id %s used for encryption",
keyID,
kvc.keyIDHash,
)
if len(annotations[keyVersionAnnotationKey]) == 0 {
if keyIDHash != kvc.keyIDHash {
return fmt.Errorf(
"key id %s does not match expected key id %s used for encryption",
keyIDHash,
kvc.keyIDHash,
)
}
} else {
keyIDHashLocal, err := getKeyIDHash(kvc.vaultURL, kvc.keyName, string(annotations[keyVersionAnnotationKey]))
if err != nil {
return fmt.Errorf("failed to get key id hash, error: %w", err)
}
if keyIDHashLocal != keyIDHash {
return fmt.Errorf(
"key id %s does not match expected key id %s used for encryption",
keyIDHash,
keyIDHashLocal,
)
}
}

algorithm := string(annotations[algorithmAnnotationKey])
Expand Down
30 changes: 20 additions & 10 deletions pkg/plugin/keyvault_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,16 @@ var (

func TestNewKeyVaultClientError(t *testing.T) {
tests := []struct {
desc string
config *config.AzureConfig
vaultName string
keyName string
keyVersion string
proxyMode bool
proxyAddress string
proxyPort int
managedHSM bool
desc string
config *config.AzureConfig
vaultName string
keyName string
keyVersion string
keyVersionlessEnabled bool
proxyMode bool
proxyAddress string
proxyPort int
managedHSM bool
}{
{
desc: "vault name not provided",
Expand All @@ -43,7 +44,7 @@ func TestNewKeyVaultClientError(t *testing.T) {
proxyMode: false,
},
{
desc: "key version not provided",
desc: "key version not provided when not keyVersionlessEnabled",
config: &config.AzureConfig{},
vaultName: "testkv",
keyName: "k8s",
Expand Down Expand Up @@ -127,6 +128,15 @@ func TestNewKeyVaultClient(t *testing.T) {
proxyMode: false,
expectedVaultURL: "https://testkv.managedhsm.azure.net/",
},
{
desc: "no error when no key version (version-less)",
config: &config.AzureConfig{ClientID: "clientid", ClientSecret: "clientsecret"},
vaultName: "testkv",
keyName: "key1",
keyVersion: "",
proxyMode: false,
expectedVaultURL: "https://testkv.vault.azure.net/",
},
}

for _, test := range tests {
Expand Down

0 comments on commit e6d4654

Please sign in to comment.