Skip to content

Commit d74fd81

Browse files
committed
fix: allows to search for exposed nested resources and doesn't expose internal sql columns
1 parent 6b6f631 commit d74fd81

File tree

6 files changed

+281
-87
lines changed

6 files changed

+281
-87
lines changed

pkg/dao/dinosaur.go

+20
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,28 @@ import (
77

88
"github.com/openshift-online/rh-trex/pkg/api"
99
"github.com/openshift-online/rh-trex/pkg/db"
10+
"github.com/openshift-online/rh-trex/pkg/util"
1011
)
1112

13+
var (
14+
dinosaurTableName = util.ToSnakeCase(api.DinosaurTypeName) + "s"
15+
dinosaurColumns = []string{
16+
"id",
17+
"created_at",
18+
"updated_at",
19+
"species",
20+
}
21+
)
22+
23+
func DinosaurApiToModel() TableMappingRelation {
24+
result := map[string]string{}
25+
applyBaseMapping(result, dinosaurColumns, dinosaurTableName)
26+
return TableMappingRelation{
27+
Mapping: result,
28+
relationTableName: dinosaurTableName,
29+
}
30+
}
31+
1232
type DinosaurDao interface {
1333
Get(ctx context.Context, id string) (*api.Dinosaur, error)
1434
Create(ctx context.Context, dinosaur *api.Dinosaur) (*api.Dinosaur, error)

pkg/dao/generic.go

+36
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package dao
22

33
import (
44
"context"
5+
"fmt"
56
"strings"
67

78
"github.com/jinzhu/inflection"
@@ -10,6 +11,41 @@ import (
1011
"github.com/openshift-online/rh-trex/pkg/db"
1112
)
1213

14+
type TableMappingRelation struct {
15+
Mapping map[string]string
16+
relationTableName string
17+
}
18+
19+
type relationMapping func() TableMappingRelation
20+
21+
func applyBaseMapping(result map[string]string, columns []string, tableName string) {
22+
for _, c := range columns {
23+
mappingKey := c
24+
mappingValue := fmt.Sprintf("%s.%s", tableName, c)
25+
columnParts := strings.Split(c, ".")
26+
if len(columnParts) == 1 {
27+
mappingKey = mappingValue
28+
}
29+
if len(columnParts) == 2 {
30+
mappingValue = strings.Split(mappingKey, ".")[1]
31+
}
32+
result[mappingKey] = mappingValue
33+
}
34+
}
35+
36+
func applyRelationMapping(result map[string]string, relations []relationMapping) {
37+
for _, relation := range relations {
38+
tableMappingRelation := relation()
39+
for k, v := range tableMappingRelation.Mapping {
40+
if _, ok := result[k]; ok {
41+
result[tableMappingRelation.relationTableName+"."+k] = v
42+
} else {
43+
result[k] = v
44+
}
45+
}
46+
}
47+
}
48+
1349
type Where struct {
1450
sql string
1551
values []any

pkg/dao/generic_test.go

+63
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
package dao
2+
3+
import (
4+
"fmt"
5+
"strings"
6+
7+
. "github.com/onsi/ginkgo/v2/dsl/core"
8+
. "github.com/onsi/gomega"
9+
)
10+
11+
var _ = Describe("applyBaseMapping", func() {
12+
It("generates base mapping", func() {
13+
result := map[string]string{}
14+
applyBaseMapping(result, []string{"id", "created_at", "column1", "nested.field"}, "test_table")
15+
for k, v := range result {
16+
if strings.HasPrefix(k, "test_table") {
17+
Expect(k).To(Equal(v))
18+
continue
19+
}
20+
// nested fields from table
21+
i := strings.Index(k, ".")
22+
Expect(k[i+1:]).To(Equal(v))
23+
}
24+
})
25+
})
26+
27+
var _ = Describe("applyRelationMapping", func() {
28+
It("generates relation mapping", func() {
29+
result := map[string]string{}
30+
applyBaseMapping(result, []string{"id", "created_at", "column1", "nested.field"}, "base_table")
31+
applyRelationMapping(result, []relationMapping{
32+
func() TableMappingRelation {
33+
result := map[string]string{}
34+
applyBaseMapping(result, []string{"id", "created_at", "column1", "nested.field"}, "relation_table")
35+
return TableMappingRelation{
36+
relationTableName: "relation_table",
37+
Mapping: result,
38+
}
39+
},
40+
})
41+
for k, v := range result {
42+
if strings.HasPrefix(k, "base_table") {
43+
Expect(k).To(Equal(v))
44+
continue
45+
}
46+
if strings.HasPrefix(k, "relation_table") {
47+
if c := strings.Count(k, "."); c > 1 {
48+
i := strings.Index(k, ".")
49+
i = strings.Index(k[i+1:], ".") + i
50+
Expect(k[i+2:]).To(Equal(v))
51+
continue
52+
}
53+
Expect(k).To(Equal(v))
54+
continue
55+
}
56+
57+
// nested fields from base table
58+
i := strings.Index(k, ".")
59+
Expect(k[i+1:]).To(Equal(v))
60+
fmt.Println(k, v)
61+
}
62+
})
63+
})

pkg/db/sql_helpers.go

+41-30
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package db
33
import (
44
"fmt"
55
"reflect"
6+
"slices"
67
"strings"
78

89
"github.com/jinzhu/inflection"
@@ -11,6 +12,11 @@ import (
1112
"gorm.io/gorm"
1213
)
1314

15+
const (
16+
invalidFieldNameMsg = "%s is not a valid field name"
17+
disallowedFieldNameMsg = "%s is a disallowed field name"
18+
)
19+
1420
// Check if a field name starts with properties.
1521
func startsWithProperties(s string) bool {
1622
return strings.HasPrefix(s, "properties.")
@@ -33,34 +39,33 @@ func hasProperty(n tsl.Node) bool {
3339
}
3440

3541
// getField gets the sql field associated with a name.
36-
func getField(name string, disallowedFields map[string]string) (field string, err *errors.ServiceError) {
42+
func getField(
43+
name string,
44+
disallowedFields []string,
45+
apiToModel map[string]string,
46+
) (field string, err *errors.ServiceError) {
3747
// We want to accept names with trailing and leading spaces
3848
trimmedName := strings.Trim(name, " ")
3949

40-
// Check for properties ->> '<some field name>'
41-
if strings.HasPrefix(trimmedName, "properties ->>") {
42-
field = trimmedName
43-
return
50+
mappedField, ok := apiToModel[trimmedName]
51+
if !ok {
52+
return "", errors.BadRequest(invalidFieldNameMsg, name)
4453
}
4554

4655
// Check for nested field, e.g., subscription_labels.key
47-
checkName := trimmedName
48-
fieldParts := strings.Split(trimmedName, ".")
56+
checkName := mappedField
57+
fieldParts := strings.Split(checkName, ".")
4958
if len(fieldParts) > 2 {
50-
err = errors.BadRequest("%s is not a valid field name", name)
59+
err = errors.BadRequest(invalidFieldNameMsg, name)
5160
return
5261
}
53-
if len(fieldParts) > 1 {
54-
checkName = fieldParts[1]
55-
}
5662

5763
// Check for allowed fields
58-
_, ok := disallowedFields[checkName]
59-
if ok {
60-
err = errors.BadRequest("%s is not a valid field name", name)
64+
if slices.Contains(disallowedFields, checkName) {
65+
err = errors.BadRequest(disallowedFieldNameMsg, name)
6166
return
6267
}
63-
field = trimmedName
68+
field = checkName
6469
return
6570
}
6671

@@ -102,7 +107,8 @@ func propertiesNodeConverter(n tsl.Node) tsl.Node {
102107
// b. replace the field name with the SQL column name.
103108
func FieldNameWalk(
104109
n tsl.Node,
105-
disallowedFields map[string]string) (newNode tsl.Node, err *errors.ServiceError) {
110+
disallowedFields []string,
111+
apiToModel map[string]string) (newNode tsl.Node, err *errors.ServiceError) {
106112

107113
var field string
108114
var l, r tsl.Node
@@ -124,7 +130,7 @@ func FieldNameWalk(
124130
}
125131

126132
// Check field name in the disallowedFields field names.
127-
field, err = getField(userFieldName, disallowedFields)
133+
field, err = getField(userFieldName, disallowedFields, apiToModel)
128134
if err != nil {
129135
return
130136
}
@@ -137,7 +143,7 @@ func FieldNameWalk(
137143
default:
138144
// o/w continue walking the tree.
139145
if n.Left != nil {
140-
l, err = FieldNameWalk(n.Left.(tsl.Node), disallowedFields)
146+
l, err = FieldNameWalk(n.Left.(tsl.Node), disallowedFields, apiToModel)
141147
if err != nil {
142148
return
143149
}
@@ -148,7 +154,7 @@ func FieldNameWalk(
148154
switch v := n.Right.(type) {
149155
case tsl.Node:
150156
// It's a regular node, just add it.
151-
r, err = FieldNameWalk(v, disallowedFields)
157+
r, err = FieldNameWalk(v, disallowedFields, apiToModel)
152158
if err != nil {
153159
return
154160
}
@@ -162,7 +168,7 @@ func FieldNameWalk(
162168

163169
// Add all nodes in the right side array.
164170
for _, e := range v {
165-
r, err = FieldNameWalk(e, disallowedFields)
171+
r, err = FieldNameWalk(e, disallowedFields, apiToModel)
166172
if err != nil {
167173
return
168174
}
@@ -189,23 +195,26 @@ func FieldNameWalk(
189195
}
190196

191197
// cleanOrderBy takes the orderBy arg and cleans it.
192-
func cleanOrderBy(userArg string, disallowedFields map[string]string) (orderBy string, err *errors.ServiceError) {
198+
func cleanOrderBy(userArg string,
199+
disallowedFields []string,
200+
apiToModel map[string]string,
201+
tableName string) (orderBy string, err *errors.ServiceError) {
193202
var orderField string
194203

195204
// We want to accept user params with trailing and leading spaces
196205
trimedName := strings.Trim(userArg, " ")
197206

198207
// Each OrderBy can be a "<field-name>" or a "<field-name> asc|desc"
199208
order := strings.Split(trimedName, " ")
200-
direction := "none valid"
201-
202-
if len(order) == 1 {
203-
orderField, err = getField(order[0], disallowedFields)
204-
direction = "asc"
205-
} else if len(order) == 2 {
206-
orderField, err = getField(order[0], disallowedFields)
209+
direction := "asc"
210+
if len(order) == 2 {
207211
direction = order[1]
208212
}
213+
field := order[0]
214+
if orderParts := strings.Split(order[0], "."); len(orderParts) == 1 {
215+
field = fmt.Sprintf("%s.%s", tableName, field)
216+
}
217+
orderField, err = getField(field, disallowedFields, apiToModel)
209218
if err != nil || (direction != "asc" && direction != "desc") {
210219
err = errors.BadRequest("bad order value '%s'", userArg)
211220
return
@@ -218,13 +227,15 @@ func cleanOrderBy(userArg string, disallowedFields map[string]string) (orderBy s
218227
// ArgsToOrderBy returns cleaned orderBy list.
219228
func ArgsToOrderBy(
220229
orderByArgs []string,
221-
disallowedFields map[string]string) (orderBy []string, err *errors.ServiceError) {
230+
disallowedFields []string,
231+
apiToModel map[string]string,
232+
tableName string) (orderBy []string, err *errors.ServiceError) {
222233

223234
var order string
224235
if len(orderByArgs) != 0 {
225236
orderBy = []string{}
226237
for _, o := range orderByArgs {
227-
order, err = cleanOrderBy(o, disallowedFields)
238+
order, err = cleanOrderBy(o, disallowedFields, apiToModel, tableName)
228239
if err != nil {
229240
return
230241
}

0 commit comments

Comments
 (0)