Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 144 additions & 0 deletions pkg/picod/picod_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"testing"
"time"

"github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt/v5"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -454,3 +455,146 @@ func TestPicoD_SetWorkspace(t *testing.T) {
// but we still resolve /var aliasing for the parent directory part.
assert.Equal(t, resolve(absLinkPath), resolve(server.workspaceDir))
}

// TestParseFileMode tests filesystem utility for file mode parsing
func TestParseFileMode(t *testing.T) {
tests := []struct {
name string
modeStr string
expected os.FileMode
desc string
}{
{
name: "Valid octal mode",
modeStr: "0644",
expected: 0644,
desc: "Should parse valid octal mode",
},
{
name: "Empty mode defaults to 0644",
modeStr: "",
expected: 0644,
desc: "Should default to 0644",
},
{
name: "Invalid mode defaults to 0644",
modeStr: "invalid",
expected: 0644,
desc: "Should default on invalid input",
},
{
name: "Mode exceeding max defaults to 0644",
modeStr: "10000",
expected: 0644,
desc: "Should default when exceeding 0777",
},
{
name: "Valid executable mode",
modeStr: "0755",
expected: 0755,
desc: "Should parse executable mode",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := parseFileMode(tt.modeStr)
assert.Equal(t, tt.expected, result, tt.desc)
})
}
}

// TestLoadBootstrapKey tests auth helper error paths
func TestLoadBootstrapKey(t *testing.T) {
tests := []struct {
name string
keyData []byte
expectErr bool
desc string
}{
{
name: "Empty key data",
keyData: []byte{},
expectErr: true,
desc: "Should reject empty key",
},
{
name: "Invalid PEM format",
keyData: []byte("not a pem"),
expectErr: true,
desc: "Should reject invalid PEM",
},
{
name: "Valid RSA public key",
keyData: func() []byte { _, pub := generateRSAKeys(t); return []byte(pub) }(),
expectErr: false,
desc: "Should accept valid RSA key",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
am := NewAuthManager()
err := am.LoadBootstrapKey(tt.keyData)
if tt.expectErr {
assert.Error(t, err, tt.desc)
} else {
assert.NoError(t, err, tt.desc)
}
})
}
}

// TestExecuteHandler_ErrorPaths tests execution pipeline error paths
func TestExecuteHandler_ErrorPaths(t *testing.T) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there no test case for the normal command?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

tmpDir, err := os.MkdirTemp("", "picod_execute_test")
require.NoError(t, err)
defer os.RemoveAll(tmpDir)

_, bootstrapPubStr := generateRSAKeys(t)
server := NewServer(Config{
BootstrapKey: []byte(bootstrapPubStr),
Workspace: tmpDir,
})

tests := []struct {
name string
request string
statusCode int
desc string
}{
{
name: "Empty command",
request: `{"command": []}`,
statusCode: http.StatusBadRequest,
desc: "Should reject empty command",
},
{
name: "Invalid JSON",
request: `{"command": invalid}`,
statusCode: http.StatusBadRequest,
desc: "Should reject invalid JSON",
},
{
name: "Invalid timeout format",
request: `{"command": ["echo", "test"], "timeout": "invalid"}`,
statusCode: http.StatusBadRequest,
desc: "Should reject invalid timeout",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := httptest.NewRecorder()
req := httptest.NewRequest("POST", "/api/execute", bytes.NewBufferString(tt.request))
req.Header.Set("Content-Type", "application/json")

ctx, _ := gin.CreateTestContext(w)
ctx.Request = req

server.ExecuteHandler(ctx)

assert.Equal(t, tt.statusCode, w.Code, tt.desc)
})
}
}
Comment on lines +549 to +938
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current implementation of TestExecuteHandler_ErrorPaths tests the handler in isolation by calling server.ExecuteHandler(ctx) directly. This bypasses the authentication middleware, which is a critical part of the request lifecycle for this endpoint.

A real request to /api/execute on an uninitialized server would be rejected by the AuthMiddleware with a 403 Forbidden status, and would never reach the handler to produce the 400 Bad Request that this test expects. This means the test could pass even if the endpoint is not correctly configured or is inaccessible.

To make this test more robust and representative of real-world usage, it should be structured as an integration test that goes through the full HTTP stack. This involves:

  1. Setting up a test server using httptest.NewServer.
  2. Initializing the server by making a request to the /init endpoint.
  3. Sending authenticated requests with invalid bodies to the /api/execute endpoint.

This approach ensures that the middleware, routing, and handler logic are all tested together. I've provided a suggestion to refactor the test accordingly.

func TestExecuteHandler_ErrorPaths(t *testing.T) {
	tmpDir, err := os.MkdirTemp("", "picod_execute_test")
	require.NoError(t, err)
	defer os.RemoveAll(tmpDir)

	bootstrapPriv, bootstrapPubStr := generateRSAKeys(t)
	sessionPriv, sessionPubStr := generateRSAKeys(t)

	server := NewServer(Config{
		BootstrapKey: []byte(bootstrapPubStr),
		Workspace:    tmpDir,
	})
	ts := httptest.NewServer(server.engine)
	defer ts.Close()
	client := ts.Client()

	// Initialize server to allow access to authenticated endpoints
	sessionPubB64 := base64.RawStdEncoding.EncodeToString([]byte(sessionPubStr))
	initClaims := jwt.MapClaims{
		"session_public_key": sessionPubB64,
		"iat":                time.Now().Unix(),
		"exp":                time.Now().Add(time.Hour).Unix(),
	}
	initToken := createToken(t, bootstrapPriv, initClaims)
	initReq, err := http.NewRequest("POST", ts.URL+"/init", nil)
	require.NoError(t, err)
	initReq.Header.Set("Authorization", "Bearer "+initToken)
	initResp, err := client.Do(initReq)
	require.NoError(t, err)
	require.Equal(t, http.StatusOK, initResp.StatusCode)
	initResp.Body.Close()

	tests := []struct {
		name       string
		request    string
		statusCode int
		desc       string
	}{
		{
			name:       "Empty command",
			request:    `{"command": []}`,
			statusCode: http.StatusBadRequest,
			desc:       "Should reject empty command",
		},
		{
			name:       "Invalid JSON",
			request:    `{"command": invalid}`,
			statusCode: http.StatusBadRequest,
			desc:       "Should reject invalid JSON",
		},
		{
			name:       "Invalid timeout format",
			request:    `{"command": ["echo", "test"], "timeout": "invalid"}`,
			statusCode: http.StatusBadRequest,
			desc:       "Should reject invalid timeout",
		},
	}

	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			reqBody := []byte(tt.request)
			hash := sha256.Sum256(reqBody)
			claims := jwt.MapClaims{
				"body_sha256": fmt.Sprintf("%x", hash),
				"iat":         time.Now().Unix(),
				"exp":         time.Now().Add(time.Hour).Unix(),
			}
			token := createToken(t, sessionPriv, claims)

			req, err := http.NewRequest("POST", ts.URL+"/api/execute", bytes.NewBuffer(reqBody))
			require.NoError(t, err)
			req.Header.Set("Content-Type", "application/json")
			req.Header.Set("Authorization", "Bearer "+token)

			resp, err := client.Do(req)
			require.NoError(t, err)
			defer resp.Body.Close()

			assert.Equal(t, tt.statusCode, resp.StatusCode, tt.desc)
		})
	}
}

Loading