Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion scripts/tee-mgmt-cli/cmd/cert.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ var certSetAWSCmd = &cobra.Command{
return fmt.Errorf("failed: %w", err)
}
fmt.Printf("TX: %s\n", txHash)
registry.PrintTxResult(client.WaitForTx(txHash), "AWS cert set")
success, reason := client.WaitForTx(txHash)
registry.PrintTxResult(success, reason, "AWS cert set")
return nil
},
}
Expand Down
3 changes: 2 additions & 1 deletion scripts/tee-mgmt-cli/cmd/heartbeat.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ var heartbeatSetMaxAgeCmd = &cobra.Command{
return fmt.Errorf("failed: %w", err)
}
fmt.Printf("TX: %s\n", txHash)
registry.PrintTxResult(client.WaitForTx(txHash), fmt.Sprintf("Heartbeat max age set to %s seconds", maxAge.String()))
success, reason := client.WaitForTx(txHash)
registry.PrintTxResult(success, reason, fmt.Sprintf("Heartbeat max age set to %s seconds", maxAge.String()))
return nil
},
}
Expand Down
6 changes: 4 additions & 2 deletions scripts/tee-mgmt-cli/cmd/pcr.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ var pcrApproveCmd = &cobra.Command{
return fmt.Errorf("failed: %w", err)
}
fmt.Printf("TX: %s\n", txHash)
registry.PrintTxResult(client.WaitForTx(txHash), "PCR approved")
success, reason := client.WaitForTx(txHash)
registry.PrintTxResult(success, reason, "PCR approved")
return nil
},
}
Expand All @@ -81,7 +82,8 @@ var pcrRevokeCmd = &cobra.Command{
return fmt.Errorf("failed: %w", err)
}
fmt.Printf("TX: %s\n", txHash)
registry.PrintTxResult(client.WaitForTx(txHash), "PCR revoked")
success, reason := client.WaitForTx(txHash)
registry.PrintTxResult(success, reason, "PCR revoked")
return nil
},
}
Expand Down
12 changes: 8 additions & 4 deletions scripts/tee-mgmt-cli/cmd/role.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ var roleGrantAdminCmd = &cobra.Command{
return fmt.Errorf("failed: %w", err)
}
fmt.Printf("TX: %s\n", txHash)
registry.PrintTxResult(client.WaitForTx(txHash), "Admin added")
success, reason := client.WaitForTx(txHash)
registry.PrintTxResult(success, reason, "Admin added")
return nil
},
}
Expand All @@ -43,7 +44,8 @@ var roleGrantOperatorCmd = &cobra.Command{
return fmt.Errorf("failed: %w", err)
}
fmt.Printf("TX: %s\n", txHash)
registry.PrintTxResult(client.WaitForTx(txHash), "Operator added")
success, reason := client.WaitForTx(txHash)
registry.PrintTxResult(success, reason, "Operator added")
return nil
},
}
Expand All @@ -60,7 +62,8 @@ var roleRevokeAdminCmd = &cobra.Command{
return fmt.Errorf("failed: %w", err)
}
fmt.Printf("TX: %s\n", txHash)
registry.PrintTxResult(client.WaitForTx(txHash), "Admin revoked")
success, reason := client.WaitForTx(txHash)
registry.PrintTxResult(success, reason, "Admin revoked")
return nil
},
}
Expand All @@ -77,7 +80,8 @@ var roleRevokeOperatorCmd = &cobra.Command{
return fmt.Errorf("failed: %w", err)
}
fmt.Printf("TX: %s\n", txHash)
registry.PrintTxResult(client.WaitForTx(txHash), "Operator revoked")
success, reason := client.WaitForTx(txHash)
registry.PrintTxResult(success, reason, "Operator revoked")
return nil
},
}
Expand Down
15 changes: 11 additions & 4 deletions scripts/tee-mgmt-cli/cmd/tee.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,10 +152,15 @@ var teeRegisterCmd = &cobra.Command{
}

fmt.Printf(" TX: %s\n", txHash)
if client.WaitForTx(txHash) {
success, revertReason := client.WaitForTx(txHash)
if success {
fmt.Printf("\nTEE registered! ID: 0x%s\n", hex.EncodeToString(expectedId[:]))
} else {
fmt.Println("\nRegistration failed")
if revertReason != "" {
fmt.Printf("\nRegistration failed: %s\n", revertReason)
} else {
fmt.Println("\nRegistration failed")
}
os.Exit(1)
}
return nil
Expand All @@ -179,7 +184,8 @@ var teeDisableCmd = &cobra.Command{
return fmt.Errorf("failed: %w", err)
}
fmt.Printf("TX: %s\n", txHash)
registry.PrintTxResult(client.WaitForTx(txHash), "TEE disabled")
success, reason := client.WaitForTx(txHash)
registry.PrintTxResult(success, reason, "TEE disabled")
return nil
},
}
Expand All @@ -201,7 +207,8 @@ var teeEnableCmd = &cobra.Command{
return fmt.Errorf("failed: %w", err)
}
fmt.Printf("TX: %s\n", txHash)
registry.PrintTxResult(client.WaitForTx(txHash), "TEE enabled")
success, reason := client.WaitForTx(txHash)
registry.PrintTxResult(success, reason, "TEE enabled")
return nil
},
}
Expand Down
3 changes: 2 additions & 1 deletion scripts/tee-mgmt-cli/cmd/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ var typeAddCmd = &cobra.Command{
return fmt.Errorf("failed: %w", err)
}
fmt.Printf("TX: %s\n", txHash)
registry.PrintTxResult(client.WaitForTx(txHash), "Type added")
success, reason := client.WaitForTx(txHash)
registry.PrintTxResult(success, reason, "Type added")
return nil
},
}
Expand Down
116 changes: 110 additions & 6 deletions scripts/tee-mgmt-cli/registry/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -453,18 +453,118 @@ func (c *Client) sendTxSigned(data []byte) (string, error) {
return signed.Hash().Hex(), nil
}

func (c *Client) WaitForTx(txHash string) bool {
func (c *Client) WaitForTx(txHash string) (bool, string) {
Log("Waiting for confirmation...")
for i := 0; i < 30; i++ {
resp, _ := c.rpcCall("eth_getTransactionReceipt", []string{txHash})
var result struct{ Result *struct{ Status string } }
var result struct {
Result *struct {
Status string `json:"status"`
BlockNumber string `json:"blockNumber"`
} `json:"result"`
}
json.Unmarshal(resp, &result)
if result.Result != nil {
return result.Result.Status == "0x1"
if result.Result.Status == "0x1" {
return true, ""
}
reason := c.getRevertReason(txHash, result.Result.BlockNumber)
return false, reason
}
time.Sleep(time.Second)
}
return false
return false, "timed out waiting for receipt"
}

func (c *Client) getRevertReason(txHash, blockNumber string) string {
// Fetch the original transaction to replay it
resp, err := c.rpcCall("eth_getTransactionByHash", []string{txHash})
if err != nil {
return ""
}
var txResult struct {
Result *struct {
From string `json:"from"`
To string `json:"to"`
Data string `json:"input"`
Value string `json:"value"`
Gas string `json:"gas"`
} `json:"result"`
}
json.Unmarshal(resp, &txResult)
if txResult.Result == nil {
return ""
}

// Replay the call at the block it was mined to get the revert data
callParams := map[string]string{
"from": txResult.Result.From,
"to": txResult.Result.To,
"data": txResult.Result.Data,
}
if txResult.Result.Value != "" {
callParams["value"] = txResult.Result.Value
}
if txResult.Result.Gas != "" {
callParams["gas"] = txResult.Result.Gas
}

resp, err = c.rpcCall("eth_call", []interface{}{callParams, blockNumber})
if err != nil {
return ""
}

var callResult struct {
Error *struct {
Message string `json:"message"`
Data string `json:"data"`
} `json:"error"`
Result string `json:"result"`
}
json.Unmarshal(resp, &callResult)

if callResult.Error != nil {
// Try to decode revert reason from error data
if callResult.Error.Data != "" {
if reason := decodeRevertReason(callResult.Error.Data); reason != "" {
return reason
}
}
return callResult.Error.Message
}

// Some nodes return revert data in the result field
if len(callResult.Result) > 2 {
if reason := decodeRevertReason(callResult.Result); reason != "" {
return reason
}
}

return "transaction reverted"
}

func decodeRevertReason(hexData string) string {
hexData = strings.TrimPrefix(hexData, "0x")
if len(hexData) < 8 {
return ""
}

data, err := hex.DecodeString(hexData)
if err != nil {
return ""
}

// Check for Error(string) selector: 0x08c379a0
if len(data) >= 68 && hex.EncodeToString(data[:4]) == "08c379a0" {
strT, _ := abi.NewType("string", "", nil)
args := abi.Arguments{{Type: strT}}
values, err := args.UnpackValues(data[4:])
if err == nil && len(values) > 0 {
return fmt.Sprintf("%v", values[0])
}
}

return ""
}

func (c *Client) rpcCall(method string, params interface{}) ([]byte, error) {
Expand Down Expand Up @@ -648,11 +748,15 @@ func Log(format string, args ...interface{}) {
fmt.Printf("[%s] %s\n", time.Now().Format("15:04:05"), fmt.Sprintf(format, args...))
}

func PrintTxResult(success bool, msg string) {
func PrintTxResult(success bool, revertReason string, msg string) {
if success {
fmt.Printf("%s\n", msg)
} else {
fmt.Println("Transaction failed")
if revertReason != "" {
fmt.Printf("Transaction failed: %s\n", revertReason)
} else {
fmt.Println("Transaction failed")
}
}
}

Expand Down
Loading