Skip to content

Commit 0f3833c

Browse files
authored
Refactor method handling logic (#92)
The v2 semantics regarding methods has the consistent property that only one method is ever called for a particular Go type regardless of the calling context. For example, if T implements both MarshalJSON and MarshalText, then MarshalJSON will always be used and it could very well be the case that MarshalText does not exist. Unfortunately, v1 does not have this property where sometimes MarshalJSON is called, and othertimes MarshalText is called. Method calling today is handled using logic like: switch which := implementsWhich(t, ti1, ti2, ti3); which { case ti1: v = ... // expr1 case ti2: v = ... // expr2 case ti3: v = ... // expr3 } where the highest precedence interface case is taken and all other interfaces are ignored (since they can no longer matter). This programming pattern makes it difficult to fall back on other interface methods when trying to implement v1 semantics. Thus, we refactor the logic to look like: if implements(t, ti3) { v = ... // expr3 } if implements(t, ti2) { v = ... // expr2 } if implements(t, ti1) { v = ... // expr1 } In particular: * In contrast to the prior switch-statement, interfaces are checked individually in the reverse order where the lowest precedence interface is checked first. Unfortunately, switching the order of evaluation means that the diff in this commit looks unreasonably large. * If a higher precedence interface is applicable, it will override the prior marshal/unmarshal function (and thus be applicable with higher precedence). Note that every if-statement sets the same variable. There are no behavorial changes in this commit. In the code rewrite shown above, the ... expression is moved verbatim with no changes whatsoever.
1 parent ece7117 commit 0f3833c

File tree

3 files changed

+113
-101
lines changed

3 files changed

+113
-101
lines changed

arshal_default.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,7 @@ func makeBytesArshaler(t reflect.Type, fncs *arshaler) *arshaler {
302302
// to forcibly treat []namedByte as a []byte.
303303
marshalArray := fncs.marshal
304304
isNamedByte := t.Elem().PkgPath() != ""
305-
hasMarshaler := implementsWhich(t.Elem(), allMarshalerTypes...) != nil
305+
hasMarshaler := implementsAny(t.Elem(), allMarshalerTypes...)
306306
fncs.marshal = func(enc *jsontext.Encoder, va addressableValue, mo *jsonopts.Struct) error {
307307
if !mo.Flags.Get(jsonflags.FormatBytesWithLegacySemantics) && isNamedByte {
308308
return marshalArray(enc, va, mo) // treat as []T or [N]T

arshal_methods.go

+108-96
Original file line numberDiff line numberDiff line change
@@ -113,54 +113,43 @@ func makeMethodArshaler(fncs *arshaler, t reflect.Type) *arshaler {
113113
return fncs
114114
}
115115

116-
// Handle custom marshaler.
117-
switch which := implementsWhich(t, jsonMarshalerV2Type, jsonMarshalerV1Type, textAppenderType, textMarshalerType); which {
118-
case jsonMarshalerV2Type:
116+
if implements(t, textMarshalerType) {
119117
fncs.nonDefault = true
120118
fncs.marshal = func(enc *jsontext.Encoder, va addressableValue, mo *jsonopts.Struct) error {
121-
xe := export.Encoder(enc)
122-
prevDepth, prevLength := xe.Tokens.DepthLength()
123-
xe.Flags.Set(jsonflags.WithinArshalCall | 1)
124-
err := va.Addr().Interface().(MarshalerV2).MarshalJSONV2(enc, mo)
125-
xe.Flags.Set(jsonflags.WithinArshalCall | 0)
126-
currDepth, currLength := xe.Tokens.DepthLength()
127-
if (prevDepth != currDepth || prevLength+1 != currLength) && err == nil {
128-
err = errNonSingularValue
129-
}
130-
if err != nil {
119+
marshaler := va.Addr().Interface().(encoding.TextMarshaler)
120+
if err := export.Encoder(enc).AppendRaw('"', false, func(b []byte) ([]byte, error) {
121+
b2, err := marshaler.MarshalText()
122+
return append(b, b2...), err
123+
}); err != nil {
131124
err = wrapSkipFunc(err, "marshal method")
132-
if xe.Flags.Get(jsonflags.ReportLegacyErrorValues) {
133-
return internal.NewMarshalerError(va.Addr().Interface(), err, "MarshalJSONV2") // unlike unmarshal, always wrapped
125+
if export.Encoder(enc).Flags.Get(jsonflags.ReportLegacyErrorValues) {
126+
return internal.NewMarshalerError(va.Addr().Interface(), err, "MarshalText") // unlike unmarshal, always wrapped
134127
}
135-
if !export.IsIOError(err) {
136-
err = newSemanticErrorWithPosition(enc, t, prevDepth, prevLength, err)
128+
if !isSemanticError(err) && !export.IsIOError(err) {
129+
err = newMarshalErrorBefore(enc, t, err)
137130
}
138131
return err
139132
}
140133
return nil
141134
}
142-
case jsonMarshalerV1Type:
143-
fncs.nonDefault = true
144-
fncs.marshal = func(enc *jsontext.Encoder, va addressableValue, mo *jsonopts.Struct) error {
145-
marshaler := va.Addr().Interface().(MarshalerV1)
146-
val, err := marshaler.MarshalJSON()
147-
if err != nil {
148-
err = wrapSkipFunc(err, "marshal method")
149-
if export.Encoder(enc).Flags.Get(jsonflags.ReportLegacyErrorValues) {
150-
return internal.NewMarshalerError(va.Addr().Interface(), err, "MarshalJSON") // unlike unmarshal, always wrapped
151-
}
152-
err = newMarshalErrorBefore(enc, t, err)
153-
return collapseSemanticErrors(err)
154-
}
155-
if err := enc.WriteValue(val); err != nil {
156-
if isSyntacticError(err) {
157-
err = newMarshalErrorBefore(enc, t, err)
135+
// TODO(https://go.dev/issue/62384): Rely on encoding.TextAppender instead.
136+
if implements(t, appenderToType) && t.PkgPath() == "net/netip" {
137+
fncs.marshal = func(enc *jsontext.Encoder, va addressableValue, mo *jsonopts.Struct) error {
138+
appender := va.Addr().Interface().(interface{ AppendTo([]byte) []byte })
139+
if err := export.Encoder(enc).AppendRaw('"', false, func(b []byte) ([]byte, error) {
140+
return appender.AppendTo(b), nil
141+
}); err != nil {
142+
if !isSemanticError(err) && !export.IsIOError(err) {
143+
err = newMarshalErrorBefore(enc, t, err)
144+
}
145+
return err
158146
}
159-
return err
147+
return nil
160148
}
161-
return nil
162149
}
163-
case textAppenderType:
150+
}
151+
152+
if implements(t, textAppenderType) {
164153
fncs.nonDefault = true
165154
fncs.marshal = func(enc *jsontext.Encoder, va addressableValue, mo *jsonopts.Struct) (err error) {
166155
appender := va.Addr().Interface().(encodingTextAppender)
@@ -176,69 +165,92 @@ func makeMethodArshaler(fncs *arshaler, t reflect.Type) *arshaler {
176165
}
177166
return nil
178167
}
179-
case textMarshalerType:
168+
}
169+
170+
if implements(t, jsonMarshalerV1Type) {
180171
fncs.nonDefault = true
181172
fncs.marshal = func(enc *jsontext.Encoder, va addressableValue, mo *jsonopts.Struct) error {
182-
marshaler := va.Addr().Interface().(encoding.TextMarshaler)
183-
if err := export.Encoder(enc).AppendRaw('"', false, func(b []byte) ([]byte, error) {
184-
b2, err := marshaler.MarshalText()
185-
return append(b, b2...), err
186-
}); err != nil {
173+
marshaler := va.Addr().Interface().(MarshalerV1)
174+
val, err := marshaler.MarshalJSON()
175+
if err != nil {
187176
err = wrapSkipFunc(err, "marshal method")
188177
if export.Encoder(enc).Flags.Get(jsonflags.ReportLegacyErrorValues) {
189-
return internal.NewMarshalerError(va.Addr().Interface(), err, "MarshalText") // unlike unmarshal, always wrapped
178+
return internal.NewMarshalerError(va.Addr().Interface(), err, "MarshalJSON") // unlike unmarshal, always wrapped
190179
}
191-
if !isSemanticError(err) && !export.IsIOError(err) {
180+
err = newMarshalErrorBefore(enc, t, err)
181+
return collapseSemanticErrors(err)
182+
}
183+
if err := enc.WriteValue(val); err != nil {
184+
if isSyntacticError(err) {
192185
err = newMarshalErrorBefore(enc, t, err)
193186
}
194187
return err
195188
}
196189
return nil
197190
}
198-
// TODO(https://go.dev/issue/62384): Rely on encoding.TextAppender instead.
199-
if implementsWhich(t, appenderToType) != nil && t.PkgPath() == "net/netip" {
200-
fncs.marshal = func(enc *jsontext.Encoder, va addressableValue, mo *jsonopts.Struct) error {
201-
appender := va.Addr().Interface().(interface{ AppendTo([]byte) []byte })
202-
if err := export.Encoder(enc).AppendRaw('"', false, func(b []byte) ([]byte, error) {
203-
return appender.AppendTo(b), nil
204-
}); err != nil {
205-
if !isSemanticError(err) && !export.IsIOError(err) {
206-
err = newMarshalErrorBefore(enc, t, err)
207-
}
208-
return err
191+
}
192+
193+
if implements(t, jsonMarshalerV2Type) {
194+
fncs.nonDefault = true
195+
fncs.marshal = func(enc *jsontext.Encoder, va addressableValue, mo *jsonopts.Struct) error {
196+
xe := export.Encoder(enc)
197+
prevDepth, prevLength := xe.Tokens.DepthLength()
198+
xe.Flags.Set(jsonflags.WithinArshalCall | 1)
199+
err := va.Addr().Interface().(MarshalerV2).MarshalJSONV2(enc, mo)
200+
xe.Flags.Set(jsonflags.WithinArshalCall | 0)
201+
currDepth, currLength := xe.Tokens.DepthLength()
202+
if (prevDepth != currDepth || prevLength+1 != currLength) && err == nil {
203+
err = errNonSingularValue
204+
}
205+
if err != nil {
206+
err = wrapSkipFunc(err, "marshal method")
207+
if xe.Flags.Get(jsonflags.ReportLegacyErrorValues) {
208+
return internal.NewMarshalerError(va.Addr().Interface(), err, "MarshalJSONV2") // unlike unmarshal, always wrapped
209209
}
210-
return nil
210+
if !export.IsIOError(err) {
211+
err = newSemanticErrorWithPosition(enc, t, prevDepth, prevLength, err)
212+
}
213+
return err
211214
}
215+
return nil
212216
}
213217
}
214218

215-
// Handle custom unmarshaler.
216-
switch which := implementsWhich(t, jsonUnmarshalerV2Type, jsonUnmarshalerV1Type, textUnmarshalerType); which {
217-
case jsonUnmarshalerV2Type:
219+
if implements(t, textUnmarshalerType) {
218220
fncs.nonDefault = true
219221
fncs.unmarshal = func(dec *jsontext.Decoder, va addressableValue, uo *jsonopts.Struct) error {
220222
xd := export.Decoder(dec)
221-
prevDepth, prevLength := xd.Tokens.DepthLength()
222-
xd.Flags.Set(jsonflags.WithinArshalCall | 1)
223-
err := va.Addr().Interface().(UnmarshalerV2).UnmarshalJSONV2(dec, uo)
224-
xd.Flags.Set(jsonflags.WithinArshalCall | 0)
225-
currDepth, currLength := xd.Tokens.DepthLength()
226-
if (prevDepth != currDepth || prevLength+1 != currLength) && err == nil {
227-
err = errNonSingularValue
228-
}
223+
var flags jsonwire.ValueFlags
224+
val, err := xd.ReadValue(&flags)
229225
if err != nil {
226+
return err // must be a syntactic or I/O error
227+
}
228+
if val.Kind() == 'n' {
229+
if !uo.Flags.Get(jsonflags.MergeWithLegacySemantics) {
230+
va.SetZero()
231+
}
232+
return nil
233+
}
234+
if val.Kind() != '"' {
235+
return newUnmarshalErrorAfter(dec, t, errNonStringValue)
236+
}
237+
s := jsonwire.UnquoteMayCopy(val, flags.IsVerbatim())
238+
unmarshaler := va.Addr().Interface().(encoding.TextUnmarshaler)
239+
if err := unmarshaler.UnmarshalText(s); err != nil {
230240
err = wrapSkipFunc(err, "unmarshal method")
231-
if xd.Flags.Get(jsonflags.ReportLegacyErrorValues) {
241+
if export.Decoder(dec).Flags.Get(jsonflags.ReportLegacyErrorValues) {
232242
return err // unlike marshal, never wrapped
233243
}
234-
if !isSyntacticError(err) && !export.IsIOError(err) {
235-
err = newSemanticErrorWithPosition(dec, t, prevDepth, prevLength, err)
244+
if !isSemanticError(err) && !isSyntacticError(err) && !export.IsIOError(err) {
245+
err = newUnmarshalErrorAfter(dec, t, err)
236246
}
237247
return err
238248
}
239249
return nil
240250
}
241-
case jsonUnmarshalerV1Type:
251+
}
252+
253+
if implements(t, jsonUnmarshalerV1Type) {
242254
fncs.nonDefault = true
243255
fncs.unmarshal = func(dec *jsontext.Decoder, va addressableValue, uo *jsonopts.Struct) error {
244256
val, err := dec.ReadValue()
@@ -256,33 +268,27 @@ func makeMethodArshaler(fncs *arshaler, t reflect.Type) *arshaler {
256268
}
257269
return nil
258270
}
259-
case textUnmarshalerType:
271+
}
272+
273+
if implements(t, jsonUnmarshalerV2Type) {
260274
fncs.nonDefault = true
261275
fncs.unmarshal = func(dec *jsontext.Decoder, va addressableValue, uo *jsonopts.Struct) error {
262276
xd := export.Decoder(dec)
263-
var flags jsonwire.ValueFlags
264-
val, err := xd.ReadValue(&flags)
265-
if err != nil {
266-
return err // must be a syntactic or I/O error
267-
}
268-
if val.Kind() == 'n' {
269-
if !uo.Flags.Get(jsonflags.MergeWithLegacySemantics) {
270-
va.SetZero()
271-
}
272-
return nil
273-
}
274-
if val.Kind() != '"' {
275-
return newUnmarshalErrorAfter(dec, t, errNonStringValue)
277+
prevDepth, prevLength := xd.Tokens.DepthLength()
278+
xd.Flags.Set(jsonflags.WithinArshalCall | 1)
279+
err := va.Addr().Interface().(UnmarshalerV2).UnmarshalJSONV2(dec, uo)
280+
xd.Flags.Set(jsonflags.WithinArshalCall | 0)
281+
currDepth, currLength := xd.Tokens.DepthLength()
282+
if (prevDepth != currDepth || prevLength+1 != currLength) && err == nil {
283+
err = errNonSingularValue
276284
}
277-
s := jsonwire.UnquoteMayCopy(val, flags.IsVerbatim())
278-
unmarshaler := va.Addr().Interface().(encoding.TextUnmarshaler)
279-
if err := unmarshaler.UnmarshalText(s); err != nil {
285+
if err != nil {
280286
err = wrapSkipFunc(err, "unmarshal method")
281-
if export.Decoder(dec).Flags.Get(jsonflags.ReportLegacyErrorValues) {
287+
if xd.Flags.Get(jsonflags.ReportLegacyErrorValues) {
282288
return err // unlike marshal, never wrapped
283289
}
284-
if !isSemanticError(err) && !isSyntacticError(err) && !export.IsIOError(err) {
285-
err = newUnmarshalErrorAfter(dec, t, err)
290+
if !isSyntacticError(err) && !export.IsIOError(err) {
291+
err = newSemanticErrorWithPosition(dec, t, prevDepth, prevLength, err)
286292
}
287293
return err
288294
}
@@ -293,13 +299,19 @@ func makeMethodArshaler(fncs *arshaler, t reflect.Type) *arshaler {
293299
return fncs
294300
}
295301

296-
// implementsWhich is like t.Implements(ifaceType) for a list of interfaces,
302+
// implementsAny is like t.Implements(ifaceType) for a list of interfaces,
297303
// but checks whether either t or reflect.PointerTo(t) implements the interface.
298-
func implementsWhich(t reflect.Type, ifaceTypes ...reflect.Type) (which reflect.Type) {
304+
func implementsAny(t reflect.Type, ifaceTypes ...reflect.Type) bool {
299305
for _, ifaceType := range ifaceTypes {
300-
if t.Implements(ifaceType) || reflect.PointerTo(t).Implements(ifaceType) {
301-
return ifaceType
306+
if implements(t, ifaceType) {
307+
return true
302308
}
303309
}
304-
return nil
310+
return false
311+
}
312+
313+
// implements is like t.Implements(ifaceType) but checks whether
314+
// either t or reflect.PointerTo(t) implements the interface.
315+
func implements(t, ifaceType reflect.Type) bool {
316+
return t.Implements(ifaceType) || reflect.PointerTo(t).Implements(ifaceType)
305317
}

fields.go

+4-4
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ func makeStructFields(root reflect.Type) (sf structFields, serr *SemanticError)
123123
// Reject any types with custom serialization otherwise
124124
// it becomes impossible to know what sub-fields to inline.
125125
tf := indirectType(f.typ)
126-
if implementsWhich(tf, allMethodTypes...) != nil && tf != jsontextValueType {
126+
if implementsAny(tf, allMethodTypes...) && tf != jsontextValueType {
127127
serr = orErrorf(serr, t, "inlined Go struct field %s of type %s must not implement marshal or unmarshal methods", sf.Name, tf)
128128
continue // invalid inlined field; treat as ignored
129129
}
@@ -151,7 +151,7 @@ func makeStructFields(root reflect.Type) (sf structFields, serr *SemanticError)
151151
case tf == jsontextValueType:
152152
f.fncs = nil // specially handled in arshal_inlined.go
153153
case tf.Kind() == reflect.Map && tf.Key().Kind() == reflect.String:
154-
if implementsWhich(tf.Key(), allMethodTypes...) != nil {
154+
if implementsAny(tf.Key(), allMethodTypes...) {
155155
serr = orErrorf(serr, t, "inlined map field %s of type %s must have a string key that does not implement marshal or unmarshal methods", sf.Name, tf)
156156
continue // invalid inlined field; treat as ignored
157157
}
@@ -185,8 +185,8 @@ func makeStructFields(root reflect.Type) (sf structFields, serr *SemanticError)
185185
}
186186
// Unfortunately, methods on the unexported field
187187
// still cannot be called.
188-
if implementsWhich(tf, allMethodTypes...) != nil ||
189-
(f.omitzero && implementsWhich(tf, isZeroerType) != nil) {
188+
if implementsAny(tf, allMethodTypes...) ||
189+
(f.omitzero && implementsAny(tf, isZeroerType)) {
190190
serr = orErrorf(serr, t, "Go struct field %s is not exported for method calls", sf.Name)
191191
continue
192192
}

0 commit comments

Comments
 (0)