diff --git a/crypto/sm2/sm2_crypt.c b/crypto/sm2/sm2_crypt.c index dc2a86ab5..b25412cee 100644 --- a/crypto/sm2/sm2_crypt.c +++ b/crypto/sm2/sm2_crypt.c @@ -46,25 +46,12 @@ IMPLEMENT_STATIC_ASN1_FUNCTIONS(SM2_Ciphertext) static size_t ec_field_size(const EC_GROUP *group) { - /* Is there some simpler way to do this? */ - BIGNUM *p = BN_new(); - BIGNUM *a = BN_new(); - BIGNUM *b = BN_new(); - size_t field_size = 0; + const BIGNUM *p = EC_GROUP_get0_field(group); - if (p == NULL || a == NULL || b == NULL) - goto done; - - if (!EC_GROUP_get_curve(group, p, a, b, NULL)) - goto done; - field_size = (BN_num_bits(p) + 7) / 8; - - done: - BN_free(p); - BN_free(a); - BN_free(b); + if (p == NULL) + return 0; - return field_size; + return BN_num_bytes(p); } int ossl_sm2_plaintext_size(const unsigned char *ct, size_t ct_size, @@ -404,3 +391,72 @@ int ossl_sm2_decrypt(const EC_KEY *key, return rc; } + +int ossl_sm2_ciphertext_decode(const uint8_t *ciphertext, size_t ciphertext_len, + EC_POINT **C1p, uint8_t **C2_data, size_t *C2_len, + uint8_t **C3_data, size_t *C3_len) +{ + int ok = 0; + BN_CTX *ctx = NULL; + EC_GROUP *group = NULL; + EC_POINT *C1 = NULL; + struct SM2_Ciphertext_st *sm2_ctext = NULL; + + sm2_ctext = d2i_SM2_Ciphertext(NULL, &ciphertext, ciphertext_len); + if (sm2_ctext == NULL) { + ERR_raise(ERR_LIB_SM2, SM2_R_ASN1_ERROR); + goto done; + } + + group = EC_GROUP_new_by_curve_name(NID_sm2); + if (group == NULL) + goto done; + + C1 = EC_POINT_new(group); + if (C1 == NULL) + goto done; + + ctx = BN_CTX_new(); + if (ctx == NULL) + goto done; + + if (!EC_POINT_set_affine_coordinates(group, C1, sm2_ctext->C1x, + sm2_ctext->C1y, ctx)) + goto done; + + if (C1p) { + EC_POINT_free(*C1p); + *C1p = C1; + } + + if (C2_data && C2_len) { + OPENSSL_free(*C2_data); + *C2_data = OPENSSL_memdup(sm2_ctext->C2->data, sm2_ctext->C2->length); + if (*C2_data == NULL) { + ERR_raise(ERR_LIB_SM2, ERR_R_MALLOC_FAILURE); + goto done; + } + *C2_len = sm2_ctext->C2->length; + } + + if (C3_data && C3_len) { + OPENSSL_free(*C3_data); + *C3_data = OPENSSL_memdup(sm2_ctext->C3->data, sm2_ctext->C3->length); + if (*C3_data == NULL) { + ERR_raise(ERR_LIB_SM2, ERR_R_MALLOC_FAILURE); + goto done; + } + *C3_len = sm2_ctext->C3->length; + } + + ok = 1; +done: + if (!ok) + EC_POINT_free(C1); + + BN_CTX_free(ctx); + EC_GROUP_free(group); + SM2_Ciphertext_free(sm2_ctext); + + return ok; +} diff --git a/crypto/sm2/sm2_threshold.c b/crypto/sm2/sm2_threshold.c index feae932e5..2ecf35f45 100644 --- a/crypto/sm2/sm2_threshold.c +++ b/crypto/sm2/sm2_threshold.c @@ -483,3 +483,266 @@ int SM2_THRESHOLD_sign3(const EVP_PKEY *key, const EVP_PKEY *temp_key, BN_free(w1); return ret; } + +int SM2_THRESHOLD_decrypt1(const unsigned char *ct, size_t ct_len, BIGNUM **kp, + EC_POINT **T1p) +{ + int ok = 0; + BIGNUM *k = NULL; + BN_CTX *ctx = NULL; + EC_GROUP *group = NULL; + EC_POINT *C1 = NULL, *T1 = NULL; + const BIGNUM *order; + + if (ct == NULL || kp == NULL || T1p == NULL) { + ERR_raise(ERR_LIB_SM2, ERR_R_PASSED_NULL_PARAMETER); + return 0; + } + + if (!ossl_sm2_ciphertext_decode(ct, ct_len, &C1, NULL, NULL, NULL, NULL)) + return 0; + + group = EC_GROUP_new_by_curve_name_ex(NULL, NULL, NID_sm2); + if (group == NULL) + goto end; + + order = EC_GROUP_get0_order(group); + + ctx = BN_CTX_new(); + if (ctx == NULL) + goto end; + + BN_CTX_start(ctx); + k = BN_CTX_get(ctx); + if (k == NULL) + goto end; + + /* Generate a random number k in [1, n-1] */ + do { + if (!BN_priv_rand_range_ex(k, order, 0, ctx)) + goto end; + } while (BN_is_zero(k)); + + T1 = EC_POINT_new(group); + if (T1 == NULL) + goto end; + + /* + * T_1 = [k]C_1 + */ + if (!EC_POINT_mul(group, T1, NULL, C1, k, ctx)) { + ERR_raise(ERR_LIB_EC, ERR_R_EC_LIB); + goto end; + } + + BN_free(*kp); + *kp = k; + EC_POINT_free(*T1p); + *T1p = T1; + + ok = 1; +end: + if (!ok) { + BN_free(k); + EC_POINT_free(T1); + } + + EC_POINT_free(C1); + EC_GROUP_free(group); + BN_CTX_end(ctx); + BN_CTX_free(ctx); + return ok; +} + +int SM2_THRESHOLD_decrypt2(EVP_PKEY *key, const EC_POINT *T1, EC_POINT **T2p) +{ + int ok = 0; + const EC_KEY *eckey; + const EC_GROUP *group; + const BIGNUM *d2; + BIGNUM *d2_inv = NULL; + BN_CTX *ctx = NULL; + EC_POINT *T2 = NULL; + + if (key == NULL || T1 == NULL || T2p == NULL) { + ERR_raise(ERR_LIB_SM2, ERR_R_PASSED_NULL_PARAMETER); + return 0; + } + + eckey = EVP_PKEY_get0_EC_KEY(key); + if (eckey == NULL) + return 0; + + group = EC_KEY_get0_group(eckey); + d2 = EC_KEY_get0_private_key(eckey); + if (d2 == NULL) + return 0; + + ctx = BN_CTX_new_ex(ossl_ec_key_get_libctx(eckey)); + if (ctx == NULL) + return 0; + + BN_CTX_start(ctx); + d2_inv = BN_CTX_get(ctx); + if (d2_inv == NULL) + goto end; + + T2 = EC_POINT_new(group); + if (T2 == NULL) + goto end; + + /* + * T_2 = d_2^(-1) * T_1 + */ + if (!ossl_ec_group_do_inverse_ord(group, d2_inv, d2, ctx) + || !EC_POINT_mul(group, T2, NULL, T1, d2_inv, ctx)) { + ERR_raise(ERR_LIB_EC, ERR_R_EC_LIB); + goto end; + } + + EC_POINT_free(*T2p); + *T2p = T2; + + ok = 1; +end: + if (!ok) + EC_POINT_free(T2); + BN_CTX_end(ctx); + BN_CTX_free(ctx); + return ok; +} + +int SM2_THRESHOLD_decrypt3(EVP_PKEY *key, const unsigned char *ct, + size_t ct_len, const BIGNUM *k, const EC_POINT *T2, + unsigned char **pt, size_t *pt_len) +{ + int ok = 0; + uint8_t *msg_mask = NULL; + const EC_KEY *eckey; + const EC_GROUP *group; + OSSL_LIB_CTX *libctx; + const char *propq; + const BIGNUM *field; + BN_CTX *ctx = NULL; + EC_POINT *C1 = NULL, *kP = NULL, *C1_inv = NULL; + const BIGNUM *d1; + BIGNUM *x2, *y2, *d1_inv, *k_inv; + size_t field_size; + uint8_t *C2 = NULL; + size_t i, msg_len; + unsigned char *x2y2buf = NULL, *msg = NULL; + + if (key == NULL || ct == NULL || k == NULL || T2 == NULL || pt == NULL + || pt_len == NULL) { + ERR_raise(ERR_LIB_SM2, ERR_R_PASSED_NULL_PARAMETER); + return 0; + } + + eckey = EVP_PKEY_get0_EC_KEY(key); + if (eckey == NULL) + return 0; + + group = EC_KEY_get0_group(eckey); + libctx = ossl_ec_key_get_libctx(eckey); + propq = ossl_ec_key_get0_propq(eckey); + + d1 = EC_KEY_get0_private_key(eckey); + if (d1 == NULL) + return 0; + + if ((field = EC_GROUP_get0_field(group)) == NULL + || (field_size = BN_num_bytes(field)) == 0) + return 0; + + if (!ossl_sm2_ciphertext_decode(ct, ct_len, &C1, &C2, &msg_len, NULL, + NULL)) + goto end; + + kP = EC_POINT_new(group); + if (kP == NULL) + goto end; + + C1_inv = EC_POINT_dup(C1, group); + if (C1_inv == NULL) + goto end; + + ctx = BN_CTX_new_ex(libctx); + if (ctx == NULL) + goto end; + + BN_CTX_start(ctx); + x2 = BN_CTX_get(ctx); + y2 = BN_CTX_get(ctx); + d1_inv = BN_CTX_get(ctx); + k_inv = BN_CTX_get(ctx); + if (k_inv == NULL) + goto end; + + /* + * [k]P_B = (x2, y2) = [k^(-1) * d_1^(-1)] * T_2 - C1 + */ + if (!ossl_ec_group_do_inverse_ord(group, d1_inv, d1, ctx) + || !ossl_ec_group_do_inverse_ord(group, k_inv, k, ctx) + || !EC_POINT_mul(group, kP, NULL, T2, k_inv, ctx) + || !EC_POINT_mul(group, kP, NULL, kP, d1_inv, ctx) + || !EC_POINT_invert(group, C1_inv, ctx) + || !EC_POINT_add(group, kP, kP, C1_inv, ctx) + || !EC_POINT_get_affine_coordinates(group, kP, x2, y2, ctx)) { + ERR_raise(ERR_LIB_EC, ERR_R_EC_LIB); + goto end; + } + + msg_mask = OPENSSL_malloc(msg_len); + if (msg_mask == NULL) { + ERR_raise(ERR_LIB_SM2, ERR_R_MALLOC_FAILURE); + goto end; + } + + x2y2buf = OPENSSL_zalloc(2 * field_size); + if (x2y2buf == NULL) { + ERR_raise(ERR_LIB_SM2, ERR_R_MALLOC_FAILURE); + goto end; + } + + if (BN_bn2binpad(x2, x2y2buf, field_size) < 0 + || BN_bn2binpad(y2, x2y2buf + field_size, field_size) < 0) { + ERR_raise(ERR_LIB_SM2, ERR_R_INTERNAL_ERROR); + goto end; + } + + /* X9.63 with no salt happens to match the KDF used in SM2 */ + if (!ossl_ecdh_kdf_X9_63(msg_mask, msg_len, x2y2buf, 2 * field_size, NULL, 0, + EVP_sm3(), libctx, propq)) { + ERR_raise(ERR_LIB_SM2, ERR_R_EVP_LIB); + goto end; + } + + msg = OPENSSL_malloc(msg_len); + if (msg == NULL) { + ERR_raise(ERR_LIB_SM2, ERR_R_MALLOC_FAILURE); + goto end; + } + + for (i = 0; i != msg_len; ++i) + msg[i] = C2[i] ^ msg_mask[i]; + + OPENSSL_free(*pt); + *pt = msg; + *pt_len = msg_len; + + ok = 1; +end: + if (!ok) + OPENSSL_free(msg); + + OPENSSL_free(msg); + OPENSSL_free(x2y2buf); + OPENSSL_free(msg_mask); + BN_CTX_end(ctx); + BN_CTX_free(ctx); + EC_POINT_free(C1_inv); + EC_POINT_free(kP); + OPENSSL_free(C2); + EC_POINT_free(C1); + return ok; +} diff --git a/include/crypto/sm2.h b/include/crypto/sm2.h index d5d85f2f0..108d821fb 100644 --- a/include/crypto/sm2.h +++ b/include/crypto/sm2.h @@ -79,6 +79,10 @@ int ossl_sm2_decrypt(const EC_KEY *key, const uint8_t *ciphertext, size_t ciphertext_len, uint8_t *ptext_buf, size_t *ptext_len); +int ossl_sm2_ciphertext_decode(const uint8_t *ciphertext, size_t ciphertext_len, + EC_POINT **C1p, uint8_t **C2p, size_t *C2_len, + uint8_t **C3p, size_t *C3_len); + const unsigned char *ossl_sm2_algorithmidentifier_encoding(int md_nid, size_t *len); diff --git a/include/openssl/sm2_threshold.h b/include/openssl/sm2_threshold.h index 191469b3f..70d9b3104 100644 --- a/include/openssl/sm2_threshold.h +++ b/include/openssl/sm2_threshold.h @@ -113,6 +113,15 @@ int SM2_THRESHOLD_sign3(const EVP_PKEY *key, const EVP_PKEY *temp_key, const unsigned char *sig2, size_t sig2_len, unsigned char **sig, size_t *siglen); +int SM2_THRESHOLD_decrypt1(const unsigned char *ct, size_t ct_len, BIGNUM **kp, + EC_POINT **T1p); + +int SM2_THRESHOLD_decrypt2(EVP_PKEY *key, const EC_POINT *T1, EC_POINT **T2p); + +int SM2_THRESHOLD_decrypt3(EVP_PKEY *key, const unsigned char *ct, + size_t ct_len, const BIGNUM *k, const EC_POINT *T2, + unsigned char **pt, size_t *pt_len); + # ifdef __cplusplus } # endif diff --git a/test/sm2_threshold_test.c b/test/sm2_threshold_test.c index 2a367bec9..e57bc865c 100644 --- a/test/sm2_threshold_test.c +++ b/test/sm2_threshold_test.c @@ -158,6 +158,82 @@ static int sm2_threshold_sign_test(int id) return ret; } +static int test_sm2_threshold_decrypt(void) +{ + int ret = 0; + const char *msg = "hello sm2 threshold"; + int msg_len = strlen(msg); + EVP_PKEY *key1 = NULL, *key2 = NULL, *pubkey1 = NULL, *pubkey2 = NULL; + EVP_PKEY *complete_key1 = NULL, *complete_key2 = NULL; + EVP_PKEY_CTX *pctx = NULL; + BIGNUM *k = NULL; + EC_POINT *T1 = NULL, *T2 = NULL; + unsigned char *buf = NULL, *pt = NULL; + size_t pt_len, outlen; + + if (!TEST_ptr(key1 = EVP_PKEY_Q_keygen(NULL, NULL, "SM2")) + || !TEST_ptr(key2 = EVP_PKEY_Q_keygen(NULL, NULL, "SM2"))) + goto err; + + if (!TEST_ptr(pubkey1 = SM2_THRESHOLD_derive_partial_pubkey(key1)) + || !TEST_ptr(pubkey2 = SM2_THRESHOLD_derive_partial_pubkey(key2))) + goto err; + + if (!TEST_ptr(complete_key1 = + SM2_THRESHOLD_derive_complete_pubkey(key1, pubkey2)) + || !TEST_ptr(complete_key2 = + SM2_THRESHOLD_derive_complete_pubkey(key2, pubkey1))) + goto err; + + if (!TEST_true(EVP_PKEY_eq(complete_key1, complete_key2))) + goto err; + + if (!TEST_ptr(pctx = EVP_PKEY_CTX_new(complete_key1, NULL))) + goto err; + + if (!TEST_true(EVP_PKEY_encrypt_init(pctx) == 1)) + goto err; + + if (!TEST_true(EVP_PKEY_encrypt(pctx, NULL, &outlen, + (const unsigned char *)msg, msg_len) == 1)) + goto err; + + if (!TEST_ptr(buf = OPENSSL_malloc(outlen))) + goto err; + + if (!TEST_true(EVP_PKEY_encrypt(pctx, buf, &outlen, + (const unsigned char *)msg, msg_len) == 1)) + goto err; + + if (!TEST_true(SM2_THRESHOLD_decrypt1(buf, outlen, &k, &T1))) + goto err; + + if (!TEST_true(SM2_THRESHOLD_decrypt2(key2, T1, &T2))) + goto err; + + if (!TEST_true(SM2_THRESHOLD_decrypt3(key1, buf, outlen, k, T2, &pt, &pt_len))) + goto err; + + if (!TEST_str_eq((const char *)pt, msg)) + goto err; + + ret = 1; +err: + EVP_PKEY_free(key1); + EVP_PKEY_free(key2); + EVP_PKEY_free(pubkey1); + EVP_PKEY_free(pubkey2); + EVP_PKEY_free(complete_key1); + EVP_PKEY_free(complete_key2); + EVP_PKEY_CTX_free(pctx); + OPENSSL_free(buf); + BN_free(k); + EC_POINT_free(T1); + EC_POINT_free(T2); + OPENSSL_free(pt); + + return ret; +} #endif int setup_tests(void) @@ -167,6 +243,7 @@ int setup_tests(void) #else ADD_TEST(sm2_threshold_keygen_test); ADD_ALL_TESTS(sm2_threshold_sign_test, 2); + ADD_TEST(test_sm2_threshold_decrypt); #endif return 1; } diff --git a/util/libcrypto.num b/util/libcrypto.num index 08ac79238..8a538ff70 100644 --- a/util/libcrypto.num +++ b/util/libcrypto.num @@ -5635,3 +5635,6 @@ SM2_THRESHOLD_sign1_final 5950 3_0_3 EXIST::FUNCTION:SM2_THRESHOLD SM2_THRESHOLD_sign1_oneshot 5951 3_0_3 EXIST::FUNCTION:SM2_THRESHOLD SM2_THRESHOLD_sign2 5952 3_0_3 EXIST::FUNCTION:SM2_THRESHOLD SM2_THRESHOLD_sign3 5953 3_0_3 EXIST::FUNCTION:SM2_THRESHOLD +SM2_THRESHOLD_decrypt1 5954 3_0_3 EXIST::FUNCTION:SM2_THRESHOLD +SM2_THRESHOLD_decrypt2 5955 3_0_3 EXIST::FUNCTION:SM2_THRESHOLD +SM2_THRESHOLD_decrypt3 5956 3_0_3 EXIST::FUNCTION:SM2_THRESHOLD