Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 103 additions & 41 deletions pkg/grpc/proto_trace_attributes_extractor.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,10 @@ func NewProtoTraceAttributesExtractor(configuration map[string]*configuration.Tr
}
for methodName, methodConfiguration := range configuration {
pe.methods[methodName] = &methodTraceAttributesExtractor{
errorLogger: errorLogger,
requestAttributes: methodConfiguration.AttributesFromFirstRequestMessage,
responseAttributes: methodConfiguration.AttributesFromFirstResponseMessage,
errorLogger: errorLogger,
requestAttributes: methodConfiguration.AttributesFromRequestMessage,
responseAttributes: methodConfiguration.AttributesFromResponseMessage,
includeFollowupMessages: methodConfiguration.IncludeFollowupMessages,
}
}
return pe
Expand All @@ -66,10 +67,10 @@ func (pe *ProtoTraceAttributesExtractor) InterceptUnaryClient(ctx context.Contex
if span == nil {
return invoker(ctx, method, req, reply, cc, opts...)
}
me.processRequest(span, req)
me.applyAttributes(span, me.requestAttributesFor(req))
err := invoker(ctx, method, req, reply, cc, opts...)
if err == nil {
me.processResponse(span, reply)
me.applyAttributes(span, me.responseAttributesFor(reply))
}
return err
}
Expand Down Expand Up @@ -113,10 +114,10 @@ func (pe *ProtoTraceAttributesExtractor) InterceptUnaryServer(ctx context.Contex
if span == nil {
return handler(ctx, req)
}
me.processRequest(span, req)
me.applyAttributes(span, me.requestAttributesFor(req))
resp, err := handler(ctx, req)
if err == nil {
me.processResponse(span, resp)
me.applyAttributes(span, me.responseAttributesFor(resp))
}
return resp, err
}
Expand Down Expand Up @@ -156,24 +157,26 @@ type methodTraceAttributesExtractor struct {
responseAttributes []string
responseOnce sync.Once
responseExtractor directionTraceAttributesExtractor

includeFollowupMessages bool
}

func (me *methodTraceAttributesExtractor) processRequest(span trace.Span, req interface{}) {
func (me *methodTraceAttributesExtractor) requestAttributesFor(req interface{}) []attribute.KeyValue {
me.requestOnce.Do(func() {
// First time we see an RPC message going from the
// client to the server.
me.requestExtractor.initialize("request", me.requestAttributes, req, me.errorLogger)
})
me.requestExtractor.gatherAttributes(span, req)
return me.requestExtractor.extractAttributes(req)
}

func (me *methodTraceAttributesExtractor) processResponse(span trace.Span, resp interface{}) {
func (me *methodTraceAttributesExtractor) responseAttributesFor(resp interface{}) []attribute.KeyValue {
me.responseOnce.Do(func() {
// First time we see an RPC message going from the
// server to the client.
me.responseExtractor.initialize("response", me.responseAttributes, resp, me.errorLogger)
})
me.responseExtractor.gatherAttributes(span, resp)
return me.responseExtractor.extractAttributes(resp)
}

// methodTraceAttributesExtractor is the bookkeeping that needs to be
Expand All @@ -199,15 +202,16 @@ func (de *directionTraceAttributesExtractor) initialize(attributePrefix string,
}
}

func (de *directionTraceAttributesExtractor) gatherAttributes(span trace.Span, m interface{}) {
if len(de.attributeExtractors) > 0 {
mProtoReflect := m.(proto.Message).ProtoReflect()
attributes := make([]attribute.KeyValue, 0, len(de.attributeExtractors))
for _, attributeExtractor := range de.attributeExtractors {
attributes = attributeExtractor(mProtoReflect, attributes)
}
span.SetAttributes(attributes...)
func (de *directionTraceAttributesExtractor) extractAttributes(m interface{}) []attribute.KeyValue {
if len(de.attributeExtractors) == 0 {
return nil
}
mProtoReflect := m.(proto.Message).ProtoReflect()
attributes := make([]attribute.KeyValue, 0, len(de.attributeExtractors))
for _, attributeExtractor := range de.attributeExtractors {
attributes = attributeExtractor(mProtoReflect, attributes)
}
return attributes
}

// attributeExtractor is a function type that is capable of extracting a
Expand Down Expand Up @@ -392,63 +396,121 @@ func newAttributeExtractor(descriptor protoreflect.MessageDescriptor, remainingF
}

// attributeExtractingClientStream is a decorator for grpc.ClientStream
// that extracts trace span attributes from the first request and
// response message in a streaming RPC.
// that extracts trace span attributes from streaming request and
// response messages.
type attributeExtractingClientStream struct {
grpc.ClientStream
method *methodTraceAttributesExtractor
span trace.Span
gotFirstRequest bool
gotFirstResponse bool
requestIndex uint64
responseIndex uint64
}

func (cs *attributeExtractingClientStream) SendMsg(m interface{}) error {
if !cs.gotFirstRequest {
isFirstMessage := !cs.gotFirstRequest
if isFirstMessage {
cs.gotFirstRequest = true
cs.method.processRequest(cs.span, m)
} else if !cs.method.includeFollowupMessages {
return cs.ClientStream.SendMsg(m)
}
attributes := cs.method.requestAttributesFor(m)
if isFirstMessage {
cs.method.applyAttributes(cs.span, attributes)
}
err := cs.ClientStream.SendMsg(m)
if err == nil {
cs.requestIndex++
addMessageEvent(cs.span, "out", cs.requestIndex, attributes)
}
return cs.ClientStream.SendMsg(m)
return err
}

func (cs *attributeExtractingClientStream) RecvMsg(m interface{}) error {
if !cs.gotFirstResponse {
if err := cs.ClientStream.RecvMsg(m); err != nil {
return err
}
if err := cs.ClientStream.RecvMsg(m); err != nil {
return err
}
isFirstMessage := !cs.gotFirstResponse
if isFirstMessage {
cs.gotFirstResponse = true
cs.method.processResponse(cs.span, m)
} else if !cs.method.includeFollowupMessages {
return nil
}
return cs.ClientStream.RecvMsg(m)
attributes := cs.method.responseAttributesFor(m)
if isFirstMessage {
cs.method.applyAttributes(cs.span, attributes)
}
cs.responseIndex++
addMessageEvent(cs.span, "in", cs.responseIndex, attributes)
return nil
}

// attributeExtractingServerStream is a decorator for grpc.ServerStream
// that extracts trace span attributes from the first request and
// response message in a streaming RPC.
// that extracts trace span attributes from streaming request and
// response messages.
type attributeExtractingServerStream struct {
grpc.ServerStream
method *methodTraceAttributesExtractor
span trace.Span
gotFirstRequest bool
gotFirstResponse bool
requestIndex uint64
responseIndex uint64
}

func (cs *attributeExtractingServerStream) RecvMsg(m interface{}) error {
if !cs.gotFirstRequest {
if err := cs.ServerStream.RecvMsg(m); err != nil {
return err
}
if err := cs.ServerStream.RecvMsg(m); err != nil {
return err
}
isFirstMessage := !cs.gotFirstRequest
if isFirstMessage {
cs.gotFirstRequest = true
cs.method.processRequest(cs.span, m)
} else if !cs.method.includeFollowupMessages {
return nil
}
return cs.ServerStream.RecvMsg(m)
attributes := cs.method.requestAttributesFor(m)
if isFirstMessage {
cs.method.applyAttributes(cs.span, attributes)
}
cs.requestIndex++
addMessageEvent(cs.span, "in", cs.requestIndex, attributes)
return nil
}

func (cs *attributeExtractingServerStream) SendMsg(m interface{}) error {
if !cs.gotFirstResponse {
isFirstMessage := !cs.gotFirstResponse
if isFirstMessage {
cs.gotFirstResponse = true
cs.method.processResponse(cs.span, m)
} else if !cs.method.includeFollowupMessages {
return cs.ServerStream.SendMsg(m)
}
attributes := cs.method.responseAttributesFor(m)
if isFirstMessage {
cs.method.applyAttributes(cs.span, attributes)
}
err := cs.ServerStream.SendMsg(m)
if err == nil {
cs.responseIndex++
addMessageEvent(cs.span, "out", cs.responseIndex, attributes)
}
return err
}

func (methodTraceAttributesExtractor) applyAttributes(span trace.Span, attributes []attribute.KeyValue) {
if span == nil || !span.IsRecording() || len(attributes) == 0 {
return
}
span.SetAttributes(attributes...)
}

func addMessageEvent(span trace.Span, direction string, index uint64, attributes []attribute.KeyValue) {
if span == nil || !span.IsRecording() {
return
}
return cs.ServerStream.SendMsg(m)
attributes = append(attributes,
attribute.String("grpc.message.direction", direction),
attribute.Int64("grpc.message.index", int64(index)),
)
span.AddEvent("grpc.message", trace.WithAttributes(attributes...))
}
Loading