@@ -3,6 +3,7 @@ package db
3
3
import (
4
4
"fmt"
5
5
"reflect"
6
+ "slices"
6
7
"strings"
7
8
8
9
"github.com/jinzhu/inflection"
@@ -11,6 +12,11 @@ import (
11
12
"gorm.io/gorm"
12
13
)
13
14
15
+ const (
16
+ invalidFieldNameMsg = "%s is not a valid field name"
17
+ disallowedFieldNameMsg = "%s is a disallowed field name"
18
+ )
19
+
14
20
// Check if a field name starts with properties.
15
21
func startsWithProperties (s string ) bool {
16
22
return strings .HasPrefix (s , "properties." )
@@ -33,34 +39,33 @@ func hasProperty(n tsl.Node) bool {
33
39
}
34
40
35
41
// getField gets the sql field associated with a name.
36
- func getField (name string , disallowedFields map [string ]string ) (field string , err * errors.ServiceError ) {
42
+ func getField (
43
+ name string ,
44
+ disallowedFields []string ,
45
+ apiToModel map [string ]string ,
46
+ ) (field string , err * errors.ServiceError ) {
37
47
// We want to accept names with trailing and leading spaces
38
48
trimmedName := strings .Trim (name , " " )
39
49
40
- // Check for properties ->> '<some field name>'
41
- if strings .HasPrefix (trimmedName , "properties ->>" ) {
42
- field = trimmedName
43
- return
50
+ mappedField , ok := apiToModel [trimmedName ]
51
+ if ! ok {
52
+ return "" , errors .BadRequest (invalidFieldNameMsg , name )
44
53
}
45
54
46
55
// Check for nested field, e.g., subscription_labels.key
47
- checkName := trimmedName
48
- fieldParts := strings .Split (trimmedName , "." )
56
+ checkName := mappedField
57
+ fieldParts := strings .Split (checkName , "." )
49
58
if len (fieldParts ) > 2 {
50
- err = errors .BadRequest ("%s is not a valid field name" , name )
59
+ err = errors .BadRequest (invalidFieldNameMsg , name )
51
60
return
52
61
}
53
- if len (fieldParts ) > 1 {
54
- checkName = fieldParts [1 ]
55
- }
56
62
57
63
// Check for allowed fields
58
- _ , ok := disallowedFields [checkName ]
59
- if ok {
60
- err = errors .BadRequest ("%s is not a valid field name" , name )
64
+ if slices .Contains (disallowedFields , checkName ) {
65
+ err = errors .BadRequest (disallowedFieldNameMsg , name )
61
66
return
62
67
}
63
- field = trimmedName
68
+ field = checkName
64
69
return
65
70
}
66
71
@@ -102,7 +107,8 @@ func propertiesNodeConverter(n tsl.Node) tsl.Node {
102
107
// b. replace the field name with the SQL column name.
103
108
func FieldNameWalk (
104
109
n tsl.Node ,
105
- disallowedFields map [string ]string ) (newNode tsl.Node , err * errors.ServiceError ) {
110
+ disallowedFields []string ,
111
+ apiToModel map [string ]string ) (newNode tsl.Node , err * errors.ServiceError ) {
106
112
107
113
var field string
108
114
var l , r tsl.Node
@@ -124,7 +130,7 @@ func FieldNameWalk(
124
130
}
125
131
126
132
// Check field name in the disallowedFields field names.
127
- field , err = getField (userFieldName , disallowedFields )
133
+ field , err = getField (userFieldName , disallowedFields , apiToModel )
128
134
if err != nil {
129
135
return
130
136
}
@@ -137,7 +143,7 @@ func FieldNameWalk(
137
143
default :
138
144
// o/w continue walking the tree.
139
145
if n .Left != nil {
140
- l , err = FieldNameWalk (n .Left .(tsl.Node ), disallowedFields )
146
+ l , err = FieldNameWalk (n .Left .(tsl.Node ), disallowedFields , apiToModel )
141
147
if err != nil {
142
148
return
143
149
}
@@ -148,7 +154,7 @@ func FieldNameWalk(
148
154
switch v := n .Right .(type ) {
149
155
case tsl.Node :
150
156
// It's a regular node, just add it.
151
- r , err = FieldNameWalk (v , disallowedFields )
157
+ r , err = FieldNameWalk (v , disallowedFields , apiToModel )
152
158
if err != nil {
153
159
return
154
160
}
@@ -162,7 +168,7 @@ func FieldNameWalk(
162
168
163
169
// Add all nodes in the right side array.
164
170
for _ , e := range v {
165
- r , err = FieldNameWalk (e , disallowedFields )
171
+ r , err = FieldNameWalk (e , disallowedFields , apiToModel )
166
172
if err != nil {
167
173
return
168
174
}
@@ -189,23 +195,26 @@ func FieldNameWalk(
189
195
}
190
196
191
197
// cleanOrderBy takes the orderBy arg and cleans it.
192
- func cleanOrderBy (userArg string , disallowedFields map [string ]string ) (orderBy string , err * errors.ServiceError ) {
198
+ func cleanOrderBy (userArg string ,
199
+ disallowedFields []string ,
200
+ apiToModel map [string ]string ,
201
+ tableName string ) (orderBy string , err * errors.ServiceError ) {
193
202
var orderField string
194
203
195
204
// We want to accept user params with trailing and leading spaces
196
205
trimedName := strings .Trim (userArg , " " )
197
206
198
207
// Each OrderBy can be a "<field-name>" or a "<field-name> asc|desc"
199
208
order := strings .Split (trimedName , " " )
200
- direction := "none valid"
201
-
202
- if len (order ) == 1 {
203
- orderField , err = getField (order [0 ], disallowedFields )
204
- direction = "asc"
205
- } else if len (order ) == 2 {
206
- orderField , err = getField (order [0 ], disallowedFields )
209
+ direction := "asc"
210
+ if len (order ) == 2 {
207
211
direction = order [1 ]
208
212
}
213
+ field := order [0 ]
214
+ if orderParts := strings .Split (order [0 ], "." ); len (orderParts ) == 1 {
215
+ field = fmt .Sprintf ("%s.%s" , tableName , field )
216
+ }
217
+ orderField , err = getField (field , disallowedFields , apiToModel )
209
218
if err != nil || (direction != "asc" && direction != "desc" ) {
210
219
err = errors .BadRequest ("bad order value '%s'" , userArg )
211
220
return
@@ -218,13 +227,15 @@ func cleanOrderBy(userArg string, disallowedFields map[string]string) (orderBy s
218
227
// ArgsToOrderBy returns cleaned orderBy list.
219
228
func ArgsToOrderBy (
220
229
orderByArgs []string ,
221
- disallowedFields map [string ]string ) (orderBy []string , err * errors.ServiceError ) {
230
+ disallowedFields []string ,
231
+ apiToModel map [string ]string ,
232
+ tableName string ) (orderBy []string , err * errors.ServiceError ) {
222
233
223
234
var order string
224
235
if len (orderByArgs ) != 0 {
225
236
orderBy = []string {}
226
237
for _ , o := range orderByArgs {
227
- order , err = cleanOrderBy (o , disallowedFields )
238
+ order , err = cleanOrderBy (o , disallowedFields , apiToModel , tableName )
228
239
if err != nil {
229
240
return
230
241
}
0 commit comments