Skip to content

Commit

Permalink
Merge pull request #183 from mercari/fix-msg-args-type-conversion
Browse files Browse the repository at this point in the history
Fix message argument type conversion
  • Loading branch information
goccy authored May 27, 2024
2 parents a8eb59c + eef5a65 commit 737b37b
Show file tree
Hide file tree
Showing 9 changed files with 432 additions and 116 deletions.
50 changes: 37 additions & 13 deletions generator/code_generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -2803,21 +2803,25 @@ func toValue(file *File, typ *resolver.Type, value *resolver.Value) string {
func arguments(file *File, expr *resolver.VariableExpr) []*Argument {
var (
isRequestArgument bool
msgArg *resolver.Message
args []*resolver.Argument
)
switch {
case expr.Call != nil:
isRequestArgument = true
args = expr.Call.Request.Args
case expr.Message != nil:
msg := expr.Message.Message
if msg.Rule != nil {
msgArg = msg.Rule.MessageArgument
}
args = expr.Message.Args
case expr.By != nil:
return nil
}

var generateArgs []*Argument
for _, arg := range args {
for _, generatedArg := range argument(file, arg) {
for _, generatedArg := range argument(file, msgArg, arg) {
protofmt := arg.ProtoFormat(resolver.DefaultProtoFormatOption, isRequestArgument)
if protofmt != "" {
generatedArg.ProtoComment = "// " + protofmt
Expand All @@ -2828,7 +2832,7 @@ func arguments(file *File, expr *resolver.VariableExpr) []*Argument {
return generateArgs
}

func argument(file *File, arg *resolver.Argument) []*Argument {
func argument(file *File, msgArg *resolver.Message, arg *resolver.Argument) []*Argument {
if arg.Value.Const != nil {
return []*Argument{
{
Expand Down Expand Up @@ -2861,7 +2865,7 @@ func argument(file *File, arg *resolver.Argument) []*Argument {
})
}
}
fromType := arg.Value.CEL.Out
fromType := arg.Value.Type()
var toType *resolver.Type
if arg.Type != nil {
toType = arg.Type
Expand All @@ -2873,25 +2877,37 @@ func argument(file *File, arg *resolver.Argument) []*Argument {
fromText := file.toTypeText(fromType)

var (
argValue = "v"
zeroValue string
argType string
requiredCast = requiredCast(fromType, toType)
argValue = "v"
zeroValue string
argType string
isRequiredCast bool
)
switch fromType.Kind {
case types.Message:
zeroValue = toMakeZeroValue(file, fromType)
argType = fromText
if requiredCast {
isRequiredCast = requiredCast(fromType, toType)
if isRequiredCast {
castFuncName := castFuncName(fromType, toType)
argValue = fmt.Sprintf("s.%s(%s)", castFuncName, argValue)
}
case types.Enum:
zeroValue = toMakeZeroValue(file, fromType)
argType = fromText
if requiredCast {
castFuncName := castFuncName(fromType, toType)
argValue = fmt.Sprintf("s.%s(%s)", castFuncName, argValue)
if msgArg != nil && arg.Name != "" {
msgArgField := msgArg.Field(arg.Name)
isRequiredCast = msgArgField != nil && msgArgField.Type.Kind != toType.Kind
if isRequiredCast {
castFuncName := castFuncName(fromType, msgArgField.Type)
argValue = fmt.Sprintf("s.%s(%s)", castFuncName, argValue)
}
}
if !isRequiredCast {
isRequiredCast = requiredCast(fromType, toType)
if isRequiredCast {
castFuncName := castFuncName(fromType, toType)
argValue = fmt.Sprintf("s.%s(%s)", castFuncName, argValue)
}
}
default:
// Since fromType is a primitive type, type conversion is possible on the CEL side.
Expand All @@ -2903,6 +2919,14 @@ func argument(file *File, arg *resolver.Argument) []*Argument {
} else {
argType = toText
}
if msgArg != nil && arg.Name != "" {
msgArgField := msgArg.Field(arg.Name)
isRequiredCast = msgArgField != nil && msgArgField.Type.Kind != toType.Kind
if isRequiredCast {
castFuncName := castFuncName(toType, msgArgField.Type)
argValue = fmt.Sprintf("s.%s(%s)", castFuncName, argValue)
}
}
}
return []*Argument{
{
Expand All @@ -2915,7 +2939,7 @@ func argument(file *File, arg *resolver.Argument) []*Argument {
OneofName: oneofName,
OneofFieldName: oneofFieldName,
If: arg.If,
RequiredCast: requiredCast,
RequiredCast: isRequiredCast,
},
}
}
Expand Down
Loading

0 comments on commit 737b37b

Please sign in to comment.