diff --git a/internal/cryptor/aes256gcm.go b/internal/cryptor/aes256gcm/aes256gcm.go similarity index 75% rename from internal/cryptor/aes256gcm.go rename to internal/cryptor/aes256gcm/aes256gcm.go index 7560d5e..d720f69 100644 --- a/internal/cryptor/aes256gcm.go +++ b/internal/cryptor/aes256gcm/aes256gcm.go @@ -1,4 +1,6 @@ -package cryptor +// Package aes256gcm provides an AES-256-GCM [cryptor.Cryptor] implementation. +// All sensitive data is handled in mlock'd memory via securemem. +package aes256gcm import ( "context" @@ -8,6 +10,7 @@ import ( "errors" "fmt" + "github.com/openkcm/krypton/internal/cryptor" "github.com/openkcm/krypton/internal/securemem" ) @@ -22,19 +25,22 @@ const ( // All key material and intermediate plaintext/ciphertext are handled in mlock'd // memory via securemem to prevent leakage into swap or core dumps. type AES256GCM struct { - info Info + info cryptor.Info } -var _ Cryptor = &AES256GCM{} +var _ cryptor.Cryptor = &AES256GCM{} + +// InfoNameAES256GCM indicates that the Cryptor supports AES-256 in Galois/Counter Mode (GCM). +const InfoNameAES256GCM cryptor.InfoName = "AES256-GCM" // ErrAllocatedDataNotFound indicates that data reserved in the secure memory vault // could not be retrieved after the cryptographic operation completed. var ErrAllocatedDataNotFound = errors.New("allocated data not found in vault") -// NewAES256GCM returns a ready-to-use AES-256-GCM cryptor. -func NewAES256GCM() *AES256GCM { +// New returns a ready-to-use AES-256-GCM cryptor. +func New() *AES256GCM { return &AES256GCM{ - info: Info{ + info: cryptor.Info{ Name: InfoNameAES256GCM, DecryptionSecretRequired: true, }, @@ -42,28 +48,32 @@ func NewAES256GCM() *AES256GCM { } // Info returns metadata about the AES256GCM cryptor. -func (a *AES256GCM) Info() Info { +func (a *AES256GCM) Info() cryptor.Info { return a.info } // Encrypt encrypts the plaintext using AES-256 in GCM mode with the provided key and AAD. -func (a *AES256GCM) Encrypt(ctx context.Context, req EncryptRequest) (*EncryptResponse, error) { +func (a *AES256GCM) Encrypt(ctx context.Context, req cryptor.EncryptRequest) (*cryptor.EncryptResponse, error) { if err := req.Validate(); err != nil { return nil, err } if req.Secret == nil { - return nil, fmt.Errorf("missing encryption secret: %w", ErrRequest) + return nil, fmt.Errorf("missing encryption secret: %w", cryptor.ErrRequest) + } + + if req.Secret.Algorithm != cryptor.KeyAlgorithmAES256 { + return nil, fmt.Errorf("invalid key algorithm: expected %s, got %s: %w", cryptor.KeyAlgorithmAES256, req.Secret.Algorithm, cryptor.ErrRequest) } - secretSize := len(req.Secret.SecureBytes()) + secretSize := len(req.Secret.Data.SecureBytes()) if secretSize != 32 { - return nil, fmt.Errorf("invalid key size: expected 32 bytes, got %d: %w", secretSize, ErrRequest) + return nil, fmt.Errorf("invalid key size: expected 32 bytes, got %d: %w", secretSize, cryptor.ErrRequest) } resp, err := securemem.Run(ctx, func(ctx context.Context, hr *securemem.HandlerRequest) error { // 1. Initialize AES-256 block cipher from the 32-byte key. - block, err := aes.NewCipher(req.Secret.SecureBytes()) + block, err := aes.NewCipher(req.Secret.Data.SecureBytes()) if err != nil { return fmt.Errorf("failed to create AES cipher: %w", err) } @@ -111,29 +121,33 @@ func (a *AES256GCM) Encrypt(ctx context.Context, req EncryptRequest) (*EncryptRe return nil, fmt.Errorf("allocated ciphertext not found in vault after encryption: %w", ErrAllocatedDataNotFound) } - return &EncryptResponse{ + return &cryptor.EncryptResponse{ Ciphertext: cipherText, }, nil } // Decrypt decrypts the ciphertext using AES-256 in GCM mode with the provided key and AAD. -func (a *AES256GCM) Decrypt(ctx context.Context, req DecryptRequest) (*DecryptResponse, error) { +func (a *AES256GCM) Decrypt(ctx context.Context, req cryptor.DecryptRequest) (*cryptor.DecryptResponse, error) { if err := req.Validate(); err != nil { return nil, err } if req.Secret == nil { - return nil, fmt.Errorf("missing decryption secret: %w", ErrRequest) + return nil, fmt.Errorf("missing decryption secret: %w", cryptor.ErrRequest) + } + + if req.Secret.Algorithm != cryptor.KeyAlgorithmAES256 { + return nil, fmt.Errorf("invalid key algorithm: expected %s, got %s: %w", cryptor.KeyAlgorithmAES256, req.Secret.Algorithm, cryptor.ErrRequest) } - secretSize := len(req.Secret.SecureBytes()) + secretSize := len(req.Secret.Data.SecureBytes()) if secretSize != 32 { - return nil, fmt.Errorf("invalid key size: expected 32 bytes, got %d: %w", secretSize, ErrRequest) + return nil, fmt.Errorf("invalid key size: expected 32 bytes, got %d: %w", secretSize, cryptor.ErrRequest) } resp, err := securemem.Run(ctx, func(ctx context.Context, hr *securemem.HandlerRequest) error { // 1. Initialize AES-256 block cipher from the 32-byte key. - block, err := aes.NewCipher(req.Secret.SecureBytes()) + block, err := aes.NewCipher(req.Secret.Data.SecureBytes()) if err != nil { return fmt.Errorf("failed to create AES cipher: %w", err) } @@ -149,7 +163,7 @@ func (a *AES256GCM) Decrypt(ctx context.Context, req DecryptRequest) (*DecryptRe // 3. Verify the ciphertext is at least nonce + tag bytes long. if cipherTextSize < nonceSize+gcm.Overhead() { - return fmt.Errorf("ciphertext too short: %w", ErrRequest) + return fmt.Errorf("ciphertext too short: %w", cryptor.ErrRequest) } // 4. Compute plaintext size: total - nonce - tag. @@ -194,7 +208,7 @@ func (a *AES256GCM) Decrypt(ctx context.Context, req DecryptRequest) (*DecryptRe return nil, fmt.Errorf("allocated plaintext not found in vault after decryption: %w", ErrAllocatedDataNotFound) } - return &DecryptResponse{ + return &cryptor.DecryptResponse{ Plaintext: plainText, }, nil } diff --git a/internal/cryptor/aes256gcm/aes256gcm_test.go b/internal/cryptor/aes256gcm/aes256gcm_test.go new file mode 100644 index 0000000..ee12219 --- /dev/null +++ b/internal/cryptor/aes256gcm/aes256gcm_test.go @@ -0,0 +1,727 @@ +package aes256gcm_test + +import ( + "crypto/rand" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/openkcm/krypton/internal/cryptor" + "github.com/openkcm/krypton/internal/cryptor/aes256gcm" + "github.com/openkcm/krypton/internal/securemem" +) + +func TestAES256GCM_Encrypt(t *testing.T) { + // given + ctx := t.Context() + subj := aes256gcm.New() + + t.Run("should fail if encrypt request validation fails", func(t *testing.T) { + // given + req := cryptor.EncryptRequest{ + TenantID: "tenant-1", + KeyID: "key-1", + KeyVersion: 1, + Secret: &cryptor.Secret{ + Algorithm: cryptor.KeyAlgorithmAES256, + }, + Plaintext: nil, // missing plaintext + } + + // when + resp, err := subj.Encrypt(ctx, req) + + // then + assert.Error(t, err) + assert.Nil(t, resp) + }) + + t.Run("should fail if encrypt request secret is nil", func(t *testing.T) { + // given + plainText := newSecureMemData(t, []byte("plaintext")) + req := cryptor.EncryptRequest{ + TenantID: "tenant-1", + KeyID: "key-1", + KeyVersion: 1, + Secret: nil, // missing secret + Plaintext: plainText, + } + + // when + resp, err := subj.Encrypt(ctx, req) + + // then + assert.Error(t, err) + assert.Nil(t, resp) + }) + + t.Run("should fail if encrypt request algorithm is unknown", func(t *testing.T) { + plainText := newSecureMemData(t, []byte("plaintext")) + + req := cryptor.EncryptRequest{ + TenantID: "tenant-1", + KeyID: "key-1", + KeyVersion: 1, + Secret: &cryptor.Secret{ + Algorithm: "unknown-algorithm", // unsupported algorithm + Data: newSecretKey(t), + }, + Plaintext: plainText, + } + + // when + resp, err := subj.Encrypt(ctx, req) + + // then + assert.Error(t, err) + assert.Nil(t, resp) + }) + + t.Run("should fail if encrypt request secret data is missing", func(t *testing.T) { + // given + plainText := newSecureMemData(t, []byte("plaintext")) + + req := cryptor.EncryptRequest{ + TenantID: "tenant-1", + KeyID: "key-1", + KeyVersion: 1, + Secret: &cryptor.Secret{ + Algorithm: cryptor.KeyAlgorithmAES256, + Data: nil, // missing secret + }, + Plaintext: plainText, + } + + // when + resp, err := subj.Encrypt(ctx, req) + + // then + assert.Error(t, err) + assert.Nil(t, resp) + }) + + t.Run("should fail to encrypt if plaintext data is destroyed", func(t *testing.T) { + // given + plainText := newSecureMemData(t, []byte("plaintext")) + secret := newSecretKey(t) + + req := cryptor.EncryptRequest{ + TenantID: "tenant-1", + KeyID: "key-1", + KeyVersion: 1, + Secret: &cryptor.Secret{ + Algorithm: cryptor.KeyAlgorithmAES256, + Data: secret, + }, + Plaintext: plainText, + } + + // destroy plaintext before encryption + require.NoError(t, plainText.Destroy()) + + // when + resp, err := subj.Encrypt(ctx, req) + + // then + assert.Error(t, err) + assert.Nil(t, resp) + }) + + t.Run("should fail if secret data size is invalid", func(t *testing.T) { + // given + plainText := newSecureMemData(t, []byte("plaintext")) + + // invalid secret size (should be 32 bytes for AES-256) + secret, err := securemem.NewData("key", 16) + require.NoError(t, err) + t.Cleanup(func() { _ = secret.Destroy() }) + + req := cryptor.EncryptRequest{ + TenantID: "tenant-1", + KeyID: "key-1", + KeyVersion: 1, + Secret: &cryptor.Secret{ + Algorithm: cryptor.KeyAlgorithmAES256, + Data: secret, // invalid key size + }, + Plaintext: plainText, + } + + // when + resp, err := subj.Encrypt(ctx, req) + + // then + assert.Error(t, err) + assert.Nil(t, resp) + }) + + t.Run("should generate different ciphertext for same plaintext and key", func(t *testing.T) { + // given + plainText := newSecureMemData(t, []byte("plaintext")) + secret := newSecretKey(t) + + req := cryptor.EncryptRequest{ + TenantID: "tenant-1", + KeyID: "key-1", + KeyVersion: 1, + Secret: &cryptor.Secret{ + Algorithm: cryptor.KeyAlgorithmAES256, + Data: secret, + }, + Plaintext: plainText, + } + + // when + resp1, err := subj.Encrypt(ctx, req) + require.NoError(t, err) + t.Cleanup(func() { _ = resp1.Ciphertext.Destroy() }) + + resp2, err := subj.Encrypt(ctx, req) + require.NoError(t, err) + t.Cleanup(func() { _ = resp2.Ciphertext.Destroy() }) + + // then + assert.NotEqual(t, resp1.Ciphertext.SecureBytes(), resp2.Ciphertext.SecureBytes()) + }) + + t.Run("should not destroy secret and plaintext after encryption", func(t *testing.T) { + // given + plainText := newSecureMemData(t, []byte("plaintext")) + secret := newSecretKey(t) + + req := cryptor.EncryptRequest{ + TenantID: "tenant-1", + KeyID: "key-1", + KeyVersion: 1, + Secret: &cryptor.Secret{ + Algorithm: cryptor.KeyAlgorithmAES256, + Data: secret, + }, + Plaintext: plainText, + } + + // when + resp, err := subj.Encrypt(ctx, req) + require.NoError(t, err) + t.Cleanup(func() { _ = resp.Ciphertext.Destroy() }) + + // then + assert.NotNil(t, plainText.SecureBytes()) + assert.NotNil(t, secret.SecureBytes()) + }) +} + +func TestAES256GCM_Decrypt(t *testing.T) { + // given + ctx := t.Context() + subj := aes256gcm.New() + + t.Run("should fail if decrypt request validation fails", func(t *testing.T) { + // given + req := cryptor.DecryptRequest{ + TenantID: "tenant-1", + KeyID: "key-1", + KeyVersion: 1, + Secret: &cryptor.Secret{ + Algorithm: cryptor.KeyAlgorithmAES256, + }, + Ciphertext: nil, // missing ciphertext + } + + // when + resp, err := subj.Decrypt(ctx, req) + + // then + assert.Error(t, err) + assert.Nil(t, resp) + }) + + t.Run("should fail if decrypt request secret is nil", func(t *testing.T) { + // given + plainText := newSecureMemData(t, []byte("ciphertext")) + + req := cryptor.DecryptRequest{ + TenantID: "tenant-1", + KeyID: "key-1", + KeyVersion: 1, + Secret: nil, + Ciphertext: plainText, + } + + // when + resp, err := subj.Decrypt(ctx, req) + + // then + assert.Error(t, err) + assert.Nil(t, resp) + }) + + t.Run("should fail if decrypt request algorithm is unknown", func(t *testing.T) { + // given + plainText := newSecureMemData(t, []byte("ciphertext")) + secret := newSecretKey(t) + + decResp, err := subj.Encrypt(ctx, cryptor.EncryptRequest{ + TenantID: "tenant-1", + KeyID: "key-1", + KeyVersion: 1, + Secret: &cryptor.Secret{ + Algorithm: cryptor.KeyAlgorithmAES256, + Data: secret, + }, + Plaintext: plainText, + }) + require.NoError(t, err) + t.Cleanup(func() { _ = decResp.Ciphertext.Destroy() }) + + req := cryptor.DecryptRequest{ + TenantID: "tenant-1", + KeyID: "key-1", + KeyVersion: 1, + Secret: &cryptor.Secret{ + Algorithm: "unknown-algorithm", // unsupported algorithm + Data: secret, + }, + Ciphertext: decResp.Ciphertext, + } + + // when + resp, err := subj.Decrypt(ctx, req) + + // then + assert.Error(t, err) + assert.Nil(t, resp) + }) + + t.Run("should fail if decrypt request secret data is missing", func(t *testing.T) { + // given + cipherText := newSecureMemData(t, []byte("ciphertext")) + + req := cryptor.DecryptRequest{ + TenantID: "tenant-1", + KeyID: "key-1", + KeyVersion: 1, + Secret: &cryptor.Secret{ + Algorithm: cryptor.KeyAlgorithmAES256, + Data: nil, // missing secret + }, + Ciphertext: cipherText, + } + + // when + resp, err := subj.Decrypt(ctx, req) + + // then + assert.Error(t, err) + assert.Nil(t, resp) + }) + + t.Run("should fail if secret data size is invalid", func(t *testing.T) { + // given + cipherText := newSecureMemData(t, []byte("ciphertext")) + + // invalid secret size (should be 32 bytes for AES-256) + secret, err := securemem.NewData("key", 16) + require.NoError(t, err) + t.Cleanup(func() { _ = secret.Destroy() }) + + req := cryptor.DecryptRequest{ + TenantID: "tenant-1", + KeyID: "key-1", + KeyVersion: 1, + Secret: &cryptor.Secret{ + Algorithm: cryptor.KeyAlgorithmAES256, + Data: secret, // invalid key size + }, + Ciphertext: cipherText, + } + + // when + resp, err := subj.Decrypt(ctx, req) + + // then + assert.Error(t, err) + assert.Nil(t, resp) + }) + + t.Run("should fail to decrypt if ciphertext data is destroyed", func(t *testing.T) { + // given + cipherText := newSecureMemData(t, []byte("ciphertext")) + secret := newSecretKey(t) + + req := cryptor.DecryptRequest{ + TenantID: "tenant-1", + KeyID: "key-1", + KeyVersion: 1, + Secret: &cryptor.Secret{ + Algorithm: cryptor.KeyAlgorithmAES256, + Data: secret, + }, + Ciphertext: cipherText, + } + + // destroy ciphertext before decryption + require.NoError(t, cipherText.Destroy()) + + // when + resp, err := subj.Decrypt(ctx, req) + + // then + assert.Error(t, err) + assert.Nil(t, resp) + }) +} + +func TestAES256GCM_EncryptDecrypt(t *testing.T) { + // given + ctx := t.Context() + subj := aes256gcm.New() + + t.Run("should encrypt and decrypt plaintext successfully", func(t *testing.T) { + // given + secret := newSecretKey(t) + text := []byte("hello, secure world!") + plainText := newSecureMemData(t, text) + + // when + // encrypt + encResp, err := subj.Encrypt(ctx, cryptor.EncryptRequest{ + TenantID: "tenant-1", + KeyID: "key-1", + KeyVersion: 1, + Secret: &cryptor.Secret{ + Algorithm: cryptor.KeyAlgorithmAES256, + Data: secret, + }, + Plaintext: plainText, + }) + + // then + require.NoError(t, err) + t.Cleanup(func() { _ = encResp.Ciphertext.Destroy() }) + + // ciphertext must differ from plaintext + assert.NotEqual(t, text, []byte(encResp.Ciphertext.SecureBytes())) + + // when + // decrypt + decResp, err := subj.Decrypt(ctx, cryptor.DecryptRequest{ + TenantID: "tenant-1", + KeyID: "key-1", + KeyVersion: 1, + Secret: &cryptor.Secret{ + Algorithm: cryptor.KeyAlgorithmAES256, + Data: secret, + }, + Ciphertext: encResp.Ciphertext, + }) + + // then + require.NoError(t, err) + t.Cleanup(func() { _ = decResp.Plaintext.Destroy() }) + + // recovered plaintext must match original + assert.Equal(t, text, []byte(decResp.Plaintext.SecureBytes())) + }) + + t.Run("should encrypt and decrypt with AAD successfully", func(t *testing.T) { + // given + secret := newSecretKey(t) + text := []byte("authenticated payload") + plainText := newSecureMemData(t, text) + aad := []byte("context-binding-data") + + // when + // encrypt with AAD + encResp, err := subj.Encrypt(ctx, cryptor.EncryptRequest{ + TenantID: "tenant-1", + KeyID: "key-1", + KeyVersion: 1, + Secret: &cryptor.Secret{ + Algorithm: cryptor.KeyAlgorithmAES256, + Data: secret, + }, + Plaintext: plainText, + AAD: aad, + }) + + // then + require.NoError(t, err) + t.Cleanup(func() { _ = encResp.Ciphertext.Destroy() }) + + // when + // decrypt with same AAD succeeds + decResp, err := subj.Decrypt(ctx, cryptor.DecryptRequest{ + TenantID: "tenant-1", + KeyID: "key-1", + KeyVersion: 1, + Secret: &cryptor.Secret{ + Algorithm: cryptor.KeyAlgorithmAES256, + Data: secret, + }, + Ciphertext: encResp.Ciphertext, + AAD: aad, + }) + + // then + require.NoError(t, err) + t.Cleanup(func() { _ = decResp.Plaintext.Destroy() }) + + assert.Equal(t, text, []byte(decResp.Plaintext.SecureBytes())) + }) + + t.Run("should fail to decrypt with wrong AAD", func(t *testing.T) { + // given + secretKey := newSecretKey(t) + plainText := newSecureMemData(t, []byte("secret")) + + // when + encResp, err := subj.Encrypt(ctx, cryptor.EncryptRequest{ + TenantID: "tenant-1", + KeyID: "key-1", + KeyVersion: 1, + Secret: &cryptor.Secret{ + Algorithm: cryptor.KeyAlgorithmAES256, + Data: secretKey, + }, + Plaintext: plainText, + AAD: []byte("correct-aad"), + }) + + // then + require.NoError(t, err) + t.Cleanup(func() { _ = encResp.Ciphertext.Destroy() }) + + // when + // decrypt with wrong AAD must fail + decResp, err := subj.Decrypt(ctx, cryptor.DecryptRequest{ + TenantID: "tenant-1", + KeyID: "key-1", + KeyVersion: 1, + Secret: &cryptor.Secret{ + Algorithm: cryptor.KeyAlgorithmAES256, + Data: secretKey, + }, + Ciphertext: encResp.Ciphertext, + AAD: []byte("wrong-aad"), + }) + + // then + assert.Error(t, err) + assert.Nil(t, decResp) + }) + + t.Run("should fail to decrypt with wrong key", func(t *testing.T) { + // given + secret1 := newSecretKey(t) + secret2 := newSecretKey(t) + plainText := newSecureMemData(t, []byte("secret")) + + // when + encResp, err := subj.Encrypt(ctx, cryptor.EncryptRequest{ + TenantID: "tenant-1", + KeyID: "key-1", + KeyVersion: 1, + Secret: &cryptor.Secret{ + Algorithm: cryptor.KeyAlgorithmAES256, + Data: secret1, + }, + Plaintext: plainText, + }) + + // then + require.NoError(t, err) + t.Cleanup(func() { _ = encResp.Ciphertext.Destroy() }) + + // when + // decrypt with different key must fail + decResp, err := subj.Decrypt(ctx, cryptor.DecryptRequest{ + TenantID: "tenant-1", + KeyID: "key-1", + KeyVersion: 1, + Secret: &cryptor.Secret{ + Algorithm: cryptor.KeyAlgorithmAES256, + Data: secret2, + }, + Ciphertext: encResp.Ciphertext, + }) + + // then + assert.Error(t, err) + assert.Nil(t, decResp) + }) + + t.Run("should fail to decrypt tampered ciphertext", func(t *testing.T) { + // given + secret := newSecretKey(t) + plainText := newSecureMemData(t, []byte("do not tamper")) + + // when + encResp, err := subj.Encrypt(ctx, cryptor.EncryptRequest{ + TenantID: "tenant-1", + KeyID: "key-1", + KeyVersion: 1, + Secret: &cryptor.Secret{ + Algorithm: cryptor.KeyAlgorithmAES256, + Data: secret, + }, + Plaintext: plainText, + }) + + // then + require.NoError(t, err) + t.Cleanup(func() { _ = encResp.Ciphertext.Destroy() }) + + // copy ciphertext into writable secure memory and flip a byte + ct := encResp.Ciphertext.SecureBytes() + tampered, err := securemem.NewData("tampered", len(ct)) + require.NoError(t, err) + + t.Cleanup(func() { _ = tampered.Destroy() }) + + copy(tampered.SecureBytes(), ct) + tampered.SecureBytes()[len(ct)-1] ^= 0xFF + + // when + decResp, err := subj.Decrypt(ctx, cryptor.DecryptRequest{ + TenantID: "tenant-1", + KeyID: "key-1", + KeyVersion: 1, + Secret: &cryptor.Secret{ + Algorithm: cryptor.KeyAlgorithmAES256, + Data: secret, + }, + Ciphertext: tampered, + }) + + // then + assert.Error(t, err) + assert.Nil(t, decResp) + }) + + t.Run("should fail to decrypt if ciphertext is too short to contain nonce and tag", func(t *testing.T) { + // given + secret := newSecretKey(t) + plainText := newSecureMemData(t, []byte("short cipher")) + + // when + encResp, err := subj.Encrypt(ctx, cryptor.EncryptRequest{ + TenantID: "tenant-1", + KeyID: "key-1", + KeyVersion: 1, + Secret: &cryptor.Secret{ + Algorithm: cryptor.KeyAlgorithmAES256, + Data: secret, + }, + Plaintext: plainText, + }) + + // then + require.NoError(t, err) + t.Cleanup(func() { _ = encResp.Ciphertext.Destroy() }) + + // create a truncated ciphertext that is too short to contain nonce and tag + ct := encResp.Ciphertext.SecureBytes() + truncated, err := securemem.NewData("truncated", 8) // too short for nonce+tag + require.NoError(t, err) + + t.Cleanup(func() { _ = truncated.Destroy() }) + copy(truncated.SecureBytes(), ct[:8]) + + // when + decResp, err := subj.Decrypt(ctx, cryptor.DecryptRequest{ + TenantID: "tenant-1", + KeyID: "key-1", + KeyVersion: 1, + Secret: &cryptor.Secret{ + Algorithm: cryptor.KeyAlgorithmAES256, + Data: secret, + }, + Ciphertext: truncated, + }) + + // then + assert.Error(t, err) + assert.Nil(t, decResp) + }) + + t.Run("should not destroy secret and ciphertext after decryption", func(t *testing.T) { + // given + secret := newSecretKey(t) + plainText := newSecureMemData(t, []byte("plaintext")) + + // when + encResp, err := subj.Encrypt(ctx, cryptor.EncryptRequest{ + TenantID: "tenant-1", + KeyID: "key-1", + KeyVersion: 1, + Secret: &cryptor.Secret{ + Algorithm: cryptor.KeyAlgorithmAES256, + Data: secret, + }, + Plaintext: plainText, + }) + require.NoError(t, err) + t.Cleanup(func() { _ = encResp.Ciphertext.Destroy() }) + + decResp, err := subj.Decrypt(ctx, cryptor.DecryptRequest{ + TenantID: "tenant-1", + KeyID: "key-1", + KeyVersion: 1, + Secret: &cryptor.Secret{ + Algorithm: cryptor.KeyAlgorithmAES256, + Data: secret, + }, + Ciphertext: encResp.Ciphertext, + }) + require.NoError(t, err) + + t.Cleanup(func() { _ = decResp.Plaintext.Destroy() }) + + // then + assert.NotNil(t, secret.SecureBytes()) + assert.NotNil(t, encResp.Ciphertext.SecureBytes()) + }) +} + +func TestAES256GCM_Info(t *testing.T) { + // given + subj := aes256gcm.New() + + // when + info := subj.Info() + + // then + assert.Equal(t, aes256gcm.InfoNameAES256GCM, info.Name) + assert.True(t, info.DecryptionSecretRequired) +} + +// newSecretKey allocate a 32-byte AES key in secure memory +func newSecretKey(t *testing.T) *securemem.Data { + t.Helper() + + key, err := securemem.NewData("test-key", 32) + require.NoError(t, err) + + _, err = rand.Read(key.SecureBytes()) + require.NoError(t, err) + + t.Cleanup(func() { _ = key.Destroy() }) + + return key +} + +// newSecureMemData allocate in secure memory +func newSecureMemData(t *testing.T, content []byte) *securemem.Data { + t.Helper() + + pt, err := securemem.NewData("test-plaintext", len(content)) + require.NoError(t, err) + + copy(pt.SecureBytes(), content) + + t.Cleanup(func() { _ = pt.Destroy() }) + + return pt +} diff --git a/internal/cryptor/cryptor.go b/internal/cryptor/cryptor.go index eab8464..2d98a17 100644 --- a/internal/cryptor/cryptor.go +++ b/internal/cryptor/cryptor.go @@ -19,8 +19,13 @@ const KeyAlgorithmAES256 KeyAlgorithm = "AES256" // InfoName identifies the type of information returned by the Info() method of a Cryptor. type InfoName string -// InfoNameAES256GCM indicates that the Cryptor supports AES-256 in Galois/Counter Mode (GCM). -const InfoNameAES256GCM InfoName = "AES256-GCM" +// Secret pairs a key algorithm with its key material stored in secure memory. +type Secret struct { + // Algorithm is the encryption algorithm to apply. + Algorithm KeyAlgorithm + // Data is the raw key bytes in mlock'd memory. Nil if the Cryptor manages its own secrets (e.g., HSM). + Data *securemem.Data +} // EncryptRequest contains parameters for an encryption operation. type EncryptRequest struct { @@ -30,11 +35,8 @@ type EncryptRequest struct { KeyID string // KeyVersion specifies which version of the key to use. KeyVersion int - // Algorithm is the encryption algorithm to apply. - Algorithm KeyAlgorithm - // Secret is the key material used for encryption. - // The Secret is nil if Cryptor manages its own secrets (e.g., HSM). - Secret *securemem.Data + // Secret holds the key material for encryption. Nil when the Cryptor manages its own secrets. + Secret *Secret // Plaintext is the data to encrypt. // The Plaintext field should not be nil. Plaintext *securemem.Data @@ -56,11 +58,8 @@ type DecryptRequest struct { KeyID string // KeyVersion specifies which version of the key to use. KeyVersion int - // Algorithm is the encryption algorithm that was used. - Algorithm KeyAlgorithm - // Secret is the key material used for decryption. - // The Secret is nil if Cryptor manages its own secrets (e.g., HSM). - Secret *securemem.Data + // Secret holds the key material for decryption. Nil when the Cryptor manages its own secrets. + Secret *Secret // Ciphertext is the data to decrypt. // The Ciphertext field should not be nil. Ciphertext *securemem.Data @@ -114,7 +113,7 @@ func (req EncryptRequest) Validate() error { if req.Plaintext == nil || len(req.Plaintext.SecureBytes()) == 0 { return fmt.Errorf("invalid plaintext: %w", ErrRequest) } - return nil + return req.Secret.Validate() } func (req DecryptRequest) Validate() error { @@ -130,5 +129,17 @@ func (req DecryptRequest) Validate() error { if req.Ciphertext == nil || len(req.Ciphertext.SecureBytes()) == 0 { return fmt.Errorf("invalid ciphertext: %w", ErrRequest) } + return req.Secret.Validate() +} + +func (cs *Secret) Validate() error { + if cs != nil { + if cs.Algorithm == "" { + return fmt.Errorf("invalid secret algorithm: %w", ErrRequest) + } + if cs.Data == nil || len(cs.Data.SecureBytes()) == 0 { + return fmt.Errorf("invalid secret data: %w", ErrRequest) + } + } return nil } diff --git a/internal/cryptor/cryptor_test.go b/internal/cryptor/cryptor_test.go index fd2498c..dc604c5 100644 --- a/internal/cryptor/cryptor_test.go +++ b/internal/cryptor/cryptor_test.go @@ -31,7 +31,7 @@ func TestDecryptRequestValidate(t *testing.T) { wantErr error }{ { - name: "valid request", + name: "valid request without secret", req: cryptor.DecryptRequest{ TenantID: "tenant1", KeyID: "key1", @@ -40,6 +40,20 @@ func TestDecryptRequestValidate(t *testing.T) { }, wantErr: nil, }, + { + name: "valid request with secret", + req: cryptor.DecryptRequest{ + TenantID: "tenant1", + KeyID: "key1", + KeyVersion: 1, + Ciphertext: validData, + Secret: &cryptor.Secret{ + Algorithm: cryptor.KeyAlgorithmAES256, + Data: validData, + }, + }, + wantErr: nil, + }, { name: "valid request with negative key version", req: cryptor.DecryptRequest{ @@ -98,7 +112,7 @@ func TestDecryptRequestValidate(t *testing.T) { wantErr: cryptor.ErrRequest, }, { - name: "empty securememe data", + name: "empty ciphertext data", req: cryptor.DecryptRequest{ TenantID: "tenant1", KeyID: "key1", @@ -107,6 +121,61 @@ func TestDecryptRequestValidate(t *testing.T) { }, wantErr: cryptor.ErrRequest, }, + { + name: "empty algorithm in secret", + req: cryptor.DecryptRequest{ + TenantID: "tenant1", + KeyID: "key1", + KeyVersion: 1, + Ciphertext: validData, + Secret: &cryptor.Secret{ + Data: validData, + }, + }, + wantErr: cryptor.ErrRequest, + }, + { + name: "empty data in secret", + req: cryptor.DecryptRequest{ + TenantID: "tenant1", + KeyID: "key1", + KeyVersion: 1, + Ciphertext: validData, + Secret: &cryptor.Secret{ + Algorithm: cryptor.KeyAlgorithmAES256, + Data: &securemem.Data{}, + }, + }, + wantErr: cryptor.ErrRequest, + }, + { + name: "destroyed data in secret", + req: cryptor.DecryptRequest{ + TenantID: "tenant1", + KeyID: "key1", + KeyVersion: 1, + Ciphertext: validData, + Secret: &cryptor.Secret{ + Algorithm: cryptor.KeyAlgorithmAES256, + Data: destroyedData, + }, + }, + wantErr: cryptor.ErrRequest, + }, + { + name: "nil data in secret", + req: cryptor.DecryptRequest{ + TenantID: "tenant1", + KeyID: "key1", + KeyVersion: 1, + Ciphertext: validData, + Secret: &cryptor.Secret{ + Algorithm: cryptor.KeyAlgorithmAES256, + Data: nil, + }, + }, + wantErr: cryptor.ErrRequest, + }, } for _, tt := range tts { @@ -141,7 +210,7 @@ func TestEncryptRequestValidate(t *testing.T) { wantErr error }{ { - name: "valid request", + name: "valid request without secret", req: cryptor.EncryptRequest{ TenantID: "tenant1", KeyID: "key1", @@ -150,6 +219,20 @@ func TestEncryptRequestValidate(t *testing.T) { }, wantErr: nil, }, + { + name: "valid request with secret", + req: cryptor.EncryptRequest{ + TenantID: "tenant1", + KeyID: "key1", + KeyVersion: 1, + Plaintext: validData, + Secret: &cryptor.Secret{ + Algorithm: cryptor.KeyAlgorithmAES256, + Data: validData, + }, + }, + wantErr: nil, + }, { name: "valid request with negative key version", req: cryptor.EncryptRequest{ @@ -208,7 +291,7 @@ func TestEncryptRequestValidate(t *testing.T) { wantErr: cryptor.ErrRequest, }, { - name: "empty securemem data", + name: "empty ciphertext data", req: cryptor.EncryptRequest{ TenantID: "tenant1", KeyID: "key1", @@ -217,6 +300,61 @@ func TestEncryptRequestValidate(t *testing.T) { }, wantErr: cryptor.ErrRequest, }, + { + name: "empty algorithm in secret", + req: cryptor.EncryptRequest{ + TenantID: "tenant1", + KeyID: "key1", + KeyVersion: 1, + Plaintext: validData, + Secret: &cryptor.Secret{ + Data: validData, + }, + }, + wantErr: cryptor.ErrRequest, + }, + { + name: "empty data in secret", + req: cryptor.EncryptRequest{ + TenantID: "tenant1", + KeyID: "key1", + KeyVersion: 1, + Plaintext: validData, + Secret: &cryptor.Secret{ + Algorithm: cryptor.KeyAlgorithmAES256, + Data: &securemem.Data{}, + }, + }, + wantErr: cryptor.ErrRequest, + }, + { + name: "destroyed data in secret", + req: cryptor.EncryptRequest{ + TenantID: "tenant1", + KeyID: "key1", + KeyVersion: 1, + Plaintext: validData, + Secret: &cryptor.Secret{ + Algorithm: cryptor.KeyAlgorithmAES256, + Data: destroyedData, + }, + }, + wantErr: cryptor.ErrRequest, + }, + { + name: "nil data in secret", + req: cryptor.EncryptRequest{ + TenantID: "tenant1", + KeyID: "key1", + KeyVersion: 1, + Plaintext: validData, + Secret: &cryptor.Secret{ + Algorithm: cryptor.KeyAlgorithmAES256, + Data: nil, + }, + }, + wantErr: cryptor.ErrRequest, + }, } for _, tt := range tts { diff --git a/internal/cryptor/staticsecret/staticsecret.go b/internal/cryptor/staticsecret/staticsecret.go new file mode 100644 index 0000000..cb57789 --- /dev/null +++ b/internal/cryptor/staticsecret/staticsecret.go @@ -0,0 +1,98 @@ +// Package staticsecret provides a [cryptor.Cryptor] that embeds a fixed AES-256-GCM key, +// removing the need for callers to supply secrets per-request. +package staticsecret + +import ( + "context" + "errors" + "fmt" + + "github.com/openkcm/krypton/internal/cryptor" + "github.com/openkcm/krypton/internal/cryptor/aes256gcm" + "github.com/openkcm/krypton/internal/securemem" +) + +// StaticSecret is a [Cryptor] that wraps [AES256GCM] with a pre-configured key, +// so callers don't need to supply secrets per-request. It rejects requests that +// include a secret, and injects its own before delegating to the underlying cipher. +// +// The caller retains ownership of the secret and must not destroy it while +// StaticSecret is in use. +type StaticSecret struct { + secret *securemem.Data + aes256gcm *aes256gcm.AES256GCM + info cryptor.Info +} + +// ErrInitializationFailed indicates that New could not be constructed +// due to an unsupported algorithm or invalid key material. +var ErrInitializationFailed = errors.New("static secret initialization failed") + +var _ cryptor.Cryptor = &StaticSecret{} + +// InfoNameStaticSecret indicates a Cryptor that manages its own static key material. +const InfoNameStaticSecret cryptor.InfoName = "AES256-GCM-STATIC-SECRET" + +// New returns a StaticSecret for the given algorithm name and key material. +// Currently only [InfoNameStaticSecret] is supported. The secret must be non-nil and non-empty. +func New(name cryptor.InfoName, secret *securemem.Data) (*StaticSecret, error) { + if name != InfoNameStaticSecret { + return nil, fmt.Errorf("unsupported algorithm name: %s: %w", name, ErrInitializationFailed) + } + + if secret == nil || len(secret.SecureBytes()) != 32 { + return nil, fmt.Errorf("invalid secret: %w", ErrInitializationFailed) + } + + return &StaticSecret{ + secret: secret, + aes256gcm: aes256gcm.New(), + info: cryptor.Info{ + Name: name, + DecryptionSecretRequired: false, + }, + }, nil +} + +// Decrypt implements [Cryptor]. It returns an error if the request contains a secret. +func (s *StaticSecret) Decrypt(ctx context.Context, req cryptor.DecryptRequest) (*cryptor.DecryptResponse, error) { + err := req.Validate() + if err != nil { + return nil, err + } + + if req.Secret != nil { + return nil, fmt.Errorf("decryption secret should not be provided in the request: %w", cryptor.ErrRequest) + } + // replacing the secret in the request with the static secret + req.Secret = &cryptor.Secret{ + Data: s.secret, + Algorithm: cryptor.KeyAlgorithmAES256, + } + + return s.aes256gcm.Decrypt(ctx, req) +} + +// Encrypt implements [Cryptor]. It returns an error if the request contains a secret. +func (s *StaticSecret) Encrypt(ctx context.Context, req cryptor.EncryptRequest) (*cryptor.EncryptResponse, error) { + err := req.Validate() + if err != nil { + return nil, err + } + + if req.Secret != nil { + return nil, fmt.Errorf("encryption secret should not be provided in the request: %w", cryptor.ErrRequest) + } + // replacing the secret in the request with the static secret + req.Secret = &cryptor.Secret{ + Data: s.secret, + Algorithm: cryptor.KeyAlgorithmAES256, + } + + return s.aes256gcm.Encrypt(ctx, req) +} + +// Info implements [Cryptor]. +func (s *StaticSecret) Info() cryptor.Info { + return s.info +} diff --git a/internal/cryptor/aes256gcm_test.go b/internal/cryptor/staticsecret/staticsecret_test.go similarity index 59% rename from internal/cryptor/aes256gcm_test.go rename to internal/cryptor/staticsecret/staticsecret_test.go index 96ed4ba..6ea80d5 100644 --- a/internal/cryptor/aes256gcm_test.go +++ b/internal/cryptor/staticsecret/staticsecret_test.go @@ -1,4 +1,4 @@ -package cryptor_test +package staticsecret_test import ( "crypto/rand" @@ -8,13 +8,83 @@ import ( "github.com/stretchr/testify/require" "github.com/openkcm/krypton/internal/cryptor" + "github.com/openkcm/krypton/internal/cryptor/staticsecret" "github.com/openkcm/krypton/internal/securemem" ) -func TestAES256GCM_Encrypt(t *testing.T) { +func TestStaticSecret_New(t *testing.T) { + // given + secretWithInvalidLen, err := securemem.NewData("key", 16) + require.NoError(t, err) + t.Cleanup(func() { _ = secretWithInvalidLen.Destroy() }) + + tts := []struct { + name string + nameArg cryptor.InfoName + keyArg *securemem.Data + wantErr bool + }{ + { + name: "should not return error for valid name and key", + nameArg: staticsecret.InfoNameStaticSecret, + keyArg: newSecretKey(t), + wantErr: false, + }, + { + name: "should return error for empty name", + nameArg: "", + keyArg: newSecretKey(t), + wantErr: true, + }, + { + name: "should return error for nil secret", + nameArg: staticsecret.InfoNameStaticSecret, + keyArg: nil, + wantErr: true, + }, + { + name: "should return error for empty secret", + nameArg: staticsecret.InfoNameStaticSecret, + keyArg: &securemem.Data{}, + wantErr: true, + }, + { + name: "should return error for invalid secret size", + nameArg: staticsecret.InfoNameStaticSecret, + keyArg: secretWithInvalidLen, + wantErr: true, + }, + { + name: "should return error if name is unknown", + nameArg: "unknown-name", + keyArg: newSecretKey(t), + wantErr: true, + }, + } + + for _, tt := range tts { + t.Run(tt.name, func(t *testing.T) { + // when + subj, err := staticsecret.New(tt.nameArg, tt.keyArg) + + // then + if tt.wantErr { + assert.Error(t, err) + assert.Nil(t, subj) + } else { + assert.NoError(t, err) + assert.NotNil(t, subj) + } + }) + } +} + +func TestStaticSecret_Encrypt(t *testing.T) { // given ctx := t.Context() - subj := cryptor.NewAES256GCM() + + subj, err := staticsecret.New(staticsecret.InfoNameStaticSecret, newSecretKey(t)) + require.NoError(t, err) t.Run("should fail if encrypt request validation fails", func(t *testing.T) { // given @@ -22,7 +92,6 @@ func TestAES256GCM_Encrypt(t *testing.T) { TenantID: "tenant-1", KeyID: "key-1", KeyVersion: 1, - Algorithm: cryptor.KeyAlgorithmAES256, Plaintext: nil, // missing plaintext } @@ -34,19 +103,19 @@ func TestAES256GCM_Encrypt(t *testing.T) { assert.Nil(t, resp) }) - t.Run("should fail if encryption secret is missing", func(t *testing.T) { + t.Run("should fail if encrypt request contains a secret", func(t *testing.T) { // given - plainText, err := securemem.NewData("plaintext", 1) - require.NoError(t, err) - t.Cleanup(func() { _ = plainText.Destroy() }) + plainText := newSecureMemData(t, []byte("plaintext")) req := cryptor.EncryptRequest{ TenantID: "tenant-1", KeyID: "key-1", KeyVersion: 1, - Algorithm: cryptor.KeyAlgorithmAES256, Plaintext: plainText, - Secret: nil, // missing secret + Secret: &cryptor.Secret{ + Data: newSecretKey(t), + Algorithm: cryptor.KeyAlgorithmAES256, + }, } // when @@ -57,31 +126,18 @@ func TestAES256GCM_Encrypt(t *testing.T) { assert.Nil(t, resp) }) - t.Run("should fail to encrypt if plaintext data is destroyed", func(t *testing.T) { + t.Run("should fail if encrypt request contains empty secret", func(t *testing.T) { // given - plainText, err := securemem.NewData("plaintext", 1) - require.NoError(t, err) - t.Cleanup(func() { _ = plainText.Destroy() }) - - secret, err := securemem.NewData("key", 32) - require.NoError(t, err) - t.Cleanup(func() { _ = secret.Destroy() }) - - _, err = rand.Read(secret.SecureBytes()) - require.NoError(t, err) + plainText := newSecureMemData(t, []byte("plaintext")) req := cryptor.EncryptRequest{ TenantID: "tenant-1", KeyID: "key-1", KeyVersion: 1, - Algorithm: cryptor.KeyAlgorithmAES256, Plaintext: plainText, - Secret: secret, + Secret: &cryptor.Secret{}, } - // destroy plaintext before encryption - require.NoError(t, plainText.Destroy()) - // when resp, err := subj.Encrypt(ctx, req) @@ -90,26 +146,41 @@ func TestAES256GCM_Encrypt(t *testing.T) { assert.Nil(t, resp) }) - t.Run("should fail if secret key size is invalid", func(t *testing.T) { + t.Run("should not fail if encrypt request is missing secret", func(t *testing.T) { // given - plainText, err := securemem.NewData("plaintext", 1) - require.NoError(t, err) - t.Cleanup(func() { _ = plainText.Destroy() }) + plainText := newSecureMemData(t, []byte("plaintext")) - // invalid secret size (should be 32 bytes for AES-256) - secret, err := securemem.NewData("key", 16) - require.NoError(t, err) - t.Cleanup(func() { _ = secret.Destroy() }) + req := cryptor.EncryptRequest{ + TenantID: "tenant-1", + KeyID: "key-1", + KeyVersion: 1, + Plaintext: plainText, + Secret: nil, // missing secret + } + + // when + resp, err := subj.Encrypt(ctx, req) + + // then + assert.NoError(t, err) + assert.NotNil(t, resp) + t.Cleanup(func() { _ = resp.Ciphertext.Destroy() }) + }) + + t.Run("should fail to encrypt if plaintext data is destroyed", func(t *testing.T) { + // given + plainText := newSecureMemData(t, []byte("plaintext")) req := cryptor.EncryptRequest{ TenantID: "tenant-1", KeyID: "key-1", KeyVersion: 1, - Algorithm: cryptor.KeyAlgorithmAES256, Plaintext: plainText, - Secret: secret, // invalid key size } + // destroy plaintext before encryption + require.NoError(t, plainText.Destroy()) + // when resp, err := subj.Encrypt(ctx, req) @@ -120,24 +191,13 @@ func TestAES256GCM_Encrypt(t *testing.T) { t.Run("should generate different ciphertext for same plaintext and key", func(t *testing.T) { // given - plainText, err := securemem.NewData("same-plaintext", 1) - require.NoError(t, err) - t.Cleanup(func() { _ = plainText.Destroy() }) - - secret, err := securemem.NewData("same-key", 32) - require.NoError(t, err) - t.Cleanup(func() { _ = secret.Destroy() }) - - _, err = rand.Read(secret.SecureBytes()) - require.NoError(t, err) + plainText := newSecureMemData(t, []byte("plaintext")) req := cryptor.EncryptRequest{ TenantID: "tenant-1", KeyID: "key-1", KeyVersion: 1, - Algorithm: cryptor.KeyAlgorithmAES256, Plaintext: plainText, - Secret: secret, } // when @@ -153,26 +213,15 @@ func TestAES256GCM_Encrypt(t *testing.T) { assert.NotEqual(t, resp1.Ciphertext.SecureBytes(), resp2.Ciphertext.SecureBytes()) }) - t.Run("should not destroy secret and plaintext after encryption", func(t *testing.T) { + t.Run("should not destroy plaintext after encryption", func(t *testing.T) { // given - plainText, err := securemem.NewData("plaintext", 1) - require.NoError(t, err) - t.Cleanup(func() { _ = plainText.Destroy() }) - - secret, err := securemem.NewData("key", 32) - require.NoError(t, err) - t.Cleanup(func() { _ = secret.Destroy() }) - - _, err = rand.Read(secret.SecureBytes()) - require.NoError(t, err) + plainText := newSecureMemData(t, []byte("plaintext")) req := cryptor.EncryptRequest{ TenantID: "tenant-1", KeyID: "key-1", KeyVersion: 1, - Algorithm: cryptor.KeyAlgorithmAES256, Plaintext: plainText, - Secret: secret, } // when @@ -182,14 +231,14 @@ func TestAES256GCM_Encrypt(t *testing.T) { // then assert.NotNil(t, plainText.SecureBytes()) - assert.NotNil(t, secret.SecureBytes()) }) } -func TestAES256GCM_Decrypt(t *testing.T) { +func TestStaticSecret_Decrypt(t *testing.T) { // given ctx := t.Context() - subj := cryptor.NewAES256GCM() + subj, err := staticsecret.New(staticsecret.InfoNameStaticSecret, newSecretKey(t)) + require.NoError(t, err) t.Run("should fail if decrypt request validation fails", func(t *testing.T) { // given @@ -197,7 +246,6 @@ func TestAES256GCM_Decrypt(t *testing.T) { TenantID: "tenant-1", KeyID: "key-1", KeyVersion: 1, - Algorithm: cryptor.KeyAlgorithmAES256, Ciphertext: nil, // missing ciphertext } @@ -209,19 +257,26 @@ func TestAES256GCM_Decrypt(t *testing.T) { assert.Nil(t, resp) }) - t.Run("should fail if encryption secret is missing", func(t *testing.T) { + t.Run("should fail if decrypt request contains a secret", func(t *testing.T) { // given - cipherText, err := securemem.NewData("plaintext", 1) + encResp, err := subj.Encrypt(ctx, cryptor.EncryptRequest{ + TenantID: "tenant-1", + KeyID: "key-1", + KeyVersion: 1, + Plaintext: newSecureMemData(t, []byte("ciphertext")), + }) require.NoError(t, err) - t.Cleanup(func() { _ = cipherText.Destroy() }) + t.Cleanup(func() { _ = encResp.Ciphertext.Destroy() }) req := cryptor.DecryptRequest{ TenantID: "tenant-1", KeyID: "key-1", KeyVersion: 1, - Algorithm: cryptor.KeyAlgorithmAES256, - Ciphertext: cipherText, - Secret: nil, // missing secret + Ciphertext: encResp.Ciphertext, + Secret: &cryptor.Secret{ + Data: newSecretKey(t), + Algorithm: cryptor.KeyAlgorithmAES256, + }, } // when @@ -232,24 +287,23 @@ func TestAES256GCM_Decrypt(t *testing.T) { assert.Nil(t, resp) }) - t.Run("should fail if secret key size is invalid", func(t *testing.T) { + t.Run("should fail if decrypt request contains empty secret", func(t *testing.T) { // given - cipherText, err := securemem.NewData("plaintext", 1) - require.NoError(t, err) - t.Cleanup(func() { _ = cipherText.Destroy() }) - - // invalid secret size (should be 32 bytes for AES-256) - secret, err := securemem.NewData("key", 16) + encResp, err := subj.Encrypt(ctx, cryptor.EncryptRequest{ + TenantID: "tenant-1", + KeyID: "key-1", + KeyVersion: 1, + Plaintext: newSecureMemData(t, []byte("ciphertext")), + }) require.NoError(t, err) - t.Cleanup(func() { _ = secret.Destroy() }) + t.Cleanup(func() { _ = encResp.Ciphertext.Destroy() }) req := cryptor.DecryptRequest{ TenantID: "tenant-1", KeyID: "key-1", KeyVersion: 1, - Algorithm: cryptor.KeyAlgorithmAES256, - Ciphertext: cipherText, - Secret: secret, // invalid key size + Ciphertext: encResp.Ciphertext, + Secret: &cryptor.Secret{}, } // when @@ -260,26 +314,43 @@ func TestAES256GCM_Decrypt(t *testing.T) { assert.Nil(t, resp) }) - t.Run("should fail to decrypt if ciphertext data is destroyed", func(t *testing.T) { + t.Run("should not fail if decrypt request is missing secret", func(t *testing.T) { // given - cipherText, err := securemem.NewData("ciphertext", 1) + encResp, err := subj.Encrypt(ctx, cryptor.EncryptRequest{ + TenantID: "tenant-1", + KeyID: "key-1", + KeyVersion: 1, + Plaintext: newSecureMemData(t, []byte("ciphertext")), + }) require.NoError(t, err) - t.Cleanup(func() { _ = cipherText.Destroy() }) + t.Cleanup(func() { _ = encResp.Ciphertext.Destroy() }) - secret, err := securemem.NewData("key", 32) - require.NoError(t, err) - t.Cleanup(func() { _ = secret.Destroy() }) + req := cryptor.DecryptRequest{ + TenantID: "tenant-1", + KeyID: "key-1", + KeyVersion: 1, + Ciphertext: encResp.Ciphertext, + Secret: nil, // missing secret + } - _, err = rand.Read(secret.SecureBytes()) - require.NoError(t, err) + // when + resp, err := subj.Decrypt(ctx, req) + + // then + assert.NoError(t, err) + assert.NotNil(t, resp) + t.Cleanup(func() { _ = resp.Plaintext.Destroy() }) + }) + + t.Run("should fail to decrypt if ciphertext data is destroyed", func(t *testing.T) { + // given + cipherText := newSecureMemData(t, []byte("ciphertext")) req := cryptor.DecryptRequest{ TenantID: "tenant-1", KeyID: "key-1", KeyVersion: 1, - Algorithm: cryptor.KeyAlgorithmAES256, Ciphertext: cipherText, - Secret: secret, } // destroy ciphertext before decryption @@ -294,46 +365,17 @@ func TestAES256GCM_Decrypt(t *testing.T) { }) } -func TestAES256GCM_EncryptDecrypt(t *testing.T) { +func TestStaticSecret_EncryptDecrypt(t *testing.T) { // given ctx := t.Context() - subj := cryptor.NewAES256GCM() - - // given - // allocate a 32-byte AES key in secure memory - newSecretKey := func(t *testing.T) *securemem.Data { - t.Helper() - key, err := securemem.NewData("test-key", 32) - require.NoError(t, err) - - _, err = rand.Read(key.SecureBytes()) - require.NoError(t, err) - - t.Cleanup(func() { _ = key.Destroy() }) - - return key - } - - // allocate plaintext in secure memory - newPlaintext := func(t *testing.T, content []byte) *securemem.Data { - t.Helper() - - pt, err := securemem.NewData("test-plaintext", len(content)) - require.NoError(t, err) - - copy(pt.SecureBytes(), content) - - t.Cleanup(func() { _ = pt.Destroy() }) - - return pt - } + subj, err := staticsecret.New(staticsecret.InfoNameStaticSecret, newSecretKey(t)) + require.NoError(t, err) t.Run("should encrypt and decrypt plaintext successfully", func(t *testing.T) { // given - secret := newSecretKey(t) text := []byte("hello, secure world!") - plainText := newPlaintext(t, text) + plainText := newSecureMemData(t, text) // when // encrypt @@ -341,8 +383,6 @@ func TestAES256GCM_EncryptDecrypt(t *testing.T) { TenantID: "tenant-1", KeyID: "key-1", KeyVersion: 1, - Algorithm: cryptor.KeyAlgorithmAES256, - Secret: secret, Plaintext: plainText, }) @@ -359,8 +399,6 @@ func TestAES256GCM_EncryptDecrypt(t *testing.T) { TenantID: "tenant-1", KeyID: "key-1", KeyVersion: 1, - Algorithm: cryptor.KeyAlgorithmAES256, - Secret: secret, Ciphertext: encResp.Ciphertext, }) @@ -374,9 +412,8 @@ func TestAES256GCM_EncryptDecrypt(t *testing.T) { t.Run("should encrypt and decrypt with AAD successfully", func(t *testing.T) { // given - secret := newSecretKey(t) text := []byte("authenticated payload") - plainText := newPlaintext(t, text) + plainText := newSecureMemData(t, text) aad := []byte("context-binding-data") // when @@ -385,8 +422,6 @@ func TestAES256GCM_EncryptDecrypt(t *testing.T) { TenantID: "tenant-1", KeyID: "key-1", KeyVersion: 1, - Algorithm: cryptor.KeyAlgorithmAES256, - Secret: secret, Plaintext: plainText, AAD: aad, }) @@ -401,8 +436,6 @@ func TestAES256GCM_EncryptDecrypt(t *testing.T) { TenantID: "tenant-1", KeyID: "key-1", KeyVersion: 1, - Algorithm: cryptor.KeyAlgorithmAES256, - Secret: secret, Ciphertext: encResp.Ciphertext, AAD: aad, }) @@ -416,16 +449,13 @@ func TestAES256GCM_EncryptDecrypt(t *testing.T) { t.Run("should fail to decrypt with wrong AAD", func(t *testing.T) { // given - secretKey := newSecretKey(t) - plainText := newPlaintext(t, []byte("secret")) + plainText := newSecureMemData(t, []byte("secret")) // when encResp, err := subj.Encrypt(ctx, cryptor.EncryptRequest{ TenantID: "tenant-1", KeyID: "key-1", KeyVersion: 1, - Algorithm: cryptor.KeyAlgorithmAES256, - Secret: secretKey, Plaintext: plainText, AAD: []byte("correct-aad"), }) @@ -440,8 +470,6 @@ func TestAES256GCM_EncryptDecrypt(t *testing.T) { TenantID: "tenant-1", KeyID: "key-1", KeyVersion: 1, - Algorithm: cryptor.KeyAlgorithmAES256, - Secret: secretKey, Ciphertext: encResp.Ciphertext, AAD: []byte("wrong-aad"), }) @@ -453,17 +481,13 @@ func TestAES256GCM_EncryptDecrypt(t *testing.T) { t.Run("should fail to decrypt with wrong key", func(t *testing.T) { // given - secret1 := newSecretKey(t) - secret2 := newSecretKey(t) - plainText := newPlaintext(t, []byte("secret")) + plainText := newSecureMemData(t, []byte("secret")) // when encResp, err := subj.Encrypt(ctx, cryptor.EncryptRequest{ TenantID: "tenant-1", KeyID: "key-1", KeyVersion: 1, - Algorithm: cryptor.KeyAlgorithmAES256, - Secret: secret1, Plaintext: plainText, }) @@ -472,13 +496,14 @@ func TestAES256GCM_EncryptDecrypt(t *testing.T) { t.Cleanup(func() { _ = encResp.Ciphertext.Destroy() }) // when - // decrypt with different key must fail - decResp, err := subj.Decrypt(ctx, cryptor.DecryptRequest{ + // decrypt with different staticsecret must fail + subj1, err := staticsecret.New(staticsecret.InfoNameStaticSecret, newSecretKey(t)) + require.NoError(t, err) + + decResp, err := subj1.Decrypt(ctx, cryptor.DecryptRequest{ TenantID: "tenant-1", KeyID: "key-1", KeyVersion: 1, - Algorithm: cryptor.KeyAlgorithmAES256, - Secret: secret2, Ciphertext: encResp.Ciphertext, }) @@ -489,17 +514,14 @@ func TestAES256GCM_EncryptDecrypt(t *testing.T) { t.Run("should fail to decrypt tampered ciphertext", func(t *testing.T) { // given - secret := newSecretKey(t) - plaintText := newPlaintext(t, []byte("do not tamper")) + plainText := newSecureMemData(t, []byte("do not tamper")) // when encResp, err := subj.Encrypt(ctx, cryptor.EncryptRequest{ TenantID: "tenant-1", KeyID: "key-1", KeyVersion: 1, - Algorithm: cryptor.KeyAlgorithmAES256, - Secret: secret, - Plaintext: plaintText, + Plaintext: plainText, }) // then @@ -521,8 +543,6 @@ func TestAES256GCM_EncryptDecrypt(t *testing.T) { TenantID: "tenant-1", KeyID: "key-1", KeyVersion: 1, - Algorithm: cryptor.KeyAlgorithmAES256, - Secret: secret, Ciphertext: tampered, }) @@ -531,18 +551,15 @@ func TestAES256GCM_EncryptDecrypt(t *testing.T) { assert.Nil(t, decResp) }) - t.Run("should fail to decrypt if cipher is too short to contain nonce and tag", func(t *testing.T) { + t.Run("should fail to decrypt if ciphertext is too short to contain nonce and tag", func(t *testing.T) { // given - secret := newSecretKey(t) - plainText := newPlaintext(t, []byte("short cipher")) + plainText := newSecureMemData(t, []byte("short cipher")) // when encResp, err := subj.Encrypt(ctx, cryptor.EncryptRequest{ TenantID: "tenant-1", KeyID: "key-1", KeyVersion: 1, - Algorithm: cryptor.KeyAlgorithmAES256, - Secret: secret, Plaintext: plainText, }) @@ -563,8 +580,6 @@ func TestAES256GCM_EncryptDecrypt(t *testing.T) { TenantID: "tenant-1", KeyID: "key-1", KeyVersion: 1, - Algorithm: cryptor.KeyAlgorithmAES256, - Secret: secret, Ciphertext: truncated, }) @@ -573,48 +588,72 @@ func TestAES256GCM_EncryptDecrypt(t *testing.T) { assert.Nil(t, decResp) }) - t.Run("should not destroy secret and ciphertext after decryption attempt", func(t *testing.T) { + t.Run("should not destroy ciphertext after decryption", func(t *testing.T) { // given - secret := newSecretKey(t) - plainText := newPlaintext(t, []byte("plaintext")) + plainText := newSecureMemData(t, []byte("plaintext")) // when encResp, err := subj.Encrypt(ctx, cryptor.EncryptRequest{ TenantID: "tenant-1", KeyID: "key-1", KeyVersion: 1, - Algorithm: cryptor.KeyAlgorithmAES256, - Secret: secret, Plaintext: plainText, }) require.NoError(t, err) t.Cleanup(func() { _ = encResp.Ciphertext.Destroy() }) - decResp, _ := subj.Decrypt(ctx, cryptor.DecryptRequest{ + decResp, err := subj.Decrypt(ctx, cryptor.DecryptRequest{ TenantID: "tenant-1", KeyID: "key-1", KeyVersion: 1, - Algorithm: cryptor.KeyAlgorithmAES256, - Secret: secret, Ciphertext: encResp.Ciphertext, }) - + require.NoError(t, err) t.Cleanup(func() { _ = decResp.Plaintext.Destroy() }) // then - assert.NotNil(t, secret.SecureBytes()) assert.NotNil(t, encResp.Ciphertext.SecureBytes()) }) } -func TestAES256GCM_Info(t *testing.T) { +func TestStaticSecret_Info(t *testing.T) { // given - subj := cryptor.NewAES256GCM() + subj, err := staticsecret.New(staticsecret.InfoNameStaticSecret, newSecretKey(t)) + require.NoError(t, err) // when info := subj.Info() // then - assert.Equal(t, cryptor.InfoNameAES256GCM, info.Name) - assert.True(t, info.DecryptionSecretRequired) + assert.Equal(t, staticsecret.InfoNameStaticSecret, info.Name) + assert.False(t, info.DecryptionSecretRequired) +} + +// newSecretKey allocate a 32-byte AES key in secure memory +func newSecretKey(t *testing.T) *securemem.Data { + t.Helper() + + key, err := securemem.NewData("test-key", 32) + require.NoError(t, err) + + _, err = rand.Read(key.SecureBytes()) + require.NoError(t, err) + + t.Cleanup(func() { _ = key.Destroy() }) + + return key +} + +// newSecureMemData allocate in secure memory +func newSecureMemData(t *testing.T, content []byte) *securemem.Data { + t.Helper() + + pt, err := securemem.NewData("test-plaintext", len(content)) + require.NoError(t, err) + + copy(pt.SecureBytes(), content) + + t.Cleanup(func() { _ = pt.Destroy() }) + + return pt }