From a718918de8413d0a3dfd6ff5992a651284dc9c20 Mon Sep 17 00:00:00 2001 From: "STeve (Xin) Huang" Date: Tue, 14 Jan 2025 16:48:04 -0500 Subject: [PATCH] GitHub proxy part 6: proxing Git using SSH transport (#49980) * GitHub proxy part 6: proxing Git using SSH transport * better command parsing and update suite * refactor * revert unnecearrty files * address review comments * ut fix * revert localsite_test.go * change special suffix to teleport-github-org for routing * fix routing ut * minor typo edit * fix ut after sshca change * add UT to sshutils * minor review comments * fix api ut because of special suffix change * GitServerReadOnlyClient * downgrade error to warning * run go mod tidy. not sure why it's needed * rename mock.go to mock_test.go --- api/client/client.go | 7 +- api/client/gitserver/gitserver.go | 8 + api/types/constants.go | 2 +- api/types/server.go | 3 + api/types/server_test.go | 5 +- constants.go | 3 + integrations/terraform/go.sum | 2 + lib/auth/authclient/api.go | 7 + lib/auth/authclient/clt.go | 3 + lib/cache/cache.go | 1 + lib/cache/git_server.go | 10 + lib/cryptosuites/suites.go | 11 +- lib/proxy/router.go | 39 +- lib/proxy/router_test.go | 53 ++- lib/reversetunnel/localsite.go | 79 +++- lib/reversetunnel/peer.go | 12 + lib/reversetunnel/remotesite.go | 5 + lib/reversetunnel/srv.go | 7 +- lib/reversetunnelclient/api.go | 2 + lib/service/service.go | 14 + lib/services/git_server.go | 8 +- lib/services/readonly/readonly.go | 3 + lib/services/role.go | 27 ++ lib/services/watcher.go | 37 ++ lib/services/watcher_test.go | 63 +++ lib/srv/authhandlers.go | 20 +- lib/srv/authhandlers_test.go | 57 +-- lib/srv/git/forward.go | 599 +++++++++++++++++++++++++ lib/srv/git/forward_test.go | 363 +++++++++++++++ lib/srv/git/github.go | 160 +++++++ lib/srv/git/github_test.go | 145 ++++++ lib/srv/regular/sshserver_test.go | 18 + lib/sshutils/exec.go | 70 +++ lib/sshutils/exec_test.go | 98 ++++ lib/sshutils/{mock.go => mock_test.go} | 53 +++ lib/sshutils/reply.go | 106 +++++ lib/sshutils/reply_test.go | 93 ++++ lib/sshutils/server.go | 7 + lib/sshutils/utils.go | 20 + lib/web/apiserver_test.go | 22 + tool/tsh/common/git_list_test.go | 8 +- 41 files changed, 2176 insertions(+), 74 deletions(-) create mode 100644 lib/srv/git/forward.go create mode 100644 lib/srv/git/forward_test.go create mode 100644 lib/srv/git/github.go create mode 100644 lib/srv/git/github_test.go create mode 100644 lib/sshutils/exec.go create mode 100644 lib/sshutils/exec_test.go rename lib/sshutils/{mock.go => mock_test.go} (57%) create mode 100644 lib/sshutils/reply.go create mode 100644 lib/sshutils/reply_test.go diff --git a/api/client/client.go b/api/client/client.go index edffe12d00ff2..63c44fb940ad3 100644 --- a/api/client/client.go +++ b/api/client/client.go @@ -4942,11 +4942,16 @@ func (c *Client) UserTasksServiceClient() *usertaskapi.Client { return usertaskapi.NewClient(usertaskv1.NewUserTaskServiceClient(c.conn)) } -// GitServerClient returns a client for managing git servers +// GitServerClient returns a client for managing Git servers func (c *Client) GitServerClient() *gitserverclient.Client { return gitserverclient.NewClient(gitserverpb.NewGitServerServiceClient(c.conn)) } +// GitServerReadOnlyClient returns the read-only client for Git servers. +func (c *Client) GitServerReadOnlyClient() gitserverclient.ReadOnlyClient { + return c.GitServerClient() +} + // GetCertAuthority retrieves a CA by type and domain. func (c *Client) GetCertAuthority(ctx context.Context, id types.CertAuthID, loadKeys bool) (types.CertAuthority, error) { ca, err := c.TrustClient().GetCertAuthority(ctx, &trustpb.GetCertAuthorityRequest{ diff --git a/api/client/gitserver/gitserver.go b/api/client/gitserver/gitserver.go index 3b799657ac2cb..bfb259c80cc82 100644 --- a/api/client/gitserver/gitserver.go +++ b/api/client/gitserver/gitserver.go @@ -22,6 +22,14 @@ import ( "github.com/gravitational/teleport/api/types" ) +// ReadOnlyClient defines getter functions for Git servers. +type ReadOnlyClient interface { + // ListGitServers returns a paginated list of Git servers. + ListGitServers(ctx context.Context, pageSize int, pageToken string) ([]types.Server, string, error) + // GetGitServer returns a Git server by name. + GetGitServer(ctx context.Context, name string) (types.Server, error) +} + // Client is an Git servers client. type Client struct { grpcClient gitserverv1.GitServerServiceClient diff --git a/api/types/constants.go b/api/types/constants.go index b274b1871de4f..10aa2322998d3 100644 --- a/api/types/constants.go +++ b/api/types/constants.go @@ -1496,5 +1496,5 @@ const ( const ( // GitHubOrgServerDomain is the sub domain used in the hostname of a // types.Server to indicate the GitHub organization of a Git server. - GitHubOrgServerDomain = "github-org" + GitHubOrgServerDomain = "teleport-github-org" ) diff --git a/api/types/server.go b/api/types/server.go index 98b61911c8415..ed84089ad3bfd 100644 --- a/api/types/server.go +++ b/api/types/server.go @@ -626,6 +626,9 @@ func (s *ServerV2) githubCheckAndSetDefaults() error { return trace.Wrap(err, "invalid GitHub organization name") } + // Set SSH host port for connection and "fake" hostname for routing. These + // values are hard-coded and cannot be customized. + s.Spec.Addr = "github.com:22" s.Spec.Hostname = MakeGitHubOrgServerDomain(s.Spec.GitHub.Organization) if s.Metadata.Labels == nil { s.Metadata.Labels = make(map[string]string) diff --git a/api/types/server_test.go b/api/types/server_test.go index 312c59dd5b72d..4e1476e9cbf38 100644 --- a/api/types/server_test.go +++ b/api/types/server_test.go @@ -623,7 +623,8 @@ func TestServerCheckAndSetDefaults(t *testing.T) { }, }, Spec: ServerSpecV2{ - Hostname: "my-org.github-org", + Addr: "github.com:22", + Hostname: "my-org.teleport-github-org", GitHub: &GitHubServerMetadata{ Integration: "my-org", Organization: "my-org", @@ -807,7 +808,7 @@ func TestGetCloudMetadataAWS(t *testing.T) { func TestGitServerOrgDomain(t *testing.T) { domain := MakeGitHubOrgServerDomain("my-org") - require.Equal(t, "my-org.github-org", domain) + require.Equal(t, "my-org.teleport-github-org", domain) githubNodeAddr := domain + ":22" org, ok := GetGitHubOrgFromNodeAddr(githubNodeAddr) diff --git a/constants.go b/constants.go index 79f97ae24bfaf..a6dc009414772 100644 --- a/constants.go +++ b/constants.go @@ -288,6 +288,9 @@ const ( // ComponentRolloutController represents the autoupdate_agent_rollout controller. ComponentRolloutController = "rollout-controller" + // ComponentForwardingGit represents the SSH proxy that forwards Git commands. + ComponentForwardingGit = "git:forward" + // VerboseLogsEnvVar forces all logs to be verbose (down to DEBUG level) VerboseLogsEnvVar = "TELEPORT_DEBUG" diff --git a/integrations/terraform/go.sum b/integrations/terraform/go.sum index 6c0be667fb1a2..0f257bbcaf2b7 100644 --- a/integrations/terraform/go.sum +++ b/integrations/terraform/go.sum @@ -1596,6 +1596,8 @@ github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D github.com/mattn/go-runewidth v0.0.4/go.mod h1:LwmH8dsx7+W8Uxz3IHJYH5QSwggIsqBzpuz5H//U1FU= github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6TULQc= github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= +github.com/mattn/go-shellwords v1.0.12 h1:M2zGm7EW6UQJvDeQxo4T51eKPurbeFbe8WtebGE2xrk= +github.com/mattn/go-shellwords v1.0.12/go.mod h1:EZzvwXDESEeg03EKmM+RmDnNOPKG4lLtQsUlTZDWQ8Y= github.com/mattn/go-sqlite3 v1.14.14/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= github.com/mattn/go-sqlite3 v1.14.24 h1:tpSp2G2KyMnnQu99ngJ47EIkWVmliIizyZBfPrBWDRM= diff --git a/lib/auth/authclient/api.go b/lib/auth/authclient/api.go index 742ea527ab248..c59bb595a2afd 100644 --- a/lib/auth/authclient/api.go +++ b/lib/auth/authclient/api.go @@ -26,6 +26,7 @@ import ( "github.com/gravitational/trace" "google.golang.org/grpc" + "github.com/gravitational/teleport/api/client/gitserver" "github.com/gravitational/teleport/api/client/proto" accessmonitoringrules "github.com/gravitational/teleport/api/gen/proto/go/teleport/accessmonitoringrules/v1" "github.com/gravitational/teleport/api/gen/proto/go/teleport/autoupdate/v1" @@ -320,6 +321,9 @@ type ReadProxyAccessPoint interface { // GetAutoUpdateAgentRollout gets the AutoUpdateAgentRollout from the backend. GetAutoUpdateAgentRollout(ctx context.Context) (*autoupdate.AutoUpdateAgentRollout, error) + + // GitServerReadOnlyClient returns the read-only client for Git servers. + GitServerReadOnlyClient() gitserver.ReadOnlyClient } // SnowflakeSessionWatcher is watcher interface used by Snowflake web session watcher. @@ -1264,6 +1268,9 @@ type Cache interface { // GetPluginStaticCredentialsByLabels will get a list of plugin static credentials resource by matching labels. GetPluginStaticCredentialsByLabels(ctx context.Context, labels map[string]string) ([]types.PluginStaticCredentials, error) + + // GitServerGetter defines methods for fetching Git servers. + services.GitServerGetter } type NodeWrapper struct { diff --git a/lib/auth/authclient/clt.go b/lib/auth/authclient/clt.go index 09e4caff54d29..4f17263feaab7 100644 --- a/lib/auth/authclient/clt.go +++ b/lib/auth/authclient/clt.go @@ -1900,4 +1900,7 @@ type ClientI interface { // GitServerClient returns git server client. GitServerClient() *gitserver.Client + + // GitServerReadOnlyClient returns the read-only client for Git servers. + GitServerReadOnlyClient() gitserver.ReadOnlyClient } diff --git a/lib/cache/cache.go b/lib/cache/cache.go index 06fd5c59098c5..1ae68be971568 100644 --- a/lib/cache/cache.go +++ b/lib/cache/cache.go @@ -288,6 +288,7 @@ func ForRemoteProxy(cfg Config) Config { {Kind: types.KindDatabaseServer}, {Kind: types.KindDatabaseService}, {Kind: types.KindKubeServer}, + {Kind: types.KindGitServer}, } cfg.QueueSize = defaults.ProxyQueueSize return cfg diff --git a/lib/cache/git_server.go b/lib/cache/git_server.go index 849a160757eee..b585b0b169817 100644 --- a/lib/cache/git_server.go +++ b/lib/cache/git_server.go @@ -23,11 +23,21 @@ import ( "github.com/gravitational/trace" + "github.com/gravitational/teleport/api/client/gitserver" apidefaults "github.com/gravitational/teleport/api/defaults" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/services" ) +// GitServerReadOnlyClient returns the read-only client for Git servers. +// +// Note that Cache implements GitServerReadOnlyClient to satisfy +// auth.ProxyAccessPoint but also has the getter functions at top level to +// satisfy auth.Cache. +func (c *Cache) GitServerReadOnlyClient() gitserver.ReadOnlyClient { + return c +} + func (c *Cache) GetGitServer(ctx context.Context, name string) (types.Server, error) { ctx, span := c.Tracer.Start(ctx, "cache/GetGitServer") defer span.End() diff --git a/lib/cryptosuites/suites.go b/lib/cryptosuites/suites.go index d43d50ef35237..8a06d18998c2c 100644 --- a/lib/cryptosuites/suites.go +++ b/lib/cryptosuites/suites.go @@ -114,6 +114,10 @@ const ( // GitHubProxyCASSH represents the SSH key for GitHub proxy CAs. GitHubProxyCASSH + // GitClient represents a key used to forward Git commands to Git services + // like GitHub. + GitClient + // keyPurposeMax is 1 greater than the last valid key purpose, used to test that all values less than this // are valid for each suite. keyPurposeMax @@ -187,8 +191,8 @@ var ( ProxyKubeClient: RSA2048, // EC2InstanceConnect has always used Ed25519 by default. EC2InstanceConnect: Ed25519, - // GitHubProxyCASSH uses same algorithms as UserCASSH. - GitHubProxyCASSH: RSA2048, + GitHubProxyCASSH: Ed25519, + GitClient: Ed25519, } // balancedV1 strikes a balance between security, compatibility, and @@ -220,6 +224,7 @@ var ( ProxyKubeClient: ECDSAP256, EC2InstanceConnect: Ed25519, GitHubProxyCASSH: Ed25519, + GitClient: Ed25519, } // fipsv1 is an algorithm suite tailored for FIPS compliance. It is based on @@ -251,6 +256,7 @@ var ( ProxyKubeClient: ECDSAP256, EC2InstanceConnect: ECDSAP256, GitHubProxyCASSH: ECDSAP256, + GitClient: ECDSAP256, } // hsmv1 in an algorithm suite tailored for clusters using an HSM or KMS @@ -284,6 +290,7 @@ var ( ProxyKubeClient: ECDSAP256, EC2InstanceConnect: Ed25519, GitHubProxyCASSH: ECDSAP256, + GitClient: Ed25519, } allSuites = map[types.SignatureAlgorithmSuite]suite{ diff --git a/lib/proxy/router.go b/lib/proxy/router.go index 61ba7f466c5de..b157c3a30f719 100644 --- a/lib/proxy/router.go +++ b/lib/proxy/router.go @@ -23,6 +23,7 @@ import ( "context" "errors" "fmt" + "math/rand/v2" "net" "os" "sync" @@ -277,7 +278,6 @@ func (r *Router) DialHost(ctx context.Context, clientSrcAddr, clientDstAddr net. } } } - } else { return nil, trace.ConnectionProblem(errors.New("connection problem"), "direct dialing to nodes not found in inventory is not supported") } @@ -377,6 +377,7 @@ func (r *Router) getRemoteCluster(ctx context.Context, clusterName string, check type site interface { GetNodes(ctx context.Context, fn func(n readonly.Server) bool) ([]types.Server, error) GetClusterNetworkingConfig(ctx context.Context) (types.ClusterNetworkingConfig, error) + GetGitServers(context.Context, func(readonly.Server) bool) ([]types.Server, error) } // remoteSite is a site implementation that wraps @@ -392,6 +393,17 @@ func (r remoteSite) GetNodes(ctx context.Context, fn func(n readonly.Server) boo return nil, trace.Wrap(err) } + servers, err := watcher.CurrentResourcesWithFilter(ctx, fn) + return servers, trace.Wrap(err) +} + +// GetGitServers uses the wrapped sites GitServerWatcher to filter git servers. +func (r remoteSite) GetGitServers(ctx context.Context, fn func(n readonly.Server) bool) ([]types.Server, error) { + watcher, err := r.site.GitServerWatcher() + if err != nil { + return nil, trace.Wrap(err) + } + return watcher.CurrentResourcesWithFilter(ctx, fn) } @@ -409,6 +421,9 @@ func (r remoteSite) GetClusterNetworkingConfig(ctx context.Context) (types.Clust // getServer attempts to locate a node matching the provided host and port in // the provided site. func getServer(ctx context.Context, host, port string, site site) (types.Server, error) { + if org, ok := types.GetGitHubOrgFromNodeAddr(host); ok { + return getGitHubServer(ctx, org, site) + } return getServerWithResolver(ctx, host, port, site, nil /* use default resolver */) } @@ -562,3 +577,25 @@ func (r *Router) GetSiteClient(ctx context.Context, clusterName string) (authcli } return site.GetClient() } + +func getGitHubServer(ctx context.Context, gitHubOrg string, site site) (types.Server, error) { + servers, err := site.GetGitServers(ctx, func(s readonly.Server) bool { + github := s.GetGitHub() + return github != nil && github.Organization == gitHubOrg + }) + if err != nil { + return nil, trace.Wrap(err) + } + + switch len(servers) { + case 0: + return nil, trace.NotFound("unable to locate Git server for GitHub organization %s", gitHubOrg) + case 1: + return servers[0], nil + default: + // It's unusual but possible to have multiple servers per organization + // (e.g. possibly a second Git server for a manual CA rotation). Pick a + // random one. + return servers[rand.N(len(servers))], nil + } +} diff --git a/lib/proxy/router_test.go b/lib/proxy/router_test.go index d18b3ce11663a..3a3faf2252480 100644 --- a/lib/proxy/router_test.go +++ b/lib/proxy/router_test.go @@ -27,6 +27,7 @@ import ( "github.com/google/uuid" "github.com/gravitational/trace" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/crypto/ssh" @@ -43,8 +44,9 @@ import ( ) type testSite struct { - cfg types.ClusterNetworkingConfig - nodes []types.Server + cfg types.ClusterNetworkingConfig + nodes []types.Server + gitServers []types.Server } func (t testSite) GetClusterNetworkingConfig(ctx context.Context) (types.ClusterNetworkingConfig, error) { @@ -61,6 +63,16 @@ func (t testSite) GetNodes(ctx context.Context, fn func(n readonly.Server) bool) return out, nil } +func (t testSite) GetGitServers(ctx context.Context, fn func(n readonly.Server) bool) ([]types.Server, error) { + var out []types.Server + for _, s := range t.gitServers { + if fn(s) { + out = append(out, s) + } + } + + return out, nil +} type server struct { name string @@ -351,6 +363,11 @@ func TestGetServers(t *testing.T) { }, ) + gitServers := []types.Server{ + makeGitHubServer(t, "org1"), + makeGitHubServer(t, "org2"), + } + // ensure tests don't have order-dependence rand.Shuffle(len(servers), func(i, j int) { servers[i], servers[j] = servers[j], servers[i] @@ -489,6 +506,28 @@ func TestGetServers(t *testing.T) { require.True(t, srv.IsOpenSSHNode()) }, }, + { + name: "git server", + site: testSite{cfg: &unambiguousCfg, gitServers: gitServers}, + host: "org2.teleport-github-org", + errAssertion: require.NoError, + serverAssertion: func(t *testing.T, srv types.Server) { + require.NotNil(t, srv) + require.NotNil(t, srv.GetGitHub()) + assert.Equal(t, "org2", srv.GetGitHub().Organization) + }, + }, + { + name: "git server not found", + site: testSite{cfg: &unambiguousCfg, gitServers: gitServers}, + host: "org-not-found.teleport-github-org", + errAssertion: func(t require.TestingT, err error, i ...interface{}) { + require.True(t, trace.IsNotFound(err), i...) + }, + serverAssertion: func(t *testing.T, srv types.Server) { + require.Nil(t, srv) + }, + }, } ctx := context.Background() @@ -891,3 +930,13 @@ func TestRouter_DialSite(t *testing.T) { }) } } + +func makeGitHubServer(t *testing.T, org string) types.Server { + t.Helper() + server, err := types.NewGitHubServer(types.GitHubServerMetadata{ + Integration: org, + Organization: org, + }) + require.NoError(t, err) + return server +} diff --git a/lib/reversetunnel/localsite.go b/lib/reversetunnel/localsite.go index 3446a882cec23..61b7b429e135e 100644 --- a/lib/reversetunnel/localsite.go +++ b/lib/reversetunnel/localsite.go @@ -46,6 +46,7 @@ import ( "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/services/readonly" "github.com/gravitational/teleport/lib/srv/forward" + "github.com/gravitational/teleport/lib/srv/git" "github.com/gravitational/teleport/lib/teleagent" "github.com/gravitational/teleport/lib/utils" logutils "github.com/gravitational/teleport/lib/utils/log" @@ -184,6 +185,11 @@ func (s *localSite) NodeWatcher() (*services.GenericWatcher[types.Server, readon return s.srv.NodeWatcher, nil } +// GitServerWatcher returns a Git server watcher for this cluster. +func (s *localSite) GitServerWatcher() (*services.GenericWatcher[types.Server, readonly.Server], error) { + return s.srv.GitServerWatcher, nil +} + // GetClient returns a client to the full Auth Server API. func (s *localSite) GetClient() (authclient.ClientI, error) { return s.client, nil @@ -248,6 +254,10 @@ func shouldDialAndForward(params reversetunnelclient.DialParams, recConfig types } func (s *localSite) Dial(params reversetunnelclient.DialParams) (net.Conn, error) { + if params.TargetServer != nil && params.TargetServer.GetKind() == types.KindGitServer { + return s.dialAndForwardGit(params) + } + recConfig, err := s.accessPoint.GetSessionRecordingConfig(s.srv.Context) if err != nil { return nil, trace.Wrap(err) @@ -259,7 +269,6 @@ func (s *localSite) Dial(params reversetunnelclient.DialParams) (net.Conn, error if shouldDialAndForward(params, recConfig) { return s.dialAndForward(params) } - // Attempt to perform a direct TCP dial. return s.DialTCP(params) } @@ -346,6 +355,51 @@ func (s *localSite) adviseReconnect(ctx context.Context) { } } +func (s *localSite) dialAndForwardGit(params reversetunnelclient.DialParams) (_ net.Conn, retErr error) { + s.logger.DebugContext(s.srv.ctx, "Dialing and forwarding git", "from", params.From, "to", params.To) + + dialStart := s.srv.Clock.Now() + targetConn, err := s.dialDirect(params) + if err != nil { + return nil, trace.ConnectionProblem(err, "failed to connect to git server") + } + + // Get a host certificate for the forwarding node from the cache. + hostCertificate, err := s.certificateCache.getHostCertificate(context.TODO(), params.Address, params.Principals) + if err != nil { + return nil, trace.Wrap(err) + } + + // Create a forwarding server that serves a single SSH connection on it. This + // server does not need to close, it will close and release all resources + // once conn is closed. + serverConfig := &git.ForwardServerConfig{ + AuthClient: s.client, + AccessPoint: s.accessPoint, + TargetConn: newMetricConn(targetConn, dialTypeDirect, dialStart, s.srv.Clock), + SrcAddr: params.From, + DstAddr: params.To, + HostCertificate: hostCertificate, + Ciphers: s.srv.Config.Ciphers, + KEXAlgorithms: s.srv.Config.KEXAlgorithms, + MACAlgorithms: s.srv.Config.MACAlgorithms, + Emitter: s.srv.Config.Emitter, + ParentContext: s.srv.Context, + LockWatcher: s.srv.LockWatcher, + HostUUID: s.srv.ID, + TargetServer: params.TargetServer, + Clock: s.clock, + } + remoteServer, err := git.NewForwardServer(serverConfig) + if err != nil { + s.logger.ErrorContext(s.srv.ctx, "Failed to create git forward server", "error", err) + return nil, trace.Wrap(err) + } + go remoteServer.Serve() + + return remoteServer.Dial() +} + func (s *localSite) dialAndForward(params reversetunnelclient.DialParams) (_ net.Conn, retErr error) { ctx := s.srv.ctx @@ -457,6 +511,18 @@ func (s *localSite) dialTunnel(dreq *sshutils.DialReq) (net.Conn, error) { return conn, nil } +func (s *localSite) dialDirect(params reversetunnelclient.DialParams) (net.Conn, error) { + dialer := proxyutils.DialerFromEnvironment(params.To.String()) + + dialTimeout := apidefaults.DefaultIOTimeout + if cnc, err := s.accessPoint.GetClusterNetworkingConfig(s.srv.Context); err != nil { + s.logger.WarnContext(s.srv.ctx, "Failed to get cluster networking config - using default dial timeout", "error", err) + } else { + dialTimeout = cnc.GetSSHDialTimeout() + } + return dialer.DialTimeout(s.srv.Context, params.To.Network(), params.To.String(), dialTimeout) +} + // tryProxyPeering determines whether the node should try to be reached over // a peer proxy. func (s *localSite) tryProxyPeering(params reversetunnelclient.DialParams) bool { @@ -650,16 +716,7 @@ func (s *localSite) getConn(params reversetunnelclient.DialParams) (conn net.Con } // If no tunnel connection was found, dial to the target host. - dialer := proxyutils.DialerFromEnvironment(params.To.String()) - - dialTimeout := apidefaults.DefaultIOTimeout - if cnc, err := s.accessPoint.GetClusterNetworkingConfig(s.srv.Context); err != nil { - s.logger.WarnContext(s.srv.ctx, "Failed to get cluster networking config - using default dial timeout", "error", err) - } else { - dialTimeout = cnc.GetSSHDialTimeout() - } - - conn, directErr = dialer.DialTimeout(s.srv.Context, params.To.Network(), params.To.String(), dialTimeout) + conn, directErr = s.dialDirect(params) if directErr != nil { directMsg := getTunnelErrorMessage(params, "direct dial", directErr) s.logger.DebugContext(s.srv.ctx, "All attempted dial methods failed", diff --git a/lib/reversetunnel/peer.go b/lib/reversetunnel/peer.go index 675ad71e4522a..9c922a20ddd78 100644 --- a/lib/reversetunnel/peer.go +++ b/lib/reversetunnel/peer.go @@ -98,6 +98,14 @@ func (p *clusterPeers) NodeWatcher() (*services.GenericWatcher[types.Server, rea return peer.NodeWatcher() } +func (p *clusterPeers) GitServerWatcher() (*services.GenericWatcher[types.Server, readonly.Server], error) { + peer, err := p.pickPeer() + if err != nil { + return nil, trace.Wrap(err) + } + return peer.GitServerWatcher() +} + func (p *clusterPeers) GetClient() (authclient.ClientI, error) { peer, err := p.pickPeer() if err != nil { @@ -198,6 +206,10 @@ func (s *clusterPeer) NodeWatcher() (*services.GenericWatcher[types.Server, read return nil, trace.ConnectionProblem(nil, "unable to fetch node watcher, this proxy %v has not been discovered yet, try again later", s) } +func (s *clusterPeer) GitServerWatcher() (*services.GenericWatcher[types.Server, readonly.Server], error) { + return nil, trace.ConnectionProblem(nil, "unable to fetch git server watcher, this proxy %v has not been discovered yet, try again later", s) +} + func (s *clusterPeer) GetClient() (authclient.ClientI, error) { return nil, trace.ConnectionProblem(nil, "unable to fetch client, this proxy %v has not been discovered yet, try again later", s) } diff --git a/lib/reversetunnel/remotesite.go b/lib/reversetunnel/remotesite.go index bfb3fa91412b4..9ba165ab47942 100644 --- a/lib/reversetunnel/remotesite.go +++ b/lib/reversetunnel/remotesite.go @@ -170,6 +170,11 @@ func (s *remoteSite) NodeWatcher() (*services.GenericWatcher[types.Server, reado return s.nodeWatcher, nil } +// GitServerWatcher returns the Git server watcher for the remote cluster. +func (s *remoteSite) GitServerWatcher() (*services.GenericWatcher[types.Server, readonly.Server], error) { + return nil, trace.NotImplemented("GitServerWatcher not implemented for remoteSite") +} + func (s *remoteSite) GetClient() (authclient.ClientI, error) { return s.remoteClient, nil } diff --git a/lib/reversetunnel/srv.go b/lib/reversetunnel/srv.go index e83efccf31166..c441698c20821 100644 --- a/lib/reversetunnel/srv.go +++ b/lib/reversetunnel/srv.go @@ -205,6 +205,9 @@ type Config struct { // NodeWatcher is a node watcher. NodeWatcher *services.GenericWatcher[types.Server, readonly.Server] + // GitServerWatcher is a Git server watcher. + GitServerWatcher *services.GenericWatcher[types.Server, readonly.Server] + // CertAuthorityWatcher is a cert authority watcher. CertAuthorityWatcher *services.CertAuthorityWatcher @@ -273,6 +276,9 @@ func (cfg *Config) CheckAndSetDefaults() error { if cfg.NodeWatcher == nil { return trace.BadParameter("missing parameter NodeWatcher") } + if cfg.GitServerWatcher == nil { + return trace.BadParameter("missing parameter GitServerWatcher") + } if cfg.CertAuthorityWatcher == nil { return trace.BadParameter("missing parameter CertAuthorityWatcher") } @@ -1271,7 +1277,6 @@ func newRemoteSite(srv *server, domainName string, sconn ssh.Conn) (*remoteSite, } go remoteSite.updateLocks(lockRetry) - return remoteSite, nil } diff --git a/lib/reversetunnelclient/api.go b/lib/reversetunnelclient/api.go index e044bf4beb012..1be0980c03e2d 100644 --- a/lib/reversetunnelclient/api.go +++ b/lib/reversetunnelclient/api.go @@ -125,6 +125,8 @@ type RemoteSite interface { CachingAccessPoint() (authclient.RemoteProxyAccessPoint, error) // NodeWatcher returns the node watcher that maintains the node set for the site NodeWatcher() (*services.GenericWatcher[types.Server, readonly.Server], error) + // GitServerWatcher returns the Git server watcher for the site + GitServerWatcher() (*services.GenericWatcher[types.Server, readonly.Server], error) // GetTunnelsCount returns the amount of active inbound tunnels // from the remote cluster GetTunnelsCount() int diff --git a/lib/service/service.go b/lib/service/service.go index 7fd997e7234f0..26de170be5d6d 100644 --- a/lib/service/service.go +++ b/lib/service/service.go @@ -4390,6 +4390,19 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { return trace.Wrap(err) } + gitServerWatcher, err := services.NewGitServerWatcher(process.ExitContext(), services.GitServerWatcherConfig{ + ResourceWatcherConfig: services.ResourceWatcherConfig{ + Component: teleport.ComponentProxy, + Logger: process.logger.With(teleport.ComponentKey, teleport.ComponentProxy), + Client: accessPoint, + MaxStaleness: time.Minute, + }, + GitServerGetter: accessPoint.GitServerReadOnlyClient(), + }) + if err != nil { + return trace.Wrap(err) + } + caWatcher, err := services.NewCertAuthorityWatcher(process.ExitContext(), services.CertAuthorityWatcherConfig{ ResourceWatcherConfig: services.ResourceWatcherConfig{ Component: teleport.ComponentProxy, @@ -4655,6 +4668,7 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { LockWatcher: lockWatcher, PeerClient: peerClient, NodeWatcher: nodeWatcher, + GitServerWatcher: gitServerWatcher, CertAuthorityWatcher: caWatcher, CircuitBreakerConfig: process.Config.CircuitBreakerConfig, LocalAuthAddresses: utils.NetAddrsToStrings(process.Config.AuthServerAddresses()), diff --git a/lib/services/git_server.go b/lib/services/git_server.go index 17b3d5b93f0ee..aa23dae56ee5d 100644 --- a/lib/services/git_server.go +++ b/lib/services/git_server.go @@ -21,16 +21,12 @@ package services import ( "context" + "github.com/gravitational/teleport/api/client/gitserver" "github.com/gravitational/teleport/api/types" ) // GitServerGetter defines interface for fetching git servers. -type GitServerGetter interface { - // GetGitServer returns Git servers by name. - GetGitServer(ctx context.Context, name string) (types.Server, error) - // ListGitServers returns all Git servers matching filter. - ListGitServers(ctx context.Context, pageSize int, pageToken string) ([]types.Server, string, error) -} +type GitServerGetter gitserver.ReadOnlyClient // GitServers defines an interface for managing git servers. type GitServers interface { diff --git a/lib/services/readonly/readonly.go b/lib/services/readonly/readonly.go index 744f2b4cd3a5c..d2fba33205479 100644 --- a/lib/services/readonly/readonly.go +++ b/lib/services/readonly/readonly.go @@ -432,6 +432,9 @@ type Server interface { GetAWSInstanceID() string // GetAWSAccountID returns the AWS Account ID if this node comes from an EC2 instance. GetAWSAccountID() string + + // GetGitHub returns the GitHub server spec. + GetGitHub() *types.GitHubServerMetadata } // DynamicWindowsDesktop represents a Windows desktop host that is automatically discovered by Windows Desktop Service. diff --git a/lib/services/role.go b/lib/services/role.go index b9e3e04d83816..d9691dc29ec27 100644 --- a/lib/services/role.go +++ b/lib/services/role.go @@ -3581,3 +3581,30 @@ func MarshalRole(role types.Role, opts ...MarshalOption) ([]byte, error) { return nil, trace.BadParameter("unrecognized role version %T", role) } } + +// AuthPreferenceGetter defines an interface for getting the authentication +// preferences. +type AuthPreferenceGetter interface { + // GetAuthPreference fetches the cluster authentication preferences. + GetAuthPreference(ctx context.Context) (types.AuthPreference, error) +} + +// AccessStateFromSSHCertificate populates access state based on user's SSH +// certificate and auth preference. +func AccessStateFromSSHCertificate(ctx context.Context, cert *ssh.Certificate, checker AccessChecker, authPrefGetter AuthPreferenceGetter) (AccessState, error) { + authPref, err := authPrefGetter.GetAuthPreference(ctx) + if err != nil { + return AccessState{}, trace.Wrap(err) + } + state := checker.GetAccessState(authPref) + _, state.MFAVerified = cert.Extensions[teleport.CertExtensionMFAVerified] + // Certain hardware-key based private key policies are treated as MFA verification. + if policyString, ok := cert.Extensions[teleport.CertExtensionPrivateKeyPolicy]; ok { + if keys.PrivateKeyPolicy(policyString).MFAVerified() { + state.MFAVerified = true + } + } + state.EnableDeviceVerification = true + state.DeviceVerified = dtauthz.IsSSHDeviceVerified(cert) + return state, nil +} diff --git a/lib/services/watcher.go b/lib/services/watcher.go index 6bbd71bee0993..62b9e882c68d6 100644 --- a/lib/services/watcher.go +++ b/lib/services/watcher.go @@ -1705,3 +1705,40 @@ func (c *oktaAssignmentCollector) processEventsAndUpdateCurrent(ctx context.Cont } func (*oktaAssignmentCollector) notifyStale() {} + +// GitServerWatcherConfig is the config for Git server watcher. +type GitServerWatcherConfig struct { + GitServerGetter + ResourceWatcherConfig +} + +// NewGitServerWatcher returns a new instance of Git server watcher. +func NewGitServerWatcher(ctx context.Context, cfg GitServerWatcherConfig) (*GenericWatcher[types.Server, readonly.Server], error) { + if cfg.GitServerGetter == nil { + return nil, trace.BadParameter("NodesGetter must be provided") + } + + w, err := NewGenericResourceWatcher(ctx, GenericWatcherConfig[types.Server, readonly.Server]{ + ResourceWatcherConfig: cfg.ResourceWatcherConfig, + ResourceKind: types.KindGitServer, + ResourceGetter: func(ctx context.Context) (all []types.Server, err error) { + var page []types.Server + var token string + for { + page, token, err = cfg.GitServerGetter.ListGitServers(ctx, apidefaults.DefaultChunkSize, token) + if err != nil { + return nil, trace.Wrap(err) + } + all = append(all, page...) + if token == "" { + break + } + } + return all, nil + }, + ResourceKey: types.Server.GetName, + DisableUpdateBroadcast: true, + CloneFunc: types.Server.DeepCopy, + }) + return w, trace.Wrap(err) +} diff --git a/lib/services/watcher_test.go b/lib/services/watcher_test.go index 52988beae8355..c369f5efda054 100644 --- a/lib/services/watcher_test.go +++ b/lib/services/watcher_test.go @@ -1403,3 +1403,66 @@ func newOktaAssignment(t *testing.T, name string) types.OktaAssignment { require.NoError(t, err) return assignment } + +func TestGitServerWatcher(t *testing.T) { + t.Parallel() + + ctx := context.Background() + bk, err := memory.New(memory.Config{}) + require.NoError(t, err) + + gitServerService, err := local.NewGitServerService(bk) + require.NoError(t, err) + w, err := services.NewGitServerWatcher(ctx, services.GitServerWatcherConfig{ + ResourceWatcherConfig: services.ResourceWatcherConfig{ + Component: "test", + Client: local.NewEventsService(bk), + MaxStaleness: time.Minute, + }, + GitServerGetter: gitServerService, + }) + require.NoError(t, err) + t.Cleanup(w.Close) + require.NoError(t, w.WaitInitialization()) + + // Add some git servers. + servers := make([]types.Server, 0, 5) + for i := 0; i < 5; i++ { + server := newGitServer(t, fmt.Sprintf("org%v", i+1)) + _, err = gitServerService.CreateGitServer(ctx, server) + require.NoError(t, err) + servers = append(servers, server) + } + + require.EventuallyWithT(t, func(t *assert.CollectT) { + filtered, err := w.CurrentResources(ctx) + assert.NoError(t, err) + assert.Len(t, filtered, len(servers)) + }, time.Second, time.Millisecond, "Timeout waiting for watcher to receive nodes.") + + filtered, err := w.CurrentResourcesWithFilter(ctx, func(s readonly.Server) bool { + if github := s.GetGitHub(); github != nil { + return github.Organization == "org1" || github.Organization == "org2" + } + return false + }) + require.NoError(t, err) + require.Len(t, filtered, 2) + + // Delete a server. + require.NoError(t, gitServerService.DeleteGitServer(ctx, servers[0].GetName())) + require.EventuallyWithT(t, func(t *assert.CollectT) { + filtered, err := w.CurrentResources(ctx) + assert.NoError(t, err) + assert.Len(t, filtered, len(servers)-1) + }, time.Second, time.Millisecond, "Timeout waiting for watcher to receive nodes.") + + filtered, err = w.CurrentResourcesWithFilter(ctx, func(s readonly.Server) bool { + if github := s.GetGitHub(); github != nil { + return github.Organization == "org1" + } + return false + }) + require.NoError(t, err) + require.Empty(t, filtered) +} diff --git a/lib/srv/authhandlers.go b/lib/srv/authhandlers.go index 5d6d12ad05dff..6a80ce5c83f3a 100644 --- a/lib/srv/authhandlers.go +++ b/lib/srv/authhandlers.go @@ -36,12 +36,10 @@ import ( "github.com/gravitational/teleport" "github.com/gravitational/teleport/api/types" apievents "github.com/gravitational/teleport/api/types/events" - "github.com/gravitational/teleport/api/utils/keys" apisshutils "github.com/gravitational/teleport/api/utils/sshutils" "github.com/gravitational/teleport/lib/auditd" "github.com/gravitational/teleport/lib/auth" "github.com/gravitational/teleport/lib/connectmycomputer" - dtauthz "github.com/gravitational/teleport/lib/devicetrust/authz" "github.com/gravitational/teleport/lib/events" "github.com/gravitational/teleport/lib/observability/metrics" "github.com/gravitational/teleport/lib/services" @@ -470,7 +468,9 @@ func (h *AuthHandlers) UserKeyAuth(conn ssh.ConnMetadata, key ssh.PublicKey) (*s log.WarnContext(ctx, "Received unexpected cert type", "cert_type", cert.CertType) } - if h.isProxy() { + // Skip RBAC check for proxy or git servers. RBAC check on git servers are + // performed outside this handler. + if h.isProxy() || h.c.Component == teleport.ComponentForwardingGit { return permissions, nil } @@ -645,19 +645,10 @@ func (a *ahLoginChecker) canLoginWithRBAC(cert *ssh.Certificate, ca types.CertAu return trace.Wrap(err) } - authPref, err := a.c.AccessPoint.GetAuthPreference(ctx) + state, err := services.AccessStateFromSSHCertificate(ctx, cert, accessChecker, a.c.AccessPoint) if err != nil { return trace.Wrap(err) } - state := accessChecker.GetAccessState(authPref) - _, state.MFAVerified = cert.Extensions[teleport.CertExtensionMFAVerified] - - // Certain hardware-key based private key policies are treated as MFA verification. - if policyString, ok := cert.Extensions[teleport.CertExtensionPrivateKeyPolicy]; ok { - if keys.PrivateKeyPolicy(policyString).MFAVerified() { - state.MFAVerified = true - } - } // we don't need to check the RBAC for the node if they are only allowed to join sessions if osUser == teleport.SSHSessionJoinPrincipal && @@ -675,9 +666,6 @@ func (a *ahLoginChecker) canLoginWithRBAC(cert *ssh.Certificate, ca types.CertAu } } - state.EnableDeviceVerification = true - state.DeviceVerified = dtauthz.IsSSHDeviceVerified(cert) - // check if roles allow access to server if err := accessChecker.CheckAccess( target, diff --git a/lib/srv/authhandlers_test.go b/lib/srv/authhandlers_test.go index 907a3db97b786..8e009819e2108 100644 --- a/lib/srv/authhandlers_test.go +++ b/lib/srv/authhandlers_test.go @@ -102,38 +102,58 @@ func (m mockConnMetadata) RemoteAddr() net.Addr { func TestRBAC(t *testing.T) { t.Parallel() + node, err := types.NewNode("testie_node", types.SubKindTeleportNode, types.ServerSpecV2{ + Addr: "1.2.3.4:22", + Hostname: "testie", + }, nil) + require.NoError(t, err) + + openSSHNode, err := types.NewNode("openssh", types.SubKindOpenSSHNode, types.ServerSpecV2{ + Addr: "1.2.3.4:22", + Hostname: "openssh", + }, nil) + require.NoError(t, err) + + gitServer, err := types.NewGitHubServer(types.GitHubServerMetadata{ + Integration: "org", + Organization: "org", + }) + require.NoError(t, err) + tests := []struct { name string component string - nodeExists bool - openSSHNode bool + targetServer types.Server assertRBACCheck require.BoolAssertionFunc }{ { name: "teleport node, regular server", component: teleport.ComponentNode, - nodeExists: true, - openSSHNode: false, + targetServer: node, assertRBACCheck: require.True, }, { name: "teleport node, forwarding server", component: teleport.ComponentForwardingNode, - nodeExists: true, - openSSHNode: false, + targetServer: node, assertRBACCheck: require.False, }, { name: "registered openssh node, forwarding server", component: teleport.ComponentForwardingNode, - nodeExists: true, - openSSHNode: true, + targetServer: openSSHNode, assertRBACCheck: require.True, }, { name: "unregistered openssh node, forwarding server", component: teleport.ComponentForwardingNode, - nodeExists: false, + targetServer: nil, + assertRBACCheck: require.False, + }, + { + name: "forwarding git", + component: teleport.ComponentForwardingGit, + targetServer: gitServer, assertRBACCheck: require.False, }, } @@ -176,29 +196,12 @@ func TestRBAC(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - // create node resource - var target types.Server - if tt.nodeExists { - n, err := types.NewServer("testie_node", types.KindNode, types.ServerSpecV2{ - Addr: "1.2.3.4:22", - Hostname: "testie", - Version: types.V2, - }) - require.NoError(t, err) - server, ok := n.(*types.ServerV2) - require.True(t, ok) - if tt.openSSHNode { - server.SubKind = types.SubKindOpenSSHNode - } - target = server - } - config := &AuthHandlerConfig{ Server: server, Component: tt.component, Emitter: &eventstest.MockRecorderEmitter{}, AccessPoint: accessPoint, - TargetServer: target, + TargetServer: tt.targetServer, } ah, err := NewAuthHandlers(config) require.NoError(t, err) diff --git a/lib/srv/git/forward.go b/lib/srv/git/forward.go new file mode 100644 index 0000000000000..1e7dae1ada6b3 --- /dev/null +++ b/lib/srv/git/forward.go @@ -0,0 +1,599 @@ +/* + * Teleport + * Copyright (C) 2025 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package git + +import ( + "context" + "io" + "log/slog" + "net" + + "github.com/google/uuid" + "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" + "golang.org/x/crypto/ssh" + + "github.com/gravitational/teleport" + tracessh "github.com/gravitational/teleport/api/observability/tracing/ssh" + "github.com/gravitational/teleport/api/types" + apievents "github.com/gravitational/teleport/api/types/events" + "github.com/gravitational/teleport/lib/auth/authclient" + "github.com/gravitational/teleport/lib/bpf" + "github.com/gravitational/teleport/lib/events" + "github.com/gravitational/teleport/lib/service/servicecfg" + "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/srv" + "github.com/gravitational/teleport/lib/sshutils" + "github.com/gravitational/teleport/lib/utils" + logutils "github.com/gravitational/teleport/lib/utils/log" +) + +// ForwardServerConfig is the configuration for the ForwardServer. +type ForwardServerConfig struct { + // ParentContext is a parent context, used to signal global + // closure + ParentContext context.Context + // TargetServer is the target server that represents the git-hosting + // service. + TargetServer types.Server + // TargetConn is the TCP connection to the remote host. + TargetConn net.Conn + // AuthClient is a client connected to the Auth server of this local cluster. + AuthClient authclient.ClientI + // AccessPoint is a caching client that provides access to this local cluster. + AccessPoint srv.AccessPoint + // Emitter is audit events emitter + Emitter events.StreamEmitter + // LockWatcher is a lock watcher. + LockWatcher *services.LockWatcher + // HostCertificate is the SSH host certificate this in-memory server presents + // to the client. + HostCertificate ssh.Signer + // SrcAddr is the source address + SrcAddr net.Addr + // DstAddr is the destination address + DstAddr net.Addr + // HostUUID is the UUID of the underlying proxy that the forwarding server + // is running in. + HostUUID string + + // Ciphers is a list of ciphers that the server supports. If omitted, + // the defaults will be used. + Ciphers []string + // KEXAlgorithms is a list of key exchange (KEX) algorithms that the + // server supports. If omitted, the defaults will be used. + KEXAlgorithms []string + // MACAlgorithms is a list of message authentication codes (MAC) that + // the server supports. If omitted the defaults will be used. + MACAlgorithms []string + // FIPS mode means Teleport started in a FedRAMP/FIPS 140-2 compliant + // configuration. + FIPS bool + + // Clock is an optoinal clock to override default real time clock + Clock clockwork.Clock +} + +// CheckAndSetDefaults checks and sets default values for any missing fields. +func (c *ForwardServerConfig) CheckAndSetDefaults() error { + if c.TargetServer == nil { + return trace.BadParameter("missing parameter TargetServer") + } + if c.TargetConn == nil { + return trace.BadParameter("missing parameter TargetConn") + } + if c.AuthClient == nil { + return trace.BadParameter("missing parameter AuthClient") + } + if c.AccessPoint == nil { + return trace.BadParameter("missing parameter AccessPoint") + } + if c.Emitter == nil { + return trace.BadParameter("missing parameter Emitter") + } + if c.HostCertificate == nil { + return trace.BadParameter("missing parameter HostCertificate") + } + if c.ParentContext == nil { + return trace.BadParameter("missing parameter ParentContext") + } + if c.LockWatcher == nil { + return trace.BadParameter("missing parameter LockWatcher") + } + if c.SrcAddr == nil { + return trace.BadParameter("source address required to identify client") + } + if c.DstAddr == nil { + return trace.BadParameter("destination address required to identify client") + } + if c.Clock == nil { + c.Clock = clockwork.NewRealClock() + } + return nil +} + +// ForwardServer is an in-memory SSH server that forwards git commands to remote +// git-hosting services like "github.com". +type ForwardServer struct { + events.StreamEmitter + cfg *ForwardServerConfig + logger *slog.Logger + auth *srv.AuthHandlers + reply *sshutils.Reply + id string + + // serverConn is the server side of the pipe to the client connection. + serverConn net.Conn + // clientConn is the client side of the pipe to the client connection. + clientConn net.Conn + // remoteClient is the client connected to the git-hosting service. + remoteClient *tracessh.Client + + // verifyRemoteHost is a callback to verify remote host like "github.com". + // Can be overridden for tests. Defaults to verifyRemoteHost. + verifyRemoteHost ssh.HostKeyCallback + // makeRemoteSigner generates the client certificate for connecting to the + // remote server. Can be overridden for tests. Defaults to makeRemoteSigner. + makeRemoteSigner func(context.Context, *ForwardServerConfig, srv.IdentityContext) (ssh.Signer, error) +} + +// Dial returns the client connection of the pipe +func (s *ForwardServer) Dial() (net.Conn, error) { + return s.clientConn, nil +} + +// NewForwardServer creates a new in-memory SSH server that forwards git +// commands to remote git-hosting services like "github.com". +func NewForwardServer(cfg *ForwardServerConfig) (*ForwardServer, error) { + if err := cfg.CheckAndSetDefaults(); err != nil { + return nil, trace.Wrap(err) + } + + serverConn, clientConn, err := utils.DualPipeNetConn(cfg.SrcAddr, cfg.DstAddr) + if err != nil { + return nil, trace.Wrap(err) + } + + logger := slog.With(teleport.ComponentKey, teleport.ComponentForwardingGit, + "src_addr", cfg.SrcAddr.String(), + "dst_addr", cfg.DstAddr.String(), + ) + s := &ForwardServer{ + StreamEmitter: cfg.Emitter, + cfg: cfg, + serverConn: serverConn, + clientConn: clientConn, + logger: logger, + reply: sshutils.NewReply(logger), + id: uuid.NewString(), + verifyRemoteHost: verifyRemoteHost(cfg.TargetServer), + makeRemoteSigner: makeRemoteSigner, + } + // TODO(greedy52) extract common parts from srv.NewAuthHandlers like + // CreateIdentityContext and UserKeyAuth to a common package. + s.auth, err = srv.NewAuthHandlers(&srv.AuthHandlerConfig{ + Server: s, + Component: teleport.ComponentForwardingGit, + Emitter: s.cfg.Emitter, + AccessPoint: cfg.AccessPoint, + TargetServer: cfg.TargetServer, + FIPS: cfg.FIPS, + Clock: cfg.Clock, + }) + if err != nil { + return nil, trace.Wrap(err) + } + return s, nil + +} + +// Serve starts an SSH server that forwards git commands. +func (s *ForwardServer) Serve() { + defer s.close() + s.logger.DebugContext(s.cfg.ParentContext, "Starting forwarding git") + defer s.logger.DebugContext(s.cfg.ParentContext, "Finished forwarding git") + server, err := sshutils.NewServer( + teleport.ComponentForwardingGit, + utils.NetAddr{}, /* empty addr, this is one time use so no use for listener*/ + sshutils.NewChanHandlerFunc(s.onChannel), + sshutils.StaticHostSigners(s.cfg.HostCertificate), + sshutils.AuthMethods{ + PublicKey: s.userKeyAuth, + }, + sshutils.SetFIPS(s.cfg.FIPS), + sshutils.SetCiphers(s.cfg.Ciphers), + sshutils.SetKEXAlgorithms(s.cfg.KEXAlgorithms), + sshutils.SetMACAlgorithms(s.cfg.MACAlgorithms), + sshutils.SetClock(s.cfg.Clock), + sshutils.SetNewConnHandler(sshutils.NewConnHandlerFunc(s.onConnection)), + ) + if err != nil { + s.logger.ErrorContext(s.cfg.ParentContext, "Failed to create git forward server", "error", err) + return + } + server.HandleConnection(s.serverConn) +} + +func (s *ForwardServer) close() { + if err := s.serverConn.Close(); err != nil && !utils.IsOKNetworkError(err) { + s.logger.WarnContext(s.cfg.ParentContext, "Failed to close server conn", "error", err) + } + if err := s.clientConn.Close(); err != nil && !utils.IsOKNetworkError(err) { + s.logger.WarnContext(s.cfg.ParentContext, "Failed to close client conn", "error", err) + } + if err := s.cfg.TargetConn.Close(); err != nil && !utils.IsOKNetworkError(err) { + s.logger.WarnContext(s.cfg.ParentContext, "Failed to close target conn", "error", err) + } +} + +func (s *ForwardServer) userKeyAuth(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) { + cert, ok := key.(*ssh.Certificate) + if !ok { + return nil, trace.BadParameter("unsupported key type") + } + if len(cert.Extensions[teleport.CertExtensionGitHubUserID]) == 0 { + return nil, trace.BadParameter("missing GitHub user ID") + } + + // Verify incoming user is "git" and override it with any valid principle + // to bypass principle validation. + if conn.User() != gitUser { + return nil, trace.BadParameter("only git is expected as user for git connections") + } + if len(cert.ValidPrincipals) > 0 { + conn = sshutils.NewSSHConnMetadataWithUser(conn, cert.ValidPrincipals[0]) + } + + // Use auth.UserKeyAuth to verify user cert is signed by UserCA. + permissions, err := s.auth.UserKeyAuth(conn, key) + if err != nil { + return nil, trace.Wrap(err) + } + + // Check RBAC on the git server resource (aka s.cfg.TargetServer). + if err := s.checkUserAccess(cert); err != nil { + s.logger.ErrorContext(s.Context(), "Permission denied", + "error", err, + "local_addr", logutils.StringerAttr(conn.LocalAddr()), + "remote_addr", logutils.StringerAttr(conn.RemoteAddr()), + "key", key.Type(), + "fingerprint", sshutils.Fingerprint(key), + "user", cert.KeyId, + ) + return nil, trace.Wrap(err) + } + return permissions, nil +} + +func (s *ForwardServer) checkUserAccess(cert *ssh.Certificate) error { + clusterName, err := s.cfg.AccessPoint.GetClusterName() + if err != nil { + return trace.Wrap(err) + } + accessInfo, err := services.AccessInfoFromLocalCertificate(cert) + if err != nil { + return trace.Wrap(err) + } + accessChecker, err := services.NewAccessChecker(accessInfo, clusterName.GetClusterName(), s.cfg.AccessPoint) + if err != nil { + return trace.Wrap(err) + } + state, err := services.AccessStateFromSSHCertificate(s.Context(), cert, accessChecker, s.cfg.AccessPoint) + if err != nil { + return trace.Wrap(err) + } + return trace.Wrap(accessChecker.CheckAccess(s.cfg.TargetServer, state)) +} + +func (s *ForwardServer) onConnection(ctx context.Context, ccx *sshutils.ConnectionContext) (context.Context, error) { + s.logger.Log(ctx, logutils.TraceLevel, "Handling new connection") + + identityCtx, err := s.auth.CreateIdentityContext(ccx.ServerConn) + if err != nil { + return nil, trace.Wrap(err) + } + + if err := s.initRemoteConn(ctx, ccx, identityCtx); err != nil { + s.logger.DebugContext(ctx, "onConnection failed", "error", err) + return ctx, trace.Wrap(err) + } + + // TODO(greedy52) decouple from srv.NewServerContext. We only need + // connection monitoring. + serverCtx, err := srv.NewServerContext(ctx, ccx, s, identityCtx) + if err != nil { + return nil, trace.Wrap(err) + } + + s.logger.Log(ctx, logutils.TraceLevel, "New connection accepted") + ccx.AddCloser(serverCtx) + return ctx, nil +} + +func (s *ForwardServer) onChannel(ctx context.Context, ccx *sshutils.ConnectionContext, nch ssh.NewChannel) { + s.logger.DebugContext(ctx, "Handling channel request", "channel", nch.ChannelType()) + + // Only expecting a session to execute a command. + if nch.ChannelType() != teleport.ChanSession { + s.reply.RejectUnknownChannel(ctx, nch) + return + } + + if s.remoteClient == nil { + s.reply.RejectWithNewRemoteSessionError(ctx, nch, trace.NotFound("missing remote client")) + return + } + remoteSession, err := s.remoteClient.NewSession(ctx) + if err != nil { + s.reply.RejectWithNewRemoteSessionError(ctx, nch, err) + return + } + defer remoteSession.Close() + + ch, in, err := nch.Accept() + if err != nil { + s.reply.RejectWithAcceptError(ctx, nch, err) + return + } + defer ch.Close() + + sctx := newSessionContext(ch, remoteSession) + for { + select { + case req := <-in: + if req == nil { + s.logger.DebugContext(ctx, "Client disconnected", "remote_addr", ccx.ServerConn.RemoteAddr()) + return + } + + ok, err := s.dispatch(ctx, sctx, req) + if err != nil { + s.reply.ReplyError(ctx, req, err) + return + } + s.reply.ReplyRequest(ctx, req, ok, nil) + + case execErr := <-sctx.waitExec: + code := sshutils.ExitCodeFromExecError(execErr) + s.logger.DebugContext(ctx, "Exec request complete", "code", code) + s.reply.SendExitStatus(ctx, ch, code) + return + + case <-ctx.Done(): + return + } + } +} + +type sessionContext struct { + channel ssh.Channel + remoteSession *tracessh.Session + waitExec chan error +} + +func newSessionContext(ch ssh.Channel, remoteSession *tracessh.Session) *sessionContext { + return &sessionContext{ + channel: ch, + remoteSession: remoteSession, + waitExec: make(chan error, 1), + } +} + +// dispatch executes an incoming request. If successful, it returns the ok value +// for the reply. Otherwise, it returns the error it encountered. +func (s *ForwardServer) dispatch(ctx context.Context, sctx *sessionContext, req *ssh.Request) (bool, error) { + s.logger.DebugContext(ctx, "Dispatching client request", "request_type", req.Type) + + switch req.Type { + case tracessh.EnvsRequest: + s.logger.DebugContext(ctx, "Ignored request", "request_type", req.Type) + return true, nil + case sshutils.ExecRequest: + return true, trace.Wrap(s.handleExec(ctx, sctx, req)) + case sshutils.EnvRequest: + return true, trace.Wrap(s.handleEnv(ctx, sctx, req)) + default: + s.logger.WarnContext(ctx, "Received unsupported SSH request", "request_type", req.Type) + return false, nil + } +} + +// handleExec proxies the Git command between client and the target server. +func (s *ForwardServer) handleExec(ctx context.Context, sctx *sessionContext, req *ssh.Request) error { + var r sshutils.ExecReq + if err := ssh.Unmarshal(req.Payload, &r); err != nil { + return trace.Wrap(err, "failed to unmarshal exec request") + } + + // TODO(greedy52) enable command recorder for audit log + sctx.remoteSession.Stdout = sctx.channel + sctx.remoteSession.Stderr = sctx.channel.Stderr() + remoteStdin, err := sctx.remoteSession.StdinPipe() + if err != nil { + return trace.Wrap(err, "failed to open remote session") + } + go func() { + defer remoteStdin.Close() + if _, err := io.Copy(remoteStdin, sctx.channel); err != nil { + s.logger.WarnContext(ctx, "Failed to copy git command stdin", "error", err) + } + }() + + if err := sctx.remoteSession.Start(ctx, r.Command); err != nil { + return trace.Wrap(err, "failed to start git command") + } + + go func() { + execErr := sctx.remoteSession.Wait() + sctx.waitExec <- execErr + }() + return nil +} + +// handleEnv sets env on the target server. +func (s *ForwardServer) handleEnv(ctx context.Context, sctx *sessionContext, req *ssh.Request) error { + var e sshutils.EnvReqParams + if err := ssh.Unmarshal(req.Payload, &e); err != nil { + return trace.Wrap(err) + } + s.logger.DebugContext(ctx, "Setting env on remote Git server", "name", e.Name, "value", e.Value) + err := sctx.remoteSession.Setenv(ctx, e.Name, e.Value) + if err != nil { + s.logger.WarnContext(ctx, "Failed to set env on remote session", "error", err, "request", e) + } + return nil +} + +func (s *ForwardServer) initRemoteConn(ctx context.Context, ccx *sshutils.ConnectionContext, identityCtx srv.IdentityContext) error { + netConfig, err := s.cfg.AccessPoint.GetClusterNetworkingConfig(s.cfg.ParentContext) + if err != nil { + return trace.Wrap(err) + } + signer, err := s.makeRemoteSigner(ctx, s.cfg, identityCtx) + if err != nil { + return trace.Wrap(err) + } + clientConfig := &ssh.ClientConfig{ + User: gitUser, + Auth: []ssh.AuthMethod{ + ssh.PublicKeys(signer), + }, + HostKeyCallback: s.verifyRemoteHost, + Timeout: netConfig.GetSSHDialTimeout(), + } + clientConfig.Ciphers = s.cfg.Ciphers + clientConfig.KeyExchanges = s.cfg.KEXAlgorithms + clientConfig.MACs = s.cfg.MACAlgorithms + + s.remoteClient, err = tracessh.NewClientConnWithDeadline( + s.cfg.ParentContext, + s.cfg.TargetConn, + s.cfg.DstAddr.String(), + clientConfig, + ) + if err != nil { + return trace.Wrap(err) + } + ccx.AddCloser(s.remoteClient) + return nil +} + +func makeRemoteSigner(ctx context.Context, cfg *ForwardServerConfig, identityCtx srv.IdentityContext) (ssh.Signer, error) { + switch cfg.TargetServer.GetSubKind() { + case types.SubKindGitHub: + return MakeGitHubSigner(ctx, GitHubSignerConfig{ + Server: cfg.TargetServer, + TeleportUser: identityCtx.TeleportUser, + IdentityExpires: identityCtx.CertValidBefore, + GitHubUserID: identityCtx.Certificate.Extensions[teleport.CertExtensionGitHubUserID], + AuthPreferenceGetter: cfg.AccessPoint, + GitHubUserCertGenerator: cfg.AuthClient.IntegrationsClient(), + Clock: cfg.Clock, + }) + default: + return nil, trace.BadParameter("unsupported subkind %q", cfg.TargetServer.GetSubKind()) + } +} + +func verifyRemoteHost(targetServer types.Server) ssh.HostKeyCallback { + return func(hostname string, remote net.Addr, key ssh.PublicKey) error { + switch targetServer.GetSubKind() { + case types.SubKindGitHub: + return VerifyGitHubHostKey(hostname, remote, key) + default: + return trace.BadParameter("unsupported subkind %q", targetServer.GetSubKind()) + } + } +} + +// Below functions implement srv.Server so git.ForwardServer can be used for +// srv.NewServerContext and srv.NewAuthHandlers. +// TODO(greedy52) decouple from srv.Server. + +func (s *ForwardServer) Context() context.Context { + return s.cfg.ParentContext +} +func (s *ForwardServer) TargetMetadata() apievents.ServerMetadata { + return apievents.ServerMetadata{ + ServerVersion: teleport.Version, + ServerNamespace: s.cfg.TargetServer.GetNamespace(), + ServerAddr: s.cfg.DstAddr.String(), + ServerHostname: s.cfg.TargetServer.GetHostname(), + ForwardedBy: s.cfg.HostUUID, + ServerSubKind: s.cfg.TargetServer.GetSubKind(), + } +} +func (s *ForwardServer) GetInfo() types.Server { + return s.cfg.TargetServer +} +func (s *ForwardServer) ID() string { + return s.id +} +func (s *ForwardServer) HostUUID() string { + return s.cfg.HostUUID +} +func (s *ForwardServer) GetNamespace() string { + return s.cfg.TargetServer.GetNamespace() +} +func (s *ForwardServer) AdvertiseAddr() string { + return s.clientConn.RemoteAddr().String() +} +func (s *ForwardServer) Component() string { + return teleport.ComponentForwardingGit +} +func (s *ForwardServer) PermitUserEnvironment() bool { + return false +} +func (s *ForwardServer) GetAccessPoint() srv.AccessPoint { + return s.cfg.AccessPoint +} +func (s *ForwardServer) GetDataDir() string { + return "" +} +func (s *ForwardServer) GetPAM() (*servicecfg.PAMConfig, error) { + return nil, trace.NotImplemented("not supported for git forward server") +} +func (s *ForwardServer) GetClock() clockwork.Clock { + return s.cfg.Clock +} +func (s *ForwardServer) UseTunnel() bool { + return false +} +func (s *ForwardServer) GetBPF() bpf.BPF { + return nil +} +func (s *ForwardServer) GetUserAccountingPaths() (utmp, wtmp, btmp string) { + return +} +func (s *ForwardServer) GetLockWatcher() *services.LockWatcher { + return s.cfg.LockWatcher +} +func (s *ForwardServer) GetCreateHostUser() bool { + return false +} +func (s *ForwardServer) GetHostUsers() srv.HostUsers { + return nil +} +func (s *ForwardServer) GetHostSudoers() srv.HostSudoers { + return nil +} + +const ( + gitUser = "git" +) diff --git a/lib/srv/git/forward_test.go b/lib/srv/git/forward_test.go new file mode 100644 index 0000000000000..c9e2ca0fcc1c6 --- /dev/null +++ b/lib/srv/git/forward_test.go @@ -0,0 +1,363 @@ +/* + * Teleport + * Copyright (C) 2025 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package git + +import ( + "context" + "io" + "log/slog" + "net" + "os" + "testing" + "time" + + "github.com/gravitational/trace" + "github.com/stretchr/testify/require" + "golang.org/x/crypto/ssh" + + "github.com/gravitational/teleport/api/constants" + tracessh "github.com/gravitational/teleport/api/observability/tracing/ssh" + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/api/types/wrappers" + apisshutils "github.com/gravitational/teleport/api/utils/sshutils" + "github.com/gravitational/teleport/lib/auth/authclient" + "github.com/gravitational/teleport/lib/auth/testauthority" + "github.com/gravitational/teleport/lib/backend/memory" + "github.com/gravitational/teleport/lib/cryptosuites" + "github.com/gravitational/teleport/lib/events/eventstest" + "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/services/local" + "github.com/gravitational/teleport/lib/srv" + "github.com/gravitational/teleport/lib/sshca" + "github.com/gravitational/teleport/lib/sshutils" + "github.com/gravitational/teleport/lib/utils" +) + +func TestMain(m *testing.M) { + utils.InitLoggerForTests() + os.Exit(m.Run()) +} + +func TestForwardServer(t *testing.T) { + caSigner, err := apisshutils.MakeTestSSHCA() + require.NoError(t, err) + userCert := makeUserCert(t, caSigner) + + tests := []struct { + name string + allowedGitHubOrg string + clientLogin string + verifyRemoteHost ssh.HostKeyCallback + wantNewClientError bool + verifyWithClient func(t *testing.T, ctx context.Context, client *tracessh.Client, m *mockGitHostingService) + }{ + { + name: "success", + allowedGitHubOrg: "*", + clientLogin: "git", + verifyRemoteHost: ssh.InsecureIgnoreHostKey(), + wantNewClientError: false, + verifyWithClient: func(t *testing.T, ctx context.Context, client *tracessh.Client, m *mockGitHostingService) { + session, err := client.NewSession(ctx) + require.NoError(t, err) + defer session.Close() + + gitCommand := "git-upload-pack 'org/my-repo.git'" + session.Stderr = io.Discard + session.Stdout = io.Discard + err = session.Run(ctx, gitCommand) + require.NoError(t, err) + require.Equal(t, gitCommand, m.receivedExec.Command) + }, + }, + { + name: "failed RBAC", + allowedGitHubOrg: "no-org-allowed", + clientLogin: "git", + verifyRemoteHost: ssh.InsecureIgnoreHostKey(), + wantNewClientError: true, + }, + { + name: "failed client login check", + allowedGitHubOrg: "*", + clientLogin: "not-git", + verifyRemoteHost: ssh.InsecureIgnoreHostKey(), + wantNewClientError: true, + }, + { + name: "failed remote host check", + allowedGitHubOrg: "*", + clientLogin: "git", + verifyRemoteHost: func(string, net.Addr, ssh.PublicKey) error { + return trace.AccessDenied("fake a remote host check error") + }, + verifyWithClient: func(t *testing.T, ctx context.Context, client *tracessh.Client, m *mockGitHostingService) { + // Connection is accepted but anything following fails. + _, err := client.NewSession(ctx) + require.Error(t, err) + }, + }, + { + name: "invalid channel type", + allowedGitHubOrg: "*", + clientLogin: "git", + verifyRemoteHost: ssh.InsecureIgnoreHostKey(), + verifyWithClient: func(t *testing.T, ctx context.Context, client *tracessh.Client, m *mockGitHostingService) { + _, _, err := client.OpenChannel(ctx, "unknown", nil) + require.Error(t, err) + require.Contains(t, err.Error(), "unknown channel type") + }, + }, + } + + for _, test := range tests { + test := test + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + mockEmitter := &eventstest.MockRecorderEmitter{} + mockGitService := newMockGitHostingService(t, caSigner) + hostCert, err := apisshutils.MakeRealHostCert(caSigner) + require.NoError(t, err) + targetConn, err := net.Dial("tcp", mockGitService.Addr()) + require.NoError(t, err) + + s, err := NewForwardServer(&ForwardServerConfig{ + TargetServer: makeGitServer(t, "org"), + TargetConn: targetConn, + AuthClient: mockAuthClient{}, + AccessPoint: mockAccessPoint{ + ca: caSigner, + allowedGitHubOrg: test.allowedGitHubOrg, + }, + Emitter: mockEmitter, + HostCertificate: hostCert, + ParentContext: ctx, + LockWatcher: makeLockWatcher(t), + SrcAddr: utils.MustParseAddr("127.0.0.1:12345"), + DstAddr: utils.MustParseAddr("127.0.0.1:2222"), + }) + require.NoError(t, err) + + s.verifyRemoteHost = test.verifyRemoteHost + s.makeRemoteSigner = func(context.Context, *ForwardServerConfig, srv.IdentityContext) (ssh.Signer, error) { + // mock server does not validate this, just put whatever. + return userCert, nil + } + go s.Serve() + + clientDialConn, err := s.Dial() + require.NoError(t, err) + + conn, chCh, reqCh, err := ssh.NewClientConn( + clientDialConn, + "127.0.0.1:222", + &ssh.ClientConfig{ + User: test.clientLogin, + Auth: []ssh.AuthMethod{ + ssh.PublicKeys(userCert), + }, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + Timeout: 5 * time.Second, + }, + ) + if test.wantNewClientError { + require.Error(t, err) + return + } + require.NoError(t, err) + client := tracessh.NewClient(conn, chCh, reqCh) + defer client.Close() + + test.verifyWithClient(t, ctx, client, mockGitService) + }) + } + +} + +func makeUserCert(t *testing.T, caSigner ssh.Signer) ssh.Signer { + t.Helper() + keygen := testauthority.New() + clientPrivateKey, err := cryptosuites.GeneratePrivateKeyWithAlgorithm(cryptosuites.ECDSAP256) + require.NoError(t, err) + clientCertBytes, err := keygen.GenerateUserCert(sshca.UserCertificateRequest{ + CASigner: caSigner, + PublicUserKey: clientPrivateKey.MarshalSSHPublicKey(), + CertificateFormat: constants.CertificateFormatStandard, + Identity: sshca.Identity{ + Username: "alice", + AllowedLogins: []string{"does-not-matter"}, + GitHubUserID: "1234567", + Traits: wrappers.Traits{}, + Roles: []string{"editor"}, + }, + }) + require.NoError(t, err) + clientAuthorizedCert, _, _, _, err := ssh.ParseAuthorizedKey(clientCertBytes) + require.NoError(t, err) + clientSigner, err := apisshutils.SSHSigner(clientAuthorizedCert.(*ssh.Certificate), clientPrivateKey) + require.NoError(t, err) + return clientSigner +} + +func makeLockWatcher(t *testing.T) *services.LockWatcher { + t.Helper() + backend, err := memory.New(memory.Config{}) + require.NoError(t, err) + lockWatcher, err := services.NewLockWatcher(context.Background(), services.LockWatcherConfig{ + ResourceWatcherConfig: services.ResourceWatcherConfig{ + Component: "git.test", + Client: local.NewEventsService(backend), + }, + LockGetter: local.NewAccessService(backend), + }) + require.NoError(t, err) + return lockWatcher +} + +func makeGitServer(t *testing.T, org string) types.Server { + t.Helper() + server, err := types.NewGitHubServer(types.GitHubServerMetadata{ + Integration: org, + Organization: org, + }) + require.NoError(t, err) + return server +} + +type mockGitHostingService struct { + *sshutils.Server + *sshutils.Reply + receivedExec sshutils.ExecReq +} + +func newMockGitHostingService(t *testing.T, caSigner ssh.Signer) *mockGitHostingService { + t.Helper() + hostCert, err := apisshutils.MakeRealHostCert(caSigner) + require.NoError(t, err) + m := &mockGitHostingService{ + Reply: &sshutils.Reply{}, + } + server, err := sshutils.NewServer( + "git.test", + utils.NetAddr{AddrNetwork: "tcp", Addr: "localhost:0"}, + m, + sshutils.StaticHostSigners(hostCert), + sshutils.AuthMethods{NoClient: true}, + sshutils.SetNewConnHandler(m), + ) + require.NoError(t, err) + require.NoError(t, server.Start()) + t.Cleanup(func() { + server.Close() + }) + m.Server = server + return m +} +func (m *mockGitHostingService) HandleNewConn(ctx context.Context, ccx *sshutils.ConnectionContext) (context.Context, error) { + slog.DebugContext(ctx, "mock git service receives new connection") + return ctx, nil +} +func (m *mockGitHostingService) HandleNewChan(ctx context.Context, ccx *sshutils.ConnectionContext, nch ssh.NewChannel) { + slog.DebugContext(ctx, "mock git service receives new chan") + ch, in, err := nch.Accept() + if err != nil { + m.RejectWithAcceptError(ctx, nch, err) + return + } + defer ch.Close() + for { + select { + case req := <-in: + if req == nil { + return + } + + if err := ssh.Unmarshal(req.Payload, &m.receivedExec); err != nil { + m.ReplyError(ctx, req, err) + return + } + if req.WantReply { + m.ReplyRequest(ctx, req, true, nil) + } + slog.DebugContext(ctx, "mock git service receives new exec request", "req", m.receivedExec) + m.SendExitStatus(ctx, ch, 0) + return + + case <-ctx.Done(): + return + } + } +} + +type mockAuthClient struct { + authclient.ClientI +} + +type mockAccessPoint struct { + srv.AccessPoint + ca ssh.Signer + allowedGitHubOrg string +} + +func (m mockAccessPoint) GetClusterName(...services.MarshalOption) (types.ClusterName, error) { + return types.NewClusterName(types.ClusterNameSpecV2{ + ClusterName: "git.test", + ClusterID: "git.test", + }) +} +func (m mockAccessPoint) GetClusterNetworkingConfig(context.Context) (types.ClusterNetworkingConfig, error) { + return types.DefaultClusterNetworkingConfig(), nil +} +func (m mockAccessPoint) GetSessionRecordingConfig(context.Context) (types.SessionRecordingConfig, error) { + return types.DefaultSessionRecordingConfig(), nil +} +func (m mockAccessPoint) GetAuthPreference(context.Context) (types.AuthPreference, error) { + return types.DefaultAuthPreference(), nil +} +func (m mockAccessPoint) GetRole(_ context.Context, name string) (types.Role, error) { + return types.NewRole(name, types.RoleSpecV6{ + Allow: types.RoleConditions{ + GitHubPermissions: []types.GitHubPermission{{ + Organizations: []string{m.allowedGitHubOrg}, + }}, + }, + }) +} +func (m mockAccessPoint) GetCertAuthorities(_ context.Context, caType types.CertAuthType, _ bool) ([]types.CertAuthority, error) { + if m.ca == nil { + return nil, trace.NotFound("no certificate authority found") + } + ca, err := types.NewCertAuthority(types.CertAuthoritySpecV2{ + Type: caType, + ClusterName: "git.test", + ActiveKeys: types.CAKeySet{ + SSH: []*types.SSHKeyPair{{ + PublicKey: ssh.MarshalAuthorizedKey(m.ca.PublicKey()), + }}, + }, + }) + if err != nil { + return nil, trace.Wrap(err) + } + return []types.CertAuthority{ca}, nil +} diff --git a/lib/srv/git/github.go b/lib/srv/git/github.go new file mode 100644 index 0000000000000..416d47b356c5c --- /dev/null +++ b/lib/srv/git/github.go @@ -0,0 +1,160 @@ +/* + * Teleport + * Copyright (C) 2025 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package git + +import ( + "context" + "net" + "slices" + "time" + + "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" + "golang.org/x/crypto/ssh" + "google.golang.org/grpc" + "google.golang.org/protobuf/types/known/durationpb" + + integrationv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/integration/v1" + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/cryptosuites" + "github.com/gravitational/teleport/lib/sshutils" +) + +// knownGithubDotComFingerprints contains a list of known GitHub fingerprints. +// +// https://docs.github.com/en/authentication/keeping-your-account-and-data-secure/githubs-ssh-key-fingerprints +// +// TODO(greedy52) these fingerprints can change (e.g. GitHub changed its RSA +// key in 2023 because of an incident). Instead of hard-coding the values, we +// should try to periodically (e.g. once per day) poll them from the API. +var knownGithubDotComFingerprints = []string{ + "SHA256:uNiVztksCsDhcc0u9e8BujQXVUpKZIDTMczCvj3tD2s", + "SHA256:p2QAMXNIC1TJYWeIOttrVc98/R1BUFWu3/LiyKgUfQM", + "SHA256:+DiY3wvvV6TuJJhbpZisF/zLDA0zPMSvHdkr4UvCOqU", +} + +// VerifyGitHubHostKey is an ssh.HostKeyCallback that verifies the host key +// belongs to "github.com". +func VerifyGitHubHostKey(_ string, _ net.Addr, key ssh.PublicKey) error { + actualFingerprint := ssh.FingerprintSHA256(key) + if slices.Contains(knownGithubDotComFingerprints, actualFingerprint) { + return nil + } + return trace.BadParameter("cannot verify github.com: unknown fingerprint %v algo %v", actualFingerprint, key.Type()) +} + +// AuthPreferenceGetter is an interface for retrieving the current configured +// cluster auth preference. +type AuthPreferenceGetter interface { + // GetAuthPreference returns the current cluster auth preference. + GetAuthPreference(context.Context) (types.AuthPreference, error) +} + +// GitHubUserCertGenerator is an interface to generating user certs for +// connecting to GitHub. +type GitHubUserCertGenerator interface { + // GenerateGitHubUserCert signs an SSH certificate for GitHub integration. + GenerateGitHubUserCert(context.Context, *integrationv1.GenerateGitHubUserCertRequest, ...grpc.CallOption) (*integrationv1.GenerateGitHubUserCertResponse, error) +} + +// GitHubSignerConfig is the config for MakeGitHubSigner. +type GitHubSignerConfig struct { + // Server is the target Git server. + Server types.Server + // GitHubUserID is the ID of the GitHub user to impersonate. + GitHubUserID string + // TeleportUser is the Teleport username + TeleportUser string + // AuthPreferenceGetter is used to get auth preference. + AuthPreferenceGetter AuthPreferenceGetter + // GitHubUserCertGenerator generate + GitHubUserCertGenerator GitHubUserCertGenerator + // IdentityExpires is the time that the identity should expire. + IdentityExpires time.Time + // Clock is used to control time. + Clock clockwork.Clock +} + +func (c *GitHubSignerConfig) CheckAndSetDefaults() error { + if c.Server == nil { + return trace.BadParameter("missing target server") + } + if c.Server.GetGitHub() == nil { + return trace.BadParameter("missing GitHub spec") + } + if c.GitHubUserID == "" { + return trace.BadParameter("missing GitHubUserID") + } + if c.TeleportUser == "" { + return trace.BadParameter("missing TeleportUser") + } + if c.AuthPreferenceGetter == nil { + return trace.BadParameter("missing AuthPreferenceGetter") + } + if c.GitHubUserCertGenerator == nil { + return trace.BadParameter("missing GitHubUserCertGenerator") + } + if c.IdentityExpires.IsZero() { + return trace.BadParameter("missing IdentityExpires") + } + if c.Clock == nil { + c.Clock = clockwork.NewRealClock() + } + return nil +} + +func (c *GitHubSignerConfig) certTTL() time.Duration { + userTTL := c.IdentityExpires.Sub(c.Clock.Now()) + return min(userTTL, defaultGitHubUserCertTTL) +} + +// MakeGitHubSigner generates an ssh.Signer that can impersonate a GitHub user +// to connect to GitHub. +func MakeGitHubSigner(ctx context.Context, config GitHubSignerConfig) (ssh.Signer, error) { + if err := config.CheckAndSetDefaults(); err != nil { + return nil, trace.Wrap(err) + } + + algo, err := cryptosuites.AlgorithmForKey(ctx, + cryptosuites.GetCurrentSuiteFromAuthPreference(config.AuthPreferenceGetter), + cryptosuites.GitClient) + if err != nil { + return nil, trace.Wrap(err, "getting signing algorithm") + } + sshKey, err := cryptosuites.GeneratePrivateKeyWithAlgorithm(algo) + if err != nil { + return nil, trace.Wrap(err, "generating SSH key") + } + resp, err := config.GitHubUserCertGenerator.GenerateGitHubUserCert(ctx, &integrationv1.GenerateGitHubUserCertRequest{ + Integration: config.Server.GetGitHub().Integration, + PublicKey: sshKey.MarshalSSHPublicKey(), + UserId: config.GitHubUserID, + KeyId: config.TeleportUser, + Ttl: durationpb.New(config.certTTL()), + }) + if err != nil { + return nil, trace.Wrap(err) + } + + // TODO(greedy52) cache it for TTL. + signer, err := sshutils.NewSigner(sshKey.PrivateKeyPEM(), resp.AuthorizedKey) + return signer, trace.Wrap(err) +} + +const defaultGitHubUserCertTTL = 10 * time.Minute diff --git a/lib/srv/git/github_test.go b/lib/srv/git/github_test.go new file mode 100644 index 0000000000000..916c87afb17e9 --- /dev/null +++ b/lib/srv/git/github_test.go @@ -0,0 +1,145 @@ +/* + * Teleport + * Copyright (C) 2025 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package git + +import ( + "context" + "crypto/rand" + "testing" + "time" + + "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" + "github.com/stretchr/testify/require" + "golang.org/x/crypto/ssh" + "google.golang.org/grpc" + + integrationv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/integration/v1" + "github.com/gravitational/teleport/api/types" + apisshutils "github.com/gravitational/teleport/api/utils/sshutils" +) + +type fakeAuthPreferenceGetter struct { +} + +func (f fakeAuthPreferenceGetter) GetAuthPreference(context.Context) (types.AuthPreference, error) { + return types.DefaultAuthPreference(), nil +} + +type fakeGitHubUserCertGenerator struct { + clock clockwork.Clock + checkTTL time.Duration +} + +func (f fakeGitHubUserCertGenerator) GenerateGitHubUserCert(_ context.Context, input *integrationv1.GenerateGitHubUserCertRequest, _ ...grpc.CallOption) (*integrationv1.GenerateGitHubUserCertResponse, error) { + if f.checkTTL != 0 && f.checkTTL != input.Ttl.AsDuration() { + return nil, trace.CompareFailed("expect ttl %v but got %v", f.checkTTL, input.Ttl.AsDuration()) + } + + caSigner, err := apisshutils.MakeTestSSHCA() + if err != nil { + return nil, trace.Wrap(err) + } + pubKey, _, _, _, err := ssh.ParseAuthorizedKey(input.PublicKey) + if err != nil { + return nil, trace.Wrap(err) + } + cert := &ssh.Certificate{ + // we have to use key id to identify teleport user + KeyId: input.KeyId, + Key: pubKey, + ValidAfter: uint64(f.clock.Now().Add(-time.Minute).Unix()), + ValidBefore: uint64(f.clock.Now().Add(input.Ttl.AsDuration()).Unix()), + CertType: ssh.UserCert, + } + if err := cert.SignCert(rand.Reader, caSigner); err != nil { + return nil, trace.Wrap(err) + } + return &integrationv1.GenerateGitHubUserCertResponse{ + AuthorizedKey: ssh.MarshalAuthorizedKey(cert), + }, nil +} + +func TestMakeGitHubSigner(t *testing.T) { + clock := clockwork.NewFakeClock() + server := makeGitServer(t, "org") + + tests := []struct { + name string + config GitHubSignerConfig + checkError require.ErrorAssertionFunc + }{ + { + name: "success", + config: GitHubSignerConfig{ + Server: server, + GitHubUserID: "1234567", + TeleportUser: "alice", + AuthPreferenceGetter: fakeAuthPreferenceGetter{}, + GitHubUserCertGenerator: fakeGitHubUserCertGenerator{ + clock: clock, + checkTTL: defaultGitHubUserCertTTL, + }, + IdentityExpires: clock.Now().Add(time.Hour), + Clock: clock, + }, + checkError: require.NoError, + }, + { + name: "success short ttl", + config: GitHubSignerConfig{ + Server: server, + GitHubUserID: "1234567", + TeleportUser: "alice", + AuthPreferenceGetter: fakeAuthPreferenceGetter{}, + GitHubUserCertGenerator: fakeGitHubUserCertGenerator{ + clock: clock, + checkTTL: time.Minute, + }, + IdentityExpires: clock.Now().Add(time.Minute), + Clock: clock, + }, + checkError: require.NoError, + }, + { + name: "no GitHubUserID", + config: GitHubSignerConfig{ + Server: server, + TeleportUser: "alice", + AuthPreferenceGetter: fakeAuthPreferenceGetter{}, + GitHubUserCertGenerator: fakeGitHubUserCertGenerator{ + clock: clock, + checkTTL: time.Minute, + }, + IdentityExpires: clock.Now().Add(time.Minute), + Clock: clock, + }, + checkError: func(t require.TestingT, err error, i ...interface{}) { + require.True(t, trace.IsBadParameter(err), i...) + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + _, err := MakeGitHubSigner(context.Background(), test.config) + test.checkError(t, err) + }) + } +} diff --git a/lib/srv/regular/sshserver_test.go b/lib/srv/regular/sshserver_test.go index eb7d3283b4508..029c5a0b1dc0d 100644 --- a/lib/srv/regular/sshserver_test.go +++ b/lib/srv/regular/sshserver_test.go @@ -1678,6 +1678,7 @@ func TestProxyRoundRobin(t *testing.T) { Emitter: proxyClient, LockWatcher: lockWatcher, NodeWatcher: nodeWatcher, + GitServerWatcher: newGitServerWatcher(ctx, t, proxyClient), CertAuthorityWatcher: caWatcher, CircuitBreakerConfig: breaker.NoopBreakerConfig(), }) @@ -1813,6 +1814,7 @@ func TestProxyDirectAccess(t *testing.T) { Emitter: proxyClient, LockWatcher: lockWatcher, NodeWatcher: nodeWatcher, + GitServerWatcher: newGitServerWatcher(ctx, t, proxyClient), CertAuthorityWatcher: caWatcher, CircuitBreakerConfig: breaker.NoopBreakerConfig(), }) @@ -2499,6 +2501,7 @@ func TestParseSubsystemRequest(t *testing.T) { Emitter: proxyClient, LockWatcher: lockWatcher, NodeWatcher: nodeWatcher, + GitServerWatcher: newGitServerWatcher(ctx, t, proxyClient), CertAuthorityWatcher: caWatcher, }) require.NoError(t, err) @@ -2760,6 +2763,7 @@ func TestIgnorePuTTYSimpleChannel(t *testing.T) { Emitter: proxyClient, LockWatcher: lockWatcher, NodeWatcher: nodeWatcher, + GitServerWatcher: newGitServerWatcher(ctx, t, proxyClient), CertAuthorityWatcher: caWatcher, }) require.NoError(t, err) @@ -3099,6 +3103,19 @@ func newNodeWatcher(ctx context.Context, t *testing.T, client *authclient.Client return nodeWatcher } +func newGitServerWatcher(ctx context.Context, t *testing.T, client *authclient.Client) *services.GenericWatcher[types.Server, readonly.Server] { + watcher, err := services.NewGitServerWatcher(ctx, services.GitServerWatcherConfig{ + ResourceWatcherConfig: services.ResourceWatcherConfig{ + Component: "test", + Client: client, + }, + GitServerGetter: client.GitServerReadOnlyClient(), + }) + require.NoError(t, err) + t.Cleanup(watcher.Close) + return watcher +} + func newCertAuthorityWatcher(ctx context.Context, t *testing.T, client types.Events) *services.CertAuthorityWatcher { caWatcher, err := services.NewCertAuthorityWatcher(ctx, services.CertAuthorityWatcherConfig{ ResourceWatcherConfig: services.ResourceWatcherConfig{ @@ -3180,6 +3197,7 @@ func TestHostUserCreationProxy(t *testing.T) { Emitter: proxyClient, LockWatcher: lockWatcher, NodeWatcher: nodeWatcher, + GitServerWatcher: newGitServerWatcher(ctx, t, proxyClient), CertAuthorityWatcher: caWatcher, CircuitBreakerConfig: breaker.NoopBreakerConfig(), }) diff --git a/lib/sshutils/exec.go b/lib/sshutils/exec.go new file mode 100644 index 0000000000000..e8476121e86c7 --- /dev/null +++ b/lib/sshutils/exec.go @@ -0,0 +1,70 @@ +/* + * Teleport + * Copyright (C) 2025 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package sshutils + +import ( + "context" + "errors" + "log/slog" + "syscall" + + "github.com/gravitational/teleport" +) + +// errorWithExitStatus defines an interface that provides an ExitStatus +// function to get the exit code of the process execution. +// +// This interface is introduced so ssh.ExitError can be mocked in unit test. +type errorWithExitStatus interface { + ExitStatus() int +} + +// execExitError defines an interface that provides a Sys function to get exit +// status from the process execution. +// +// This interface is introduced so exec.ExitError can be mocked in unit test. +type execExitError interface { + Sys() any +} + +// ExitCodeFromExecError extracts and returns the exit code from the +// error. +func ExitCodeFromExecError(err error) int { + // If no error occurred, return 0 (success). + if err == nil { + return teleport.RemoteCommandSuccess + } + + var execExitErr execExitError + var exitErr errorWithExitStatus + switch { + case errors.As(err, &execExitErr): + waitStatus, ok := execExitErr.Sys().(syscall.WaitStatus) + if !ok { + return teleport.RemoteCommandFailure + } + return waitStatus.ExitStatus() + case errors.As(err, &exitErr): + return exitErr.ExitStatus() + // An error occurred, but the type is unknown, return a generic 255 code. + default: + slog.DebugContext(context.Background(), "Unknown error returned when executing command", "error", err) + return teleport.RemoteCommandFailure + } +} diff --git a/lib/sshutils/exec_test.go b/lib/sshutils/exec_test.go new file mode 100644 index 0000000000000..fd24eac6a42c5 --- /dev/null +++ b/lib/sshutils/exec_test.go @@ -0,0 +1,98 @@ +/* + * Teleport + * Copyright (C) 2025 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package sshutils + +import ( + "errors" + "os/exec" + "syscall" + "testing" + + "github.com/stretchr/testify/require" + "golang.org/x/crypto/ssh" + + "github.com/gravitational/teleport" +) + +type mockErrorWithExitStatus struct { +} + +func (e mockErrorWithExitStatus) ExitStatus() int { + return 2 +} +func (e mockErrorWithExitStatus) Error() string { + return "mockErrorWithExitStatus" +} + +type mockExecExitError struct { + sys any +} + +func (e mockExecExitError) Sys() any { + return e.sys +} +func (e mockExecExitError) Error() string { + return "mockExecExitError" +} + +func TestExitCodeFromExecError(t *testing.T) { + // These struct types cannot be mocked. Implementation uses interfaces + // instead of these types. Double check if these types satisfy the + // interfaces. + require.ErrorAs(t, &ssh.ExitError{}, new(errorWithExitStatus)) + require.ErrorAs(t, &exec.ExitError{}, new(execExitError)) + + tests := []struct { + name string + input error + want int + }{ + { + name: "success", + input: nil, + want: teleport.RemoteCommandSuccess, + }, + { + name: "exec exit error", + input: mockExecExitError{sys: syscall.WaitStatus(1 << 8)}, + want: 1, + }, + { + name: "exec exit error with unknown sys", + input: mockExecExitError{sys: "unknown"}, + want: teleport.RemoteCommandFailure, + }, + { + name: "ssh exit error", + input: mockErrorWithExitStatus{}, + want: 2, + }, + { + name: "unknown error", + input: errors.New("unknown error"), + want: teleport.RemoteCommandFailure, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require.Equal(t, tt.want, ExitCodeFromExecError(tt.input)) + }) + } +} diff --git a/lib/sshutils/mock.go b/lib/sshutils/mock_test.go similarity index 57% rename from lib/sshutils/mock.go rename to lib/sshutils/mock_test.go index 37cdc2b819f1b..43cfe4c543eb9 100644 --- a/lib/sshutils/mock.go +++ b/lib/sshutils/mock_test.go @@ -22,6 +22,7 @@ import ( "errors" "io" + "github.com/stretchr/testify/mock" "golang.org/x/crypto/ssh" ) @@ -67,3 +68,55 @@ type mockSSHConn struct { func (mc *mockSSHConn) OpenChannel(name string, data []byte) (ssh.Channel, <-chan *ssh.Request, error) { return mc.mockChan, make(<-chan *ssh.Request), nil } + +type mockSSHNewChannel struct { + mock.Mock + ssh.NewChannel +} + +func newMockSSHNewChannel(channelType string) *mockSSHNewChannel { + m := new(mockSSHNewChannel) + m.On("ChannelType").Return(channelType) + m.On("Reject", mock.Anything, mock.Anything).Return(nil) + return m +} + +func (m *mockSSHNewChannel) ChannelType() string { + return m.Called().String(0) +} + +func (m *mockSSHNewChannel) Reject(reason ssh.RejectionReason, message string) error { + args := m.Called(reason, message) + return args.Error(0) +} + +type mockSSHChannel struct { + mock.Mock + ssh.Channel +} + +func newMockSSHChannel() *mockSSHChannel { + m := new(mockSSHChannel) + m.On("SendRequest", mock.Anything, mock.Anything, mock.Anything).Return(false, nil) + return m +} + +func (m *mockSSHChannel) SendRequest(name string, wantReply bool, payload []byte) (bool, error) { + args := m.Called(name, wantReply, payload) + return args.Bool(0), args.Error(1) +} + +type mockSSHRequest struct { + mock.Mock +} + +func newMockSSHRequest() *mockSSHRequest { + m := new(mockSSHRequest) + m.On("Reply", mock.Anything, mock.Anything).Return(nil) + return m +} + +func (m *mockSSHRequest) Reply(ok bool, payload []byte) error { + args := m.Called(ok, payload) + return args.Error(0) +} diff --git a/lib/sshutils/reply.go b/lib/sshutils/reply.go new file mode 100644 index 0000000000000..a96d66efd79f8 --- /dev/null +++ b/lib/sshutils/reply.go @@ -0,0 +1,106 @@ +/* + * Teleport + * Copyright (C) 2025 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package sshutils + +import ( + "context" + "errors" + "fmt" + "log/slog" + + "golang.org/x/crypto/ssh" +) + +// SSHRequest defines an interface for ssh.Request. +type SSHRequest interface { + // Reply sends a response to a request. + Reply(ok bool, payload []byte) error +} + +func sshRequestType(r SSHRequest) string { + if sshReq, ok := r.(*ssh.Request); ok { + return sshReq.Type + } + return "unknown" +} + +// Reply is a helper to handle replying/rejecting and log messages when needed. +type Reply struct { + log *slog.Logger +} + +// NewReply creates a new reply helper for SSH servers. +func NewReply(log *slog.Logger) *Reply { + return &Reply{log: log} +} + +// RejectChannel rejects the channel with provided message. +func (r *Reply) RejectChannel(ctx context.Context, nch ssh.NewChannel, reason ssh.RejectionReason, msg string) { + if err := nch.Reject(reason, msg); err != nil { + r.log.WarnContext(ctx, "Failed to reject channel", "error", err) + } +} + +// RejectUnknownChannel rejects the channel with reason ssh.UnknownChannelType. +func (r *Reply) RejectUnknownChannel(ctx context.Context, nch ssh.NewChannel) { + channelType := nch.ChannelType() + r.log.WarnContext(ctx, "Unknown channel type", "channel", channelType) + r.RejectChannel(ctx, nch, ssh.UnknownChannelType, fmt.Sprintf("unknown channel type: %v", channelType)) +} + +// RejectWithAcceptError rejects the channel when ssh.NewChannel.Accept fails. +func (r *Reply) RejectWithAcceptError(ctx context.Context, nch ssh.NewChannel, err error) { + r.log.WarnContext(ctx, "Unable to accept channel", "channel", nch.ChannelType(), "error", err) + r.RejectChannel(ctx, nch, ssh.ConnectionFailed, fmt.Sprintf("unable to accept channel: %v", err)) +} + +// RejectWithNewRemoteSessionError rejects the channel when the corresponding +// remote session fails to create. +func (r *Reply) RejectWithNewRemoteSessionError(ctx context.Context, nch ssh.NewChannel, remoteError error) { + r.log.WarnContext(ctx, "Remote session open failed", "error", remoteError) + reason, msg := ssh.ConnectionFailed, fmt.Sprintf("remote session open failed: %v", remoteError) + var e *ssh.OpenChannelError + if errors.As(remoteError, &e) { + reason, msg = e.Reason, e.Message + } + r.RejectChannel(ctx, nch, reason, msg) +} + +// ReplyError replies an error to an ssh.Request. +func (r *Reply) ReplyError(ctx context.Context, req SSHRequest, err error) { + r.log.WarnContext(ctx, "failure handling SSH request", "request_type", sshRequestType(req), "error", err) + if err := req.Reply(false, []byte(err.Error())); err != nil { + r.log.WarnContext(ctx, "failed sending error Reply on SSH channel", "error", err) + } +} + +// ReplyRequest replies to an ssh.Request with provided ok and payload. +func (r *Reply) ReplyRequest(ctx context.Context, req SSHRequest, ok bool, payload []byte) { + if err := req.Reply(ok, payload); err != nil { + r.log.WarnContext(ctx, "failed replying OK to SSH request", "request_type", sshRequestType(req), "error", err) + } +} + +// SendExitStatus sends an exit-status. +func (r *Reply) SendExitStatus(ctx context.Context, ch ssh.Channel, code int) { + _, err := ch.SendRequest("exit-status", false, ssh.Marshal(struct{ C uint32 }{C: uint32(code)})) + if err != nil { + r.log.InfoContext(ctx, "Failed to send exit status", "error", err) + } +} diff --git a/lib/sshutils/reply_test.go b/lib/sshutils/reply_test.go new file mode 100644 index 0000000000000..68427142dbffc --- /dev/null +++ b/lib/sshutils/reply_test.go @@ -0,0 +1,93 @@ +/* + * Teleport + * Copyright (C) 2025 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package sshutils + +import ( + "context" + "errors" + "log/slog" + "testing" + + "golang.org/x/crypto/ssh" + + "github.com/gravitational/teleport" +) + +func TestReply(t *testing.T) { + r := NewReply(slog.With(teleport.Component, "test")) + + t.Run("RejectChannel", func(t *testing.T) { + m := newMockSSHNewChannel("session") + r.RejectChannel(context.Background(), m, ssh.ResourceShortage, "test error") + m.AssertCalled(t, "Reject", ssh.ResourceShortage, "test error") + }) + + t.Run("RejectUnknownChannel", func(t *testing.T) { + m := newMockSSHNewChannel("unknown_channel") + r.RejectUnknownChannel(context.Background(), m) + m.AssertCalled(t, "Reject", ssh.UnknownChannelType, "unknown channel type: unknown_channel") + }) + + t.Run("RejectWithAcceptError", func(t *testing.T) { + m := newMockSSHNewChannel("session") + r.RejectWithAcceptError(context.Background(), m, errors.New("test error")) + m.AssertCalled(t, "Reject", ssh.ConnectionFailed, "unable to accept channel: test error") + }) + + t.Run("RejectWithNewRemoteSessionError", func(t *testing.T) { + t.Run("internal error", func(t *testing.T) { + m := newMockSSHNewChannel("session") + r.RejectWithNewRemoteSessionError(context.Background(), m, errors.New("test error")) + m.AssertCalled(t, "Reject", ssh.ConnectionFailed, "remote session open failed: test error") + }) + t.Run("remote error", func(t *testing.T) { + m := newMockSSHNewChannel("session") + r.RejectWithNewRemoteSessionError(context.Background(), m, &ssh.OpenChannelError{ + Reason: ssh.ResourceShortage, + Message: "test error", + }) + m.AssertCalled(t, "Reject", ssh.ResourceShortage, "test error") + }) + }) + + t.Run("ReplyError", func(t *testing.T) { + m := newMockSSHRequest() + r.ReplyError(context.Background(), m, errors.New("test error")) + m.AssertCalled(t, "Reply", false, []byte("test error")) + }) + + t.Run("ReplyRequest", func(t *testing.T) { + t.Run("ok true", func(t *testing.T) { + m := newMockSSHRequest() + r.ReplyRequest(context.Background(), m, true, []byte("ok true")) + m.AssertCalled(t, "Reply", true, []byte("ok true")) + }) + t.Run("ok false", func(t *testing.T) { + m := newMockSSHRequest() + r.ReplyRequest(context.Background(), m, false, []byte("ok false")) + m.AssertCalled(t, "Reply", false, []byte("ok false")) + }) + }) + + t.Run("SendExitStatus", func(t *testing.T) { + m := newMockSSHChannel() + r.SendExitStatus(context.Background(), m, 1) + m.AssertCalled(t, "SendRequest", "exit-status", false, []byte{0, 0, 0, 1}) + }) +} diff --git a/lib/sshutils/server.go b/lib/sshutils/server.go index 7020c302342c6..2d8f7cb85d4bc 100644 --- a/lib/sshutils/server.go +++ b/lib/sshutils/server.go @@ -715,6 +715,13 @@ type NewConnHandler interface { HandleNewConn(ctx context.Context, ccx *ConnectionContext) (context.Context, error) } +// NewConnHandlerFunc wraps a function to satisfy NewConnHandler interface. +type NewConnHandlerFunc func(ctx context.Context, ccx *ConnectionContext) (context.Context, error) + +func (f NewConnHandlerFunc) HandleNewConn(ctx context.Context, ccx *ConnectionContext) (context.Context, error) { + return f(ctx, ccx) +} + type AuthMethods struct { PublicKey PublicKeyFunc Password PasswordFunc diff --git a/lib/sshutils/utils.go b/lib/sshutils/utils.go index 5f08a40748fbe..7dcd762629e74 100644 --- a/lib/sshutils/utils.go +++ b/lib/sshutils/utils.go @@ -23,6 +23,7 @@ import ( "strconv" "github.com/gravitational/trace" + "golang.org/x/crypto/ssh" "github.com/gravitational/teleport/lib/utils" ) @@ -42,3 +43,22 @@ func SplitHostPort(addrString string) (string, uint32, error) { } return addr.Host(), uint32(addr.Port(0)), nil } + +// SSHConnMetadataWithUser overrides an ssh.ConnMetadata with provided user. +type SSHConnMetadataWithUser struct { + ssh.ConnMetadata + user string +} + +// NewSSHConnMetadataWithUser overrides an ssh.ConnMetadata with provided user. +func NewSSHConnMetadataWithUser(conn ssh.ConnMetadata, user string) SSHConnMetadataWithUser { + return SSHConnMetadataWithUser{ + ConnMetadata: conn, + user: user, + } +} + +// User returns the user ID for this connection. +func (s SSHConnMetadataWithUser) User() string { + return s.user +} diff --git a/lib/web/apiserver_test.go b/lib/web/apiserver_test.go index 71b51568c5610..c6b204ce129fe 100644 --- a/lib/web/apiserver_test.go +++ b/lib/web/apiserver_test.go @@ -400,6 +400,16 @@ func newWebSuiteWithConfig(t *testing.T, cfg webSuiteConfig) *WebSuite { require.NoError(t, err) defer caWatcher.Close() + proxyGitServerWatcher, err := services.NewGitServerWatcher(ctx, services.GitServerWatcherConfig{ + ResourceWatcherConfig: services.ResourceWatcherConfig{ + Component: teleport.ComponentProxy, + Client: s.proxyClient, + }, + GitServerGetter: s.proxyClient.GitServerReadOnlyClient(), + }) + require.NoError(t, err) + t.Cleanup(proxyGitServerWatcher.Close) + revTunServer, err := reversetunnel.NewServer(reversetunnel.Config{ ID: node.ID(), Listener: revTunListener, @@ -415,6 +425,7 @@ func newWebSuiteWithConfig(t *testing.T, cfg webSuiteConfig) *WebSuite { DataDir: t.TempDir(), LockWatcher: proxyLockWatcher, NodeWatcher: proxyNodeWatcher, + GitServerWatcher: proxyGitServerWatcher, CertAuthorityWatcher: caWatcher, CircuitBreakerConfig: breaker.NoopBreakerConfig(), LocalAuthAddresses: []string{s.server.TLS.Listener.Addr().String()}, @@ -8270,6 +8281,16 @@ func createProxy(ctx context.Context, t *testing.T, proxyID string, node *regula require.NoError(t, err) t.Cleanup(proxyNodeWatcher.Close) + proxyGitServerWatcher, err := services.NewGitServerWatcher(ctx, services.GitServerWatcherConfig{ + ResourceWatcherConfig: services.ResourceWatcherConfig{ + Component: teleport.ComponentProxy, + Client: client, + }, + GitServerGetter: client.GitServerReadOnlyClient(), + }) + require.NoError(t, err) + t.Cleanup(proxyGitServerWatcher.Close) + revTunServer, err := reversetunnel.NewServer(reversetunnel.Config{ ID: node.ID(), Listener: revTunListener, @@ -8285,6 +8306,7 @@ func createProxy(ctx context.Context, t *testing.T, proxyID string, node *regula DataDir: t.TempDir(), LockWatcher: proxyLockWatcher, NodeWatcher: proxyNodeWatcher, + GitServerWatcher: proxyGitServerWatcher, CertAuthorityWatcher: proxyCAWatcher, CircuitBreakerConfig: breaker.NoopBreakerConfig(), LocalAuthAddresses: []string{authServer.Listener.Addr().String()}, diff --git a/tool/tsh/common/git_list_test.go b/tool/tsh/common/git_list_test.go index d2f038db13775..cf4d52e318043 100644 --- a/tool/tsh/common/git_list_test.go +++ b/tool/tsh/common/git_list_test.go @@ -98,8 +98,8 @@ func TestGitListCommand(t *testing.T) { }, containsOutput: []string{ `"kind": "git_server"`, - `"hostname": "org1.github-org"`, - `"hostname": "org2.github-org"`, + `"hostname": "org1.teleport-github-org"`, + `"hostname": "org2.teleport-github-org"`, }, }, { @@ -110,8 +110,8 @@ func TestGitListCommand(t *testing.T) { }, containsOutput: []string{ "- kind: git_server", - "hostname: org1.github-org", - "hostname: org2.github-org", + "hostname: org1.teleport-github-org", + "hostname: org2.teleport-github-org", }, }, }