diff --git a/internal/auth/auth.go b/internal/auth/auth.go index 51871e7..b9820ea 100644 --- a/internal/auth/auth.go +++ b/internal/auth/auth.go @@ -1,6 +1,7 @@ package auth import ( + "context" "net/http" "strconv" "strings" @@ -158,3 +159,7 @@ func (ua *UserAuth) OptionalAuthMiddleware() gin.HandlerFunc { c.Next() } } + +func SetUserID(ctx context.Context, userID int) context.Context { + return context.WithValue(ctx, "userID", userID) +} diff --git a/internal/domain/task/entity.go b/internal/domain/task/entity.go index d8f3004..ca18ed5 100644 --- a/internal/domain/task/entity.go +++ b/internal/domain/task/entity.go @@ -3,8 +3,9 @@ package task import "time" type Task struct { - ID int `json:"id" binding:"gte=1" example:"1" ` + ID int `json:"id" gorm:"primaryKey"` Task string `json:"task" binding:"required" example:"Buy milk" gorm:"not null"` Status string `json:"status" binding:"required" example:"pending" gorm:"not null"` + UserID int `json:"user_id" gorm:"not null;index"` CreatedAt time.Time `json:"created_at" example:"2025-08-27 10:35:16.263"` } diff --git a/internal/domain/user/entity.go b/internal/domain/user/entity.go index fc447e0..ba8fcc5 100644 --- a/internal/domain/user/entity.go +++ b/internal/domain/user/entity.go @@ -1,6 +1,7 @@ package user import ( + "taskflow/internal/domain/task" "time" "gorm.io/gorm" @@ -10,6 +11,7 @@ type User struct { ID int `gorm:"primaryKey" json:"id"` Email string `gorm:"uniqueIndex;size:255;not null" json:"email"` Password string `gorm:"size:255;not null" json:"password,omitempty"` + Tasks []task.Task `json:"tasks" gorm:"foreignKey:UserID;constraint:OnDelete:CASCADE"` CreatedAt time.Time `json:"created_at"` UpdatedAt time.Time `json:"updated_at"` DeletedAt gorm.DeletedAt `gorm:"index" json:"deleted_at"` diff --git a/internal/handler/task/task_handler.go b/internal/handler/task/task_handler.go index e06f8ca..f07868e 100644 --- a/internal/handler/task/task_handler.go +++ b/internal/handler/task/task_handler.go @@ -41,7 +41,14 @@ func (h *TaskHandler) CreateTask(c *gin.Context) { c.JSON(http.StatusBadRequest, common.ErrorResponse{Message: err.Error()}) return } - if err := h.service.CreateTask(&req); err != nil { + + userID, exists := c.Get("userID") + if !exists { + c.JSON(http.StatusUnauthorized, common.ErrorResponse{Message: "unauthorized"}) + return + } + + if err := h.service.CreateTask(userID.(int), &req); err != nil { c.JSON(http.StatusBadRequest, common.ErrorResponse{Message: err.Error()}) return } @@ -65,7 +72,13 @@ func (h *TaskHandler) GetTask(c *gin.Context) { return } - resp, err := h.service.GetTask(id) + userID, exists := c.Get("userID") + if !exists { + c.JSON(http.StatusUnauthorized, common.ErrorResponse{Message: "unauthorized"}) + return + } + + resp, err := h.service.GetTask(userID.(int), id) if err != nil { c.JSON(http.StatusNotFound, common.ErrorResponse{Message: "Task not found"}) return @@ -85,7 +98,13 @@ func (h *TaskHandler) GetTask(c *gin.Context) { // @Failure 500 {object} common.ErrorResponse "Internal server error" // @Router /tasks [get] func (h *TaskHandler) ListTasks(c *gin.Context) { - res, err := h.service.ListTasks() + userID, exists := c.Get("userID") + if !exists { + c.JSON(http.StatusUnauthorized, common.ErrorResponse{Message: "unauthorized"}) + return + } + + res, err := h.service.ListTasks(userID.(int)) if err != nil { c.JSON(http.StatusInternalServerError, common.ErrorResponse{Message: err.Error()}) return @@ -106,6 +125,12 @@ func (h *TaskHandler) ListTasks(c *gin.Context) { // @Failure 404 {object} common.ErrorResponse "Task not found" // @Router /tasks/{id}/status [patch] func (h *TaskHandler) UpdateStatus(c *gin.Context) { + userID, exists := c.Get("userID") + if !exists { + c.JSON(http.StatusUnauthorized, common.ErrorResponse{Message: "unauthorized"}) + return + } + id, err := strconv.Atoi(c.Param("id")) if err != nil || id < 1 { c.JSON(http.StatusBadRequest, common.ErrorResponse{Message: "invalid task ID"}) @@ -118,7 +143,7 @@ func (h *TaskHandler) UpdateStatus(c *gin.Context) { return } - if err := h.service.UpdateStatus(id, req.Status); err != nil { + if err := h.service.UpdateStatus(userID.(int), id, req.Status); err != nil { if err == gorm.ErrRecordNotFound { c.JSON(http.StatusNotFound, common.ErrorResponse{Message: "Task not found"}) return @@ -143,6 +168,12 @@ func (h *TaskHandler) UpdateStatus(c *gin.Context) { // @Failure 500 {object} common.ErrorResponse "Internal server error" // @Router /tasks/{id} [delete] func (h *TaskHandler) Delete(c *gin.Context) { + userID, exists := c.Get("userID") + if !exists { + c.JSON(http.StatusUnauthorized, common.ErrorResponse{Message: "unauthorized"}) + return + } + // Parse ID from path id, err := strconv.Atoi(c.Param("id")) if err != nil || id < 1 { @@ -150,7 +181,7 @@ func (h *TaskHandler) Delete(c *gin.Context) { return } - if err := h.service.Delete(id); err != nil { + if err := h.service.Delete(userID.(int), id); err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { c.JSON(http.StatusNotFound, common.ErrorResponse{Message: "Task not found"}) } else { diff --git a/internal/handler/task/task_handler_mock.go b/internal/handler/task/task_handler_mock.go deleted file mode 100644 index 805a361..0000000 --- a/internal/handler/task/task_handler_mock.go +++ /dev/null @@ -1 +0,0 @@ -package task_handler diff --git a/internal/handler/task/task_handler_test.go b/internal/handler/task/task_handler_test.go index 9508c6e..0dbdc1d 100644 --- a/internal/handler/task/task_handler_test.go +++ b/internal/handler/task/task_handler_test.go @@ -22,22 +22,25 @@ func setupGin() *gin.Engine { gin.SetMode(gin.TestMode) return gin.New() } + func TestTaskHandler_CreateTask(t *testing.T) { tests := []struct { name string + userID int requestBody any setupMock func() *task_service.TaskServiceMock expectedStatus int expectedBody any }{ { - name: "success case", + name: "success case", + userID: 1, requestBody: dto.CreateTaskRequest{ Task: "Buy Milk", }, setupMock: func() *task_service.TaskServiceMock { mockService := new(task_service.TaskServiceMock) - mockService.On("CreateTask", mock.MatchedBy(func(req *dto.CreateTaskRequest) bool { + mockService.On("CreateTask", 1, mock.MatchedBy(func(req *dto.CreateTaskRequest) bool { return req.Task == "Buy Milk" })).Return(nil) return mockService @@ -48,37 +51,40 @@ func TestTaskHandler_CreateTask(t *testing.T) { }, }, { - name: "failure case - invalid JSON", - requestBody: `{"task": }`, // malformed JSON + name: "failure case - no userID", + userID: 0, + requestBody: dto.CreateTaskRequest{ + Task: "Buy Milk", + }, setupMock: func() *task_service.TaskServiceMock { return new(task_service.TaskServiceMock) }, - expectedStatus: http.StatusBadRequest, + expectedStatus: http.StatusUnauthorized, expectedBody: common.ErrorResponse{ - Message: "invalid character '}' looking for beginning of value", + Message: "unauthorized", }, }, { - name: "failure case - validation error", - requestBody: dto.CreateTaskRequest{ - Task: "", // empty task - }, + name: "failure case - invalid JSON", + userID: 1, + requestBody: `{"task": }`, // malformed JSON setupMock: func() *task_service.TaskServiceMock { return new(task_service.TaskServiceMock) }, expectedStatus: http.StatusBadRequest, expectedBody: common.ErrorResponse{ - Message: "Key: 'CreateTaskRequest.Task' Error:Tag: 'required' ActualTag: 'required' Namespace: 'CreateTaskRequest.Task' StructNamespace: 'CreateTaskRequest.Task' StructField: 'Task' ActualField: 'Task' Value: '' Param: ''", + Message: "invalid character '}' looking for beginning of value", }, }, { - name: "failure case - service error", + name: "failure case - service error", + userID: 1, requestBody: dto.CreateTaskRequest{ Task: "Buy Milk", }, setupMock: func() *task_service.TaskServiceMock { mockService := new(task_service.TaskServiceMock) - mockService.On("CreateTask", mock.Anything).Return(errors.New("service error")) + mockService.On("CreateTask", 1, mock.Anything).Return(errors.New("service error")) return mockService }, expectedStatus: http.StatusBadRequest, @@ -95,7 +101,12 @@ func TestTaskHandler_CreateTask(t *testing.T) { handler := NewTaskHandler(mockService, mockAuth) router := setupGin() - router.POST("/tasks", handler.CreateTask) + router.POST("/tasks", func(c *gin.Context) { + if tt.userID != 0 { + c.Set("userID", tt.userID) // inject userID to simulate authentication + } + handler.CreateTask(c) + }) var body []byte var err error @@ -139,9 +150,11 @@ func TestTaskHandler_CreateTask(t *testing.T) { }) } } + func TestTaskHandler_GetTask(t *testing.T) { tests := []struct { name string + userID int taskID string setupMock func() *task_service.TaskServiceMock expectedStatus int @@ -149,10 +162,11 @@ func TestTaskHandler_GetTask(t *testing.T) { }{ { name: "success case", + userID: 1, taskID: "1", setupMock: func() *task_service.TaskServiceMock { mockService := new(task_service.TaskServiceMock) - mockService.On("GetTask", 1).Return(dto.GetTaskResponse{ + mockService.On("GetTask", 1, 1).Return(dto.GetTaskResponse{ ID: 1, Task: "Buy Milk", Status: "pending", @@ -168,6 +182,7 @@ func TestTaskHandler_GetTask(t *testing.T) { }, { name: "failure case - invalid ID", + userID: 1, taskID: "invalid", setupMock: func() *task_service.TaskServiceMock { return new(task_service.TaskServiceMock) @@ -179,6 +194,7 @@ func TestTaskHandler_GetTask(t *testing.T) { }, { name: "failure case - ID less than 1", + userID: 1, taskID: "0", setupMock: func() *task_service.TaskServiceMock { return new(task_service.TaskServiceMock) @@ -190,10 +206,11 @@ func TestTaskHandler_GetTask(t *testing.T) { }, { name: "failure case - task not found", + userID: 1, taskID: "999", setupMock: func() *task_service.TaskServiceMock { mockService := new(task_service.TaskServiceMock) - mockService.On("GetTask", 999).Return(dto.GetTaskResponse{}, errors.New("not found")) + mockService.On("GetTask", 1, 999).Return(dto.GetTaskResponse{}, errors.New("not found")) return mockService }, expectedStatus: http.StatusNotFound, @@ -210,7 +227,11 @@ func TestTaskHandler_GetTask(t *testing.T) { handler := NewTaskHandler(mockService, mockAuth) router := setupGin() - router.GET("/tasks/:id", handler.GetTask) + // Wrap handler to inject userID simulating authenticated request + router.GET("/tasks/:id", func(c *gin.Context) { + c.Set("userID", tt.userID) + handler.GetTask(c) + }) req := httptest.NewRequest(http.MethodGet, "/tasks/"+tt.taskID, nil) w := httptest.NewRecorder() @@ -247,7 +268,7 @@ func TestTaskHandler_ListTasks(t *testing.T) { name: "success case - with tasks", setupMock: func() *task_service.TaskServiceMock { mockService := new(task_service.TaskServiceMock) - mockService.On("ListTasks").Return(dto.ListTasksResponse{ + mockService.On("ListTasks", 1).Return(dto.ListTasksResponse{ Tasks: []dto.GetTaskResponse{ {ID: 1, Task: "Buy Milk", Status: "pending"}, {ID: 2, Task: "Buy Eggs", Status: "completed"}, @@ -267,7 +288,7 @@ func TestTaskHandler_ListTasks(t *testing.T) { name: "success case - empty list", setupMock: func() *task_service.TaskServiceMock { mockService := new(task_service.TaskServiceMock) - mockService.On("ListTasks").Return(dto.ListTasksResponse{ + mockService.On("ListTasks", 1).Return(dto.ListTasksResponse{ Tasks: []dto.GetTaskResponse{}, }, nil) return mockService @@ -281,7 +302,7 @@ func TestTaskHandler_ListTasks(t *testing.T) { name: "failure case - service error", setupMock: func() *task_service.TaskServiceMock { mockService := new(task_service.TaskServiceMock) - mockService.On("ListTasks").Return(dto.ListTasksResponse{}, errors.New("database error")) + mockService.On("ListTasks", 1).Return(dto.ListTasksResponse{}, errors.New("database error")) return mockService }, expectedStatus: http.StatusInternalServerError, @@ -289,6 +310,16 @@ func TestTaskHandler_ListTasks(t *testing.T) { Message: "database error", }, }, + { + name: "failure case - no userID", + setupMock: func() *task_service.TaskServiceMock { + return new(task_service.TaskServiceMock) + }, + expectedStatus: http.StatusUnauthorized, + expectedBody: common.ErrorResponse{ + Message: "unauthorized", + }, + }, } for _, tt := range tests { @@ -298,7 +329,13 @@ func TestTaskHandler_ListTasks(t *testing.T) { handler := NewTaskHandler(mockService, mockAuth) router := setupGin() - router.GET("/tasks", handler.ListTasks) + router.GET("/tasks", func(c *gin.Context) { + // Only set userID for cases other than "no userID" + if tt.expectedStatus != http.StatusUnauthorized { + c.Set("userID", 1) + } + handler.ListTasks(c) + }) req := httptest.NewRequest(http.MethodGet, "/tasks", nil) w := httptest.NewRecorder() @@ -327,6 +364,7 @@ func TestTaskHandler_ListTasks(t *testing.T) { func TestTaskHandler_UpdateStatus(t *testing.T) { tests := []struct { name string + userID *int // nil means unauthorized taskID string requestBody any setupMock func() *task_service.TaskServiceMock @@ -335,13 +373,14 @@ func TestTaskHandler_UpdateStatus(t *testing.T) { }{ { name: "success case", + userID: intPtr(123), taskID: "1", requestBody: dto.UpdateStatusRequest{ Status: "completed", }, setupMock: func() *task_service.TaskServiceMock { mockService := new(task_service.TaskServiceMock) - mockService.On("UpdateStatus", 1, "completed").Return(nil) + mockService.On("UpdateStatus", 123, 1, "completed").Return(nil) return mockService }, expectedStatus: http.StatusOK, @@ -349,8 +388,24 @@ func TestTaskHandler_UpdateStatus(t *testing.T) { Message: "status updated", }, }, + { + name: "failure case - unauthorized", + userID: nil, + taskID: "1", + requestBody: dto.UpdateStatusRequest{ + Status: "completed", + }, + setupMock: func() *task_service.TaskServiceMock { + return new(task_service.TaskServiceMock) + }, + expectedStatus: http.StatusUnauthorized, + expectedBody: common.ErrorResponse{ + Message: "unauthorized", + }, + }, { name: "failure case - invalid ID", + userID: intPtr(123), taskID: "invalid", requestBody: dto.UpdateStatusRequest{ Status: "completed", @@ -365,6 +420,7 @@ func TestTaskHandler_UpdateStatus(t *testing.T) { }, { name: "failure case - ID less than 1", + userID: intPtr(123), taskID: "0", requestBody: dto.UpdateStatusRequest{ Status: "completed", @@ -379,6 +435,7 @@ func TestTaskHandler_UpdateStatus(t *testing.T) { }, { name: "failure case - invalid JSON", + userID: intPtr(123), taskID: "1", requestBody: `{"status": }`, // malformed JSON setupMock: func() *task_service.TaskServiceMock { @@ -391,6 +448,7 @@ func TestTaskHandler_UpdateStatus(t *testing.T) { }, { name: "failure case - invalid status", + userID: intPtr(123), taskID: "1", requestBody: dto.UpdateStatusRequest{ Status: "invalid-status", @@ -402,13 +460,14 @@ func TestTaskHandler_UpdateStatus(t *testing.T) { }, { name: "failure case - task not found", + userID: intPtr(123), taskID: "999", requestBody: dto.UpdateStatusRequest{ Status: "completed", }, setupMock: func() *task_service.TaskServiceMock { mockService := new(task_service.TaskServiceMock) - mockService.On("UpdateStatus", 999, "completed").Return(gorm.ErrRecordNotFound) + mockService.On("UpdateStatus", 123, 999, "completed").Return(gorm.ErrRecordNotFound) return mockService }, expectedStatus: http.StatusNotFound, @@ -418,13 +477,14 @@ func TestTaskHandler_UpdateStatus(t *testing.T) { }, { name: "failure case - service error", + userID: intPtr(123), taskID: "1", requestBody: dto.UpdateStatusRequest{ Status: "completed", }, setupMock: func() *task_service.TaskServiceMock { mockService := new(task_service.TaskServiceMock) - mockService.On("UpdateStatus", 1, "completed").Return(errors.New("service error")) + mockService.On("UpdateStatus", 123, 1, "completed").Return(errors.New("service error")) return mockService }, expectedStatus: http.StatusBadRequest, @@ -441,6 +501,15 @@ func TestTaskHandler_UpdateStatus(t *testing.T) { handler := NewTaskHandler(mockService, mockAuth) router := setupGin() + + // Middleware to inject userID if set + router.Use(func(c *gin.Context) { + if tt.userID != nil { + c.Set("userID", *tt.userID) + } + c.Next() + }) + router.PATCH("/tasks/:id/status", handler.UpdateStatus) var body []byte @@ -466,7 +535,6 @@ func TestTaskHandler_UpdateStatus(t *testing.T) { err = json.Unmarshal(w.Body.Bytes(), &responseBody) assert.NoError(t, err) - // Special handling for validation errors if tt.name == "failure case - invalid status" { errorResp := make(map[string]any) json.Unmarshal(w.Body.Bytes(), &errorResp) @@ -488,6 +556,10 @@ func TestTaskHandler_UpdateStatus(t *testing.T) { } } +func intPtr(i int) *int { + return &i +} + func TestTaskHandler_Delete(t *testing.T) { tests := []struct { name string @@ -501,7 +573,8 @@ func TestTaskHandler_Delete(t *testing.T) { taskID: "1", setupMock: func() *task_service.TaskServiceMock { mockService := new(task_service.TaskServiceMock) - mockService.On("Delete", 1).Return(nil) + // Expect both userID and taskID + mockService.On("Delete", 1, 1).Return(nil) return mockService }, expectedStatus: http.StatusOK, @@ -536,7 +609,7 @@ func TestTaskHandler_Delete(t *testing.T) { taskID: "999", setupMock: func() *task_service.TaskServiceMock { mockService := new(task_service.TaskServiceMock) - mockService.On("Delete", 999).Return(gorm.ErrRecordNotFound) + mockService.On("Delete", 1, 999).Return(gorm.ErrRecordNotFound) return mockService }, expectedStatus: http.StatusNotFound, @@ -549,7 +622,7 @@ func TestTaskHandler_Delete(t *testing.T) { taskID: "1", setupMock: func() *task_service.TaskServiceMock { mockService := new(task_service.TaskServiceMock) - mockService.On("Delete", 1).Return(errors.New("database error")) + mockService.On("Delete", 1, 1).Return(errors.New("database error")) return mockService }, expectedStatus: http.StatusInternalServerError, @@ -566,14 +639,21 @@ func TestTaskHandler_Delete(t *testing.T) { handler := NewTaskHandler(mockService, mockAuth) router := setupGin() + // Middleware to inject fake userID + router.Use(func(c *gin.Context) { + c.Set("userID", 1) + c.Next() + }) router.DELETE("/tasks/:id", handler.Delete) req := httptest.NewRequest(http.MethodDelete, "/tasks/"+tt.taskID, nil) w := httptest.NewRecorder() router.ServeHTTP(w, req) + // Assert status assert.Equal(t, tt.expectedStatus, w.Code) + // Assert response body var responseBody any err := json.Unmarshal(w.Body.Bytes(), &responseBody) assert.NoError(t, err) diff --git a/internal/handler/user/user_handler_test.go b/internal/handler/user/user_handler_test.go new file mode 100644 index 0000000..501b777 --- /dev/null +++ b/internal/handler/user/user_handler_test.go @@ -0,0 +1,116 @@ +package user_handler + +import ( + "bytes" + "encoding/json" + "errors" + "net/http" + "net/http/httptest" + "taskflow/internal/common" + "taskflow/internal/dto" + user_service "taskflow/internal/service/user" + "testing" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +// Setup Gin for tests +func setupGin() *gin.Engine { + gin.SetMode(gin.TestMode) + return gin.New() +} + +func TestUserHandler_Register(t *testing.T) { + tests := []struct { + name string + requestBody any + setupMock func() *user_service.UserServiceMock + expectedStatus int + expectedBody any + }{ + { + name: "success case", + requestBody: dto.CreateUserRequest{ + Email: "test@example.com", + Password: "password", + }, + setupMock: func() *user_service.UserServiceMock { + mockSvc := new(user_service.UserServiceMock) + mockSvc.On("CreateUser", mock.Anything).Return(&dto.CreateUserResponse{ + ID: 1, + Email: "test@example.com", + }, nil) + return mockSvc + }, + expectedStatus: http.StatusCreated, + expectedBody: dto.CreateUserResponse{ + ID: 1, + Email: "test@example.com", + }, + }, + { + name: "failure case - email exists", + requestBody: dto.CreateUserRequest{ + Email: "exist@example.com", + Password: "password", + }, + setupMock: func() *user_service.UserServiceMock { + mockSvc := new(user_service.UserServiceMock) + mockSvc.On("CreateUser", mock.Anything).Return(nil, errors.New("email already exists")) + return mockSvc + }, + expectedStatus: http.StatusConflict, + expectedBody: common.ErrorResponse{ + Message: "email already exists", + }, + }, + { + name: "failure case - invalid JSON", + requestBody: `{"email":}`, // malformed + setupMock: func() *user_service.UserServiceMock { return new(user_service.UserServiceMock) }, + expectedStatus: http.StatusBadRequest, + expectedBody: common.ErrorResponse{ + Message: "invalid character '}' looking for beginning of value", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockSvc := tt.setupMock() + handler := NewUserHandler(mockSvc, nil) + + router := setupGin() + router.POST("/auth/register", handler.Register) + + var body []byte + var err error + if str, ok := tt.requestBody.(string); ok { + body = []byte(str) + } else { + body, err = json.Marshal(tt.requestBody) + assert.NoError(t, err) + } + + req := httptest.NewRequest(http.MethodPost, "/auth/register", bytes.NewBuffer(body)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + assert.Equal(t, tt.expectedStatus, w.Code) + + var responseBody any + err = json.Unmarshal(w.Body.Bytes(), &responseBody) + assert.NoError(t, err) + + expectedBytes, _ := json.Marshal(tt.expectedBody) + var expectedResponse any + _ = json.Unmarshal(expectedBytes, &expectedResponse) + + assert.Equal(t, expectedResponse, responseBody) + mockSvc.AssertExpectations(t) + }) + } +} diff --git a/internal/repository/gorm/gorm_task/task_repository.go b/internal/repository/gorm/gorm_task/task_repository.go index 9b78738..74bd777 100644 --- a/internal/repository/gorm/gorm_task/task_repository.go +++ b/internal/repository/gorm/gorm_task/task_repository.go @@ -21,17 +21,18 @@ func (r *TaskRepository) Create(t *task.Task) error { return r.db.Create(t).Error } -func (r *TaskRepository) GetByID(id int) (*task.Task, error) { +func (r *TaskRepository) GetByID(userID int, id int) (*task.Task, error) { var t task.Task - if err := r.db.First(&t, id).Error; err != nil { + err := r.db.Where("id = ? AND user_id = ?", id, userID).First(&t).Error + if err != nil { return nil, err } return &t, nil } -func (r *TaskRepository) List() ([]task.Task, error) { +func (r *TaskRepository) List(userID int) ([]task.Task, error) { var tasks []task.Task - if err := r.db.Find(&tasks).Error; err != nil { + if err := r.db.Where("user_id = ?", userID).Find(&tasks).Error; err != nil { return nil, err } return tasks, nil @@ -41,10 +42,12 @@ func (r *TaskRepository) Update(t *task.Task) error { return r.db.Save(t).Error } -func (r *TaskRepository) Delete(id int) error { - return r.db.Delete(&task.Task{}, id).Error +func (r *TaskRepository) Delete(userID int, id int) error { + return r.db. + Where("user_id = ?", userID). + Delete(&task.Task{}, id).Error } -func (r *TaskRepository) UpdateStatus(id int, status string) error { - return r.db.Model(&task.Task{}).Where("id = ?", id).Update("status", status).Error +func (r *TaskRepository) UpdateStatus(userID int, id int, status string) error { + return r.db.Model(&task.Task{}).Where("user_id = ? AND id = ?", userID, id).Update("status", status).Error } diff --git a/internal/repository/gorm/gorm_task/task_repository_interface.go b/internal/repository/gorm/gorm_task/task_repository_interface.go index 742c5fb..76506ee 100644 --- a/internal/repository/gorm/gorm_task/task_repository_interface.go +++ b/internal/repository/gorm/gorm_task/task_repository_interface.go @@ -1,12 +1,14 @@ package gorm_task -import "taskflow/internal/domain/task" +import ( + "taskflow/internal/domain/task" +) type TaskRepositoryInterface interface { Create(task *task.Task) error - GetByID(id int) (*task.Task, error) - List() ([]task.Task, error) + GetByID(userID int, id int) (*task.Task, error) + List(userID int) ([]task.Task, error) Update(task *task.Task) error - Delete(id int) error - UpdateStatus(id int, status string) error + Delete(userID int, id int) error + UpdateStatus(userID int, id int, status string) error } diff --git a/internal/repository/gorm/gorm_task/task_repository_mock.go b/internal/repository/gorm/gorm_task/task_repository_mock.go index a63edc2..ceba1ad 100644 --- a/internal/repository/gorm/gorm_task/task_repository_mock.go +++ b/internal/repository/gorm/gorm_task/task_repository_mock.go @@ -22,13 +22,13 @@ func (m *TaskRepoMock) Create(task *task.Task) error { // if args.Get(0) == nil { // return nil, args.Error(1) // } -func (m *TaskRepoMock) GetByID(id int) (*task.Task, error) { - args := m.Called(id) +func (m *TaskRepoMock) GetByID(userID int, id int) (*task.Task, error) { + args := m.Called(userID, id) return args.Get(0).(*task.Task), args.Error(1) } -func (m *TaskRepoMock) List() ([]task.Task, error) { - args := m.Called() +func (m *TaskRepoMock) List(userID int) ([]task.Task, error) { + args := m.Called(userID) return args.Get(0).([]task.Task), args.Error(1) } @@ -37,14 +37,14 @@ func (m *TaskRepoMock) Update(task *task.Task) error { return args.Error(0) } -func (m *TaskRepoMock) Delete(id int) error { +func (m *TaskRepoMock) Delete(userID int, id int) error { - args := m.Called(id) + args := m.Called(userID, id) return args.Error(0) } -func (m *TaskRepoMock) UpdateStatus(id int, status string) error { +func (m *TaskRepoMock) UpdateStatus(userID int, id int, status string) error { - args := m.Called(id, status) + args := m.Called(userID, id, status) return args.Error(0) } diff --git a/internal/repository/gorm/gorm_task/task_repository_test.go b/internal/repository/gorm/gorm_task/task_repository_test.go index dea5a7f..62a0129 100644 --- a/internal/repository/gorm/gorm_task/task_repository_test.go +++ b/internal/repository/gorm/gorm_task/task_repository_test.go @@ -71,65 +71,76 @@ func TestTaskRepository_GetByID(t *testing.T) { taskToCreate := task.Task{ Task: "Buy Milk", Status: "pending", + UserID: 1, } err := db.Create(&taskToCreate).Error - require.NoError(t, err) // now ID is set + require.NoError(t, err) r := NewTaskRepository(db) - got, gotErr := r.GetByID(taskToCreate.ID) + got, gotErr := r.GetByID(1, taskToCreate.ID) assert.NoError(t, gotErr) assert.NotNil(t, got) assert.Equal(t, taskToCreate.ID, got.ID) assert.Equal(t, taskToCreate.Task, got.Task) assert.Equal(t, taskToCreate.Status, got.Status) + assert.Equal(t, taskToCreate.UserID, got.UserID) }) t.Run("non-existing id", func(t *testing.T) { db := setupTestDB(t) r := NewTaskRepository(db) - got, gotErr := r.GetByID(9999) + got, gotErr := r.GetByID(1, 9999) assert.Error(t, gotErr) assert.Nil(t, got) }) } -func TestTaskRepository_List(t *testing.T) { +func TestTaskRepository_List(t *testing.T) { tasks := []task.Task{ - {Task: "Buy Milk", Status: "pending"}, - {Task: "Buy Milk 2", Status: "pending"}, - {Task: "Buy Milk 3", Status: "pending"}, + {Task: "Buy Milk", Status: "pending", UserID: 1}, + {Task: "Buy Milk 2", Status: "pending", UserID: 1}, + {Task: "Buy Milk 3", Status: "pending", UserID: 2}, // different user } - t.Run("successful response", func(t *testing.T) { + t.Run("successful response with userID filter", func(t *testing.T) { db := setupTestDB(t) r := NewTaskRepository(db) for i := range tasks { - err := db.Create(&tasks[i]).Error - require.NoError(t, err) + require.NoError(t, db.Create(&tasks[i]).Error) } - got, gotErr := r.List() - assert.NoError(t, gotErr) - assert.Len(t, got, len(tasks)) + got, err := r.List(1) + assert.NoError(t, err) + assert.Len(t, got, 2) // only UserID 1 tasks - for i, taskItem := range tasks { - assert.Equal(t, taskItem.Task, got[i].Task) - assert.Equal(t, taskItem.Status, got[i].Status) - assert.NotZero(t, got[i].ID) - assert.NotZero(t, got[i].CreatedAt) + for _, tsk := range got { + assert.Equal(t, 1, tsk.UserID) + assert.NotZero(t, tsk.ID) + assert.NotZero(t, tsk.CreatedAt) } }) - t.Run("Empyt response", func(t *testing.T) { + + t.Run("empty response", func(t *testing.T) { db := setupTestDB(t) r := NewTaskRepository(db) - got, gotErr := r.List() - assert.NoError(t, gotErr) + got, err := r.List(99) // userID with no tasks + assert.NoError(t, err) assert.Len(t, got, 0) + }) + t.Run("database error", func(t *testing.T) { + db := setupTestDB(t) + // drop table to simulate failure + db.Migrator().DropTable(&task.Task{}) + r := NewTaskRepository(db) + + got, err := r.List(1) + assert.Error(t, err) + assert.Nil(t, got) }) } @@ -172,35 +183,36 @@ func TestTaskRepository_Update(t *testing.T) { func TestTaskRepository_Delete(t *testing.T) { t.Run("successful delete", func(t *testing.T) { db := setupTestDB(t) - taskToCreate := task.Task{Task: "Buy Milk", Status: "pending"} + taskToCreate := task.Task{UserID: 1, Task: "Buy Milk", Status: "pending"} require.NoError(t, db.Create(&taskToCreate).Error) r := NewTaskRepository(db) - err := r.Delete(taskToCreate.ID) + err := r.Delete(1, taskToCreate.ID) assert.NoError(t, err) var fetched task.Task err = db.First(&fetched, taskToCreate.ID).Error - assert.Error(t, err) // should not be found + assert.Error(t, err) assert.True(t, errors.Is(err, gorm.ErrRecordNotFound)) }) - t.Run("delete non-existing id", func(t *testing.T) { + t.Run("delete non-existing task", func(t *testing.T) { db := setupTestDB(t) r := NewTaskRepository(db) - err := r.Delete(9999) - assert.NoError(t, err) // GORM Delete does not return error if record not found + err := r.Delete(1, 9999) + assert.NoError(t, err) // GORM does not error if record not found }) } func TestTaskRepository_UpdateStatus(t *testing.T) { t.Run("successful status update", func(t *testing.T) { db := setupTestDB(t) - taskToCreate := task.Task{Task: "Buy Milk", Status: "pending"} + + taskToCreate := task.Task{UserID: 1, Task: "Buy Milk", Status: "pending"} require.NoError(t, db.Create(&taskToCreate).Error) r := NewTaskRepository(db) - err := r.UpdateStatus(taskToCreate.ID, "completed") + err := r.UpdateStatus(1, taskToCreate.ID, "completed") assert.NoError(t, err) var updated task.Task @@ -211,7 +223,9 @@ func TestTaskRepository_UpdateStatus(t *testing.T) { t.Run("update status non-existing task", func(t *testing.T) { db := setupTestDB(t) r := NewTaskRepository(db) - err := r.UpdateStatus(9999, "completed") - assert.NoError(t, err) // GORM does nothing but does not error + + // Non-existing userID + taskID + err := r.UpdateStatus(1, 9999, "completed") + assert.NoError(t, err) // GORM returns nil error if no rows affected }) } diff --git a/internal/service/task/task_service.go b/internal/service/task/task_service.go index 0239d80..c0868c8 100644 --- a/internal/service/task/task_service.go +++ b/internal/service/task/task_service.go @@ -17,20 +17,30 @@ func NewTaskService(repo gorm_task.TaskRepositoryInterface) *TaskService { var _ TaskServiceInterface = (*TaskService)(nil) -func (s *TaskService) CreateTask(taskRequest *dto.CreateTaskRequest) error { +func (s *TaskService) CreateTask(userID int, taskRequest *dto.CreateTaskRequest) error { + if userID == 0 { + return errors.New("invalid user") + } + if taskRequest.Task == "" { return errors.New("task name cannot be empty") } task := task.Task{ + UserID: userID, Task: taskRequest.Task, Status: "pending", } return s.repo.Create(&task) } -func (s *TaskService) GetTask(id int) (dto.GetTaskResponse, error) { - t, err := s.repo.GetByID(id) + +func (s *TaskService) GetTask(userID int, id int) (dto.GetTaskResponse, error) { + if userID == 0 { + return dto.GetTaskResponse{}, errors.New("invalid user") + } + + t, err := s.repo.GetByID(userID, id) if err != nil { return dto.GetTaskResponse{}, err } @@ -41,8 +51,12 @@ func (s *TaskService) GetTask(id int) (dto.GetTaskResponse, error) { }, nil } -func (s *TaskService) ListTasks() (dto.ListTasksResponse, error) { - tasks, err := s.repo.List() // This returns []task.Task +func (s *TaskService) ListTasks(userID int) (dto.ListTasksResponse, error) { + if userID == 0 { + return dto.ListTasksResponse{}, errors.New("invalid user") + } + + tasks, err := s.repo.List(userID) // This returns []task.Task if err != nil { return dto.ListTasksResponse{}, err } @@ -62,13 +76,13 @@ func (s *TaskService) ListTasks() (dto.ListTasksResponse, error) { }, nil } -func (s *TaskService) UpdateStatus(id int, status string) error { +func (s *TaskService) UpdateStatus(userID int, id int, status string) error { if status != "pending" && status != "completed" { return errors.New("invalid status") } - return s.repo.UpdateStatus(id, status) + return s.repo.UpdateStatus(userID, id, status) } -func (s *TaskService) Delete(id int) error { - return s.repo.Delete(id) +func (s *TaskService) Delete(userID int, id int) error { + return s.repo.Delete(userID, id) } diff --git a/internal/service/task/task_service_mock.go b/internal/service/task/task_service_mock.go index e4c4071..b9b867a 100644 --- a/internal/service/task/task_service_mock.go +++ b/internal/service/task/task_service_mock.go @@ -12,24 +12,24 @@ type TaskServiceMock struct { var _ TaskServiceInterface = (*TaskServiceMock)(nil) -func (m *TaskServiceMock) CreateTask(taskRequest *dto.CreateTaskRequest) error { +func (m *TaskServiceMock) CreateTask(userID int, taskRequest *dto.CreateTaskRequest) error { - args := m.Called(taskRequest) + args := m.Called(userID, taskRequest) return args.Error(0) } -func (m *TaskServiceMock) GetTask(id int) (dto.GetTaskResponse, error) { - args := m.Called(id) +func (m *TaskServiceMock) GetTask(userID int, id int) (dto.GetTaskResponse, error) { + args := m.Called(userID, id) return args.Get(0).(dto.GetTaskResponse), args.Error(1) } -func (m *TaskServiceMock) ListTasks() (dto.ListTasksResponse, error) { - args := m.Called() +func (m *TaskServiceMock) ListTasks(userID int) (dto.ListTasksResponse, error) { + args := m.Called(userID) return args.Get(0).(dto.ListTasksResponse), args.Error(1) } -func (m *TaskServiceMock) UpdateStatus(id int, status string) error { - args := m.Called(id, status) +func (m *TaskServiceMock) UpdateStatus(userID int, id int, status string) error { + args := m.Called(userID, id, status) return args.Error(0) } -func (m *TaskServiceMock) Delete(id int) error { - args := m.Called(id) +func (m *TaskServiceMock) Delete(userID int, id int) error { + args := m.Called(userID, id) return args.Error(0) } diff --git a/internal/service/task/task_service_test.go b/internal/service/task/task_service_test.go index c143a33..fb779b5 100644 --- a/internal/service/task/task_service_test.go +++ b/internal/service/task/task_service_test.go @@ -15,47 +15,67 @@ import ( func TestTaskService_CreateTask(t *testing.T) { tests := []struct { name string + userID int taskRequest *dto.CreateTaskRequest setupMock func() *gorm_task.TaskRepoMock wantErr bool + errMessage string }{ { - name: "success case - create task Buy Milk", + name: "success case - create task Buy Milk", + userID: 1, taskRequest: &dto.CreateTaskRequest{ Task: "Buy Milk", }, setupMock: func() *gorm_task.TaskRepoMock { mockRepo := new(gorm_task.TaskRepoMock) mockRepo.On("Create", mock.MatchedBy(func(tk *task.Task) bool { - return tk.Task == "Buy Milk" + return tk.UserID == 1 && tk.Task == "Buy Milk" && tk.Status == "pending" })).Return(nil) return mockRepo }, - wantErr: false, + wantErr: false, + errMessage: "", }, { - name: "failure case - Empty Task", + name: "failure case - empty task", + userID: 1, taskRequest: &dto.CreateTaskRequest{ Task: "", }, setupMock: func() *gorm_task.TaskRepoMock { return new(gorm_task.TaskRepoMock) }, - wantErr: true, + wantErr: true, + errMessage: "task name cannot be empty", }, { - name: "failure case - database error", + name: "failure case - invalid user", + taskRequest: &dto.CreateTaskRequest{ + Task: "Buy Milk", // <-- non-empty task + }, + setupMock: func() *gorm_task.TaskRepoMock { + return new(gorm_task.TaskRepoMock) + }, + userID: 0, // <-- invalid user triggers the error + wantErr: true, + errMessage: "invalid user", + }, + { + name: "failure case - database error", + userID: 2, taskRequest: &dto.CreateTaskRequest{ Task: "Buy Eggs", }, setupMock: func() *gorm_task.TaskRepoMock { mockRepo := new(gorm_task.TaskRepoMock) mockRepo.On("Create", mock.MatchedBy(func(tk *task.Task) bool { - return tk.Task == "Buy Eggs" + return tk.UserID == 2 && tk.Task == "Buy Eggs" && tk.Status == "pending" })).Return(errors.New("db error")) return mockRepo }, - wantErr: true, + wantErr: true, + errMessage: "db error", }, } @@ -64,10 +84,11 @@ func TestTaskService_CreateTask(t *testing.T) { mockRepo := tt.setupMock() service := NewTaskService(mockRepo) - err := service.CreateTask(tt.taskRequest) + err := service.CreateTask(tt.userID, tt.taskRequest) if tt.wantErr { assert.Error(t, err) + assert.EqualError(t, err, tt.errMessage) } else { assert.NoError(t, err) } @@ -81,19 +102,22 @@ func TestTaskService_GetTask(t *testing.T) { tests := []struct { name string // description of this test case id int + userID int setupMock func() *gorm_task.TaskRepoMock want dto.GetTaskResponse wantErr bool }{ { - name: "success case", - id: 1, + name: "success case", + id: 1, + userID: 1, setupMock: func() *gorm_task.TaskRepoMock { mockRepo := new(gorm_task.TaskRepoMock) - mockRepo.On("GetByID", 1).Return(&task.Task{ + mockRepo.On("GetByID", 1, 1).Return(&task.Task{ ID: 1, Task: "Buy Milk", Status: "pending", + UserID: 1, }, nil) return mockRepo }, @@ -105,11 +129,12 @@ func TestTaskService_GetTask(t *testing.T) { wantErr: false, }, { - name: "failure case - task not found", - id: 2, + name: "failure case - task not found", + id: 2, + userID: 1, setupMock: func() *gorm_task.TaskRepoMock { mockRepo := new(gorm_task.TaskRepoMock) - mockRepo.On("GetByID", 2).Return((*task.Task)(nil), errors.New("not found")) + mockRepo.On("GetByID", 1, 2).Return((*task.Task)(nil), errors.New("not found")) return mockRepo }, want: dto.GetTaskResponse{}, @@ -121,13 +146,12 @@ func TestTaskService_GetTask(t *testing.T) { t.Run(tt.name, func(t *testing.T) { mockRepo := tt.setupMock() s := NewTaskService(mockRepo) - got, gotErr := s.GetTask(tt.id) + got, gotErr := s.GetTask(tt.userID, tt.id) if tt.wantErr { assert.Error(t, gotErr) assert.Equal(t, dto.GetTaskResponse{}, got) } else { - assert.NotZero(t, got) assert.NoError(t, gotErr) assert.Equal(t, tt.want, got) } @@ -139,20 +163,19 @@ func TestTaskService_GetTask(t *testing.T) { func TestTaskService_ListTasks(t *testing.T) { tests := []struct { - name string // description of this test case - // Named input parameters for receiver constructor. + name string + userID int setupMock func() *gorm_task.TaskRepoMock want dto.ListTasksResponse wantErr bool }{ { - name: "success", + name: "success - single task", + userID: 1, setupMock: func() *gorm_task.TaskRepoMock { mockRepo := new(gorm_task.TaskRepoMock) - mockRepo.On("List").Return([]task.Task{ - { - ID: 1, Task: "Buy milk", Status: "pending", CreatedAt: time.Now(), - }, + mockRepo.On("List", 1).Return([]task.Task{ + {ID: 1, Task: "Buy milk", Status: "pending", CreatedAt: time.Now()}, }, nil) return mockRepo }, @@ -164,48 +187,42 @@ func TestTaskService_ListTasks(t *testing.T) { wantErr: false, }, { - name: "success - multiple tasks", + name: "failure - invalid userID", + userID: 0, setupMock: func() *gorm_task.TaskRepoMock { - mockRepo := new(gorm_task.TaskRepoMock) - mockRepo.On("List").Return([]task.Task{ - {ID: 1, Task: "Buy Milk", Status: "pending", CreatedAt: time.Now()}, - {ID: 2, Task: "Buy Eggs", Status: "completed", CreatedAt: time.Now()}, - }, nil) - return mockRepo + return new(gorm_task.TaskRepoMock) // repo should not be called }, - want: dto.ListTasksResponse{ - Tasks: []dto.GetTaskResponse{ - {ID: 1, Task: "Buy Milk", Status: "pending"}, - {ID: 2, Task: "Buy Eggs", Status: "completed"}, - }, - }, - wantErr: false, + want: dto.ListTasksResponse{}, + wantErr: true, }, { - name: "failure - db error", + name: "failure - db error", + userID: 1, setupMock: func() *gorm_task.TaskRepoMock { mockRepo := new(gorm_task.TaskRepoMock) - mockRepo.On("List").Return(([]task.Task)(nil), errors.New("db error")) + // return empty slice instead of nil + mockRepo.On("List", 1).Return([]task.Task{}, errors.New("db error")) return mockRepo }, want: dto.ListTasksResponse{}, wantErr: true, }, } + for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { mockRepo := tt.setupMock() s := NewTaskService(mockRepo) - got, gotErr := s.ListTasks() + got, gotErr := s.ListTasks(tt.userID) if tt.wantErr { assert.Error(t, gotErr) assert.Equal(t, dto.ListTasksResponse{}, got) } else { - assert.NotZero(t, got) assert.NoError(t, gotErr) assert.Equal(t, tt.want, got) } + mockRepo.AssertExpectations(t) }) } @@ -215,19 +232,22 @@ func TestTaskService_UpdateStatus(t *testing.T) { tests := []struct { name string setupMock func() *gorm_task.TaskRepoMock + userID int id int status string wantErr bool }{ { name: "success - pending", + userID: 123, id: 1, status: "pending", setupMock: func() *gorm_task.TaskRepoMock { mockRepo := new(gorm_task.TaskRepoMock) mockRepo.On("UpdateStatus", - mock.MatchedBy(func(id int) bool { return id > 0 }), - mock.MatchedBy(func(status string) bool { return status == "pending" || status == "completed" }), + 123, + 1, + "pending", ).Return(nil) return mockRepo }, @@ -235,20 +255,19 @@ func TestTaskService_UpdateStatus(t *testing.T) { }, { name: "success - completed", + userID: 123, id: 2, status: "completed", setupMock: func() *gorm_task.TaskRepoMock { mockRepo := new(gorm_task.TaskRepoMock) - mockRepo.On("UpdateStatus", - mock.Anything, - mock.Anything, - ).Return(nil) + mockRepo.On("UpdateStatus", 123, 2, "completed").Return(nil) return mockRepo }, wantErr: false, }, { name: "failure - invalid status", + userID: 123, id: 3, status: "invalid", setupMock: func() *gorm_task.TaskRepoMock { @@ -258,11 +277,12 @@ func TestTaskService_UpdateStatus(t *testing.T) { }, { name: "failure - repo error", + userID: 123, id: 4, status: "pending", setupMock: func() *gorm_task.TaskRepoMock { mockRepo := new(gorm_task.TaskRepoMock) - mockRepo.On("UpdateStatus", mock.Anything, mock.Anything).Return(errors.New("db error")) + mockRepo.On("UpdateStatus", 123, 4, "pending").Return(errors.New("db error")) return mockRepo }, wantErr: true, @@ -273,7 +293,8 @@ func TestTaskService_UpdateStatus(t *testing.T) { t.Run(tt.name, func(t *testing.T) { mockRepo := tt.setupMock() s := NewTaskService(mockRepo) - gotErr := s.UpdateStatus(tt.id, tt.status) + + gotErr := s.UpdateStatus(tt.userID, tt.id, tt.status) if tt.wantErr { assert.Error(t, gotErr) @@ -282,10 +303,6 @@ func TestTaskService_UpdateStatus(t *testing.T) { } mockRepo.AssertExpectations(t) - - if tt.status == "pending" || tt.status == "completed" { - mockRepo.AssertExpectations(t) - } }) } } @@ -294,51 +311,59 @@ func TestTaskService_Delete(t *testing.T) { tests := []struct { name string setupMock func() *gorm_task.TaskRepoMock + userID int id int wantErr bool }{ { - name: "success", - id: 1, + name: "success", + userID: 1, + id: 1, setupMock: func() *gorm_task.TaskRepoMock { mockRepo := new(gorm_task.TaskRepoMock) - mockRepo.On("Delete", 1).Return(nil) + mockRepo.On("Delete", 1, 1).Return(nil) return mockRepo }, wantErr: false, }, - { - name: "failure - repo error", - id: 2, + name: "failure - repo error", + userID: 1, + id: 2, setupMock: func() *gorm_task.TaskRepoMock { mockRepo := new(gorm_task.TaskRepoMock) - mockRepo.On("Delete", 2).Return(errors.New("db error")) + mockRepo.On("Delete", 1, 2).Return(errors.New("db error")) return mockRepo }, wantErr: true, }, - { - name: "failure - delete non-existing task", - id: 3, + name: "failure - delete non-existing task", + userID: 1, + id: 3, setupMock: func() *gorm_task.TaskRepoMock { mockRepo := new(gorm_task.TaskRepoMock) - mockRepo.On("Delete", 3).Return(errors.New("not found")) + mockRepo.On("Delete", 1, 3).Return(errors.New("not found")) return mockRepo }, wantErr: true, }, } + for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { mockRepo := tt.setupMock() s := NewTaskService(mockRepo) - err := s.Delete(tt.id) - assert.Equal(t, tt.wantErr, err != nil, "error mismatch") + err := s.Delete(tt.userID, tt.id) - assert.True(t, mockRepo.AssertExpectations(t)) + if tt.wantErr { + assert.Error(t, err, "expected an error but got nil") + } else { + assert.NoError(t, err, "expected no error but got one") + } + + mockRepo.AssertExpectations(t) }) } } diff --git a/internal/service/task/tasks_service_interface.go b/internal/service/task/tasks_service_interface.go index 0457c01..2234933 100644 --- a/internal/service/task/tasks_service_interface.go +++ b/internal/service/task/tasks_service_interface.go @@ -1,11 +1,13 @@ package task_service -import "taskflow/internal/dto" +import ( + "taskflow/internal/dto" +) type TaskServiceInterface interface { - CreateTask(taskRequest *dto.CreateTaskRequest) error - GetTask(id int) (dto.GetTaskResponse, error) - ListTasks() (dto.ListTasksResponse, error) - UpdateStatus(id int, status string) error - Delete(id int) error + CreateTask(userID int, taskRequest *dto.CreateTaskRequest) error + GetTask(userID int, id int) (dto.GetTaskResponse, error) + ListTasks(userID int) (dto.ListTasksResponse, error) + UpdateStatus(userID int, id int, status string) error + Delete(userID int, id int) error } diff --git a/internal/service/user/user_service_mock.go b/internal/service/user/user_service_mock.go index 369317f..35ae7ed 100644 --- a/internal/service/user/user_service_mock.go +++ b/internal/service/user/user_service_mock.go @@ -14,20 +14,38 @@ var _ UserServiceInterface = (*UserServiceMock)(nil) func (m *UserServiceMock) CreateUser(req *dto.CreateUserRequest) (*dto.CreateUserResponse, error) { args := m.Called(req) - return args.Get(0).(*dto.CreateUserResponse), nil + + var resp *dto.CreateUserResponse + if r := args.Get(0); r != nil { + resp = r.(*dto.CreateUserResponse) + } + + return resp, args.Error(1) } func (m *UserServiceMock) AuthenticateUser(req *dto.AuthRequest) (*dto.AuthResponse, error) { args := m.Called(req) - return args.Get(0).(*dto.AuthResponse), nil + var resp *dto.AuthResponse + if r := args.Get(0); r != nil { + resp = r.(*dto.AuthResponse) + } + return resp, args.Error(1) } func (m *UserServiceMock) UpdatePassword(req *dto.UpdatePasswordRequest) (*dto.UpdatePasswordResponse, error) { args := m.Called(req) - return args.Get(0).(*dto.UpdatePasswordResponse), nil + var resp *dto.UpdatePasswordResponse + if r := args.Get(0); r != nil { + resp = r.(*dto.UpdatePasswordResponse) + } + return resp, args.Error(1) } func (m *UserServiceMock) DeleteUser(req *dto.DeleteUserRequest) (*dto.DeleteUserResponse, error) { args := m.Called(req) - return args.Get(0).(*dto.DeleteUserResponse), nil + var resp *dto.DeleteUserResponse + if r := args.Get(0); r != nil { + resp = r.(*dto.DeleteUserResponse) + } + return resp, args.Error(1) } diff --git a/main.go b/main.go index 1d6c5d5..0976389 100644 --- a/main.go +++ b/main.go @@ -76,7 +76,7 @@ func main() { sqlDB, _ := db.DB() defer sqlDB.Close() - if err := db.AutoMigrate(&task.Task{}, &user.User{}); err != nil { + if err := db.AutoMigrate(&user.User{}, &task.Task{}); err != nil { log.Fatalf("AutoMigrate failed: %v", err) }