Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion pkg/blobstore/grpcclients/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ go_library(
"fsac_blob_access.go",
"icas_blob_access.go",
"iscc_blob_access.go",
"zstd_config.go",
"zstd_pool.go",
],
importpath = "github.com/buildbarn/bb-storage/pkg/blobstore/grpcclients",
visibility = ["//visibility:public"],
Expand All @@ -28,12 +30,16 @@ go_library(
"@org_golang_google_grpc//codes",
"@org_golang_google_grpc//metadata",
"@org_golang_google_grpc//status",
"@org_golang_x_sync//semaphore",
],
)

go_test(
name = "grpcclients_test",
srcs = ["cas_blob_access_test.go"],
srcs = [
"cas_blob_access_test.go",
"zstd_pool_test.go",
],
deps = [
":grpcclients",
"//internal/mock",
Expand Down
129 changes: 103 additions & 26 deletions pkg/blobstore/grpcclients/cas_blob_access.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package grpcclients

import (
"context"
"errors"
"io"
"slices"
"sync"
Expand All @@ -14,7 +15,6 @@ import (
"github.com/buildbarn/bb-storage/pkg/digest"
"github.com/buildbarn/bb-storage/pkg/util"
"github.com/google/uuid"
"github.com/klauspost/compress/zstd"

"google.golang.org/genproto/googleapis/bytestream"
"google.golang.org/grpc"
Expand All @@ -23,6 +23,37 @@ import (
"google.golang.org/grpc/status"
)

// DefaultZstdPool is a shared pool for ZSTD encoders/decoders.
// It is initialized with default settings on first use.
// This can be overridden via SetDefaultZstdPool for custom configurations.
var (
defaultZstdPool *BoundedZstdPool
defaultZstdPoolOnce sync.Once
)

// GetDefaultZstdPool returns the shared ZSTD pool, initializing it if needed.
func GetDefaultZstdPool() *BoundedZstdPool {
defaultZstdPoolOnce.Do(func() {
defaultZstdPool = DefaultZstdPoolConfig().NewPool()
})
return defaultZstdPool
}

// SetDefaultZstdPool allows overriding the default pool before first use.
// This should be called during initialization, before any CAS operations.
// It panics if called after the pool has already been initialized.
func SetDefaultZstdPool(pool *BoundedZstdPool) {
// Check if already initialized
var alreadyInit bool
defaultZstdPoolOnce.Do(func() {
defaultZstdPool = pool
alreadyInit = false
})
if alreadyInit {
panic("SetDefaultZstdPool called after pool was already initialized")
}
}

type casBlobAccess struct {
byteStreamClient bytestream.ByteStreamClient
contentAddressableStorageClient remoteexecution.ContentAddressableStorageClient
Expand All @@ -31,6 +62,7 @@ type casBlobAccess struct {
readChunkSize int
enableZSTDCompression bool
supportedCompressors atomic.Pointer[[]remoteexecution.Compressor_Value]
zstdPool *BoundedZstdPool
}

// NewCASBlobAccess creates a BlobAccess handle that relays any requests
Expand All @@ -42,13 +74,23 @@ type casBlobAccess struct {
// If enableZSTDCompression is true, the client will use ZSTD compression
// for ByteStream operations if the server supports it.
func NewCASBlobAccess(client grpc.ClientConnInterface, uuidGenerator util.UUIDGenerator, readChunkSize int, enableZSTDCompression bool) blobstore.BlobAccess {
return NewCASBlobAccessWithPool(client, uuidGenerator, readChunkSize, enableZSTDCompression, nil)
}

// NewCASBlobAccessWithPool creates a BlobAccess handle with a custom ZSTD pool.
// If pool is nil, the default shared pool will be used.
func NewCASBlobAccessWithPool(client grpc.ClientConnInterface, uuidGenerator util.UUIDGenerator, readChunkSize int, enableZSTDCompression bool, pool *BoundedZstdPool) blobstore.BlobAccess {
if pool == nil {
pool = GetDefaultZstdPool()
}
return &casBlobAccess{
byteStreamClient: bytestream.NewByteStreamClient(client),
contentAddressableStorageClient: remoteexecution.NewContentAddressableStorageClient(client),
capabilitiesClient: remoteexecution.NewCapabilitiesClient(client),
uuidGenerator: uuidGenerator,
readChunkSize: readChunkSize,
enableZSTDCompression: enableZSTDCompression,
zstdPool: pool,
}
}

Expand All @@ -74,47 +116,61 @@ func (r *byteStreamChunkReader) Close() {
}
}

// zstdByteStreamChunkReader reads compressed data from gRPC stream and decompresses using pooled decoder.
type zstdByteStreamChunkReader struct {
client bytestream.ByteStream_ReadClient
cancel context.CancelFunc
zstdReader io.ReadCloser
pool *BoundedZstdPool
decoder *DecoderWrapper
pipeReader *io.PipeReader
pipeWriter *io.PipeWriter
readChunkSize int
wg sync.WaitGroup
initOnce sync.Once
initErr error
}

func (r *zstdByteStreamChunkReader) Read() ([]byte, error) {
if r.zstdReader == nil {
pr, pw := io.Pipe()
func (r *zstdByteStreamChunkReader) init(ctx context.Context) error {
r.initOnce.Do(func() {
r.pipeReader, r.pipeWriter = io.Pipe()

// Start goroutine to read from gRPC and write to pipe
r.wg.Add(1)
go func() {
defer r.wg.Done()
defer pw.Close()
defer r.pipeWriter.Close()
for {
chunk, err := r.client.Recv()
if err != nil {
if err != io.EOF {
pw.CloseWithError(err)
r.pipeWriter.CloseWithError(err)
}
return
}
if _, writeErr := pw.Write(chunk.Data); writeErr != nil {
pw.CloseWithError(writeErr)
if _, writeErr := r.pipeWriter.Write(chunk.Data); writeErr != nil {
r.pipeWriter.CloseWithError(writeErr)
return
}
}
}()

var err error
r.zstdReader, err = util.NewZstdReadCloser(pr, zstd.WithDecoderConcurrency(1))
if err != nil {
pr.Close()
return nil, err
// Acquire decoder from pool (blocking if at capacity)
r.decoder, r.initErr = r.pool.AcquireDecoder(ctx, r.pipeReader)
if r.initErr != nil {
r.pipeReader.CloseWithError(r.initErr)
}
})
return r.initErr
}

func (r *zstdByteStreamChunkReader) Read() ([]byte, error) {
// Lazy initialization on first read - allows context to be passed
if err := r.init(context.Background()); err != nil {
return nil, err
}

buf := make([]byte, r.readChunkSize)
n, err := r.zstdReader.Read(buf)
n, err := r.decoder.Read(buf)
if n > 0 {
if err != nil && err != io.EOF {
err = nil
Expand All @@ -125,12 +181,19 @@ func (r *zstdByteStreamChunkReader) Read() ([]byte, error) {
}

func (r *zstdByteStreamChunkReader) Close() {
if r.zstdReader != nil {
r.zstdReader.Close()
// Release decoder back to pool
if r.decoder != nil {
r.pool.ReleaseDecoder(r.decoder)
r.decoder = nil
}

if r.pipeReader != nil {
r.pipeReader.Close()
}

r.cancel()

// Drain the gRPC stream.
// Drain the gRPC stream
for {
if _, err := r.client.Recv(); err != nil {
break
Expand Down Expand Up @@ -223,6 +286,7 @@ func (ba *casBlobAccess) Get(ctx context.Context, digest digest.Digest) buffer.B
return buffer.NewCASBufferFromChunkReader(digest, &zstdByteStreamChunkReader{
client: client,
cancel: cancel,
pool: ba.zstdPool,
readChunkSize: ba.readChunkSize,
}, buffer.BackendProvided(buffer.Irreparable(digest)))
}
Expand Down Expand Up @@ -269,21 +333,34 @@ func (ba *casBlobAccess) Put(ctx context.Context, digest digest.Digest, b buffer
cancel: cancel,
}

zstdWriter, err := zstd.NewWriter(byteStreamWriter, zstd.WithEncoderConcurrency(1))
// Acquire encoder from pool (blocks if at capacity - provides backpressure)
encoder, err := ba.zstdPool.AcquireEncoder(ctx, byteStreamWriter)
if err != nil {
cancel()
client.CloseAndRecv()
return status.Errorf(codes.Internal, "Failed to create zstd writer: %v", err)
b.Discard()
if _, closeErr := client.CloseAndRecv(); closeErr != nil {
return status.Errorf(codes.Internal, "Failed to close client: %v and acquire encoder: %v", closeErr, err)
}
return status.Errorf(codes.ResourceExhausted, "Failed to acquire ZSTD encoder: %v", err)
}

if err := b.IntoWriter(zstdWriter); err != nil {
zstdWriter.Close()
byteStreamWriter.Close()
// Ensure encoder is returned to pool
defer ba.zstdPool.ReleaseEncoder(encoder)

if err := b.IntoWriter(encoder.Encoder); err != nil {
if zstdCloseErr := encoder.Close(); zstdCloseErr != nil {
err = errors.Join(err, zstdCloseErr)
}
if closeErr := byteStreamWriter.Close(); closeErr != nil {
err = errors.Join(err, closeErr)
}
return err
}

if err := zstdWriter.Close(); err != nil {
byteStreamWriter.Close()
if err := encoder.Close(); err != nil {
if closeErr := byteStreamWriter.Close(); closeErr != nil {
err = errors.Join(err, closeErr)
}
return err
}

Expand Down
Loading