Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
masahide committed Sep 22, 2024
1 parent f1d27e4 commit 21de990
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 74 deletions.
69 changes: 57 additions & 12 deletions cmd/wsl2-ssh-agent-proxy/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"encoding/binary"
"errors"
"flag"
"fmt"
"io"
"log"
"net"
Expand All @@ -20,9 +21,25 @@ import (
"syscall"
)

const maxAgentResponseBytes = 16 << 20

//go:embed pwsh.ps1
var pwshScript string
var debug bool
var byteOrder binary.ByteOrder // エンディアンを保持する変数

func isLittleEndian() bool {
var i int32 = 0x01020304
buf := make([]byte, 4)
binary.LittleEndian.PutUint32(buf, uint32(i))
return buf[0] == 0x04
}
func init() {
byteOrder = map[bool]binary.ByteOrder{
true: binary.LittleEndian,
false: binary.BigEndian,
}[isLittleEndian()]
}

const (
HeaderSize = 12 // 4 bytes for channel ID, 4 bytes for message length
Expand Down Expand Up @@ -78,9 +95,9 @@ func (mux *Multiplexer) readLoop(ctx context.Context) error {
log.Printf("mux readFull header:[%v]", header)
}

packetType := binary.LittleEndian.Uint32(header[:4])
channelID := binary.LittleEndian.Uint32(header[4:])
length := binary.LittleEndian.Uint32(header[8:])
packetType := byteOrder.Uint32(header[:4])
channelID := byteOrder.Uint32(header[4:])
length := byteOrder.Uint32(header[8:])
payload := make([]byte, length)
_, err = io.ReadFull(mux.reader, payload)
if err != nil {
Expand Down Expand Up @@ -112,9 +129,9 @@ func (mux *Multiplexer) readLoop(ctx context.Context) error {

func (mux *Multiplexer) WriteChannel(packet Packet) error {
buf := make([]byte, HeaderSize+len(packet.Payload))
binary.LittleEndian.PutUint32(buf[0:4], packet.PacketType)
binary.LittleEndian.PutUint32(buf[4:8], packet.ChannelID)
binary.LittleEndian.PutUint32(buf[8:HeaderSize], uint32(len(packet.Payload)))
byteOrder.PutUint32(buf[0:4], packet.PacketType)
byteOrder.PutUint32(buf[4:8], packet.ChannelID)
byteOrder.PutUint32(buf[8:HeaderSize], uint32(len(packet.Payload)))
copy(buf[HeaderSize:], packet.Payload)
_, err := mux.writer.Write(buf)
//n, err := mux.writer.Write(buf)
Expand All @@ -141,6 +158,37 @@ func (mux *Multiplexer) CloseChannel(channelID uint32) {
}
}

func writePacket(w io.Writer, b []byte) error {
var length [4]byte
binary.BigEndian.PutUint32(length[:], uint32(len(b)))
if _, err := w.Write(length[:]); err != nil {
return err
}
if _, err := w.Write(b); err != nil {
return err
}
return nil
}
func readPacket(conn net.Conn) ([]byte, error) {
var length [4]byte
if _, err := io.ReadFull(conn, length[:]); err != nil {
return []byte{}, err
}
l := binary.BigEndian.Uint32(length[:])
if l == 0 {
return []byte{}, fmt.Errorf("agent: request size is 0")
}
if l > maxAgentResponseBytes {
return []byte{}, fmt.Errorf("agent: request too large: %d", l)
}
b := make([]byte, l)
if _, err := io.ReadFull(conn, b); err != nil {
return []byte{}, err
}
//log.Printf("readPacket b:%s", string(b))
return b, nil
}

func (ps *pwshIOStream) handleConnection(ctx context.Context, conn net.Conn, channelID uint32) {
defer conn.Close()
ch := ps.OpenChannel(channelID)
Expand All @@ -151,19 +199,16 @@ func (ps *pwshIOStream) handleConnection(ctx context.Context, conn net.Conn, cha
}()
packetType := PacketTypeConnectSend
for {
payload := make([]byte, 4096)
n, err := conn.Read(payload)
b, err := readPacket(conn)
if err != nil {
if err == io.EOF {
if debug {
log.Printf("DomainSocket.read ch:%d io.EOF", channelID)
}
break
}
log.Println("Error reading from connection:", err)
break
}
ps.WriteChannel(Packet{PacketType: packetType, ChannelID: channelID, Payload: payload[:n]})
ps.WriteChannel(Packet{PacketType: packetType, ChannelID: channelID, Payload: b})
packetType = PacketTypeSend
select {
case <-ctx.Done():
Expand All @@ -175,7 +220,7 @@ func (ps *pwshIOStream) handleConnection(ctx context.Context, conn net.Conn, cha

domainSocketWriter := bufio.NewWriter(conn)
for msg := range ch {
_, err := domainSocketWriter.Write(msg)
err := writePacket(domainSocketWriter, msg)
if err != nil {
log.Println("Error writing to connection:", err)
break
Expand Down
84 changes: 22 additions & 62 deletions cmd/wsl2-ssh-agent-proxy/pwsh.ps1
Original file line number Diff line number Diff line change
Expand Up @@ -32,63 +32,6 @@ $WritePacketWorker = {
}
}

$NamePipeReadWorkerScript = {
param (
[Hashtable] $WorkerInstance
)

class NamedPipeReadWorker {
[Hashtable] $WorkerInstance

NamedPipeReadWorker([Hashtable] $WorkerInstance) {
$this.WorkerInstance = $WorkerInstance
}

[void]SendResponse([hashtable]$Packet) {
$null = $this.WorkerInstance.MainPacketQueue.Enqueue($Packet)
$null = $this.WorkerInstance.MainPacketQueueSignal.Set()
# [Console]::Error.WriteLine("PacketWorker [ch:$($Packet.ChannelID) type:$($Packet.TypeNum)]: Response sent.")
}

[void]StopWorker([Int32]$ChannelID) {
$this.SendResponse(@{ Type = 2; Payload = [byte[]]::new(0); ChannelID = $ChannelID })
$null = $this.WorkerInstance.WorkerQueue.Enqueue($this.WorkerInstance)
# [Console]::Error.WriteLine("PacketWorker [ch:$($ChannelID)]: Worker stopped.")
}
[void]NamedPipeReaderRun() {
# [Console]::Error.WriteLine("PacketReadWorker started.")
$Payload = [byte[]]::new(10240)
while ($true) {
# TODO: NamedPipeReadWorkerQueueSignal, NamedPipeReadWorkerQueue
$null = $this.WorkerInstance.NamedPipeReadWorkerQueueSignal.WaitOne()
$task = @{}
while ($this.WorkerInstance.NamedPipeReadWorkerQueue.TryDequeue([ref]$task)) {
# [Console]::Error.WriteLine("PacketReadWorker [ch:$($task.channelID) type:$($this.WorkerInstance.TypeNum)]: Packet received.")
while ($true) {
try {
$n = $task.NamedPipeStream.Read($Payload, 0, $Payload.Length)
if ($n -gt 0) {
$Payload = $Payload[0..($n - 1)]
$this.SendResponse(@{ Type = 1; Payload = $Payload; ChannelID = $task.channelID })
throw "PacketReadWorker [ch:$($task.channelID)]: Response read from named pipe and sent."
}
}
catch {
[Console]::Error.WriteLine("PacketReadWorker [ch:$($task.channelID)]: Exception occurred while processing. Error: $($_.Exception.Message). Worker will stop.")
$this.SendResponse(@{ Type = 2; Payload = [byte[]]::new(0); ChannelID = $task.ChannelID })
$task.NamedPipeStream.Close()
$null = $this.WorkerInstance.WorkerQueue.Enqueue($this.WorkerInstance)
# [Console]::Error.WriteLine("PacketReadWorker [ch:$($task.ChannelID)]: Worker stopped.")
continue
}
}
}
}
}
}
[NamedPipeReadWorker]::new($WorkerInstance).NamedPipeReaderRun()
}


$PacketWorkerScript = {
param (
Expand Down Expand Up @@ -149,6 +92,7 @@ $PacketWorkerScript = {
$this.NamedPipeStream = $null
return $false
}
# [Console]::Error.WriteLine("PacketWorker [ch:$($Packet.channelID) type:$($Packet.TypeNum)]: open Named-pipe:openssh-ssh-agent...")
$this.NamedPipeStream = [System.IO.Pipes.NamedPipeClientStream]::new(".", "openssh-ssh-agent", [System.IO.Pipes.PipeDirection]::InOut)
$this.NamedPipeStream.Connect()
$this.WorkerInstance.ChannelID = $Packet.ChannelID
Expand All @@ -165,19 +109,35 @@ $PacketWorkerScript = {
return $false
}

# TODO: Write -> Read
# Write -> Read

# [Console]::Error.WriteLine("PacketWorker [ch:$($Packet.channelID) type:$($Packet.TypeNum)]: Named pipe write...")
$Header = [BitConverter]::GetBytes($Packet.Payload.Length)
$null = [Array]::Reverse($Header)
$this.NamedPipeStream.Write($Header, 0, $Header.Length)
$this.NamedPipeStream.Write($Packet.Payload, 0, $Packet.Payload.Length)
$this.NamedPipeStream.Flush()
# [Console]::Error.WriteLine("PacketWorker [ch:$($Packet.channelID) type:$($Packet.TypeNum)]: Data written to named pipe.")
# [Console]::Error.WriteLine("PacketWorker [ch:$($Packet.channelID) type:$($Packet.TypeNum)]: Data written to named pipe. length:$($Packet.Payload.Length)")

$Payload = [byte[]]::new(10240)
$n = $this.NamedPipeStream.Read($Payload, 0, $Payload.Length)
$Header = [byte[]]::new(4)
$n = $this.NamedPipeStream.Read($Header, 0, $Header.Length)
if ($n -eq 0) {
[Console]::Error.WriteLine("PacketWorker [ch:$($Packet.channelID) type:$($Packet.TypeNum)]: header reaad error length:0")
return $false
}
$null = [Array]::Reverse($Header)
$length = [BitConverter]::ToInt32($Header, 0)
$Payload = [byte[]]::new($length)
$n = $this.NamedPipeStream.Read($Payload, 0, $length)
if ($n -gt 0) {
# [Console]::Error.WriteLine("PacketWorker [ch:$($Packet.channelID) type:$($Packet.TypeNum)]: read payload:length:$($Payload.Length), n:$($n)")
$Payload = $Payload[0..($n - 1)]
# [Console]::Error.WriteLine("PacketWorker [ch:$($Packet.channelID) type:$($Packet.TypeNum)]: read payload:length:$($Payload.Length), n:$($n)")
$this.SendResponse(@{ Type = 1; Payload = $Payload; ChannelID = $Packet.ChannelID })
# [Console]::Error.WriteLine("PacketWorker [ch:$($Packet.channelID) type:$($Packet.TypeNum)]: Response read from named pipe and sent.")
return $true
}
return $true
return $false
}
}

Expand Down

0 comments on commit 21de990

Please sign in to comment.