From 3c8033fa8ed7af55fbc3210250d483847b33ea01 Mon Sep 17 00:00:00 2001 From: Dante Catalfamo <43040593+dantecatalfamo@users.noreply.github.com> Date: Thu, 30 Jan 2025 13:01:51 -0500 Subject: [PATCH] Edible Scripts Backend (#25739) #24602 --- changes/24602-editable-scripts | 1 + server/datastore/mysql/scripts.go | 25 ++++ server/datastore/mysql/scripts_test.go | 40 ++++++ server/fleet/activities.go | 22 +++ server/fleet/datastore.go | 3 + server/fleet/service.go | 3 + server/mock/datastore_mock.go | 12 ++ server/service/handler.go | 1 + server/service/integration_enterprise_test.go | 27 ++++ server/service/scripts.go | 125 ++++++++++++++++++ 10 files changed, 259 insertions(+) create mode 100644 changes/24602-editable-scripts diff --git a/changes/24602-editable-scripts b/changes/24602-editable-scripts new file mode 100644 index 000000000000..c6380edcbad0 --- /dev/null +++ b/changes/24602-editable-scripts @@ -0,0 +1 @@ +- Added API endpoint for updating script contents diff --git a/server/datastore/mysql/scripts.go b/server/datastore/mysql/scripts.go index c0c5280823b2..544d74a097ea 100644 --- a/server/datastore/mysql/scripts.go +++ b/server/datastore/mysql/scripts.go @@ -339,6 +339,31 @@ func (ds *Datastore) NewScript(ctx context.Context, script *fleet.Script) (*flee return ds.getScriptDB(ctx, ds.writer(ctx), uint(id)) //nolint:gosec // dismiss G115 } +func (ds *Datastore) UpdateScriptContents(ctx context.Context, scriptID uint, scriptContents string) (*fleet.Script, error) { + const stmt = ` +UPDATE script_contents +INNER JOIN + scripts ON scripts.script_content_id = script_contents.id +SET + contents = ?, + md5_checksum = UNHEX(?) +WHERE + scripts.id = ? +` + md5Checksum := md5ChecksumScriptContent(scriptContents) + + _, err := ds.writer(ctx).ExecContext(ctx, stmt, scriptContents, md5Checksum, scriptID) + if err != nil { + return nil, ctxerr.Wrap(ctx, err, "updating script contents") + } + + if _, err := ds.writer(ctx).ExecContext(ctx, "UPDATE scripts SET updated_at = NOW() WHERE id = ?", scriptID); err != nil { + return nil, ctxerr.Wrap(ctx, err, "updating script updated_at time") + } + + return ds.Script(ctx, scriptID) +} + func insertScript(ctx context.Context, tx sqlx.ExtContext, script *fleet.Script, scriptContentsID uint) (sql.Result, error) { const insertStmt = ` INSERT INTO diff --git a/server/datastore/mysql/scripts_test.go b/server/datastore/mysql/scripts_test.go index 197732d04125..879f21e43532 100644 --- a/server/datastore/mysql/scripts_test.go +++ b/server/datastore/mysql/scripts_test.go @@ -40,6 +40,7 @@ func TestScripts(t *testing.T) { {"TestGetAnyScriptContents", testGetAnyScriptContents}, {"TestDeleteScriptsAssignedToPolicy", testDeleteScriptsAssignedToPolicy}, {"TestDeletePendingHostScriptExecutionsForPolicy", testDeletePendingHostScriptExecutionsForPolicy}, + {"UpdateScriptContents", testUpdateScriptContents}, } for _, c := range cases { t.Run(c.name, func(t *testing.T) { @@ -1586,3 +1587,42 @@ func testDeletePendingHostScriptExecutionsForPolicy(t *testing.T, ds *Datastore) ) require.Equal(t, 1, count) } + +func testUpdateScriptContents(t *testing.T, ds *Datastore) { + ctx := context.Background() + + originalScript, err := ds.NewScript(ctx, &fleet.Script{ + Name: "script1", + ScriptContents: "hello world", + }) + require.NoError(t, err) + + originalContents, err := ds.GetScriptContents(ctx, originalScript.ScriptContentID) + require.NoError(t, err) + require.Equal(t, "hello world", string(originalContents)) + + ExecAdhocSQL(t, ds, func(q sqlx.ExtContext) error { + _, err := q.ExecContext(ctx, "UPDATE scripts SET updated_at = ? WHERE id = ?", time.Now().Add(-2*time.Minute), originalScript.ID) + if err != nil { + return err + } + return nil + }) + + // Make sure updated_at was changed correctly, but the script is the same + oldScript, err := ds.Script(ctx, originalScript.ID) + require.Equal(t, originalScript.ScriptContentID, oldScript.ScriptContentID) + require.NoError(t, err) + require.NotEqual(t, originalScript.UpdatedAt, oldScript.UpdatedAt) + + // Modify the script + updatedScript, err := ds.UpdateScriptContents(ctx, originalScript.ID, "updated script") + require.NoError(t, err) + require.Equal(t, originalScript.ID, updatedScript.ID) + require.Equal(t, originalScript.ScriptContentID, updatedScript.ScriptContentID) + + updatedContents, err := ds.GetScriptContents(ctx, originalScript.ScriptContentID) + require.NoError(t, err) + require.Equal(t, "updated script", string(updatedContents)) + require.NotEqual(t, oldScript.UpdatedAt, updatedScript.UpdatedAt) +} diff --git a/server/fleet/activities.go b/server/fleet/activities.go index 4d184426f7df..aaa41ef5d943 100644 --- a/server/fleet/activities.go +++ b/server/fleet/activities.go @@ -1332,6 +1332,28 @@ func (a ActivityTypeAddedScript) Documentation() (activity, details, detailsExam }` } +type ActivityTypeUpdatedScript struct { + ScriptName string `json:"script_name"` + TeamID *uint `json:"team_id"` + TeamName *string `json:"team_name"` +} + +func (a ActivityTypeUpdatedScript) ActivityName() string { + return "updated_script" +} + +func (a ActivityTypeUpdatedScript) Documentation() (activity, details, detailsExample string) { + return `Generated when a script is updated.`, + `This activity contains the following fields: +- "script_name": Name of the script. +- "team_id": The ID of the team that the script applies to, ` + "`null`" + ` if it applies to devices that are not in a team. +- "team_name": The name of the team that the script applies to, ` + "`null`" + ` if it applies to devices that are not in a team.`, `{ + "script_name": "set-timezones.sh", + "team_id": 123, + "team_name": "Workstations" +}` +} + type ActivityTypeDeletedScript struct { ScriptName string `json:"script_name"` TeamID *uint `json:"team_id"` diff --git a/server/fleet/datastore.go b/server/fleet/datastore.go index e05965137fcd..f2ded0aca910 100644 --- a/server/fleet/datastore.go +++ b/server/fleet/datastore.go @@ -1644,6 +1644,9 @@ type Datastore interface { // NewScript creates a new saved script. NewScript(ctx context.Context, script *Script) (*Script, error) + // UpdateScriptContents replaces the script contents of a script + UpdateScriptContents(ctx context.Context, scriptID uint, scriptContents string) (*Script, error) + // Script returns the saved script corresponding to id. Script(ctx context.Context, id uint) (*Script, error) diff --git a/server/fleet/service.go b/server/fleet/service.go index bf7b833e1026..26a91765f6dd 100644 --- a/server/fleet/service.go +++ b/server/fleet/service.go @@ -1123,6 +1123,9 @@ type Service interface { // io.Reader r. NewScript(ctx context.Context, teamID *uint, name string, r io.Reader) (*Script, error) + // UpdateScript updates a saved script with the contents of io.Reader r + UpdateScript(ctx context.Context, scriptID uint, r io.Reader) (*Script, error) + // DeleteScript deletes an existing (saved) script. DeleteScript(ctx context.Context, scriptID uint) error diff --git a/server/mock/datastore_mock.go b/server/mock/datastore_mock.go index e5306955b41a..396c549243ba 100644 --- a/server/mock/datastore_mock.go +++ b/server/mock/datastore_mock.go @@ -1049,6 +1049,8 @@ type ListPendingHostScriptExecutionsFunc func(ctx context.Context, hostID uint, type NewScriptFunc func(ctx context.Context, script *fleet.Script) (*fleet.Script, error) +type UpdateScriptContentsFunc func(ctx context.Context, scriptID uint, scriptContents string) (*fleet.Script, error) + type ScriptFunc func(ctx context.Context, id uint) (*fleet.Script, error) type GetScriptContentsFunc func(ctx context.Context, id uint) ([]byte, error) @@ -2752,6 +2754,9 @@ type DataStore struct { NewScriptFunc NewScriptFunc NewScriptFuncInvoked bool + UpdateScriptContentsFunc UpdateScriptContentsFunc + UpdateScriptContentsFuncInvoked bool + ScriptFunc ScriptFunc ScriptFuncInvoked bool @@ -6593,6 +6598,13 @@ func (s *DataStore) NewScript(ctx context.Context, script *fleet.Script) (*fleet return s.NewScriptFunc(ctx, script) } +func (s *DataStore) UpdateScriptContents(ctx context.Context, scriptID uint, scriptContents string) (*fleet.Script, error) { + s.mu.Lock() + s.UpdateScriptContentsFuncInvoked = true + s.mu.Unlock() + return s.UpdateScriptContentsFunc(ctx, scriptID, scriptContents) +} + func (s *DataStore) Script(ctx context.Context, id uint) (*fleet.Script, error) { s.mu.Lock() s.ScriptFuncInvoked = true diff --git a/server/service/handler.go b/server/service/handler.go index 005884a21bea..f0aee564d621 100644 --- a/server/service/handler.go +++ b/server/service/handler.go @@ -513,6 +513,7 @@ func attachFleetAPIRoutes(r *mux.Router, svc fleet.Service, config config.FleetC ue.POST("/api/_version_/fleet/scripts", createScriptEndpoint, createScriptRequest{}) ue.GET("/api/_version_/fleet/scripts", listScriptsEndpoint, listScriptsRequest{}) ue.GET("/api/_version_/fleet/scripts/{script_id:[0-9]+}", getScriptEndpoint, getScriptRequest{}) + ue.PATCH("/api/_version_/fleet/scripts/{script_id:[0-9]+}", updateScriptEndpoint, updateScriptRequest{}) ue.DELETE("/api/_version_/fleet/scripts/{script_id:[0-9]+}", deleteScriptEndpoint, deleteScriptRequest{}) ue.POST("/api/_version_/fleet/scripts/batch", batchSetScriptsEndpoint, batchSetScriptsRequest{}) diff --git a/server/service/integration_enterprise_test.go b/server/service/integration_enterprise_test.go index b44cd0cd1e1f..822cba9e51e4 100644 --- a/server/service/integration_enterprise_test.go +++ b/server/service/integration_enterprise_test.go @@ -7200,6 +7200,33 @@ func (s *integrationEnterpriseTestSuite) TestSavedScripts() { require.NotEqual(t, tmScriptID, newScriptResp.ScriptID) s.lastActivityMatches("added_script", fmt.Sprintf(`{"script_name": %q, "team_name": %q, "team_id": %d}`, "script2.sh", tm.Name, tm.ID), 0) + // Update a script + updateScriptRep := updateScriptResponse{} + body, headers = generateNewScriptMultipartRequest(t, + "script1.sh", []byte(`echo "updated script"`), s.token, nil) + res = s.DoRawWithHeaders("PATCH", fmt.Sprintf("/api/latest/fleet/scripts/%d", tmScriptID), body.Bytes(), http.StatusOK, headers) + err = json.NewDecoder(res.Body).Decode(&updateScriptRep) + require.NoError(t, err) + require.NotZero(t, newScriptResp.ScriptID) + require.Equal(t, tmScriptID, updateScriptRep.ScriptID) + s.lastActivityMatches("updated_script", fmt.Sprintf(`{"script_name": %q, "team_name": %q, "team_id": %d}`, "script1.sh", tm.Name, tm.ID), 0) + + // Download the updated script + res = s.Do("GET", fmt.Sprintf("/api/latest/fleet/scripts/%d", tmScriptID), nil, http.StatusOK, "alt", "media") + b, err = io.ReadAll(res.Body) + require.NoError(t, err) + require.Equal(t, `echo "updated script"`, string(b)) + require.Equal(t, int64(len(`echo "updated script"`)), res.ContentLength) + require.Equal(t, fmt.Sprintf("attachment;filename=\"%s %s\"", time.Now().Format(time.DateOnly), "script1.sh"), res.Header.Get("Content-Disposition")) + + // Try updating a non-existant script + updateScriptRep = updateScriptResponse{} + body, headers = generateNewScriptMultipartRequest(t, + "script1.sh", []byte(`echo "updated script"`), s.token, nil) + res = s.DoRawWithHeaders("PATCH", fmt.Sprintf("/api/latest/fleet/scripts/%d", 999999999999), body.Bytes(), http.StatusNotFound, headers) + err = json.NewDecoder(res.Body).Decode(&updateScriptRep) + require.NoError(t, err) + // delete the no-team script s.Do("DELETE", fmt.Sprintf("/api/latest/fleet/scripts/%d", noTeamScriptID), nil, http.StatusNoContent) s.lastActivityMatches("deleted_script", fmt.Sprintf(`{"script_name": %q, "team_name": null, "team_id": null}`, "script1.sh"), 0) diff --git a/server/service/scripts.go b/server/service/scripts.go index d2fd67703cd6..299390cf3e49 100644 --- a/server/service/scripts.go +++ b/server/service/scripts.go @@ -20,6 +20,7 @@ import ( "github.com/fleetdm/fleet/v4/server/contexts/logging" "github.com/fleetdm/fleet/v4/server/fleet" "github.com/fleetdm/fleet/v4/server/ptr" + "github.com/gorilla/mux" ) //////////////////////////////////////////////////////////////////////////////// @@ -746,6 +747,130 @@ func (svc *Service) GetScript(ctx context.Context, scriptID uint, withContent bo return script, content, nil } +//////////////////////////////////////////////////////////////////////////////// +// Update Script Contents +//////////////////////////////////////////////////////////////////////////////// + +type updateScriptRequest struct { + Script *multipart.FileHeader + ScriptID uint +} + +func (updateScriptRequest) DecodeRequest(ctx context.Context, r *http.Request) (interface{}, error) { + var decoded updateScriptRequest + + err := r.ParseMultipartForm(512 * units.MiB) // same in-memory size as for other multipart requests we have + if err != nil { + return nil, &fleet.BadRequestError{ + Message: "failed to parse multipart form", + InternalErr: err, + } + } + + vars := mux.Vars(r) + scriptIDStr, ok := vars["script_id"] + if !ok { + return nil, &fleet.BadRequestError{Message: "missing script id"} + } + scriptID, err := strconv.ParseUint(scriptIDStr, 10, 64) + if err != nil { + return nil, &fleet.BadRequestError{Message: "invalid script id"} + } + // Check if scriptID exceeds the maximum value for uint, code linter + if scriptID > uint64(^uint(0)) { + return nil, &fleet.BadRequestError{Message: "script id out of bounds"} + } + + decoded.ScriptID = uint(scriptID) + + fhs, ok := r.MultipartForm.File["script"] + if !ok || len(fhs) < 1 { + return nil, &fleet.BadRequestError{Message: "no file headers for script"} + } + decoded.Script = fhs[0] + + return &decoded, nil +} + +type updateScriptResponse struct { + Err error `json:"error,omitempty"` + ScriptID uint `json:"script_id,omitempty"` +} + +func (r updateScriptResponse) error() error { return r.Err } + +func updateScriptEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (errorer, error) { + req := request.(*updateScriptRequest) + + scriptFile, err := req.Script.Open() + if err != nil { + return &updateScriptResponse{Err: err}, nil + } + defer scriptFile.Close() + + script, err := svc.UpdateScript(ctx, req.ScriptID, scriptFile) + if err != nil { + return updateScriptResponse{Err: err}, nil + } + return updateScriptResponse{ScriptID: script.ID}, nil +} + +func (svc *Service) UpdateScript(ctx context.Context, scriptID uint, r io.Reader) (*fleet.Script, error) { + script, err := svc.ds.Script(ctx, scriptID) + if err != nil { + svc.authz.SkipAuthorization(ctx) + return nil, ctxerr.Wrap(ctx, err, "finding original script to update") + } + + if err := svc.authz.Authorize(ctx, &fleet.Script{TeamID: script.TeamID}, fleet.ActionWrite); err != nil { + return nil, err + } + + b, err := io.ReadAll(r) + if err != nil { + return nil, ctxerr.Wrap(ctx, err, "read script contents") + } + + scriptContents := file.Dos2UnixNewlines(string(b)) + + if err := svc.ds.ValidateEmbeddedSecrets(ctx, []string{scriptContents}); err != nil { + return nil, fleet.NewInvalidArgumentError("script", err.Error()) + } + + if err := fleet.ValidateHostScriptContents(scriptContents, true); err != nil { + return nil, fleet.NewInvalidArgumentError("script", err.Error()) + } + + // Update the script + savedScript, err := svc.ds.UpdateScriptContents(ctx, scriptID, scriptContents) + if err != nil { + return nil, ctxerr.Wrap(ctx, err, "updating script contents") + } + + var teamName *string + if script.TeamID != nil && *script.TeamID != 0 { + tm, err := svc.EnterpriseOverrides.TeamByIDOrName(ctx, script.TeamID, nil) + if err != nil { + return nil, ctxerr.Wrap(ctx, err, "get team name for create script activity") + } + teamName = &tm.Name + } + + if err := svc.NewActivity( + ctx, + authz.UserFromContext(ctx), + fleet.ActivityTypeUpdatedScript{ + TeamID: script.TeamID, + TeamName: teamName, + ScriptName: script.Name, + }, + ); err != nil { + return nil, ctxerr.Wrap(ctx, err, "new activity for update script") + } + + return savedScript, nil +} + //////////////////////////////////////////////////////////////////////////////// // Get Host Script Details ////////////////////////////////////////////////////////////////////////////////