Skip to content

Commit

Permalink
feat: configurable curve and support secp256k1
Browse files Browse the repository at this point in the history
Signed-off-by: lanford33 <[email protected]>
  • Loading branch information
LanfordCai committed Dec 28, 2024
1 parent 6f32efc commit 581781f
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 14 deletions.
31 changes: 25 additions & 6 deletions mpc/binance/ecdsa/mpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ import (
"github.com/bnb-chain/tss-lib/v2/ecdsa/keygen"
"github.com/bnb-chain/tss-lib/v2/ecdsa/signing"
"github.com/bnb-chain/tss-lib/v2/tss"
"github.com/btcsuite/btcd/btcec/v2"
s256k1 "github.com/btcsuite/btcd/btcec/v2"
"github.com/decred/dcrd/dcrec/secp256k1/v4"
"github.com/golang/protobuf/proto"
"github.com/golang/protobuf/ptypes/any"
)
Expand Down Expand Up @@ -106,14 +109,20 @@ type party struct {
in chan tss.Message
shareData *keygen.LocalPartySaveData
closeChan chan struct{}
curve elliptic.Curve
}

func NewParty(id uint16, logger Logger) *party {
func NewParty(id uint16, curve elliptic.Curve, logger Logger) *party {
if curve == nil {
curve = s256k1.S256()
}

return &party{
logger: logger,
id: tss.NewPartyID(fmt.Sprintf("%d", id), "", big.NewInt(int64(id))),
out: make(chan tss.Message, 1000),
in: make(chan tss.Message, 1000),
curve: curve,
}
}

Expand Down Expand Up @@ -190,7 +199,17 @@ func (p *party) ThresholdPK() ([]byte, error) {
if err != nil {
return nil, err
}
return x509.MarshalPKIXPublicKey(pk)

switch p.curve.Params().Name {
case string(tss.Secp256k1):
xFieldVal, yFieldVal := new(secp256k1.FieldVal), new(secp256k1.FieldVal)
xFieldVal.SetByteSlice(pk.X.Bytes())
yFieldVal.SetByteSlice(pk.Y.Bytes())
btcecPubKey := btcec.NewPublicKey(xFieldVal, yFieldVal)
return btcecPubKey.SerializeCompressed(), nil
default:
return x509.MarshalPKIXPublicKey(pk)
}
}

func (p *party) SetShareData(shareData []byte) error {
Expand All @@ -199,9 +218,9 @@ func (p *party) SetShareData(shareData []byte) error {
if err != nil {
return fmt.Errorf("failed deserializing shares: %w", err)
}
localSaveData.ECDSAPub.SetCurve(elliptic.P256())
localSaveData.ECDSAPub.SetCurve(p.curve)
for _, xj := range localSaveData.BigXj {
xj.SetCurve(elliptic.P256())
xj.SetCurve(p.curve)
}
p.shareData = &localSaveData
return nil
Expand All @@ -210,7 +229,7 @@ func (p *party) SetShareData(shareData []byte) error {
func (p *party) Init(parties []uint16, threshold int, sendMsg func(msg []byte, isBroadcast bool, to uint16)) {
partyIDs := partyIDsFromNumbers(parties)
ctx := tss.NewPeerContext(partyIDs)
p.params = tss.NewParameters(elliptic.P256(), ctx, p.id, len(parties), threshold)
p.params = tss.NewParameters(p.curve, ctx, p.id, len(parties), threshold)
p.id.Index = p.locatePartyIndex(p.id)
p.sendMsg = sendMsg
p.closeChan = make(chan struct{})
Expand All @@ -237,7 +256,7 @@ func (p *party) Sign(ctx context.Context, msgHash []byte) ([]byte, error) {

end := make(chan *common.SignatureData, 1)

msgToSign := hashToInt(msgHash, elliptic.P256())
msgToSign := hashToInt(msgHash, p.curve)
party := signing.NewLocalParty(msgToSign, p.params, *p.shareData, p.out, end)

var endWG sync.WaitGroup
Expand Down
7 changes: 4 additions & 3 deletions mpc/binance/ecdsa/mpc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ package ecdsa
import (
"context"
"crypto/ecdsa"
"crypto/elliptic"
"fmt"
"math/big"
"sync"
Expand Down Expand Up @@ -108,9 +109,9 @@ func (parties parties) Mapping() map[string]*tss.PartyID {
}

func TestTSS(t *testing.T) {
pA := NewParty(1, logger("pA", t.Name()))
pB := NewParty(2, logger("pB", t.Name()))
pC := NewParty(3, logger("pC", t.Name()))
pA := NewParty(1, elliptic.P256(), logger("pA", t.Name()))
pB := NewParty(2, elliptic.P256(), logger("pB", t.Name()))
pC := NewParty(3, elliptic.P256(), logger("pC", t.Name()))

t.Logf("Created parties")

Expand Down
17 changes: 12 additions & 5 deletions test/binance/ecdsa_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@ package binance_test

import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/x509"
"testing"

ecdsa_scheme "github.com/IBM/TSS/mpc/binance/ecdsa"

. "github.com/IBM/TSS/types"

"github.com/stretchr/testify/assert"
)

Expand All @@ -18,7 +21,9 @@ func TestThresholdBinanceECDSA(t *testing.T) {
var signatureAlgorithms func([]*commLogger) (func(uint16) KeyGenerator, func(uint16) Signer)

verifySig = verifySignatureECDSA
signatureAlgorithms = ecdsaKeygenAndSign
signatureAlgorithms = func(loggers []*commLogger) (func(uint16) KeyGenerator, func(uint16) Signer) {
return ecdsaKeygenAndSign(elliptic.P256(), loggers)
}

testScheme(t, n, signatureAlgorithms, verifySig, false)
}
Expand All @@ -31,18 +36,20 @@ func TestFastThresholdBinanceECDSA(t *testing.T) {
var signatureAlgorithms func([]*commLogger) (func(uint16) KeyGenerator, func(uint16) Signer)

verifySig = verifySignatureECDSA
signatureAlgorithms = ecdsaKeygenAndSign
signatureAlgorithms = func(loggers []*commLogger) (func(uint16) KeyGenerator, func(uint16) Signer) {
return ecdsaKeygenAndSign(elliptic.P256(), loggers)
}

testScheme(t, n, signatureAlgorithms, verifySig, true)
}

func ecdsaKeygenAndSign(loggers []*commLogger) (func(id uint16) KeyGenerator, func(id uint16) Signer) {
func ecdsaKeygenAndSign(curve elliptic.Curve, loggers []*commLogger) (func(id uint16) KeyGenerator, func(id uint16) Signer) {
kgf := func(id uint16) KeyGenerator {
return ecdsa_scheme.NewParty(id, loggers[id-1])
return ecdsa_scheme.NewParty(id, curve, loggers[id-1])
}

sf := func(id uint16) Signer {
return ecdsa_scheme.NewParty(id, loggers[id-1])
return ecdsa_scheme.NewParty(id, curve, loggers[id-1])
}
return kgf, sf
}
Expand Down

0 comments on commit 581781f

Please sign in to comment.