diff --git a/.gitignore b/.gitignore index 94cb000..d941e23 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ .env internal/db/.env tmp/ +.obsidian/ diff --git a/go.mod b/go.mod index 1f087af..914353a 100644 --- a/go.mod +++ b/go.mod @@ -11,6 +11,7 @@ require ( github.com/swaggo/files v1.0.1 github.com/swaggo/gin-swagger v1.6.1 github.com/swaggo/swag v1.16.6 + golang.org/x/time v0.12.0 gorm.io/driver/mysql v1.6.0 gorm.io/driver/sqlite v1.6.0 gorm.io/gorm v1.30.1 @@ -28,7 +29,6 @@ require ( github.com/docker/docker v27.1.1+incompatible // indirect github.com/docker/go-connections v0.5.0 // indirect github.com/docker/go-units v0.5.0 // indirect - github.com/fatih/color v1.18.0 // indirect github.com/go-openapi/swag/conv v0.25.3 // indirect github.com/go-openapi/swag/jsonname v0.25.3 // indirect github.com/go-openapi/swag/jsonutils v0.25.3 // indirect @@ -40,7 +40,6 @@ require ( github.com/goccy/go-yaml v1.18.0 // indirect github.com/gogo/protobuf v1.3.2 // indirect github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 // indirect - github.com/mattn/go-colorable v0.1.14 // indirect github.com/moby/docker-image-spec v1.3.1 // indirect github.com/moby/sys/user v0.3.0 // indirect github.com/moby/term v0.5.0 // indirect @@ -50,7 +49,6 @@ require ( github.com/pkg/errors v0.9.1 // indirect github.com/quic-go/qpack v0.5.1 // indirect github.com/quic-go/quic-go v0.56.0 // indirect - github.com/rakyll/gotest v0.0.6 // indirect github.com/sirupsen/logrus v1.9.3 // indirect github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb // indirect github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 // indirect diff --git a/go.sum b/go.sum index 68cf021..9a691d7 100644 --- a/go.sum +++ b/go.sum @@ -35,9 +35,6 @@ github.com/docker/go-connections v0.5.0 h1:USnMq7hx7gwdVZq1L49hLXaFtUdTADjXGp+uj github.com/docker/go-connections v0.5.0/go.mod h1:ov60Kzw0kKElRwhNs9UlUHAE/F9Fe6GLaXnqyDdmEXc= github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4= github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk= -github.com/fatih/color v1.9.0/go.mod h1:eQcE1qtQxscV5RaZvpXrrb8Drkc3/DdQ+uUYCNjL+zU= -github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM= -github.com/fatih/color v1.18.0/go.mod h1:4FelSpRwEGDpQ12mAdzqdOukCy4u8WUtOY6lkT/6HfU= github.com/gabriel-vasile/mimetype v1.4.11 h1:AQvxbp830wPhHTqc1u7nzoLT+ZFxGY7emj5DR5DYFik= github.com/gabriel-vasile/mimetype v1.4.11/go.mod h1:d+9Oxyo1wTzWdyVUPMmXFvp4F9tea18J8ufA774AB3s= github.com/gin-contrib/gzip v0.0.6 h1:NjcunTcGAj5CO1gn4N8jHOSIeRFHIbn51z6K+xaN4d4= @@ -102,8 +99,6 @@ github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= -github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= -github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= @@ -118,11 +113,6 @@ github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= -github.com/mattn/go-colorable v0.1.4/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE= -github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= -github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8= -github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= -github.com/mattn/go-isatty v0.0.11/go.mod h1:PhnuNfih5lzO57/f3n+odYbM4JtupLOxQOAqxQCu2WE= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU= @@ -156,8 +146,6 @@ github.com/quic-go/qpack v0.5.1 h1:giqksBPnT/HDtZ6VhtFKgoLOWmlyo9Ei6u9PqzIMbhI= github.com/quic-go/qpack v0.5.1/go.mod h1:+PC4XFrEskIVkcLzpEkbLqq1uCoxPhQuvK5rH1ZgaEg= github.com/quic-go/quic-go v0.56.0 h1:q/TW+OLismmXAehgFLczhCDTYB3bFmua4D9lsNBWxvY= github.com/quic-go/quic-go v0.56.0/go.mod h1:9gx5KsFQtw2oZ6GZTyh+7YEvOxWCL9WZAepnHxgAo6c= -github.com/rakyll/gotest v0.0.6 h1:hBTqkO3jiuwYW/M9gL4bu0oTYcm8J6knQAAPUsJsz1I= -github.com/rakyll/gotest v0.0.6/go.mod h1:SkoesdNCWmiD4R2dljIUcfSnNdVZ12y8qK4ojDkc2Sc= github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc= github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= @@ -228,9 +216,7 @@ golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I= golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= diff --git a/internal/handler/user/user_handler.go b/internal/handler/user/user_handler.go index 25f0e2c..d29b522 100644 --- a/internal/handler/user/user_handler.go +++ b/internal/handler/user/user_handler.go @@ -33,6 +33,7 @@ var _ UserHandlerInterface = (*UserHandler)(nil) // @Success 201 {object} dto.CreateUserResponse // @Failure 400 {object} common.ErrorResponse // @Failure 409 {object} common.ErrorResponse "Email already exists" +// @Failure 429 {object} common.ErrorResponse "Rate limit exceeded" // @Router /auth/register [post] func (h *UserHandler) Register(c *gin.Context) { var req dto.CreateUserRequest @@ -64,6 +65,7 @@ func (h *UserHandler) Register(c *gin.Context) { // @Success 200 {object} dto.AuthResponse // @Failure 400 {object} common.ErrorResponse // @Failure 401 {object} common.ErrorResponse "Invalid credentials" +// @Failure 429 {object} common.ErrorResponse "Rate limit exceeded" // @Router /auth/login [post] func (h *UserHandler) Login(c *gin.Context) { var req dto.AuthRequest diff --git a/internal/middleware/ratelimiter/rate_limiter.go b/internal/middleware/ratelimiter/rate_limiter.go new file mode 100644 index 0000000..f2346e1 --- /dev/null +++ b/internal/middleware/ratelimiter/rate_limiter.go @@ -0,0 +1,94 @@ +package ratelimiter + +import ( + "net/http" + "sync" + "time" + + "taskflow/internal/common" + + "github.com/gin-gonic/gin" + "golang.org/x/time/rate" +) + +// IPRateLimiter holds rate limiters for each IP address +type IPRateLimiter struct { + ips map[string]*rate.Limiter + mu *sync.RWMutex + r rate.Limit + b int +} + +// NewIPRateLimiter creates a new IP-based rate limiter +// r: rate limit (requests per second) +// b: burst size (maximum requests allowed at once) +func NewIPRateLimiter(r rate.Limit, b int) *IPRateLimiter { + return &IPRateLimiter{ + ips: make(map[string]*rate.Limiter), + mu: &sync.RWMutex{}, + r: r, + b: b, + } +} + +// AddIP creates a new rate limiter for an IP address +func (i *IPRateLimiter) AddIP(ip string) *rate.Limiter { + i.mu.Lock() + defer i.mu.Unlock() + + limiter := rate.NewLimiter(i.r, i.b) + i.ips[ip] = limiter + + return limiter +} + +// GetLimiter returns the rate limiter for the provided IP address +func (i *IPRateLimiter) GetLimiter(ip string) *rate.Limiter { + i.mu.Lock() + limiter, exists := i.ips[ip] + + if !exists { + i.mu.Unlock() + return i.AddIP(ip) + } + + i.mu.Unlock() + return limiter +} + +// Middleware returns a Gin middleware function for rate limiting +func (i *IPRateLimiter) Middleware() gin.HandlerFunc { + return func(c *gin.Context) { + limiter := i.GetLimiter(c.ClientIP()) + + if !limiter.Allow() { + c.JSON(http.StatusTooManyRequests, common.ErrorResponse{ + Message: "rate limit exceeded, please try again later", + }) + c.Abort() + return + } + + c.Next() + } +} + +// CleanupOldEntries removes old entries from the rate limiter map +// This should be called periodically to prevent memory leaks +func (i *IPRateLimiter) CleanupOldEntries() { + i.mu.Lock() + defer i.mu.Unlock() + + // Clear all entries - they will be recreated on next request + i.ips = make(map[string]*rate.Limiter) +} + +// StartCleanupRoutine starts a goroutine that periodically cleans up old entries +func (i *IPRateLimiter) StartCleanupRoutine(interval time.Duration) { + ticker := time.NewTicker(interval) + go func() { + for range ticker.C { + i.CleanupOldEntries() + } + }() +} diff --git a/internal/middleware/ratelimiter/rate_limiter_test.go b/internal/middleware/ratelimiter/rate_limiter_test.go new file mode 100644 index 0000000..02aade7 --- /dev/null +++ b/internal/middleware/ratelimiter/rate_limiter_test.go @@ -0,0 +1,212 @@ +package ratelimiter + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/assert" + "golang.org/x/time/rate" +) + +func TestNewIPRateLimiter(t *testing.T) { + limiter := NewIPRateLimiter(rate.Limit(5), 10) + + assert.NotNil(t, limiter) + assert.NotNil(t, limiter.ips) + assert.NotNil(t, limiter.mu) + assert.Equal(t, rate.Limit(5), limiter.r) + assert.Equal(t, 10, limiter.b) +} + +func TestAddIP(t *testing.T) { + limiter := NewIPRateLimiter(rate.Limit(5), 10) + ip := "192.168.1.1" + + rateLimiter := limiter.AddIP(ip) + + assert.NotNil(t, rateLimiter) + assert.Equal(t, 1, len(limiter.ips)) +} + +func TestGetLimiter(t *testing.T) { + limiter := NewIPRateLimiter(rate.Limit(5), 10) + ip := "192.168.1.1" + + // First call should create a new limiter + rateLimiter1 := limiter.GetLimiter(ip) + assert.NotNil(t, rateLimiter1) + assert.Equal(t, 1, len(limiter.ips)) + + // Second call should return the same limiter + rateLimiter2 := limiter.GetLimiter(ip) + assert.NotNil(t, rateLimiter2) + assert.Equal(t, 1, len(limiter.ips)) + assert.Equal(t, rateLimiter1, rateLimiter2) +} + +func TestMiddleware_AllowsRequests(t *testing.T) { + gin.SetMode(gin.TestMode) + + // Create a rate limiter with generous limits + limiter := NewIPRateLimiter(rate.Limit(100), 100) + + r := gin.New() + r.Use(limiter.Middleware()) + r.GET("/test", func(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{"message": "success"}) + }) + + // Make a request + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/test", nil) + r.ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + + var response map[string]string + err := json.Unmarshal(w.Body.Bytes(), &response) + assert.NoError(t, err) + assert.Equal(t, "success", response["message"]) +} + +func TestMiddleware_RateLimitsRequests(t *testing.T) { + gin.SetMode(gin.TestMode) + + // Create a rate limiter with very strict limits (1 request per second, burst of 2) + limiter := NewIPRateLimiter(rate.Limit(1), 2) + + r := gin.New() + r.Use(limiter.Middleware()) + r.GET("/test", func(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{"message": "success"}) + }) + + successCount := 0 + rateLimitedCount := 0 + + // Make 5 rapid requests + for i := 0; i < 5; i++ { + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/test", nil) + req.RemoteAddr = "192.168.1.1:1234" // Same IP for all requests + r.ServeHTTP(w, req) + + if w.Code == http.StatusOK { + successCount++ + } else if w.Code == http.StatusTooManyRequests { + rateLimitedCount++ + } + } + + // First 2 requests should succeed (burst capacity) + // Remaining 3 should be rate limited + assert.Equal(t, 2, successCount, "Expected 2 successful requests") + assert.Equal(t, 3, rateLimitedCount, "Expected 3 rate limited requests") +} + +func TestMiddleware_DifferentIPs(t *testing.T) { + gin.SetMode(gin.TestMode) + + // Create a rate limiter with strict limits + limiter := NewIPRateLimiter(rate.Limit(1), 2) + + r := gin.New() + r.Use(limiter.Middleware()) + r.GET("/test", func(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{"message": "success"}) + }) + + // Test with different IPs + ips := []string{"192.168.1.1:1234", "192.168.1.2:1234", "192.168.1.3:1234"} + + for _, ip := range ips { + // Each IP should get 2 successful requests (burst) + for i := 0; i < 2; i++ { + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/test", nil) + req.RemoteAddr = ip + r.ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code, "Request should succeed for different IPs") + } + + // Third request from same IP should be rate limited + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/test", nil) + req.RemoteAddr = ip + r.ServeHTTP(w, req) + + assert.Equal(t, http.StatusTooManyRequests, w.Code, "Third request should be rate limited") + } +} + +func TestMiddleware_RateLimitResponse(t *testing.T) { + gin.SetMode(gin.TestMode) + + // Create a rate limiter with very strict limits + limiter := NewIPRateLimiter(rate.Limit(1), 1) + + r := gin.New() + r.Use(limiter.Middleware()) + r.GET("/test", func(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{"message": "success"}) + }) + + // First request should succeed + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/test", nil) + req.RemoteAddr = "192.168.1.1:1234" + r.ServeHTTP(w, req) + assert.Equal(t, http.StatusOK, w.Code) + + // Second request should be rate limited + w = httptest.NewRecorder() + req, _ = http.NewRequest("GET", "/test", nil) + req.RemoteAddr = "192.168.1.1:1234" + r.ServeHTTP(w, req) + + assert.Equal(t, http.StatusTooManyRequests, w.Code) + + var response map[string]string + err := json.Unmarshal(w.Body.Bytes(), &response) + assert.NoError(t, err) + assert.Equal(t, "rate limit exceeded, please try again later", response["error"]) +} + +func TestCleanupOldEntries(t *testing.T) { + limiter := NewIPRateLimiter(rate.Limit(5), 10) + + // Add some IPs + limiter.GetLimiter("192.168.1.1") + limiter.GetLimiter("192.168.1.2") + limiter.GetLimiter("192.168.1.3") + + assert.Equal(t, 3, len(limiter.ips)) + + // Clean up entries + limiter.CleanupOldEntries() + + assert.Equal(t, 0, len(limiter.ips)) +} + +func TestStartCleanupRoutine(t *testing.T) { + limiter := NewIPRateLimiter(rate.Limit(5), 10) + + // Add some IPs + limiter.GetLimiter("192.168.1.1") + limiter.GetLimiter("192.168.1.2") + assert.Equal(t, 2, len(limiter.ips)) + + // Start cleanup routine with very short interval + limiter.StartCleanupRoutine(100 * time.Millisecond) + + // Wait for cleanup to run + time.Sleep(200 * time.Millisecond) + + // IPs should be cleaned up + assert.Equal(t, 0, len(limiter.ips)) +} diff --git a/main.go b/main.go index b359643..850ccb3 100644 --- a/main.go +++ b/main.go @@ -2,12 +2,14 @@ package main import ( "log" + "time" "taskflow/internal/auth" "taskflow/internal/domain/task" "taskflow/internal/domain/user" task_handler "taskflow/internal/handler/task" user_handler "taskflow/internal/handler/user" + "taskflow/internal/middleware/ratelimiter" "taskflow/internal/repository/gorm/gorm_task" "taskflow/internal/repository/gorm/gorm_user" task_service "taskflow/internal/service/task" @@ -21,6 +23,7 @@ import ( _ "github.com/go-sql-driver/mysql" swaggerfiles "github.com/swaggo/files" ginSwagger "github.com/swaggo/gin-swagger" + "golang.org/x/time/rate" ) // @title TaskFlow API @@ -54,6 +57,12 @@ func main() { taskHandler := task_handler.NewTaskHandler(taskSvc, userAuth) userHandler := user_handler.NewUserHandler(userSvc, userAuth) + // Rate limiter setup for auth endpoints + // Allows 5 requests per second with a burst of 10 requests + authRateLimiter := ratelimiter.NewIPRateLimiter(rate.Limit(5), 10) + // Clean up old IP entries every hour to prevent memory leaks + authRateLimiter.StartCleanupRoutine(1 * time.Hour) + // Router setup r := gin.Default() docs.SwaggerInfo.BasePath = "/api" @@ -62,8 +71,13 @@ func main() { api := r.Group("/api") { - api.POST("/auth/register", userHandler.Register) - api.POST("/auth/login", userHandler.Login) + // Auth routes with rate limiting + authRoutes := api.Group("/auth") + authRoutes.Use(authRateLimiter.Middleware()) + { + authRoutes.POST("/register", userHandler.Register) + authRoutes.POST("/login", userHandler.Login) + } taskRoutes := api.Group("/tasks") taskRoutes.Use(userAuth.AuthMiddleware())