Skip to content

Commit

Permalink
Add support for computed columns
Browse files Browse the repository at this point in the history
  • Loading branch information
System-Glitch committed Aug 29, 2022
1 parent e305b7f commit 71df4a7
Show file tree
Hide file tree
Showing 14 changed files with 411 additions and 38 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,5 +36,5 @@ jobs:
- name: Run lint
uses: golangci/golangci-lint-action@v2
with:
version: v1.47
version: v1.49
args: --timeout 5m
4 changes: 3 additions & 1 deletion .golangci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ linters:
- misspell
- bodyclose
- govet
- deadcode
- unused
- errcheck
disable-all: false
fast: false
Expand All @@ -35,3 +35,5 @@ issues:
exclude-use-default: false
max-issues-per-linter: 0
max-same-issues: 0
exclude:
- should have a package comment
29 changes: 29 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,35 @@ Internally, `goyave.dev/filter` uses [Goyave's `Paginator`](https://goyave.dev/g
- If `per_page` isn't given, the default page size will be used. This default value can be overridden by changing `filter.DefaultPageSize`.
- Either way, the result is **always** paginated, even if those two parameters are missing.

## Computed columns

Sometimes you need to work with a "virtual" column that is not stored in your database, but is computed using an SQL expression. A dynamic status depending on a date for example. In order to support the features of this library properly, you will have to add the expression to your model using the `computed` struct tag:

```go
type MyModel struct{
ID uint
// ...
StartDate time.Time
Status string `gorm:"->"` `computed:"CASE WHEN ~~~ct~~~.start_date < NOW() THEN 'pending' ELSE 'started' END"`
}
```

*Note: the `~~~ct~~~` is an indicator for the **c**urrent **t**able. It will be replaced by the correct table or relation name automatically. This allows the usage of computed fields in relations too, where joins are needed.*

**Tip:** you can also use composition to avoid including the virtual column into your model:
```go
type MyModel struct{
ID uint
// ...
StartDate time.Time
}

type MyModelWithStatus struct{
MyModel
Status string `gorm:"->"` `computed:"CASE WHEN ~~~ct~~~.start_date < NOW() THEN 'pending' ELSE 'started' END"`
}
```

## Security

- Inputs are escaped to prevent SQL injections.
Expand Down
16 changes: 14 additions & 2 deletions filter.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
package filter

import (
"fmt"
"strings"

"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/schema"
)

Expand Down Expand Up @@ -32,9 +36,17 @@ func (f *Filter) Scope(settings *Settings, sch *schema.Schema) (func(*gorm.DB) *
return tx
}

computed := field.StructField.Tag.Get("computed")

conditionScope := func(tx *gorm.DB) *gorm.DB {
tableName := tx.Statement.Quote(tableFromJoinName(s.Table, joinName)) + "."
return f.Operator.Function(tx, f, tableName+tx.Statement.Quote(field.DBName), field.DataType)
table := tx.Statement.Quote(tableFromJoinName(s.Table, joinName))
var fieldExpr string
if computed != "" {
fieldExpr = fmt.Sprintf("(%s)", strings.ReplaceAll(computed, clause.CurrentTable, table))
} else {
fieldExpr = table + "." + tx.Statement.Quote(field.DBName)
}
return f.Operator.Function(tx, f, fieldExpr, field.DataType)
}

return joinScope, conditionScope
Expand Down
90 changes: 90 additions & 0 deletions filter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -448,3 +448,93 @@ func TestFilterScopeWithAlreadyExistingRawJoin(t *testing.T) {
assert.Equal(t, expected, db.Statement.Clauses)
assert.NotEmpty(t, db.Statement.Joins)
}

type FilterTestModelComputedRelation struct {
Name string
Computed string `computed:"~~~ct~~~.computedcolumnrelation"`
ID uint
ParentID uint
}

type FilterTestModelComputed struct {
Relation *FilterTestModelComputedRelation `gorm:"foreignKey:ParentID"`
Name string
Computed string `computed:"~~~ct~~~.computedcolumn"`
ID uint
}

func TestFilterScopeComputed(t *testing.T) {
db, _ := gorm.Open(&tests.DummyDialector{}, nil)
filter := &Filter{Field: "computed", Args: []string{"val1"}, Operator: Operators["$eq"]}

results := []*FilterTestModelComputed{}
schema, err := parseModel(db, &results)
if !assert.Nil(t, err) {
return
}

db = db.Model(&results).Scopes(filter.Scope(&Settings{}, schema)).Find(&results)
expected := map[string]clause.Clause{
"WHERE": {
Name: "WHERE",
Expression: clause.Where{
Exprs: []clause.Expression{
clause.Expr{SQL: "(`filter_test_model_computeds`.computedcolumn) = ?", Vars: []interface{}{"val1"}},
},
},
},
}
assert.Equal(t, expected, db.Statement.Clauses)
}

func TestFilterScopeComputedRelation(t *testing.T) {
db, _ := gorm.Open(&tests.DummyDialector{}, nil)
filter := &Filter{Field: "Relation.computed", Args: []string{"val1"}, Operator: Operators["$eq"]}

results := []*FilterTestModelComputed{}
schema, err := parseModel(db, &results)
if !assert.Nil(t, err) {
return
}

db = db.Model(&results).Scopes(filter.Scope(&Settings{}, schema)).Find(&results)
expected := map[string]clause.Clause{
"WHERE": {
Name: "WHERE",
Expression: clause.Where{
Exprs: []clause.Expression{
clause.Expr{SQL: "(`Relation`.computedcolumnrelation) = ?", Vars: []interface{}{"val1"}},
},
},
},
"FROM": {
Name: "FROM",
Expression: clause.From{
Joins: []clause.Join{
{
Type: clause.LeftJoin,
Table: clause.Table{
Name: "filter_test_model_computed_relations",
Alias: "Relation",
},
ON: clause.Where{
Exprs: []clause.Expression{
clause.Eq{
Column: clause.Column{
Table: "filter_test_model_computeds",
Name: "id",
},
Value: clause.Column{
Table: "Relation",
Name: "parent_id",
},
},
},
},
},
},
},
},
}
assert.Equal(t, expected, db.Statement.Clauses)
}
10 changes: 5 additions & 5 deletions join.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,16 +94,16 @@ func joinScope(relationName string, rel *schema.Relationship, fields []string, b
return tx
}
if columns != nil {
for _, k := range rel.FieldSchema.PrimaryFieldDBNames {
if !sliceutil.ContainsStr(columns, k) && (blacklist == nil || !sliceutil.ContainsStr(blacklist.FieldsBlacklist, k)) {
columns = append(columns, k)
for _, primaryField := range rel.FieldSchema.PrimaryFields {
if !columnsContain(columns, primaryField) && (blacklist == nil || !sliceutil.ContainsStr(blacklist.FieldsBlacklist, primaryField.DBName)) {
columns = append(columns, primaryField)
}
}
for _, backwardsRelation := range rel.FieldSchema.Relationships.Relations {
if backwardsRelation.FieldSchema == rel.Schema && backwardsRelation.Type == schema.BelongsTo {
for _, ref := range backwardsRelation.References {
if !sliceutil.ContainsStr(columns, ref.ForeignKey.DBName) && (blacklist == nil || !sliceutil.ContainsStr(blacklist.FieldsBlacklist, ref.ForeignKey.DBName)) {
columns = append(columns, ref.ForeignKey.DBName)
if !columnsContain(columns, ref.ForeignKey) && (blacklist == nil || !sliceutil.ContainsStr(blacklist.FieldsBlacklist, ref.ForeignKey.DBName)) {
columns = append(columns, ref.ForeignKey)
}
}
}
Expand Down
17 changes: 15 additions & 2 deletions search.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
package filter

import (
"fmt"
"strings"

"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/schema"
)

Expand Down Expand Up @@ -43,8 +47,17 @@ func (s *Search) Scope(schema *schema.Schema) func(*gorm.DB) *gorm.DB {
Or: true,
}

tableName := tx.Statement.Quote(tableFromJoinName(sch.Table, joinName)) + "."
searchQuery = s.Operator.Function(searchQuery, filter, tableName+tx.Statement.Quote(f.DBName), f.DataType)
table := tx.Statement.Quote(tableFromJoinName(sch.Table, joinName))

computed := f.StructField.Tag.Get("computed")
var fieldExpr string
if computed != "" {
fieldExpr = fmt.Sprintf("(%s)", strings.ReplaceAll(computed, clause.CurrentTable, table))
} else {
fieldExpr = table + "." + tx.Statement.Quote(f.DBName)
}

searchQuery = s.Operator.Function(searchQuery, filter, fieldExpr, f.DataType)
}

return tx.Where(searchQuery)
Expand Down
91 changes: 91 additions & 0 deletions search_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -312,3 +312,94 @@ func TestSeachScopeWithJoinNestedRelation(t *testing.T) {
}
assert.Equal(t, expected, db.Statement.Clauses)
}

type SearchTestModelComputedRelation struct {
Name string
Computed string `computed:"~~~ct~~~.computedcolumnrelation"`
ID uint
ParentID uint
}

type SearchTestModelComputed struct {
Relation *SearchTestModelComputedRelation `gorm:"foreignKey:ParentID"`
Name string
Computed string `computed:"~~~ct~~~.computedcolumn"`
ID uint
}

func TestSearchScopeComputed(t *testing.T) {
db, _ := gorm.Open(&tests.DummyDialector{}, nil)
search := &Search{
Fields: []string{"computed", "Relation.computed"},
Query: "My Query",
Operator: Operators["$eq"],
}

results := []*SearchTestModelComputed{}
schema, err := parseModel(db, &results)
if !assert.Nil(t, err) {
return
}

db = db.Model(&results).Scopes(search.Scope(schema)).Find(&results)
expected := map[string]clause.Clause{
"WHERE": {
Name: "WHERE",
Expression: clause.Where{
Exprs: []clause.Expression{
clause.AndConditions{
Exprs: []clause.Expression{
clause.OrConditions{
Exprs: []clause.Expression{
clause.Expr{
SQL: "(`search_test_model_computeds`.computedcolumn) = ?",
Vars: []interface{}{"My Query"},
WithoutParentheses: false,
},
},
},
clause.OrConditions{
Exprs: []clause.Expression{
clause.Expr{
SQL: "(`Relation`.computedcolumnrelation) = ?",
Vars: []interface{}{"My Query"},
WithoutParentheses: false,
},
},
},
},
},
},
},
},
"FROM": {
Name: "FROM",
Expression: clause.From{
Joins: []clause.Join{
{
Type: clause.LeftJoin,
Table: clause.Table{
Name: "search_test_model_computed_relations",
Alias: "Relation",
},
ON: clause.Where{
Exprs: []clause.Expression{
clause.Eq{
Column: clause.Column{
Table: "search_test_model_computeds",
Name: "id",
},
Value: clause.Column{
Table: "Relation",
Name: "parent_id",
},
},
},
},
},
},
},
},
}
assert.Equal(t, expected, db.Statement.Clauses)
}
27 changes: 19 additions & 8 deletions settings.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"sync"

"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/schema"
"goyave.dev/goyave/v4"
"goyave.dev/goyave/v4/database"
Expand Down Expand Up @@ -217,7 +218,9 @@ func (s *Settings) applySearch(request *goyave.Request, schema *schema.Schema) *
if ok {
fields := s.FieldsSearch
if fields == nil {
fields = s.getSelectableFields(schema.FieldsByDBName)
for _, f := range s.getSelectableFields(schema.FieldsByDBName) {
fields = append(fields, f.DBName)
}
}

operator := s.SearchOperator
Expand All @@ -237,22 +240,22 @@ func (s *Settings) applySearch(request *goyave.Request, schema *schema.Schema) *
return nil
}

func (b *Blacklist) getSelectableFields(fields map[string]*schema.Field) []string {
func (b *Blacklist) getSelectableFields(fields map[string]*schema.Field) []*schema.Field {
blacklist := []string{}
if b.FieldsBlacklist != nil {
blacklist = b.FieldsBlacklist
}
columns := make([]string, 0, len(fields))
for k := range fields {
columns := make([]*schema.Field, 0, len(fields))
for k, f := range fields {
if !sliceutil.ContainsStr(blacklist, k) {
columns = append(columns, k)
columns = append(columns, f)
}
}

return columns
}

func selectScope(table string, fields []string, override bool) func(*gorm.DB) *gorm.DB {
func selectScope(table string, fields []*schema.Field, override bool) func(*gorm.DB) *gorm.DB {
return func(tx *gorm.DB) *gorm.DB {

if fields == nil {
Expand All @@ -264,9 +267,17 @@ func selectScope(table string, fields []string, override bool) func(*gorm.DB) *g
fieldsWithTableName = []string{"1"}
} else {
fieldsWithTableName = make([]string, 0, len(fields))
tableName := tx.Statement.Quote(table) + "."
tableName := tx.Statement.Quote(table)
for _, f := range fields {
fieldsWithTableName = append(fieldsWithTableName, tableName+tx.Statement.Quote(f))
computed := f.StructField.Tag.Get("computed")
var fieldExpr string
if computed != "" {
fieldExpr = fmt.Sprintf("(%s) %s", strings.ReplaceAll(computed, clause.CurrentTable, tableName), tx.Statement.Quote(f.DBName))
} else {
fieldExpr = tableName + "." + tx.Statement.Quote(f.DBName)
}

fieldsWithTableName = append(fieldsWithTableName, fieldExpr)
}
}

Expand Down
Loading

0 comments on commit 71df4a7

Please sign in to comment.