Skip to content

Commit f97369d

Browse files
committed
implement WriteStateBytes
1 parent a032b33 commit f97369d

File tree

2 files changed

+83
-0
lines changed

2 files changed

+83
-0
lines changed

tfprotov6/state_store.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,16 @@ type ReadStateBytesStream struct {
5656
Chunks iter.Seq[ReadStateByteChunk]
5757
}
5858

59+
type WriteStateBytesStream struct {
60+
Chunks iter.Seq[WriteStateByteChunk]
61+
}
62+
63+
type WriteStateBytesResponse struct {
64+
Diagnostics []*Diagnostic
65+
}
66+
67+
type WriteStateByteChunk = StateByteChunk
68+
5969
type ReadStateByteChunk struct {
6070
StateByteChunk
6171
Diagnostics []*Diagnostic

tfprotov6/tf6server/server.go

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"encoding/json"
99
"errors"
1010
"fmt"
11+
"io"
1112
"os"
1213
"os/signal"
1314
"regexp"
@@ -1624,6 +1625,78 @@ func (s *server) ReadStateBytes(protoReq *tfplugin6.ReadStateBytes_Request, prot
16241625
return nil
16251626
}
16261627

1628+
func (s *server) WriteStateBytes(srv grpc.ClientStreamingServer[tfplugin6.WriteStateBytes_RequestChunk, tfplugin6.WriteStateBytes_Response]) error {
1629+
rpc := "WriteStateBytes"
1630+
ctx := srv.Context()
1631+
ctx = s.loggingContext(ctx)
1632+
ctx = logging.RpcContext(ctx, rpc)
1633+
// ctx = logging.StateStoreContext(ctx, protoReq.TypeName)
1634+
ctx = s.stoppableContext(ctx)
1635+
// logging.ProtocolTrace(ctx, "Received request")
1636+
// defer logging.ProtocolTrace(ctx, "Served request")
1637+
1638+
ctx = tf6serverlogging.DownstreamRequest(ctx)
1639+
1640+
server, ok := s.downstream.(tfprotov6.StateStoreServer)
1641+
if !ok {
1642+
err := status.Error(codes.Unimplemented, "ProviderServer does not implement WriteStateBytes")
1643+
logging.ProtocolError(ctx, err.Error())
1644+
return err
1645+
}
1646+
1647+
iterator := func(yield func(tfprotov6.WriteStateByteChunk) bool) {
1648+
for {
1649+
chunk, err := srv.Recv()
1650+
if err == io.EOF {
1651+
break
1652+
}
1653+
if err != nil {
1654+
// attempt to send the error back to client
1655+
msgErr := srv.SendMsg(&tfplugin6.WriteStateBytes_Response{
1656+
Diagnostics: toproto.Diagnostics([]*tfprotov6.Diagnostic{
1657+
{
1658+
Severity: tfprotov6.DiagnosticSeverityError,
1659+
Summary: "Writing state chunk failed",
1660+
Detail: fmt.Sprintf("Attempt to write a byte chunk of state %q to %q failed: %s",
1661+
chunk.StateId, chunk.TypeName, err),
1662+
},
1663+
}),
1664+
})
1665+
if msgErr != nil {
1666+
err := status.Error(codes.Unimplemented, "ProviderServer does not implement WriteStateBytes")
1667+
logging.ProtocolError(ctx, err.Error())
1668+
return
1669+
}
1670+
return
1671+
}
1672+
1673+
ok := yield(tfprotov6.WriteStateByteChunk{
1674+
Bytes: chunk.Bytes,
1675+
TotalLength: chunk.TotalLength,
1676+
Range: tfprotov6.StateByteRange{
1677+
Start: chunk.Range.Start,
1678+
End: chunk.Range.End,
1679+
},
1680+
})
1681+
if !ok {
1682+
return
1683+
}
1684+
1685+
}
1686+
}
1687+
1688+
resp, err := server.WriteStateBytes(ctx, &tfprotov6.WriteStateBytesStream{
1689+
Chunks: iterator,
1690+
})
1691+
if err != nil {
1692+
return err
1693+
}
1694+
1695+
return srv.SendAndClose(&tfplugin6.WriteStateBytes_Response{
1696+
Diagnostics: toproto.Diagnostics(resp.Diagnostics),
1697+
})
1698+
}
1699+
16271700
func (s *server) GetStates(ctx context.Context, protoReq *tfplugin6.GetStates_Request) (*tfplugin6.GetStates_Response, error) {
16281701
rpc := "GetStates"
16291702
ctx = s.loggingContext(ctx)

0 commit comments

Comments
 (0)