From 987995de4e7e6a6991818b8ff9a943d8386ec1e9 Mon Sep 17 00:00:00 2001 From: Jakub Nyckowski Date: Thu, 25 Apr 2024 14:08:07 -0400 Subject: [PATCH] Remove Assist command execution and chat feature (#40806) * Remove Assist command execution and chat feature The changes remove tests and related code for command execution in the assist package and web command handling. * Remove Teleport Assist chat from CHANGELOG.md The changes in this commit reflect the removal of the Teleport Assist chat from the CHANGELOG documentation of Teleport version 16. The Assist feature continues to be available in the SSH Web Terminal and Audit Monitoring. * Remove assist chat from the test plan --- .github/ISSUE_TEMPLATE/testplan.md | 11 - CHANGELOG.md | 5 + integration/assist/command_test.go | 468 ------------------ lib/web/apiserver.go | 76 --- lib/web/assistant.go | 375 +------------- lib/web/assistant_test.go | 292 ----------- lib/web/command.go | 768 ----------------------------- lib/web/command_test.go | 422 ---------------- lib/web/command_utils.go | 207 -------- lib/web/command_utils_test.go | 145 ------ lib/web/terminal.go | 47 +- 11 files changed, 56 insertions(+), 2760 deletions(-) delete mode 100644 integration/assist/command_test.go delete mode 100644 lib/web/command.go delete mode 100644 lib/web/command_test.go delete mode 100644 lib/web/command_utils.go delete mode 100644 lib/web/command_utils_test.go diff --git a/.github/ISSUE_TEMPLATE/testplan.md b/.github/ISSUE_TEMPLATE/testplan.md index 5206101eee8a5..58180b57fde30 100644 --- a/.github/ISSUE_TEMPLATE/testplan.md +++ b/.github/ISSUE_TEMPLATE/testplan.md @@ -1504,17 +1504,6 @@ Assist test plan is in the core section instead of WebUI as most functionality i - [ ] Assist is enabled by default in the Cloud Team plan. - [ ] Assist is always disabled when etcd is used as a backend. -- Conversations - - [ ] A new conversation can be started. - - [ ] SSH command can be executed on one server. - - [ ] SSH command can be executed on multiple servers. - - [ ] SSH command can be executed on a node with per session MFA enabled. - - [ ] Execution output is explained when it fits the context window. - - [ ] Assist can list all nodes/execute a command on all nodes (using embeddings). - - [ ] Access request can be created. - - [ ] Access request is created when approved. - - [ ] Conversation title is set after the first message. - - SSH integration - [ ] Assist icon is visible in WebUI's Terminal - [ ] A Bash command can be generated in the above window. diff --git a/CHANGELOG.md b/CHANGELOG.md index ce97cf90a7147..66856d39cd467 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,11 @@ Opsgenie plugin users, role annotations must now contain See [the Opsgenie plugin documentation](docs/pages/access-controls/access-request-plugins/opsgenie.mdx) for setup instructions. +#### Teleport Assist chat has been remove. + +Teleport Assist chat has been removed from Teleport 16. Assist is still available +in the SSH Web Terminal and Audit Monitoring. + ## 15.0.0 (xx/xx/24) ### New features diff --git a/integration/assist/command_test.go b/integration/assist/command_test.go deleted file mode 100644 index 1ddddb668b137..0000000000000 --- a/integration/assist/command_test.go +++ /dev/null @@ -1,468 +0,0 @@ -/* - * Teleport - * Copyright (C) 2023 Gravitational, Inc. - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU Affero General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU Affero General Public License for more details. - * - * You should have received a copy of the GNU Affero General Public License - * along with this program. If not, see . - */ - -package assist - -import ( - "context" - "crypto/tls" - "crypto/x509" - "encoding/base64" - "encoding/json" - "encoding/pem" - "fmt" - "io/fs" - "net" - "net/http/httptest" - "net/url" - "os" - "path/filepath" - "testing" - "time" - - "github.com/gogo/protobuf/proto" - "github.com/google/uuid" - "github.com/gorilla/websocket" - "github.com/gravitational/trace" - "github.com/sashabaranov/go-openai" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "golang.org/x/crypto/ssh" - "google.golang.org/protobuf/types/known/timestamppb" - - "github.com/gravitational/teleport" - "github.com/gravitational/teleport/api/breaker" - "github.com/gravitational/teleport/api/client" - "github.com/gravitational/teleport/api/constants" - apidefaults "github.com/gravitational/teleport/api/defaults" - "github.com/gravitational/teleport/api/gen/proto/go/assist/v1" - "github.com/gravitational/teleport/api/types" - "github.com/gravitational/teleport/api/utils/keys" - apisshutils "github.com/gravitational/teleport/api/utils/sshutils" - "github.com/gravitational/teleport/integration/helpers" - "github.com/gravitational/teleport/lib/ai/testutils" - libauth "github.com/gravitational/teleport/lib/auth" - "github.com/gravitational/teleport/lib/auth/native" - "github.com/gravitational/teleport/lib/config" - "github.com/gravitational/teleport/lib/defaults" - "github.com/gravitational/teleport/lib/service" - "github.com/gravitational/teleport/lib/service/servicecfg" - "github.com/gravitational/teleport/lib/session" - "github.com/gravitational/teleport/lib/sshutils" - "github.com/gravitational/teleport/lib/utils" - "github.com/gravitational/teleport/lib/web" -) - -const ( - testUser = "testUser" - testCommandOutput = "teleport1234" - testToken = "token" - testClusterName = "teleport.example.com" -) - -// TestAssistCommandOpenSSH tests that command output is properly recorded when -// executing commands through assist on OpenSSH nodes. -func TestAssistCommandOpenSSH(t *testing.T) { - // Setup section: starting Teleport, creating the user and starting a mock SSH server - testDir := t.TempDir() - ctx, cancel := context.WithCancel(context.Background()) - t.Cleanup(cancel) - - openAIMock := mockOpenAI(t) - - rc := setupTeleport(t, testDir, openAIMock.URL) - auth := rc.Process.GetAuthServer() - proxyAddr, err := rc.Process.ProxyWebAddr() - require.NoError(t, err) - - userClient, userPassword := setupTestUser(t, ctx, rc) - - node := registerAndSetupMockSSHNode(t, ctx, testDir, rc) - - // Test section: We're checking that when a user executes a command through - // Assist on an agentless node, a session recording gets created and - // contains the command output. - - // Create a new conversation - conversation, err := userClient.CreateAssistantConversation(context.Background(), &assist.CreateAssistantConversationRequest{ - Username: testUser, - CreatedTime: timestamppb.Now(), - }) - require.NoError(t, err) - - // Login and execute the command - webPack := helpers.LoginWebClient(t, proxyAddr.String(), testUser, userPassword) - endpoint, err := url.JoinPath("command", "$site", "execute") - require.NoError(t, err) - - req := web.CommandRequest{ - Query: fmt.Sprintf("name == \"%s\"", node.GetHostname()), - Login: testUser, - ConversationID: conversation.Id, - ExecutionID: uuid.New().String(), - Command: "echo teleport", - } - - ws, resp, err := webPack.OpenWebsocket(t, endpoint, req) - require.NoError(t, err) - require.NoError(t, resp.Body.Close()) - - // Processing the execution websocket messages: - // - the first message is the session metadata (including the session ID we need) - // - the second message is the streamed command output - // - the third message is a session close - - execSocket := executionWebsocketReader{ws} - - // First message: session metadata - envelope, err := execSocket.Read() - require.NoError(t, err) - var sessionMetadata sessionMetadataResponse - require.NoError(t, json.Unmarshal([]byte(envelope.Payload), &sessionMetadata)) - - // Second message: command output - envelope, err = execSocket.Read() - require.NoError(t, err) - perNodeEnvelope := struct { - ID string `json:"node_id"` - MsgType string `json:"type"` - Payload string `json:"payload"` - }{} - require.NoError(t, json.Unmarshal([]byte(envelope.Payload), &perNodeEnvelope)) - // Assert the command executed properly. If the execution failed, we will - // receive a web.envelopeTypeError message instead - require.Equal(t, web.EnvelopeTypeStdout, perNodeEnvelope.MsgType) - output, err := base64.StdEncoding.DecodeString(perNodeEnvelope.Payload) - require.NoError(t, err) - // Assert the streamed command output content is the one expected - require.Equal(t, testCommandOutput, string(output)) - - // Third message: session close - envelope, err = execSocket.Read() - require.NoError(t, err) - require.Equal(t, defaults.WebsocketClose, envelope.Type) - // Now the execution is finished - - // Waiting for the session recording to be uploaded and available - require.Eventually(t, func() bool { - chunk, err := auth.GetSessionChunk(apidefaults.Namespace, sessionMetadata.Session.ID, 0, 4096) - if err != nil { - if trace.IsNotFound(err) { - return false - } - assert.Fail(t, "error should be nil or not found, is %s", err) - } - assert.NotNil(t, chunk) - return true - }, 10*time.Second, 200*time.Millisecond) - - // Validating the session recording contains the SSH server output - chunk, err := auth.GetSessionChunk(apidefaults.Namespace, sessionMetadata.Session.ID, 0, 4096) - require.NoError(t, err) - require.Equal(t, testCommandOutput, string(chunk)) -} - -// mockOpenAI starts an OpenAI mock server that answers one completion request -// successfully (the output is a plain text command summary, it cannot be used -// for an agent thinking step. -// The server returns errors for embeddings requests from the auth, but -// this should not affect the test. -func mockOpenAI(t *testing.T) *httptest.Server { - responses := []string{"This is the summary of the command."} - server := httptest.NewServer(testutils.GetTestHandlerFn(t, responses)) - t.Cleanup(server.Close) - return server -} - -// setupTeleport starts a Teleport instance running the Auth and Proxy service, -// with Assist and the web service enabled. The instance supports Node joining -// with the static token testToken. -func setupTeleport(t *testing.T, testDir, openaiMockURL string) *helpers.TeleInstance { - cfg := helpers.InstanceConfig{ - ClusterName: testClusterName, - HostID: uuid.New().String(), - NodeName: helpers.Loopback, - Log: utils.NewLoggerForTests(), - } - cfg.Listeners = helpers.SingleProxyPortSetup(t, &cfg.Fds) - rc := helpers.NewInstance(t, cfg) - - var err error - rcConf := servicecfg.MakeDefaultConfig() - rcConf.DataDir = testDir - rcConf.Auth.Enabled = true - rcConf.Auth.NetworkingConfig.SetProxyListenerMode(types.ProxyListenerMode_Multiplex) - rcConf.Auth.Preference.SetSecondFactor("off") - rcConf.Proxy.Enabled = true - rcConf.Proxy.DisableWebService = false - rcConf.Proxy.DisableWebInterface = true - rcConf.SSH.Enabled = false - rcConf.Version = "v3" - rcConf.Auth.StaticTokens, err = types.NewStaticTokens(types.StaticTokensSpecV2{ - StaticTokens: []types.ProvisionTokenV1{ - { - Roles: []types.SystemRole{types.RoleNode}, - Token: testToken, - }, - }, - }) - rcConf.Proxy.AssistAPIKey = "test" - rcConf.Auth.AssistAPIKey = "test" - openAIConfig := openai.DefaultConfig("test") - openAIConfig.BaseURL = openaiMockURL + "/v1" - rcConf.Testing.OpenAIConfig = &openAIConfig - require.NoError(t, err) - rcConf.CircuitBreakerConfig = breaker.NoopBreakerConfig() - - err = rc.CreateEx(t, nil, rcConf) - require.NoError(t, err) - err = rc.Start() - require.NoError(t, err) - t.Cleanup(func() { - _ = rc.StopAll() - }) - - return rc -} - -// setupTestUser creates a user with the access, editor and auditor roles. This -// user must be able to execute commands on the test SSH node, and query the -// session recordings. -// The function also sets a password for the user (this is needed to log in and -// call the web endpoints). -// Finally, it also builds and returns a Teleport client logged in as the user. -func setupTestUser(t *testing.T, ctx context.Context, rc *helpers.TeleInstance) (*client.Client, string) { - auth := rc.Process.GetAuthServer() - // Create user - user, err := types.NewUser(testUser) - require.NoError(t, err) - user.SetLogins([]string{testUser}) - user.AddRole(teleport.PresetEditorRoleName) - user.AddRole(teleport.PresetAccessRoleName) - user.AddRole(teleport.PresetAuditorRoleName) - user, err = auth.UpsertUser(ctx, user) - require.NoError(t, err) - - userPassword := uuid.NewString() - require.NoError(t, auth.UpsertPassword(testUser, []byte(userPassword))) - - creds, err := newTestCredentials(t, rc, user) - require.NoError(t, err) - clientConfig := client.Config{ - Addrs: []string{rc.Auth}, - Credentials: []client.Credentials{creds}, - } - userClient, err := client.New(ctx, clientConfig) - require.NoError(t, err) - _, err = userClient.Ping(ctx) - require.NoError(t, err) - - return userClient, userPassword -} - -// newTestCredentials builds Teleport credentials for the testUser. -// Those credentials can only be used for auth connection. -func newTestCredentials(t *testing.T, rc *helpers.TeleInstance, user types.User) (client.Credentials, error) { - auth := rc.Process.GetAuthServer() - - // Get user certs - userKey, err := native.GenerateRSAPrivateKey() - require.NoError(t, err) - userPubKey, err := ssh.NewPublicKey(&userKey.PublicKey) - require.NoError(t, err) - testCertsReq := libauth.GenerateUserTestCertsRequest{ - Key: ssh.MarshalAuthorizedKey(userPubKey), - Username: user.GetName(), - TTL: time.Hour, - Compatibility: constants.CertificateFormatStandard, - RouteToCluster: testClusterName, - } - _, tlsCert, err := auth.GenerateUserTestCerts(testCertsReq) - require.NoError(t, err) - - // Build credentials from the certs - pemKey := pem.EncodeToMemory( - &pem.Block{ - Type: "RSA PRIVATE KEY", - Bytes: x509.MarshalPKCS1PrivateKey(userKey), - }, - ) - cert, err := keys.X509KeyPair(tlsCert, pemKey) - if err != nil { - return nil, trace.Wrap(err) - } - - pool := x509.NewCertPool() - pool.AppendCertsFromPEM(rc.Secrets.TLSCACert) - - tlsConf := &tls.Config{ - Certificates: []tls.Certificate{cert}, - RootCAs: pool, - } - return client.LoadTLS(tlsConf), nil -} - -// registerAndSetupMockSSHNode registers an agentless SSH node in Teleport and -// starts a mock SSH server. -func registerAndSetupMockSSHNode(t *testing.T, ctx context.Context, testDir string, rc *helpers.TeleInstance) types.Server { - // Reserve the listener for the SSH server. We can't start the SSH server - // right now because we need to get a valid certificate. The certificate - // needs proper principals, which implies knowing the node ID. This only - // happens after the node has joined. - var sshListenerFds []*servicecfg.FileDescriptor - sshAddr := helpers.NewListenerOn(t, "localhost", service.ListenerNodeSSH, &sshListenerFds) - - node := registerMockSSHNode(t, ctx, sshAddr, testDir, rc) - - sshListener, err := sshListenerFds[0].ToListener() - require.NoError(t, err) - - setupMockSSHNode(t, ctx, sshListener, node.GetName(), rc) - - return node -} - -func registerMockSSHNode(t *testing.T, ctx context.Context, sshAddr, testDir string, rc *helpers.TeleInstance) types.Server { - // Setup: running a one-shot Teleport instance to register our mock SSH node - // into the cluster and allow agentless execution. - opensshConfigPath := filepath.Join(testDir, "sshd_config") - require.NoError(t, os.WriteFile(opensshConfigPath, []byte{}, fs.FileMode(0644))) - teleportDataDir := filepath.Join(testDir, "teleport_openssh") - - openSSHCfg := servicecfg.MakeDefaultConfig() - - openSSHCfg.OpenSSH.Enabled = true - err := config.ConfigureOpenSSH(&config.CommandLineFlags{ - DataDir: teleportDataDir, - ProxyServer: rc.Web, - AuthToken: testToken, - JoinMethod: string(types.JoinMethodToken), - OpenSSHConfigPath: opensshConfigPath, - RestartOpenSSH: false, - CheckCommand: "echo okay", - Labels: "hello=true", - Address: sshAddr, - InsecureMode: true, - Debug: true, - }, openSSHCfg) - require.NoError(t, err) - - err = service.Run(ctx, *openSSHCfg, nil) - require.NoError(t, err) - - // Wait for node propagation - require.Eventually(t, helpers.FindNodeWithLabel(t, ctx, rc.Process.GetAuthServer(), "hello", "true"), time.Second*2, time.Millisecond*50) - nodes, err := rc.Process.GetAuthServer().GetNodes(ctx, apidefaults.Namespace) - require.NoError(t, err) - require.Len(t, nodes, 1) - return nodes[0] -} - -func setupMockSSHNode(t *testing.T, ctx context.Context, sshListener net.Listener, nodeName string, rc *helpers.TeleInstance) { - // Setup: creating and starting openssh mock server - ca, err := rc.Process.GetAuthServer().GetCertAuthority(ctx, types.CertAuthID{ - Type: types.HostCA, - DomainName: testClusterName, - }, true) - require.NoError(t, err) - - signers, err := sshutils.GetSigners(ca) - require.NoError(t, err) - require.Len(t, signers, 1) - - cert, err := apisshutils.MakeRealHostCertWithPrincipals(signers[0], nodeName) - require.NoError(t, err) - handler := sshutils.NewChanHandlerFunc(handlerSSH) - sshServer, err := sshutils.NewServer( - "test", - utils.NetAddr{AddrNetwork: "tcp", Addr: sshListener.Addr().String()}, - handler, - []ssh.Signer{cert}, - sshutils.AuthMethods{ - NoClient: true, - }, - sshutils.SetInsecureSkipHostValidation(), - sshutils.SetLogger(utils.NewLoggerForTests().WithField("component", "mocksshserver")), - ) - require.NoError(t, err) - require.NoError(t, sshServer.SetListener(sshListener)) - require.NoError(t, sshServer.Start()) -} - -// this is a dummy SSH handler. It only supports "exec" requests. All other -// requests are happily acknowledged and discarded. Receieving an "exec" request -// sends testCommandOutput in the main channel and closes all channels. -// This is not strictly following the SSH RFC as request processing is blocked -// as soon as an exec request is received, but is good enough for our use-case. -func handlerSSH(_ context.Context, ccx *sshutils.ConnectionContext, nch ssh.NewChannel) { - ch, requests, err := nch.Accept() - if err != nil { - return - } - // Sessions have out-of-band requests such as "shell", - // "pty-req", "env" and "exec". Here we don't output anything and start a - // routine consuming requests and waiting for the "exec" one. - go func(in <-chan *ssh.Request) { - for { - select { - case req := <-in: - if req.Type == "exec" { - req.Reply(true, nil) - _, err = ch.Write([]byte(testCommandOutput)) - msg := struct { - Status uint32 - }{ - Status: 0, - } - ch.SendRequest("exit-status", false, ssh.Marshal(&msg)) - ch.Close() - ccx.Close() - return - } else { - req.Reply(true, nil) - } - // If it's been 10 seconds we have not received any message, we exit - case <-time.After(10 * time.Second): - ch.Close() - ccx.Close() - return - } - - } - }(requests) -} - -// Small helper that wraps a websocket and unmarshalls messages as Teleport -// websocket ones. -type executionWebsocketReader struct { - *websocket.Conn -} - -func (r executionWebsocketReader) Read() (web.Envelope, error) { - _, data, err := r.ReadMessage() - if err != nil { - return web.Envelope{}, trace.Wrap(err) - } - var envelope web.Envelope - return envelope, trace.Wrap(proto.Unmarshal(data, &envelope)) -} - -// This is used for unmarshalling -type sessionMetadataResponse struct { - Session session.Session `json:"session"` -} diff --git a/lib/web/apiserver.go b/lib/web/apiserver.go index 503667b0731d8..943164e9c16df 100644 --- a/lib/web/apiserver.go +++ b/lib/web/apiserver.go @@ -951,36 +951,9 @@ func (h *Handler) bindDefaultEndpoints() { h.GET("/webapi/sites/:site/user-groups", h.WithClusterAuth(h.getUserGroups)) - // WebSocket endpoint for the chat conversation - // Deprecated: The connect/ws variant should be used instead. - // TODO(lxea): DELETE in v16 - h.GET("/webapi/sites/:site/assistant", h.WithClusterAuthWebSocket(false, h.assistant)) // WebSocket endpoint for the chat conversation, websocket auth h.GET("/webapi/sites/:site/assistant/ws", h.WithClusterAuthWebSocket(true, h.assistant)) - // Sets the title for the conversation. - h.POST("/webapi/assistant/conversations/:conversation_id/title", h.WithAuth(h.setAssistantTitle)) - h.POST("/webapi/assistant/title/summary", h.WithAuth(h.generateAssistantTitle)) - - // Creates a new conversation - the conversation ID is returned in the response. - h.POST("/webapi/assistant/conversations", h.WithAuth(h.createAssistantConversation)) - - // Deletes the given conversation. - h.DELETE("/webapi/assistant/conversations/:conversation_id", h.WithAuth(h.deleteAssistantConversation)) - - // Returns all conversations for the given user. - h.GET("/webapi/assistant/conversations", h.WithAuth(h.getAssistantConversations)) - - // Returns all messages in the given conversation. - h.GET("/webapi/assistant/conversations/:conversation_id", h.WithAuth(h.getAssistantConversationByID)) - - // Allows executing an arbitrary command on multiple nodes. - // Deprecated: The execute/ws variant should be used instead. - // TODO(lxea): DELETE in v16 - h.GET("/webapi/command/:site/execute", h.WithClusterAuthWebSocket(false, h.executeCommand)) - // Allows executing an arbitrary command on multiple nodes, websocket auth. - h.GET("/webapi/command/:site/execute/ws", h.WithClusterAuthWebSocket(true, h.executeCommand)) - // Fetches the user's preferences h.GET("/webapi/user/preferences", h.WithAuth(h.getUserPreferences)) @@ -3250,55 +3223,6 @@ func (h *Handler) generateSession(req *TerminalRequest, clusterName string, scx }, nil } -func (h *Handler) generateCommandSession(host *hostInfo, login, clusterName, owner string) (session.Session, error) { - h.log.Infof("Generating new session for %s in %s\n", host.hostName, clusterName) - - return session.Session{ - Login: login, - ServerID: host.id, - ClusterName: clusterName, - ServerHostname: host.hostName, - ServerHostPort: host.port, - ID: session.NewID(), - Created: time.Now().UTC(), - LastActive: time.Now().UTC(), - Namespace: apidefaults.Namespace, - Owner: owner, - }, nil -} - -// hostInfo is a helper struct used to store host information. -type hostInfo struct { - id string - hostName string - port int -} - -// findByQuery returns all hosts matching the given query/predicate. -// The query is a predicate expression that can be used to filter hosts. -func findByQuery(ctx context.Context, clt auth.ClientI, query string) ([]hostInfo, error) { - servers, err := apiclient.GetAllResources[types.Server](ctx, clt, &proto.ListResourcesRequest{ - ResourceType: types.KindNode, - Namespace: apidefaults.Namespace, - PredicateExpression: query, - }) - if err != nil { - return nil, trace.Wrap(err) - } - - hosts := make([]hostInfo, 0, len(servers)) - for _, server := range servers { - h := hostInfo{ - hostName: server.GetHostname(), - id: server.GetName(), - port: defaultPort, - } - hosts = append(hosts, h) - } - - return hosts, nil -} - // fetchExistingSession fetches an active or pending SSH session by the SessionID passed in the TerminalRequest. func (h *Handler) fetchExistingSession(ctx context.Context, clt auth.ClientI, req *TerminalRequest, siteName string) (session.Session, types.SessionTracker, error) { sessionID, err := session.ParseID(req.SessionID.String()) diff --git a/lib/web/assistant.go b/lib/web/assistant.go index c366443547410..db01d8dd571c3 100644 --- a/lib/web/assistant.go +++ b/lib/web/assistant.go @@ -20,7 +20,6 @@ package web import ( "context" - "encoding/json" "errors" "io" "net/http" @@ -29,7 +28,6 @@ import ( "github.com/gorilla/websocket" "github.com/gravitational/trace" "github.com/julienschmidt/httprouter" - "google.golang.org/protobuf/types/known/timestamppb" "github.com/gravitational/teleport/api/client/proto" assistpb "github.com/gravitational/teleport/api/gen/proto/go/assist/v1" @@ -38,8 +36,6 @@ import ( "github.com/gravitational/teleport/lib/ai/tokens" "github.com/gravitational/teleport/lib/assist" "github.com/gravitational/teleport/lib/auth" - "github.com/gravitational/teleport/lib/httplib" - "github.com/gravitational/teleport/lib/modules" "github.com/gravitational/teleport/lib/reversetunnelclient" ) @@ -50,70 +46,11 @@ const ( actionSSHExplainCommand = "ssh-explain" // actionGenerateAuditQuery is the name of the action for generating audit queries. actionGenerateAuditQuery = "audit-query" - // We can not know how many tokens we will consume in advance. - // Try to consume a small amount of tokens first. + // We cannot know how many tokens we will consume in advance. + // Try to consume a small number of tokens first. lookaheadTokens = 100 ) -// createAssistantConversationResponse is a response for POST /webapi/assistant/conversations. -type createdAssistantConversationResponse struct { - // ID is a conversation ID. - ID string `json:"id"` -} - -// createAssistantConversation is a handler for POST /webapi/assistant/conversations. -func (h *Handler) createAssistantConversation(_ http.ResponseWriter, r *http.Request, - _ httprouter.Params, sctx *SessionContext, -) (any, error) { - authClient, err := sctx.GetClient() - if err != nil { - return nil, trace.Wrap(err) - } - - if err := checkAssistEnabled(authClient, r.Context()); err != nil { - return nil, trace.Wrap(err) - } - - req := &assistpb.CreateAssistantConversationRequest{ - CreatedTime: timestamppb.New(h.clock.Now().UTC()), - Username: sctx.GetUser(), - } - - resp, err := authClient.CreateAssistantConversation(r.Context(), req) - if err != nil { - return nil, err - } - - return &createdAssistantConversationResponse{ - ID: resp.Id, - }, nil -} - -// deleteAssistantConversation is a handler for DELETE /webapi/assistant/conversations/:conversation_id. -func (h *Handler) deleteAssistantConversation(_ http.ResponseWriter, r *http.Request, - p httprouter.Params, sctx *SessionContext, -) (any, error) { - authClient, err := sctx.GetClient() - if err != nil { - return nil, trace.Wrap(err) - } - - if err := checkAssistEnabled(authClient, r.Context()); err != nil { - return nil, trace.Wrap(err) - } - - conversationID := p.ByName("conversation_id") - - if err := authClient.DeleteAssistantConversation(r.Context(), &assistpb.DeleteAssistantConversationRequest{ - ConversationId: conversationID, - Username: sctx.GetUser(), - }); err != nil { - return nil, trace.Wrap(err) - } - - return OK(), nil -} - // assistantMessage is an assistant message that is sent to the client. type assistantMessage struct { // Type is a type of the message. @@ -124,214 +61,9 @@ type assistantMessage struct { Payload string `json:"payload"` } -// getAssistantConversation is a handler for GET /webapi/assistant/conversations/:conversation_id. -func (h *Handler) getAssistantConversationByID(_ http.ResponseWriter, r *http.Request, - p httprouter.Params, sctx *SessionContext, -) (any, error) { - authClient, err := sctx.GetClient() - if err != nil { - return nil, trace.Wrap(err) - } - - if err := checkAssistEnabled(authClient, r.Context()); err != nil { - return nil, trace.Wrap(err) - } - - conversationID := p.ByName("conversation_id") - - resp, err := authClient.GetAssistantMessages(r.Context(), &assistpb.GetAssistantMessagesRequest{ - ConversationId: conversationID, - Username: sctx.GetUser(), - }) - if err != nil { - return nil, trace.Wrap(err) - } - - return conversationResponse(resp), nil -} - -// conversationResponse creates a response for GET conversation response. -func conversationResponse(resp *assistpb.GetAssistantMessagesResponse) any { - type response struct { - Messages []assistantMessage `json:"messages"` - } - - jsonResp := &response{ - Messages: make([]assistantMessage, 0, len(resp.Messages)), - } - - for _, message := range resp.Messages { - jsonResp.Messages = append(jsonResp.Messages, assistantMessage{ - Type: assist.MessageType(message.Type), - CreatedTime: message.CreatedTime.AsTime().Format(time.RFC3339), - Payload: message.Payload, - }) - } - - return jsonResp -} - -// conversationInfo is a response for GET conversation response. -type conversationInfo struct { - // ID is a conversation ID. - ID string `json:"id"` - // Title is a conversation title. - Title string `json:"title,omitempty"` - // CreatedTime is a time when the conversation was created in RFC3339 format. - CreatedTime string `json:"created_time"` -} - -// conversationsResponse is a response for GET conversation response. -type conversationsResponse struct { - Conversations []conversationInfo `json:"conversations"` -} - -// getAssistantConversations is a handler for GET /webapi/assistant/conversations. -func (h *Handler) getAssistantConversations(_ http.ResponseWriter, r *http.Request, - _ httprouter.Params, sctx *SessionContext, -) (any, error) { - authClient, err := sctx.GetClient() - if err != nil { - return nil, trace.Wrap(err) - } - - if err := checkAssistEnabled(authClient, r.Context()); err != nil { - return nil, trace.Wrap(err) - } - - resp, err := authClient.GetAssistantConversations(r.Context(), &assistpb.GetAssistantConversationsRequest{ - Username: sctx.GetUser(), - }) - if err != nil { - return nil, trace.Wrap(err) - } - - return genConversationsResponse(resp), nil -} - -func genConversationsResponse(resp *assistpb.GetAssistantConversationsResponse) *conversationsResponse { - jsonResp := &conversationsResponse{ - Conversations: make([]conversationInfo, 0, len(resp.Conversations)), - } - - for _, conversation := range resp.Conversations { - jsonResp.Conversations = append(jsonResp.Conversations, conversationInfo{ - ID: conversation.Id, - Title: conversation.Title, - CreatedTime: conversation.CreatedTime.AsTime().Format(time.RFC3339), - }) - } - - return jsonResp -} - -// setAssistantTitle is a handler for POST /webapi/assistant/conversations/:conversation_id/title. -func (h *Handler) setAssistantTitle(_ http.ResponseWriter, r *http.Request, - p httprouter.Params, sctx *SessionContext, -) (any, error) { - req := struct { - Title string `json:"title"` - }{} - - if err := httplib.ReadJSON(r, &req); err != nil { - return nil, trace.Wrap(err) - } - - authClient, err := sctx.GetClient() - if err != nil { - return nil, trace.Wrap(err) - } - - if err := checkAssistEnabled(authClient, r.Context()); err != nil { - return nil, trace.Wrap(err) - } - - conversationID := p.ByName("conversation_id") - - conversationInfo := &assistpb.UpdateAssistantConversationInfoRequest{ - ConversationId: conversationID, - Username: sctx.GetUser(), - Title: req.Title, - } - - if err := authClient.UpdateAssistantConversationInfo(r.Context(), conversationInfo); err != nil { - return nil, trace.Wrap(err) - } - - return OK(), nil -} - -// generateAssistantTitleRequest is a request for POST /webapi/assistant/title/summary. -type generateAssistantTitleRequest struct { - Message string `json:"message"` -} - -// generateAssistantTitle is a handler for POST /webapi/assistant/title/summary. -func (h *Handler) generateAssistantTitle(_ http.ResponseWriter, r *http.Request, - _ httprouter.Params, sctx *SessionContext, -) (any, error) { - var req generateAssistantTitleRequest - if err := httplib.ReadJSON(r, &req); err != nil { - return nil, trace.Wrap(err) - } - - authClient, err := sctx.GetClient() - if err != nil { - return nil, trace.Wrap(err) - } - - if err := checkAssistEnabled(authClient, r.Context()); err != nil { - return nil, trace.Wrap(err) - } - - client, err := assist.NewClient(r.Context(), h.cfg.ProxyClient, - h.cfg.ProxySettings, h.cfg.OpenAIConfig) - if err != nil { - return nil, trace.Wrap(err) - } - - titleSummary, err := client.GenerateSummary(r.Context(), req.Message) - if err != nil { - return nil, trace.Wrap(err) - } - - conversationInfo := &conversationInfo{ - Title: titleSummary, - } - - // We only want to emmit - if modules.GetModules().Features().Cloud { - go func() { - ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) - defer cancel() - class, err := client.ClassifyMessage(ctx, req.Message, assist.MessageClasses) - if err != nil { - return - } - h.log.Debugf("message classified as '%s'", class) - // TODO(shaka): emit event here to report the message class - usageEventReq := &proto.SubmitUsageEventRequest{ - Event: &usageeventsv1.UsageEventOneOf{ - Event: &usageeventsv1.UsageEventOneOf_AssistNewConversation{ - AssistNewConversation: &usageeventsv1.AssistNewConversationEvent{ - Category: class, - }, - }, - }, - } - if err := authClient.SubmitUsageEvent(ctx, usageEventReq); err != nil { - h.log.WithError(err).Warn("Failed to emit usage event") - } - }() - - } - - return conversationInfo, nil -} - // assistant is a handler for GET /webapi/sites/:site/assistant. // This handler covers the main chat conversation as well as the -// SSH completition (SSH command generation and output explanation). +// SSH competition (SSH command generation and output explanation). func (h *Handler) assistant(w http.ResponseWriter, r *http.Request, _ httprouter.Params, sctx *SessionContext, site reversetunnelclient.RemoteSite, ws *websocket.Conn, ) (any, error) { @@ -357,31 +89,6 @@ func (h *Handler) reserveTokens(usedTokens *tokens.TokenCount) (int, int) { return promptTokens, completionTokens } -// reportTokenUsage sends a token usage event for a conversation. -func (h *Handler) reportConversationTokenUsage(authClient auth.ClientI, usedTokens *tokens.TokenCount, conversationID string) { - // Create a new context to not be bounded by the request timeout. - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - - promptTokens, completionTokens := h.reserveTokens(usedTokens) - usageEventReq := &proto.SubmitUsageEventRequest{ - Event: &usageeventsv1.UsageEventOneOf{ - Event: &usageeventsv1.UsageEventOneOf_AssistCompletion{ - AssistCompletion: &usageeventsv1.AssistCompletionEvent{ - ConversationId: conversationID, - TotalTokens: int64(promptTokens + completionTokens), - PromptTokens: int64(promptTokens), - CompletionTokens: int64(completionTokens), - }, - }, - }, - } - - if err := authClient.SubmitUsageEvent(ctx, usageEventReq); err != nil { - h.log.WithError(err).Warn("Failed to emit usage event") - } -} - // reportTokenUsage sends a token usage event for an action. func (h *Handler) reportActionTokenUsage(authClient auth.ClientI, usedTokens *tokens.TokenCount, action string) { // Create a new context to not be bounded by the request timeout. @@ -515,7 +222,7 @@ func runAssistant(h *Handler, w http.ResponseWriter, r *http.Request, case actionGenerateAuditQuery: err = h.assistGenAuditQueryLoop(ctx, assistClient, authClient, ws, sctx.GetUser()) default: - err = h.assistChatLoop(ctx, assistClient, authClient, conversationID, sctx, ws) + err = trace.Errorf("Teleport Assist Chat has been remove in v16") } return trace.Wrap(err) @@ -627,79 +334,7 @@ func (h *Handler) assistGenSSHCommandLoop(ctx context.Context, assistClient *ass return nil } -// assistChatLoop is the main chat loop for the assistant. -// It reads the user's input from provided WS and generates a response. -func (h *Handler) assistChatLoop(ctx context.Context, assistClient *assist.Assist, authClient auth.ClientI, - conversationID string, sctx *SessionContext, ws *websocket.Conn, -) error { - ac, err := sctx.GetUserAccessChecker() - if err != nil { - return trace.Wrap(err) - } - - toolContext := &tools.ToolContext{ - AssistEmbeddingServiceClient: authClient.EmbeddingClient(), - AccessRequestClient: authClient, - AccessChecker: ac, - NodeWatcher: h.nodeWatcher, - ClusterName: sctx.cfg.Parent.clusterName, - User: sctx.GetUser(), - } - - chat, err := assistClient.NewChat(ctx, authClient, toolContext, conversationID) - if err != nil { - return trace.Wrap(err) - } - - onMessage := func(kind assist.MessageType, payload []byte, createdTime time.Time) error { - return trace.Wrap(onMessageFn(ws, kind, payload, createdTime)) - } - - if chat.IsNewConversation() { - // new conversation, generate a hello message - if _, err := chat.ProcessComplete(ctx, onMessage, ""); err != nil { - return trace.Wrap(err) - } - } - - for { - _, payload, err := ws.ReadMessage() - if err != nil { - if wsIsClosed(err) { - break - } - return trace.Wrap(err) - } - - var wsIncoming assistantMessage - if err := json.Unmarshal(payload, &wsIncoming); err != nil { - return trace.Wrap(err) - } - - if wsIncoming.Type == assist.MessageKindAccessRequestCreated { - chat.RecordMesssage(ctx, wsIncoming.Type, wsIncoming.Payload) - } - - if err := h.preliminaryRateLimitGuard(onMessage); err != nil { - return trace.Wrap(err) - } - - //TODO(jakule): Should we sanitize the payload? - usedTokens, err := chat.ProcessComplete(ctx, onMessage, wsIncoming.Payload) - if err != nil { - return trace.Wrap(err) - } - - // Token usage reporting is asynchronous as we might still be streaming - // a message, and we don't want to block everything. - go h.reportConversationTokenUsage(authClient, usedTokens, conversationID) - } - - h.log.Debug("end assistant conversation loop") - return nil -} - -// preliminaryRateLimitGuard checks that some small amount of tokens are still available and the ratelimit is not exceeded. +// preliminaryRateLimitGuard checks that some small number of tokens is still available and the ratelimit is not exceeded. // This is done because the changed quantity within the limiter is not known until after a request is processed. func (h *Handler) preliminaryRateLimitGuard(onMessageFn func(kind assist.MessageType, payload []byte, createdTime time.Time) error) error { const errorMsg = "You have reached the rate limit. Please try again later." diff --git a/lib/web/assistant_test.go b/lib/web/assistant_test.go index 15d925dadb723..bcc823e0e95d5 100644 --- a/lib/web/assistant_test.go +++ b/lib/web/assistant_test.go @@ -19,25 +19,19 @@ package web import ( - "context" "crypto/tls" "encoding/json" "fmt" - "io" "net/http" "net/http/httptest" "net/url" - "strings" "testing" "github.com/gorilla/websocket" "github.com/gravitational/trace" "github.com/sashabaranov/go-openai" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "golang.org/x/time/rate" - authproto "github.com/gravitational/teleport/api/client/proto" "github.com/gravitational/teleport/api/types" aitest "github.com/gravitational/teleport/lib/ai/testutils" "github.com/gravitational/teleport/lib/assist" @@ -45,231 +39,6 @@ import ( "github.com/gravitational/teleport/lib/services" ) -func Test_runAssistant(t *testing.T) { - t.Parallel() - - readStreamResponse := func(t *testing.T, ws *websocket.Conn) string { - var sb strings.Builder - for { - var msg assistantMessage - _, payload, err := ws.ReadMessage() - require.NoError(t, err) - - err = json.Unmarshal(payload, &msg) - require.NoError(t, err) - - if msg.Type == assist.MessageKindAssistantPartialFinalize { - break - } - - require.Equal(t, assist.MessageKindAssistantPartialMessage, msg.Type) - sb.WriteString(msg.Payload) - } - - return sb.String() - } - - readRateLimitedMessage := func(t *testing.T, ws *websocket.Conn) { - var msg assistantMessage - _, payload, err := ws.ReadMessage() - require.NoError(t, err) - - err = json.Unmarshal(payload, &msg) - require.NoError(t, err) - - require.Equal(t, assist.MessageKindError, msg.Type) - require.Equal(t, "You have reached the rate limit. Please try again later.", msg.Payload) - } - - testCases := []struct { - name string - responses []string - cfg webSuiteConfig - setup func(*testing.T, *WebSuite) - act func(*testing.T, *websocket.Conn) - }{ - { - name: "normal", - responses: []string{ - generateTextResponse(), - }, - act: func(t *testing.T, ws *websocket.Conn) { - err := ws.WriteMessage(websocket.TextMessage, []byte(`{"payload": "show free disk space"}`)) - require.NoError(t, err) - - const expectedMsg = "Which node do you want to use?" - require.Contains(t, readStreamResponse(t, ws), expectedMsg) - }, - }, - { - name: "rate limited", - responses: []string{ - generateTextResponse(), - }, - cfg: webSuiteConfig{ - ClusterFeatures: &authproto.Features{ - Cloud: true, - }, - }, - setup: func(t *testing.T, s *WebSuite) { - // Assert that rate limiter is set up when Cloud feature is active, - // before replacing with a lower capacity rate-limiter for test purposes - require.InEpsilon(t, float64(assistantLimiterRate), float64(s.webHandler.handler.assistantLimiter.Limit()), 0.0) - - // 101 token capacity (lookaheadTokens+1) and a slow replenish rate - // to let the first completion request succeed, but not the second one - s.webHandler.handler.assistantLimiter = rate.NewLimiter(rate.Limit(0.001), 101) - - }, - act: func(t *testing.T, ws *websocket.Conn) { - err := ws.WriteMessage(websocket.TextMessage, []byte(`{"payload": "show free disk space"}`)) - require.NoError(t, err) - - const expectedMsg = "Which node do you want to use?" - require.Contains(t, readStreamResponse(t, ws), expectedMsg) - - err = ws.WriteMessage(websocket.TextMessage, []byte(`{"payload": "all nodes, please"}`)) - require.NoError(t, err) - - readRateLimitedMessage(t, ws) - }, - }, - } - - for _, tc := range testCases { - tc := tc - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - responses := tc.responses - server := httptest.NewServer(aitest.GetTestHandlerFn(t, responses)) - t.Cleanup(server.Close) - - openaiCfg := openai.DefaultConfig("test-token") - openaiCfg.BaseURL = server.URL - tc.cfg.OpenAIConfig = &openaiCfg - s := newWebSuiteWithConfig(t, tc.cfg) - - if tc.setup != nil { - tc.setup(t, s) - } - assistRole := allowAssistAccess(t, s) - - ctx := context.Background() - authPack := s.authPack(t, "foo", assistRole.GetName()) - // Create the conversation - conversationID := s.makeAssistConversation(t, ctx, authPack) - - // Make WS client and start the conversation - ws, err := s.makeAssistant(t, authPack, conversationID, "") - require.NoError(t, err) - t.Cleanup(func() { require.NoError(t, ws.Close()) }) - - _, payload, err := ws.ReadMessage() - require.NoError(t, err) - - var msg assistantMessage - err = json.Unmarshal(payload, &msg) - require.NoError(t, err) - - // Expect "hello" message - require.Equal(t, assist.MessageKindAssistantMessage, msg.Type) - require.Contains(t, msg.Payload, "Hey, I'm Teleport") - - tc.act(t, ws) - }) - } -} - -// Test_runAssistError tests that the assistant returns an error message -// when the OpenAI API returns an error. -func Test_runAssistError(t *testing.T) { - t.Parallel() - - readHelloMsg := func(ws *websocket.Conn) { - _, payload, err := ws.ReadMessage() - require.NoError(t, err) - - var msg assistantMessage - err = json.Unmarshal(payload, &msg) - require.NoError(t, err) - - // Expect "hello" message - require.Equal(t, assist.MessageKindAssistantMessage, msg.Type) - require.Contains(t, msg.Payload, "Hey, I'm Teleport") - } - - readErrorMsg := func(ws *websocket.Conn) { - err := ws.WriteMessage(websocket.TextMessage, []byte(`{"payload": "show free disk space"}`)) - require.NoError(t, err) - - _, payload, err := ws.ReadMessage() - require.NoError(t, err, "expected error message, payload: %s", payload) - - var msg assistantMessage - err = json.Unmarshal(payload, &msg) - require.NoError(t, err) - - // Expect OpenAI error message - require.Equal(t, assist.MessageKindError, msg.Type) - require.Contains(t, msg.Payload, "An error has occurred. Please try again later.") - } - - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - // Simulate rate limit error - w.WriteHeader(429) - - errMsg := openai.ErrorResponse{ - Error: &openai.APIError{ - Code: "rate_limit_reached", - Message: "You are sending requests too quickly.", - Param: nil, - Type: "rate_limit_reached", - HTTPStatusCode: 429, - }, - } - - dataBytes, err := json.Marshal(errMsg) - // Use assert as require doesn't work when called from a goroutine - assert.NoError(t, err, "Marshal error") - - _, err = w.Write(dataBytes) - assert.NoError(t, err, "Write error") - })) - t.Cleanup(server.Close) - - openaiCfg := openai.DefaultConfig("test-token") - openaiCfg.BaseURL = server.URL - s := newWebSuiteWithConfig(t, webSuiteConfig{OpenAIConfig: &openaiCfg}) - - assistRole := allowAssistAccess(t, s) - authPack := s.authPack(t, "foo", assistRole.GetName()) - - ctx := context.Background() - // Create the conversation - conversationID := s.makeAssistConversation(t, ctx, authPack) - - // Make WS client and start the conversation - ws, err := s.makeAssistant(t, authPack, conversationID, "") - require.NoError(t, err) - t.Cleanup(func() { - // The TLS connection might or might not be closed, this is an implementation detail. - // We want to check whether the websocket gets appropriately closed, not the underlying TLS connection. - // The connection will eventually be closed and reclaimed by the server. We can skip checking the error. - _ = ws.Close() - }) - - // verify responses - readHelloMsg(ws) - readErrorMsg(ws) - - // Check for the close message - _, _, err = ws.ReadMessage() - var closeErr *websocket.CloseError - require.ErrorAs(t, err, &closeErr, "Expected close error") - require.Equal(t, websocket.CloseInternalServerErr, closeErr.Code, "Expected abnormal closure") -} - func Test_SSHCommandGeneration(t *testing.T) { t.Parallel() @@ -354,46 +123,6 @@ func Test_SSHCommandExplain(t *testing.T) { assertResponse(ws) } -func Test_generateAssistantTitle(t *testing.T) { - // Test setup - t.Parallel() - ctx := context.Background() - - responses := []string{"This is the message summary.", "troubleshooting"} - server := httptest.NewServer(aitest.GetTestHandlerFn(t, responses)) - t.Cleanup(server.Close) - - openaiCfg := openai.DefaultConfig("test-token") - openaiCfg.BaseURL = server.URL - s := newWebSuiteWithConfig(t, webSuiteConfig{ - ClusterFeatures: &authproto.Features{ - Cloud: true, - }, - OpenAIConfig: &openaiCfg, - }) - - assistRole := allowAssistAccess(t, s) - assistRole, err := s.server.Auth().UpsertRole(s.ctx, assistRole) - require.NoError(t, err) - - pack := s.authPack(t, "foo", assistRole.GetName()) - - // Real test: we craft a request asking for a summary - endpoint := pack.clt.Endpoint("webapi", "assistant", "title", "summary") - req := generateAssistantTitleRequest{Message: "This is a test user message asking Teleport assist to do something."} - - // Executing the request and validating the output is as expected - resp, err := pack.clt.PostJSON(ctx, endpoint, &req) - require.NoError(t, err) - - var info conversationInfo - body, err := io.ReadAll(resp.Reader()) - require.NoError(t, err) - err = json.Unmarshal(body, &info) - require.NoError(t, err) - require.NotEmpty(t, info.Title) -} - func allowAssistAccess(t *testing.T, s *WebSuite) types.Role { assistRole, err := types.NewRole("assist-access", types.RoleSpecV6{ Allow: types.RoleConditions{ @@ -409,22 +138,6 @@ func allowAssistAccess(t *testing.T, s *WebSuite) types.Role { return assistRole } -// makeAssistConversation creates a new assist conversation and returns its ID -func (s *WebSuite) makeAssistConversation(t *testing.T, ctx context.Context, authPack *authPack) string { - clt := authPack.clt - - resp, err := clt.PostJSON(ctx, clt.Endpoint("webapi", "assistant", "conversations"), nil) - require.NoError(t, err) - - convResp := struct { - ConversationID string `json:"id"` - }{} - err = json.Unmarshal(resp.Bytes(), &convResp) - require.NoError(t, err) - - return convResp.ConversationID -} - // makeAssistant creates a new assistant websocket connection. func (s *WebSuite) makeAssistant(_ *testing.T, pack *authPack, conversationID, action string) (*websocket.Conn, error) { if action == "" && conversationID == "" { @@ -476,11 +189,6 @@ func (s *WebSuite) makeAssistant(_ *testing.T, pack *authPack, conversationID, a return ws, nil } -// generateTextResponse generates a response for a text completion -func generateTextResponse() string { - return "\nWhich node do you want to use?" -} - func generateCommandResponse() string { return "```" + `json { diff --git a/lib/web/command.go b/lib/web/command.go deleted file mode 100644 index af285c66a46b2..0000000000000 --- a/lib/web/command.go +++ /dev/null @@ -1,768 +0,0 @@ -/* - * Teleport - * Copyright (C) 2023 Gravitational, Inc. - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU Affero General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU Affero General Public License for more details. - * - * You should have received a copy of the GNU Affero General Public License - * along with this program. If not, see . - */ - -package web - -import ( - "bytes" - "context" - "encoding/json" - "fmt" - "net/http" - "strings" - "sync" - "time" - - "github.com/gogo/protobuf/proto" - "github.com/gorilla/websocket" - "github.com/gravitational/trace" - "github.com/julienschmidt/httprouter" - "github.com/sirupsen/logrus" - oteltrace "go.opentelemetry.io/otel/trace" - "golang.org/x/crypto/ssh" - "golang.org/x/sync/errgroup" - "google.golang.org/protobuf/types/known/timestamppb" - - "github.com/gravitational/teleport" - clientproto "github.com/gravitational/teleport/api/client/proto" - apidefaults "github.com/gravitational/teleport/api/defaults" - "github.com/gravitational/teleport/api/gen/proto/go/assist/v1" - usageeventsv1 "github.com/gravitational/teleport/api/gen/proto/go/usageevents/v1" - "github.com/gravitational/teleport/api/observability/tracing" - "github.com/gravitational/teleport/lib/agentless" - "github.com/gravitational/teleport/lib/ai/tokens" - assistlib "github.com/gravitational/teleport/lib/assist" - "github.com/gravitational/teleport/lib/auth" - "github.com/gravitational/teleport/lib/client" - "github.com/gravitational/teleport/lib/defaults" - "github.com/gravitational/teleport/lib/httplib" - "github.com/gravitational/teleport/lib/proxy" - "github.com/gravitational/teleport/lib/reversetunnelclient" - "github.com/gravitational/teleport/lib/services" - "github.com/gravitational/teleport/lib/session" - "github.com/gravitational/teleport/lib/teleagent" -) - -// summaryBufferCapacity is the summary buffer size in bytes. The summary buffer -// is shared across all nodes the command is running on and stores the command -// output. If the command output exceeds the buffer capacity, the summary won't -// be computed. -const summaryBufferCapacity = 2000 - -// CommandRequest is a request to execute a command on all nodes that match the query. -type CommandRequest struct { - // Command is the command to be executed on all nodes. - Command string `json:"command"` - // Login is a Linux username to connect as. - Login string `json:"login"` - // Query is the predicate query to filter nodes where the command will be executed. - Query string `json:"query"` - // ConversationID is the conversation context that was used to execute the command. - ConversationID string `json:"conversation_id"` - // ExecutionID is a unique ID used to identify the command execution. - ExecutionID string `json:"execution_id"` -} - -// commandExecResult is a result of a command execution. -type commandExecResult struct { - // NodeID is the ID of the node where the command was executed. - NodeID string `json:"node_id"` - // NodeName is the name of the node where the command was executed. - NodeName string `json:"node_name"` - // ExecutionID is a unique ID used to identify the command execution. - ExecutionID string `json:"execution_id"` - // SessionID is the ID of the session where the command was executed. - SessionID string `json:"session_id"` -} - -// sessionEndEvent is an event that is sent when a session ends. -type sessionEndEvent struct { - // NodeID is the ID of the server where the session was created. - NodeID string `json:"node_id"` -} - -// Check checks if the request is valid. -func (c *CommandRequest) Check() error { - if c.Command == "" { - return trace.BadParameter("missing command") - } - - if c.Query == "" { - return trace.BadParameter("missing query") - } - - if c.Login == "" { - return trace.BadParameter("missing login") - } - - if c.ConversationID == "" { - return trace.BadParameter("missing conversation ID") - } - - if c.ExecutionID == "" { - return trace.BadParameter("missing execution ID") - } - - return nil -} - -// executeCommand executes a command on all nodes that match the query. -func (h *Handler) executeCommand( - w http.ResponseWriter, - r *http.Request, - _ httprouter.Params, - sessionCtx *SessionContext, - site reversetunnelclient.RemoteSite, - rawWS *websocket.Conn, -) (any, error) { - q := r.URL.Query() - params := q.Get("params") - if params == "" { - return nil, trace.BadParameter("missing params") - } - var req CommandRequest - if err := json.Unmarshal([]byte(params), &req); err != nil { - return nil, trace.BadParameter("failed to read JSON message: %v", err) - } - - if err := req.Check(); err != nil { - return nil, trace.BadParameter("invalid payload: %v", err) - } - - clt, err := sessionCtx.GetUserClient(r.Context(), site) - if err != nil { - return nil, trace.Wrap(err) - } - - if err := checkAssistEnabled(clt, r.Context()); err != nil { - return nil, trace.Wrap(err) - } - - ctx, err := h.cfg.SessionControl.AcquireSessionContext(r.Context(), sessionCtx, req.Login, h.cfg.ProxyWebAddr.Addr, r.RemoteAddr) - if err != nil { - return nil, trace.Wrap(err) - } - - authAccessPoint, err := site.CachingAccessPoint() - if err != nil { - h.log.WithError(err).Debug("Unable to get auth access point.") - return nil, trace.Wrap(err) - } - - netConfig, err := authAccessPoint.GetClusterNetworkingConfig(ctx) - if err != nil { - h.log.WithError(err).Debug("Unable to fetch cluster networking config.") - return nil, trace.Wrap(err) - } - - clusterName := site.GetName() - - defer func() { - rawWS.WriteMessage(websocket.CloseMessage, nil) - rawWS.Close() - }() - - keepAliveInterval := netConfig.GetKeepAliveInterval() - err = rawWS.SetReadDeadline(deadlineForInterval(keepAliveInterval)) - if err != nil { - h.log.WithError(err).Error("Error setting websocket readline") - return nil, trace.Wrap(err) - } - // Update the read deadline upon receiving a pong message. - rawWS.SetPongHandler(func(_ string) error { - // This is intentonally called without a lock as this callback is - // called from the same goroutine as the read loop which is already locked. - return trace.Wrap(rawWS.SetReadDeadline(deadlineForInterval(keepAliveInterval))) - }) - - // Wrap the raw websocket connection in a syncRWWSConn so that we can - // safely read and write to the single websocket connection from - // multiple goroutines/execution nodes. - ws := &syncRWWSConn{WSConn: rawWS} - - hosts, err := findByQuery(ctx, clt, req.Query) - if err != nil { - log.WithError(err).Warn("Failed to find nodes by labels") - return nil, trace.Wrap(err) - } - - if len(hosts) == 0 { - const errMsg = "no servers found" - h.log.Error(errMsg) - return nil, trace.Errorf(errMsg) - } - - h.log.Debugf("Found %d hosts to run Assist command %q on.", len(hosts), req.Command) - - mfaCacheFn := getMFACacheFn() - interactiveCommand := strings.Split(req.Command, " ") - - buffer := newSummaryBuffer(summaryBufferCapacity) - - runCmd := func(host *hostInfo) error { - sessionData, err := h.generateCommandSession(host, req.Login, clusterName, sessionCtx.cfg.User) - if err != nil { - h.log.WithError(err).Debug("Unable to generate new ssh session.") - return trace.Wrap(err) - } - - h.log.Debugf("New command request for server=%s, id=%v, login=%s, sid=%s, websid=%s.", - host.hostName, host.id, req.Login, sessionData.ID, sessionCtx.GetSessionID()) - - commandHandlerConfig := CommandHandlerConfig{ - SessionCtx: sessionCtx, - UserAuthClient: clt, - SessionData: sessionData, - KeepAliveInterval: keepAliveInterval, - ProxyHostPort: h.ProxyHostPort(), - InteractiveCommand: interactiveCommand, - Router: h.cfg.Router, - TracerProvider: h.cfg.TracerProvider, - LocalAccessPoint: h.auth.accessPoint, - mfaFuncCache: mfaCacheFn, - buffer: buffer, - HostNameResolver: func(serverID string) (string, error) { - return serverID, nil - }, - } - - handler, err := newCommandHandler(ctx, commandHandlerConfig) - if err != nil { - h.log.WithError(err).Error("Unable to create terminal.") - return trace.Wrap(err) - } - handler.ws = &noopCloserWS{ws} - - h.userConns.Add(1) - defer h.userConns.Add(-1) - - h.log.Infof("Executing command: %#v.", req) - httplib.MakeTracingHandler(handler, teleport.ComponentProxy).ServeHTTP(w, r) - - msgPayload, err := json.Marshal(&commandExecResult{ - NodeID: host.id, - NodeName: host.hostName, - ExecutionID: req.ExecutionID, - SessionID: string(sessionData.ID), - }) - - if err != nil { - return trace.Wrap(err) - } - - err = clt.CreateAssistantMessage(ctx, &assist.CreateAssistantMessageRequest{ - ConversationId: req.ConversationID, - Username: sessionCtx.GetUser(), - Message: &assist.AssistantMessage{ - Type: string(assistlib.MessageKindCommandResult), - CreatedTime: timestamppb.New(time.Now().UTC()), - Payload: string(msgPayload), - }, - }) - - return trace.Wrap(err) - } - - runCommands(hosts, runCmd, int(netConfig.GetAssistCommandExecutionWorkers()), h.log) - - var tokenCount *tokens.TokenCount - // Optionally, try to compute the command summary. - if output, valid := buffer.Export(); valid { - summaryReq := summaryRequest{ - hosts: hosts, - output: output, - authClient: clt, - username: sessionCtx.GetUser(), - executionID: req.ExecutionID, - conversationID: req.ConversationID, - command: req.Command, - } - tokenCount, err = h.computeAndSendSummary(ctx, &summaryReq, ws) - if err != nil { - return nil, trace.Wrap(err) - } - } - - prompt, completion := tokens.CountTokens(tokenCount) - - usageEventReq := &clientproto.SubmitUsageEventRequest{ - Event: &usageeventsv1.UsageEventOneOf{ - Event: &usageeventsv1.UsageEventOneOf_AssistExecution{ - AssistExecution: &usageeventsv1.AssistExecutionEvent{ - ConversationId: req.ConversationID, - NodeCount: int64(len(hosts)), - TotalTokens: int64(completion + prompt), - PromptTokens: int64(prompt), - CompletionTokens: int64(completion), - }, - }, - }, - } - if err := clt.SubmitUsageEvent(ctx, usageEventReq); err != nil { - h.log.WithError(err).Warn("Failed to emit usage event") - } - - return nil, nil -} - -type summaryRequest struct { - hosts []hostInfo - output map[string][]byte - authClient auth.ClientI - username string - executionID string - conversationID string - command string -} - -func (h *Handler) computeAndSendSummary( - ctx context.Context, - req *summaryRequest, - ws WSConn, -) (*tokens.TokenCount, error) { - // Convert the map nodeId->output into a map nodeName->output - namedOutput := outputByName(req.hosts, req.output) - - history, err := req.authClient.GetAssistantMessages(ctx, &assist.GetAssistantMessagesRequest{ - ConversationId: req.conversationID, - Username: req.username, - }) - if err != nil { - return nil, trace.Wrap(err) - } - - assistClient, err := assistlib.NewClient(ctx, req.authClient, h.cfg.ProxySettings, h.cfg.OpenAIConfig) - if err != nil { - return nil, trace.Wrap(err) - } - - summary, tokenCount, err := assistClient.GenerateCommandSummary(ctx, history.GetMessages(), namedOutput) - if err != nil { - return nil, trace.Wrap(err) - } - - // Add the summary message to the backend, so it is persisted on chat - // reload. - messagePayload, err := json.Marshal(&assistlib.CommandExecSummary{ - ExecutionID: req.executionID, - Command: req.command, - Summary: summary, - }) - if err != nil { - return nil, trace.Wrap(err) - } - summaryMessage := &assist.CreateAssistantMessageRequest{ - ConversationId: req.conversationID, - Username: req.username, - Message: &assist.AssistantMessage{ - Type: string(assistlib.MessageKindCommandResultSummary), - CreatedTime: timestamppb.New(time.Now().UTC()), - Payload: string(messagePayload), - }, - } - - err = req.authClient.CreateAssistantMessage(ctx, summaryMessage) - if err != nil { - return nil, trace.Wrap(err) - } - - // Send the summary over the execution websocket to provide instant - // feedback to the user. - out := &outEnvelope{ - Type: envelopeTypeSummary, - Payload: []byte(summary), - } - data, err := json.Marshal(out) - if err != nil { - return nil, trace.Wrap(err) - } - stream := NewWStream(ctx, ws, log, nil) - _, err = stream.Write(data) - return tokenCount, trace.Wrap(err) -} - -func outputByName(hosts []hostInfo, output map[string][]byte) map[string][]byte { - hostIDToName := make(map[string]string, len(hosts)) - for _, host := range hosts { - hostIDToName[host.id] = host.hostName - } - namedOutput := make(map[string][]byte, len(output)) - for id, data := range output { - namedOutput[hostIDToName[id]] = data - } - return namedOutput -} - -// runCommands runs the given command on the given hosts. -func runCommands(hosts []hostInfo, runCmd func(host *hostInfo) error, numParallel int, log logrus.FieldLogger) { - var group errgroup.Group - group.SetLimit(numParallel) - - for _, host := range hosts { - host := host - group.Go(func() error { - return trace.Wrap(runCmd(&host), "failed to start session on %v", host.hostName) - }) - } - - // Wait for all commands to finish. - if err := group.Wait(); err != nil { - log.WithError(err).Debug("Assist command execution failed") - } -} - -// getMFACacheFn returns a function that caches the result of the given -// get function. The cache is protected by a mutex, so it is safe to call -// the returned function from multiple goroutines. -func getMFACacheFn() mfaFuncCache { - var mutex sync.Mutex - var authMethods []ssh.AuthMethod - - return func(issueMfaAuthFn func() ([]ssh.AuthMethod, error)) ([]ssh.AuthMethod, error) { - mutex.Lock() - defer mutex.Unlock() - - if authMethods != nil { - return authMethods, nil - } - - authMethods, err := issueMfaAuthFn() - return authMethods, trace.Wrap(err) - } -} - -func newCommandHandler(ctx context.Context, cfg CommandHandlerConfig) (*commandHandler, error) { - err := cfg.CheckAndSetDefaults() - if err != nil { - return nil, trace.Wrap(err) - } - - _, span := cfg.tracer.Start(ctx, "NewCommand") - defer span.End() - - return &commandHandler{ - sshBaseHandler: sshBaseHandler{ - log: logrus.WithFields(logrus.Fields{ - teleport.ComponentKey: teleport.ComponentWebsocket, - "session_id": cfg.SessionData.ID.String(), - }), - ctx: cfg.SessionCtx, - userAuthClient: cfg.UserAuthClient, - sessionData: cfg.SessionData, - keepAliveInterval: cfg.KeepAliveInterval, - proxyHostPort: cfg.ProxyHostPort, - interactiveCommand: cfg.InteractiveCommand, - router: cfg.Router, - localAccessPoint: cfg.LocalAccessPoint, - tracer: cfg.tracer, - resolver: cfg.HostNameResolver, - }, - mfaAuthCache: cfg.mfaFuncCache, - buffer: cfg.buffer, - }, nil -} - -// CommandHandlerConfig is the configuration for the command handler. -type CommandHandlerConfig struct { - // SessionCtx is the context for the user's web session. - SessionCtx *SessionContext - // UserAuthClient is used to fetch nodes and sessions from the backend via the users' identity. - UserAuthClient UserAuthClient - // SessionData is the data to send to the client on the initial session creation. - SessionData session.Session - // KeepAliveInterval is the interval for sending ping frames to a web client. - // This value is pulled from the cluster network config and - // guaranteed to be set to a nonzero value as it's enforced by the configuration. - KeepAliveInterval time.Duration - // ProxyHostPort is the address of the server to connect to. - ProxyHostPort string - // InteractiveCommand is a command to execute. - InteractiveCommand []string - // Router determines how connections to nodes are created - Router *proxy.Router - // TracerProvider is used to create the tracer - TracerProvider oteltrace.TracerProvider - // LocalAccessPoint is the subset of the Proxy cache required to - // look up information from the local cluster. This should not - // be used for anything that requires RBAC on behalf of the user. - // Anything requests that should be made on behalf of the user should - // use [UserAuthClient]. - LocalAccessPoint localAccessPoint - // HostNameResolver allows the hostname to be determined from a server UUID - // so that a friendly name can be displayed in the console tab. - HostNameResolver func(serverID string) (hostname string, err error) - // tracer is used to create spans - tracer oteltrace.Tracer - // mfaFuncCache is used to cache the MFA auth method - mfaFuncCache mfaFuncCache - // buffer shared across multiple commandHandlers that saves the command - // output in order to generate a summary of the executed commands. - buffer *summaryBuffer -} - -// CheckAndSetDefaults checks and sets default values. -func (t *CommandHandlerConfig) CheckAndSetDefaults() error { - // Make sure whatever session is requested is a valid session id. - _, err := session.ParseID(t.SessionData.ID.String()) - if err != nil { - return trace.BadParameter("sid: invalid session id") - } - - if t.SessionData.Login == "" { - return trace.BadParameter("login: missing login") - } - - if t.SessionData.ServerID == "" { - return trace.BadParameter("server: missing server") - } - - if t.UserAuthClient == nil { - return trace.BadParameter("UserAuthClient must be provided") - } - - if t.SessionCtx == nil { - return trace.BadParameter("SessionCtx must be provided") - } - - if t.Router == nil { - return trace.BadParameter("Router must be provided") - } - - if t.TracerProvider == nil { - t.TracerProvider = tracing.DefaultProvider() - } - - if t.LocalAccessPoint == nil { - return trace.BadParameter("localAccessPoint must be provided") - } - - if t.mfaFuncCache == nil { - return trace.BadParameter("mfaFuncCache must be provided") - } - - t.tracer = t.TracerProvider.Tracer("webcommand") - - return nil -} - -// mfaFuncCache is a function type that caches the result of a function that -// returns a list of ssh.AuthMethods. -type mfaFuncCache func(func() ([]ssh.AuthMethod, error)) ([]ssh.AuthMethod, error) - -// commandHandler is a handler for executing commands on a remote node. -type commandHandler struct { - sshBaseHandler - - // stream is the websocket stream to the client. - stream *WSStream - - // ws a raw websocket connection to the client. - ws WSConn - - // mfaAuthCache is a function that caches the result of a function that - // returns a list of ssh.AuthMethods. It is used to cache the result of - // the MFA challenge. - mfaAuthCache mfaFuncCache - - // buffer shared across multiple commandHandlers that saves the command - // output in order to generate a summary of the executed commands. - buffer *summaryBuffer -} - -// sendError sends an error message to the client using the provided websocket. -func (t *sshBaseHandler) sendError(errMsg string, err error, ws WSConn) { - envelope := &Envelope{ - Version: defaults.WebsocketVersion, - Type: defaults.WebsocketError, - Payload: fmt.Sprintf("%s: %s", errMsg, err.Error()), - } - - envelopeBytes, err := proto.Marshal(envelope) - if err != nil { - t.log.WithError(err).Error("failed to marshal error message") - } - if err := ws.WriteMessage(websocket.BinaryMessage, envelopeBytes); err != nil { - t.log.WithError(err).Error("failed to send error message") - } -} - -func (t *commandHandler) ServeHTTP(_ http.ResponseWriter, r *http.Request) { - // Allow closing websocket if the user logs out before exiting - // the session. - t.ctx.AddClosers(t) - defer t.ctx.RemoveCloser(t) - - sessionMetadataResponse, err := json.Marshal(siteSessionGenerateResponse{Session: t.sessionData}) - if err != nil { - t.sendError("unable to marshal session response", err, t.ws) - return - } - - envelope := &Envelope{ - Version: defaults.WebsocketVersion, - Type: defaults.WebsocketSessionMetadata, - Payload: string(sessionMetadataResponse), - } - - envelopeBytes, err := proto.Marshal(envelope) - if err != nil { - t.sendError("unable to marshal session data event for web client", err, t.ws) - return - } - - err = t.ws.WriteMessage(websocket.BinaryMessage, envelopeBytes) - if err != nil { - t.sendError("unable to write message to socket", err, t.ws) - return - } - - t.handler(r) -} - -func (t *commandHandler) handler(r *http.Request) { - t.stream = NewWStream(r.Context(), t.ws, t.log, nil) - - // Create a Teleport client, if not able to, show the reason to the user in - // the terminal. - tc, err := t.makeClient(r.Context(), t.ws) - if err != nil { - t.log.WithError(err).Info("Failed creating a client for session") - t.writeError(err) - return - } - - t.log.Debug("Creating websocket stream") - - // Start sending ping frames through websocket to the client. - go startPingLoop(r.Context(), t.ws, t.keepAliveInterval, t.log, t.Close) - - // Pump raw terminal in/out and audit events into the websocket. - t.streamOutput(r.Context(), tc) -} - -// streamOutput opens an SSH connection to the remote host and streams -// events back to the web client. -func (t *commandHandler) streamOutput(ctx context.Context, tc *client.TeleportClient) { - ctx, span := t.tracer.Start(ctx, "commandHandler/streamOutput") - defer span.End() - - nc, err := t.connectToHost(ctx, t.ws, tc, t.connectToNodeWithMFA) - if err != nil { - t.log.WithError(err).Warn("Unable to stream terminal - failure connecting to host") - t.writeError(err) - return - } - - defer nc.Close() - - // Enable session recording - nc.AddEnv(teleport.EnableNonInteractiveSessionRecording, "true") - - // Establish SSH connection to the server. This function will block until - // either an error occurs or it completes successfully. - if err = nc.RunCommand(ctx, t.interactiveCommand); err != nil { - t.log.WithError(err).Warn("Unable to stream terminal - failure running shell") - t.writeError(err) - return - } - - if err := t.stream.SendCloseMessage(sessionEndEvent{NodeID: t.sessionData.ServerID}); err != nil { - t.log.WithError(err).Error("Unable to send close event to web client.") - return - } - - t.log.Debug("Sent close event to web client.") -} - -// connectToNodeWithMFA attempts to perform the mfa ceremony and then dial the -// host with the retrieved single use certs. -// If called multiple times, the mfa ceremony will only be performed once. -func (t *commandHandler) connectToNodeWithMFA(ctx context.Context, ws WSConn, tc *client.TeleportClient, accessChecker services.AccessChecker, getAgent teleagent.Getter, signer agentless.SignerCreator) (*client.NodeClient, error) { - authMethods, err := t.mfaAuthCache(func() ([]ssh.AuthMethod, error) { - // perform mfa ceremony and retrieve new certs - authMethods, err := t.issueSessionMFACerts(ctx, tc, t.stream) - if err != nil { - return nil, trace.Wrap(err) - } - - return authMethods, nil - }) - if err != nil { - return nil, trace.Wrap(err) - } - - return t.connectToNodeWithMFABase(ctx, ws, tc, accessChecker, getAgent, signer, authMethods) -} - -// Close is no-op as we never want to close the connection to the client. -// Connection should be closed in the handler when it was created. -func (t *commandHandler) Close() error { - return nil -} - -// makeClient builds a *client.TeleportClient for the connection. -func (t *commandHandler) makeClient(ctx context.Context, ws WSConn) (*client.TeleportClient, error) { - ctx, span := tracing.DefaultProvider().Tracer("command").Start(ctx, "commandHandler/makeClient") - defer span.End() - - clientConfig, err := makeTeleportClientConfig(ctx, t.ctx) - if err != nil { - return nil, trace.Wrap(err) - } - - clientConfig.HostLogin = t.sessionData.Login - clientConfig.ForwardAgent = client.ForwardAgentLocal - clientConfig.Namespace = apidefaults.Namespace - clientConfig.Stdout = newBufferedPayloadWriter(newPayloadWriter(t.sessionData.ServerID, EnvelopeTypeStdout, t.stream), t.buffer) - clientConfig.Stderr = newBufferedPayloadWriter(newPayloadWriter(t.sessionData.ServerID, envelopeTypeStderr, t.stream), t.buffer) - clientConfig.Stdin = &bytes.Buffer{} // set stdin to a dummy buffer - clientConfig.SiteName = t.sessionData.ClusterName - if err := clientConfig.ParseProxyHost(t.proxyHostPort); err != nil { - return nil, trace.BadParameter("failed to parse proxy address: %v", err) - } - clientConfig.Host = t.sessionData.ServerHostname - clientConfig.HostPort = t.sessionData.ServerHostPort - clientConfig.SessionID = t.sessionData.ID.String() - clientConfig.ClientAddr = ws.RemoteAddr().String() - clientConfig.Tracer = t.tracer - - tc, err := client.NewClient(clientConfig) - if err != nil { - return nil, trace.BadParameter("failed to create client: %v", err) - } - - return tc, nil -} - -// writeError displays an error in the terminal window. -func (t *commandHandler) writeError(err error) { - out := &outEnvelope{ - NodeID: t.sessionData.ServerID, - Type: envelopeTypeError, - Payload: []byte(err.Error()), - } - data, err := json.Marshal(out) - if err != nil { - t.log.WithError(err).Error("failed to marshal error message") - return - } - - if _, writeErr := t.stream.Write(data); writeErr != nil { - t.log.WithError(writeErr).Warnf("Unable to send error to terminal: %v", err) - } -} diff --git a/lib/web/command_test.go b/lib/web/command_test.go deleted file mode 100644 index 3dcc18232092f..0000000000000 --- a/lib/web/command_test.go +++ /dev/null @@ -1,422 +0,0 @@ -/* - * Teleport - * Copyright (C) 2023 Gravitational, Inc. - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU Affero General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU Affero General Public License for more details. - * - * You should have received a copy of the GNU Affero General Public License - * along with this program. If not, see . - */ - -package web - -import ( - "context" - "crypto/tls" - "encoding/json" - "fmt" - "io" - "net/http" - "net/http/httptest" - "net/url" - "strings" - "sync/atomic" - "testing" - "time" - - "github.com/gogo/protobuf/proto" - "github.com/google/uuid" - "github.com/gorilla/websocket" - "github.com/gravitational/trace" - "github.com/sashabaranov/go-openai" - "github.com/sirupsen/logrus" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "google.golang.org/protobuf/types/known/timestamppb" - - "github.com/gravitational/teleport/api/gen/proto/go/assist/v1" - "github.com/gravitational/teleport/api/types" - "github.com/gravitational/teleport/lib/ai/testutils" - assistlib "github.com/gravitational/teleport/lib/assist" - "github.com/gravitational/teleport/lib/auth" - "github.com/gravitational/teleport/lib/client" - "github.com/gravitational/teleport/lib/services" - "github.com/gravitational/teleport/lib/session" - "github.com/gravitational/teleport/lib/utils" -) - -const ( - testCommand = "echo txlxport | sed 's/x/e/g'" - testUser = "foo" -) - -func TestExecuteCommand(t *testing.T) { - t.Parallel() - openAIMock := mockOpenAISummary(t) - openAIConfig := openai.DefaultConfig("test-token") - openAIConfig.BaseURL = openAIMock.URL - s := newWebSuiteWithConfig(t, webSuiteConfig{ - disableDiskBasedRecording: true, - OpenAIConfig: &openAIConfig, - }) - - assistRole, err := types.NewRole("assist-access", types.RoleSpecV6{ - Allow: types.RoleConditions{ - Rules: []types.Rule{ - types.NewRule(types.KindAssistant, services.RW()), - }, - }, - }) - require.NoError(t, err) - assistRole, err = s.server.Auth().UpsertRole(s.ctx, assistRole) - require.NoError(t, err) - - ws, _, err := s.makeCommand(t, s.authPack(t, testUser, assistRole.GetName()), uuid.New()) - require.NoError(t, err) - t.Cleanup(func() { require.NoError(t, ws.Close()) }) - - stream := NewWStream(context.Background(), ws, utils.NewLoggerForTests(), nil) - - require.NoError(t, waitForCommandOutput(stream, "teleport")) -} - -func TestExecuteCommandHistory(t *testing.T) { - t.Parallel() - - openAIMock := mockOpenAISummary(t) - openAIConfig := openai.DefaultConfig("test-token") - openAIConfig.BaseURL = openAIMock.URL - s := newWebSuiteWithConfig(t, webSuiteConfig{ - disableDiskBasedRecording: true, - OpenAIConfig: &openAIConfig, - }) - - assistRole, err := types.NewRole("assist-access", types.RoleSpecV6{ - Allow: types.RoleConditions{ - Rules: []types.Rule{ - types.NewRule(types.KindAssistant, services.RW()), - }, - }, - }) - require.NoError(t, err) - assistRole, err = s.server.Auth().UpsertRole(s.ctx, assistRole) - require.NoError(t, err) - - authPack := s.authPack(t, testUser, assistRole.GetName()) - - ctx := context.Background() - clt, err := s.server.NewClient(auth.TestUser(testUser)) - require.NoError(t, err) - - // Create conversation, otherwise the command execution will not be saved - conversation, err := clt.CreateAssistantConversation(context.Background(), &assist.CreateAssistantConversationRequest{ - Username: testUser, - CreatedTime: timestamppb.Now(), - }) - require.NoError(t, err) - - require.NotEmpty(t, conversation.GetId()) - - conversationID, err := uuid.Parse(conversation.GetId()) - require.NoError(t, err) - - ws, _, err := s.makeCommand(t, authPack, conversationID) - require.NoError(t, err) - - stream := NewWStream(ctx, ws, utils.NewLoggerForTests(), nil) - - // When command executes - require.NoError(t, waitForCommandOutput(stream, "teleport")) - - // Close the stream if not already closed - _ = stream.Close() - - // Then command execution history is saved - var messages *assist.GetAssistantMessagesResponse - // Command execution history is saved in asynchronously, so we need to wait for it. - require.Eventually(t, func() bool { - messages, err = clt.GetAssistantMessages(ctx, &assist.GetAssistantMessagesRequest{ - ConversationId: conversationID.String(), - Username: testUser, - }) - require.NoError(t, err) - - return len(messagesByType(messages.GetMessages())[assistlib.MessageKindCommandResult]) == 1 - }, 5*time.Second, 100*time.Millisecond) - - // Assert the returned message - resultMessages, ok := messagesByType(messages.GetMessages())[assistlib.MessageKindCommandResult] - require.True(t, ok, "Message must be of type COMMAND_RESULT") - msg := resultMessages[0] - require.NotZero(t, msg.CreatedTime) - - var result commandExecResult - err = json.Unmarshal([]byte(msg.GetPayload()), &result) - require.NoError(t, err) - - require.NotEmpty(t, result.ExecutionID) - require.NotEmpty(t, result.SessionID) - require.Equal(t, "node", result.NodeName) - require.Equal(t, "node", result.NodeID) -} - -func TestExecuteCommandSummary(t *testing.T) { - t.Parallel() - - openAIMock := mockOpenAISummary(t) - openAIConfig := openai.DefaultConfig("test-token") - openAIConfig.BaseURL = openAIMock.URL - s := newWebSuiteWithConfig(t, webSuiteConfig{ - disableDiskBasedRecording: true, - OpenAIConfig: &openAIConfig, - }) - - assistRole, err := types.NewRole("assist-access", types.RoleSpecV6{ - Allow: types.RoleConditions{ - Rules: []types.Rule{ - types.NewRule(types.KindAssistant, services.RW()), - }, - }, - }) - require.NoError(t, err) - assistRole, err = s.server.Auth().UpsertRole(s.ctx, assistRole) - require.NoError(t, err) - - authPack := s.authPack(t, testUser, assistRole.GetName()) - - ctx := context.Background() - clt, err := s.server.NewClient(auth.TestUser(testUser)) - require.NoError(t, err) - - // Create conversation, otherwise the command execution will not be saved - conversation, err := clt.CreateAssistantConversation(context.Background(), &assist.CreateAssistantConversationRequest{ - Username: testUser, - CreatedTime: timestamppb.Now(), - }) - require.NoError(t, err) - - require.NotEmpty(t, conversation.GetId()) - - conversationID, err := uuid.Parse(conversation.GetId()) - require.NoError(t, err) - - ws, _, err := s.makeCommand(t, authPack, conversationID) - require.NoError(t, err) - - // For simplicity, use simple WS to io.Reader adapter - stream := &wsReader{conn: ws} - - // Wait for command execution to complete - require.NoError(t, waitForCommandOutput(stream, "teleport")) - - dec := json.NewDecoder(stream) - - // Consume the close message - var sessionMetadata sessionEndEvent - err = dec.Decode(&sessionMetadata) - require.NoError(t, err) - require.Equal(t, "node", sessionMetadata.NodeID) - - // Consume the summary message - var env outEnvelope - err = dec.Decode(&env) - require.NoError(t, err) - require.Equal(t, envelopeTypeSummary, env.Type) - require.NotEmpty(t, env.Payload) - - // Wait for the command execution history to be saved - var messages *assist.GetAssistantMessagesResponse - // Command execution history is saved in asynchronously, so we need to wait for it. - require.Eventually(t, func() bool { - messages, err = clt.GetAssistantMessages(ctx, &assist.GetAssistantMessagesRequest{ - ConversationId: conversationID.String(), - Username: testUser, - }) - assert.NoError(t, err) - - return len(messagesByType(messages.GetMessages())[assistlib.MessageKindCommandResultSummary]) == 1 - }, 5*time.Second, 100*time.Millisecond) - - // Check the returned summary message - summaryMessages, ok := messagesByType(messages.GetMessages())[assistlib.MessageKindCommandResultSummary] - require.True(t, ok, "At least one summary message is expected") - msg := summaryMessages[0] - require.NotZero(t, msg.CreatedTime) - - var result assistlib.CommandExecSummary - err = json.Unmarshal([]byte(msg.GetPayload()), &result) - require.NoError(t, err) - - require.NotEmpty(t, result.ExecutionID) - require.Equal(t, testCommand, result.Command) - require.NotEmpty(t, result.Summary) -} - -func (s *WebSuite) makeCommand(t *testing.T, pack *authPack, conversationID uuid.UUID) (*websocket.Conn, *session.Session, error) { - req := CommandRequest{ - Query: fmt.Sprintf("name == \"%s\"", s.srvID), - Login: pack.login, - ConversationID: conversationID.String(), - ExecutionID: uuid.New().String(), - Command: testCommand, - } - - u := url.URL{ - Host: s.url().Host, - Scheme: client.WSS, - Path: fmt.Sprintf("/v1/webapi/command/%v/execute/ws", currentSiteShortcut), - } - data, err := json.Marshal(req) - if err != nil { - return nil, nil, err - } - - q := u.Query() - q.Set("params", string(data)) - u.RawQuery = q.Encode() - - dialer := websocket.Dialer{} - dialer.TLSClientConfig = &tls.Config{ - InsecureSkipVerify: true, - } - - header := http.Header{} - header.Add("Origin", "http://localhost") - for _, cookie := range pack.cookies { - header.Add("Cookie", cookie.String()) - } - - ws, resp, err := dialer.Dial(u.String(), header) - if err != nil { - return nil, nil, trace.Wrap(err) - } - - if err := makeAuthReqOverWS(ws, pack.session.Token); err != nil { - return nil, nil, trace.Wrap(err) - } - - ty, raw, err := ws.ReadMessage() - if err != nil { - return nil, nil, trace.Wrap(err) - } - require.Equal(t, websocket.BinaryMessage, ty) - var env Envelope - - err = proto.Unmarshal(raw, &env) - if err != nil { - return nil, nil, trace.Wrap(err) - } - - var sessResp siteSessionGenerateResponse - - err = json.Unmarshal([]byte(env.Payload), &sessResp) - if err != nil { - return nil, nil, trace.Wrap(err) - } - - err = resp.Body.Close() - if err != nil { - return nil, nil, trace.Wrap(err) - } - - return ws, &sessResp.Session, nil -} - -func waitForCommandOutput(stream io.Reader, substr string) error { - timeoutCh := time.After(10 * time.Second) - - for { - select { - case <-timeoutCh: - return trace.BadParameter("timeout waiting on terminal for output: %v", substr) - default: - } - - var env outEnvelope - dec := json.NewDecoder(stream) - if err := dec.Decode(&env); err != nil { - return trace.Wrap(err, "decoding envelope JSON from stream") - } - - data := removeSpace(string(env.Payload)) - if strings.Contains(data, substr) { - return nil - } - } -} - -// Test_runCommands tests that runCommands runs the given command on all hosts. -// The commands should run in parallel, but we don't have a deterministic way to -// test that (sleep with checking the execution time in not deterministic). -func Test_runCommands(t *testing.T) { - const numWorkers = 30 - counter := atomic.Int32{} - - runCmd := func(host *hostInfo) error { - counter.Add(1) - return nil - } - - hosts := make([]hostInfo, 0, 100) - for i := 0; i < 100; i++ { - hosts = append(hosts, hostInfo{ - hostName: fmt.Sprintf("localhost%d", i), - }) - } - - logger := logrus.New() - logger.Out = io.Discard - - runCommands(hosts, runCmd, numWorkers, logger) - - require.Equal(t, int32(100), counter.Load()) -} - -func mockOpenAISummary(t *testing.T) *httptest.Server { - responses := []string{"This is the summary of the command."} - server := httptest.NewServer(testutils.GetTestHandlerFn(t, responses)) - t.Cleanup(server.Close) - return server -} - -func messagesByType(messages []*assist.AssistantMessage) map[assistlib.MessageType][]*assist.AssistantMessage { - byType := make(map[assistlib.MessageType][]*assist.AssistantMessage) - for _, message := range messages { - byType[assistlib.MessageType(message.GetType())] = append(byType[assistlib.MessageType(message.GetType())], message) - } - return byType -} - -// wsReader implements io.Reader interface over websocket connection -type wsReader struct { - conn *websocket.Conn -} - -// Read reads data from websocket connection. -// The message should be in web.Envelope format and only the payload will be returned. -// It expects that the passed buffer is big enough to fit the whole message. -func (r *wsReader) Read(p []byte) (int, error) { - _, data, err := r.conn.ReadMessage() - if err != nil { - return 0, trace.Wrap(err) - } - - var envelope Envelope - if err := proto.Unmarshal(data, &envelope); err != nil { - return 0, trace.Errorf("Unable to parse message payload %v", err) - } - - if len(envelope.Payload) > len(p) { - return 0, trace.BadParameter("buffer too small") - } - - return copy(p, envelope.Payload), nil -} diff --git a/lib/web/command_utils.go b/lib/web/command_utils.go deleted file mode 100644 index 04c5545bee059..0000000000000 --- a/lib/web/command_utils.go +++ /dev/null @@ -1,207 +0,0 @@ -/* - * Teleport - * Copyright (C) 2023 Gravitational, Inc. - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU Affero General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU Affero General Public License for more details. - * - * You should have received a copy of the GNU Affero General Public License - * along with this program. If not, see . - */ - -package web - -import ( - "encoding/json" - "io" - "net" - "sync" - "time" - - "github.com/gravitational/trace" -) - -// WSConn is a gorilla/websocket minimal interface used by our web implementation. -// This interface exists to override the default websocket.Conn implementation, -// currently used by noopCloserWS to prevent WS being closed by wrapping stream. -type WSConn interface { - Close() error - - LocalAddr() net.Addr - RemoteAddr() net.Addr - - WriteControl(messageType int, data []byte, deadline time.Time) error - WriteMessage(messageType int, data []byte) error - ReadMessage() (messageType int, p []byte, err error) - SetReadLimit(limit int64) - SetReadDeadline(t time.Time) error - SetWriteDeadline(t time.Time) error - - PongHandler() func(appData string) error - SetPongHandler(h func(appData string) error) - CloseHandler() func(code int, text string) error - SetCloseHandler(h func(code int, text string) error) -} - -const ( - EnvelopeTypeStdout = "stdout" - envelopeTypeStderr = "stderr" - envelopeTypeError = "teleport-error" - envelopeTypeSummary = "summary" -) - -// outEnvelope is an envelope used to wrap messages send back to the client connected over WS. -type outEnvelope struct { - NodeID string `json:"node_id"` - Type string `json:"type"` - Payload []byte `json:"payload"` -} - -// payloadWriter is a wrapper around io.Writer, which wraps the given bytes into -// outEnvelope and writes it to the underlying stream. -type payloadWriter struct { - nodeID string - // output name, can be stdout, stderr, teleport-error or summary. - outputName string - // stream is the underlying stream. - stream io.Writer -} - -// Write writes the given bytes to the underlying stream. -func (p *payloadWriter) Write(b []byte) (n int, err error) { - out := &outEnvelope{ - NodeID: p.nodeID, - Type: p.outputName, - Payload: b, - } - data, err := json.Marshal(out) - if err != nil { - return 0, trace.Wrap(err) - } - - _, err = p.stream.Write(data) - // return the size of the original message as a message send over stream - // is larger due to json marshaling and envelope. - return len(b), trace.Wrap(err) -} - -func newPayloadWriter(nodeID, outputName string, stream io.Writer) *payloadWriter { - return &payloadWriter{ - nodeID: nodeID, - outputName: outputName, - stream: stream, - } -} - -// noopCloserWS is a wrapper around websocket.Conn, which does nothing on Close(). -// This struct is used to prevent WS being closed by wrapping stream. -// Currently, it is being used in Command web handler to prevent WS being closed -// by any underlying code as we want to keep the connection open until the command -// is executed on all nodes and a single failure should not close the connection. -type noopCloserWS struct { - WSConn -} - -// Close does nothing. -func (ws *noopCloserWS) Close() error { - return nil -} - -// syncRWWSConn is a wrapper around websocket.Conn, which serializes -// read and write to a web socket connection. This is needed to prevent -// a race conditions and panics in gorilla/websocket. -// Details https://pkg.go.dev/github.com/gorilla/websocket#hdr-Concurrency -// This struct does not lock SetReadDeadline() as the SetReadDeadline() -// is called from the pong handler, which is interanlly called on ReadMessage() -// according to https://pkg.go.dev/github.com/gorilla/websocket#hdr-Control_Messages -// This would prevent the pong handler from being called. -type syncRWWSConn struct { - // WSConn the underlying websocket connection. - WSConn - // rmtx is a mutex used to serialize reads. - rmtx sync.Mutex - // wmtx is a mutex used to serialize writes. - wmtx sync.Mutex -} - -func (s *syncRWWSConn) WriteMessage(messageType int, data []byte) error { - s.wmtx.Lock() - defer s.wmtx.Unlock() - return s.WSConn.WriteMessage(messageType, data) -} - -func (s *syncRWWSConn) ReadMessage() (messageType int, p []byte, err error) { - s.rmtx.Lock() - defer s.rmtx.Unlock() - return s.WSConn.ReadMessage() -} - -func newBufferedPayloadWriter(pw *payloadWriter, buffer *summaryBuffer) *bufferedPayloadWriter { - return &bufferedPayloadWriter{ - payloadWriter: pw, - buffer: buffer, - } -} - -type bufferedPayloadWriter struct { - *payloadWriter - buffer *summaryBuffer -} - -func (bp *bufferedPayloadWriter) Write(data []byte) (int, error) { - bp.buffer.Write(bp.nodeID, data) - return bp.payloadWriter.Write(data) -} - -func newSummaryBuffer(capacity int) *summaryBuffer { - return &summaryBuffer{ - buffer: make(map[string][]byte), - remainingCapacity: capacity, - invalid: false, - mutex: sync.Mutex{}, - } -} - -type summaryBuffer struct { - buffer map[string][]byte - remainingCapacity int - invalid bool - // mutex protects all members of the struct and must be acquired before - // performing any read or write operation - mutex sync.Mutex -} - -func (b *summaryBuffer) Write(node string, data []byte) { - b.mutex.Lock() - defer b.mutex.Unlock() - if b.invalid { - return - } - if len(data) > b.remainingCapacity { - // We're out of capacity, not all content will be written to the buffer - // it should not be used anymore - b.invalid = true - return - } - b.buffer[node] = append(b.buffer[node], data...) - b.remainingCapacity -= len(data) -} - -// Export returns the buffer content and a whether the Export is valid. -// Exporting the buffer can only happen once. -func (b *summaryBuffer) Export() (map[string][]byte, bool) { - b.mutex.Lock() - defer b.mutex.Unlock() - if b.invalid { - return nil, false - } - b.invalid = true - return b.buffer, len(b.buffer) != 0 -} diff --git a/lib/web/command_utils_test.go b/lib/web/command_utils_test.go deleted file mode 100644 index acb74690b735c..0000000000000 --- a/lib/web/command_utils_test.go +++ /dev/null @@ -1,145 +0,0 @@ -/* - * Teleport - * Copyright (C) 2023 Gravitational, Inc. - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU Affero General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU Affero General Public License for more details. - * - * You should have received a copy of the GNU Affero General Public License - * along with this program. If not, see . - */ - -package web - -import ( - "sync" - "testing" - - "github.com/stretchr/testify/require" -) - -func TestSummaryBuffer(t *testing.T) { - t.Parallel() - tests := []struct { - name string - outputs map[string][][]byte - capacity int - expectedOutput map[string][]byte - assertValidity require.BoolAssertionFunc - }{ - { - name: "Single node", - outputs: map[string][][]byte{ - "node": { - []byte("foo"), - []byte("bar"), - []byte("baz"), - }, - }, - capacity: 9, - expectedOutput: map[string][]byte{ - "node": []byte("foobarbaz"), - }, - assertValidity: require.True, - }, - { - name: "Single node overflow", - outputs: map[string][][]byte{ - "node": { - []byte("foo"), - []byte("bar"), - []byte("baz"), - }, - }, - capacity: 8, - expectedOutput: nil, - assertValidity: require.False, - }, - { - name: "Multiple nodes", - outputs: map[string][][]byte{ - "node1": { - []byte("foo"), - []byte("bar"), - []byte("baz"), - }, - "node2": { - []byte("baz"), - []byte("bar"), - []byte("foo"), - }, - "node3": { - []byte("baz"), - []byte("baz"), - []byte("baz"), - }, - }, - capacity: 30, - expectedOutput: map[string][]byte{ - "node1": []byte("foobarbaz"), - "node2": []byte("bazbarfoo"), - "node3": []byte("bazbazbaz"), - }, - assertValidity: require.True, - }, - { - name: "Multiple nodes overflow", - outputs: map[string][][]byte{ - "node1": { - []byte("foo"), - []byte("bar"), - []byte("baz"), - }, - "node2": { - []byte("baz"), - []byte("bar"), - []byte("foo"), - }, - "node3": { - []byte("baz"), - []byte("baz"), - []byte("baz"), - }, - }, - capacity: 25, - expectedOutput: nil, - assertValidity: require.False, - }, - { - name: "No output", - outputs: nil, - capacity: 10, - expectedOutput: map[string][]byte{}, - assertValidity: require.False, - }, - } - for _, tc := range tests { - tc := tc - t.Run(tc.name, func(t *testing.T) { - buffer := newSummaryBuffer(tc.capacity) - var wg sync.WaitGroup - for node, output := range tc.outputs { - node := node - output := output - wg.Add(1) - go func() { - defer wg.Done() - for _, chunk := range output { - buffer.Write(node, chunk) - } - }() - } - wg.Wait() - output, isValid := buffer.Export() - require.Equal(t, tc.expectedOutput, output) - tc.assertValidity(t, isValid) - }) - } -} diff --git a/lib/web/terminal.go b/lib/web/terminal.go index 5bcb2d38c3e22..34f2eef8e1d62 100644 --- a/lib/web/terminal.go +++ b/lib/web/terminal.go @@ -108,6 +108,28 @@ type UserAuthClient interface { MaintainSessionPresence(ctx context.Context) (authproto.AuthService_MaintainSessionPresenceClient, error) } +// WSConn is a gorilla/websocket minimal interface used by our web implementation. +// This interface exists to override the default websocket.Conn implementation, +// currently used by noopCloserWS to prevent WS being closed by wrapping stream. +type WSConn interface { + Close() error + + LocalAddr() net.Addr + RemoteAddr() net.Addr + + WriteControl(messageType int, data []byte, deadline time.Time) error + WriteMessage(messageType int, data []byte) error + ReadMessage() (messageType int, p []byte, err error) + SetReadLimit(limit int64) + SetReadDeadline(t time.Time) error + SetWriteDeadline(t time.Time) error + + PongHandler() func(appData string) error + SetPongHandler(h func(appData string) error) + CloseHandler() func(code int, text string) error + SetCloseHandler(h func(code int, text string) error) +} + // NewTerminal creates a web-based terminal based on WebSockets and returns a // new TerminalHandler. func NewTerminal(ctx context.Context, cfg TerminalHandlerConfig) (*TerminalHandler, error) { @@ -771,7 +793,13 @@ func monitorSessionLatency(ctx context.Context, clock clockwork.Clock, stream *W return nil } -// streamTerminal opens a SSH connection to the remote host and streams +// sessionEndEvent is an event sent when a session ends. +type sessionEndEvent struct { + // NodeID is the ID of the server where the session was created. + NodeID string `json:"node_id"` +} + +// streamTerminal opens an SSH connection to the remote host and streams // events back to the web client. func (t *TerminalHandler) streamTerminal(ctx context.Context, tc *client.TeleportClient) { ctx, span := t.tracer.Start(ctx, "terminal/streamTerminal") @@ -913,6 +941,23 @@ func (t *sshBaseHandler) connectToNodeWithMFABase(ctx context.Context, ws WSConn return nc, nil } +// sendError sends an error message to the client using the provided websocket. +func (t *sshBaseHandler) sendError(errMsg string, err error, ws WSConn) { + envelope := &Envelope{ + Version: defaults.WebsocketVersion, + Type: defaults.WebsocketError, + Payload: fmt.Sprintf("%s: %s", errMsg, err.Error()), + } + + envelopeBytes, err := proto.Marshal(envelope) + if err != nil { + t.log.WithError(err).Error("failed to marshal error message") + } + if err := ws.WriteMessage(websocket.BinaryMessage, envelopeBytes); err != nil { + t.log.WithError(err).Error("failed to send error message") + } +} + // streamEvents receives events over the SSH connection and forwards them to // the web client. func (t *TerminalHandler) streamEvents(ctx context.Context, tc *client.TeleportClient) {