Skip to content

Commit

Permalink
Support SM2 two-party threshold decrypt
Browse files Browse the repository at this point in the history
  • Loading branch information
dongbeiouba committed Jan 17, 2024
1 parent f91c17b commit ba6df41
Show file tree
Hide file tree
Showing 6 changed files with 429 additions and 17 deletions.
90 changes: 73 additions & 17 deletions crypto/sm2/sm2_crypt.c
Original file line number Diff line number Diff line change
Expand Up @@ -46,25 +46,12 @@ IMPLEMENT_STATIC_ASN1_FUNCTIONS(SM2_Ciphertext)

static size_t ec_field_size(const EC_GROUP *group)
{
/* Is there some simpler way to do this? */
BIGNUM *p = BN_new();
BIGNUM *a = BN_new();
BIGNUM *b = BN_new();
size_t field_size = 0;
const BIGNUM *p = EC_GROUP_get0_field(group);

if (p == NULL || a == NULL || b == NULL)
goto done;

if (!EC_GROUP_get_curve(group, p, a, b, NULL))
goto done;
field_size = (BN_num_bits(p) + 7) / 8;

done:
BN_free(p);
BN_free(a);
BN_free(b);
if (p == NULL)
return 0;

return field_size;
return BN_num_bytes(p);
}

int ossl_sm2_plaintext_size(const unsigned char *ct, size_t ct_size,
Expand Down Expand Up @@ -404,3 +391,72 @@ int ossl_sm2_decrypt(const EC_KEY *key,

return rc;
}

int ossl_sm2_ciphertext_decode(const uint8_t *ciphertext, size_t ciphertext_len,
EC_POINT **C1p, uint8_t **C2_data, size_t *C2_len,
uint8_t **C3_data, size_t *C3_len)
{
int ok = 0;
BN_CTX *ctx = NULL;
EC_GROUP *group = NULL;
EC_POINT *C1 = NULL;
struct SM2_Ciphertext_st *sm2_ctext = NULL;

sm2_ctext = d2i_SM2_Ciphertext(NULL, &ciphertext, ciphertext_len);
if (sm2_ctext == NULL) {
ERR_raise(ERR_LIB_SM2, SM2_R_ASN1_ERROR);
goto done;
}

group = EC_GROUP_new_by_curve_name(NID_sm2);
if (group == NULL)
goto done;

C1 = EC_POINT_new(group);
if (C1 == NULL)
goto done;

ctx = BN_CTX_new();
if (ctx == NULL)
goto done;

if (!EC_POINT_set_affine_coordinates(group, C1, sm2_ctext->C1x,
sm2_ctext->C1y, ctx))
goto done;

if (C1p) {
EC_POINT_free(*C1p);
*C1p = C1;
}

if (C2_data && C2_len) {
OPENSSL_free(*C2_data);
*C2_data = OPENSSL_memdup(sm2_ctext->C2->data, sm2_ctext->C2->length);
if (*C2_data == NULL) {
ERR_raise(ERR_LIB_SM2, ERR_R_MALLOC_FAILURE);
goto done;
}
*C2_len = sm2_ctext->C2->length;
}

if (C3_data && C3_len) {
OPENSSL_free(*C3_data);
*C3_data = OPENSSL_memdup(sm2_ctext->C3->data, sm2_ctext->C3->length);
if (*C3_data == NULL) {
ERR_raise(ERR_LIB_SM2, ERR_R_MALLOC_FAILURE);
goto done;
}
*C3_len = sm2_ctext->C3->length;
}

ok = 1;
done:
if (!ok)
EC_POINT_free(C1);

BN_CTX_free(ctx);
EC_GROUP_free(group);
SM2_Ciphertext_free(sm2_ctext);

return ok;
}
263 changes: 263 additions & 0 deletions crypto/sm2/sm2_threshold.c
Original file line number Diff line number Diff line change
Expand Up @@ -483,3 +483,266 @@ int SM2_THRESHOLD_sign3(const EVP_PKEY *key, const EVP_PKEY *temp_key,
BN_free(w1);
return ret;
}

int SM2_THRESHOLD_decrypt1(const unsigned char *ct, size_t ct_len, BIGNUM **kp,
EC_POINT **T1p)
{
int ok = 0;
BIGNUM *k = NULL;
BN_CTX *ctx = NULL;
EC_GROUP *group = NULL;
EC_POINT *C1 = NULL, *T1 = NULL;
const BIGNUM *order;

if (ct == NULL || kp == NULL || T1p == NULL) {
ERR_raise(ERR_LIB_SM2, ERR_R_PASSED_NULL_PARAMETER);
return 0;
}

if (!ossl_sm2_ciphertext_decode(ct, ct_len, &C1, NULL, NULL, NULL, NULL))
return 0;

group = EC_GROUP_new_by_curve_name_ex(NULL, NULL, NID_sm2);
if (group == NULL)
goto end;

order = EC_GROUP_get0_order(group);

ctx = BN_CTX_new();
if (ctx == NULL)
goto end;

BN_CTX_start(ctx);
k = BN_CTX_get(ctx);
if (k == NULL)
goto end;

/* Generate a random number k in [1, n-1] */
do {
if (!BN_priv_rand_range_ex(k, order, 0, ctx))
goto end;
} while (BN_is_zero(k));

T1 = EC_POINT_new(group);
if (T1 == NULL)
goto end;

/*
* T_1 = [k]C_1
*/
if (!EC_POINT_mul(group, T1, NULL, C1, k, ctx)) {
ERR_raise(ERR_LIB_EC, ERR_R_EC_LIB);
goto end;
}

BN_free(*kp);
*kp = k;
EC_POINT_free(*T1p);
*T1p = T1;

ok = 1;
end:
if (!ok) {
BN_free(k);
EC_POINT_free(T1);
}

EC_POINT_free(C1);
EC_GROUP_free(group);
BN_CTX_end(ctx);
BN_CTX_free(ctx);
return ok;
}

int SM2_THRESHOLD_decrypt2(EVP_PKEY *key, const EC_POINT *T1, EC_POINT **T2p)
{
int ok = 0;
const EC_KEY *eckey;
const EC_GROUP *group;
const BIGNUM *d2;
BIGNUM *d2_inv = NULL;
BN_CTX *ctx = NULL;
EC_POINT *T2 = NULL;

if (key == NULL || T1 == NULL || T2p == NULL) {
ERR_raise(ERR_LIB_SM2, ERR_R_PASSED_NULL_PARAMETER);
return 0;
}

eckey = EVP_PKEY_get0_EC_KEY(key);
if (eckey == NULL)
return 0;

group = EC_KEY_get0_group(eckey);
d2 = EC_KEY_get0_private_key(eckey);
if (d2 == NULL)
return 0;

ctx = BN_CTX_new_ex(ossl_ec_key_get_libctx(eckey));
if (ctx == NULL)
return 0;

BN_CTX_start(ctx);
d2_inv = BN_CTX_get(ctx);
if (d2_inv == NULL)
goto end;

T2 = EC_POINT_new(group);
if (T2 == NULL)
goto end;

/*
* T_2 = d_2^(-1) * T_1
*/
if (!ossl_ec_group_do_inverse_ord(group, d2_inv, d2, ctx)
|| !EC_POINT_mul(group, T2, NULL, T1, d2_inv, ctx)) {
ERR_raise(ERR_LIB_EC, ERR_R_EC_LIB);
goto end;
}

EC_POINT_free(*T2p);
*T2p = T2;

ok = 1;
end:
if (!ok)
EC_POINT_free(T2);
BN_CTX_end(ctx);
BN_CTX_free(ctx);
return ok;
}

int SM2_THRESHOLD_decrypt3(EVP_PKEY *key, const unsigned char *ct,
size_t ct_len, const BIGNUM *k, const EC_POINT *T2,
unsigned char **pt, size_t *pt_len)
{
int ok = 0;
uint8_t *msg_mask = NULL;
const EC_KEY *eckey;
const EC_GROUP *group;
OSSL_LIB_CTX *libctx;
const char *propq;
const BIGNUM *field;
BN_CTX *ctx = NULL;
EC_POINT *C1 = NULL, *kP = NULL, *C1_inv = NULL;
const BIGNUM *d1;
BIGNUM *x2, *y2, *d1_inv, *k_inv;
size_t field_size;
uint8_t *C2 = NULL;
size_t i, msg_len;
unsigned char *x2y2buf = NULL, *msg = NULL;

if (key == NULL || ct == NULL || k == NULL || T2 == NULL || pt == NULL
|| pt_len == NULL) {
ERR_raise(ERR_LIB_SM2, ERR_R_PASSED_NULL_PARAMETER);
return 0;
}

eckey = EVP_PKEY_get0_EC_KEY(key);
if (eckey == NULL)
return 0;

group = EC_KEY_get0_group(eckey);
libctx = ossl_ec_key_get_libctx(eckey);
propq = ossl_ec_key_get0_propq(eckey);

d1 = EC_KEY_get0_private_key(eckey);
if (d1 == NULL)
return 0;

if ((field = EC_GROUP_get0_field(group)) == NULL
|| (field_size = BN_num_bytes(field)) == 0)
return 0;

if (!ossl_sm2_ciphertext_decode(ct, ct_len, &C1, &C2, &msg_len, NULL,
NULL))
goto end;

kP = EC_POINT_new(group);
if (kP == NULL)
goto end;

C1_inv = EC_POINT_dup(C1, group);
if (C1_inv == NULL)
goto end;

ctx = BN_CTX_new_ex(libctx);
if (ctx == NULL)
goto end;

BN_CTX_start(ctx);
x2 = BN_CTX_get(ctx);
y2 = BN_CTX_get(ctx);
d1_inv = BN_CTX_get(ctx);
k_inv = BN_CTX_get(ctx);
if (k_inv == NULL)
goto end;

/*
* [k]P_B = (x2, y2) = [k^(-1) * d_1^(-1)] * T_2 - C1
*/
if (!ossl_ec_group_do_inverse_ord(group, d1_inv, d1, ctx)
|| !ossl_ec_group_do_inverse_ord(group, k_inv, k, ctx)
|| !EC_POINT_mul(group, kP, NULL, T2, k_inv, ctx)
|| !EC_POINT_mul(group, kP, NULL, kP, d1_inv, ctx)
|| !EC_POINT_invert(group, C1_inv, ctx)
|| !EC_POINT_add(group, kP, kP, C1_inv, ctx)
|| !EC_POINT_get_affine_coordinates(group, kP, x2, y2, ctx)) {
ERR_raise(ERR_LIB_EC, ERR_R_EC_LIB);
goto end;
}

msg_mask = OPENSSL_malloc(msg_len);
if (msg_mask == NULL) {
ERR_raise(ERR_LIB_SM2, ERR_R_MALLOC_FAILURE);
goto end;
}

x2y2buf = OPENSSL_zalloc(2 * field_size);
if (x2y2buf == NULL) {
ERR_raise(ERR_LIB_SM2, ERR_R_MALLOC_FAILURE);
goto end;
}

if (BN_bn2binpad(x2, x2y2buf, field_size) < 0
|| BN_bn2binpad(y2, x2y2buf + field_size, field_size) < 0) {
ERR_raise(ERR_LIB_SM2, ERR_R_INTERNAL_ERROR);
goto end;
}

/* X9.63 with no salt happens to match the KDF used in SM2 */
if (!ossl_ecdh_kdf_X9_63(msg_mask, msg_len, x2y2buf, 2 * field_size, NULL, 0,
EVP_sm3(), libctx, propq)) {
ERR_raise(ERR_LIB_SM2, ERR_R_EVP_LIB);
goto end;
}

msg = OPENSSL_malloc(msg_len);
if (msg == NULL) {
ERR_raise(ERR_LIB_SM2, ERR_R_MALLOC_FAILURE);
goto end;
}

for (i = 0; i != msg_len; ++i)
msg[i] = C2[i] ^ msg_mask[i];

OPENSSL_free(*pt);
*pt = msg;
*pt_len = msg_len;

ok = 1;
end:
if (!ok)
OPENSSL_free(msg);

OPENSSL_free(msg);
OPENSSL_free(x2y2buf);
OPENSSL_free(msg_mask);
BN_CTX_end(ctx);
BN_CTX_free(ctx);
EC_POINT_free(C1_inv);
EC_POINT_free(kP);
OPENSSL_free(C2);
EC_POINT_free(C1);
return ok;
}
4 changes: 4 additions & 0 deletions include/crypto/sm2.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,10 @@ int ossl_sm2_decrypt(const EC_KEY *key,
const uint8_t *ciphertext, size_t ciphertext_len,
uint8_t *ptext_buf, size_t *ptext_len);

int ossl_sm2_ciphertext_decode(const uint8_t *ciphertext, size_t ciphertext_len,
EC_POINT **C1p, uint8_t **C2p, size_t *C2_len,
uint8_t **C3p, size_t *C3_len);

const unsigned char *ossl_sm2_algorithmidentifier_encoding(int md_nid,
size_t *len);

Expand Down
9 changes: 9 additions & 0 deletions include/openssl/sm2_threshold.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,15 @@ int SM2_THRESHOLD_sign3(const EVP_PKEY *key, const EVP_PKEY *temp_key,
const unsigned char *sig2, size_t sig2_len,
unsigned char **sig, size_t *siglen);

int SM2_THRESHOLD_decrypt1(const unsigned char *ct, size_t ct_len, BIGNUM **kp,
EC_POINT **T1p);

int SM2_THRESHOLD_decrypt2(EVP_PKEY *key, const EC_POINT *T1, EC_POINT **T2p);

int SM2_THRESHOLD_decrypt3(EVP_PKEY *key, const unsigned char *ct,
size_t ct_len, const BIGNUM *k, const EC_POINT *T2,
unsigned char **pt, size_t *pt_len);

# ifdef __cplusplus
}
# endif
Expand Down
Loading

0 comments on commit ba6df41

Please sign in to comment.