Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
buger committed Oct 17, 2024
1 parent 369b685 commit 3b6324f
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 48 deletions.
28 changes: 14 additions & 14 deletions ee/middleware/streams/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ type Middleware struct {
base BaseMiddleware

createStreamManagerLock sync.Mutex
streamManagerCache sync.Map // Map of payload hash to Manager
StreamManagerCache sync.Map // Map of payload hash to Manager

ctx context.Context
cancel context.CancelFunc
Expand Down Expand Up @@ -87,7 +87,7 @@ func (s *Middleware) Init() {
s.ctx, s.cancel = context.WithCancel(context.Background())

s.Logger().Debug("Initializing default stream manager")
s.defaultManager = s.createStreamManager(nil)
s.defaultManager = s.CreateStreamManager(nil)

// Start garbage collection routine
go func() {
Expand All @@ -97,16 +97,16 @@ func (s *Middleware) Init() {
for {
select {
case <-ticker.C:
s.garbageCollect()
s.GC()
case <-s.ctx.Done():
return
}
}
}()
}

// createStreamManager creates or retrieves a stream manager based on the request.
func (s *Middleware) createStreamManager(r *http.Request) *Manager {
// CreateStreamManager creates or retrieves a stream manager based on the request.
func (s *Middleware) CreateStreamManager(r *http.Request) *Manager {
streamsConfig := s.getStreamsConfig(r)
configJSON, _ := json.Marshal(streamsConfig)
cacheKey := fmt.Sprintf("%x", sha256.Sum256(configJSON))
Expand All @@ -116,7 +116,7 @@ func (s *Middleware) createStreamManager(r *http.Request) *Manager {

s.Logger().Debug("Attempting to load stream manager from cache")
s.Logger().Debugf("Cache key: %s", cacheKey)
if cachedManager, found := s.streamManagerCache.Load(cacheKey); found {
if cachedManager, found := s.StreamManagerCache.Load(cacheKey); found {
s.Logger().Debug("Found cached stream manager")
return cachedManager.(*Manager)
}
Expand All @@ -130,16 +130,16 @@ func (s *Middleware) createStreamManager(r *http.Request) *Manager {
newManager.initStreams(r, streamsConfig)

if r != nil {
s.streamManagerCache.Store(cacheKey, newManager)
s.StreamManagerCache.Store(cacheKey, newManager)
}
return newManager
}

// garbageCollect removes inactive stream managers.
func (s *Middleware) garbageCollect() {
// GC removes inactive stream managers.
func (s *Middleware) GC() {
s.Logger().Debug("Starting garbage collection for inactive stream managers")

s.streamManagerCache.Range(func(key, value interface{}) bool {
s.StreamManagerCache.Range(func(key, value interface{}) bool {
manager := value.(*Manager)
if manager == s.defaultManager {
return true
Expand All @@ -155,7 +155,7 @@ func (s *Middleware) garbageCollect() {
}
return true
})
s.streamManagerCache.Delete(key)
s.StreamManagerCache.Delete(key)
}

return true
Expand Down Expand Up @@ -232,7 +232,7 @@ func (s *Middleware) ProcessRequest(w http.ResponseWriter, r *http.Request, _ in
}

var match mux.RouteMatch
streamManager := s.createStreamManager(r)
streamManager := s.CreateStreamManager(r)
streamManager.routeLock.Lock()
streamManager.muxer.Match(newRequest, &match)
streamManager.routeLock.Unlock()
Expand All @@ -258,7 +258,7 @@ func (s *Middleware) Unload() {
totalStreams := 0
s.cancel()

s.streamManagerCache.Range(func(_, value interface{}) bool {
s.StreamManagerCache.Range(func(_, value interface{}) bool {
manager, ok := value.(*Manager)
if !ok {
return true
Expand All @@ -276,6 +276,6 @@ func (s *Middleware) Unload() {
})

GlobalStreamCounter.Add(-int64(totalStreams))
s.streamManagerCache = sync.Map{}
s.StreamManagerCache = sync.Map{}
s.Logger().Info("All streams successfully removed")
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package streams
package gateway

import (
"bytes"
Expand Down Expand Up @@ -30,7 +30,7 @@ import (

"github.com/TykTechnologies/tyk/apidef/oas"
"github.com/TykTechnologies/tyk/config"
"github.com/TykTechnologies/tyk/gateway"
"github.com/TykTechnologies/tyk/ee/middleware/streams"
"github.com/TykTechnologies/tyk/internal/model"
"github.com/TykTechnologies/tyk/test"
)
Expand Down Expand Up @@ -92,7 +92,7 @@ output:
t.Run(tc.name, func(t *testing.T) {
config, err := yamlConfigToMap(tc.configYaml)
require.NoError(t, err)
httpPaths := GetHTTPPaths(config)
httpPaths := streams.GetHTTPPaths(config)
assert.ElementsMatch(t, tc.expected, httpPaths)
})
}
Expand Down Expand Up @@ -182,7 +182,7 @@ func TestStreamingAPISingleClient(t *testing.T) {
assert.NoError(t, err)
streamConfig := fmt.Sprintf(bentoNatsTemplate, configSubject, connectionStr)

ts := gateway.StartTest(func(globalConf *config.Config) {
ts := StartTest(func(globalConf *config.Config) {
globalConf.Streaming.Enabled = true
})
t.Cleanup(func() {
Expand Down Expand Up @@ -248,7 +248,7 @@ func TestStreamingAPIMultipleClients(t *testing.T) {

streamConfig := fmt.Sprintf(bentoNatsTemplate, "test", connectionStr)

ts := gateway.StartTest(func(globalConf *config.Config) {
ts := StartTest(func(globalConf *config.Config) {
globalConf.Streaming.Enabled = true
})
t.Cleanup(func() {
Expand Down Expand Up @@ -325,12 +325,12 @@ func TestStreamingAPIMultipleClients(t *testing.T) {
require.Empty(t, messages)
}

func setUpStreamAPI(ts *gateway.Test, apiName string, streamConfig string) error {
func setUpStreamAPI(ts *Test, apiName string, streamConfig string) error {
oasAPI, err := setupOASForStreamAPI(streamConfig)
if err != nil {
return err
}
ts.Gw.BuildAndLoadAPI(func(spec *gateway.APISpec) {
ts.Gw.BuildAndLoadAPI(func(spec *APISpec) {
spec.Proxy.ListenPath = fmt.Sprintf("/%s", apiName)
spec.UseKeylessAccess = true
spec.IsOAS = true
Expand Down Expand Up @@ -359,7 +359,7 @@ func setupOASForStreamAPI(streamingConfig string) (oas.OAS, error) {
}

oasAPI.Extensions = map[string]interface{}{
ExtensionTykStreaming: parsedStreamingConfig,
streams.ExtensionTykStreaming: parsedStreamingConfig,
}

return oasAPI, nil
Expand All @@ -381,11 +381,11 @@ func yamlConfigToMap(streamingConfig string) (map[string]interface{}, error) {
func TestAsyncAPI(t *testing.T) {
t.SkipNow()

ts := gateway.StartTest(func(globalConf *config.Config) {
ts := StartTest(func(globalConf *config.Config) {
globalConf.Streaming.Enabled = true
})

ts.Gw.BuildAndLoadAPI(func(spec *gateway.APISpec) {
ts.Gw.BuildAndLoadAPI(func(spec *APISpec) {
spec.Proxy.ListenPath = "/test"
spec.UseKeylessAccess = true
})
Expand Down Expand Up @@ -447,14 +447,14 @@ streams:
}

oasAPI.Extensions = map[string]interface{}{
ExtensionTykStreaming: parsedStreamingConfig,
streams.ExtensionTykStreaming: parsedStreamingConfig,
// oas.ExtensionTykAPIGateway: tykExtension,
}

ts.Gw.BuildAndLoadAPI(func(spec *gateway.APISpec) {
ts.Gw.BuildAndLoadAPI(func(spec *APISpec) {
spec.Proxy.ListenPath = "/test"
spec.UseKeylessAccess = true
}, func(spec *gateway.APISpec) {
}, func(spec *APISpec) {
spec.SetDisabledFlags()
spec.APIID = "base-api-id"
spec.VersionDefinition.Enabled = false
Expand All @@ -469,8 +469,8 @@ streams:
// Check that standard API still works
_, _ = ts.Run(t, test.TestCase{Code: http.StatusOK, Method: http.MethodGet, Path: "/test"})

if GlobalStreamCounter.Load() != 1 {
t.Fatalf("Expected 1 stream, got %d", GlobalStreamCounter.Load())
if streams.GlobalStreamCounter.Load() != 1 {
t.Fatalf("Expected 1 stream, got %d", streams.GlobalStreamCounter.Load())
}

time.Sleep(500 * time.Millisecond)
Expand Down Expand Up @@ -524,7 +524,7 @@ func TestAsyncAPIHttp(t *testing.T) {
t.Fatalf("Failed to get Kafka brokers: %v", err)
}

ts := gateway.StartTest(func(globalConf *config.Config) {
ts := StartTest(func(globalConf *config.Config) {
globalConf.Streaming.Enabled = true
})
defer ts.Close()
Expand All @@ -537,13 +537,13 @@ func TestAsyncAPIHttp(t *testing.T) {
}
}

func setupStreamingAPI(t *testing.T, ts *gateway.Test, consumerGroup string, tenantID string, kafkaHost string) string {
func setupStreamingAPI(t *testing.T, ts *Test, consumerGroup string, tenantID string, kafkaHost string) string {
t.Helper()
t.Logf("Setting up streaming API for tenant: %s with consumer group: %s", tenantID, consumerGroup)

apiName := fmt.Sprintf("streaming-api-%s", tenantID)

ts.Gw.BuildAndLoadAPI(func(spec *gateway.APISpec) {
ts.Gw.BuildAndLoadAPI(func(spec *APISpec) {
spec.Proxy.ListenPath = fmt.Sprintf("/%s", apiName)
spec.UseKeylessAccess = true
spec.IsOAS = true
Expand Down Expand Up @@ -593,19 +593,19 @@ streams:
}

oasAPI.Extensions = map[string]interface{}{
ExtensionTykStreaming: parsedStreamingConfig,
streams.ExtensionTykStreaming: parsedStreamingConfig,
}

return oasAPI
}

func testAsyncAPIHttp(t *testing.T, ts *gateway.Test, isDynamic bool, tenantID string, apiName string, kafkaHost string) {
func testAsyncAPIHttp(t *testing.T, ts *Test, isDynamic bool, tenantID string, apiName string, kafkaHost string) {
t.Helper()
const messageToSend = "hello websocket"
const numMessages = 2
const numClients = 2

streamCount := GlobalStreamCounter.Load()
streamCount := streams.GlobalStreamCounter.Load()
t.Logf("Stream count for tenant %s: %d", tenantID, streamCount)

// Create WebSocket clients
Expand Down Expand Up @@ -713,7 +713,7 @@ func testAsyncAPIHttp(t *testing.T, ts *gateway.Test, isDynamic bool, tenantID s
t.Log("Test completed, closing WebSocket connections")
}

func waitForAPIToBeLoaded(ts *gateway.Test) error {
func waitForAPIToBeLoaded(ts *Test) error {
maxAttempts := 2
for i := 0; i < maxAttempts; i++ {
req, err := http.NewRequestWithContext(context.Background(), "GET", ts.URL+"/streaming-api-default/metrics", nil)
Expand Down Expand Up @@ -755,7 +755,7 @@ func TestWebSocketConnectionClosedOnAPIReload(t *testing.T) {
t.Fatalf("Failed to get Kafka brokers: %v", err)
}

ts := gateway.StartTest(func(globalConf *config.Config) {
ts := StartTest(func(globalConf *config.Config) {
globalConf.Streaming.Enabled = true
})
defer ts.Close()
Expand All @@ -780,7 +780,7 @@ func TestWebSocketConnectionClosedOnAPIReload(t *testing.T) {
})

// Reload the API by rebuilding and loading it
ts.Gw.BuildAndLoadAPI(func(spec *gateway.APISpec) {
ts.Gw.BuildAndLoadAPI(func(spec *APISpec) {
spec.Proxy.ListenPath = fmt.Sprintf("/%s", apiName)
spec.UseKeylessAccess = true
spec.IsOAS = true
Expand Down Expand Up @@ -808,7 +808,7 @@ func TestWebSocketConnectionClosedOnAPIReload(t *testing.T) {
}

func TestStreamingAPISingleClient_Input_HTTPServer(t *testing.T) {
ts := gateway.StartTest(func(globalConf *config.Config) {
ts := StartTest(func(globalConf *config.Config) {
globalConf.Streaming.Enabled = true
})
t.Cleanup(func() {
Expand Down Expand Up @@ -859,7 +859,7 @@ func TestStreamingAPIMultipleClients_Input_HTTPServer(t *testing.T) {
// Testing input http -> output http (3 output instances and 10 messages)
// Messages are distributed in a round-robin fashion.

ts := gateway.StartTest(func(globalConf *config.Config) {
ts := StartTest(func(globalConf *config.Config) {
globalConf.Streaming.Enabled = true
})
t.Cleanup(func() {
Expand Down Expand Up @@ -942,7 +942,7 @@ func (d *DummyBase) Logger() *logrus.Entry {
}

func TestStreamingAPIGarbageCollection(t *testing.T) {
ts := gateway.StartTest(func(globalConf *config.Config) {
ts := StartTest(func(globalConf *config.Config) {
globalConf.Streaming.Enabled = true
})
t.Cleanup(func() {
Expand All @@ -954,17 +954,17 @@ func TestStreamingAPIGarbageCollection(t *testing.T) {

apiName := "test-api"

specs := ts.Gw.BuildAndLoadAPI(func(spec *gateway.APISpec) {
specs := ts.Gw.BuildAndLoadAPI(func(spec *APISpec) {
spec.Proxy.ListenPath = fmt.Sprintf("/%s", apiName)
spec.UseKeylessAccess = true
spec.IsOAS = true
spec.OAS = oasAPI
spec.OAS.Fill(*spec.APIDefinition)
})

apiSpec := NewAPISpec(specs[0].APIID, specs[0].Name, specs[0].IsOAS, specs[0].OAS, specs[0].StripListenPath)
apiSpec := streams.NewAPISpec(specs[0].APIID, specs[0].Name, specs[0].IsOAS, specs[0].OAS, specs[0].StripListenPath)

s := Middleware{Gw: ts.Gw, Spec: apiSpec, base: &DummyBase{}}
s := streams.NewMiddleware(ts.Gw, &DummyBase{}, apiSpec)

if err := setUpStreamAPI(ts, apiName, bentoHTTPServerTemplate); err != nil {
t.Fatal(err)
Expand All @@ -975,21 +975,21 @@ func TestStreamingAPIGarbageCollection(t *testing.T) {
r, err := http.NewRequest("POST", publishURL, nil)
require.NoError(t, err)

s.createStreamManager(r)
s.CreateStreamManager(r)

// We should have a Stream manager in the cache.
var streamManagersBeforeGC int
s.streamManagerCache.Range(func(k, v interface{}) bool {
s.StreamManagerCache.Range(func(k, v interface{}) bool {
streamManagersBeforeGC++
return true
})
require.Equal(t, 1, streamManagersBeforeGC)

s.garbageCollect()
s.GC()

// Garbage collection removed the stream manager because the activity counter is zero.
var streamManagersAfterGC int
s.streamManagerCache.Range(func(k, v interface{}) bool {
s.StreamManagerCache.Range(func(k, v interface{}) bool {
streamManagersAfterGC++
return true
})
Expand Down

0 comments on commit 3b6324f

Please sign in to comment.