Skip to content

Commit

Permalink
fix: cmd multiplexing
Browse files Browse the repository at this point in the history
Signed-off-by: Toma Puljak <[email protected]>
  • Loading branch information
Tpuljak committed Jan 20, 2025
1 parent f0a3415 commit 138a9da
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 50 deletions.
83 changes: 41 additions & 42 deletions pkg/agent/toolbox/process/session/execute.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
package session

import (
"bytes"
"context"
"errors"
"fmt"
"net/http"
Expand Down Expand Up @@ -37,74 +39,71 @@ func SessionExecuteCommand(configDir string) func(c *gin.Context) {
var cmdId *string
var logFile *os.File

if request.Async {
cmdId = util.Pointer(uuid.NewString())
cmdId = util.Pointer(uuid.NewString())

command := &Command{
Id: *cmdId,
Command: request.Command,
}
session.commands[*cmdId] = command
command := &Command{
Id: *cmdId,
Command: request.Command,
}
session.commands[*cmdId] = command

err := os.MkdirAll(filepath.Join(configDir, "sessions", sessionId, *cmdId), 0755)
if err != nil {
c.AbortWithError(http.StatusInternalServerError, err)
return
}
err := os.MkdirAll(filepath.Join(configDir, "sessions", sessionId, *cmdId), 0755)
if err != nil {
c.AbortWithError(http.StatusInternalServerError, err)
return
}

logFile, err = os.Create(filepath.Join(configDir, "sessions", sessionId, *cmdId, "output.log"))
if err != nil {
c.AbortWithError(http.StatusInternalServerError, err)
return
}
logFile, err = os.Create(filepath.Join(configDir, "sessions", sessionId, *cmdId, "output.log"))
if err != nil {
c.AbortWithError(http.StatusInternalServerError, err)
return
}

cmdToExec := fmt.Sprintf("\n%s ; echo \"DAYTONA_CMD_EXIT_CODE: $?\"\n", request.Command)
cmdToExec := fmt.Sprintf("%s > %s 2>&1 ; echo \"DTN_EXIT: $?\" >> %s\n", request.Command, logFile.Name(), logFile.Name())

type execResult struct {
out string
err error
exitCode *int
}

resultChan := make(chan execResult)

go func() {
out := ""
var exitCode *int
defer close(resultChan)

for session.outReader.Scan() {
line := session.outReader.Text()
line = line + "\n"
logChan := make(chan []byte)
errChan := make(chan error)

exitCode, line = extractExitCode(line)
go util.ReadLog(context.Background(), logFile, true, logChan, errChan)

if request.Async {
_, err := logFile.Write([]byte(line))
if err != nil {
resultChan <- execResult{err: err}
return
defer logFile.Close()

for {
select {
case logEntry := <-logChan:
logEntry = bytes.Trim(logEntry, "\x00")
if len(logEntry) == 0 {
continue
}
} else {
exitCode, line := extractExitCode(string(logEntry))
out += line
}

if exitCode != nil {
if request.Async {
if exitCode != nil {
sessions[sessionId].commands[*cmdId].ExitCode = exitCode
resultChan <- execResult{out: out, exitCode: exitCode, err: nil}
return
}
case err := <-errChan:
if err != nil {
resultChan <- execResult{out: out, exitCode: nil, err: err}
return
}
break
}
}

if logFile != nil {
logFile.Close()
}
resultChan <- execResult{out: out, exitCode: exitCode, err: session.outReader.Err()}
}()

_, err := session.stdinWriter.Write([]byte(cmdToExec))
_, err = session.stdinWriter.Write([]byte(cmdToExec))
if err != nil {
c.AbortWithError(http.StatusBadRequest, err)
return
Expand Down Expand Up @@ -134,7 +133,7 @@ func SessionExecuteCommand(configDir string) func(c *gin.Context) {
func extractExitCode(output string) (*int, string) {
var exitCode *int

regex := regexp.MustCompile(`DAYTONA_CMD_EXIT_CODE: (\d+)\n`)
regex := regexp.MustCompile(`DTN_EXIT: (\d+)\n`)
matches := regex.FindStringSubmatch(output)
if len(matches) > 1 {
code, err := strconv.Atoi(matches[1])
Expand All @@ -145,7 +144,7 @@ func extractExitCode(output string) (*int, string) {
}

if exitCode != nil {
output = strings.Replace(output, fmt.Sprintf("DAYTONA_CMD_EXIT_CODE: %d\n", *exitCode), "", 1)
output = strings.Replace(output, fmt.Sprintf("DTN_EXIT: %d\n", *exitCode), "", 1)
}

return exitCode, output
Expand Down
22 changes: 14 additions & 8 deletions pkg/agent/toolbox/process/session/log.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,14 @@ import (
"github.com/daytonaio/daytona/internal/util"
"github.com/daytonaio/daytona/pkg/api/controllers/log"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
)

func GetSessionCommandLogs(configDir string) func(c *gin.Context) {
return func(c *gin.Context) {
sessionId := c.Param("sessionId")
cmdId := c.Param("commandId")

if cmdId == "" {
c.AbortWithError(http.StatusBadRequest, errors.New("commandId is required"))
return
}

_, ok := sessions[sessionId]
if !ok {
c.AbortWithError(http.StatusNotFound, errors.New("session not found"))
Expand All @@ -49,8 +45,17 @@ func GetSessionCommandLogs(configDir string) func(c *gin.Context) {
return
}
defer logFile.Close()

log.ReadLog(c, logFile, util.ReadLog, log.WriteToWs)
log.ReadLog(c, logFile, util.ReadLog, func(conn *websocket.Conn, messages chan []byte, errors chan error) {
for {
msg := <-messages
_, output := extractExitCode(string(msg))
err := conn.WriteMessage(websocket.TextMessage, []byte(output))
if err != nil {
errors <- err
break
}
}
})
return
}

Expand All @@ -64,6 +69,7 @@ func GetSessionCommandLogs(configDir string) func(c *gin.Context) {
return
}

c.String(http.StatusOK, string(content))
_, output := extractExitCode(string(content))
c.String(http.StatusOK, output)
}
}

0 comments on commit 138a9da

Please sign in to comment.