Skip to content

Commit

Permalink
Merge branch 'master' into joerger/fix-mfa-leaf-app
Browse files Browse the repository at this point in the history
  • Loading branch information
Joerger authored Jan 30, 2025
2 parents 5c887d6 + ee119ea commit eae69cf
Show file tree
Hide file tree
Showing 236 changed files with 11,665 additions and 8,864 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/doc-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ jobs:
- name: Run the linter
uses: errata-ai/vale-action@d89dee975228ae261d22c15adcd03578634d429c # v2.1.1
with:
version: 2.30.0
version: 3.9.4
# Take the comma-separated list of files returned by the "Check for
# relevant changes" job.
separator: ","
Expand Down
60 changes: 60 additions & 0 deletions api/client/joinservice.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,11 @@ type RegisterAzureChallengeResponseFunc func(challenge string) (*proto.RegisterU
// error.
type RegisterTPMChallengeResponseFunc func(challenge *proto.TPMEncryptedCredential) (*proto.RegisterUsingTPMMethodChallengeResponse, error)

// RegisterOracleChallengeResponseFunc is a function type meant to be passed to
// RegisterUsingOracleMethod: It must return a
// *proto.OracleSignedRequest for a given challenge, or an error.
type RegisterOracleChallengeResponseFunc func(challenge string) (*proto.OracleSignedRequest, error)

// RegisterUsingIAMMethod registers the caller using the IAM join method and
// returns signed certs to join the cluster.
//
Expand Down Expand Up @@ -202,6 +207,61 @@ func (c *JoinServiceClient) RegisterUsingTPMMethod(
return certs, nil
}

// RegisterUsingOracleMethod registers the caller using the Oracle join method and
// returns signed certs to join the cluster. The caller must provide a
// ChallengeResponseFunc which returns a *proto.OracleSignedRequest
// for a given challenge, or an error.
func (c *JoinServiceClient) RegisterUsingOracleMethod(
ctx context.Context,
tokenReq *types.RegisterUsingTokenRequest,
oracleRequestFromChallenge RegisterOracleChallengeResponseFunc,
) (*proto.Certs, error) {
ctx, cancel := context.WithCancel(ctx)
defer cancel()

oracleJoinClient, err := c.grpcClient.RegisterUsingOracleMethod(ctx)
if err != nil {
return nil, trace.Wrap(err)
}
if err := oracleJoinClient.Send(&proto.RegisterUsingOracleMethodRequest{
Request: &proto.RegisterUsingOracleMethodRequest_RegisterUsingTokenRequest{
RegisterUsingTokenRequest: tokenReq,
},
}); err != nil {
return nil, trace.Wrap(err)
}

challengeResp, err := oracleJoinClient.Recv()
if err != nil {
return nil, trace.Wrap(err)
}
challenge := challengeResp.GetChallenge()
if challenge == "" {
return nil, trace.BadParameter("missing challenge")
}
oracleSignedReq, err := oracleRequestFromChallenge(challenge)
if err != nil {
return nil, trace.Wrap(err)
}
if err := oracleJoinClient.Send(&proto.RegisterUsingOracleMethodRequest{
Request: &proto.RegisterUsingOracleMethodRequest_OracleRequest{
OracleRequest: oracleSignedReq,
},
}); err != nil {
return nil, trace.Wrap(err)
}

certsResp, err := oracleJoinClient.Recv()
if err != nil {
return nil, trace.Wrap(err)
}
certs := certsResp.GetCerts()
if certs == nil {
return nil, trace.BadParameter("expected certificate response, got %T", certsResp.Response)
}
return certs, nil
}

// RegisterUsingToken registers the caller using a token and returns signed
// certs.
// This is used where a more specific RPC has not been introduced for the join
Expand Down
113 changes: 111 additions & 2 deletions api/client/joinservice_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,18 @@ import (

type mockJoinServiceServer struct {
proto.UnimplementedJoinServiceServer
registerUsingTPMMethod func(srv proto.JoinService_RegisterUsingTPMMethodServer) error
registerUsingTPMMethod func(srv proto.JoinService_RegisterUsingTPMMethodServer) error
registerUsingOracleMethod func(srv proto.JoinService_RegisterUsingOracleMethodServer) error
}

func (m *mockJoinServiceServer) RegisterUsingTPMMethod(srv proto.JoinService_RegisterUsingTPMMethodServer) error {
return m.registerUsingTPMMethod(srv)
}

func (m *mockJoinServiceServer) RegisterUsingOracleMethod(srv proto.JoinService_RegisterUsingOracleMethodServer) error {
return m.registerUsingOracleMethod(srv)
}

func TestJoinServiceClient_RegisterUsingTPMMethod(t *testing.T) {
t.Parallel()

Expand Down Expand Up @@ -109,11 +114,11 @@ func TestJoinServiceClient_RegisterUsingTPMMethod(t *testing.T) {
proto.RegisterJoinServiceServer(srv, mockService)

go func() {
defer cancel()
err := srv.Serve(lis)
if err != nil && !errors.Is(err, grpc.ErrServerStopped) {
assert.NoError(t, err)
}
cancel()
}()

// grpc.NewClient attempts to DNS resolve addr, whereas grpc.Dial doesn't.
Expand All @@ -140,3 +145,107 @@ func TestJoinServiceClient_RegisterUsingTPMMethod(t *testing.T) {
assert.Empty(t, cmp.Diff(mockCerts, certs))
}
}

func TestJoinServiceClient_RegisterUsingOracleMethod(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)

lis := bufconn.Listen(100)
t.Cleanup(func() {
assert.NoError(t, lis.Close())
})

tokenReq := &types.RegisterUsingTokenRequest{
Token: "token",
}
mockTokenRequest := &proto.RegisterUsingOracleMethodRequest{
Request: &proto.RegisterUsingOracleMethodRequest_RegisterUsingTokenRequest{
RegisterUsingTokenRequest: tokenReq,
},
}
mockChallenge := "challenge"
oracleReq := &proto.OracleSignedRequest{
Headers: map[string]string{
"x-teleport-challenge": mockChallenge,
},
PayloadHeaders: map[string]string{
"x-teleport-challenge": mockChallenge,
},
}

mockOracleRequest := &proto.RegisterUsingOracleMethodRequest{
Request: &proto.RegisterUsingOracleMethodRequest_OracleRequest{
OracleRequest: oracleReq,
},
}
mockCerts := &proto.Certs{
TLS: []byte("cert"),
}
mockService := &mockJoinServiceServer{
registerUsingOracleMethod: func(srv proto.JoinService_RegisterUsingOracleMethodServer) error {
tokenReq, err := srv.Recv()
if !assert.NoError(t, err) {
return err
}
assert.Empty(t, cmp.Diff(mockTokenRequest, tokenReq))
err = srv.Send(&proto.RegisterUsingOracleMethodResponse{
Response: &proto.RegisterUsingOracleMethodResponse_Challenge{
Challenge: mockChallenge,
},
})
if !assert.NoError(t, err) {
return err
}
headerReq, err := srv.Recv()
if !assert.NoError(t, err) {
return err
}
assert.Empty(t, cmp.Diff(mockOracleRequest, headerReq))

err = srv.Send(&proto.RegisterUsingOracleMethodResponse{
Response: &proto.RegisterUsingOracleMethodResponse_Certs{
Certs: mockCerts,
},
})
if !assert.NoError(t, err) {
return err
}
return nil
},
}
srv := grpc.NewServer()
t.Cleanup(srv.Stop)
proto.RegisterJoinServiceServer(srv, mockService)

go func() {
defer cancel()
err := srv.Serve(lis)
if err != nil && !errors.Is(err, grpc.ErrServerStopped) {
assert.NoError(t, err)
}
}()

// grpc.NewClient attempts to DNS resolve addr, whereas grpc.Dial doesn't.
c, err := grpc.Dial(
"bufconn",
grpc.WithContextDialer(func(ctx context.Context, _ string) (net.Conn, error) {
return lis.DialContext(ctx)
}),
grpc.WithTransportCredentials(insecure.NewCredentials()),
)
require.NoError(t, err)

joinClient := NewJoinServiceClient(proto.NewJoinServiceClient(c))
certs, err := joinClient.RegisterUsingOracleMethod(
ctx,
tokenReq,
func(challenge string) (*proto.OracleSignedRequest, error) {
assert.Equal(t, mockChallenge, challenge)
return oracleReq, nil
},
)
if assert.NoError(t, err) {
assert.Empty(t, cmp.Diff(mockCerts, certs))
}
}
19 changes: 18 additions & 1 deletion api/client/proto/joinservice.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ limitations under the License.

package proto

import "github.com/gravitational/trace"
import (
"github.com/gravitational/trace"
)

func (r *RegisterUsingIAMMethodRequest) CheckAndSetDefaults() error {
if len(r.StsIdentityRequest) == 0 {
Expand All @@ -34,3 +36,18 @@ func (r *RegisterUsingAzureMethodRequest) CheckAndSetDefaults() error {
}
return trace.Wrap(r.RegisterUsingTokenRequest.CheckAndSetDefaults())
}

func (r *RegisterUsingOracleMethodRequest) CheckAndSetDefaults() error {
switch req := r.Request.(type) {
case *RegisterUsingOracleMethodRequest_RegisterUsingTokenRequest:
return trace.Wrap(req.RegisterUsingTokenRequest.CheckAndSetDefaults())
case *RegisterUsingOracleMethodRequest_OracleRequest:
if len(req.OracleRequest.Headers) == 0 {
return trace.BadParameter("missing parameter Headers")
}
if len(req.OracleRequest.PayloadHeaders) == 0 {
return trace.BadParameter("missing parameter PayloadHeaders")
}
}
return trace.BadParameter("invalid request type: %T", r.Request)
}
Loading

0 comments on commit eae69cf

Please sign in to comment.