Skip to content

Commit

Permalink
Edible Scripts Backend (#25739)
Browse files Browse the repository at this point in the history
  • Loading branch information
dantecatalfamo authored Jan 30, 2025
1 parent af475c7 commit 3c8033f
Show file tree
Hide file tree
Showing 10 changed files with 259 additions and 0 deletions.
1 change: 1 addition & 0 deletions changes/24602-editable-scripts
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
- Added API endpoint for updating script contents
25 changes: 25 additions & 0 deletions server/datastore/mysql/scripts.go
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,31 @@ func (ds *Datastore) NewScript(ctx context.Context, script *fleet.Script) (*flee
return ds.getScriptDB(ctx, ds.writer(ctx), uint(id)) //nolint:gosec // dismiss G115
}

func (ds *Datastore) UpdateScriptContents(ctx context.Context, scriptID uint, scriptContents string) (*fleet.Script, error) {
const stmt = `
UPDATE script_contents
INNER JOIN
scripts ON scripts.script_content_id = script_contents.id
SET
contents = ?,
md5_checksum = UNHEX(?)
WHERE
scripts.id = ?
`
md5Checksum := md5ChecksumScriptContent(scriptContents)

_, err := ds.writer(ctx).ExecContext(ctx, stmt, scriptContents, md5Checksum, scriptID)
if err != nil {
return nil, ctxerr.Wrap(ctx, err, "updating script contents")
}

if _, err := ds.writer(ctx).ExecContext(ctx, "UPDATE scripts SET updated_at = NOW() WHERE id = ?", scriptID); err != nil {
return nil, ctxerr.Wrap(ctx, err, "updating script updated_at time")
}

return ds.Script(ctx, scriptID)
}

func insertScript(ctx context.Context, tx sqlx.ExtContext, script *fleet.Script, scriptContentsID uint) (sql.Result, error) {
const insertStmt = `
INSERT INTO
Expand Down
40 changes: 40 additions & 0 deletions server/datastore/mysql/scripts_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ func TestScripts(t *testing.T) {
{"TestGetAnyScriptContents", testGetAnyScriptContents},
{"TestDeleteScriptsAssignedToPolicy", testDeleteScriptsAssignedToPolicy},
{"TestDeletePendingHostScriptExecutionsForPolicy", testDeletePendingHostScriptExecutionsForPolicy},
{"UpdateScriptContents", testUpdateScriptContents},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
Expand Down Expand Up @@ -1586,3 +1587,42 @@ func testDeletePendingHostScriptExecutionsForPolicy(t *testing.T, ds *Datastore)
)
require.Equal(t, 1, count)
}

func testUpdateScriptContents(t *testing.T, ds *Datastore) {
ctx := context.Background()

originalScript, err := ds.NewScript(ctx, &fleet.Script{
Name: "script1",
ScriptContents: "hello world",
})
require.NoError(t, err)

originalContents, err := ds.GetScriptContents(ctx, originalScript.ScriptContentID)
require.NoError(t, err)
require.Equal(t, "hello world", string(originalContents))

ExecAdhocSQL(t, ds, func(q sqlx.ExtContext) error {
_, err := q.ExecContext(ctx, "UPDATE scripts SET updated_at = ? WHERE id = ?", time.Now().Add(-2*time.Minute), originalScript.ID)
if err != nil {
return err
}
return nil
})

// Make sure updated_at was changed correctly, but the script is the same
oldScript, err := ds.Script(ctx, originalScript.ID)
require.Equal(t, originalScript.ScriptContentID, oldScript.ScriptContentID)
require.NoError(t, err)
require.NotEqual(t, originalScript.UpdatedAt, oldScript.UpdatedAt)

// Modify the script
updatedScript, err := ds.UpdateScriptContents(ctx, originalScript.ID, "updated script")
require.NoError(t, err)
require.Equal(t, originalScript.ID, updatedScript.ID)
require.Equal(t, originalScript.ScriptContentID, updatedScript.ScriptContentID)

updatedContents, err := ds.GetScriptContents(ctx, originalScript.ScriptContentID)
require.NoError(t, err)
require.Equal(t, "updated script", string(updatedContents))
require.NotEqual(t, oldScript.UpdatedAt, updatedScript.UpdatedAt)
}
22 changes: 22 additions & 0 deletions server/fleet/activities.go
Original file line number Diff line number Diff line change
Expand Up @@ -1332,6 +1332,28 @@ func (a ActivityTypeAddedScript) Documentation() (activity, details, detailsExam
}`
}

type ActivityTypeUpdatedScript struct {
ScriptName string `json:"script_name"`
TeamID *uint `json:"team_id"`
TeamName *string `json:"team_name"`
}

func (a ActivityTypeUpdatedScript) ActivityName() string {
return "updated_script"
}

func (a ActivityTypeUpdatedScript) Documentation() (activity, details, detailsExample string) {
return `Generated when a script is updated.`,
`This activity contains the following fields:
- "script_name": Name of the script.
- "team_id": The ID of the team that the script applies to, ` + "`null`" + ` if it applies to devices that are not in a team.
- "team_name": The name of the team that the script applies to, ` + "`null`" + ` if it applies to devices that are not in a team.`, `{
"script_name": "set-timezones.sh",
"team_id": 123,
"team_name": "Workstations"
}`
}

type ActivityTypeDeletedScript struct {
ScriptName string `json:"script_name"`
TeamID *uint `json:"team_id"`
Expand Down
3 changes: 3 additions & 0 deletions server/fleet/datastore.go
Original file line number Diff line number Diff line change
Expand Up @@ -1644,6 +1644,9 @@ type Datastore interface {
// NewScript creates a new saved script.
NewScript(ctx context.Context, script *Script) (*Script, error)

// UpdateScriptContents replaces the script contents of a script
UpdateScriptContents(ctx context.Context, scriptID uint, scriptContents string) (*Script, error)

// Script returns the saved script corresponding to id.
Script(ctx context.Context, id uint) (*Script, error)

Expand Down
3 changes: 3 additions & 0 deletions server/fleet/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -1123,6 +1123,9 @@ type Service interface {
// io.Reader r.
NewScript(ctx context.Context, teamID *uint, name string, r io.Reader) (*Script, error)

// UpdateScript updates a saved script with the contents of io.Reader r
UpdateScript(ctx context.Context, scriptID uint, r io.Reader) (*Script, error)

// DeleteScript deletes an existing (saved) script.
DeleteScript(ctx context.Context, scriptID uint) error

Expand Down
12 changes: 12 additions & 0 deletions server/mock/datastore_mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -1049,6 +1049,8 @@ type ListPendingHostScriptExecutionsFunc func(ctx context.Context, hostID uint,

type NewScriptFunc func(ctx context.Context, script *fleet.Script) (*fleet.Script, error)

type UpdateScriptContentsFunc func(ctx context.Context, scriptID uint, scriptContents string) (*fleet.Script, error)

type ScriptFunc func(ctx context.Context, id uint) (*fleet.Script, error)

type GetScriptContentsFunc func(ctx context.Context, id uint) ([]byte, error)
Expand Down Expand Up @@ -2752,6 +2754,9 @@ type DataStore struct {
NewScriptFunc NewScriptFunc
NewScriptFuncInvoked bool

UpdateScriptContentsFunc UpdateScriptContentsFunc
UpdateScriptContentsFuncInvoked bool

ScriptFunc ScriptFunc
ScriptFuncInvoked bool

Expand Down Expand Up @@ -6593,6 +6598,13 @@ func (s *DataStore) NewScript(ctx context.Context, script *fleet.Script) (*fleet
return s.NewScriptFunc(ctx, script)
}

func (s *DataStore) UpdateScriptContents(ctx context.Context, scriptID uint, scriptContents string) (*fleet.Script, error) {
s.mu.Lock()
s.UpdateScriptContentsFuncInvoked = true
s.mu.Unlock()
return s.UpdateScriptContentsFunc(ctx, scriptID, scriptContents)
}

func (s *DataStore) Script(ctx context.Context, id uint) (*fleet.Script, error) {
s.mu.Lock()
s.ScriptFuncInvoked = true
Expand Down
1 change: 1 addition & 0 deletions server/service/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,7 @@ func attachFleetAPIRoutes(r *mux.Router, svc fleet.Service, config config.FleetC
ue.POST("/api/_version_/fleet/scripts", createScriptEndpoint, createScriptRequest{})
ue.GET("/api/_version_/fleet/scripts", listScriptsEndpoint, listScriptsRequest{})
ue.GET("/api/_version_/fleet/scripts/{script_id:[0-9]+}", getScriptEndpoint, getScriptRequest{})
ue.PATCH("/api/_version_/fleet/scripts/{script_id:[0-9]+}", updateScriptEndpoint, updateScriptRequest{})
ue.DELETE("/api/_version_/fleet/scripts/{script_id:[0-9]+}", deleteScriptEndpoint, deleteScriptRequest{})
ue.POST("/api/_version_/fleet/scripts/batch", batchSetScriptsEndpoint, batchSetScriptsRequest{})

Expand Down
27 changes: 27 additions & 0 deletions server/service/integration_enterprise_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7200,6 +7200,33 @@ func (s *integrationEnterpriseTestSuite) TestSavedScripts() {
require.NotEqual(t, tmScriptID, newScriptResp.ScriptID)
s.lastActivityMatches("added_script", fmt.Sprintf(`{"script_name": %q, "team_name": %q, "team_id": %d}`, "script2.sh", tm.Name, tm.ID), 0)

// Update a script
updateScriptRep := updateScriptResponse{}
body, headers = generateNewScriptMultipartRequest(t,
"script1.sh", []byte(`echo "updated script"`), s.token, nil)
res = s.DoRawWithHeaders("PATCH", fmt.Sprintf("/api/latest/fleet/scripts/%d", tmScriptID), body.Bytes(), http.StatusOK, headers)
err = json.NewDecoder(res.Body).Decode(&updateScriptRep)
require.NoError(t, err)
require.NotZero(t, newScriptResp.ScriptID)
require.Equal(t, tmScriptID, updateScriptRep.ScriptID)
s.lastActivityMatches("updated_script", fmt.Sprintf(`{"script_name": %q, "team_name": %q, "team_id": %d}`, "script1.sh", tm.Name, tm.ID), 0)

// Download the updated script
res = s.Do("GET", fmt.Sprintf("/api/latest/fleet/scripts/%d", tmScriptID), nil, http.StatusOK, "alt", "media")
b, err = io.ReadAll(res.Body)
require.NoError(t, err)
require.Equal(t, `echo "updated script"`, string(b))
require.Equal(t, int64(len(`echo "updated script"`)), res.ContentLength)
require.Equal(t, fmt.Sprintf("attachment;filename=\"%s %s\"", time.Now().Format(time.DateOnly), "script1.sh"), res.Header.Get("Content-Disposition"))

// Try updating a non-existant script
updateScriptRep = updateScriptResponse{}
body, headers = generateNewScriptMultipartRequest(t,
"script1.sh", []byte(`echo "updated script"`), s.token, nil)
res = s.DoRawWithHeaders("PATCH", fmt.Sprintf("/api/latest/fleet/scripts/%d", 999999999999), body.Bytes(), http.StatusNotFound, headers)
err = json.NewDecoder(res.Body).Decode(&updateScriptRep)
require.NoError(t, err)

// delete the no-team script
s.Do("DELETE", fmt.Sprintf("/api/latest/fleet/scripts/%d", noTeamScriptID), nil, http.StatusNoContent)
s.lastActivityMatches("deleted_script", fmt.Sprintf(`{"script_name": %q, "team_name": null, "team_id": null}`, "script1.sh"), 0)
Expand Down
125 changes: 125 additions & 0 deletions server/service/scripts.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"github.com/fleetdm/fleet/v4/server/contexts/logging"
"github.com/fleetdm/fleet/v4/server/fleet"
"github.com/fleetdm/fleet/v4/server/ptr"
"github.com/gorilla/mux"
)

////////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -746,6 +747,130 @@ func (svc *Service) GetScript(ctx context.Context, scriptID uint, withContent bo
return script, content, nil
}

////////////////////////////////////////////////////////////////////////////////
// Update Script Contents
////////////////////////////////////////////////////////////////////////////////

type updateScriptRequest struct {
Script *multipart.FileHeader
ScriptID uint
}

func (updateScriptRequest) DecodeRequest(ctx context.Context, r *http.Request) (interface{}, error) {
var decoded updateScriptRequest

err := r.ParseMultipartForm(512 * units.MiB) // same in-memory size as for other multipart requests we have
if err != nil {
return nil, &fleet.BadRequestError{
Message: "failed to parse multipart form",
InternalErr: err,
}
}

vars := mux.Vars(r)
scriptIDStr, ok := vars["script_id"]
if !ok {
return nil, &fleet.BadRequestError{Message: "missing script id"}
}
scriptID, err := strconv.ParseUint(scriptIDStr, 10, 64)
if err != nil {
return nil, &fleet.BadRequestError{Message: "invalid script id"}
}
// Check if scriptID exceeds the maximum value for uint, code linter
if scriptID > uint64(^uint(0)) {
return nil, &fleet.BadRequestError{Message: "script id out of bounds"}
}

decoded.ScriptID = uint(scriptID)

fhs, ok := r.MultipartForm.File["script"]
if !ok || len(fhs) < 1 {
return nil, &fleet.BadRequestError{Message: "no file headers for script"}
}
decoded.Script = fhs[0]

return &decoded, nil
}

type updateScriptResponse struct {
Err error `json:"error,omitempty"`
ScriptID uint `json:"script_id,omitempty"`
}

func (r updateScriptResponse) error() error { return r.Err }

func updateScriptEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (errorer, error) {
req := request.(*updateScriptRequest)

scriptFile, err := req.Script.Open()
if err != nil {
return &updateScriptResponse{Err: err}, nil
}
defer scriptFile.Close()

script, err := svc.UpdateScript(ctx, req.ScriptID, scriptFile)
if err != nil {
return updateScriptResponse{Err: err}, nil
}
return updateScriptResponse{ScriptID: script.ID}, nil
}

func (svc *Service) UpdateScript(ctx context.Context, scriptID uint, r io.Reader) (*fleet.Script, error) {
script, err := svc.ds.Script(ctx, scriptID)
if err != nil {
svc.authz.SkipAuthorization(ctx)
return nil, ctxerr.Wrap(ctx, err, "finding original script to update")
}

if err := svc.authz.Authorize(ctx, &fleet.Script{TeamID: script.TeamID}, fleet.ActionWrite); err != nil {
return nil, err
}

b, err := io.ReadAll(r)
if err != nil {
return nil, ctxerr.Wrap(ctx, err, "read script contents")
}

scriptContents := file.Dos2UnixNewlines(string(b))

if err := svc.ds.ValidateEmbeddedSecrets(ctx, []string{scriptContents}); err != nil {
return nil, fleet.NewInvalidArgumentError("script", err.Error())
}

if err := fleet.ValidateHostScriptContents(scriptContents, true); err != nil {
return nil, fleet.NewInvalidArgumentError("script", err.Error())
}

// Update the script
savedScript, err := svc.ds.UpdateScriptContents(ctx, scriptID, scriptContents)
if err != nil {
return nil, ctxerr.Wrap(ctx, err, "updating script contents")
}

var teamName *string
if script.TeamID != nil && *script.TeamID != 0 {
tm, err := svc.EnterpriseOverrides.TeamByIDOrName(ctx, script.TeamID, nil)
if err != nil {
return nil, ctxerr.Wrap(ctx, err, "get team name for create script activity")
}
teamName = &tm.Name
}

if err := svc.NewActivity(
ctx,
authz.UserFromContext(ctx),
fleet.ActivityTypeUpdatedScript{
TeamID: script.TeamID,
TeamName: teamName,
ScriptName: script.Name,
},
); err != nil {
return nil, ctxerr.Wrap(ctx, err, "new activity for update script")
}

return savedScript, nil
}

////////////////////////////////////////////////////////////////////////////////
// Get Host Script Details
////////////////////////////////////////////////////////////////////////////////
Expand Down

0 comments on commit 3c8033f

Please sign in to comment.