Skip to content

Commit

Permalink
refactor: simplify store implementation (resolve #5)
Browse files Browse the repository at this point in the history
fix: make store thread-safe (resolve #1)
  • Loading branch information
muety committed Jan 7, 2021
1 parent bc23f48 commit ac9137e
Show file tree
Hide file tree
Showing 13 changed files with 184 additions and 151 deletions.
26 changes: 16 additions & 10 deletions api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@ import (
"fmt"
"github.com/muety/webhook2telegram/config"
"github.com/muety/webhook2telegram/model"
"github.com/muety/webhook2telegram/services"
"github.com/muety/webhook2telegram/store"
limiter "github.com/n1try/limiter/v3"
"github.com/n1try/limiter/v3"
memst "github.com/n1try/limiter/v3/drivers/store/memory"
uuid "github.com/satori/go.uuid"
"io/ioutil"
Expand All @@ -21,14 +22,20 @@ import (
)

var (
botStore store.Store
botConfig *config.BotConfig
client *http.Client
cmdRateLimiter *limiter.Limiter
userService *services.UserService
)

func init() {
// get config
botConfig = config.Get()
botStore = config.GetStore()

// init services
userService = services.NewUserService(botStore)

// init http client
client = &http.Client{Timeout: (config.PollTimeoutSec + 10) * time.Second}
Expand All @@ -46,11 +53,11 @@ func init() {

func GetUpdate() (*[]model.TelegramUpdate, error) {
offset := 0
if store.Get(config.KeyUpdateID) != nil {
offset = int(store.Get(config.KeyUpdateID).(float64)) + 1
if botStore.Get(config.KeyUpdateID) != nil {
offset = int(botStore.Get(config.KeyUpdateID).(float64)) + 1
}
apiUrl := botConfig.GetApiUrl() + string("/getUpdates?timeout="+strconv.Itoa(config.PollTimeoutSec)+"&offset="+strconv.Itoa(offset))
log.Println("Polling for updates.")
apiUrl := botConfig.GetApiUrl() + "/getUpdates?timeout=" + strconv.Itoa(config.PollTimeoutSec) + "&offset=" + strconv.Itoa(offset)
log.Println("polling for updates")
request, _ := http.NewRequest(http.MethodGet, apiUrl, nil)
request.Close = true

Expand All @@ -76,7 +83,7 @@ func GetUpdate() (*[]model.TelegramUpdate, error) {

if len(update.Result) > 0 {
var latestUpdateId interface{} = float64(update.Result[len(update.Result)-1].UpdateId)
store.Put(config.KeyUpdateID, latestUpdateId)
botStore.Put(config.KeyUpdateID, latestUpdateId)
}

return &update.Result, nil
Expand Down Expand Up @@ -169,11 +176,11 @@ func processUpdate(update model.TelegramUpdate) {
if strings.TrimSpace(update.Message.Text) == config.CmdStart {
// create new token
id := uuid.NewV4()
store.InvalidateToken(chatId)
store.Put(id.String(), model.StoreObject{User: update.Message.From, ChatId: chatId})
userService.InvalidateToken(chatId)
botStore.Put(id.String(), model.StoreObject{User: update.Message.From, ChatId: chatId})

text = fmt.Sprintf(config.MessageTokenResponse, id.String())
log.Printf("Sending new token %s to %d", id.String(), chatId)
log.Printf("sending new token %s to %d", id.String(), chatId)
} else if strings.TrimSpace(update.Message.Text) == config.CmdHelp {
// print help message
text = fmt.Sprintf(config.MessageHelpResponse, botConfig.Version)
Expand All @@ -193,7 +200,6 @@ func processUpdate(update model.TelegramUpdate) {

func checkBlacklist(senderId int) bool {
for _, id := range botConfig.Blacklist {
// TODO: refactor ids to be strings, not numbers!
if sid, err := strconv.Atoi(id); err == nil && sid == senderId {
return true
}
Expand Down
14 changes: 6 additions & 8 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,16 @@ import (
)

const (
BaseURL = "https://api.telegram.org/bot"
StoreFile = "store.gob"
PollTimeoutSec = 60
FlushTimeoutMin = 1
UserIdRegex = "(?m)^\\d+$"
BaseURL = "https://api.telegram.org/bot"
StoreFile = "store.gob"
PollTimeoutSec = 60
UserIdRegex = "(?m)^\\d+$"
)

const (
KeyUpdateID = "latestUpdateId"
KeyMessage = "message"
KeyParams = "message_params"
KeyMessages = "messages"
)

const (
Expand Down Expand Up @@ -126,12 +124,12 @@ func Get() *BotConfig {
flag.Parse()

if *tokenPtr == "" {
log.Fatalln("Token missing.")
log.Fatalln("token missing")
}

proxyUri, err := url.Parse(*proxyPtr)
if err != nil {
log.Println("Failed to parse proxy URI.")
log.Println("failed to parse proxy uri")
}

cfg = &BotConfig{
Expand Down
14 changes: 14 additions & 0 deletions config/store.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package config

import (
"github.com/muety/webhook2telegram/store"
)

var storeInstance store.Store

func GetStore() store.Store {
if storeInstance == nil {
storeInstance = store.NewGobStore(Get().GetStorePath())
}
return storeInstance
}
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ require (
github.com/leandro-lugaresi/hub v1.1.1
github.com/n1try/limiter/v3 v3.5.0
github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e // indirect
github.com/orcaman/concurrent-map v0.0.0-20210106121528-16402b402231
github.com/prometheus/client_golang v1.8.0
github.com/satori/go.uuid v1.2.0
gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f // indirect
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,8 @@ github.com/openzipkin-contrib/zipkin-go-opentracing v0.4.5/go.mod h1:/wsWhb9smxS
github.com/openzipkin/zipkin-go v0.1.6/go.mod h1:QgAqvLzwWbR/WpD4A3cGpPtJrZXNIiJc5AZX7/PBEpw=
github.com/openzipkin/zipkin-go v0.2.1/go.mod h1:NaW6tEwdmWMaCDZzg8sh+IBNOxHMPnhQw8ySjnjRyN4=
github.com/openzipkin/zipkin-go v0.2.2/go.mod h1:NaW6tEwdmWMaCDZzg8sh+IBNOxHMPnhQw8ySjnjRyN4=
github.com/orcaman/concurrent-map v0.0.0-20210106121528-16402b402231 h1:fa50YL1pzKW+1SsBnJDOHppJN9stOEwS+CRWyUtyYGU=
github.com/orcaman/concurrent-map v0.0.0-20210106121528-16402b402231/go.mod h1:Lu3tH6HLW3feq74c2GC+jIMS/K2CFcDWnWD9XkenwhI=
github.com/pact-foundation/pact-go v1.0.4/go.mod h1:uExwJY4kCzNPcHRj+hCR/HBbOOIwwtUjcrb0b5/5kLM=
github.com/pascaldekloe/goe v0.0.0-20180627143212-57f6aae5913c/go.mod h1:lzWF7FIEvWOWxwDKqyGYQf6ZUaNfKdP144TG7ZOy1lc=
github.com/pborman/uuid v1.2.0/go.mod h1:X/NO0urCmaxf9VXbdlT7C2Yzkj2IKimNn4k+gtPdI/k=
Expand Down
14 changes: 8 additions & 6 deletions handlers/message.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,19 @@ import (
"github.com/muety/webhook2telegram/config"
"github.com/muety/webhook2telegram/model"
"github.com/muety/webhook2telegram/resolvers"
"github.com/muety/webhook2telegram/store"
"github.com/muety/webhook2telegram/services"
"net/http"
)

type MessageHandler struct{}
type MessageHandler struct {
userService *services.UserService
}

func NewMessageHandler() *MessageHandler {
return &MessageHandler{}
func NewMessageHandler(userService *services.UserService) *MessageHandler {
return &MessageHandler{userService: userService}
}

func (h MessageHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
func (h *MessageHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
var m *model.DefaultMessage
var p *model.MessageParams

Expand Down Expand Up @@ -50,7 +52,7 @@ func (h MessageHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return
}

recipientId := store.ResolveToken(token)
recipientId := h.userService.ResolveToken(token)

if len(recipientId) == 0 {
w.WriteHeader(http.StatusNotFound)
Expand Down
24 changes: 6 additions & 18 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"github.com/muety/webhook2telegram/inlets/alertmanager_webhook"
"github.com/muety/webhook2telegram/inlets/bitbucket_webhook"
"github.com/muety/webhook2telegram/inlets/webmentionio_webhook"
"github.com/muety/webhook2telegram/services"
"github.com/prometheus/client_golang/prometheus/promhttp"
"log"
"net"
Expand All @@ -21,27 +22,19 @@ import (
"github.com/muety/webhook2telegram/config"
"github.com/muety/webhook2telegram/inlets/default"
"github.com/muety/webhook2telegram/middleware"
"github.com/muety/webhook2telegram/store"
)

var (
botConfig *config.BotConfig
)
var botConfig *config.BotConfig

func init() {
botConfig = config.Get()
}

func flush() {
for {
time.Sleep(config.FlushTimeoutMin * time.Minute)
store.Flush(botConfig.GetStorePath())
}
}

func registerRoutes() {
indexHandler := handlers.NewIndexHandler()
messageHandler := handlers.NewMessageHandler()
messageHandler := handlers.NewMessageHandler(
services.NewUserService(config.GetStore()),
)
baseChain := alice.New(
middleware.WithEventLogging(),
middleware.WithMethodCheck(),
Expand Down Expand Up @@ -124,15 +117,10 @@ func listen() {

func exitGracefully() {
config.GetHub().Close()
store.Flush(botConfig.GetStorePath())
config.GetStore().Flush()
}

func main() {
store.Read(botConfig.GetStorePath())
store.Automigrate()

go flush()

registerRoutes()
connectApi()
listen()
Expand Down
32 changes: 32 additions & 0 deletions services/user.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package services

import (
"github.com/muety/webhook2telegram/model"
"github.com/muety/webhook2telegram/store"
"strconv"
)

type UserService struct {
store store.Store
}

func NewUserService(store store.Store) *UserService {
return &UserService{store: store}
}

func (s *UserService) InvalidateToken(userChatId int) {
for k, v := range s.store.GetItems() {
entry, ok := v.(model.StoreObject)
if ok && entry.ChatId == userChatId {
s.store.Delete(k)
}
}
}

func (s *UserService) ResolveToken(token string) string {
value := s.store.Get(token)
if value != nil {
return strconv.Itoa((value.(model.StoreObject)).ChatId)
}
return ""
}
92 changes: 92 additions & 0 deletions store/gob.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
package store

import (
"encoding/gob"
"github.com/muety/webhook2telegram/model"
"github.com/orcaman/concurrent-map"
"log"
"os"
)

type GobStore struct {
data cmap.ConcurrentMap
filePath string
}

func NewGobStore(filePath string) *GobStore {
//gob.Register(model.StoreObject{})
//gob.Register(model.StoreMessageObject{})

// Backwards compatibility
gob.RegisterName("main.StoreObject", model.StoreObject{})
gob.RegisterName("main.StoreMessageObject", model.StoreMessageObject{})

store := &GobStore{
data: cmap.New(),
filePath: filePath,
}

if err := store.load(); err == nil {
log.Println("read existing gob store from file")
}

return store
}

func (s *GobStore) load() error {
file, err := os.Open(s.filePath)
defer file.Close()
if err != nil {
log.Printf("error: failed to read store from %s\n", s.filePath)
return nil
}

var rawData map[string]interface{}
if err := gob.NewDecoder(file).Decode(&rawData); err != nil {
log.Printf("error: failed to decode store data from %s (%v)\n", s.filePath, err)
return nil
}

s.data = cmap.New()
for k, v := range rawData {
s.data.Set(k, v)
}

return nil
}

func (s *GobStore) dump() error {
file, err := os.Create(s.filePath)
defer file.Close()
if err != nil {
log.Printf("error: failed to dump store to %s (%v)", s.filePath, err)
return err
}

return gob.NewEncoder(file).Encode(s.data.Items())
}

func (s *GobStore) Get(key string) interface{} {
if v, ok := s.data.Get(key); ok {
return v
}
return nil
}

func (s *GobStore) Put(key string, value interface{}) {
s.data.Set(key, value)
go s.dump()
}

func (s *GobStore) Delete(key string) {
s.data.Remove(key)
go s.dump()
}

func (s *GobStore) GetItems() map[string]interface{} {
return s.data.Items()
}

func (s *GobStore) Flush() error {
return s.dump()
}
Loading

0 comments on commit ac9137e

Please sign in to comment.