Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions exec_service/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ load("@rules_go//proto:def.bzl", "go_proto_library")
proto_library(
name = "exec_service_proto",
srcs = ["exec_service.proto"],
deps = ["@protobuf//:duration_proto"],
)

go_proto_library(
Expand Down
20 changes: 16 additions & 4 deletions exec_service/cmd/exec_client/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,15 +79,25 @@ func run(args []string, stdout, stderr io.Writer) int {
defer conn.Close()

client := pb.NewExecServiceClient(conn)
stream, err := client.RunCommand(context.Background(), &pb.StartCommandRequest{
CommandLine: cmdLine,
WorkingDir: *dir,
})
stream, err := client.RunCommand(context.Background())
if err != nil {
fmt.Fprintf(stderr, "RunCommand: %v\n", err)
return 1
}

if err := stream.Send(&pb.ClientEvent{
Event: &pb.ClientEvent_Start{
Start: &pb.StartCommandRequest{
CommandLine: cmdLine,
WorkingDir: *dir,
},
},
}); err != nil {
fmt.Fprintf(stderr, "send start: %v\n", err)
return 1
}
stream.CloseSend()

for {
ev, err := stream.Recv()
if err == io.EOF {
Expand All @@ -98,6 +108,8 @@ func run(args []string, stdout, stderr io.Writer) int {
return 1
}
switch e := ev.Event.(type) {
case *pb.ServerEvent_Started:
// Ignore for CLI usage.
case *pb.ServerEvent_Output:
stdout.Write(e.Output)
case *pb.ServerEvent_Exited:
Expand Down
15 changes: 12 additions & 3 deletions exec_service/cmd/exec_server/exec_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@ func collect(t *testing.T, stream pb.ExecService_RunCommandClient) (output strin
t.Fatalf("Recv: %v", err)
}
switch e := ev.Event.(type) {
case *pb.ServerEvent_Started:
// Ignore.
case *pb.ServerEvent_Output:
output += string(e.Output)
case *pb.ServerEvent_Exited:
Expand All @@ -119,12 +121,19 @@ func collect(t *testing.T, stream pb.ExecService_RunCommandClient) (output strin

func TestRunCommand(t *testing.T) {
client := startServer(t)
stream, err := client.RunCommand(context.Background(), &pb.StartCommandRequest{
CommandLine: "echo hello",
})
stream, err := client.RunCommand(context.Background())
if err != nil {
t.Fatal(err)
}
if err := stream.Send(&pb.ClientEvent{
Event: &pb.ClientEvent_Start{
Start: &pb.StartCommandRequest{CommandLine: "echo hello"},
},
}); err != nil {
t.Fatal(err)
}
stream.CloseSend()

output, exitCode, _ := collect(t, stream)
if exitCode != 0 {
t.Errorf("exit code = %d, want 0", exitCode)
Expand Down
38 changes: 36 additions & 2 deletions exec_service/exec_service.proto
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,26 @@ syntax = "proto3";

package exec_service;

import "google/protobuf/duration.proto";

option go_package = "github.com/google/agent-shell-tools/exec_service/execservicepb";

// A simple command execution service.
service ExecService {
// Executes a command and streams the output until completion.
rpc RunCommand(StartCommandRequest) returns (stream ServerEvent);
// Executes a command and streams output until completion. The first
// ClientEvent must be a StartCommandRequest; subsequent events may
// include TerminateCommand.
rpc RunCommand(stream ClientEvent) returns (stream ServerEvent);
}

message ClientEvent {
oneof event {
// Must be the first message on the stream.
StartCommandRequest start = 1;

// Request termination of the running command.
TerminateCommand terminate = 2;
}
}

message StartCommandRequest {
Expand All @@ -32,16 +46,36 @@ message StartCommandRequest {
string working_dir = 2;
}

message TerminateCommand {
// When false, send SIGTERM only (grace_period is ignored).
// When true, guarantee the process dies via SIGKILL. If grace_period
// is set, send SIGTERM first and SIGKILL after the grace period;
// otherwise send SIGKILL immediately.
bool force = 1;

// Only used when force is true. If non-zero, SIGTERM is sent first
// and SIGKILL follows after this duration.
google.protobuf.Duration grace_period = 2;
}

message ServerEvent {
oneof event {
// Incremental stdout/stderr from the command.
bytes output = 1;

// Sent exactly once when the command finishes.
ExitInfo exited = 2;

// Sent once after the command starts, before any output.
CommandStarted started = 3;
}
}

message CommandStarted {
// Opaque identifier for this command invocation.
string command_id = 1;
}

message ExitInfo {
int32 exit_code = 1;
// Populated only if the system failed to execute the command at all.
Expand Down
6 changes: 6 additions & 0 deletions exec_service/server/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ go_library(
visibility = ["//visibility:public"],
deps = [
"//exec_service:exec_service_go_proto",
"@org_golang_google_grpc//:grpc",
"@org_golang_google_grpc//codes",
"@org_golang_google_grpc//status",
],
)

Expand All @@ -31,7 +34,10 @@ go_test(
":server",
"//exec_service:exec_service_go_proto",
"@org_golang_google_grpc//:grpc",
"@org_golang_google_grpc//codes",
"@org_golang_google_grpc//credentials/insecure",
"@org_golang_google_grpc//status",
"@org_golang_google_grpc//test/bufconn",
"@org_golang_google_protobuf//types/known/durationpb",
],
)
122 changes: 107 additions & 15 deletions exec_service/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,18 @@
package server

import (
"crypto/rand"
"errors"
"fmt"
"io"
"os"
"os/exec"
"syscall"
"time"

pb "github.com/google/agent-shell-tools/exec_service/execservicepb"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)

// ExecServer implements the ExecService gRPC service.
Expand All @@ -33,9 +37,22 @@ type ExecServer struct {
}

// RunCommand executes a shell command and streams output events until the
// command exits. The command_line is interpreted by sh -c. Stdout and stderr
// are merged into the output stream.
func (s *ExecServer) RunCommand(req *pb.StartCommandRequest, stream pb.ExecService_RunCommandServer) error {
// command exits. The first ClientEvent must be a StartCommandRequest.
// Subsequent TerminateCommand events signal the running process.
func (s *ExecServer) RunCommand(stream pb.ExecService_RunCommandServer) error {
// The first message must be a start request.
first, err := stream.Recv()
if err != nil {
if err == io.EOF {
return status.Error(codes.InvalidArgument, "empty stream: first message must be StartCommandRequest")
}
return err
}
req := first.GetStart()
if req == nil {
return status.Error(codes.InvalidArgument, "first message must be StartCommandRequest")
}

cmd := exec.Command("sh", "-c", req.GetCommandLine())
cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}

Expand All @@ -57,32 +74,93 @@ func (s *ExecServer) RunCommand(req *pb.StartCommandRequest, stream pb.ExecServi
}
pw.Close()

killGroup := func() {
syscall.Kill(-cmd.Process.Pid, syscall.SIGKILL)
pid := cmd.Process.Pid
killPG := func() { syscall.Kill(-pid, syscall.SIGKILL) }
termPG := func() { syscall.Kill(-pid, syscall.SIGTERM) }

// Send CommandStarted with a random ID.
cmdID := randomID()
if err := stream.Send(&pb.ServerEvent{
Event: &pb.ServerEvent_Started{
Started: &pb.CommandStarted{CommandId: cmdID},
},
}); err != nil {
killPG()
cmd.Wait()
return err
}

// Wait for the command concurrently. When it exits, kill the
// process group so that background children close their inherited
// pipe fds, allowing the read loop to reach EOF cleanly. The
// deadline is a fallback for children that escaped the group
// (e.g. via setsid); it is not reached in the normal case
// because SIGKILL closes the pipe before the deadline fires.
// pipe fds, allowing the read loop to reach EOF cleanly. This is
// pipe cleanup, not termination policy — it runs regardless of how
// the process exited. The deadline is a fallback for children that
// escaped the group (e.g. via setsid).
waitCh := make(chan error, 1)
procDone := make(chan struct{})
go func() {
err := cmd.Wait()
killGroup()
close(procDone)
killPG()
pr.SetReadDeadline(time.Now().Add(100 * time.Millisecond))
waitCh <- err
}()

// Listen for client events (terminate) and context cancellation (disconnect).
protoErr := make(chan error, 1)
go func() {
for {
ev, err := stream.Recv()
if err != nil {
// io.EOF means the client closed its send side — normal,
// the command keeps running. Any other error means the
// stream is broken.
if err != io.EOF {
select {
case <-procDone:
default:
killPG()
}
}
return
}
switch ev.Event.(type) {
case *pb.ClientEvent_Terminate:
t := ev.GetTerminate()
select {
case <-procDone:
continue
default:
}
if !t.GetForce() {
termPG()
} else if gp := t.GetGracePeriod(); gp != nil && gp.AsDuration() > 0 {
termPG()
go func() {
select {
case <-time.After(gp.AsDuration()):
killPG()
case <-procDone:
}
}()
} else {
killPG()
}
default:
// Protocol violation: only terminate is valid after start.
killPG()
protoErr <- status.Error(codes.InvalidArgument, "only TerminateCommand is valid after StartCommandRequest")
return
}
}
}()

// Kill the process group if the client disconnects.
done := make(chan struct{})
defer close(done)
go func() {
select {
case <-stream.Context().Done():
killGroup()
case <-done:
killPG()
case <-procDone:
}
}()

Expand All @@ -93,7 +171,7 @@ func (s *ExecServer) RunCommand(req *pb.StartCommandRequest, stream pb.ExecServi
if sendErr := stream.Send(&pb.ServerEvent{
Event: &pb.ServerEvent_Output{Output: append([]byte(nil), buf[:n]...)},
}); sendErr != nil {
killGroup()
killPG()
<-waitCh
return sendErr
}
Expand All @@ -115,6 +193,14 @@ func (s *ExecServer) RunCommand(req *pb.StartCommandRequest, stream pb.ExecServi
}
}

// If the recv goroutine detected a protocol violation, return it
// instead of the normal exit event.
select {
case err := <-protoErr:
return err
default:
}

return stream.Send(&pb.ServerEvent{
Event: &pb.ServerEvent_Exited{
Exited: &pb.ExitInfo{
Expand All @@ -135,3 +221,9 @@ func sendError(stream pb.ExecService_RunCommandServer, msg string) error {
},
})
}

func randomID() string {
b := make([]byte, 16)
io.ReadFull(rand.Reader, b)
return fmt.Sprintf("%x", b)
}
Loading
Loading