Skip to content

Commit

Permalink
Merge pull request #1 from Ezzahhh/sig-terminate
Browse files Browse the repository at this point in the history
feat: cleaner terminate with signals
  • Loading branch information
Ezzahhh authored Jan 7, 2025
2 parents e0d5203 + 8300348 commit 0711d19
Show file tree
Hide file tree
Showing 6 changed files with 94 additions and 101 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ terraform {
required_providers {
tunnel = {
source = "Ezzahhh/tunnel"
version = ">= 1.2.0"
version = ">= 1.3.0"
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion internal/provider/common.go → internal/libs/common.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package provider
package libs

import (
"fmt"
Expand Down
81 changes: 81 additions & 0 deletions internal/libs/watch.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
package libs

import (
"fmt"
"log"
"os"
"runtime"
"strconv"
"syscall"
"time"

ps "github.com/shirou/gopsutil/v4/process"
)

func WatchProcess(pid string) (err error) {
pidInt, err := strconv.Atoi(pid)
if err != nil {
return fmt.Errorf("invalid PID: %v", err)
}
parent, err := ps.NewProcess(int32(pidInt))
if err != nil {
return err
}
child, err := ps.NewProcess(int32(os.Getpid()))
if err != nil {
return err
}
// pool for parent process liveliness every 2 seconds
go func() {
for {
_, err := parent.Status()
if err != nil {
log.Printf("parent process exited: %v\n", err)
if runtime.GOOS == "windows" {
err = child.Terminate()
} else {
err = child.SendSignal(syscall.SIGINT)
}
if err != nil {
log.Printf("failed to terminate process: %v\n", err)
}
}
time.Sleep(2 * time.Second)
}
}()

return nil
}

func CheckProcessExists(pid int) error {
cmd, err := ps.NewProcess(int32(pid))
if err != nil {
return err
}
stats, err := cmd.Status()
if err != nil {
return err
}
if stats[0] == "zombie" {
return fmt.Errorf("process died")
}

return nil
}

func Interrupt(pid int) error {
cmd, err := ps.NewProcess(int32(pid))
if err != nil {
return err
}
if runtime.GOOS == "windows" {
err = cmd.Terminate()
} else {
err = cmd.SendSignal(syscall.SIGINT)
}
if err != nil {
return err
}

return nil
}
3 changes: 2 additions & 1 deletion internal/provider/data_source_ssm.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"strconv"

"github.com/Ezzahhh/terraform-provider-tunnel/internal/libs"
"github.com/Ezzahhh/terraform-provider-tunnel/internal/ssm"
"github.com/hashicorp/terraform-plugin-framework/datasource"
"github.com/hashicorp/terraform-plugin-framework/datasource/schema"
Expand Down Expand Up @@ -87,7 +88,7 @@ func (d *SSMDataSource) Read(ctx context.Context, req datasource.ReadRequest, re
}

// Get a free port for the local tunnel
localPort, err := GetFreePort()
localPort, err := libs.GetFreePort()
if err != nil {
resp.Diagnostics.AddError("Failed to find open port", fmt.Sprintf("Error: %s", err))
return
Expand Down
61 changes: 5 additions & 56 deletions internal/provider/ephemeral_ssm.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,14 @@ package provider

import (
"context"
"encoding/json"
"fmt"
"strconv"

"github.com/Ezzahhh/terraform-provider-tunnel/internal/libs"
"github.com/Ezzahhh/terraform-provider-tunnel/internal/ssm"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
aws_ssm "github.com/aws/aws-sdk-go-v2/service/ssm"
"github.com/hashicorp/terraform-plugin-framework/ephemeral"
"github.com/hashicorp/terraform-plugin-framework/ephemeral/schema"
"github.com/hashicorp/terraform-plugin-framework/types"
ps "github.com/shirou/gopsutil/v4/process"
)

// Ensure provider defined types fully satisfy framework interfaces.
Expand Down Expand Up @@ -92,7 +88,7 @@ func (d *SSMEphemeral) Open(ctx context.Context, req ephemeral.OpenRequest, resp
}

// Get a free port for the local tunnel
localPort, err := GetFreePort()
localPort, err := libs.GetFreePort()
if err != nil {
resp.Diagnostics.AddError("Failed to find open port", fmt.Sprintf("Error: %s", err))
return
Expand All @@ -103,8 +99,7 @@ func (d *SSMEphemeral) Open(ctx context.Context, req ephemeral.OpenRequest, resp
data.LocalHost = types.StringValue("localhost")
data.LocalPort = types.Int64Value(int64(localPort))

forkResult, err := ssm.ForkRemoteTunnel(ctx, ssm.TunnelConfig{
SSMProfile: data.SSMProfile.ValueString(),
cmd, err := ssm.ForkRemoteTunnel(ctx, ssm.TunnelConfig{
SSMRegion: data.SSMRegion.ValueString(),
SSMInstance: data.SSMInstance.ValueString(),
TargetHost: data.TargetHost.ValueString(),
Expand All @@ -116,21 +111,9 @@ func (d *SSMEphemeral) Open(ctx context.Context, req ephemeral.OpenRequest, resp
return
}

sessionID := forkResult.Session.SessionId

// Encode the session ID string as JSON
sessionIDBytes, err := json.Marshal(sessionID)
if err != nil {
resp.Diagnostics.AddError("Failed to serialise JSON for session ID", fmt.Sprintf("Error: %s", err))
return
}

// Save data into Terraform state
resp.Diagnostics.Append(resp.Result.Set(ctx, &data)...)
resp.Private.SetKey(ctx, "tunnel_pid", []byte(strconv.Itoa(forkResult.Command.Process.Pid)))
resp.Private.SetKey(ctx, "session_id", sessionIDBytes)
resp.Private.SetKey(ctx, "ssm_region", []byte(data.SSMRegion.ValueString()))
resp.Private.SetKey(ctx, "ssm_profile", []byte(data.SSMProfile.ValueString()))
resp.Private.SetKey(ctx, "tunnel_pid", []byte(strconv.Itoa(cmd.Process.Pid)))
}

func (d *SSMEphemeral) Close(ctx context.Context, req ephemeral.CloseRequest, resp *ephemeral.CloseResponse) {
Expand All @@ -141,42 +124,8 @@ func (d *SSMEphemeral) Close(ctx context.Context, req ephemeral.CloseRequest, re
return
}

tunnel, err := ps.NewProcess(int32(tunnelPID))
if err != nil {
resp.Diagnostics.AddError("Failed to find tunnel process", fmt.Sprintf("Error: %s", err))
return
}

if err := tunnel.Terminate(); err != nil {
if err := libs.Interrupt(tunnelPID); err != nil {
resp.Diagnostics.AddError("Failed to terminate tunnel process", fmt.Sprintf("Error: %s", err))
return
}

sessionIDBytes, _ := req.Private.GetKey(ctx, "session_id")
var sessionID string
if err := json.Unmarshal(sessionIDBytes, &sessionID); err != nil {
resp.Diagnostics.AddError("Failed to decode session ID JSON", fmt.Sprintf("Error: %s", err))
return
}
ssmRegion, _ := req.Private.GetKey(ctx, "ssm_region")
if len(sessionID) < 1 {
resp.Diagnostics.AddWarning("Cannot close SSM tunnel", "SessionID length received is 0")
return
}
ssmProfile, _ := req.Private.GetKey(ctx, "ssm_profile")
awsCfg, err := config.LoadDefaultConfig(ctx, config.WithRegion(string(ssmRegion)), config.WithSharedConfigProfile(string(ssmProfile)))
if err != nil {
resp.Diagnostics.AddError("Failed to load AWS config", fmt.Sprintf("Error: %s", err))
return
}

ssmClient := aws_ssm.NewFromConfig(awsCfg)

_, err = ssmClient.TerminateSession(ctx, &aws_ssm.TerminateSessionInput{
SessionId: aws.String(sessionID),
})
if err != nil {
resp.Diagnostics.AddError("Failed to terminate SSM session", fmt.Sprintf("Error: %s", err))
return
}
}
46 changes: 4 additions & 42 deletions internal/ssm/tunnel.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,13 @@ import (
"strconv"
"time"

"github.com/Ezzahhh/terraform-provider-tunnel/internal/libs"
"github.com/aws/aws-sdk-go-v2/service/ssm"
pluginSession "github.com/aws/session-manager-plugin/src/sessionmanagerplugin/session"
_ "github.com/aws/session-manager-plugin/src/sessionmanagerplugin/session/portsession"
"github.com/aws/smithy-go/ptr"
ps "github.com/shirou/gopsutil/v4/process"
)

type ForkRemoteResult struct {
Command *exec.Cmd
Session SessionParams
}

func GetEndpoint(ctx context.Context, region string) (string, error) {
resolver := ssm.NewDefaultEndpointResolverV2()
endpoint, err := resolver.ResolveEndpoint(ctx, ssm.EndpointParameters{
Expand All @@ -33,37 +28,7 @@ func GetEndpoint(ctx context.Context, region string) (string, error) {
return endpoint.URI.String(), nil
}

func WatchProcess(pid string) (err error) {
pidInt, err := strconv.Atoi(pid)
if err != nil {
return fmt.Errorf("invalid PID: %v", err)
}
parent, err := ps.NewProcess(int32(pidInt))
if err != nil {
return err
}
child, err := ps.NewProcess(int32(os.Getpid()))
if err != nil {
return err
}
// pool for parent process liveliness every 2 seconds
go func() {
for {
_, err := parent.Status()
if err != nil {
fmt.Printf("parent process exited: %v\n", err)
if err := child.Terminate(); err != nil {
fmt.Printf("failed to terminate process: %v\n", err)
}
}
time.Sleep(2 * time.Second)
}
}()

return nil
}

func ForkRemoteTunnel(ctx context.Context, cfg TunnelConfig) (*ForkRemoteResult, error) {
func ForkRemoteTunnel(ctx context.Context, cfg TunnelConfig) (*exec.Cmd, error) {
// First we start a session using AWS SDK
// see https://github.com/aws/aws-cli/blob/master/awscli/customizations/sessionmanager.py#L104
sessionParams, err := StartTunnelSession(ctx, cfg)
Expand Down Expand Up @@ -102,15 +67,12 @@ func ForkRemoteTunnel(ctx context.Context, cfg TunnelConfig) (*ForkRemoteResult,
return nil, err
}

return &ForkRemoteResult{
Command: cmd,
Session: sessionParams,
}, nil
return cmd, nil
}

func StartRemoteTunnel(ctx context.Context, cfg TunnelConfig, parentPid string) (err error) {
// Watch parent process lifecycle ie. main terraform process
err = WatchProcess(parentPid)
err = libs.WatchProcess(parentPid)
if err != nil {
return err
}
Expand Down

0 comments on commit 0711d19

Please sign in to comment.