Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support sm2 double certs #28

Merged
merged 1 commit into from
Aug 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 81 additions & 8 deletions crypto/cert.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,11 @@ import "C"

import (
"errors"
"fmt"
"io"
"io/ioutil"
"math/big"
"os"
"runtime"
"time"
"unsafe"
Expand All @@ -42,6 +45,7 @@ const (
EVP_SHA256 EVP_MD = iota
EVP_SHA384 EVP_MD = iota
EVP_SHA512 EVP_MD = iota
EVP_SM3 EVP_MD = iota
)

// X509_Version represents a version on an x509 certificate.
Expand Down Expand Up @@ -82,7 +86,7 @@ func NewCertWrapper(x unsafe.Pointer, ref ...interface{}) *Certificate {
}
}

// Allocate and return a new Name object.
// NewName allocate and return a new Name object.
func NewName() (*Name, error) {
n := C.X509_NAME_new()
if n == nil {
Expand Down Expand Up @@ -139,7 +143,9 @@ func NewCertificate(info *CertificateInfo, key PublicKey) (*Certificate, error)
runtime.SetFinalizer(c, func(c *Certificate) {
C.X509_free(c.x)
})

if err := c.SetVersion(X509_V3); err != nil {
return nil, err
}
name, err := c.GetSubjectName()
if err != nil {
return nil, err
Expand Down Expand Up @@ -272,9 +278,10 @@ func (c *Certificate) SetPubKey(pubKey PublicKey) error {
}

// Sign a certificate using a private key and a digest name.
// Accepted digest names are 'sha256', 'sha384', and 'sha512'.
// Accepted digest names are 'sm3', 'sha256', 'sha384', and 'sha512'.
func (c *Certificate) Sign(privKey PrivateKey, digest EVP_MD) error {
switch digest {
case EVP_SM3:
case EVP_SHA256:
case EVP_SHA384:
case EVP_SHA512:
Expand All @@ -293,27 +300,59 @@ func (c *Certificate) insecureSign(privKey PrivateKey, digest EVP_MD) error {
return nil
}

// Add an extension to a certificate.
// AddExtension Add an extension to a certificate.
// Extension constants are NID_* as found in openssl.
func (c *Certificate) AddExtension(nid NID, value string) error {
if c.x == nil {
return errors.New("certificate is nil")
}

issuer := c
if c.Issuer != nil {
if c.Issuer.x == nil {
return errors.New("issuer certificate is nil")
}
issuer = c.Issuer
}

cValue := C.CString(value)
defer C.free(unsafe.Pointer(cValue))

var ctx C.X509V3_CTX
C.X509V3_set_ctx(&ctx, c.x, issuer.x, nil, nil, 0)
ex := C.X509V3_EXT_conf_nid(nil, &ctx, C.int(nid), C.CString(value))

ex := C.X509V3_EXT_conf_nid(nil, &ctx, C.int(nid), cValue)
if ex == nil {
return errors.New("failed to create x509v3 extension")
return fmt.Errorf("failed to create x509v3 extension: %s", getOpenSSLError())
}
defer C.X509_EXTENSION_free(ex)

if C.X509_add_ext(c.x, ex, -1) <= 0 {
return errors.New("failed to add x509v3 extension")
return fmt.Errorf("failed to add x509v3 extension: %s", getOpenSSLError())
}

return nil
}

// getOpenSSLError Get the last error from the OpenSSL error queue.
func getOpenSSLError() string {
var errStrBuf [120]byte
C.ERR_error_string_n(C.ERR_get_error(), (*C.char)(unsafe.Pointer(&errStrBuf[0])), 120)
return string(errStrBuf[:])
}

// helper function to validate extension input
func validateExtensionInput(nid NID, value string) error {
if nid <= 0 {
return errors.New("invalid NID")
}
if value == "" {
return errors.New("empty extension value")
}
return nil
}

// Wraps AddExtension using a map of NID to text extension.
// AddExtensions Wraps AddExtension using a map of NID to text extension.
// Will return without finishing if it encounters an error.
func (c *Certificate) AddExtensions(extensions map[NID]string) error {
for nid, value := range extensions {
Expand Down Expand Up @@ -420,6 +459,40 @@ func getDigestFunction(digest EVP_MD) (md *C.EVP_MD) {
md = C.X_EVP_sha384()
case EVP_SHA512:
md = C.X_EVP_sha512()
case EVP_SM3:
md = C.X_EVP_sm3()
}
return md
}

// LoadPEMFromFile loads a PEM file and returns the []byte format.
func LoadPEMFromFile(filename string) ([]byte, error) {
file, err := os.Open(filename)
if err != nil {
return nil, err
}
defer file.Close()

pemBlock, err := io.ReadAll(file)
if err != nil {
return nil, err
}

return pemBlock, nil
}

// SavePEMToFile saves a PEM block to a file.
func SavePEMToFile(pemBlock []byte, filename string) error {
file, err := os.Create(filename)
if err != nil {
return err
}
defer file.Close()

_, err = file.Write(pemBlock)
if err != nil {
return err
}

return nil
}
145 changes: 145 additions & 0 deletions crypto/cert_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ package crypto

import (
"math/big"
"os"
"path/filepath"
"testing"
"time"
)
Expand All @@ -42,6 +44,28 @@ func TestCertGenerate(t *testing.T) {
}
}

func TestCertGenerateSM2(t *testing.T) {
key, err := GenerateECKey(Sm2Curve)
if err != nil {
t.Fatal(err)
}
info := &CertificateInfo{
Serial: big.NewInt(int64(1)),
Issued: 0,
Expires: 24 * time.Hour,
Country: "US",
Organization: "Test",
CommonName: "localhost",
}
cert, err := NewCertificate(info, key)
if err != nil {
t.Fatal(err)
}
if err := cert.Sign(key, EVP_SM3); err != nil {
t.Fatal(err)
}
}

func TestCAGenerate(t *testing.T) {
cakey, err := GenerateRSAKey(768)
if err != nil {
Expand Down Expand Up @@ -101,6 +125,127 @@ func TestCAGenerate(t *testing.T) {
}
}

func TestCAGenerateSM2(t *testing.T) {
dirName := filepath.Join("test-runs", "TestCAGenerateSM2")
_, err := os.Stat(dirName)
if os.IsNotExist(err) {
// The directory does not exist, creating it now.
err := os.MkdirAll(dirName, 0755)
if err != nil {
t.Logf("Failed to create the directory: %v\n", err)
}
} else if err != nil {
// other error
t.Logf("Failed to check the directory: %v\n", err)
}

// Helper function: generate and save key
generateAndSaveKey := func(filename string) PrivateKey {
key, err := GenerateECKey(Sm2Curve)
if err != nil {
t.Fatal(err)
}
pem, err := key.MarshalPKCS8PrivateKeyPEM()
if err != nil {
t.Fatal(err)
}
err = SavePEMToFile(pem, filename)
if err != nil {
t.Fatal(err)
}
return key
}

// Helper function: sign and save certificate
signAndSaveCert := func(cert *Certificate, caKey PrivateKey, filename string) {
err := cert.Sign(caKey, EVP_SM3)
if err != nil {
t.Fatal(err)
}
certPem, err := cert.MarshalPEM()
if err != nil {
t.Fatal(err)
}
err = SavePEMToFile(certPem, filename)
if err != nil {
t.Fatal(err)
}
}

// Create CA certificate
caKey, err := GenerateECKey(Sm2Curve)
if err != nil {
t.Fatal(err)
}
caInfo := CertificateInfo{
big.NewInt(1),
0,
87600 * time.Hour, // 10 years
"US",
"Test CA",
"CA",
}
caExtensions := map[NID]string{
NID_basic_constraints: "critical,CA:TRUE",
NID_key_usage: "critical,digitalSignature,keyCertSign,cRLSign",
NID_subject_key_identifier: "hash",
NID_authority_key_identifier: "keyid:always,issuer",
}
ca, err := NewCertificate(&caInfo, caKey)
if err != nil {
t.Fatal(err)
}
err = ca.AddExtensions(caExtensions)
if err != nil {
t.Fatal(err)
}
caFile := filepath.Join(dirName, "chain-ca.crt")
signAndSaveCert(ca, caKey, caFile)

// Define additional certificate information
certInfos := []struct {
name string
keyUsage string
}{
{"server_enc", "keyAgreement, keyEncipherment, dataEncipherment"},
{"server_sign", "nonRepudiation, digitalSignature"},
{"client_sign", "nonRepudiation, digitalSignature"},
{"client_enc", "keyAgreement, keyEncipherment, dataEncipherment"},
}

// Create additional certificates
for _, info := range certInfos {
keyFile := filepath.Join(dirName, info.name+".key")
key := generateAndSaveKey(keyFile)
certInfo := CertificateInfo{
Serial: big.NewInt(1),
Issued: 0,
Expires: 87600 * time.Hour, // 10 years
Country: "US",
Organization: "Test",
CommonName: "localhost",
}
extensions := map[NID]string{
NID_basic_constraints: "critical,CA:FALSE",
NID_key_usage: info.keyUsage,
}
cert, err := NewCertificate(&certInfo, key)
if err != nil {
t.Fatal(err)
}
err = cert.AddExtensions(extensions)
if err != nil {
t.Fatal(err)
}
err = cert.SetIssuer(ca)
if err != nil {
t.Fatal(err)
}
certFile := filepath.Join(dirName, info.name+".crt")
signAndSaveCert(cert, caKey, certFile)
}
}

func TestCertGetNameEntry(t *testing.T) {
key, err := GenerateRSAKey(768)
if err != nil {
Expand Down
29 changes: 29 additions & 0 deletions crypto/key.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,10 @@ type PrivateKey interface {
// MarshalPKCS1PrivateKeyDER converts the private key to DER-encoded PKCS1
// format
MarshalPKCS1PrivateKeyDER() (der_block []byte, err error)

// MarshalPKCS8PrivateKeyPEM converts the private key to PEM-encoded PKCS8
// format
MarshalPKCS8PrivateKeyPEM() (pem_block []byte, err error)
}

type pKey struct {
Expand Down Expand Up @@ -239,6 +243,31 @@ func (key *pKey) VerifyPKCS1v15(method Method, data, sig []byte) error {
}
}

func (key *pKey) MarshalPKCS8PrivateKeyPEM() ([]byte, error) {
if key.key == nil {
return nil, errors.New("empty key")
}

bio := C.BIO_new(C.BIO_s_mem())
if bio == nil {
return nil, errors.New("failed to allocate memory")
}
defer C.BIO_free(bio)

if C.PEM_write_bio_PKCS8PrivateKey(bio, key.key, nil, nil, 0, nil, nil) != 1 {
return nil, errors.New("failed to write private key")
}

var ptr *C.char
length := C.X_BIO_get_mem_data(bio, &ptr)
if length <= 0 {
return nil, errors.New("failed to read bio data")
}

result := C.GoBytes(unsafe.Pointer(ptr), C.int(length))
return result, nil
}

func (key *pKey) Encrypt(data []byte) ([]byte, error) {
ctx := C.EVP_PKEY_CTX_new(key.key, nil)
defer C.EVP_PKEY_CTX_free(ctx)
Expand Down
4 changes: 4 additions & 0 deletions crypto/shim.c
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,10 @@ void* X_BIO_get_data(BIO* bio) {
return BIO_get_data(bio);
}

long X_BIO_get_mem_data(BIO *b, char **pp) {
return BIO_get_mem_data(b, pp);
}

EVP_MD_CTX* X_EVP_MD_CTX_new() {
return EVP_MD_CTX_new();
}
Expand Down
1 change: 1 addition & 0 deletions crypto/shim.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ extern int X_BIO_read(BIO *b, void *buf, int len);
extern int X_BIO_write(BIO *b, const void *buf, int len);
extern BIO *X_BIO_new_write_bio();
extern BIO *X_BIO_new_read_bio();
extern long X_BIO_get_mem_data(BIO *b, char **pp);

extern int X_BN_num_bytes(const BIGNUM *a);

Expand Down
Loading
Loading