Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion core/helper/errors/other.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,7 @@ package errs
import "errors"

var (
ErrInvalidPage = errors.New("entered page is invalid")
ErrInvalidPage = errors.New("entered page is invalid")
ErrInvalidSort = errors.New("entered sort parameter is invalid")
ErrInvalidInclude = errors.New("entered include parameter is invalid")
)
49 changes: 41 additions & 8 deletions infrastructure/query/generic.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package query

import (
"fmt"
"math"
"strings"

Expand All @@ -10,27 +11,59 @@ import (
"gorm.io/gorm"
)

func applySorting(stmt *gorm.DB, allowedSorts []string, sort string) *gorm.DB {
func applySorting(stmt *gorm.DB, allowedSorts []string, sort string) (*gorm.DB, error) {
col := sort
direction := " ASC"

if sort == "" {
col = allowedSorts[0]
return stmt.Order(col + direction), nil
}

if strings.HasPrefix(sort, "-") {
col = sort[1:]
direction = " DESC"
}

if !slices.Contains(allowedSorts, col) {
col = allowedSorts[0]
direction = " ASC"
return nil, fmt.Errorf("%w: column '%s' (allowed values: %s)",
errs.ErrInvalidSort, col, strings.Join(allowedSorts, ", "))
}

stmt = stmt.Order(col + direction)
return stmt
return stmt.Order(col + direction), nil
}

func applyIncludes(stmt *gorm.DB, allowedIncludes []string, includes string) (*gorm.DB, error) {
for _, include := range strings.Split(includes, ",") {
if include == "" {
continue
}

allowedValues := "-"
if len(allowedIncludes) > 0 {
allowedValues = strings.Join(allowedIncludes, ", ")
}
if !slices.Contains(allowedIncludes, include) {
return nil, fmt.Errorf("%w: column '%s' (allowed values: %s)",
errs.ErrInvalidInclude, include, allowedValues)
}
stmt = stmt.Preload(include)
}
return stmt, nil
}

func GetWithPagination[T any](stmt *gorm.DB, req base.PaginationRequest, allowedSorts []string,
func GetWithPagination[T any](
stmt *gorm.DB, req base.PaginationRequest,
allowedSorts []string, allowedIncludes []string,
) (data []T, paginationResp base.PaginationResponse, err error) {
stmt = applySorting(stmt, allowedSorts, req.Sort)
stmt, err = applySorting(stmt, allowedSorts, req.Sort)
if err != nil {
return nil, paginationResp, err
}

stmt, err = applyIncludes(stmt, allowedIncludes, req.Includes)
if err != nil {
return nil, paginationResp, err
}

if req.PerPage == 0 {
err = stmt.Find(&data).Error
Expand Down
6 changes: 4 additions & 2 deletions infrastructure/query/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ import (
"gorm.io/gorm"
)

var allowedSorts = []string{"id", "name", "email", "created_at", "updated_at"}
var userAllowedSorts = []string{"id", "name", "email", "created_at", "updated_at"}
var userAllowedIncludes = []string{}

type userQuery struct {
db *gorm.DB
Expand All @@ -36,7 +37,8 @@ func (qr *userQuery) GetAllUsers(ctx context.Context, req dto.UserGetsRequest,
stmt = stmt.Where("name ILIKE ? OR email ILIKE ?", search, search)
}

users, pageResp, err := GetWithPagination[entity.User](stmt, req.PaginationRequest, allowedSorts)
users, pageResp, err := GetWithPagination[entity.User](stmt,
req.PaginationRequest, userAllowedSorts, userAllowedIncludes)
if err != nil {
return nil, pageResp, err
}
Expand Down
7 changes: 4 additions & 3 deletions support/base/request.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
package base

type PaginationRequest struct {
Sort string `json:"sort" form:"sort"`
Page int `json:"page" form:"page" binding:"omitempty,min=1"`
PerPage int `json:"per_page" form:"per_page" binding:"omitempty,min=1"`
Sort string `json:"sort" form:"sort"`
Includes string `json:"includes" form:"includes"`
Page int `json:"page" form:"page" binding:"omitempty,min=1"`
PerPage int `json:"per_page" form:"per_page" binding:"omitempty,min=1"`
}
Loading