Skip to content

Commit 75ed180

Browse files
authored
feat(core): add argspec util to get corresponding type from argstype (#2850)
1 parent 8d354fc commit 75ed180

File tree

3 files changed

+71
-0
lines changed

3 files changed

+71
-0
lines changed

internal/core/arg_specs.go

+9
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@ package core
33
import (
44
"context"
55
"fmt"
6+
"reflect"
67
"strings"
78

89
"github.com/scaleway/scaleway-sdk-go/scw"
10+
"github.com/scaleway/scaleway-sdk-go/strcase"
911
"github.com/scaleway/scaleway-sdk-go/validation"
1012
)
1113

@@ -119,6 +121,13 @@ func (a *ArgSpec) ConflictWith(b *ArgSpec) bool {
119121
(a.OneOfGroup == b.OneOfGroup)
120122
}
121123

124+
// GetArgsTypeField returns the type of the argument in the given ArgsType
125+
func (a *ArgSpec) GetArgsTypeField(argsType reflect.Type) (reflect.Type, error) {
126+
argSpecGoName := strcase.ToPublicGoName(a.Name)
127+
128+
return getTypeForFieldByName(argsType, strings.Split(argSpecGoName, "."))
129+
}
130+
122131
type DefaultFunc func(ctx context.Context) (value string, doc string)
123132

124133
func ZoneArgSpec(zones ...scw.Zone) *ArgSpec {

internal/core/arg_specs_test.go

+33
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package core
22

33
import (
4+
"reflect"
45
"testing"
56

67
"github.com/alecthomas/assert"
@@ -36,3 +37,35 @@ func TestOneOf(t *testing.T) {
3637
assert.False(t, a.ConflictWith(c))
3738
assert.False(t, e.ConflictWith(e))
3839
}
40+
41+
func TestArgSpecGetArgsTypeField(t *testing.T) {
42+
data := struct {
43+
Field string
44+
FieldStruct struct {
45+
NestedField int
46+
}
47+
FieldSlice []float32
48+
FieldMap map[string]bool
49+
}{}
50+
dataType := reflect.TypeOf(data)
51+
52+
fieldSpec := ArgSpec{Name: "field"}
53+
typ, err := fieldSpec.GetArgsTypeField(dataType)
54+
assert.Nil(t, err)
55+
assert.Equal(t, reflect.TypeOf("string"), typ, "%s is not string", typ.Name())
56+
57+
fieldSpec = ArgSpec{Name: "field-struct.nested-field"}
58+
typ, err = fieldSpec.GetArgsTypeField(dataType)
59+
assert.Nil(t, err)
60+
assert.Equal(t, reflect.TypeOf(int(1)), typ, "%s is not int", typ.Name())
61+
62+
fieldSpec = ArgSpec{Name: "field-slice.{index}"}
63+
typ, err = fieldSpec.GetArgsTypeField(dataType)
64+
assert.Nil(t, err)
65+
assert.Equal(t, reflect.TypeOf(float32(1)), typ, "%s is not float32", typ.Name())
66+
67+
fieldSpec = ArgSpec{Name: "field-map.{key}"}
68+
typ, err = fieldSpec.GetArgsTypeField(dataType)
69+
assert.Nil(t, err)
70+
assert.Equal(t, reflect.TypeOf(true), typ, "%s is not bool", typ.Name())
71+
}

internal/core/reflect.go

+29
Original file line numberDiff line numberDiff line change
@@ -109,3 +109,32 @@ func getValuesForFieldByName(value reflect.Value, parts []string) (values []refl
109109

110110
return nil, fmt.Errorf("case is not handled")
111111
}
112+
113+
// getTypeForFieldByName recursively search for fields in an ArgsType
114+
// The search is based on the name of the field.
115+
func getTypeForFieldByName(value reflect.Type, parts []string) (reflect.Type, error) {
116+
if len(parts) == 0 {
117+
return value, nil
118+
}
119+
120+
switch value.Kind() {
121+
case reflect.Ptr:
122+
return getTypeForFieldByName(value.Elem(), parts)
123+
124+
case reflect.Slice:
125+
return getTypeForFieldByName(value.Elem(), parts[1:])
126+
127+
case reflect.Map:
128+
return getTypeForFieldByName(value.Elem(), parts[1:])
129+
130+
case reflect.Struct:
131+
fieldName := strcase.ToPublicGoName(parts[0])
132+
field, hasField := value.FieldByName(fieldName)
133+
if !hasField {
134+
return nil, fmt.Errorf("field %v does not exist for %v", fieldName, value.Name())
135+
}
136+
return getTypeForFieldByName(field.Type, parts[1:])
137+
}
138+
139+
return nil, fmt.Errorf("type kind %s is not handled", value.Kind().String())
140+
}

0 commit comments

Comments
 (0)