Skip to content

Commit 36e9665

Browse files
committed
Create Credential Provider for EntraID
1 parent efe0f65 commit 36e9665

7 files changed

+422
-0
lines changed

entra_id/credentials_provider.go

+57
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
package entra_id
2+
3+
import (
4+
"log"
5+
"time"
6+
)
7+
8+
// EntraIdIdentityProvider defines the interface for an identity provider
9+
type EntraIdIdentityProvider interface {
10+
RequestToken(forceRefresh bool) (string, time.Duration, error)
11+
}
12+
13+
// EntraIdCredentialsProvider manages credentials and token lifecycle
14+
type EntraIdCredentialsProvider struct {
15+
tokenManager *TokenManager
16+
isStreaming bool
17+
}
18+
19+
// NewEntraIdCredentialsProvider initializes a new credentials provider
20+
func NewEntraIdCredentialsProvider(idp EntraIdIdentityProvider, refreshInterval time.Duration, telemetryEnabled bool) *EntraIdCredentialsProvider {
21+
refreshFunc := func() (string, time.Duration, error) {
22+
return idp.RequestToken(false)
23+
}
24+
25+
tokenManager := NewTokenManager(refreshFunc, refreshInterval, telemetryEnabled)
26+
27+
return &EntraIdCredentialsProvider{
28+
tokenManager: tokenManager,
29+
isStreaming: false,
30+
}
31+
}
32+
33+
// GetCredentials retrieves the current token or refreshes it if needed
34+
func (cp *EntraIdCredentialsProvider) GetCredentials() (string, error) {
35+
token, valid := cp.tokenManager.GetToken()
36+
if !valid {
37+
if err := cp.tokenManager.RefreshToken(); err != nil {
38+
log.Printf("[EntraIdCredentialsProvider] Failed to refresh token: %v", err)
39+
return "", err
40+
}
41+
token, _ = cp.tokenManager.GetToken()
42+
}
43+
44+
// Start streaming if not already started
45+
if !cp.isStreaming {
46+
cp.tokenManager.StartAutoRefresh()
47+
cp.isStreaming = true
48+
}
49+
50+
return token, nil
51+
}
52+
53+
// Stop stops the credentials provider and cleans up resources
54+
func (cp *EntraIdCredentialsProvider) Stop() {
55+
cp.tokenManager.StopAutoRefresh()
56+
log.Println("[EntraIdCredentialsProvider] Stopped and cleaned up resources.")
57+
}

entra_id/credentials_provider_test.go

+81
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
package entra_id_test
2+
3+
import (
4+
"errors"
5+
"time"
6+
7+
. "github.com/bsm/ginkgo/v2"
8+
. "github.com/bsm/gomega"
9+
"github.com/go-redis/entra_id"
10+
)
11+
12+
type MockEntraIdIdentityProvider struct {
13+
Token string
14+
TTL time.Duration
15+
Error error
16+
}
17+
18+
func (m *MockEntraIdIdentityProvider) RequestToken(forceRefresh bool) (string, time.Duration, error) {
19+
if m.Error != nil {
20+
return "", 0, m.Error
21+
}
22+
return m.Token, m.TTL, nil
23+
}
24+
25+
var _ = Describe("EntraIdCredentialsProvider", func() {
26+
var (
27+
provider *entra_id.EntraIdCredentialsProvider
28+
mockIDP *MockEntraIdIdentityProvider
29+
refreshRate time.Duration
30+
)
31+
32+
BeforeEach(func() {
33+
refreshRate = 1 * time.Minute
34+
mockIDP = &MockEntraIdIdentityProvider{
35+
Token: "mock-token",
36+
TTL: 10 * time.Second,
37+
Error: nil,
38+
}
39+
provider = entra_id.NewEntraIdCredentialsProvider(mockIDP, refreshRate, true)
40+
})
41+
42+
AfterEach(func() {
43+
provider.Stop()
44+
})
45+
46+
Context("Initial Token Retrieval", func() {
47+
It("should retrieve a valid token from the identity provider", func() {
48+
token, err := provider.GetCredentials()
49+
Expect(err).To(BeNil())
50+
Expect(token).To(Equal("mock-token"))
51+
})
52+
53+
It("should return an error if the identity provider fails", func() {
54+
mockIDP.Error = errors.New("identity provider failure")
55+
_, err := provider.GetCredentials()
56+
Expect(err).To(HaveOccurred())
57+
Expect(err.Error()).To(ContainSubstring("identity provider failure"))
58+
})
59+
})
60+
61+
Context("Automatic Token Renewal", func() {
62+
It("should automatically refresh the token when it expires", func() {
63+
token, err := provider.GetCredentials()
64+
Expect(err).To(BeNil())
65+
Expect(token).To(Equal("mock-token"))
66+
67+
time.Sleep(11 * time.Second) // Wait for token expiry and auto-refresh
68+
69+
newToken, err := provider.GetCredentials()
70+
Expect(err).To(BeNil())
71+
Expect(newToken).To(Equal("mock-token")) // Mock still returns the same token
72+
})
73+
})
74+
75+
Context("Stop Streaming", func() {
76+
It("should stop token renewal and clean up resources when Stop is called", func() {
77+
provider.GetCredentials() // Start streaming
78+
// Ensure no further actions or panics occur after stopping, the stopping ocuur in the AfterEach
79+
})
80+
})
81+
})

entra_id/entra_id_suite_test.go

+13
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
package entra_id_test
2+
3+
import (
4+
"testing"
5+
6+
. "github.com/bsm/ginkgo/v2"
7+
. "github.com/bsm/gomega"
8+
)
9+
10+
func TestEntraId(t *testing.T) {
11+
RegisterFailHandler(Fail)
12+
RunSpecs(t, "EntraId Suite")
13+
}

entra_id/go.mod

+13
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
module entra_id
2+
3+
go 1.22.0
4+
5+
toolchain go1.23.1
6+
7+
replace github.com/go-redis/entra_id => ./
8+
9+
require (
10+
github.com/bsm/ginkgo/v2 v2.12.0
11+
github.com/bsm/gomega v1.27.10
12+
github.com/go-redis/entra_id v0.0.0-00010101000000-000000000000
13+
)

entra_id/go.sum

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
2+
github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c=
3+
github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA=
4+
github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0=

entra_id/token_manager.go

+129
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
package entra_id
2+
3+
import (
4+
"log"
5+
"sync"
6+
"time"
7+
)
8+
9+
type TokenManager struct {
10+
token string
11+
expiresAt time.Time
12+
mutex sync.Mutex
13+
refreshFunc func() (string, time.Duration, error)
14+
stopChan chan struct{}
15+
refreshTicker *time.Ticker
16+
refreshInterval time.Duration
17+
telemetryEnabled bool
18+
}
19+
20+
// NewTokenManager initializes a new TokenManager.
21+
func NewTokenManager(refreshFunc func() (string, time.Duration, error), refreshInterval time.Duration, telemetryEnabled bool) *TokenManager {
22+
return &TokenManager{
23+
refreshFunc: refreshFunc,
24+
stopChan: make(chan struct{}),
25+
refreshInterval: refreshInterval,
26+
telemetryEnabled: telemetryEnabled,
27+
}
28+
}
29+
30+
// SetToken updates the token and its expiration.
31+
func (tm *TokenManager) SetToken(token string, ttl time.Duration) {
32+
tm.mutex.Lock()
33+
defer tm.mutex.Unlock()
34+
tm.token = token
35+
tm.expiresAt = time.Now().Add(ttl)
36+
log.Printf("[TokenManager] Token updated with TTL: %s", ttl)
37+
}
38+
39+
// GetToken returns the current token if it's still valid.
40+
func (tm *TokenManager) GetToken() (string, bool) {
41+
tm.mutex.Lock()
42+
defer tm.mutex.Unlock()
43+
if time.Now().After(tm.expiresAt) {
44+
return "", false
45+
}
46+
return tm.token, true
47+
}
48+
49+
// RefreshToken fetches a new token using the provided refresh function.
50+
func (tm *TokenManager) RefreshToken() error {
51+
if tm.refreshFunc == nil {
52+
return nil
53+
}
54+
token, ttl, err := tm.refreshFunc()
55+
if err != nil {
56+
log.Printf("[TokenManager] Failed to refresh token: %v", err)
57+
return err
58+
}
59+
tm.SetToken(token, ttl)
60+
log.Println("[TokenManager] Token refreshed successfully.")
61+
return nil
62+
}
63+
64+
// StartAutoRefresh starts a goroutine to proactively refresh the token.
65+
func (tm *TokenManager) StartAutoRefresh() {
66+
tm.refreshTicker = time.NewTicker(tm.refreshInterval)
67+
go func() {
68+
for {
69+
select {
70+
case <-tm.refreshTicker.C:
71+
if tm.shouldRefresh() {
72+
log.Println("[TokenManager] Proactively refreshing token...")
73+
if err := tm.RefreshToken(); err != nil {
74+
log.Printf("[TokenManager] Error during token refresh: %v", err)
75+
}
76+
}
77+
case <-tm.stopChan:
78+
log.Println("[TokenManager] Stopping auto-refresh...")
79+
return
80+
}
81+
}
82+
}()
83+
}
84+
85+
// StopAutoRefresh stops the auto-refresh goroutine and cleans up resources.
86+
func (tm *TokenManager) StopAutoRefresh() {
87+
if tm.refreshTicker != nil {
88+
tm.refreshTicker.Stop()
89+
}
90+
close(tm.stopChan)
91+
log.Println("[TokenManager] Auto-refresh stopped and resources cleaned.")
92+
}
93+
94+
// shouldRefresh determines if the token should be refreshed.
95+
func (tm *TokenManager) shouldRefresh() bool {
96+
tm.mutex.Lock()
97+
defer tm.mutex.Unlock()
98+
remaining := time.Until(tm.expiresAt)
99+
100+
// Trigger refresh when less than 20% of TTL remains
101+
return remaining < (tm.refreshInterval / 5)
102+
}
103+
104+
// MonitorTelemetry adds monitoring for token usage and expiration.
105+
func (tm *TokenManager) MonitorTelemetry() {
106+
if !tm.telemetryEnabled {
107+
return
108+
}
109+
110+
go func() {
111+
ticker := time.NewTicker(30 * time.Second) // Adjust as needed
112+
defer ticker.Stop()
113+
114+
for {
115+
select {
116+
case <-ticker.C:
117+
_, valid := tm.GetToken()
118+
if !valid {
119+
log.Println("[TokenManager] Token has expired.")
120+
} else {
121+
log.Printf("[TokenManager] Token is valid: expires in %s", time.Until(tm.expiresAt))
122+
}
123+
case <-tm.stopChan:
124+
log.Println("[TokenManager] Telemetry monitoring stopped.")
125+
return
126+
}
127+
}
128+
}()
129+
}

0 commit comments

Comments
 (0)