diff --git a/api/message.go b/api/message.go index 3225c3d9..fc2dd22c 100644 --- a/api/message.go +++ b/api/message.go @@ -16,9 +16,9 @@ import ( // The MessageDatabase interface for encapsulating database access. type MessageDatabase interface { - GetMessagesByApplicationSince(appID uint, limit int, since uint) ([]*model.Message, error) + GetMessagesByApplicationPaginated(appID uint, limit int, since, after uint64, by string) ([]*model.Message, error) GetApplicationByID(id uint) (*model.Application, error) - GetMessagesByUserSince(userID uint, limit int, since uint) ([]*model.Message, error) + GetMessagesByUserPaginated(userID uint, limit int, since, after uint64, by string) ([]*model.Message, error) DeleteMessageByID(id uint) error GetMessageByID(id uint) (*model.Message, error) DeleteMessagesByUser(userID uint) error @@ -41,8 +41,10 @@ type MessageAPI struct { } type pagingParams struct { - Limit int `form:"limit" binding:"min=1,max=200"` - Since uint `form:"since" binding:"min=0"` + Limit int `form:"limit" binding:"min=1,max=200"` + Since uint64 `form:"since" binding:"min=0"` + After uint64 `form:"after" binding:"min=0"` + By string `form:"by" binding:"oneof=id date"` } // GetMessages returns all messages from a user. @@ -69,6 +71,27 @@ type pagingParams struct { // required: false // type: integer // format: int64 +// - name: by +// in: query +// description: the field to order by +// required: false +// type: string +// enum: [id, date] +// default: id +// - name: after +// in: query +// description: return all messages with an cursor value greater than or equal to this value +// minimum: 0 +// required: false +// type: integer +// format: int64 +// - name: since +// in: query +// description: return all messages with an cursor value less than this value +// minimum: 0 +// required: false +// type: integer +// format: int64 // responses: // 200: // description: Ok @@ -90,7 +113,7 @@ func (a *MessageAPI) GetMessages(ctx *gin.Context) { userID := auth.GetUserID(ctx) withPaging(ctx, func(params *pagingParams) { // the +1 is used to check if there are more messages and will be removed on buildWithPaging - messages, err := a.DB.GetMessagesByUserSince(userID, params.Limit+1, params.Since) + messages, err := a.DB.GetMessagesByUserPaginated(userID, params.Limit+1, params.Since, params.After, params.By) if success := successOrAbort(ctx, 500, err); !success { return } @@ -120,7 +143,7 @@ func buildWithPaging(ctx *gin.Context, paging *pagingParams, messages []*model.M } func withPaging(ctx *gin.Context, f func(pagingParams *pagingParams)) { - params := &pagingParams{Limit: 100} + params := &pagingParams{Limit: 100, By: "id"} if err := ctx.MustBindWith(params, binding.Query); err == nil { f(params) } @@ -151,11 +174,25 @@ func withPaging(ctx *gin.Context, f func(pagingParams *pagingParams)) { // type: integer // - name: since // in: query -// description: return all messages with an ID less than this value +// description: return all messages with an cursor value less than this value +// minimum: 0 +// required: false +// type: integer +// format: int64 +// - name: after +// in: query +// description: return all messages with an cursor value greater than or equal to this value // minimum: 0 // required: false // type: integer // format: int64 +// - name: by +// in: query +// description: the field to order by +// required: false +// type: string +// enum: [id, date] +// default: id // responses: // 200: // description: Ok @@ -186,7 +223,7 @@ func (a *MessageAPI) GetMessagesWithApplication(ctx *gin.Context) { } if app != nil && app.UserID == auth.GetUserID(ctx) { // the +1 is used to check if there are more messages and will be removed on buildWithPaging - messages, err := a.DB.GetMessagesByApplicationSince(id, params.Limit+1, params.Since) + messages, err := a.DB.GetMessagesByApplicationPaginated(id, params.Limit+1, params.Since, params.After, params.By) if success := successOrAbort(ctx, 500, err); !success { return } diff --git a/database/message.go b/database/message.go index b8b23175..f06fd6f1 100644 --- a/database/message.go +++ b/database/message.go @@ -1,8 +1,11 @@ package database import ( + "time" + "github.com/gotify/server/v2/model" "gorm.io/gorm" + "gorm.io/gorm/clause" ) // GetMessageByID returns the messages for the given id or nil. @@ -34,14 +37,33 @@ func (d *GormDatabase) GetMessagesByUser(userID uint) ([]*model.Message, error) return messages, err } -// GetMessagesByUserSince returns limited messages from a user. +// GetMessagesByUserPaginated returns limited messages from a user. // If since is 0 it will be ignored. -func (d *GormDatabase) GetMessagesByUserSince(userID uint, limit int, since uint) ([]*model.Message, error) { +func (d *GormDatabase) GetMessagesByUserPaginated(userID uint, limit int, since, after uint64, by string) ([]*model.Message, error) { var messages []*model.Message db := d.DB.Joins("JOIN applications ON applications.user_id = ?", userID). - Where("messages.application_id = applications.id").Order("messages.id desc").Limit(limit) + Where("messages.application_id = applications.id").Order(clause.OrderBy{Columns: []clause.OrderByColumn{ + { + Column: clause.Column{ + Table: "messages", + Name: by, + }, + Desc: since != 0 || after == 0, + }, + }}).Limit(limit) if since != 0 { - db = db.Where("messages.id < ?", since) + sinceVal := any(since) + if by == "date" { + sinceVal = time.Unix(int64(since), 0) + } + db = db.Where(clause.Lt{Column: clause.Column{Table: "messages", Name: by}, Value: sinceVal}) + } + if after != 0 { + afterVal := any(after) + if by == "date" { + afterVal = time.Unix(int64(after), 0) + } + db = db.Where(clause.Gte{Column: clause.Column{Table: "messages", Name: by}, Value: afterVal}) } err := db.Find(&messages).Error if err == gorm.ErrRecordNotFound { @@ -60,13 +82,32 @@ func (d *GormDatabase) GetMessagesByApplication(tokenID uint) ([]*model.Message, return messages, err } -// GetMessagesByApplicationSince returns limited messages from an application. +// GetMessagesByApplicationPaginated returns limited messages from an application. // If since is 0 it will be ignored. -func (d *GormDatabase) GetMessagesByApplicationSince(appID uint, limit int, since uint) ([]*model.Message, error) { +func (d *GormDatabase) GetMessagesByApplicationPaginated(appID uint, limit int, since, after uint64, by string) ([]*model.Message, error) { var messages []*model.Message - db := d.DB.Where("application_id = ?", appID).Order("messages.id desc").Limit(limit) + db := d.DB.Where("application_id = ?", appID).Order(clause.OrderBy{Columns: []clause.OrderByColumn{ + { + Column: clause.Column{ + Table: "messages", + Name: by, + }, + Desc: since != 0 || after == 0, + }, + }}).Limit(limit) if since != 0 { - db = db.Where("messages.id < ?", since) + sinceVal := any(since) + if by == "date" { + sinceVal = time.Unix(int64(since), 0) + } + db = db.Where(clause.Lt{Column: clause.Column{Table: "messages", Name: by}, Value: sinceVal}) + } + if after != 0 { + afterVal := any(after) + if by == "date" { + afterVal = time.Unix(int64(after), 0) + } + db = db.Where(clause.Gte{Column: clause.Column{Table: "messages", Name: by}, Value: afterVal}) } err := db.Find(&messages).Error if err == gorm.ErrRecordNotFound { diff --git a/database/message_test.go b/database/message_test.go index 950567f7..53bf674e 100644 --- a/database/message_test.go +++ b/database/message_test.go @@ -1,6 +1,7 @@ package database import ( + "slices" "testing" "time" @@ -156,71 +157,142 @@ func (s *DatabaseSuite) TestGetMessagesSince() { require.NoError(s.T(), s.db.CreateApplication(app)) require.NoError(s.T(), s.db.CreateApplication(app2)) - curDate := time.Now() + curDate := time.Unix(time.Now().Unix(), 0) for i := 1; i <= 500; i++ { - s.db.CreateMessage(&model.Message{ApplicationID: app.ID, Message: "abc", Date: curDate.Add(time.Duration(i) * time.Second)}) - s.db.CreateMessage(&model.Message{ApplicationID: app2.ID, Message: "abc", Date: curDate.Add(time.Duration(i) * time.Second)}) + s.db.CreateMessage(&model.Message{ApplicationID: app.ID, Message: "abc", Date: curDate.Add(time.Duration(i*2) * time.Second)}) + s.db.CreateMessage(&model.Message{ApplicationID: app2.ID, Message: "abc", Date: curDate.Add(time.Duration(i*2+1) * time.Second)}) } - actual, err := s.db.GetMessagesByUserSince(user.ID, 50, 0) + actual, err := s.db.GetMessagesByUserPaginated(user.ID, 50, 0, 0, "id") require.NoError(s.T(), err) assert.Len(s.T(), actual, 50) hasIDInclusiveBetween(s.T(), actual, 1000, 951, 1) - actual, err = s.db.GetMessagesByUserSince(user.ID, 50, 951) + actual, err = s.db.GetMessagesByUserPaginated(user.ID, 50, 0, 0, "date") require.NoError(s.T(), err) assert.Len(s.T(), actual, 50) + hasIDInclusiveBetween(s.T(), actual, 1000, 951, 1) + + actual, err = s.db.GetMessagesByUserPaginated(user.ID, 50, 951, 0, "id") + require.NoError(s.T(), err) + assert.Len(s.T(), actual, 50) + hasIDInclusiveBetween(s.T(), actual, 950, 901, 1) + + actual, err = s.db.GetMessagesByUserPaginated(user.ID, 50, 0, 901, "id") + require.NoError(s.T(), err) + assert.Len(s.T(), actual, 50) + slices.Reverse(actual) hasIDInclusiveBetween(s.T(), actual, 950, 901, 1) - actual, err = s.db.GetMessagesByUserSince(user.ID, 100, 951) + actual, err = s.db.GetMessagesByUserPaginated(user.ID, 100, 951, 0, "id") require.NoError(s.T(), err) assert.Len(s.T(), actual, 100) hasIDInclusiveBetween(s.T(), actual, 950, 851, 1) - actual, err = s.db.GetMessagesByUserSince(user.ID, 100, 51) + actual, err = s.db.GetMessagesByUserPaginated(user.ID, 100, 51, 0, "id") require.NoError(s.T(), err) assert.Len(s.T(), actual, 50) hasIDInclusiveBetween(s.T(), actual, 50, 1, 1) - actual, err = s.db.GetMessagesByApplicationSince(app.ID, 50, 0) + actual, err = s.db.GetMessagesByApplicationPaginated(app.ID, 50, 0, 0, "id") require.NoError(s.T(), err) assert.Len(s.T(), actual, 50) hasIDInclusiveBetween(s.T(), actual, 999, 901, 2) - actual, err = s.db.GetMessagesByApplicationSince(app.ID, 50, 901) + actual, err = s.db.GetMessagesByApplicationPaginated(app.ID, 50, 0, 0, "date") + require.NoError(s.T(), err) + assert.Len(s.T(), actual, 50) + hasIDInclusiveBetween(s.T(), actual, 999, 901, 2) + + actual, err = s.db.GetMessagesByApplicationPaginated(app.ID, 50, uint64(curDate.Unix()+50), 0, "date") + require.NoError(s.T(), err) + assert.Len(s.T(), actual, 24) + hasIDInclusiveBetween(s.T(), actual, 47, 1, 2) + + actual, err = s.db.GetMessagesByApplicationPaginated(app.ID, 50, uint64(curDate.Unix()+50), uint64(curDate.Unix()+10), "date") + require.NoError(s.T(), err) + assert.Len(s.T(), actual, 20) + hasIDInclusiveBetween(s.T(), actual, 47, 47-20*2+2, 2) + + actual, err = s.db.GetMessagesByApplicationPaginated(app.ID, 50, 0, uint64(curDate.Unix()+950), "date") + require.NoError(s.T(), err) + assert.Len(s.T(), actual, 26) + slices.Reverse(actual) + hasIDInclusiveBetween(s.T(), actual, 999, 949, 2) + + actual, err = s.db.GetMessagesByApplicationPaginated(app.ID, 50, 901, 0, "id") + require.NoError(s.T(), err) + assert.Len(s.T(), actual, 50) + hasIDInclusiveBetween(s.T(), actual, 899, 801, 2) + + actual, err = s.db.GetMessagesByApplicationPaginated(app.ID, 50, 0, 801, "id") require.NoError(s.T(), err) assert.Len(s.T(), actual, 50) + slices.Reverse(actual) hasIDInclusiveBetween(s.T(), actual, 899, 801, 2) - actual, err = s.db.GetMessagesByApplicationSince(app.ID, 100, 666) + actual, err = s.db.GetMessagesByApplicationPaginated(app.ID, 100, 666, 0, "id") require.NoError(s.T(), err) assert.Len(s.T(), actual, 100) hasIDInclusiveBetween(s.T(), actual, 665, 467, 2) - actual, err = s.db.GetMessagesByApplicationSince(app.ID, 100, 101) + actual, err = s.db.GetMessagesByApplicationPaginated(app.ID, 100, 101, 0, "id") require.NoError(s.T(), err) assert.Len(s.T(), actual, 50) hasIDInclusiveBetween(s.T(), actual, 99, 1, 2) - actual, err = s.db.GetMessagesByApplicationSince(app2.ID, 50, 0) + actual, err = s.db.GetMessagesByApplicationPaginated(app2.ID, 50, 0, 0, "id") require.NoError(s.T(), err) assert.Len(s.T(), actual, 50) hasIDInclusiveBetween(s.T(), actual, 1000, 902, 2) - actual, err = s.db.GetMessagesByApplicationSince(app2.ID, 50, 902) + actual, err = s.db.GetMessagesByApplicationPaginated(app2.ID, 50, 0, 0, "date") + require.NoError(s.T(), err) + assert.Len(s.T(), actual, 50) + hasIDInclusiveBetween(s.T(), actual, 1000, 902, 2) + + actual, err = s.db.GetMessagesByApplicationPaginated(app2.ID, 50, uint64(curDate.Unix()+50), 0, "date") + require.NoError(s.T(), err) + assert.Len(s.T(), actual, 24) + hasIDInclusiveBetween(s.T(), actual, 48, 2, 2) + + actual, err = s.db.GetMessagesByApplicationPaginated(app2.ID, 50, uint64(curDate.Unix()+50), uint64(curDate.Unix()+10), "date") + require.NoError(s.T(), err) + assert.Len(s.T(), actual, 20) + hasIDInclusiveBetween(s.T(), actual, 48, 48-20*2+2, 2) + + actual, err = s.db.GetMessagesByApplicationPaginated(app2.ID, 50, 0, uint64(curDate.Unix()+950), "date") + require.NoError(s.T(), err) + assert.Len(s.T(), actual, 26) + slices.Reverse(actual) + hasIDInclusiveBetween(s.T(), actual, 1000, 950, 2) + + actual, err = s.db.GetMessagesByApplicationPaginated(app2.ID, 50, 902, 0, "id") require.NoError(s.T(), err) assert.Len(s.T(), actual, 50) hasIDInclusiveBetween(s.T(), actual, 900, 802, 2) - actual, err = s.db.GetMessagesByApplicationSince(app2.ID, 100, 667) + actual, err = s.db.GetMessagesByApplicationPaginated(app2.ID, 50, 0, 802, "id") + require.NoError(s.T(), err) + assert.Len(s.T(), actual, 50) + slices.Reverse(actual) + hasIDInclusiveBetween(s.T(), actual, 900, 802, 2) + + actual, err = s.db.GetMessagesByApplicationPaginated(app2.ID, 100, 667, 0, "id") require.NoError(s.T(), err) assert.Len(s.T(), actual, 100) hasIDInclusiveBetween(s.T(), actual, 666, 468, 2) - actual, err = s.db.GetMessagesByApplicationSince(app2.ID, 100, 102) + actual, err = s.db.GetMessagesByApplicationPaginated(app2.ID, 100, 102, 0, "id") require.NoError(s.T(), err) assert.Len(s.T(), actual, 50) hasIDInclusiveBetween(s.T(), actual, 100, 2, 2) + + actual, err = s.db.GetMessagesByApplicationPaginated(app2.ID, 50, 0, uint64(curDate.Unix()+950), "date") + require.NoError(s.T(), err) + assert.Len(s.T(), actual, 26) + slices.Reverse(actual) + hasIDInclusiveBetween(s.T(), actual, 1000, 950, 2) } func hasIDInclusiveBetween(t *testing.T, msgs []*model.Message, from, to, decrement int) { diff --git a/docs/spec.json b/docs/spec.json index 450450ec..837a2e68 100644 --- a/docs/spec.json +++ b/docs/spec.json @@ -485,9 +485,28 @@ "minimum": 0, "type": "integer", "format": "int64", - "description": "return all messages with an ID less than this value", + "description": "return all messages with an cursor value less than this value", "name": "since", "in": "query" + }, + { + "minimum": 0, + "type": "integer", + "format": "int64", + "description": "return all messages with an cursor value greater than or equal to this value", + "name": "after", + "in": "query" + }, + { + "enum": [ + "id", + "date" + ], + "type": "string", + "default": "id", + "description": "the field to order by", + "name": "by", + "in": "query" } ], "responses": { @@ -1025,6 +1044,33 @@ "description": "return all messages with an ID less than this value", "name": "since", "in": "query" + }, + { + "enum": [ + "id", + "date" + ], + "type": "string", + "default": "id", + "description": "the field to order by", + "name": "by", + "in": "query" + }, + { + "minimum": 0, + "type": "integer", + "format": "int64", + "description": "return all messages with an cursor value greater than or equal to this value", + "name": "after", + "in": "query" + }, + { + "minimum": 0, + "type": "integer", + "format": "int64", + "description": "return all messages with an cursor value less than this value", + "name": "since", + "in": "query" } ], "responses": { diff --git a/model/message.go b/model/message.go index e00545a2..779d5a15 100644 --- a/model/message.go +++ b/model/message.go @@ -6,13 +6,13 @@ import ( // Message holds information about a message. type Message struct { - ID uint `gorm:"autoIncrement;primaryKey;index"` - ApplicationID uint + ID uint `gorm:"autoIncrement;primaryKey;index"` + ApplicationID uint `gorm:"index:,composite:composite_application_id_date"` Message string `gorm:"type:text"` Title string `gorm:"type:text"` Priority int Extras []byte - Date time.Time + Date time.Time `gorm:"index:,composite:composite_application_id_date"` } // MessageExternal Model