Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

server: add embedding requirement to ServerTransportStream #7798

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions internal/transport/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,14 @@
contentSubtype string
}

// StreamDelegate creates a requirement that grpc.ServerTransportStream can use
// to ensure proper delegation to an embedded stream.
type StreamDelegate interface {
mustEmbedDelegateStream()
}

func (s *Stream) mustEmbedDelegateStream() {}

Check warning on line 355 in internal/transport/transport.go

View check run for this annotation

Codecov / codecov/patch

internal/transport/transport.go#L355

Added line #L355 was not covered by tests

// isHeaderSent is only valid on the server-side.
func (s *Stream) isHeaderSent() bool {
return atomic.LoadUint32(&s.headerSent) == 1
Expand Down
16 changes: 3 additions & 13 deletions internal/xds/rbac/rbac_engine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1784,7 +1784,8 @@ func (s) TestChainEngine(t *testing.T) {
}
ctx = peer.NewContext(ctx, data.rpcData.peerInfo)
stream := &ServerTransportStreamWithMethod{
method: data.rpcData.fullMethod,
ServerTransportStream: nil, // We have no delegate because this test doesn't actually use the stream.
method: data.rpcData.fullMethod,
}

ctx = grpc.NewContextWithServerTransportStream(ctx, stream)
Expand All @@ -1805,25 +1806,14 @@ func (s) TestChainEngine(t *testing.T) {
}

type ServerTransportStreamWithMethod struct {
grpc.ServerTransportStream
method string
}

func (sts *ServerTransportStreamWithMethod) Method() string {
return sts.method
}

func (sts *ServerTransportStreamWithMethod) SetHeader(metadata.MD) error {
return nil
}

func (sts *ServerTransportStreamWithMethod) SendHeader(metadata.MD) error {
return nil
}

func (sts *ServerTransportStreamWithMethod) SetTrailer(metadata.MD) error {
return nil
}

// An audit logger that will log to the auditEvents slice.
type TestAuditLoggerBuffer struct {
auditEvents *[]*audit.Event
Expand Down
53 changes: 44 additions & 9 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -1844,7 +1844,8 @@
type streamKey struct{}

// NewContextWithServerTransportStream creates a new context from ctx and
// attaches stream to it.
// attaches stream to it. stream must embed a delegate stream, typically
// obtained by calling ServerTransportStreamFromContext first.
//
// # Experimental
//
Expand All @@ -1854,6 +1855,42 @@
return context.WithValue(ctx, streamKey{}, stream)
}

// NewContextWithUnstableServerTransportStream creates a new context from ctx
// and attaches stream to it.
//
// # Unstable
//
// Notice: This API is UNSTABLE and is expected to change over time. Use
// NewContextWithServerTransportStream instead for a stable version.
func NewContextWithUnstableServerTransportStream(ctx context.Context, stream UnstableServerTransportStream) context.Context {
return context.WithValue(ctx, streamKey{}, stableStream{UnstableServerTransportStream: stream})

Check warning on line 1866 in server.go

View check run for this annotation

Codecov / codecov/patch

server.go#L1865-L1866

Added lines #L1865 - L1866 were not covered by tests
}

type stableStream struct {
UnstableServerTransportStream
transport.StreamDelegate
}

// UnstableServerTransportStream is a minimal interface that a transport stream
// must implement. This can be used to mock an actual transport stream for tests
// of handler code that use, for example, grpc.SetHeader (which requires some
// stream to be in context).
//
// See also NewContextWithUnstableServerTransportStream.
//
// # Unstable
//
// Notice: This API is UNSTABLE and is expected to change over time. Use
// ServerTransportStream instead for a stable version.
type UnstableServerTransportStream interface {
Method() string
SetHeader(md metadata.MD) error
SendHeader(md metadata.MD) error
SetTrailer(md metadata.MD) error
SetSendCompress(name string) error
ClientAdvertisedCompressors() []string
}

// ServerTransportStream is a minimal interface that a transport stream must
// implement. This can be used to mock an actual transport stream for tests of
// handler code that use, for example, grpc.SetHeader (which requires some
Expand All @@ -1866,10 +1903,8 @@
// Notice: This type is EXPERIMENTAL and may be changed or removed in a
// later release.
type ServerTransportStream interface {
Method() string
SetHeader(md metadata.MD) error
SendHeader(md metadata.MD) error
SetTrailer(md metadata.MD) error
transport.StreamDelegate // Forces embedding a stream for delegating undefined methods.
UnstableServerTransportStream
}

// ServerTransportStreamFromContext returns the ServerTransportStream saved in
Expand Down Expand Up @@ -2103,8 +2138,8 @@
// Notice: This function is EXPERIMENTAL and may be changed or removed in a
// later release.
func SetSendCompressor(ctx context.Context, name string) error {
stream, ok := ServerTransportStreamFromContext(ctx).(*transport.Stream)
if !ok || stream == nil {
stream := ServerTransportStreamFromContext(ctx)
if stream == nil {
return fmt.Errorf("failed to fetch the stream from the given context")
}

Expand All @@ -2125,8 +2160,8 @@
// Notice: This function is EXPERIMENTAL and may be changed or removed in a
// later release.
func ClientSupportedCompressors(ctx context.Context) ([]string, error) {
stream, ok := ServerTransportStreamFromContext(ctx).(*transport.Stream)
if !ok || stream == nil {
stream := ServerTransportStreamFromContext(ctx)
if stream == nil {
return nil, fmt.Errorf("failed to fetch the stream from the given context %v", ctx)
}

Expand Down