Skip to content

Commit 27c8c29

Browse files
lgarofaloLeland Garofalo
and
Leland Garofalo
authoredJul 25, 2023
Testing retry and timeout for signing ops (#366)
* Testing retry and timeout for signing ops * bugfixes * Adjust use of context for retry/timeout --------- Co-authored-by: Leland Garofalo <leland@cloudflare.com>
1 parent 83f280f commit 27c8c29

File tree

3 files changed

+98
-7
lines changed

3 files changed

+98
-7
lines changed
 

‎cmd/gokeyless/gokeyless.go

+13
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,9 @@ type Config struct {
6464
TracingEnabled bool `yaml:"tracing_enabled" mapstructure:"tracing_enabled"`
6565
TracingAddress string `yaml:"tracing_address" mapstructure:"tracing_address"`
6666
TracingSampleRate float64 `yaml:"tracing_sample_rate" mapstructure:"tracing_sample_rate"` // between 0 and 1
67+
68+
SignTimeout string `yaml:"sign_timeout" mapstructure:"sign_timeout"`
69+
SignRetryCount int `yaml:"sign_retry_count" mapstructure:"sign_retry_count"`
6770
}
6871

6972
// PrivateKeyStoreConfig defines a key store.
@@ -309,6 +312,16 @@ func runMain() error {
309312
}
310313

311314
cfg := server.DefaultServeConfig()
315+
if config.SignTimeout != "" {
316+
signTimeoutDuration, err := time.ParseDuration(config.SignTimeout)
317+
if err != nil {
318+
log.Fatalf("failed to parse signTimeout: %s", err)
319+
}
320+
cfg = cfg.WithSignTimeout(signTimeoutDuration)
321+
}
322+
if config.SignRetryCount > 0 {
323+
cfg = cfg.WithSignRetryCount(config.SignRetryCount)
324+
}
312325
s, err := server.NewServerFromFile(cfg, config.CertFile, config.KeyFile, config.CACertFile)
313326
if err != nil {
314327
return fmt.Errorf("cannot start server: %w", err)

‎server/server.go

+75-6
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,9 @@ type Server struct {
5151
listeners map[net.Listener]map[net.Conn]struct{}
5252
shutdown bool
5353
mtx sync.Mutex
54+
55+
signTimeout time.Duration
56+
signRetryCount int
5457
}
5558

5659
// NewServer prepares a TLS server capable of receiving connections from keyless clients.
@@ -73,6 +76,8 @@ func NewServer(config *ServeConfig, cert tls.Certificate, keylessCA *x509.CertPo
7376
dispatcher: rpc.NewServer(),
7477
limitedDispatcher: rpc.NewServer(),
7578
listeners: make(map[net.Listener]map[net.Conn]struct{}),
79+
signTimeout: config.signTimeout,
80+
signRetryCount: config.signRetryCount,
7681
}
7782

7883
return s, nil
@@ -448,19 +453,56 @@ func (s *Server) unlimitedDo(pkt *protocol.Packet, connName string) response {
448453
return makeErrResponse(pkt, protocol.ErrKeyNotFound)
449454
}
450455

451-
signSpan, _ := opentracing.StartSpanFromContext(ctx, "execute.Sign")
456+
signSpan, ctx := opentracing.StartSpanFromContext(ctx, "execute.Sign")
452457
defer signSpan.Finish()
453458
var sig []byte
454-
sig, err = key.Sign(rand.Reader, pkt.Operation.Payload, opts)
455-
if err != nil {
456-
tracing.LogError(span, err)
457-
log.Errorf("Connection %v: %s: Signing error: %v\n", connName, protocol.ErrCrypto, err)
458-
return makeErrResponse(pkt, protocol.ErrCrypto)
459+
// By default, we only try the request once, unless retry count is configured
460+
for attempts := 1 + s.signRetryCount; attempts > 0; attempts-- {
461+
var err error
462+
// If signTimeout is not set, the value will be zero
463+
if s.signTimeout == 0 {
464+
sig, err = key.Sign(rand.Reader, pkt.Operation.Payload, opts)
465+
} else {
466+
ch := make(chan signWithTimeoutWrapper, 1)
467+
ctxTimeout, cancel := context.WithTimeout(ctx, s.signTimeout)
468+
defer cancel()
469+
470+
go signWithTimeout(ctxTimeout, ch, key, rand.Reader, pkt.Operation.Payload, opts)
471+
select {
472+
case <-ctxTimeout.Done():
473+
sig = nil
474+
err = ctxTimeout.Err()
475+
case result := <-ch:
476+
sig = result.sig
477+
err = result.error
478+
}
479+
}
480+
if err != nil {
481+
if attempts > 1 {
482+
log.Debugf("Connection %v: failed sign attempt: %s, %d attempt(s) left", connName, err, attempts-1)
483+
continue
484+
} else {
485+
tracing.LogError(span, err)
486+
log.Errorf("Connection %v: %s: Signing error: %v\n", connName, protocol.ErrCrypto, err)
487+
return makeErrResponse(pkt, protocol.ErrCrypto)
488+
}
489+
}
490+
break
459491
}
460492

461493
return makeRespondResponse(pkt, sig)
462494
}
463495

496+
type signWithTimeoutWrapper struct {
497+
sig []byte
498+
error error
499+
}
500+
501+
func signWithTimeout(ctx context.Context, ch chan signWithTimeoutWrapper, key crypto.Signer, rand io.Reader, digest []byte, opts crypto.SignerOpts) {
502+
sig, err := key.Sign(rand, digest, opts)
503+
ch <- signWithTimeoutWrapper{sig, err}
504+
}
505+
464506
func (s *Server) limitedDo(pkt *protocol.Packet, connName string) response {
465507
spanCtx, err := tracing.SpanContextFromBinary(pkt.Operation.JaegerSpan)
466508
if err != nil {
@@ -697,6 +739,8 @@ type ServeConfig struct {
697739
tcpTimeout, unixTimeout time.Duration
698740
isLimited func(state tls.ConnectionState) (bool, error)
699741
customOpFunc CustomOpFunction
742+
signTimeout time.Duration
743+
signRetryCount int
700744
}
701745

702746
const (
@@ -718,6 +762,8 @@ func DefaultServeConfig() *ServeConfig {
718762
unixTimeout: defaultUnixTimeout,
719763
maxConnPendingRequests: 1024,
720764
isLimited: func(state tls.ConnectionState) (bool, error) { return false, nil },
765+
signTimeout: 0,
766+
signRetryCount: 0,
721767
}
722768
}
723769

@@ -757,6 +803,29 @@ func (s *ServeConfig) WithIsLimited(f func(state tls.ConnectionState) (bool, err
757803
return s
758804
}
759805

806+
// WithSignTimeout specifies the sign operation timeout. This timeout is used to enforce a
807+
// max execution time for a single sign operation
808+
func (s *ServeConfig) WithSignTimeout(timeout time.Duration) *ServeConfig {
809+
s.signTimeout = timeout
810+
return s
811+
}
812+
813+
// SignTimeout returns the sign operation timeout
814+
func (s *ServeConfig) SignTimeout() time.Duration {
815+
return s.signTimeout
816+
}
817+
818+
// WithSignRetryCount specifics a number of retries to allow for failed sign operations
819+
func (s *ServeConfig) WithSignRetryCount(signRetryCount int) *ServeConfig {
820+
s.signRetryCount = signRetryCount
821+
return s
822+
}
823+
824+
// SignRetryCount returns the count of retries allowed for sign operations
825+
func (s *ServeConfig) SignRetryCount() int {
826+
return s.signRetryCount
827+
}
828+
760829
// CustomOpFunction is the signature for custom opcode functions.
761830
//
762831
// If it returns a non-nil error which implements protocol.Error, the server

‎tests/common_test.go

+10-1
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,9 @@ type IntegrationTestSuite struct {
6464
ecdsaKey *client.PrivateKey
6565
ed25519Key *client.PrivateKey
6666
remote client.Remote
67+
68+
retryCount int
69+
timeout time.Duration
6770
}
6871

6972
func fixedCurrentTime() time.Time {
@@ -148,6 +151,11 @@ func (s *IntegrationTestSuite) NewRemoteSignerByPubKeyFile(filepath string) (cry
148151
func TestSuite(t *testing.T) {
149152
s := &IntegrationTestSuite{}
150153
suite.Run(t, s)
154+
s2 := &IntegrationTestSuite{
155+
timeout: time.Second,
156+
retryCount: 3,
157+
}
158+
suite.Run(t, s2)
151159
}
152160

153161
// SetupTest sets up a compatible server and client for use by tests.
@@ -160,7 +168,8 @@ func (s *IntegrationTestSuite) SetupTest() {
160168
atomic.StoreUint32(&client.TestDisableConnectionPool, 1)
161169

162170
var err error
163-
s.server, err = server.NewServerFromFile(nil, serverCert, serverKey, keylessCA)
171+
config := server.DefaultServeConfig().WithSignTimeout(s.timeout).WithSignRetryCount(s.retryCount)
172+
s.server, err = server.NewServerFromFile(config, serverCert, serverKey, keylessCA)
164173
require.NoError(err)
165174
s.server.TLSConfig().Time = fixedCurrentTime
166175

0 commit comments

Comments
 (0)