From df7a239e20de13f7aa03e85b9991243804d5ed67 Mon Sep 17 00:00:00 2001 From: Anuj Khandelwal Date: Tue, 20 Aug 2024 23:46:54 +0530 Subject: [PATCH] chore: unit tests for session service (#725) --- .mockery.yaml | 4 + core/authenticate/session/mocks/repository.go | 288 ++++++++++++++++++ core/authenticate/session/service_test.go | 135 ++++++++ 3 files changed, 427 insertions(+) create mode 100644 core/authenticate/session/mocks/repository.go create mode 100644 core/authenticate/session/service_test.go diff --git a/.mockery.yaml b/.mockery.yaml index c7ce68830..5328a1614 100644 --- a/.mockery.yaml +++ b/.mockery.yaml @@ -30,4 +30,8 @@ packages: github.com/raystack/frontier/core/policy: config: dir: "core/policy/mocks" + all: true + github.com/raystack/frontier/core/authenticate/session: + config: + dir: "core/authenticate/session/mocks" all: true \ No newline at end of file diff --git a/core/authenticate/session/mocks/repository.go b/core/authenticate/session/mocks/repository.go new file mode 100644 index 000000000..45cd0de92 --- /dev/null +++ b/core/authenticate/session/mocks/repository.go @@ -0,0 +1,288 @@ +// Code generated by mockery v2.40.2. DO NOT EDIT. + +package mocks + +import ( + context "context" + + session "github.com/raystack/frontier/core/authenticate/session" + mock "github.com/stretchr/testify/mock" + + time "time" + + uuid "github.com/google/uuid" +) + +// Repository is an autogenerated mock type for the Repository type +type Repository struct { + mock.Mock +} + +type Repository_Expecter struct { + mock *mock.Mock +} + +func (_m *Repository) EXPECT() *Repository_Expecter { + return &Repository_Expecter{mock: &_m.Mock} +} + +// Delete provides a mock function with given fields: ctx, id +func (_m *Repository) Delete(ctx context.Context, id uuid.UUID) error { + ret := _m.Called(ctx, id) + + if len(ret) == 0 { + panic("no return value specified for Delete") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, uuid.UUID) error); ok { + r0 = rf(ctx, id) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Repository_Delete_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Delete' +type Repository_Delete_Call struct { + *mock.Call +} + +// Delete is a helper method to define mock.On call +// - ctx context.Context +// - id uuid.UUID +func (_e *Repository_Expecter) Delete(ctx interface{}, id interface{}) *Repository_Delete_Call { + return &Repository_Delete_Call{Call: _e.mock.On("Delete", ctx, id)} +} + +func (_c *Repository_Delete_Call) Run(run func(ctx context.Context, id uuid.UUID)) *Repository_Delete_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(uuid.UUID)) + }) + return _c +} + +func (_c *Repository_Delete_Call) Return(_a0 error) *Repository_Delete_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *Repository_Delete_Call) RunAndReturn(run func(context.Context, uuid.UUID) error) *Repository_Delete_Call { + _c.Call.Return(run) + return _c +} + +// DeleteExpiredSessions provides a mock function with given fields: ctx +func (_m *Repository) DeleteExpiredSessions(ctx context.Context) error { + ret := _m.Called(ctx) + + if len(ret) == 0 { + panic("no return value specified for DeleteExpiredSessions") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context) error); ok { + r0 = rf(ctx) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Repository_DeleteExpiredSessions_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DeleteExpiredSessions' +type Repository_DeleteExpiredSessions_Call struct { + *mock.Call +} + +// DeleteExpiredSessions is a helper method to define mock.On call +// - ctx context.Context +func (_e *Repository_Expecter) DeleteExpiredSessions(ctx interface{}) *Repository_DeleteExpiredSessions_Call { + return &Repository_DeleteExpiredSessions_Call{Call: _e.mock.On("DeleteExpiredSessions", ctx)} +} + +func (_c *Repository_DeleteExpiredSessions_Call) Run(run func(ctx context.Context)) *Repository_DeleteExpiredSessions_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context)) + }) + return _c +} + +func (_c *Repository_DeleteExpiredSessions_Call) Return(_a0 error) *Repository_DeleteExpiredSessions_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *Repository_DeleteExpiredSessions_Call) RunAndReturn(run func(context.Context) error) *Repository_DeleteExpiredSessions_Call { + _c.Call.Return(run) + return _c +} + +// Get provides a mock function with given fields: ctx, id +func (_m *Repository) Get(ctx context.Context, id uuid.UUID) (*session.Session, error) { + ret := _m.Called(ctx, id) + + if len(ret) == 0 { + panic("no return value specified for Get") + } + + var r0 *session.Session + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, uuid.UUID) (*session.Session, error)); ok { + return rf(ctx, id) + } + if rf, ok := ret.Get(0).(func(context.Context, uuid.UUID) *session.Session); ok { + r0 = rf(ctx, id) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*session.Session) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, uuid.UUID) error); ok { + r1 = rf(ctx, id) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Repository_Get_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Get' +type Repository_Get_Call struct { + *mock.Call +} + +// Get is a helper method to define mock.On call +// - ctx context.Context +// - id uuid.UUID +func (_e *Repository_Expecter) Get(ctx interface{}, id interface{}) *Repository_Get_Call { + return &Repository_Get_Call{Call: _e.mock.On("Get", ctx, id)} +} + +func (_c *Repository_Get_Call) Run(run func(ctx context.Context, id uuid.UUID)) *Repository_Get_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(uuid.UUID)) + }) + return _c +} + +func (_c *Repository_Get_Call) Return(_a0 *session.Session, _a1 error) *Repository_Get_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *Repository_Get_Call) RunAndReturn(run func(context.Context, uuid.UUID) (*session.Session, error)) *Repository_Get_Call { + _c.Call.Return(run) + return _c +} + +// Set provides a mock function with given fields: ctx, _a1 +func (_m *Repository) Set(ctx context.Context, _a1 *session.Session) error { + ret := _m.Called(ctx, _a1) + + if len(ret) == 0 { + panic("no return value specified for Set") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, *session.Session) error); ok { + r0 = rf(ctx, _a1) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Repository_Set_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Set' +type Repository_Set_Call struct { + *mock.Call +} + +// Set is a helper method to define mock.On call +// - ctx context.Context +// - _a1 *session.Session +func (_e *Repository_Expecter) Set(ctx interface{}, _a1 interface{}) *Repository_Set_Call { + return &Repository_Set_Call{Call: _e.mock.On("Set", ctx, _a1)} +} + +func (_c *Repository_Set_Call) Run(run func(ctx context.Context, _a1 *session.Session)) *Repository_Set_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*session.Session)) + }) + return _c +} + +func (_c *Repository_Set_Call) Return(_a0 error) *Repository_Set_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *Repository_Set_Call) RunAndReturn(run func(context.Context, *session.Session) error) *Repository_Set_Call { + _c.Call.Return(run) + return _c +} + +// UpdateValidity provides a mock function with given fields: ctx, id, validity +func (_m *Repository) UpdateValidity(ctx context.Context, id uuid.UUID, validity time.Duration) error { + ret := _m.Called(ctx, id, validity) + + if len(ret) == 0 { + panic("no return value specified for UpdateValidity") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, uuid.UUID, time.Duration) error); ok { + r0 = rf(ctx, id, validity) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Repository_UpdateValidity_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateValidity' +type Repository_UpdateValidity_Call struct { + *mock.Call +} + +// UpdateValidity is a helper method to define mock.On call +// - ctx context.Context +// - id uuid.UUID +// - validity time.Duration +func (_e *Repository_Expecter) UpdateValidity(ctx interface{}, id interface{}, validity interface{}) *Repository_UpdateValidity_Call { + return &Repository_UpdateValidity_Call{Call: _e.mock.On("UpdateValidity", ctx, id, validity)} +} + +func (_c *Repository_UpdateValidity_Call) Run(run func(ctx context.Context, id uuid.UUID, validity time.Duration)) *Repository_UpdateValidity_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(uuid.UUID), args[2].(time.Duration)) + }) + return _c +} + +func (_c *Repository_UpdateValidity_Call) Return(_a0 error) *Repository_UpdateValidity_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *Repository_UpdateValidity_Call) RunAndReturn(run func(context.Context, uuid.UUID, time.Duration) error) *Repository_UpdateValidity_Call { + _c.Call.Return(run) + return _c +} + +// NewRepository creates a new instance of Repository. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewRepository(t interface { + mock.TestingT + Cleanup(func()) +}) *Repository { + mock := &Repository{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/core/authenticate/session/service_test.go b/core/authenticate/session/service_test.go new file mode 100644 index 000000000..b564e6208 --- /dev/null +++ b/core/authenticate/session/service_test.go @@ -0,0 +1,135 @@ +package session_test + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/google/uuid" + "github.com/raystack/frontier/core/authenticate/session" + "github.com/raystack/frontier/core/authenticate/session/mocks" + "github.com/raystack/frontier/pkg/server/consts" + "github.com/raystack/salt/log" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "google.golang.org/grpc/metadata" +) + +func TestService_Create(t *testing.T) { + t.Run("should create a session when parameters are passed correctly", func(t *testing.T) { + mockRepository := mocks.NewRepository(t) + svc := session.NewService(log.NewLogrus(), mockRepository, 24*time.Hour) + + mockRepository.On("Set", mock.Anything, mock.AnythingOfType("*session.Session")).Run(func(args mock.Arguments) { + arg := args.Get(1) + r := arg.(*session.Session) + assert.Equal(t, r.UserID, "1") + }).Return(nil) + + userID := "1" + sess, err := svc.Create(context.Background(), userID) + + assert.Nil(t, err) + assert.Equal(t, sess.UserID, "1") + }) + + t.Run("should return an error when session is not successfully set", func(t *testing.T) { + mockRepository := mocks.NewRepository(t) + svc := session.NewService(log.NewLogrus(), mockRepository, 24*time.Hour) + + mockRepository.On("Set", mock.Anything, mock.AnythingOfType("*session.Session")).Run(func(args mock.Arguments) { + arg := args.Get(1) + r := arg.(*session.Session) + assert.Equal(t, r.UserID, "1") + }).Return(errors.New("internal-error")) + + userID := "1" + _, err := svc.Create(context.Background(), userID) + + assert.NotNil(t, err) + assert.Equal(t, err.Error(), "internal-error") + }) +} + +func TestService_Refresh(t *testing.T) { + t.Run("should refresh a session successfully", func(t *testing.T) { + mockRepository := mocks.NewRepository(t) + mockSessionID := uuid.New() + svc := session.NewService(log.NewLogrus(), mockRepository, 24*time.Hour) + + mockRepository.On("UpdateValidity", mock.Anything, mockSessionID, 24*time.Hour).Return(nil) + + err := svc.Refresh(context.Background(), mockSessionID) + + assert.Nil(t, err) + }) + + t.Run("should return an error if refresh fails", func(t *testing.T) { + mockRepository := mocks.NewRepository(t) + mockSessionID := uuid.New() + svc := session.NewService(log.NewLogrus(), mockRepository, 24*time.Hour) + + mockRepository.On("UpdateValidity", mock.Anything, mockSessionID, 24*time.Hour).Return(errors.New("internal-error")) + + err := svc.Refresh(context.Background(), mockSessionID) + + assert.NotNil(t, err) + assert.Equal(t, err.Error(), "internal-error") + }) +} + +func TestService_Delete(t *testing.T) { + t.Run("should delete a session successfully", func(t *testing.T) { + mockRepository := mocks.NewRepository(t) + mockSessionID := uuid.New() + svc := session.NewService(log.NewLogrus(), mockRepository, 24*time.Hour) + + mockRepository.On("Delete", mock.Anything, mockSessionID).Return(nil) + + err := svc.Delete(context.Background(), mockSessionID) + + assert.Nil(t, err) + }) + + t.Run("should return an error if deletion fails", func(t *testing.T) { + mockRepository := mocks.NewRepository(t) + mockSessionID := uuid.New() + svc := session.NewService(log.NewLogrus(), mockRepository, 24*time.Hour) + + mockRepository.On("Delete", mock.Anything, mockSessionID).Return(errors.New("internal-error")) + + err := svc.Delete(context.Background(), mockSessionID) + + assert.NotNil(t, err) + assert.Equal(t, err.Error(), "internal-error") + }) +} + +func TestService_ExtractFromContext(t *testing.T) { + t.Run("should be able to extract session from context if it is present", func(t *testing.T) { + mockRepository := mocks.NewRepository(t) + mockSessionID := uuid.New() + svc := session.NewService(log.NewLogrus(), mockRepository, 24*time.Hour) + + md := metadata.New(map[string]string{consts.SessionIDGatewayKey: mockSessionID.String(), "key2": "val2"}) + ctx := metadata.NewIncomingContext(context.Background(), md) + + mockRepository.On("Get", ctx, mockSessionID).Return(&session.Session{ + ID: mockSessionID, + }, nil) + + sess, err := svc.ExtractFromContext(ctx) + assert.Nil(t, err) + assert.Equal(t, sess.ID, mockSessionID) + }) + + t.Run("should return an error if session is not present in context metadata", func(t *testing.T) { + mockRepository := mocks.NewRepository(t) + svc := session.NewService(log.NewLogrus(), mockRepository, 24*time.Hour) + + _, err := svc.ExtractFromContext(context.Background()) + assert.NotNil(t, err) + assert.Equal(t, err, session.ErrNoSession) + }) +}