diff --git a/hooks/session.go b/hooks/session.go index 94c9b32..9d0d1f0 100644 --- a/hooks/session.go +++ b/hooks/session.go @@ -22,6 +22,7 @@ func (h *SessionHook) Provides(b byte) bool { return bytes.Contains([]byte{ mqtt.OnConnectAuthenticate, mqtt.OnACLCheck, + mqtt.OnSubscribe, mqtt.OnDisconnect, }, []byte{b}) } @@ -53,10 +54,19 @@ func (h *SessionHook) OnConnectAuthenticate(cl *mqtt.Client, pk packets.Packet) func (h *SessionHook) OnACLCheck(cl *mqtt.Client, topic string, write bool) bool { h.Log.Debug("ACLCheck", "client", cl.ID, "topic", topic, "write", write) - h.Store.Subscribe(cl.ID, topic) return true } +func (h *SessionHook) OnSubscribe(cl *mqtt.Client, pk packets.Packet) packets.Packet { + h.Log.Debug("Subscribe", "client", cl.ID, "topic", pk.TopicName) + topics := []string{} + for _, sub := range pk.Filters { + topics = append(topics, sub.Filter) + } + h.Store.Subscribe(cl.ID, topics) + return pk +} + func (h *SessionHook) OnDisconnect(cl *mqtt.Client, err error, expire bool) { h.Log.Debug("Disconnect", "client", cl.ID, "expire", expire) h.Store.Leave(cl.ID) diff --git a/lib/clientstore.go b/lib/clientstore.go index bc5d2c1..b2e5233 100644 --- a/lib/clientstore.go +++ b/lib/clientstore.go @@ -40,19 +40,31 @@ func (s *ClientStore) Leave(id string) { s.metrics.sessionGauge.Dec() } -func (s *ClientStore) Subscribe(id string, topic string) { +func (s *ClientStore) Subscribe(id string, topics []string) { s.mutex.Lock() defer s.mutex.Unlock() client, ok := s.clients[id] - if ok { - client.Subscribtions = append(client.Subscribtions, topic) - client.LastActivityAt = time.Now() + if !ok { + return + } + for _, topic := range topics { + found := false + for _, sub := range client.Subscribtions { + if sub == topic { + found = true + break + } + } + if !found { + client.Subscribtions = append(client.Subscribtions, topic) + } + labels := prometheus.Labels{"topic": topic} + s.metrics.subscribeCounter.With(labels).Inc() } - labels := prometheus.Labels{"topic": topic} - s.metrics.subscribeCounter.With(labels).Inc() + client.LastActivityAt = time.Now() } func (s *ClientStore) Publish(id string, topic string) { @@ -60,12 +72,14 @@ func (s *ClientStore) Publish(id string, topic string) { defer s.mutex.Unlock() client, ok := s.clients[id] - if ok { - value := client.Publications[topic] - client.Publications[topic] = value + 1 - client.LastActivityAt = time.Now() + if !ok { + return } + value := client.Publications[topic] + client.Publications[topic] = value + 1 + client.LastActivityAt = time.Now() + labels := prometheus.Labels{"topic": topic} s.metrics.publishCounter.With(labels).Inc() } diff --git a/test/main_test.go b/test/main_test.go index 5341792..7f1914b 100644 --- a/test/main_test.go +++ b/test/main_test.go @@ -70,6 +70,11 @@ func TestPublishIsForwardedToHTTP(t *testing.T) { t.Fatalf("client reports not connected") } + // Subscribe to a topic + if tok := client.Subscribe("topic/test", 0, nil); !tok.WaitTimeout(5*time.Second) || tok.Error() != nil { + t.Fatalf("subscribe failed: %v", tok.Error()) + } + // Publish and assert HTTP saw the payload. payload := []byte(`{"hello":"world"}`) if tok := client.Publish("devices/42/state", 0, false, payload); !tok.WaitTimeout(5*time.Second) || tok.Error() != nil { @@ -95,7 +100,7 @@ func TestPublishIsForwardedToHTTP(t *testing.T) { if err != nil { t.Fatalf("response read failed: %v", err) } - if !strings.Contains(string(content), "\"id\":\"it-test\",\"username\":\"testClient\",\"subscriptions\":[\"devices/42/state\"],\"publications\":{\"devices/42/state\":1}") { - t.Fatalf("unexpected content from the clients endpoint") + if !strings.Contains(string(content), "\"id\":\"it-test\",\"username\":\"testClient\",\"subscriptions\":[\"topic/test\"],\"publications\":{\"devices/42/state\":1}") { + t.Fatalf("unexpected content from the clients endpoint, got %s", content) } }