Skip to content

Commit

Permalink
Support sm2 double certs
Browse files Browse the repository at this point in the history
  • Loading branch information
ZBCccc committed Aug 12, 2024
1 parent cea6181 commit e462021
Show file tree
Hide file tree
Showing 9 changed files with 684 additions and 25 deletions.
89 changes: 81 additions & 8 deletions crypto/cert.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,11 @@ import "C"

import (
"errors"
"fmt"
"io"
"io/ioutil"
"math/big"
"os"
"runtime"
"time"
"unsafe"
Expand All @@ -42,6 +45,7 @@ const (
EVP_SHA256 EVP_MD = iota
EVP_SHA384 EVP_MD = iota
EVP_SHA512 EVP_MD = iota
EVP_SM3 EVP_MD = iota
)

// X509_Version represents a version on an x509 certificate.
Expand Down Expand Up @@ -82,7 +86,7 @@ func NewCertWrapper(x unsafe.Pointer, ref ...interface{}) *Certificate {
}
}

// Allocate and return a new Name object.
// NewName allocate and return a new Name object.
func NewName() (*Name, error) {
n := C.X509_NAME_new()
if n == nil {
Expand Down Expand Up @@ -139,7 +143,9 @@ func NewCertificate(info *CertificateInfo, key PublicKey) (*Certificate, error)
runtime.SetFinalizer(c, func(c *Certificate) {
C.X509_free(c.x)
})

if err := c.SetVersion(X509_V3); err != nil {
return nil, err
}
name, err := c.GetSubjectName()
if err != nil {
return nil, err
Expand Down Expand Up @@ -272,9 +278,10 @@ func (c *Certificate) SetPubKey(pubKey PublicKey) error {
}

// Sign a certificate using a private key and a digest name.
// Accepted digest names are 'sha256', 'sha384', and 'sha512'.
// Accepted digest names are 'sm3', 'sha256', 'sha384', and 'sha512'.
func (c *Certificate) Sign(privKey PrivateKey, digest EVP_MD) error {
switch digest {
case EVP_SM3:
case EVP_SHA256:
case EVP_SHA384:
case EVP_SHA512:
Expand All @@ -293,27 +300,59 @@ func (c *Certificate) insecureSign(privKey PrivateKey, digest EVP_MD) error {
return nil
}

// Add an extension to a certificate.
// AddExtension Add an extension to a certificate.
// Extension constants are NID_* as found in openssl.
func (c *Certificate) AddExtension(nid NID, value string) error {
if c.x == nil {
return errors.New("certificate is nil")
}

issuer := c
if c.Issuer != nil {
if c.Issuer.x == nil {
return errors.New("issuer certificate is nil")
}
issuer = c.Issuer
}

cValue := C.CString(value)
defer C.free(unsafe.Pointer(cValue))

var ctx C.X509V3_CTX
C.X509V3_set_ctx(&ctx, c.x, issuer.x, nil, nil, 0)
ex := C.X509V3_EXT_conf_nid(nil, &ctx, C.int(nid), C.CString(value))

ex := C.X509V3_EXT_conf_nid(nil, &ctx, C.int(nid), cValue)
if ex == nil {
return errors.New("failed to create x509v3 extension")
return fmt.Errorf("failed to create x509v3 extension: %s", getOpenSSLError())
}
defer C.X509_EXTENSION_free(ex)

if C.X509_add_ext(c.x, ex, -1) <= 0 {
return errors.New("failed to add x509v3 extension")
return fmt.Errorf("failed to add x509v3 extension: %s", getOpenSSLError())
}

return nil
}

// getOpenSSLError Get the last error from the OpenSSL error queue.
func getOpenSSLError() string {
var errStrBuf [120]byte
C.ERR_error_string_n(C.ERR_get_error(), (*C.char)(unsafe.Pointer(&errStrBuf[0])), 120)
return string(errStrBuf[:])
}

// helper function to validate extension input
func validateExtensionInput(nid NID, value string) error {
if nid <= 0 {
return errors.New("invalid NID")
}
if value == "" {
return errors.New("empty extension value")
}
return nil
}

// Wraps AddExtension using a map of NID to text extension.
// AddExtensions Wraps AddExtension using a map of NID to text extension.
// Will return without finishing if it encounters an error.
func (c *Certificate) AddExtensions(extensions map[NID]string) error {
for nid, value := range extensions {
Expand Down Expand Up @@ -420,6 +459,40 @@ func getDigestFunction(digest EVP_MD) (md *C.EVP_MD) {
md = C.X_EVP_sha384()
case EVP_SHA512:
md = C.X_EVP_sha512()
case EVP_SM3:
md = C.X_EVP_sm3()
}
return md
}

// LoadPEMFromFile loads a PEM file and returns the []byte format.
func LoadPEMFromFile(filename string) ([]byte, error) {
file, err := os.Open(filename)
if err != nil {
return nil, err
}
defer file.Close()

pemBlock, err := io.ReadAll(file)
if err != nil {
return nil, err
}

return pemBlock, nil
}

// SavePEMToFile saves a PEM block to a file.
func SavePEMToFile(pemBlock []byte, filename string) error {
file, err := os.Create(filename)
if err != nil {
return err
}
defer file.Close()

_, err = file.Write(pemBlock)
if err != nil {
return err
}

return nil
}
145 changes: 145 additions & 0 deletions crypto/cert_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ package crypto

import (
"math/big"
"os"
"path/filepath"
"testing"
"time"
)
Expand All @@ -42,6 +44,28 @@ func TestCertGenerate(t *testing.T) {
}
}

func TestCertGenerateSM2(t *testing.T) {
key, err := GenerateECKey(Sm2Curve)
if err != nil {
t.Fatal(err)
}
info := &CertificateInfo{
Serial: big.NewInt(int64(1)),
Issued: 0,
Expires: 24 * time.Hour,
Country: "US",
Organization: "Test",
CommonName: "localhost",
}
cert, err := NewCertificate(info, key)
if err != nil {
t.Fatal(err)
}
if err := cert.Sign(key, EVP_SM3); err != nil {
t.Fatal(err)
}
}

func TestCAGenerate(t *testing.T) {
cakey, err := GenerateRSAKey(768)
if err != nil {
Expand Down Expand Up @@ -101,6 +125,127 @@ func TestCAGenerate(t *testing.T) {
}
}

func TestCAGenerateSM2(t *testing.T) {
dirName := filepath.Join("test-runs", "TestCAGenerateSM2")
_, err := os.Stat(dirName)
if os.IsNotExist(err) {
// The directory does not exist, creating it now.
err := os.MkdirAll(dirName, 0755)
if err != nil {
t.Logf("Failed to create the directory: %v\n", err)
}
} else if err != nil {
// other error
t.Logf("Failed to check the directory: %v\n", err)
}

// Helper function: generate and save key
generateAndSaveKey := func(filename string) PrivateKey {
key, err := GenerateECKey(Sm2Curve)
if err != nil {
t.Fatal(err)
}
pem, err := key.MarshalPKCS8PrivateKeyPEM()
if err != nil {
t.Fatal(err)
}
err = SavePEMToFile(pem, filename)
if err != nil {
t.Fatal(err)
}
return key
}

// Helper function: sign and save certificate
signAndSaveCert := func(cert *Certificate, caKey PrivateKey, filename string) {
err := cert.Sign(caKey, EVP_SM3)
if err != nil {
t.Fatal(err)
}
certPem, err := cert.MarshalPEM()
if err != nil {
t.Fatal(err)
}
err = SavePEMToFile(certPem, filename)
if err != nil {
t.Fatal(err)
}
}

// Create CA certificate
caKey, err := GenerateECKey(Sm2Curve)
if err != nil {
t.Fatal(err)
}
caInfo := CertificateInfo{
big.NewInt(1),
0,
87600 * time.Hour, // 10 years
"US",
"Test CA",
"CA",
}
caExtensions := map[NID]string{
NID_basic_constraints: "critical,CA:TRUE",
NID_key_usage: "critical,digitalSignature,keyCertSign,cRLSign",
NID_subject_key_identifier: "hash",
NID_authority_key_identifier: "keyid:always,issuer",
}
ca, err := NewCertificate(&caInfo, caKey)
if err != nil {
t.Fatal(err)
}
err = ca.AddExtensions(caExtensions)
if err != nil {
t.Fatal(err)
}
caFile := filepath.Join(dirName, "chain-ca.crt")
signAndSaveCert(ca, caKey, caFile)

// Define additional certificate information
certInfos := []struct {
name string
keyUsage string
}{
{"server_enc", "keyAgreement, keyEncipherment, dataEncipherment"},
{"server_sign", "nonRepudiation, digitalSignature"},
{"client_sign", "nonRepudiation, digitalSignature"},
{"client_enc", "keyAgreement, keyEncipherment, dataEncipherment"},
}

// Create additional certificates
for _, info := range certInfos {
keyFile := filepath.Join(dirName, info.name+".key")
key := generateAndSaveKey(keyFile)
certInfo := CertificateInfo{
Serial: big.NewInt(1),
Issued: 0,
Expires: 87600 * time.Hour, // 10 years
Country: "US",
Organization: "Test",
CommonName: "localhost",
}
extensions := map[NID]string{
NID_basic_constraints: "critical,CA:FALSE",
NID_key_usage: info.keyUsage,
}
cert, err := NewCertificate(&certInfo, key)
if err != nil {
t.Fatal(err)
}
err = cert.AddExtensions(extensions)
if err != nil {
t.Fatal(err)
}
err = cert.SetIssuer(ca)
if err != nil {
t.Fatal(err)
}
certFile := filepath.Join(dirName, info.name+".crt")
signAndSaveCert(cert, caKey, certFile)
}
}

func TestCertGetNameEntry(t *testing.T) {
key, err := GenerateRSAKey(768)
if err != nil {
Expand Down
29 changes: 29 additions & 0 deletions crypto/key.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,10 @@ type PrivateKey interface {
// MarshalPKCS1PrivateKeyDER converts the private key to DER-encoded PKCS1
// format
MarshalPKCS1PrivateKeyDER() (der_block []byte, err error)

// MarshalPKCS8PrivateKeyPEM converts the private key to PEM-encoded PKCS8
// format
MarshalPKCS8PrivateKeyPEM() (pem_block []byte, err error)
}

type pKey struct {
Expand Down Expand Up @@ -239,6 +243,31 @@ func (key *pKey) VerifyPKCS1v15(method Method, data, sig []byte) error {
}
}

func (key *pKey) MarshalPKCS8PrivateKeyPEM() ([]byte, error) {
if key.key == nil {
return nil, errors.New("empty key")
}

bio := C.BIO_new(C.BIO_s_mem())
if bio == nil {
return nil, errors.New("failed to allocate memory")
}
defer C.BIO_free(bio)

if C.PEM_write_bio_PKCS8PrivateKey(bio, key.key, nil, nil, 0, nil, nil) != 1 {
return nil, errors.New("failed to write private key")
}

var ptr *C.char
length := C.X_BIO_get_mem_data(bio, &ptr)
if length <= 0 {
return nil, errors.New("failed to read bio data")
}

result := C.GoBytes(unsafe.Pointer(ptr), C.int(length))
return result, nil
}

func (key *pKey) Encrypt(data []byte) ([]byte, error) {
ctx := C.EVP_PKEY_CTX_new(key.key, nil)
defer C.EVP_PKEY_CTX_free(ctx)
Expand Down
4 changes: 4 additions & 0 deletions crypto/shim.c
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,10 @@ void* X_BIO_get_data(BIO* bio) {
return BIO_get_data(bio);
}

long X_BIO_get_mem_data(BIO *b, char **pp) {
return BIO_get_mem_data(b, pp);
}

EVP_MD_CTX* X_EVP_MD_CTX_new() {
return EVP_MD_CTX_new();
}
Expand Down
1 change: 1 addition & 0 deletions crypto/shim.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ extern int X_BIO_read(BIO *b, void *buf, int len);
extern int X_BIO_write(BIO *b, const void *buf, int len);
extern BIO *X_BIO_new_write_bio();
extern BIO *X_BIO_new_read_bio();
extern long X_BIO_get_mem_data(BIO *b, char **pp);

extern int X_BN_num_bytes(const BIGNUM *a);

Expand Down
Loading

0 comments on commit e462021

Please sign in to comment.