Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add connection pool so we don't leak connections #152

Merged
merged 2 commits into from
Nov 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
273 changes: 273 additions & 0 deletions controllers/jetstream/conn_pool.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,273 @@
package jetstream

import (
"crypto/sha256"
"crypto/tls"
"encoding/json"
"fmt"
"os"
"sync"

"github.com/nats-io/nats.go"
"github.com/sirupsen/logrus"
"golang.org/x/sync/singleflight"
)

type natsContext struct {
Name string `json:"name"`
URL string `json:"url"`
JWT string `json:"jwt"`
Seed string `json:"seed"`
Credentials string `json:"credential"`
Nkey string `json:"nkey"`
Token string `json:"token"`
Username string `json:"username"`
Password string `json:"password"`
TLSCAs []string `json:"tls_ca"`
TLSCert string `json:"tls_cert"`
TLSKey string `json:"tls_key"`
}

func (c *natsContext) copy() *natsContext {
if c == nil {
return nil
}
cp := *c
return &cp
}

func (c *natsContext) hash() (string, error) {
b, err := json.Marshal(c)
if err != nil {
return "", fmt.Errorf("error marshaling context to json: %v", err)
}
if c.Nkey != "" {
fb, err := os.ReadFile(c.Nkey)
if err != nil {
return "", fmt.Errorf("error opening nkey file %s: %v", c.Nkey, err)
}
b = append(b, fb...)
}
if c.Credentials != "" {
fb, err := os.ReadFile(c.Credentials)
if err != nil {
return "", fmt.Errorf("error opening creds file %s: %v", c.Credentials, err)
}
b = append(b, fb...)
}
if len(c.TLSCAs) > 0 {
for _, cert := range c.TLSCAs {
fb, err := os.ReadFile(cert)
if err != nil {
return "", fmt.Errorf("error opening ca file %s: %v", cert, err)
}
b = append(b, fb...)
}
}
if c.TLSCert != "" {
fb, err := os.ReadFile(c.TLSCert)
if err != nil {
return "", fmt.Errorf("error opening cert file %s: %v", c.TLSCert, err)
}
b = append(b, fb...)
}
if c.TLSKey != "" {
fb, err := os.ReadFile(c.TLSKey)
if err != nil {
return "", fmt.Errorf("error opening key file %s: %v", c.TLSKey, err)
}
b = append(b, fb...)
}
hash := sha256.New()
hash.Write(b)
return fmt.Sprintf("%x", hash.Sum(nil)), nil
}

type natsContextDefaults struct {
Name string
URL string
TLSCAs []string
TLSCert string
TLSKey string
TLSConfig *tls.Config
}

type pooledNatsConn struct {
nc *nats.Conn
cp *natsConnPool
key string
count uint64
closed bool
}

func (pc *pooledNatsConn) ReturnToPool() {
pc.cp.Lock()
pc.count--
if pc.count == 0 {
if pooledConn, ok := pc.cp.cache[pc.key]; ok && pc == pooledConn {
delete(pc.cp.cache, pc.key)
}
pc.closed = true
pc.cp.Unlock()
pc.nc.Close()
return
}
pc.cp.Unlock()
}

type natsConnPool struct {
sync.Mutex
cache map[string]*pooledNatsConn
logger *logrus.Logger
group *singleflight.Group
natsDefaults *natsContextDefaults
natsOpts []nats.Option
}

func newNatsConnPool(logger *logrus.Logger, natsDefaults *natsContextDefaults, natsOpts []nats.Option) *natsConnPool {
return &natsConnPool{
cache: map[string]*pooledNatsConn{},
group: &singleflight.Group{},
logger: logger,
natsDefaults: natsDefaults,
natsOpts: natsOpts,
}
}

const getPooledConnMaxTries = 10

// Get returns a *pooledNatsConn
func (cp *natsConnPool) Get(cfg *natsContext) (*pooledNatsConn, error) {
if cfg == nil {
return nil, fmt.Errorf("nats context must not be nil")
}

// copy cfg
cfg = cfg.copy()

// set defaults
if cfg.Name == "" {
cfg.Name = cp.natsDefaults.Name
}
if cfg.URL == "" {
cfg.URL = cp.natsDefaults.URL
}
if len(cfg.TLSCAs) == 0 {
cfg.TLSCAs = cp.natsDefaults.TLSCAs
}
if cfg.TLSCert == "" {
cfg.TLSCert = cp.natsDefaults.TLSCert
}
if cfg.TLSKey == "" {
cfg.TLSKey = cp.natsDefaults.TLSKey
}

// get hash
key, err := cfg.hash()
if err != nil {
return nil, err
}

for i := 0; i < getPooledConnMaxTries; i++ {
connection, err := cp.getPooledConn(key, cfg)
if err != nil {
return nil, err
}

cp.Lock()
if connection.closed {
// ReturnToPool closed this while lock not held, try again
cp.Unlock()
continue
}

// increment count out of the pool
connection.count++
cp.Unlock()
return connection, nil
}

return nil, fmt.Errorf("failed to get pooled connection after %d attempts", getPooledConnMaxTries)
}

// getPooledConn gets or establishes a *pooledNatsConn in a singleflight group, but does not increment its count
func (cp *natsConnPool) getPooledConn(key string, cfg *natsContext) (*pooledNatsConn, error) {
conn, err, _ := cp.group.Do(key, func() (interface{}, error) {
cp.Lock()
pooledConn, ok := cp.cache[key]
if ok && pooledConn.nc.IsConnected() {
cp.Unlock()
return pooledConn, nil
}
cp.Unlock()

opts := cp.natsOpts
opts = append(opts, func(options *nats.Options) error {
if cfg.Name != "" {
options.Name = cfg.Name
}
if cfg.Token != "" {
options.Token = cfg.Token
}
if cfg.Username != "" {
options.User = cfg.Username
}
if cfg.Password != "" {
options.Password = cfg.Password
}
return nil
})

if cfg.JWT != "" && cfg.Seed != "" {
opts = append(opts, nats.UserJWTAndSeed(cfg.JWT, cfg.Seed))
}

if cfg.Nkey != "" {
opt, err := nats.NkeyOptionFromSeed(cfg.Nkey)
if err != nil {
return nil, fmt.Errorf("unable to load nkey: %v", err)
}
opts = append(opts, opt)
}

if cfg.Credentials != "" {
opts = append(opts, nats.UserCredentials(cfg.Credentials))
}

if len(cfg.TLSCAs) > 0 {
opts = append(opts, nats.RootCAs(cfg.TLSCAs...))
}

if cfg.TLSCert != "" && cfg.TLSKey != "" {
opts = append(opts, nats.ClientCert(cfg.TLSCert, cfg.TLSKey))
}

nc, err := nats.Connect(cfg.URL, opts...)
if err != nil {
return nil, err
}
cp.logger.Infof("%s connected to NATS Deployment: %s", cfg.Name, nc.ConnectedAddr())

connection := &pooledNatsConn{
nc: nc,
cp: cp,
key: key,
}

cp.Lock()
cp.cache[key] = connection
cp.Unlock()

return connection, err
})

if err != nil {
return nil, err
}

connection, ok := conn.(*pooledNatsConn)
if !ok {
return nil, fmt.Errorf("not a pooledNatsConn")
}
return connection, nil
}
92 changes: 92 additions & 0 deletions controllers/jetstream/conn_pool_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
package jetstream

import (
"sync"
"testing"
"time"

"github.com/nats-io/nats.go"

natsservertest "github.com/nats-io/nats-server/v2/test"
"github.com/sirupsen/logrus"
testifyAssert "github.com/stretchr/testify/assert"
)

func TestConnPool(t *testing.T) {
t.Parallel()

s := natsservertest.RunRandClientPortServer()
defer s.Shutdown()
o1 := &natsContext{
Name: "Client 1",
}
o2 := &natsContext{
Name: "Client 1",
}
o3 := &natsContext{
Name: "Client 2",
}

natsDefaults := &natsContextDefaults{
URL: s.ClientURL(),
}
natsOptions := []nats.Option{
nats.MaxReconnects(10240),
}
cp := newNatsConnPool(logrus.New(), natsDefaults, natsOptions)

var c1, c2, c3 *pooledNatsConn
var c1e, c2e, c3e error
wg := &sync.WaitGroup{}
wg.Add(3)
go func() {
c1, c1e = cp.Get(o1)
wg.Done()
}()
go func() {
c2, c2e = cp.Get(o2)
wg.Done()
}()
go func() {
c3, c3e = cp.Get(o3)
wg.Done()
}()
wg.Wait()

assert := testifyAssert.New(t)
if assert.NoError(c1e) && assert.NoError(c2e) {
assert.Same(c1, c2)
}
if assert.NoError(c3e) {
assert.NotSame(c1, c3)
assert.NotSame(c2, c3)
}

c1.ReturnToPool()
c3.ReturnToPool()
time.Sleep(1 * time.Second)
assert.False(c1.nc.IsClosed())
assert.False(c2.nc.IsClosed())
assert.True(c3.nc.IsClosed())

c4, c4e := cp.Get(o1)
if assert.NoError(c4e) {
assert.Same(c2, c4)
}

c2.ReturnToPool()
c4.ReturnToPool()
time.Sleep(1 * time.Second)
assert.True(c1.nc.IsClosed())
assert.True(c2.nc.IsClosed())
assert.True(c4.nc.IsClosed())

c5, c5e := cp.Get(o1)
if assert.NoError(c5e) {
assert.NotSame(c1, c5)
}

c5.ReturnToPool()
time.Sleep(1 * time.Second)
assert.True(c5.nc.IsClosed())
}
Loading