diff --git a/spannerlib/socket-server/.gitignore b/spannerlib/socket-server/.gitignore new file mode 100644 index 00000000..e935fd38 --- /dev/null +++ b/spannerlib/socket-server/.gitignore @@ -0,0 +1 @@ +binaries diff --git a/spannerlib/socket-server/build-executables.sh b/spannerlib/socket-server/build-executables.sh new file mode 100755 index 00000000..dcad6fdb --- /dev/null +++ b/spannerlib/socket-server/build-executables.sh @@ -0,0 +1,17 @@ +# Builds the socket server binary for darwin/arm64, linux/x64, and windows/x64. +# The binaries are stored in the following files: +# binaries/osx-arm64/spannerlib_socket_server +# binaries/linux-x64/spannerlib_socket_server +# binaries/win-x64/spannerlib_socket_server.exe + +mkdir -p binaries/osx-arm64 +GOOS=darwin GOARCH=arm64 go build -o binaries/osx-arm64/spannerlib_socket_server server.go +chmod +x binaries/osx-arm64/spannerlib_socket_server + +mkdir -p binaries/linux-x64 +GOOS=linux GOARCH=amd64 go build -o binaries/linux-x64/spannerlib_socket_server server.go +chmod +x binaries/linux-x64/spannerlib_socket_server + +mkdir -p binaries/win-x64 +GOOS=windows GOARCH=amd64 go build -o binaries/win-x64/spannerlib_socket_server.exe server.go +chmod +x binaries/win-x64/spannerlib_socket_server.exe diff --git a/spannerlib/socket-server/client/connection.go b/spannerlib/socket-server/client/connection.go new file mode 100644 index 00000000..07c6036c --- /dev/null +++ b/spannerlib/socket-server/client/connection.go @@ -0,0 +1,126 @@ +package client + +import ( + "bufio" + "net" + + "cloud.google.com/go/spanner/apiv1/spannerpb" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "spannerlib/socket-server/message" + "spannerlib/socket-server/protocol" +) + +type Connection struct { + pool *Pool + + conn net.Conn + reader *bufio.Reader + writer *bufio.Writer +} + +func (c *Connection) Begin(options *spannerpb.TransactionOptions) error { + msg := message.CreateBeginMessage(options) + if err := msg.Write(c.writer); err != nil { + return err + } + if err := c.writer.Flush(); err != nil { + return err + } + _, err := message.ReadMessageOrError(c.reader) + if err != nil { + return err + } + return nil +} + +func (c *Connection) Commit() (*spannerpb.CommitResponse, error) { + msg := message.CreateCommitMessage() + if err := msg.Write(c.writer); err != nil { + return nil, err + } + if err := c.writer.Flush(); err != nil { + return nil, err + } + m, err := message.ReadMessageOrError(c.reader) + if err != nil { + return nil, err + } + result, ok := m.(*message.CommitResultMessage) + if !ok { + return nil, status.Error(codes.Internal, "message is not a CommitResultMessage") + } + + return result.Response, nil +} + +func (c *Connection) Rollback() error { + msg := message.CreateRollbackMessage() + if err := msg.Write(c.writer); err != nil { + return err + } + if err := c.writer.Flush(); err != nil { + return err + } + _, err := message.ReadMessageOrError(c.reader) + if err != nil { + return err + } + return nil +} + +func (c *Connection) Execute(request *spannerpb.ExecuteSqlRequest) (*Rows, error) { + msg := message.CreateExecuteMessage(request) + if err := msg.Write(c.writer); err != nil { + return nil, err + } + if err := c.writer.Flush(); err != nil { + return nil, err + } + r, err := message.ReadMessageOrError(c.reader) + if err != nil { + return nil, err + } + rowsMsg, ok := r.(*message.RowsMessage) + if !ok { + return nil, status.Error(codes.Internal, "message type is not RowsMessage") + } + metadata, err := protocol.ReadMetadata(c.reader) + if err != nil { + return nil, err + } + return &Rows{ + conn: c, + id: rowsMsg.Id, + metadata: metadata, + }, nil +} + +func (c *Connection) ExecuteBatch(request *spannerpb.ExecuteBatchDmlRequest) (*spannerpb.ExecuteBatchDmlResponse, error) { + msg := message.CreateExecuteBatchMessage(request) + if err := msg.Write(c.writer); err != nil { + return nil, err + } + if err := c.writer.Flush(); err != nil { + return nil, err + } + r, err := message.ReadMessageOrError(c.reader) + if err != nil { + return nil, err + } + batchMsg, ok := r.(*message.BatchResultMessage) + if !ok { + return nil, status.Error(codes.Internal, "message type is not BatchResultMessage") + } + return batchMsg.Response, nil +} + +func (c *Connection) Close() error { + if err := c.conn.Close(); err != nil { + return err + } + c.reader = nil + c.writer = nil + c.conn = nil + return nil +} diff --git a/spannerlib/socket-server/client/pool.go b/spannerlib/socket-server/client/pool.go new file mode 100644 index 00000000..80817d1b --- /dev/null +++ b/spannerlib/socket-server/client/pool.go @@ -0,0 +1,57 @@ +package client + +import ( + "bufio" + "net" + + "github.com/google/uuid" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "spannerlib/socket-server/message" +) + +type Pool struct { + tp string + addr string + + id string + dsn string +} + +func CreatePool(tp, addr, dsn string) *Pool { + return CreatePoolWithId(tp, addr, dsn, uuid.New().String()) +} + +func CreatePoolWithId(tp, addr, dsn, id string) *Pool { + return &Pool{tp: tp, addr: addr, id: id, dsn: dsn} +} + +func (p *Pool) CreateConnection() (*Connection, error) { + conn, err := net.Dial(p.tp, p.addr) + if err != nil { + return nil, err + } + connection := &Connection{ + pool: p, + conn: conn, + reader: bufio.NewReader(conn), + writer: bufio.NewWriter(conn), + } + startup := message.CreateStartupMessage(p.id, p.dsn) + if err := startup.Write(connection.writer); err != nil { + return nil, err + } + if err := connection.writer.Flush(); err != nil { + return nil, err + } + if msg, err := message.ReadMessageOrError(connection.reader); err != nil { + _ = connection.Close() + return nil, err + } else { + if _, ok := msg.(*message.StatusMessage); !ok { + return nil, status.Error(codes.Internal, "message type is not StatusMessage") + } + } + + return connection, nil +} diff --git a/spannerlib/socket-server/client/rows.go b/spannerlib/socket-server/client/rows.go new file mode 100644 index 00000000..2a7754ee --- /dev/null +++ b/spannerlib/socket-server/client/rows.go @@ -0,0 +1,51 @@ +package client + +import ( + "cloud.google.com/go/spanner/apiv1/spannerpb" + "google.golang.org/protobuf/types/known/structpb" + "spannerlib/socket-server/protocol" +) + +type Rows struct { + conn *Connection + id int64 + + metadata *spannerpb.ResultSetMetadata + stats *spannerpb.ResultSetStats +} + +func (r *Rows) Next() (*structpb.ListValue, error) { + var hasMoreRows bool + if err := protocol.ReadBool(r.conn.reader, &hasMoreRows); err != nil { + return nil, err + } + if !hasMoreRows { + stats, err := protocol.ReadStats(r.conn.reader) + if err != nil { + return nil, err + } + r.stats = stats + return nil, nil + } + + row, err := protocol.ReadRow(r.conn.reader, r.metadata) + if err != nil { + return nil, err + } + return row, nil +} + +func (r *Rows) Close() error { + if r.stats == nil { + for { + row, err := r.Next() + if err != nil { + return err + } + if row == nil { + break + } + } + } + return nil +} diff --git a/spannerlib/socket-server/message/batch_result.go b/spannerlib/socket-server/message/batch_result.go new file mode 100644 index 00000000..fdfbb90c --- /dev/null +++ b/spannerlib/socket-server/message/batch_result.go @@ -0,0 +1,45 @@ +package message + +import ( + "bufio" + "fmt" + + "cloud.google.com/go/spanner/apiv1/spannerpb" + "spannerlib/socket-server/protocol" +) + +var _ Message = &BatchResultMessage{} + +type BatchResultMessage struct { + message + Response *spannerpb.ExecuteBatchDmlResponse +} + +func CreateBatchResultMessage(res *spannerpb.ExecuteBatchDmlResponse) *BatchResultMessage { + return &BatchResultMessage{ + message: message{messageId: BatchResultMessageId}, + Response: res, + } +} + +func (m *BatchResultMessage) String() string { + return fmt.Sprintf("BatchResultMessage: %v", m.Response) +} + +func (m *BatchResultMessage) MessageId() Id { + return BatchResultMessageId +} + +func (m *BatchResultMessage) Write(writer *bufio.Writer) error { + if err := m.writeHeader(writer); err != nil { + return err + } + if err := protocol.WriteExecuteBatchResponse(writer, m.Response); err != nil { + return err + } + return nil +} + +func (m *BatchResultMessage) Handle(handler Handler) error { + return handler.HandleBatchResult(m) +} diff --git a/spannerlib/socket-server/message/begin.go b/spannerlib/socket-server/message/begin.go new file mode 100644 index 00000000..b8c3fd12 --- /dev/null +++ b/spannerlib/socket-server/message/begin.go @@ -0,0 +1,41 @@ +package message + +import ( + "bufio" + "fmt" + + "cloud.google.com/go/spanner/apiv1/spannerpb" + "spannerlib/socket-server/protocol" +) + +var _ Message = &BeginMessage{} + +type BeginMessage struct { + message + Options *spannerpb.TransactionOptions +} + +func CreateBeginMessage(options *spannerpb.TransactionOptions) *BeginMessage { + return &BeginMessage{ + message: message{messageId: BeginMessageId}, + Options: options, + } +} + +func (m *BeginMessage) String() string { + return fmt.Sprintf("BeginMessage: %v", m.Options) +} + +func (m *BeginMessage) Write(writer *bufio.Writer) error { + if err := m.writeHeader(writer); err != nil { + return err + } + if err := protocol.WriteTransactionOptions(writer, m.Options); err != nil { + return err + } + return nil +} + +func (m *BeginMessage) Handle(handler Handler) error { + return handler.HandleBegin(m) +} diff --git a/spannerlib/socket-server/message/commit.go b/spannerlib/socket-server/message/commit.go new file mode 100644 index 00000000..f6fcd7c6 --- /dev/null +++ b/spannerlib/socket-server/message/commit.go @@ -0,0 +1,33 @@ +package message + +import ( + "bufio" + "fmt" +) + +var _ Message = &CommitMessage{} + +type CommitMessage struct { + message +} + +func CreateCommitMessage() *CommitMessage { + return &CommitMessage{ + message: message{messageId: CommitMessageId}, + } +} + +func (m *CommitMessage) String() string { + return fmt.Sprintf("CommitMessage") +} + +func (m *CommitMessage) Write(writer *bufio.Writer) error { + if err := m.writeHeader(writer); err != nil { + return err + } + return nil +} + +func (m *CommitMessage) Handle(handler Handler) error { + return handler.HandleCommit(m) +} diff --git a/spannerlib/socket-server/message/commit_result.go b/spannerlib/socket-server/message/commit_result.go new file mode 100644 index 00000000..506ba0e0 --- /dev/null +++ b/spannerlib/socket-server/message/commit_result.go @@ -0,0 +1,41 @@ +package message + +import ( + "bufio" + "fmt" + + "cloud.google.com/go/spanner/apiv1/spannerpb" + "spannerlib/socket-server/protocol" +) + +var _ Message = &CommitResultMessage{} + +type CommitResultMessage struct { + message + Response *spannerpb.CommitResponse +} + +func CreateCommitResultMessage(resp *spannerpb.CommitResponse) *CommitResultMessage { + return &CommitResultMessage{ + message: message{messageId: CommitResultMessageId}, + Response: resp, + } +} + +func (m *CommitResultMessage) String() string { + return fmt.Sprintf("CommitResultMessage: %v", m.Response) +} + +func (m *CommitResultMessage) Write(writer *bufio.Writer) error { + if err := m.writeHeader(writer); err != nil { + return err + } + if err := protocol.WriteCommitResponse(writer, m.Response); err != nil { + return err + } + return nil +} + +func (m *CommitResultMessage) Handle(handler Handler) error { + return handler.HandleCommitResult(m) +} diff --git a/spannerlib/socket-server/message/execute.go b/spannerlib/socket-server/message/execute.go new file mode 100644 index 00000000..16aad422 --- /dev/null +++ b/spannerlib/socket-server/message/execute.go @@ -0,0 +1,41 @@ +package message + +import ( + "bufio" + "fmt" + + "cloud.google.com/go/spanner/apiv1/spannerpb" + "spannerlib/socket-server/protocol" +) + +var _ Message = &ExecuteMessage{} + +type ExecuteMessage struct { + message + Request *spannerpb.ExecuteSqlRequest +} + +func CreateExecuteMessage(request *spannerpb.ExecuteSqlRequest) *ExecuteMessage { + return &ExecuteMessage{ + message: message{messageId: ExecuteMessageId}, + Request: request, + } +} + +func (m *ExecuteMessage) String() string { + return fmt.Sprintf("ExecuteMessage: %v", m.Request.Sql) +} + +func (m *ExecuteMessage) Write(writer *bufio.Writer) error { + if err := m.writeHeader(writer); err != nil { + return err + } + if err := protocol.WriteExecuteSqlRequest(writer, m.Request); err != nil { + return err + } + return nil +} + +func (m *ExecuteMessage) Handle(handler Handler) error { + return handler.HandleExecute(m) +} diff --git a/spannerlib/socket-server/message/execute_batch.go b/spannerlib/socket-server/message/execute_batch.go new file mode 100644 index 00000000..9d504575 --- /dev/null +++ b/spannerlib/socket-server/message/execute_batch.go @@ -0,0 +1,41 @@ +package message + +import ( + "bufio" + "fmt" + + "cloud.google.com/go/spanner/apiv1/spannerpb" + "spannerlib/socket-server/protocol" +) + +var _ Message = &ExecuteBatchMessage{} + +type ExecuteBatchMessage struct { + message + Request *spannerpb.ExecuteBatchDmlRequest +} + +func CreateExecuteBatchMessage(request *spannerpb.ExecuteBatchDmlRequest) *ExecuteBatchMessage { + return &ExecuteBatchMessage{ + message: message{messageId: ExecuteBatchMessageId}, + Request: request, + } +} + +func (m *ExecuteBatchMessage) String() string { + return fmt.Sprintf("ExecuteBatchMessage: %v", m.Request) +} + +func (m *ExecuteBatchMessage) Write(writer *bufio.Writer) error { + if err := m.writeHeader(writer); err != nil { + return err + } + if err := protocol.WriteExecuteBatchRequest(writer, m.Request); err != nil { + return err + } + return nil +} + +func (m *ExecuteBatchMessage) Handle(handler Handler) error { + return handler.HandleExecuteBatch(m) +} diff --git a/spannerlib/socket-server/message/handler.go b/spannerlib/socket-server/message/handler.go new file mode 100644 index 00000000..4a2e28a6 --- /dev/null +++ b/spannerlib/socket-server/message/handler.go @@ -0,0 +1,16 @@ +package message + +type Handler interface { + HandleStartup(msg *StartupMessage) error + HandleExecute(msg *ExecuteMessage) error + HandleExecuteBatch(msg *ExecuteBatchMessage) error + HandleRows(msg *RowsMessage) error + HandleBatchResult(msg *BatchResultMessage) error + + HandleBegin(msg *BeginMessage) error + HandleCommit(msg *CommitMessage) error + HandleRollback(msg *RollbackMessage) error + HandleCommitResult(msg *CommitResultMessage) error + + HandleStatus(msg *StatusMessage) error +} diff --git a/spannerlib/socket-server/message/message.go b/spannerlib/socket-server/message/message.go new file mode 100644 index 00000000..28ae37b6 --- /dev/null +++ b/spannerlib/socket-server/message/message.go @@ -0,0 +1,200 @@ +package message + +import ( + "bufio" + + spb "google.golang.org/genproto/googleapis/rpc/status" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/proto" + "spannerlib/socket-server/protocol" +) + +type Id byte + +const StatusMessageId Id = '0' +const StartupMessageId Id = 'S' +const ExecuteMessageId Id = 'E' +const ExecuteBatchMessageId Id = 'B' +const RowsMessageId Id = 'R' +const BatchResultMessageId Id = 'A' + +const BeginMessageId Id = 'b' +const CommitMessageId Id = 'c' +const RollbackMessageId Id = 'r' +const CommitResultMessageId Id = 't' + +const CloseMessageId Id = 'X' + +type Message interface { + String() string + MessageId() Id + Write(writer *bufio.Writer) error + + Handle(handler Handler) error +} + +type message struct { + messageId Id +} + +func (m *message) writeHeader(writer *bufio.Writer) error { + if err := writer.WriteByte(byte(m.messageId)); err != nil { + return err + } + return nil +} + +func (m *message) MessageId() Id { + return m.messageId +} + +type CloseMessage struct { + message + RowsId int64 +} + +func ReadMessageOrError(reader *bufio.Reader) (Message, error) { + msg, err := ReadMessage(reader) + if err != nil { + return nil, err + } + if msg.MessageId() != StatusMessageId { + return msg, nil + } + s := msg.(*StatusMessage) + if s.status == nil || s.status.Code == int32(codes.OK) { + return msg, nil + } + return nil, status.ErrorProto(s.status) +} + +func ReadMessage(reader *bufio.Reader) (Message, error) { + var id byte + if err := protocol.ReadByte(reader, &id); err != nil { + return nil, err + } + msg, err := CreateMessage(Id(id), reader) + if err != nil { + return nil, err + } + return msg, nil +} + +func CreateMessage(id Id, reader *bufio.Reader) (Message, error) { + switch id { + case StatusMessageId: + return createStatusMessage(reader) + case StartupMessageId: + return createStartupMessage(reader) + case BeginMessageId: + return createBeginMessage(reader) + case CommitMessageId: + return createCommitMessage(reader) + case RollbackMessageId: + return createRollbackMessage(reader) + case CommitResultMessageId: + return createCommitResultMessage(reader) + case ExecuteMessageId: + return createExecuteMessage(reader) + case ExecuteBatchMessageId: + return createExecuteBatchMessage(reader) + case RowsMessageId: + return createRowsMessage(reader) + case BatchResultMessageId: + return createBatchResultMessage(reader) + default: + return nil, status.Errorf(codes.InvalidArgument, "unknown MessageId: %v", id) + } +} + +func createStatusMessage(reader *bufio.Reader) (*StatusMessage, error) { + msg := &StatusMessage{message: message{messageId: StatusMessageId}} + var b []byte + if err := protocol.ReadBytes(reader, &b); err != nil { + return nil, err + } + msg.status = &spb.Status{} + if err := proto.Unmarshal(b, msg.status); err != nil { + return nil, err + } + return msg, nil +} + +func createStartupMessage(reader *bufio.Reader) (*StartupMessage, error) { + msg := &StartupMessage{message: message{messageId: StartupMessageId}} + if err := protocol.ReadString(reader, &msg.Pool); err != nil { + return nil, err + } + if err := protocol.ReadString(reader, &msg.DSN); err != nil { + return nil, err + } + return msg, nil +} + +func createBeginMessage(reader *bufio.Reader) (*BeginMessage, error) { + msg := &BeginMessage{} + options, err := protocol.ReadTransactionOptions(reader) + if err != nil { + return nil, err + } + msg.Options = options + return msg, nil +} + +func createCommitMessage(reader *bufio.Reader) (*CommitMessage, error) { + return &CommitMessage{}, nil +} + +func createRollbackMessage(reader *bufio.Reader) (*RollbackMessage, error) { + return &RollbackMessage{}, nil +} + +func createCommitResultMessage(reader *bufio.Reader) (*CommitResultMessage, error) { + msg := &CommitResultMessage{} + resp, err := protocol.ReadCommitResponse(reader) + if err != nil { + return nil, err + } + msg.Response = resp + + return msg, nil +} + +func createExecuteMessage(reader *bufio.Reader) (*ExecuteMessage, error) { + msg := &ExecuteMessage{message: message{messageId: ExecuteMessageId}} + request, err := protocol.ReadExecuteSqlRequest(reader) + if err != nil { + return nil, err + } + msg.Request = request + return msg, nil +} + +func createExecuteBatchMessage(reader *bufio.Reader) (*ExecuteBatchMessage, error) { + msg := &ExecuteBatchMessage{message: message{messageId: ExecuteBatchMessageId}} + request, err := protocol.ReadExecuteBatchRequest(reader) + if err != nil { + return nil, err + } + msg.Request = request + return msg, nil +} + +func createRowsMessage(reader *bufio.Reader) (*RowsMessage, error) { + msg := &RowsMessage{message: message{messageId: RowsMessageId}} + if err := protocol.ReadInt64(reader, &msg.Id); err != nil { + return nil, err + } + return msg, nil +} + +func createBatchResultMessage(reader *bufio.Reader) (*BatchResultMessage, error) { + msg := &BatchResultMessage{message: message{messageId: BatchResultMessageId}} + resp, err := protocol.ReadExecuteBatchResponse(reader) + if err != nil { + return nil, err + } + msg.Response = resp + return msg, nil +} diff --git a/spannerlib/socket-server/message/rollback.go b/spannerlib/socket-server/message/rollback.go new file mode 100644 index 00000000..71f007a8 --- /dev/null +++ b/spannerlib/socket-server/message/rollback.go @@ -0,0 +1,33 @@ +package message + +import ( + "bufio" + "fmt" +) + +var _ Message = &RollbackMessage{} + +type RollbackMessage struct { + message +} + +func CreateRollbackMessage() *RollbackMessage { + return &RollbackMessage{ + message: message{messageId: RollbackMessageId}, + } +} + +func (m *RollbackMessage) String() string { + return fmt.Sprintf("RollbackMessage") +} + +func (m *RollbackMessage) Write(writer *bufio.Writer) error { + if err := m.writeHeader(writer); err != nil { + return err + } + return nil +} + +func (m *RollbackMessage) Handle(handler Handler) error { + return handler.HandleRollback(m) +} diff --git a/spannerlib/socket-server/message/rows.go b/spannerlib/socket-server/message/rows.go new file mode 100644 index 00000000..e1b7e3b1 --- /dev/null +++ b/spannerlib/socket-server/message/rows.go @@ -0,0 +1,44 @@ +package message + +import ( + "bufio" + "fmt" + + "spannerlib/socket-server/protocol" +) + +var _ Message = &RowsMessage{} + +type RowsMessage struct { + message + Id int64 +} + +func CreateRowsMessage(id int64) *RowsMessage { + return &RowsMessage{ + message: message{messageId: RowsMessageId}, + Id: id, + } +} + +func (m *RowsMessage) String() string { + return fmt.Sprintf("RowsMessage: %v", m.Id) +} + +func (m *RowsMessage) MessageId() Id { + return RowsMessageId +} + +func (m *RowsMessage) Write(writer *bufio.Writer) error { + if err := m.writeHeader(writer); err != nil { + return err + } + if err := protocol.WriteInt64(writer, m.Id); err != nil { + return err + } + return nil +} + +func (m *RowsMessage) Handle(handler Handler) error { + return handler.HandleRows(m) +} diff --git a/spannerlib/socket-server/message/startup.go b/spannerlib/socket-server/message/startup.go new file mode 100644 index 00000000..d2d48576 --- /dev/null +++ b/spannerlib/socket-server/message/startup.go @@ -0,0 +1,45 @@ +package message + +import ( + "bufio" + "fmt" + + "spannerlib/socket-server/protocol" +) + +var _ Message = &StartupMessage{} + +type StartupMessage struct { + message + Pool string + DSN string +} + +func CreateStartupMessage(pool, dsn string) *StartupMessage { + return &StartupMessage{ + message: message{messageId: StartupMessageId}, + Pool: pool, + DSN: dsn, + } +} + +func (m *StartupMessage) String() string { + return fmt.Sprintf("StartupMessage: %v", m.DSN) +} + +func (m *StartupMessage) Write(writer *bufio.Writer) error { + if err := m.writeHeader(writer); err != nil { + return err + } + if err := protocol.WriteString(writer, m.Pool); err != nil { + return err + } + if err := protocol.WriteString(writer, m.DSN); err != nil { + return err + } + return nil +} + +func (m *StartupMessage) Handle(handler Handler) error { + return handler.HandleStartup(m) +} diff --git a/spannerlib/socket-server/message/status.go b/spannerlib/socket-server/message/status.go new file mode 100644 index 00000000..39bcf5f8 --- /dev/null +++ b/spannerlib/socket-server/message/status.go @@ -0,0 +1,60 @@ +package message + +import ( + "bufio" + "fmt" + + spb "google.golang.org/genproto/googleapis/rpc/status" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/proto" + "spannerlib/socket-server/protocol" +) + +var _ Message = &StatusMessage{} +var OK = &StatusMessage{ + message: message{messageId: StatusMessageId}, + status: &spb.Status{Code: int32(codes.OK)}, +} + +type StatusMessage struct { + message + status *spb.Status +} + +func CreateStatusMessage(err error) *StatusMessage { + s, _ := status.FromError(err) + msg := &StatusMessage{ + message: message{messageId: StatusMessageId}, + status: s.Proto(), + } + return msg +} + +func (m *StatusMessage) Err() error { + if m.status == nil { + return nil + } + return status.ErrorProto(m.status) +} + +func (m *StatusMessage) String() string { + return fmt.Sprintf("StatusMessage: %v", m.status) +} + +func (m *StatusMessage) Write(writer *bufio.Writer) error { + if err := m.writeHeader(writer); err != nil { + return err + } + if m.status != nil { + b, _ := proto.Marshal(m.status) + if err := protocol.WriteBytes(writer, b); err != nil { + return err + } + } + return nil +} + +func (m *StatusMessage) Handle(handler Handler) error { + return handler.HandleStatus(m) +} diff --git a/spannerlib/socket-server/protocol/encoding.go b/spannerlib/socket-server/protocol/encoding.go new file mode 100644 index 00000000..3d84ef0f --- /dev/null +++ b/spannerlib/socket-server/protocol/encoding.go @@ -0,0 +1,298 @@ +package protocol + +import ( + "bufio" + "encoding/binary" + "io" + + "cloud.google.com/go/spanner/apiv1/spannerpb" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/types/known/structpb" +) + +func StringLength(s string) uint32 { + return 4 + uint32(len([]byte(s))) +} + +func ReadBool(reader *bufio.Reader, b *bool) error { + if err := binary.Read(reader, binary.BigEndian, b); err != nil { + return err + } + return nil +} + +func WriteBool(writer *bufio.Writer, b bool) error { + if err := binary.Write(writer, binary.BigEndian, b); err != nil { + return err + } + return nil +} + +func ReadInt32(reader *bufio.Reader, n *int32) error { + if err := binary.Read(reader, binary.BigEndian, n); err != nil { + return err + } + return nil +} + +func WriteInt32(writer *bufio.Writer, value int32) error { + if err := binary.Write(writer, binary.BigEndian, value); err != nil { + return err + } + return nil +} + +func ReadUInt32(reader *bufio.Reader, n *uint32) error { + if err := binary.Read(reader, binary.BigEndian, n); err != nil { + return err + } + return nil +} + +func WriteUInt32(writer *bufio.Writer, value uint32) error { + if err := binary.Write(writer, binary.BigEndian, value); err != nil { + return err + } + return nil +} + +func ReadInt64(reader *bufio.Reader, n *int64) error { + if err := binary.Read(reader, binary.BigEndian, n); err != nil { + return err + } + return nil +} + +func WriteInt64(writer *bufio.Writer, value int64) error { + if err := binary.Write(writer, binary.BigEndian, value); err != nil { + return err + } + return nil +} + +func ReadFloat64(reader *bufio.Reader, n *float64) error { + if err := binary.Read(reader, binary.BigEndian, n); err != nil { + return err + } + return nil +} + +func WriteFloat64(writer *bufio.Writer, value float64) error { + if err := binary.Write(writer, binary.BigEndian, value); err != nil { + return err + } + return nil +} + +func ReadString(reader *bufio.Reader, s *string) error { + var l uint32 + if err := binary.Read(reader, binary.BigEndian, &l); err != nil { + return err + } + buf := make([]byte, l) + if _, err := io.ReadFull(reader, buf); err != nil { + return err + } + *s = string(buf) + return nil +} + +func WriteString(writer *bufio.Writer, s string) error { + l := uint32(len([]byte(s))) + if err := binary.Write(writer, binary.BigEndian, l); err != nil { + return err + } + if _, err := writer.WriteString(s); err != nil { + return err + } + return nil +} + +func ReadByte(reader *bufio.Reader, b *byte) error { + if err := binary.Read(reader, binary.BigEndian, b); err != nil { + return err + } + return nil +} + +func WriteByte(writer *bufio.Writer, b byte) error { + if err := binary.Write(writer, binary.BigEndian, b); err != nil { + return err + } + return nil +} + +func ReadBytes(reader *bufio.Reader, b *[]byte) error { + var l uint32 + if err := binary.Read(reader, binary.BigEndian, &l); err != nil { + return err + } + buf := make([]byte, l) + if _, err := io.ReadFull(reader, buf); err != nil { + return err + } + *b = buf + return nil +} + +func WriteBytes(writer *bufio.Writer, b []byte) error { + l := uint32(len(b)) + if err := binary.Write(writer, binary.BigEndian, l); err != nil { + return err + } + if _, err := writer.Write(b); err != nil { + return err + } + return nil +} + +func ReadStats(reader *bufio.Reader) (*spannerpb.ResultSetStats, error) { + c := &spannerpb.ResultSetStats_RowCountExact{} + if err := binary.Read(reader, binary.BigEndian, &c.RowCountExact); err != nil { + return nil, err + } + return &spannerpb.ResultSetStats{RowCount: c}, nil +} + +func WriteStats(writer *bufio.Writer, stats *spannerpb.ResultSetStats) error { + var c int64 + switch stats.GetRowCount().(type) { + case *spannerpb.ResultSetStats_RowCountExact: + c = stats.GetRowCountExact() + case *spannerpb.ResultSetStats_RowCountLowerBound: + c = stats.GetRowCountLowerBound() + } + if err := binary.Write(writer, binary.BigEndian, c); err != nil { + return err + } + return nil +} + +func ReadRow(reader *bufio.Reader, metadata *spannerpb.ResultSetMetadata) (*structpb.ListValue, error) { + row := &structpb.ListValue{} + row.Values = make([]*structpb.Value, len(metadata.RowType.Fields)) + for i := range metadata.RowType.Fields { + value, err := ReadValue(reader) + if err != nil { + return nil, err + } + row.Values[i] = value + } + return row, nil +} + +func WriteRow(writer *bufio.Writer, values *structpb.ListValue) error { + for _, value := range values.Values { + if err := WriteValue(writer, value); err != nil { + return err + } + } + return nil +} + +type ValueType byte + +const ( + ValueTypeNull ValueType = iota + ValueTypeBool + ValueTypeNumber + ValueTypeString + ValueTypeList + ValueTypeStruct +) + +func ReadValue(reader *bufio.Reader) (*structpb.Value, error) { + var tp byte + if err := ReadByte(reader, &tp); err != nil { + return nil, err + } + + valueType := ValueType(tp) + switch valueType { + case ValueTypeNull: + return structpb.NewNullValue(), nil + case ValueTypeBool: + var b bool + if err := binary.Read(reader, binary.BigEndian, &b); err != nil { + return nil, err + } + return structpb.NewBoolValue(b), nil + case ValueTypeNumber: + var f float64 + if err := binary.Read(reader, binary.BigEndian, &f); err != nil { + return nil, err + } + return structpb.NewNumberValue(f), nil + case ValueTypeString: + var s string + if err := ReadString(reader, &s); err != nil { + return nil, err + } + return structpb.NewStringValue(s), nil + case ValueTypeList: + var num int32 + if err := binary.Read(reader, binary.BigEndian, &num); err != nil { + return nil, err + } + lv := &structpb.ListValue{Values: make([]*structpb.Value, num)} + for i := 0; i < int(num); i++ { + v, err := ReadValue(reader) + if err != nil { + return nil, err + } + lv.Values[i] = v + } + return structpb.NewListValue(lv), nil + case ValueTypeStruct: + return nil, status.Errorf(codes.Unimplemented, "struct type not yet implemented") + default: + return nil, status.Errorf(codes.Internal, "unknown type: %v", valueType) + } +} + +func WriteValue(writer *bufio.Writer, value *structpb.Value) error { + switch value.Kind.(type) { + case *structpb.Value_NullValue: + if err := WriteByte(writer, 0); err != nil { + return err + } + case *structpb.Value_BoolValue: + if err := WriteByte(writer, 1); err != nil { + return err + } + if err := WriteBool(writer, value.GetBoolValue()); err != nil { + return err + } + case *structpb.Value_NumberValue: + if err := WriteByte(writer, 2); err != nil { + return err + } + if err := WriteFloat64(writer, value.GetNumberValue()); err != nil { + return err + } + case *structpb.Value_StringValue: + if err := WriteByte(writer, 3); err != nil { + return err + } + if err := WriteString(writer, value.GetStringValue()); err != nil { + return err + } + case *structpb.Value_ListValue: + if err := WriteByte(writer, 4); err != nil { + return err + } + // Write the number of values and then each value. + if err := WriteInt32(writer, int32(len(value.GetListValue().GetValues()))); err != nil { + return err + } + for _, value := range value.GetListValue().GetValues() { + if err := WriteValue(writer, value); err != nil { + return err + } + } + case *structpb.Value_StructValue: + return status.Errorf(codes.Unimplemented, "struct value not yet supported") + } + return nil +} diff --git a/spannerlib/socket-server/protocol/execute_batch.go b/spannerlib/socket-server/protocol/execute_batch.go new file mode 100644 index 00000000..eff01a70 --- /dev/null +++ b/spannerlib/socket-server/protocol/execute_batch.go @@ -0,0 +1,123 @@ +package protocol + +import ( + "bufio" + + "cloud.google.com/go/spanner/apiv1/spannerpb" + spb "google.golang.org/genproto/googleapis/rpc/status" + "google.golang.org/grpc/codes" + "google.golang.org/protobuf/proto" +) + +func ReadExecuteBatchRequest(reader *bufio.Reader) (*spannerpb.ExecuteBatchDmlRequest, error) { + req := &spannerpb.ExecuteBatchDmlRequest{} + var numStatements int32 + if err := ReadInt32(reader, &numStatements); err != nil { + return nil, err + } + req.Statements = make([]*spannerpb.ExecuteBatchDmlRequest_Statement, numStatements) + for i := int32(0); i < numStatements; i++ { + statement := &spannerpb.ExecuteBatchDmlRequest_Statement{} + if err := ReadString(reader, &statement.Sql); err != nil { + return nil, err + } + params, err := ReadParams(reader) + if err != nil { + return nil, err + } + statement.Params = params + paramTypes, err := ReadParamTypes(reader) + if err != nil { + return nil, err + } + statement.ParamTypes = paramTypes + req.Statements[i] = statement + } + return req, nil +} + +func WriteExecuteBatchRequest(writer *bufio.Writer, request *spannerpb.ExecuteBatchDmlRequest) error { + if err := WriteInt32(writer, int32(len(request.Statements))); err != nil { + return err + } + for _, statement := range request.Statements { + if err := WriteString(writer, statement.Sql); err != nil { + return err + } + if err := WriteParams(writer, statement.Params); err != nil { + return err + } + if err := WriteParamTypes(writer, statement.ParamTypes); err != nil { + return err + } + } + return nil +} + +func ReadExecuteBatchResponse(reader *bufio.Reader) (*spannerpb.ExecuteBatchDmlResponse, error) { + res := &spannerpb.ExecuteBatchDmlResponse{} + var code int32 + if err := ReadInt32(reader, &code); err != nil { + return nil, err + } + if code != int32(codes.OK) { + var b []byte + if err := ReadBytes(reader, &b); err != nil { + return nil, err + } + var status *spb.Status + if err := proto.Unmarshal(b, status); err != nil { + return nil, err + } + res.Status = status + } + var numStatements int32 + if err := ReadInt32(reader, &numStatements); err != nil { + return nil, err + } + res.ResultSets = make([]*spannerpb.ResultSet, numStatements) + for i := int32(0); i < numStatements; i++ { + var c int64 + if err := ReadInt64(reader, &c); err != nil { + return nil, err + } + res.ResultSets[i] = &spannerpb.ResultSet{ + Stats: &spannerpb.ResultSetStats{ + RowCount: &spannerpb.ResultSetStats_RowCountExact{ + RowCountExact: c, + }, + }, + } + } + return res, nil +} + +func WriteExecuteBatchResponse(writer *bufio.Writer, response *spannerpb.ExecuteBatchDmlResponse) error { + if response.Status != nil { + if err := WriteInt32(writer, response.Status.Code); err != nil { + return err + } + if response.Status.Code != int32(codes.OK) { + s, err := proto.Marshal(response.Status) + if err != nil { + return err + } + if err := WriteBytes(writer, s); err != nil { + return err + } + } + } else { + if err := WriteInt32(writer, int32(codes.OK)); err != nil { + return err + } + } + if err := WriteInt32(writer, int32(len(response.ResultSets))); err != nil { + return err + } + for _, resultSet := range response.ResultSets { + if err := WriteInt64(writer, resultSet.Stats.GetRowCountExact()); err != nil { + return err + } + } + return nil +} diff --git a/spannerlib/socket-server/protocol/execute_sql.go b/spannerlib/socket-server/protocol/execute_sql.go new file mode 100644 index 00000000..0754b7ec --- /dev/null +++ b/spannerlib/socket-server/protocol/execute_sql.go @@ -0,0 +1,38 @@ +package protocol + +import ( + "bufio" + + "cloud.google.com/go/spanner/apiv1/spannerpb" +) + +func ReadExecuteSqlRequest(reader *bufio.Reader) (*spannerpb.ExecuteSqlRequest, error) { + req := &spannerpb.ExecuteSqlRequest{} + if err := ReadString(reader, &req.Sql); err != nil { + return nil, err + } + params, err := ReadParams(reader) + if err != nil { + return nil, err + } + req.Params = params + paramTypes, err := ReadParamTypes(reader) + if err != nil { + return nil, err + } + req.ParamTypes = paramTypes + return req, nil +} + +func WriteExecuteSqlRequest(writer *bufio.Writer, request *spannerpb.ExecuteSqlRequest) error { + if err := WriteString(writer, request.Sql); err != nil { + return err + } + if err := WriteParams(writer, request.Params); err != nil { + return err + } + if err := WriteParamTypes(writer, request.ParamTypes); err != nil { + return err + } + return nil +} diff --git a/spannerlib/socket-server/protocol/metadata.go b/spannerlib/socket-server/protocol/metadata.go new file mode 100644 index 00000000..38e9b801 --- /dev/null +++ b/spannerlib/socket-server/protocol/metadata.go @@ -0,0 +1,52 @@ +package protocol + +import ( + "bufio" + "encoding/binary" + + "cloud.google.com/go/spanner/apiv1/spannerpb" +) + +func ReadMetadata(reader *bufio.Reader) (*spannerpb.ResultSetMetadata, error) { + metadata := &spannerpb.ResultSetMetadata{} + var numFields int32 + if err := binary.Read(reader, binary.BigEndian, &numFields); err != nil { + return nil, err + } + metadata.RowType = &spannerpb.StructType{} + metadata.RowType.Fields = make([]*spannerpb.StructType_Field, numFields) + for i := 0; i < int(numFields); i++ { + metadata.RowType.Fields[i] = &spannerpb.StructType_Field{Type: &spannerpb.Type{}} + var code int32 + if err := binary.Read(reader, binary.BigEndian, &code); err != nil { + return nil, err + } + metadata.RowType.Fields[i].Type.Code = spannerpb.TypeCode(code) + if code == int32(spannerpb.TypeCode_ARRAY) { + var elementCode int32 + if err := binary.Read(reader, binary.BigEndian, &elementCode); err != nil { + return nil, err + } + metadata.RowType.Fields[i].Type.ArrayElementType = &spannerpb.Type{Code: spannerpb.TypeCode(elementCode)} + } + if err := ReadString(reader, &metadata.RowType.Fields[i].Name); err != nil { + return nil, err + } + } + return metadata, nil +} + +func WriteMetadata(writer *bufio.Writer, metadata *spannerpb.ResultSetMetadata) error { + if err := WriteInt32(writer, int32(len(metadata.RowType.Fields))); err != nil { + return err + } + for _, field := range metadata.RowType.Fields { + if err := WriteType(writer, field.Type); err != nil { + return err + } + if err := WriteString(writer, field.Name); err != nil { + return err + } + } + return nil +} diff --git a/spannerlib/socket-server/protocol/params.go b/spannerlib/socket-server/protocol/params.go new file mode 100644 index 00000000..a46fa29f --- /dev/null +++ b/spannerlib/socket-server/protocol/params.go @@ -0,0 +1,97 @@ +package protocol + +import ( + "bufio" + + "cloud.google.com/go/spanner/apiv1/spannerpb" + "google.golang.org/protobuf/types/known/structpb" +) + +func ReadParams(reader *bufio.Reader) (*structpb.Struct, error) { + params := &structpb.Struct{} + var numParams int32 + if err := ReadInt32(reader, &numParams); err != nil { + return nil, err + } + if numParams > 0 { + params = &structpb.Struct{} + params.Fields = map[string]*structpb.Value{} + for i := 0; i < int(numParams); i++ { + var name string + if err := ReadString(reader, &name); err != nil { + return nil, err + } + value, err := ReadValue(reader) + if err != nil { + return nil, err + } + params.Fields[name] = value + } + } + return params, nil +} + +func WriteParams(writer *bufio.Writer, params *structpb.Struct) error { + if params != nil { + if err := WriteInt32(writer, int32(len(params.Fields))); err != nil { + return err + } + for key, value := range params.Fields { + if err := WriteString(writer, key); err != nil { + return err + } + if err := WriteValue(writer, value); err != nil { + return err + } + } + } else { + if err := WriteInt32(writer, int32(0)); err != nil { + return err + } + } + return nil +} + +func ReadParamTypes(reader *bufio.Reader) (map[string]*spannerpb.Type, error) { + paramTypes := map[string]*spannerpb.Type{} + var numParamTypes int32 + if err := ReadInt32(reader, &numParamTypes); err != nil { + return nil, err + } + if numParamTypes > 0 { + paramTypes = map[string]*spannerpb.Type{} + for i := 0; i < int(numParamTypes); i++ { + var name string + if err := ReadString(reader, &name); err != nil { + return nil, err + } + tp, err := ReadType(reader) + if err != nil { + return nil, err + } + paramTypes[name] = tp + } + } + return paramTypes, nil +} + +func WriteParamTypes(writer *bufio.Writer, paramTypes map[string]*spannerpb.Type) error { + if paramTypes != nil { + if err := WriteInt32(writer, int32(len(paramTypes))); err != nil { + return err + } + for key, tp := range paramTypes { + if err := WriteString(writer, key); err != nil { + return err + } + if err := WriteType(writer, tp); err != nil { + return err + } + } + } else { + if err := WriteInt32(writer, int32(0)); err != nil { + return err + } + } + return nil +} diff --git a/spannerlib/socket-server/protocol/transaction.go b/spannerlib/socket-server/protocol/transaction.go new file mode 100644 index 00000000..9d81d414 --- /dev/null +++ b/spannerlib/socket-server/protocol/transaction.go @@ -0,0 +1,72 @@ +package protocol + +import ( + "bufio" + + "cloud.google.com/go/spanner/apiv1/spannerpb" + "google.golang.org/protobuf/proto" +) + +func ReadTransactionOptions(reader *bufio.Reader) (*spannerpb.TransactionOptions, error) { + var b []byte + if err := ReadBytes(reader, &b); err != nil { + return nil, err + } + if len(b) == 0 { + return nil, nil + } + options := &spannerpb.TransactionOptions{} + if err := proto.Unmarshal(b, options); err != nil { + return nil, err + } + return options, nil +} + +func WriteTransactionOptions(writer *bufio.Writer, options *spannerpb.TransactionOptions) error { + if options == nil { + if err := WriteBytes(writer, []byte{}); err != nil { + return err + } + return nil + } + b, err := proto.Marshal(options) + if err != nil { + return err + } + if err := WriteBytes(writer, b); err != nil { + return err + } + return nil +} + +func ReadCommitResponse(reader *bufio.Reader) (*spannerpb.CommitResponse, error) { + var b []byte + if err := ReadBytes(reader, &b); err != nil { + return nil, err + } + if len(b) == 0 { + return nil, nil + } + resp := &spannerpb.CommitResponse{} + if err := proto.Unmarshal(b, resp); err != nil { + return nil, err + } + return resp, nil +} + +func WriteCommitResponse(writer *bufio.Writer, response *spannerpb.CommitResponse) error { + if response == nil { + if err := WriteBytes(writer, []byte{}); err != nil { + return err + } + return nil + } + b, err := proto.Marshal(response) + if err != nil { + return err + } + if err := WriteBytes(writer, b); err != nil { + return err + } + return nil +} diff --git a/spannerlib/socket-server/protocol/type.go b/spannerlib/socket-server/protocol/type.go new file mode 100644 index 00000000..e83f3732 --- /dev/null +++ b/spannerlib/socket-server/protocol/type.go @@ -0,0 +1,36 @@ +package protocol + +import ( + "bufio" + + "cloud.google.com/go/spanner/apiv1/spannerpb" +) + +func ReadType(reader *bufio.Reader) (*spannerpb.Type, error) { + tp := &spannerpb.Type{} + var code int32 + if err := ReadInt32(reader, &code); err != nil { + return nil, err + } + tp.Code = spannerpb.TypeCode(code) + if code == int32(spannerpb.TypeCode_ARRAY) { + elementType, err := ReadType(reader) + if err != nil { + return nil, err + } + tp.ArrayElementType = elementType + } + return tp, nil +} + +func WriteType(writer *bufio.Writer, tp *spannerpb.Type) error { + if err := WriteInt32(writer, int32(tp.Code)); err != nil { + return err + } + if tp.Code == spannerpb.TypeCode_ARRAY { + if err := WriteInt32(writer, int32(tp.ArrayElementType.Code)); err != nil { + return err + } + } + return nil +} diff --git a/spannerlib/socket-server/server.go b/spannerlib/socket-server/server.go new file mode 100644 index 00000000..0c269c1e --- /dev/null +++ b/spannerlib/socket-server/server.go @@ -0,0 +1,51 @@ +package main + +import ( + "log" + "net" + "os" + "os/signal" + "syscall" + + "spannerlib/socket-server/server" +) + +func main() { + if len(os.Args) < 2 { + log.Fatalf("Missing server address\n") + } + name := os.Args[1] + tp := "unix" + if len(os.Args) > 2 { + tp = os.Args[2] + } + // + if tp == "unix" { + defer func() { _ = os.Remove(name) }() + // Set up a channel to listen for OS signals that terminate the process, + // so we can clean up the temp file in those cases as well. + sigs := make(chan os.Signal, 1) + signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) + go func() { + // Wait for a signal. + <-sigs + // Delete the temp file. + _ = os.Remove(name) + os.Exit(0) + }() + } + + listener, err := net.Listen(tp, name) + if err != nil { + log.Fatalf("failed to listen: %v\n", err) + } + defer func() { _ = listener.Close() }() + log.Printf("Starting server on %s\n", listener.Addr().String()) + s, err := server.CreateServer() + if err != nil { + log.Fatalf("failed to create server: %v\n", err) + } + if err := s.Serve(listener); err != nil { + log.Fatalf("failed to serve: %v\n", err) + } +} diff --git a/spannerlib/socket-server/server/connection_handler.go b/spannerlib/socket-server/server/connection_handler.go new file mode 100644 index 00000000..abfa5c53 --- /dev/null +++ b/spannerlib/socket-server/server/connection_handler.go @@ -0,0 +1,46 @@ +package server + +import ( + "bufio" + "log" + "net" + + "spannerlib/socket-server/message" +) + +type ConnectionHandler struct { + server *Server + conn net.Conn + + reader *bufio.Reader + writer *bufio.Writer + + poolId int64 + connId int64 + + handler *messageHandler +} + +func (c *ConnectionHandler) handleConnection() { + for { + msg, err := message.ReadMessage(c.reader) + if err != nil { + log.Printf("Error reading message from client: %v", err) + break + } + log.Printf("Received message: %s\n", msg) + if err := msg.Handle(c.handler); err != nil { + log.Printf("error handling message: %v\n", err) + resp := message.CreateStatusMessage(err) + if err := resp.Write(c.writer); err != nil { + log.Printf("error writing response: %v\n", err) + break + } + } + if err := c.writer.Flush(); err != nil { + log.Printf("error flushing response: %v\n", err) + break + } + } + _ = c.conn.Close() +} diff --git a/spannerlib/socket-server/server/message_handler.go b/spannerlib/socket-server/server/message_handler.go new file mode 100644 index 00000000..cf5af91a --- /dev/null +++ b/spannerlib/socket-server/server/message_handler.go @@ -0,0 +1,162 @@ +package server + +import ( + "context" + + "spannerlib/api" + "spannerlib/socket-server/message" + "spannerlib/socket-server/protocol" +) + +var _ message.Handler = &messageHandler{} + +type messageHandler struct { + conn *ConnectionHandler +} + +func (h *messageHandler) HandleStatus(msg *message.StatusMessage) error { + return nil +} + +func (h *messageHandler) HandleStartup(msg *message.StartupMessage) error { + ctx := context.Background() + var err error + + p, ok := h.conn.server.pools.Load(msg.Pool) + if !ok { + h.conn.poolId, err = api.CreatePool(ctx, msg.DSN) + if err != nil { + return err + } + h.conn.server.pools.Store(msg.Pool, h.conn.poolId) + } else { + h.conn.poolId = p.(int64) + } + h.conn.connId, err = api.CreateConnection(ctx, h.conn.poolId) + if err != nil { + return err + } + + if err := message.OK.Write(h.conn.writer); err != nil { + return err + } + + return nil +} + +func (h *messageHandler) HandleBegin(msg *message.BeginMessage) error { + ctx := context.Background() + + if err := api.BeginTransaction(ctx, h.conn.poolId, h.conn.connId, msg.Options); err != nil { + return err + } + if err := message.OK.Write(h.conn.writer); err != nil { + return err + } + return nil +} + +func (h *messageHandler) HandleCommit(msg *message.CommitMessage) error { + ctx := context.Background() + + resp, err := api.Commit(ctx, h.conn.poolId, h.conn.connId) + if err != nil { + return err + } + result := message.CreateCommitResultMessage(resp) + if err := result.Write(h.conn.writer); err != nil { + return err + } + + return nil +} + +func (h *messageHandler) HandleRollback(msg *message.RollbackMessage) error { + ctx := context.Background() + + if err := api.Rollback(ctx, h.conn.poolId, h.conn.connId); err != nil { + return err + } + if err := message.OK.Write(h.conn.writer); err != nil { + return err + } + + return nil +} + +func (h *messageHandler) HandleCommitResult(msg *message.CommitResultMessage) error { + return nil +} + +func (h *messageHandler) HandleExecute(msg *message.ExecuteMessage) error { + ctx := context.Background() + + rows, err := api.Execute(ctx, h.conn.poolId, h.conn.connId, msg.Request) + if err != nil { + return err + } + defer func() { _ = api.CloseRows(ctx, h.conn.poolId, h.conn.connId, rows) }() + + rowsMsg := message.CreateRowsMessage(rows) + if err := rowsMsg.Write(h.conn.writer); err != nil { + return err + } + metadata, err := api.Metadata(ctx, h.conn.poolId, h.conn.connId, rows) + if err != nil { + return err + } + if err := protocol.WriteMetadata(h.conn.writer, metadata); err != nil { + return err + } + + for { + values, err := api.Next(ctx, h.conn.poolId, h.conn.connId, rows) + if err != nil { + return err + } + if values == nil { + if err := protocol.WriteBool(h.conn.writer, false); err != nil { + return err + } + break + } + if err := protocol.WriteBool(h.conn.writer, true); err != nil { + return err + } + if err := protocol.WriteRow(h.conn.writer, values); err != nil { + return err + } + } + stats, err := api.ResultSetStats(ctx, h.conn.poolId, h.conn.connId, rows) + if err != nil { + return err + } + if err := protocol.WriteStats(h.conn.writer, stats); err != nil { + return err + } + + return nil +} + +func (h *messageHandler) HandleExecuteBatch(msg *message.ExecuteBatchMessage) error { + ctx := context.Background() + + res, err := api.ExecuteBatch(ctx, h.conn.poolId, h.conn.connId, msg.Request) + if err != nil { + return err + } + batchResultMsg := message.CreateBatchResultMessage(res) + if err := batchResultMsg.Write(h.conn.writer); err != nil { + return err + } + + return nil +} + +func (h *messageHandler) HandleRows(msg *message.RowsMessage) error { + return nil +} + +func (h *messageHandler) HandleBatchResult(msg *message.BatchResultMessage) error { + return nil +} diff --git a/spannerlib/socket-server/server/server.go b/spannerlib/socket-server/server/server.go new file mode 100644 index 00000000..d1d3e4ea --- /dev/null +++ b/spannerlib/socket-server/server/server.go @@ -0,0 +1,50 @@ +package server + +import ( + "bufio" + "net" + "sync" +) + +func CreateServer() (*Server, error) { + return &Server{ + handlers: make([]*ConnectionHandler, 0), + pools: &sync.Map{}, + }, nil +} + +type Server struct { + listener net.Listener + handlers []*ConnectionHandler + + pools *sync.Map +} + +func (s *Server) GracefulStop() { + if s.listener != nil { + _ = s.listener.Close() + s.listener = nil + } +} + +func (s *Server) Serve(listener net.Listener) error { + s.listener = listener + for { + conn, err := listener.Accept() + if err != nil { + return err + } + + connectionHandler := &ConnectionHandler{ + server: s, + conn: conn, + reader: bufio.NewReader(conn), + writer: bufio.NewWriter(conn), + } + connectionHandler.handler = &messageHandler{ + conn: connectionHandler, + } + s.handlers = append(s.handlers, connectionHandler) + go connectionHandler.handleConnection() + } +} diff --git a/spannerlib/socket-server/server/server_test.go b/spannerlib/socket-server/server/server_test.go new file mode 100644 index 00000000..cfb44fd5 --- /dev/null +++ b/spannerlib/socket-server/server/server_test.go @@ -0,0 +1,198 @@ +package server + +import ( + "fmt" + "net" + "os" + "path/filepath" + "runtime" + "testing" + + "cloud.google.com/go/spanner/admin/database/apiv1/databasepb" + "cloud.google.com/go/spanner/apiv1/spannerpb" + "github.com/google/uuid" + "github.com/googleapis/go-sql-spanner/testutil" + "spannerlib/socket-server/client" +) + +func TestConnect(t *testing.T) { + t.Parallel() + + server, teardown := setupMockSpannerServer(t) + defer teardown() + dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address) + + tp, addr, cleanup := startTestSpannerLibServer(t) + defer cleanup() + + pool := client.CreatePool(tp, addr, dsn) + conn, err := pool.CreateConnection() + if err != nil { + t.Fatal(err) + } + if err := conn.Close(); err != nil { + t.Fatal(err) + } +} + +func TestExecute(t *testing.T) { + t.Parallel() + + server, teardown := setupMockSpannerServer(t) + defer teardown() + dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address) + + tp, addr, cleanup := startTestSpannerLibServer(t) + defer cleanup() + + pool := client.CreatePool(tp, addr, dsn) + conn, err := pool.CreateConnection() + if err != nil { + t.Fatal(err) + } + defer func() { _ = conn.Close() }() + rows, err := conn.Execute(&spannerpb.ExecuteSqlRequest{Sql: testutil.SelectFooFromBar}) + if err != nil { + t.Fatal(err) + } + c := 0 + for { + row, err := rows.Next() + if err != nil { + t.Fatal(err) + } + if row == nil { + break + } + c++ + if g, w := len(row.Values), 1; g != w { + t.Fatalf("col count mismatch\n Got: %v\n Want: %v", g, w) + } + } + if g, w := c, 2; g != w { + t.Fatalf("row count mismatch\n Got: %d\n Want: %d", g, w) + } +} + +func TestExecuteBatch(t *testing.T) { + t.Parallel() + + server, teardown := setupMockSpannerServer(t) + defer teardown() + dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address) + + tp, addr, cleanup := startTestSpannerLibServer(t) + defer cleanup() + + pool := client.CreatePool(tp, addr, dsn) + conn, err := pool.CreateConnection() + if err != nil { + t.Fatal(err) + } + defer func() { _ = conn.Close() }() + + result, err := conn.ExecuteBatch(&spannerpb.ExecuteBatchDmlRequest{ + Statements: []*spannerpb.ExecuteBatchDmlRequest_Statement{ + {Sql: testutil.UpdateBarSetFoo}, + {Sql: testutil.UpdateBarSetFoo}, + }, + }) + if err != nil { + t.Fatal(err) + } + if g, w := len(result.ResultSets), 2; g != w { + t.Fatalf("result count mismatch\n Got: %v\n Want: %v", g, w) + } + for _, result := range result.ResultSets { + if g, w := result.Stats.GetRowCountExact(), int64(testutil.UpdateBarSetFooRowCount); g != w { + t.Fatalf("update count mismatch\n Got: %v\n Want: %v", g, w) + } + } +} + +func TestTransaction(t *testing.T) { + t.Parallel() + + server, teardown := setupMockSpannerServer(t) + defer teardown() + dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address) + + tp, addr, cleanup := startTestSpannerLibServer(t) + defer cleanup() + + pool := client.CreatePool(tp, addr, dsn) + conn, err := pool.CreateConnection() + if err != nil { + t.Fatal(err) + } + defer func() { _ = conn.Close() }() + + for _, commit := range []bool{true, false} { + if err := conn.Begin(&spannerpb.TransactionOptions{}); err != nil { + t.Fatal(err) + } + rows, err := conn.Execute(&spannerpb.ExecuteSqlRequest{Sql: testutil.UpdateBarSetFoo}) + if err != nil { + t.Fatal(err) + } + row, err := rows.Next() + if err != nil { + t.Fatal(err) + } + if row != nil { + t.Fatal("expected nil row") + } + + if commit { + resp, err := conn.Commit() + if err != nil { + t.Fatal(err) + } + if resp == nil || resp.CommitTimestamp == nil { + t.Fatal("missing commit timestamp") + } + } else { + if err := conn.Rollback(); err != nil { + t.Fatal(err) + } + } + } +} + +func startTestSpannerLibServer(t *testing.T) (tp, addr string, cleanup func()) { + var name string + if runtime.GOOS == "windows" { + tp = "tcp" + name = "localhost:0" + } else { + tp = "unix" + name = filepath.Join(os.TempDir(), fmt.Sprintf("spannerlib-%s", uuid.NewString())) + } + lis, err := net.Listen(tp, name) + if err != nil { + t.Fatalf("failed to listen: %v\n", err) + } + addr = lis.Addr().String() + server, err := CreateServer() + if err != nil { + t.Fatalf("failed to create server: %v\n", err) + } + go func() { _ = server.Serve(lis) }() + + cleanup = func() { + server.GracefulStop() + _ = os.Remove(name) + } + + return +} + +func setupMockSpannerServer(t *testing.T) (server *testutil.MockedSpannerInMemTestServer, teardown func()) { + return setupMockSpannerServerWithDialect(t, databasepb.DatabaseDialect_GOOGLE_STANDARD_SQL) +} + +func setupMockSpannerServerWithDialect(t *testing.T, dialect databasepb.DatabaseDialect) (server *testutil.MockedSpannerInMemTestServer, teardown func()) { + server, _, serverTeardown := testutil.NewMockedSpannerInMemTestServer(t) + server.SetupSelectDialectResult(dialect) + return server, serverTeardown +} diff --git a/spannerlib/wrappers/spannerlib-dotnet/build-socket-server.sh b/spannerlib/wrappers/spannerlib-dotnet/build-socket-server.sh new file mode 100755 index 00000000..3626e95c --- /dev/null +++ b/spannerlib/wrappers/spannerlib-dotnet/build-socket-server.sh @@ -0,0 +1,18 @@ +# Builds the socket server binary for darwin/arm64, linux/x64, and windows/x64 +# and copies the binaries to the appropriate folders of the .NET wrapper. + +cd ../../socket-server || exit 1 +./build-executables.sh +cd ../wrappers/spannerlib-dotnet || exit 1 + +mkdir -p spannerlib-dotnet-socket-server/binaries/any +rm spannerlib-dotnet-socket-server/binaries/any/spannerlib_socket_server 2> /dev/null + +mkdir -p spannerlib-dotnet-socket-server/binaries/osx-arm64 +cp ../../socket-server/binaries/osx-arm64/spannerlib_socket_server spannerlib-dotnet-socket-server/binaries/osx-arm64/spannerlib_socket_server + +mkdir -p spannerlib-dotnet-socket-server/binaries/linux-x64 +cp ../../socket-server/binaries/linux-x64/spannerlib_socket_server spannerlib-dotnet-socket-server/binaries/linux-x64/spannerlib_socket_server + +mkdir -p spannerlib-dotnet-socket-server/binaries/win-x64 +cp ../../socket-server/binaries/win-x64/spannerlib_socket_server.exe spannerlib-dotnet-socket-server/binaries/win-x64/spannerlib_socket_server.exe diff --git a/spannerlib/wrappers/spannerlib-dotnet/build.sh b/spannerlib/wrappers/spannerlib-dotnet/build.sh index 2e71a82d..8d27979d 100755 --- a/spannerlib/wrappers/spannerlib-dotnet/build.sh +++ b/spannerlib/wrappers/spannerlib-dotnet/build.sh @@ -48,6 +48,10 @@ dotnet nuget locals global-packages --clear echo "Building gRPC server..." ./build-grpc-server.sh +# Build socket server +echo "Building socket server..." +./build-socket-server.sh + # Build shared library echo "Building shared library..." ./build-shared-lib.sh @@ -76,8 +80,17 @@ else mkdir -p "$PWD"/spannerlib-dotnet-grpc-server/bin/Release dotnet nuget add source "$PWD"/spannerlib-dotnet-grpc-server/bin/Release --name local-grpc-server-build fi +dotnet nuget remove source local-socket-server-build 2>/dev/null +if [ "$RUNNER_OS" == "Windows" ]; then + # PWD does not work on Windows + mkdir -p "${GITHUB_WORKSPACE}\spannerlib\wrappers\spannerlib-dotnet\spannerlib-dotnet-socket-server\bin\Release" + dotnet nuget add source "${GITHUB_WORKSPACE}\spannerlib\wrappers\spannerlib-dotnet\spannerlib-dotnet-socket-server\bin\Release" --name local-socket-server-build +else + mkdir -p "$PWD"/spannerlib-dotnet-socket-server/bin/Release + dotnet nuget add source "$PWD"/spannerlib-dotnet-socket-server/bin/Release --name local-socket-server-build +fi -# Create packages for the two components that contain the binaries (shared library + gRPC server) +# Create packages for the two components that contain the binaries (shared library + gRPC server + socket server) find ./**/bin/Release -type f -name "Alpha*.nupkg" -exec rm {} \; cd spannerlib-dotnet-native || exit 1 dotnet pack @@ -85,6 +98,9 @@ cd .. || exit 1 cd spannerlib-dotnet-grpc-server || exit 1 dotnet pack cd .. || exit 1 +cd spannerlib-dotnet-socket-server || exit 1 +dotnet pack +cd .. || exit 1 # Restore the packages of all the projects so they pick up the locally built packages. dotnet restore diff --git a/spannerlib/wrappers/spannerlib-dotnet/spannerlib-dotnet-socket-server-impl/ConnectionImpl.cs b/spannerlib/wrappers/spannerlib-dotnet/spannerlib-dotnet-socket-server-impl/ConnectionImpl.cs new file mode 100644 index 00000000..cab372b7 --- /dev/null +++ b/spannerlib/wrappers/spannerlib-dotnet/spannerlib-dotnet-socket-server-impl/ConnectionImpl.cs @@ -0,0 +1,69 @@ +using System; +using System.Net.Sockets; +using System.Threading; +using System.Threading.Tasks; +using Google.Cloud.SpannerLib.SocketServer.Message; + +namespace Google.Cloud.SpannerLib.SocketServer; + +internal class ConnectionImpl : Connection +{ + private readonly PoolImpl _pool; + + internal NetworkStream Stream { get; } + + internal static ConnectionImpl Create(PoolImpl pool) + { + NetworkStream? stream = null; + try + { + var socket = pool.Spanner.CreateSocket(); + stream = new NetworkStream(socket); + ExecuteStartup(stream, pool); + return new ConnectionImpl(pool, stream); + } + catch (Exception) + { + stream?.Close(); + throw; + } + } + + internal static async Task CreateAsync(PoolImpl pool, CancellationToken cancellationToken) + { + NetworkStream? stream = null; + try + { + Socket socket = await pool.Spanner.CreateSocketAsync().ConfigureAwait(false); + stream = new NetworkStream(socket); + await ExecuteStartupAsync(stream, pool, cancellationToken).ConfigureAwait(false); + return new ConnectionImpl(pool, stream); + } + catch (Exception) + { + stream?.Close(); + throw; + } + } + + private ConnectionImpl(PoolImpl pool, NetworkStream stream) : base(pool, 0) + { + _pool = pool; + Stream = stream; + } + + private static void ExecuteStartup(NetworkStream stream, PoolImpl pool) + { + var startup = new StartupMessage(pool); + startup.Write(stream); + stream.Flush(); + Message.Message.ReadStatusMessage(stream); + } + + private static Task ExecuteStartupAsync(NetworkStream stream, PoolImpl pool, CancellationToken cancellationToken) + { + var startup = new StartupMessage(pool); + return startup.WriteAsync(stream, cancellationToken); + } + +} \ No newline at end of file diff --git a/spannerlib/wrappers/spannerlib-dotnet/spannerlib-dotnet-socket-server-impl/Message/Message.cs b/spannerlib/wrappers/spannerlib-dotnet/spannerlib-dotnet-socket-server-impl/Message/Message.cs new file mode 100644 index 00000000..3f7df4fb --- /dev/null +++ b/spannerlib/wrappers/spannerlib-dotnet/spannerlib-dotnet-socket-server-impl/Message/Message.cs @@ -0,0 +1,47 @@ +using System; +using System.Buffers.Binary; +using System.IO; +using System.Net.Sockets; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Google.Rpc; + +namespace Google.Cloud.SpannerLib.SocketServer.Message; + +internal abstract class Message +{ + internal const byte StatusMessageId = (byte) '0'; + internal const byte StartupMessageId = (byte) 'S'; + + internal static Message ReadMessage(NetworkStream stream) + { + var id = stream.ReadByte(); + if (id == -1) + { + throw new EndOfStreamException(); + } + + var messageId = (byte)id; + return messageId switch + { + StatusMessageId => StatusMessage.Read(stream), + _ => throw new InvalidOperationException("Unknown message id: " + messageId) + }; + } + + internal static void ReadStatusMessage(NetworkStream stream) + { + var message = ReadMessage(stream); + if (message is StatusMessage statusMessage) + { + if (statusMessage.Status.Code == (int) Code.Ok) + { + return; + } + throw new SpannerException(statusMessage.Status); + } + throw new InvalidOperationException("Unexpected message type: " + message.GetType().Name); + } + +} \ No newline at end of file diff --git a/spannerlib/wrappers/spannerlib-dotnet/spannerlib-dotnet-socket-server-impl/Message/RowsMessage.cs b/spannerlib/wrappers/spannerlib-dotnet/spannerlib-dotnet-socket-server-impl/Message/RowsMessage.cs new file mode 100644 index 00000000..7541f976 --- /dev/null +++ b/spannerlib/wrappers/spannerlib-dotnet/spannerlib-dotnet-socket-server-impl/Message/RowsMessage.cs @@ -0,0 +1,6 @@ +namespace Google.Cloud.SpannerLib.SocketServer.Message; + +internal class RowsMessage : Message +{ + +} \ No newline at end of file diff --git a/spannerlib/wrappers/spannerlib-dotnet/spannerlib-dotnet-socket-server-impl/Message/StartupMessage.cs b/spannerlib/wrappers/spannerlib-dotnet/spannerlib-dotnet-socket-server-impl/Message/StartupMessage.cs new file mode 100644 index 00000000..5c72e002 --- /dev/null +++ b/spannerlib/wrappers/spannerlib-dotnet/spannerlib-dotnet-socket-server-impl/Message/StartupMessage.cs @@ -0,0 +1,32 @@ +using System; +using System.Net.Sockets; +using System.Threading; +using System.Threading.Tasks; + +namespace Google.Cloud.SpannerLib.SocketServer.Message; + +internal class StartupMessage : Message +{ + private static readonly byte MessageId = (byte)'S'; + + private PoolImpl Pool { get; } + + internal StartupMessage(PoolImpl pool) + { + Pool = pool; + } + + internal void Write(NetworkStream stream) + { + stream.WriteByte(MessageId); + WriteString(stream, Pool.GeneratedId); + WriteString(stream, Pool.ConnectionString); + } + + internal async Task WriteAsync(NetworkStream stream, CancellationToken cancellationToken) + { + stream.WriteByte(MessageId); + await WriteStringAsync(stream, Pool.GeneratedId, cancellationToken).ConfigureAwait(false); + await WriteStringAsync(stream, Pool.ConnectionString, cancellationToken).ConfigureAwait(false); + } +} diff --git a/spannerlib/wrappers/spannerlib-dotnet/spannerlib-dotnet-socket-server-impl/Message/StatusMessage.cs b/spannerlib/wrappers/spannerlib-dotnet/spannerlib-dotnet-socket-server-impl/Message/StatusMessage.cs new file mode 100644 index 00000000..99f425f3 --- /dev/null +++ b/spannerlib/wrappers/spannerlib-dotnet/spannerlib-dotnet-socket-server-impl/Message/StatusMessage.cs @@ -0,0 +1,21 @@ +using System.Net.Sockets; +using Google.Rpc; + +namespace Google.Cloud.SpannerLib.SocketServer.Message; + +internal class StatusMessage : Message +{ + internal Status Status { get; } + + internal static StatusMessage Read(NetworkStream stream) + { + var bytes = ReadBytes(stream); + var status = Status.Parser.ParseFrom(bytes); + return new StatusMessage(status); + } + + internal StatusMessage(Status status) + { + Status = status; + } +} \ No newline at end of file diff --git a/spannerlib/wrappers/spannerlib-dotnet/spannerlib-dotnet-socket-server-impl/PoolImpl.cs b/spannerlib/wrappers/spannerlib-dotnet/spannerlib-dotnet-socket-server-impl/PoolImpl.cs new file mode 100644 index 00000000..459258c8 --- /dev/null +++ b/spannerlib/wrappers/spannerlib-dotnet/spannerlib-dotnet-socket-server-impl/PoolImpl.cs @@ -0,0 +1,18 @@ +using System; + +namespace Google.Cloud.SpannerLib.SocketServer; + +internal class PoolImpl : Google.Cloud.SpannerLib.Pool +{ + internal SocketLibSpanner Spanner { get; } + + internal string GeneratedId { get; } = Guid.NewGuid().ToString(); + + internal string ConnectionString { get; } + + internal PoolImpl(SocketLibSpanner spanner, long id, string connectionString) : base(spanner, id) + { + Spanner = spanner; + ConnectionString = connectionString; + } +} \ No newline at end of file diff --git a/spannerlib/wrappers/spannerlib-dotnet/spannerlib-dotnet-socket-server-impl/Protocol/Encoding.cs b/spannerlib/wrappers/spannerlib-dotnet/spannerlib-dotnet-socket-server-impl/Protocol/Encoding.cs new file mode 100644 index 00000000..554f18ca --- /dev/null +++ b/spannerlib/wrappers/spannerlib-dotnet/spannerlib-dotnet-socket-server-impl/Protocol/Encoding.cs @@ -0,0 +1,110 @@ +using System; +using System.Buffers.Binary; +using System.IO; +using System.Net.Sockets; +using System.Threading; +using System.Threading.Tasks; + +namespace Google.Cloud.SpannerLib.SocketServer.Protocol; + +internal static class Encoding +{ + + private static void ReadExactly(NetworkStream stream, byte[] buffer) + { + ReadExactly(stream, buffer, 0, buffer.Length); + } + + private static void ReadExactly(NetworkStream stream, byte[] buffer, int offset, int count) + { + var bytesRead = 0; + while (bytesRead < count) + { + var read = stream.Read(buffer, offset + bytesRead, count - bytesRead); + if (read == 0) + { + // The remote host has closed the connection prematurely + throw new EndOfStreamException("The remote host closed the connection before all data was received."); + } + bytesRead += read; + } + } + + internal static byte[] ReadBytes(NetworkStream stream) + { + var length = ReadInt(stream); + var buffer = new byte[length]; + ReadExactly(stream, buffer); + return buffer; + } + + internal static void WriteBytes(NetworkStream stream, byte[] bytes) + { + WriteInt(stream, bytes.Length); + stream.Write(bytes, 0, bytes.Length); + } + + internal static Task WriteBytesAsync(NetworkStream stream, byte[] bytes, CancellationToken cancellationToken) + { + WriteInt(stream, bytes.Length); + return stream.WriteAsync(bytes, 0, bytes.Length, cancellationToken); + } + + internal static void WriteString(NetworkStream stream, string str) + { + var bytes = System.Text.Encoding.UTF8.GetBytes(str); + WriteInt(stream, bytes.Length); + stream.Write(bytes, 0, bytes.Length); + } + + internal static Task WriteStringAsync(NetworkStream stream, string str, CancellationToken cancellationToken) + { + var bytes = System.Text.Encoding.UTF8.GetBytes(str); + WriteInt(stream, bytes.Length); + return stream.WriteAsync(bytes, 0, bytes.Length, cancellationToken); + } + + internal static int ReadInt(NetworkStream stream) + { + var buffer = new byte[4]; + ReadExactly(stream, buffer); + return BinaryPrimitives.ReadInt32BigEndian(buffer); + } + + internal static void WriteInt(NetworkStream stream, int value) + { + Span buffer = stackalloc byte[4]; + BinaryPrimitives.WriteInt32BigEndian(buffer, value); + stream.Write(buffer); + } + + internal static long ReadLong(NetworkStream stream) + { + var buffer = new byte[8]; + ReadExactly(stream, buffer); + return BinaryPrimitives.ReadInt64BigEndian(buffer); + } + + internal static void WriteLong(NetworkStream stream, long value) + { + Span buffer = stackalloc byte[8]; + BinaryPrimitives.WriteInt64BigEndian(buffer, value); + stream.Write(buffer); + } + + internal static bool ReadBool(NetworkStream stream) + { + var b = stream.ReadByte(); + if (b == -1) + { + throw new EndOfStreamException(); + } + return b != 0; + } + + internal static void WriteBool(NetworkStream stream, bool value) + { + stream.WriteByte((byte)(value ? 1 : 0)); + } + +} \ No newline at end of file diff --git a/spannerlib/wrappers/spannerlib-dotnet/spannerlib-dotnet-socket-server-impl/RowsImpl.cs b/spannerlib/wrappers/spannerlib-dotnet/spannerlib-dotnet-socket-server-impl/RowsImpl.cs new file mode 100644 index 00000000..cf71c307 --- /dev/null +++ b/spannerlib/wrappers/spannerlib-dotnet/spannerlib-dotnet-socket-server-impl/RowsImpl.cs @@ -0,0 +1,48 @@ +using System.Net.Sockets; +using Google.Cloud.Spanner.V1; +using Google.Cloud.SpannerLib.SocketServer.Protocol; +using Google.Protobuf.WellKnownTypes; + +namespace Google.Cloud.SpannerLib.SocketServer; + +internal class RowsImpl : Rows +{ + internal RowsImpl Create(ConnectionImpl connection) + { + var stream = connection.Stream; + var id = Encoding.ReadLong(stream); + var metadataBytes = Encoding.ReadBytes(stream); + var metadata = ResultSetMetadata.Parser.ParseFrom(metadataBytes); + return new RowsImpl(connection, id, metadata); + } + + private readonly ConnectionImpl _connection; + + private readonly NetworkStream _stream; + + public override ResultSetMetadata? Metadata { get; } + + private ResultSetStats _stats; + + internal RowsImpl(ConnectionImpl connection, long id, ResultSetMetadata metadata) : base(connection, id) + { + _connection = connection; + _stream = connection.Stream; + Metadata = metadata; + } + + public override ListValue? Next() + { + var hasMoreRows = Encoding.ReadBool(_stream); + if (!hasMoreRows) + { + + } + } + + private ResultSetStats ReadStats() + { + var bytes = Encoding.ReadBytes(_stream); + return ResultSetStats.Parser.ParseFrom(bytes); + } +} \ No newline at end of file diff --git a/spannerlib/wrappers/spannerlib-dotnet/spannerlib-dotnet-socket-server-impl/SocketLibSpanner.cs b/spannerlib/wrappers/spannerlib-dotnet/spannerlib-dotnet-socket-server-impl/SocketLibSpanner.cs new file mode 100644 index 00000000..9c9a7c42 --- /dev/null +++ b/spannerlib/wrappers/spannerlib-dotnet/spannerlib-dotnet-socket-server-impl/SocketLibSpanner.cs @@ -0,0 +1,334 @@ +using System; +using System.Collections.Generic; +using System.Net.Sockets; +using System.Threading; +using System.Threading.Tasks; +using Google.Api.Gax; +using Google.Cloud.Spanner.V1; +using Google.Protobuf.WellKnownTypes; +using Grpc.Core; + +namespace Google.Cloud.SpannerLib.SocketServer; + +public sealed class SocketLibSpanner : ISpannerLib +{ + private readonly Server _server; + private readonly string _address; + private bool _disposed; + private readonly Dictionary _pools = new(); + private readonly Dictionary _connections = new(); + + public SocketLibSpanner(Server.AddressType addressType = Server.AddressType.UnixDomainSocket) + { + _server = new Server(); + _address = _server.Start(addressType: addressType); + } + + ~SocketLibSpanner() => Dispose(false); + + public void Dispose() + { + Dispose(true); + GC.SuppressFinalize(this); + } + + private void Dispose(bool disposing) + { + if (_disposed) + { + return; + } + try + { + _server.Dispose(); + } + finally + { + _disposed = true; + } + } + + internal Socket CreateSocket() + { + var socket = new Socket(AddressFamily.Unix, SocketType.Stream, ProtocolType.Unspecified); + var endpoint = new UnixDomainSocketEndPoint(_address); + socket.Connect(endpoint); + return socket; + } + + internal async Task CreateSocketAsync() + { + var socket = new Socket(AddressFamily.Unix, SocketType.Stream, ProtocolType.Unspecified); + var endpoint = new UnixDomainSocketEndPoint(_address); + await socket.ConnectAsync(endpoint).ConfigureAwait(false); + return socket; + } + + public Pool CreatePool(string connectionString) + { + var pool = new PoolImpl(this, 0, connectionString); + _pools.Add(pool, pool); + return pool; + } + + public void ClosePool(Pool pool) + { + _pools.Remove(pool); + } + + public Connection CreateConnection(Pool pool) + { + var connection = ConnectionImpl.Create(_pools[pool]); + _connections.Add(connection, connection); + return connection; + } + + public void CloseConnection(Connection connection) + { + _connections.Remove(connection); + } + + public Task CloseConnectionAsync(Connection connection, CancellationToken cancellationToken = default) + { + _connections.Remove(connection); + return Task.CompletedTask; + } + + public CommitResponse? WriteMutations(Connection connection, BatchWriteRequest.Types.MutationGroup mutations) + { + throw new NotImplementedException(); + } + + public Task WriteMutationsAsync(Connection connection, + BatchWriteRequest.Types.MutationGroup mutations, CancellationToken cancellationToken = default) + { + throw new NotImplementedException(); + } + + public Rows Execute(Connection connection, ExecuteSqlRequest statement, int prefetchRows = 0) + { + return ExecuteStreaming(connection, statement, prefetchRows); + } + + public async Task ExecuteAsync(Connection connection, ExecuteSqlRequest statement, int prefetchRows = 0, CancellationToken cancellationToken = default) + { + try + { + return await ExecuteStreamingAsync(connection, statement, prefetchRows, cancellationToken).ConfigureAwait(false); + } + catch (RpcException exception) + { + throw SpannerException.ToSpannerException(exception); + } + } + + private async Task ExecuteStreamingAsync(Connection connection, ExecuteSqlRequest statement, int prefetchRows, CancellationToken cancellationToken = default) + { + var client = _clients[Random.Shared.Next(_clients.Length)]; + var stream = TranslateException(() => client.ExecuteStreaming(new ExecuteRequest + { + Connection = ToProto(connection), + ExecuteSqlRequest = statement, + FetchOptions = new FetchOptions + { + NumRows = prefetchRows, + }, + })); + return await StreamingRows.CreateAsync(this, connection, stream, cancellationToken).ConfigureAwait(false); + } + + internal AsyncServerStreamingCall ContinueStreamingAsync(Connection connection, long rowsId, CancellationToken cancellationToken) + { + var client = _clients[Random.Shared.Next(_clients.Length)]; + return TranslateException(() => client.ContinueStreaming(new V1.Rows + { + Connection = ToProto(connection), + Id = rowsId, + }, cancellationToken: cancellationToken)); + } + + public long[] ExecuteBatch(Connection connection, ExecuteBatchDmlRequest statements) + { + var response = TranslateException(() => Client.ExecuteBatch(new ExecuteBatchRequest + { + Connection = ToProto(connection), + ExecuteBatchDmlRequest = statements, + })); + return ISpannerLib.ToUpdateCounts(response); + } + + public async Task ExecuteBatchAsync(Connection connection, ExecuteBatchDmlRequest statements, CancellationToken cancellationToken = default) + { + try + { + var response = await Client.ExecuteBatchAsync(new ExecuteBatchRequest + { + Connection = ToProto(connection), + ExecuteBatchDmlRequest = statements, + }, cancellationToken: cancellationToken).ConfigureAwait(false); + return ISpannerLib.ToUpdateCounts(response); + } + catch (RpcException exception) + { + throw SpannerException.ToSpannerException(exception); + } + } + + public ResultSetMetadata? Metadata(Rows rows) + { + return TranslateException(() => Client.Metadata(ToProto(rows))); + } + + public async Task MetadataAsync(Rows rows, CancellationToken cancellationToken = default) + { + try + { + return await Client.MetadataAsync(ToProto(rows), cancellationToken: cancellationToken).ConfigureAwait(false); + } + catch (RpcException exception) + { + throw SpannerException.ToSpannerException(exception); + } + } + + public ResultSetMetadata? NextResultSet(Rows rows) + { + return TranslateException(() => Client.NextResultSet(ToProto(rows))); + } + + public async Task NextResultSetAsync(Rows rows, CancellationToken cancellationToken = default) + { + try + { + return await Client.NextResultSetAsync(ToProto(rows), cancellationToken: cancellationToken).ConfigureAwait(false); + } + catch (RpcException exception) + { + throw SpannerException.ToSpannerException(exception); + } + } + + public ResultSetStats? Stats(Rows rows) + { + return TranslateException(() => Client.ResultSetStats(ToProto(rows))); + } + + public ListValue? Next(Rows rows, int numRows, ISpannerLib.RowEncoding encoding) + { + var row = TranslateException(() =>Client.Next(new NextRequest + { + Rows = ToProto(rows), + FetchOptions = new FetchOptions + { + NumRows = numRows, + Encoding = (long) encoding, + }, + })); + return row.Values.Count == 0 ? null : row; + } + + public async Task NextAsync(Rows rows, int numRows, ISpannerLib.RowEncoding encoding, CancellationToken cancellationToken = default) + { + try + { + return await Client.NextAsync(new NextRequest + { + Rows = ToProto(rows), + FetchOptions = new FetchOptions + { + NumRows = numRows, + Encoding = (long) encoding, + }, + }, cancellationToken: cancellationToken).ConfigureAwait(false); + } + catch (RpcException exception) + { + throw SpannerException.ToSpannerException(exception); + } + } + + public void CloseRows(Rows rows) + { + TranslateException(() => Client.CloseRows(ToProto(rows))); + } + + public async Task CloseRowsAsync(Rows rows, CancellationToken cancellationToken = default) + { + try + { + await Client.CloseRowsAsync(ToProto(rows), cancellationToken: cancellationToken).ConfigureAwait(false); + } + catch (RpcException exception) + { + throw SpannerException.ToSpannerException(exception); + } + } + + public void BeginTransaction(Connection connection, TransactionOptions transactionOptions) + { + TranslateException(() => Client.BeginTransaction(new BeginTransactionRequest + { + Connection = ToProto(connection), + TransactionOptions = transactionOptions, + })); + } + + public async Task BeginTransactionAsync(Connection connection, TransactionOptions transactionOptions, CancellationToken cancellationToken = default) + { + await TranslateException(() => Client.BeginTransactionAsync(new BeginTransactionRequest + { + Connection = ToProto(connection), + TransactionOptions = transactionOptions, + })).ConfigureAwait(false); + } + + public CommitResponse? Commit(Connection connection) + { + var response = TranslateException(() => Client.Commit(ToProto(connection))); + return response.CommitTimestamp == null ? null : response; + } + + public async Task CommitAsync(Connection connection, CancellationToken cancellationToken = default) + { + try + { + var response = await Client.CommitAsync(ToProto(connection), cancellationToken: cancellationToken).ConfigureAwait(false); + return response.CommitTimestamp == null ? null : response; + } + catch (RpcException exception) + { + throw new SpannerException(new Status { Code = (int) exception.Status.StatusCode, Message = exception.Status.Detail }); + } + } + + public void Rollback(Connection connection) + { + TranslateException(() => Client.Rollback(ToProto(connection))); + } + + public async Task RollbackAsync(Connection connection, CancellationToken cancellationToken = default) + { + try + { + await Client.RollbackAsync(ToProto(connection), cancellationToken: cancellationToken).ConfigureAwait(false); + } + catch (RpcException exception) + { + throw SpannerException.ToSpannerException(exception); + } + } + + private PoolImpl FromProto(V1.Pool pool) => new(this, pool.Id); + + private static V1.Pool ToProto(PoolImpl poolImpl) => new() { Id = poolImpl.Id }; + + private Connection FromProto(PoolImpl poolImpl, V1.Connection proto) => + _communicationStyle == CommunicationStyle.ServerStreaming + ? new Connection(poolImpl, proto.Id) + : new GrpcBidiConnection(this, poolImpl, proto.Id); + + internal static V1.Connection ToProto(Connection connection) => new() { Id = connection.Id, Pool = ToProto(connection.Pool), }; + + private V1.Rows ToProto(Rows rows) => new() { Id = rows.Id, Connection = ToProto(rows.SpannerConnection), }; + +} \ No newline at end of file diff --git a/spannerlib/wrappers/spannerlib-dotnet/spannerlib-dotnet-socket-server-impl/spannerlib-dotnet-socket-server-impl.csproj b/spannerlib/wrappers/spannerlib-dotnet/spannerlib-dotnet-socket-server-impl/spannerlib-dotnet-socket-server-impl.csproj new file mode 100644 index 00000000..d4d611af --- /dev/null +++ b/spannerlib/wrappers/spannerlib-dotnet/spannerlib-dotnet-socket-server-impl/spannerlib-dotnet-socket-server-impl.csproj @@ -0,0 +1,22 @@ + + + + netstandard2.1 + Google.Cloud.SpannerLib.SocketServer + enable + default + Alpha.Google.Cloud.SpannerLib.SocketServerImpl + 1.0.0-alpha.20251217171434 + Google + + + + + + + + + + + + diff --git a/spannerlib/wrappers/spannerlib-dotnet/spannerlib-dotnet-socket-server/Server.cs b/spannerlib/wrappers/spannerlib-dotnet/spannerlib-dotnet-socket-server/Server.cs new file mode 100644 index 00000000..9237227c --- /dev/null +++ b/spannerlib/wrappers/spannerlib-dotnet/spannerlib-dotnet-socket-server/Server.cs @@ -0,0 +1,195 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System; +using System.Diagnostics; +using System.IO; +using System.Runtime.InteropServices; +using System.Threading; + +namespace Google.Cloud.SpannerLib.SocketServer; + +public class Server : IDisposable +{ + public enum AddressType + { + UnixDomainSocket, + Tcp, + } + + private Process? _process; + private string? _host; + private bool _disposed; + + public bool IsRunning => _process is { HasExited: false }; + + public Server() + { + } + + public string Start(AddressType addressType = AddressType.UnixDomainSocket) + { + if (_disposed) + { + throw new ObjectDisposedException(nameof(Server)); + } + if (IsRunning) + { + throw new InvalidOperationException("The server is already started."); + } + (_host, _process) = StartServer(addressType, TimeSpan.FromSeconds(5)); + return _host; + } + + private static Tuple StartServer(AddressType addressType, TimeSpan timeout) + { + string arguments; + if (addressType == AddressType.UnixDomainSocket) + { + // Generate a random temp file name that will be used for the Unix domain socket communication. + arguments = Path.GetTempPath() + Guid.NewGuid(); + } + else if (addressType == AddressType.Tcp) + { + arguments = "localhost:0 tcp"; + } + else + { + arguments = "localhost:0 tcp"; + } + + var binaryFileName = GetBinaryFileName().Replace('/', Path.DirectorySeparatorChar); + var info = new ProcessStartInfo + { + Arguments = arguments, + UseShellExecute = false, + FileName = binaryFileName, + RedirectStandardOutput = true, + RedirectStandardError = true, + }; + // Start the process as a child process. The process will automatically stop when the + // parent process stops. + var process = Process.Start(info); + if (process == null) + { + throw new InvalidOperationException("Failed to start spanner"); + } + if (addressType == AddressType.UnixDomainSocket) + { + var watch = new Stopwatch(); + while (!File.Exists(arguments)) + { + if (watch.Elapsed > timeout) + { + throw new TimeoutException($"Attempt to start server timed out after {timeout}"); + } + Thread.Sleep(1); + } + } + if (addressType == AddressType.UnixDomainSocket) + { + // Return the name of the Unix domain socket. + return Tuple.Create(arguments, process); + } + // Read the dynamically assigned port. + var address = process.StandardError.ReadLine(); + if (address?.Contains("Starting server on") ?? false) + { + var lastSpace = address.LastIndexOf(" ", StringComparison.Ordinal); + return Tuple.Create(address.Substring(lastSpace + 1), process); + } + throw new InvalidOperationException("Failed to read server address"); + } + + private static string GetBinaryFileName() + { + string? fileName = null; + if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + { + switch (RuntimeInformation.OSArchitecture) + { + case Architecture.X64: + fileName = "runtimes/win-x64/native/spannerlib_socket_server.exe"; + break; + } + } + else if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux)) + { + switch (RuntimeInformation.OSArchitecture) + { + case Architecture.X64: + fileName = "runtimes/linux-x64/native/spannerlib_socket_server"; + break; + } + } + else if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX)) + { + switch (RuntimeInformation.ProcessArchitecture) + { + case Architecture.Arm64: + fileName = "runtimes/osx-arm64/native/spannerlib_socket_server"; + break; + } + } + if (fileName != null && File.Exists(fileName)) + { + return fileName; + } + if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + { + if (File.Exists("runtimes/any/native/spannerlib_socket_server.exe")) + { + return "runtimes/any/native/spannerlib_socket_server.exe"; + } + } + if (File.Exists("runtimes/any/native/spannerlib_socket_server")) + { + return "runtimes/any/native/spannerlib_socket_server"; + } + + throw new PlatformNotSupportedException(); + } + + public void Dispose() + { + Dispose(true); + GC.SuppressFinalize(this); + } + + public void Stop() + { + if (_process == null || _process.HasExited) + { + return; + } + _process.Kill(); + } + + protected virtual void Dispose(bool disposing) + { + if (_disposed) + { + return; + } + try + { + Stop(); + _process?.Dispose(); + } + finally + { + _disposed = true; + } + } +} \ No newline at end of file diff --git a/spannerlib/wrappers/spannerlib-dotnet/spannerlib-dotnet-socket-server/spannerlib-dotnet-socket-server.csproj b/spannerlib/wrappers/spannerlib-dotnet/spannerlib-dotnet-socket-server/spannerlib-dotnet-socket-server.csproj new file mode 100644 index 00000000..ef666409 --- /dev/null +++ b/spannerlib/wrappers/spannerlib-dotnet/spannerlib-dotnet-socket-server/spannerlib-dotnet-socket-server.csproj @@ -0,0 +1,25 @@ + + + + netstandard2.1 + Google.Cloud.SpannerLib.SocketServer + enable + default + Alpha.Google.Cloud.SpannerLib.SocketServer + SpannerLib Socket Server + Google + 1.0.0-alpha.20251217171434 + + + + + + + + + + + + + + diff --git a/spannerlib/wrappers/spannerlib-dotnet/spannerlib-dotnet.sln b/spannerlib/wrappers/spannerlib-dotnet/spannerlib-dotnet.sln index c12c60ca..11bfdf2d 100644 --- a/spannerlib/wrappers/spannerlib-dotnet/spannerlib-dotnet.sln +++ b/spannerlib/wrappers/spannerlib-dotnet/spannerlib-dotnet.sln @@ -18,6 +18,10 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "spannerlib-dotnet-grpc-impl EndProject Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "spannerlib-dotnet-grpc-v1", "spannerlib-dotnet-grpc-v1\spannerlib-dotnet-grpc-v1.csproj", "{C2538D0C-6544-4B44-88A7-02517B786FFD}" EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "spannerlib-dotnet-socket-server", "spannerlib-dotnet-socket-server\spannerlib-dotnet-socket-server.csproj", "{C343D513-6F3D-4AB1-9951-C967C568BCC6}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "spannerlib-dotnet-socket-server-impl", "spannerlib-dotnet-socket-server-impl\spannerlib-dotnet-socket-server-impl.csproj", "{90E2B68A-9133-4F36-8E41-479E1F24C5BE}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -60,5 +64,13 @@ Global {C2538D0C-6544-4B44-88A7-02517B786FFD}.Debug|Any CPU.Build.0 = Debug|Any CPU {C2538D0C-6544-4B44-88A7-02517B786FFD}.Release|Any CPU.ActiveCfg = Release|Any CPU {C2538D0C-6544-4B44-88A7-02517B786FFD}.Release|Any CPU.Build.0 = Release|Any CPU + {C343D513-6F3D-4AB1-9951-C967C568BCC6}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {C343D513-6F3D-4AB1-9951-C967C568BCC6}.Debug|Any CPU.Build.0 = Debug|Any CPU + {C343D513-6F3D-4AB1-9951-C967C568BCC6}.Release|Any CPU.ActiveCfg = Release|Any CPU + {C343D513-6F3D-4AB1-9951-C967C568BCC6}.Release|Any CPU.Build.0 = Release|Any CPU + {90E2B68A-9133-4F36-8E41-479E1F24C5BE}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {90E2B68A-9133-4F36-8E41-479E1F24C5BE}.Debug|Any CPU.Build.0 = Debug|Any CPU + {90E2B68A-9133-4F36-8E41-479E1F24C5BE}.Release|Any CPU.ActiveCfg = Release|Any CPU + {90E2B68A-9133-4F36-8E41-479E1F24C5BE}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection EndGlobal