Skip to content

Commit 3790b20

Browse files
committed
feat: check default value type
1 parent e1caf7d commit 3790b20

File tree

3 files changed

+181
-27
lines changed

3 files changed

+181
-27
lines changed

generator/golang/resolver.go

+165-27
Original file line numberDiff line numberDiff line change
@@ -157,20 +157,29 @@ func (r *Resolver) getContainerTypeName(g *Scope, t *parser.Type) (name string,
157157
// getIDValue returns the literal representation of a const value.
158158
// The extra must be associated with g and from a const value that has
159159
// type parser.ConstType_ConstIdentifier.
160-
func (r *Resolver) getIDValue(g *Scope, extra *parser.ConstValueExtra) (v string, ok bool) {
160+
func (r *Resolver) getIDValue(g *Scope, extra *parser.ConstValueExtra) (v string, t *parser.Type, ok bool) {
161161
if extra.Index == -1 {
162162
if extra.IsEnum {
163163
enum, ok := g.ast.GetEnum(extra.Sel)
164164
if !ok {
165-
return "", false
165+
return "", t, false
166166
}
167167
if en := g.Enum(enum.Name); en != nil {
168168
if ev := en.Value(extra.Name); ev != nil {
169169
v = ev.GoName().String()
170+
t = &parser.Type{
171+
Name: enum.Name,
172+
Category: parser.Category_Enum,
173+
}
170174
}
171175
}
172176
} else {
173177
v = g.globals.Get(extra.Name)
178+
con, ok := g.ast.GetConstant(extra.Name)
179+
if !ok {
180+
return "", t, false
181+
}
182+
t = con.Type
174183
}
175184
} else {
176185
g = g.includes[extra.Index].Scope
@@ -186,7 +195,7 @@ func (r *Resolver) getIDValue(g *Scope, extra *parser.ConstValueExtra) (v string
186195
pkg := r.root.includeIDL(r.util, g.ast)
187196
v = pkg + "." + v
188197
}
189-
return v, v != ""
198+
return v, t, v != ""
190199
}
191200

192201
// ResolveConst returns the initialization code for a constant or a default value.
@@ -239,10 +248,14 @@ func (r *Resolver) onBool(g *Scope, name string, t *parser.Type, v *parser.Const
239248
return s, nil
240249
}
241250

242-
if val, ok := r.getIDValue(g, v.Extra); ok {
243-
return val, nil
251+
val, cate, ok := r.getIDValue(g, v.Extra)
252+
if !ok {
253+
return "", fmt.Errorf("undefined value: %q", s)
244254
}
245-
return "", fmt.Errorf("undefined value: %q", s)
255+
if err := r.typeMatch(t, cate, name); err != nil {
256+
return "", err
257+
}
258+
return val, nil
246259
}
247260
return "", errTypeMissMatch(name, t, v)
248261
}
@@ -260,12 +273,16 @@ func (r *Resolver) onInt(g *Scope, name string, t *parser.Type, v *parser.ConstV
260273
if s == "false" {
261274
return "0", nil
262275
}
263-
if val, ok := r.getIDValue(g, v.Extra); ok {
264-
goType, _ := r.getTypeName(g, t)
265-
val = fmt.Sprintf("%s(%s)", goType, val)
266-
return val, nil
276+
val, cate, ok := r.getIDValue(g, v.Extra)
277+
if !ok {
278+
return "", fmt.Errorf("undefined value: %q", s)
267279
}
268-
return "", fmt.Errorf("undefined value: %q", s)
280+
if err := r.typeMatch(t, cate, name); err != nil {
281+
return "", err
282+
}
283+
goType, _ := r.getTypeName(g, t)
284+
val = fmt.Sprintf("%s(%s)", goType, val)
285+
return val, nil
269286
}
270287
return "", errTypeMissMatch(name, t, v)
271288
}
@@ -286,10 +303,14 @@ func (r *Resolver) onDouble(g *Scope, name string, t *parser.Type, v *parser.Con
286303
if s == "false" {
287304
return "0.0", nil
288305
}
289-
if val, ok := r.getIDValue(g, v.Extra); ok {
290-
return val, nil
306+
val, cate, ok := r.getIDValue(g, v.Extra)
307+
if !ok {
308+
return "", fmt.Errorf("undefined value: %q", s)
309+
}
310+
if err := r.typeMatch(t, cate, name); err != nil {
311+
return "", err
291312
}
292-
return "", fmt.Errorf("undefined value: %q", s)
313+
return val, nil
293314
}
294315
return "", errTypeMissMatch(name, t, v)
295316
}
@@ -310,10 +331,14 @@ func (r *Resolver) onStrBin(g *Scope, name string, t *parser.Type, v *parser.Con
310331
break
311332
}
312333

313-
if val, ok := r.getIDValue(g, v.Extra); ok {
314-
return val, nil
334+
val, cate, ok := r.getIDValue(g, v.Extra)
335+
if !ok {
336+
return "", fmt.Errorf("undefined value: %q", s)
315337
}
316-
return "", fmt.Errorf("undefined value: %q", s)
338+
if err := r.typeMatch(t, cate, name); err != nil {
339+
return "", err
340+
}
341+
return val, nil
317342
default:
318343
}
319344
return "", errTypeMissMatch(name, t, v)
@@ -324,10 +349,14 @@ func (r *Resolver) onEnum(g *Scope, name string, t *parser.Type, v *parser.Const
324349
case parser.ConstType_ConstInt:
325350
return fmt.Sprintf("%d", v.TypedValue.GetInt()), nil
326351
case parser.ConstType_ConstIdentifier:
327-
val, ok := r.getIDValue(g, v.Extra)
328-
if ok {
329-
return val, nil
352+
val, cate, ok := r.getIDValue(g, v.Extra)
353+
if !ok {
354+
return "", fmt.Errorf("undefined value: %q", v.TypedValue.GetIdentifier())
355+
}
356+
if err := r.typeMatch(t, cate, name); err != nil {
357+
return "", err
330358
}
359+
return val, nil
331360
}
332361
return "", fmt.Errorf("expect const value for %q is a int or enum, got %+v", name, v)
333362
}
@@ -354,8 +383,14 @@ func (r *Resolver) onSetOrList(g *Scope, name string, t *parser.Type, v *parser.
354383
return fmt.Sprintf("%s{\n%s\n}", goType, strings.Join(ss, "\n")), nil
355384

356385
case parser.ConstType_ConstIdentifier:
357-
val, ok := r.getIDValue(g, v.Extra)
358-
if ok && val != "true" && val != "false" {
386+
val, cate, ok := r.getIDValue(g, v.Extra)
387+
if !ok {
388+
return "", fmt.Errorf("undefined value: %q", v.TypedValue.GetIdentifier())
389+
}
390+
if err := r.typeMatch(t, cate, name); err != nil {
391+
return "", err
392+
}
393+
if val != "true" && val != "false" {
359394
return val, nil
360395
}
361396

@@ -391,8 +426,14 @@ func (r *Resolver) onMap(g *Scope, name string, t *parser.Type, v *parser.ConstV
391426
return fmt.Sprintf("%s{\n%s\n}", goType, strings.Join(kvs, "\n")), nil
392427

393428
case parser.ConstType_ConstIdentifier:
394-
val, ok := r.getIDValue(g, v.Extra)
395-
if ok && val != "true" && val != "false" {
429+
val, cate, ok := r.getIDValue(g, v.Extra)
430+
if !ok {
431+
return "", fmt.Errorf("undefined value: %q", v.TypedValue.GetIdentifier())
432+
}
433+
if err := r.typeMatch(t, cate, name); err != nil {
434+
return "", err
435+
}
436+
if val != "true" && val != "false" {
396437
return val, nil
397438
}
398439
}
@@ -406,8 +447,14 @@ func (r *Resolver) onStructLike(g *Scope, name string, t *parser.Type, v *parser
406447
return "", err
407448
}
408449
if v.Type == parser.ConstType_ConstIdentifier {
409-
val, ok := r.getIDValue(g, v.Extra)
410-
if ok && val != "true" && val != "false" {
450+
val, cate, ok := r.getIDValue(g, v.Extra)
451+
if !ok {
452+
return "", fmt.Errorf("undefined value: %q", v.TypedValue.GetIdentifier())
453+
}
454+
if err := r.typeMatch(t, cate, name); err != nil {
455+
return "", err
456+
}
457+
if val != "true" && val != "false" {
411458
return val, nil
412459
}
413460
}
@@ -450,7 +497,7 @@ func (r *Resolver) onStructLike(g *Scope, name string, t *parser.Type, v *parser
450497
}
451498

452499
if NeedRedirect(f) {
453-
if f.Type.Category.IsBaseType() {
500+
if IsBaseType(f.Type) {
454501
// a trick to create pointers without temporary variables
455502
val = fmt.Sprintf("(&struct{x %s}{%s}).x", typ, val)
456503
}
@@ -493,6 +540,97 @@ func (r *Resolver) getStructLike(g *Scope, t *parser.Type) (f *Scope, s *parser.
493540
return
494541
}
495542

543+
func (r *Resolver) typeMatch(field *parser.Type, value *parser.Type, name string) error {
544+
if field.Category.IsBool() {
545+
if !value.Category.IsBool() {
546+
return fmt.Errorf("type of %s is not bool type", name)
547+
}
548+
return nil
549+
}
550+
if field.Category.IsInteger() {
551+
if !value.Category.IsDigital() {
552+
return fmt.Errorf("type of %s is not digital type", name)
553+
}
554+
return nil
555+
}
556+
if field.Category.IsDouble() {
557+
if !value.Category.IsDouble() {
558+
return fmt.Errorf("type of %s is not double type", name)
559+
}
560+
return nil
561+
}
562+
if field.Category.IsString() {
563+
if !value.Category.IsString() {
564+
return fmt.Errorf("type of %s is not string type", name)
565+
}
566+
return nil
567+
}
568+
if field.Category.IsBinary() {
569+
if !value.Category.IsString() && !value.Category.IsBinary() {
570+
return fmt.Errorf("type of %s is not string or binary type", name)
571+
}
572+
return nil
573+
}
574+
if field.Category.IsEnum() {
575+
if !value.Category.IsEnum() {
576+
return fmt.Errorf("type of %s is not enum type", name)
577+
}
578+
if field.NameWithReference() != value.NameWithReference() {
579+
return fmt.Errorf("enum type of %s is not %s, %s", name, field.NameWithReference(), value.NameWithReference())
580+
}
581+
return nil
582+
}
583+
if field.Category.IsSet() {
584+
if !value.Category.IsSet() {
585+
return fmt.Errorf("type of %s is not set type", name)
586+
}
587+
return r.typeMatch(field.ValueType, value.ValueType, name)
588+
}
589+
if field.Category.IsList() {
590+
if !value.Category.IsList() && !value.Category.IsSet() {
591+
return fmt.Errorf("type of %s is not set or list type", name)
592+
}
593+
return r.typeMatch(field.ValueType, value.ValueType, name)
594+
}
595+
if field.Category.IsMap() {
596+
if !value.Category.IsMap() {
597+
return fmt.Errorf("type of %s is not map type", name)
598+
}
599+
if err := r.typeMatch(field.KeyType, value.KeyType, name); err != nil {
600+
return err
601+
}
602+
return r.typeMatch(field.ValueType, value.ValueType, name)
603+
}
604+
if field.Category.IsStruct() {
605+
if !value.Category.IsStruct() {
606+
return fmt.Errorf("type of %s is not struct type", name)
607+
}
608+
if field.NameWithReference() != value.NameWithReference() {
609+
return fmt.Errorf("type of %s is not %s", name, field.NameWithReference())
610+
}
611+
return nil
612+
}
613+
if field.Category.IsUnion() {
614+
if !value.Category.IsUnion() {
615+
return fmt.Errorf("type of %s is not union type", name)
616+
}
617+
if field.NameWithReference() != value.NameWithReference() {
618+
return fmt.Errorf("type of %s is not %s", name, field.NameWithReference())
619+
}
620+
return nil
621+
}
622+
if field.Category.IsException() {
623+
if !value.Category.IsException() {
624+
return fmt.Errorf("type of %s is not exception type", name)
625+
}
626+
if field.NameWithReference() != value.NameWithReference() {
627+
return fmt.Errorf("type of %s is not %s", name, field.NameWithReference())
628+
}
629+
return nil
630+
}
631+
return fmt.Errorf("type of %s not matched %s", name, field.NameWithReference())
632+
}
633+
496634
func (r *Resolver) bin2str(t *parser.Type) *parser.Type {
497635
if t.Category == parser.Category_Binary {
498636
r := *t

parser/AST-extend-category.go

+9
Original file line numberDiff line numberDiff line change
@@ -118,3 +118,12 @@ func (p Category) IsContainerType() bool {
118118
func (p Category) IsStructLike() bool {
119119
return p == Category_Struct || p == Category_Union || p == Category_Exception
120120
}
121+
122+
func (p Category) IsInteger() bool {
123+
return p == Category_Byte || p == Category_I16 || p == Category_I32 || p == Category_I64
124+
}
125+
126+
func (p Category) IsDigital() bool {
127+
return p == Category_Byte || p == Category_I16 || p == Category_I32 ||
128+
p == Category_I64 || p == Category_Double || p == Category_Enum
129+
}

parser/AST-extend.go

+7
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,13 @@ func (t *Type) String() string {
122122
return t.Name
123123
}
124124

125+
func (t *Type) NameWithReference() string {
126+
if t.Reference != nil && t.Reference.Name != "" {
127+
return t.Reference.Name
128+
}
129+
return t.Name
130+
}
131+
125132
// GetField returns a field of the struct-like that matches the name.
126133
func (s *StructLike) GetField(name string) (*Field, bool) {
127134
for _, fi := range s.Fields {

0 commit comments

Comments
 (0)