diff --git a/cmd/mpcium/main.go b/cmd/mpcium/main.go index 6d12286..94ceb92 100644 --- a/cmd/mpcium/main.go +++ b/cmd/mpcium/main.go @@ -467,6 +467,7 @@ func GetNATSConnection(environment string) (*nats.Conn, error) { opts := []nats.Option{ nats.MaxReconnects(-1), // retry forever nats.ReconnectWait(2 * time.Second), + nats.NoEcho(), // Optimization: avoid echoing messages back to the publisher nats.DisconnectHandler(func(nc *nats.Conn) { logger.Warn("Disconnected from NATS") }), diff --git a/examples/generate/main.go b/examples/generate/main.go index 952935d..1b8b316 100644 --- a/examples/generate/main.go +++ b/examples/generate/main.go @@ -29,7 +29,7 @@ func main() { logger.Init(environment, false) natsURL := viper.GetString("nats.url") - natsConn, err := nats.Connect(natsURL) + natsConn, err := nats.Connect(natsURL, nats.NoEcho()) if err != nil { logger.Fatal("Failed to connect to NATS", err) } diff --git a/pkg/eventconsumer/event_consumer.go b/pkg/eventconsumer/event_consumer.go index c84e684..cf982c8 100644 --- a/pkg/eventconsumer/event_consumer.go +++ b/pkg/eventconsumer/event_consumer.go @@ -152,6 +152,8 @@ func (ec *eventConsumer) handleKeyGenEvent(natMsg *nats.Msg) { baseCtx, baseCancel := context.WithTimeout(context.Background(), KeyGenTimeOut) defer baseCancel() + logger.Info("[KEY GEN] Key generation result") + raw := natMsg.Data var msg types.GenerateKeyMessage if err := json.Unmarshal(raw, &msg); err != nil { @@ -167,6 +169,9 @@ func (ec *eventConsumer) handleKeyGenEvent(natMsg *nats.Msg) { } walletID := msg.WalletID + + logger.Info("[KEY GEN] Key generation result", "walletID", walletID) + ecdsaSession, err := ec.node.CreateKeyGenSession(mpc.SessionTypeECDSA, walletID, ec.mpcThreshold, ec.genKeyResultQueue) if err != nil { ec.handleKeygenSessionError(walletID, err, "Failed to create ECDSA key generation session", natMsg) diff --git a/pkg/eventconsumer/keygen_consumer.go b/pkg/eventconsumer/keygen_consumer.go index 98d9f07..d46600c 100644 --- a/pkg/eventconsumer/keygen_consumer.go +++ b/pkg/eventconsumer/keygen_consumer.go @@ -90,7 +90,7 @@ func (sc *keygenConsumer) waitForAllPeersReadyToGenKey(ctx context.Context) erro } } -// Run subscribes to signing events and processes them until the context is canceled. +// Run subscribes to keygen events and processes them until the context is canceled. func (sc *keygenConsumer) Run(ctx context.Context) error { // Wait for sufficient peers before starting to consume messages if err := sc.waitForAllPeersReadyToGenKey(ctx); err != nil { @@ -110,7 +110,7 @@ func (sc *keygenConsumer) Run(ctx context.Context) error { return fmt.Errorf("failed to subscribe to keygen events: %w", err) } sc.jsSub = sub - logger.Info("SigningConsumer: Subscribed to keygen events") + logger.Info("KeygenConsumer: Subscribed to keygen events") // Block until context cancellation. <-ctx.Done() @@ -140,9 +140,11 @@ func (sc *keygenConsumer) handleKeygenEvent(msg jetstream.Msg) { return } - // Create a reply inbox to receive the signing event response. + // Create a reply inbox to receive the keygen event response. replyInbox := nats.NewInbox() + logger.Info("Newreplybox id", "topic", replyInbox) + // Use a synchronous subscription for the reply inbox. replySub, err := sc.natsConn.SubscribeSync(replyInbox) if err != nil { @@ -156,12 +158,12 @@ func (sc *keygenConsumer) handleKeygenEvent(msg jetstream.Msg) { } }() - // Publish the signing event with the reply inbox. + // Publish the keygen event with the reply inbox. headers := map[string]string{ "SessionID": uuid.New().String(), } if err := sc.pubsub.PublishWithReply(MPCGenerateEvent, replyInbox, msg.Data(), headers); err != nil { - logger.Error("KeygenConsumer: Failed to publish signing event with reply", err) + logger.Error("KeygenConsumer: Failed to publish keygen event with reply", err) _ = msg.Nak() return } diff --git a/pkg/messaging/nats_subscription.go b/pkg/messaging/nats_subscription.go new file mode 100644 index 0000000..1b83126 --- /dev/null +++ b/pkg/messaging/nats_subscription.go @@ -0,0 +1,34 @@ +package messaging + +import ( + "fmt" + + "github.com/nats-io/nats.go" +) + +type Subscription interface { + Unsubscribe() error +} + +// a subscription can be made by pubsub or dicrectmessaging +type natsSubscription struct { + subscription *nats.Subscription + topic string + pubSub *natsPubSub + direct *natsDirectMessaging +} + +func (ns *natsSubscription) Unsubscribe() error { + if ns.topic == "" { + return fmt.Errorf("cannot cleanup handlers: topic is empty") + } + + if ns.pubSub != nil { + ns.pubSub.cleanupHandlers(ns.topic) + } + + if ns.direct != nil { + ns.direct.cleanupHandlers(ns.topic) + } + return ns.subscription.Unsubscribe() +} diff --git a/pkg/messaging/point2point.go b/pkg/messaging/point2point.go index 225f090..f59bf9a 100644 --- a/pkg/messaging/point2point.go +++ b/pkg/messaging/point2point.go @@ -126,5 +126,11 @@ func (d *natsDirectMessaging) Listen(topic string, handler func(data []byte)) (S d.handlers[topic] = append(d.handlers[topic], handler) d.mu.Unlock() - return &natsSubscription{subscription: sub}, nil + return &natsSubscription{subscription: sub, topic: topic, direct: d}, nil +} + +func (d *natsDirectMessaging) cleanupHandlers(topic string) { + d.mu.Lock() + defer d.mu.Unlock() + delete(d.handlers, topic) } diff --git a/pkg/messaging/pubsub.go b/pkg/messaging/pubsub.go index 8e4fd0e..0d832eb 100644 --- a/pkg/messaging/pubsub.go +++ b/pkg/messaging/pubsub.go @@ -1,38 +1,49 @@ package messaging import ( + "sync" + "github.com/fystack/mpcium/pkg/logger" "github.com/nats-io/nats.go" ) -type Subscription interface { - Unsubscribe() error -} - type PubSub interface { Publish(topic string, message []byte) error PublishWithReply(topic, reply string, data []byte, headers map[string]string) error - Subscribe(topic string, handler func(msg *nats.Msg)) (Subscription, error) + Subscribe(topic string, handler func(*nats.Msg)) (Subscription, error) } type natsPubSub struct { natsConn *nats.Conn -} - -type natsSubscription struct { - subscription *nats.Subscription -} - -func (ns *natsSubscription) Unsubscribe() error { - return ns.subscription.Unsubscribe() + handlers map[string][]func(*nats.Msg) + mu sync.Mutex } func NewNATSPubSub(natsConn *nats.Conn) PubSub { - return &natsPubSub{natsConn} + return &natsPubSub{ + natsConn: natsConn, + handlers: make(map[string][]func(*nats.Msg)), + } } func (n *natsPubSub) Publish(topic string, message []byte) error { - logger.Debug("[NATS] Publishing message", "topic", topic) + logger.Info("[NATS] Publishing message", "topic", topic) + + // access local handlers for subscribed topics + n.mu.Lock() + defer n.mu.Unlock() + + handlers, ok := n.handlers[topic] + if ok && len(handlers) != 0 { + msgNats := &nats.Msg{ + Subject: topic, // Required: the topic to publish to + Data: message, // The []byte payload + } + for _, handler := range handlers { + handler(msgNats) + } + } + return n.natsConn.Publish(topic, message) } @@ -46,11 +57,23 @@ func (n *natsPubSub) PublishWithReply(topic, reply string, data []byte, headers for k, v := range headers { msg.Header.Set(k, v) } + + // access local handlers for subscribed topics + n.mu.Lock() + defer n.mu.Unlock() + + handlers, ok := n.handlers[topic] + if ok && len(handlers) != 0 { + for _, handler := range handlers { + handler(msg) + } + } + err := n.natsConn.PublishMsg(msg) return err } -func (n *natsPubSub) Subscribe(topic string, handler func(msg *nats.Msg)) (Subscription, error) { +func (n *natsPubSub) Subscribe(topic string, handler func(*nats.Msg)) (Subscription, error) { //Handle subscription: handle more fields in msg sub, err := n.natsConn.Subscribe(topic, func(msg *nats.Msg) { handler(msg) @@ -59,5 +82,15 @@ func (n *natsPubSub) Subscribe(topic string, handler func(msg *nats.Msg)) (Subsc return nil, err } - return &natsSubscription{subscription: sub}, nil + n.mu.Lock() + n.handlers[topic] = append(n.handlers[topic], handler) + n.mu.Unlock() + + return &natsSubscription{subscription: sub, topic: topic, pubSub: n}, nil +} + +func (n *natsPubSub) cleanupHandlers(topic string) { + n.mu.Lock() + defer n.mu.Unlock() + delete(n.handlers, topic) } diff --git a/pkg/mpc/key_exchange_session.go b/pkg/mpc/key_exchange_session.go index 2065f03..9aeaa2f 100644 --- a/pkg/mpc/key_exchange_session.go +++ b/pkg/mpc/key_exchange_session.go @@ -90,6 +90,7 @@ func (e *ecdhSession) ListenKeyExchange() error { } if ecdhMsg.From == e.nodeID { + logger.Info("To self message successfully received", "nodeID", e.nodeID) return } diff --git a/pkg/mpc/session.go b/pkg/mpc/session.go index b1a76b5..fd74c7b 100644 --- a/pkg/mpc/session.go +++ b/pkg/mpc/session.go @@ -228,7 +228,7 @@ func (s *session) receiveTssMessage(msg *types.TssMessage) { s.ErrCh <- errors.Wrap(err, "Broken TSS Share") return } - logger.Debug( + logger.Info( "Received message", "round", round.RoundMsg, @@ -285,17 +285,15 @@ func (s *session) subscribeFromPeersAsync(fromIDs []string) { } func (s *session) subscribeBroadcastAsync() { - go func() { - topic := s.topicComposer.ComposeBroadcastTopic() - sub, err := s.pubSub.Subscribe(topic, func(natMsg *nats.Msg) { - s.receiveBroadcastTssMessage(natMsg.Data) - }) - if err != nil { - s.ErrCh <- fmt.Errorf("Failed to subscribe to broadcast topic %s: %w", topic, err) - return - } - s.broadcastSub = sub - }() + topic := s.topicComposer.ComposeBroadcastTopic() + sub, err := s.pubSub.Subscribe(topic, func(natMsg *nats.Msg) { + go s.receiveBroadcastTssMessage(natMsg.Data) + }) + if err != nil { + s.ErrCh <- fmt.Errorf("Failed to subscribe to broadcast topic %s: %w", topic, err) + return + } + s.broadcastSub = sub } func (s *session) ListenToIncomingMessageAsync() {