diff --git a/lib/auth/bot_test.go b/lib/auth/bot_test.go index ae4ddb14136b9..2e019ffa7123e 100644 --- a/lib/auth/bot_test.go +++ b/lib/auth/bot_test.go @@ -42,6 +42,7 @@ import ( "github.com/stretchr/testify/require" "golang.org/x/crypto/ssh" "google.golang.org/grpc" + "google.golang.org/protobuf/testing/protocmp" "github.com/gravitational/teleport" apiclient "github.com/gravitational/teleport/api/client" @@ -49,10 +50,12 @@ import ( "github.com/gravitational/teleport/api/client/webclient" headerv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/header/v1" machineidv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/machineid/v1" + workloadidentityv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1" "github.com/gravitational/teleport/api/metadata" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/api/types/events" "github.com/gravitational/teleport/api/utils/keys" + "github.com/gravitational/teleport/integrations/lib/testing/fakejoin" "github.com/gravitational/teleport/lib/auth/authclient" "github.com/gravitational/teleport/lib/auth/join" "github.com/gravitational/teleport/lib/auth/machineid/machineidv1" @@ -216,6 +219,146 @@ func TestRegisterBotCertificateGenerationCheck(t *testing.T) { } } +// TestBotJoinAttrs_Kubernetes validates that a bot can join using the +// Kubernetes join method and that the correct join attributes are encoded in +// the resulting bot cert, and, that when this cert is used to produce role +// certificates, the correct attributes are encoded in the role cert. +// +// Whilst this specifically tests the Kubernetes join method, it tests by proxy +// the implementation for most of the join methods. +func TestBotJoinAttrs_Kubernetes(t *testing.T) { + t.Parallel() + + srv := newTestTLSServer(t) + ctx := context.Background() + + role, err := CreateRole(ctx, srv.Auth(), "example", types.RoleSpecV6{}) + require.NoError(t, err) + + // Create a new bot. + client, err := srv.NewClient(TestAdmin()) + require.NoError(t, err) + bot, err := client.BotServiceClient().CreateBot(ctx, &machineidv1pb.CreateBotRequest{ + Bot: &machineidv1pb.Bot{ + Metadata: &headerv1.Metadata{ + Name: "test", + }, + Spec: &machineidv1pb.BotSpec{ + Roles: []string{"example"}, + }, + }, + }) + require.NoError(t, err) + + k8s, err := fakejoin.NewKubernetesSigner(srv.Clock()) + require.NoError(t, err) + jwks, err := k8s.GetMarshaledJWKS() + require.NoError(t, err) + fakePSAT, err := k8s.SignServiceAccountJWT( + "my-pod", + "my-namespace", + "my-service-account", + srv.ClusterName(), + ) + require.NoError(t, err) + + tok, err := types.NewProvisionTokenFromSpec( + "my-k8s-token", + time.Time{}, + types.ProvisionTokenSpecV2{ + Roles: types.SystemRoles{types.RoleBot}, + JoinMethod: types.JoinMethodKubernetes, + BotName: bot.Metadata.Name, + Kubernetes: &types.ProvisionTokenSpecV2Kubernetes{ + Type: types.KubernetesJoinTypeStaticJWKS, + StaticJWKS: &types.ProvisionTokenSpecV2Kubernetes_StaticJWKSConfig{ + JWKS: jwks, + }, + Allow: []*types.ProvisionTokenSpecV2Kubernetes_Rule{ + { + ServiceAccount: "my-namespace:my-service-account", + }, + }, + }, + }, + ) + require.NoError(t, err) + require.NoError(t, client.CreateToken(ctx, tok)) + + result, err := join.Register(ctx, join.RegisterParams{ + Token: tok.GetName(), + JoinMethod: types.JoinMethodKubernetes, + ID: state.IdentityID{ + Role: types.RoleBot, + }, + AuthServers: []utils.NetAddr{*utils.MustParseAddr(srv.Addr().String())}, + KubernetesReadFileFunc: func(name string) ([]byte, error) { + return []byte(fakePSAT), nil + }, + }) + require.NoError(t, err) + + // Validate correct join attributes are encoded. + cert, err := tlsca.ParseCertificatePEM(result.Certs.TLS) + require.NoError(t, err) + ident, err := tlsca.FromSubject(cert.Subject, cert.NotAfter) + require.NoError(t, err) + wantAttrs := &workloadidentityv1pb.JoinAttrs{ + Meta: &workloadidentityv1pb.JoinAttrsMeta{ + JoinTokenName: tok.GetName(), + JoinMethod: string(types.JoinMethodKubernetes), + }, + Kubernetes: &workloadidentityv1pb.JoinAttrsKubernetes{ + ServiceAccount: &workloadidentityv1pb.JoinAttrsKubernetesServiceAccount{ + Namespace: "my-namespace", + Name: "my-service-account", + }, + Pod: &workloadidentityv1pb.JoinAttrsKubernetesPod{ + Name: "my-pod", + }, + Subject: "system:serviceaccount:my-namespace:my-service-account", + }, + } + require.Empty(t, cmp.Diff( + ident.JoinAttributes, + wantAttrs, + protocmp.Transform(), + )) + + // Now, try to produce a role certificate using the bot cert, to ensure + // that the join attributes are correctly propagated. + privateKeyPEM, err := keys.MarshalPrivateKey(result.PrivateKey) + require.NoError(t, err) + tlsCert, err := tls.X509KeyPair(result.Certs.TLS, privateKeyPEM) + require.NoError(t, err) + sshPub, err := ssh.NewPublicKey(result.PrivateKey.Public()) + require.NoError(t, err) + tlsPub, err := keys.MarshalPublicKey(result.PrivateKey.Public()) + require.NoError(t, err) + botClient := srv.NewClientWithCert(tlsCert) + roleCerts, err := botClient.GenerateUserCerts(ctx, proto.UserCertsRequest{ + SSHPublicKey: ssh.MarshalAuthorizedKey(sshPub), + TLSPublicKey: tlsPub, + Username: bot.Status.UserName, + RoleRequests: []string{ + role.GetName(), + }, + UseRoleRequests: true, + Expires: srv.Clock().Now().Add(time.Hour), + }) + require.NoError(t, err) + + roleCert, err := tlsca.ParseCertificatePEM(roleCerts.TLS) + require.NoError(t, err) + roleIdent, err := tlsca.FromSubject(roleCert.Subject, roleCert.NotAfter) + require.NoError(t, err) + require.Empty(t, cmp.Diff( + roleIdent.JoinAttributes, + wantAttrs, + protocmp.Transform(), + )) +} + // TestRegisterBotInstance tests that bot instances are created properly on join func TestRegisterBotInstance(t *testing.T) { t.Parallel() @@ -282,7 +425,6 @@ func TestRegisterBotInstance(t *testing.T) { require.Equal(t, int32(1), ia.Generation) require.Equal(t, string(types.JoinMethodToken), ia.JoinMethod) require.Equal(t, token.GetSafeName(), ia.JoinToken) - // The latest authentications field should contain the same record (and // only that record.) require.Len(t, botInstance.GetStatus().LatestAuthentications, 1)