Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import BitwardenKit
import Foundation
import Networking

// MARK: - AccountTokenProvider
Expand All @@ -21,28 +23,34 @@ actor DefaultAccountTokenProvider: AccountTokenProvider {
private weak var accountTokenProviderDelegate: AccountTokenProviderDelegate?

/// The `HTTPService` used to make the API call to refresh the access token.
let httpService: HTTPService
private let httpService: HTTPService

/// The task associated with refreshing the token, if one is in progress.
private(set) var refreshTask: Task<String, Error>?

/// The service used to get the present time.
private let timeProvider: TimeProvider

/// The `TokenService` used to get the current tokens from.
let tokenService: TokenService
private let tokenService: TokenService

// MARK: Initialization

/// Initialize an `AccountTokenProvider`.
///
/// - Parameters:
/// - httpService: The service used to make the API call to refresh the access token.
/// - timeProvider: The service used to get the present time.
/// - tokenService: The service used to get the current tokens from.
///
init(
httpService: HTTPService,
timeProvider: TimeProvider = CurrentTime(),
tokenService: TokenService,
) {
self.tokenService = tokenService
self.httpService = httpService
self.timeProvider = timeProvider
self.tokenService = tokenService
}

// MARK: Methods
Expand All @@ -54,15 +62,19 @@ actor DefaultAccountTokenProvider: AccountTokenProvider {
return try await refreshTask.value
}

return try await tokenService.getAccessToken()
let token = try await tokenService.getAccessToken()
if await shouldRefresh(accessToken: token) {
return try await refreshToken()
} else {
return token
}
}

func refreshToken() async throws {
func refreshToken() async throws -> String {
if let refreshTask {
// If there's a refresh in progress, wait for it to complete rather than triggering
// another refresh.
_ = try await refreshTask.value
return
return try await refreshTask.value
}

let refreshTask = Task {
Expand All @@ -73,9 +85,12 @@ actor DefaultAccountTokenProvider: AccountTokenProvider {
let response = try await httpService.send(
IdentityTokenRefreshRequest(refreshToken: refreshToken),
)
let expirationDate = timeProvider.presentTime.addingTimeInterval(TimeInterval(response.expiresIn))

try await tokenService.setTokens(
accessToken: response.accessToken,
refreshToken: response.refreshToken,
expirationDate: expirationDate,
)

return response.accessToken
Expand All @@ -88,17 +103,35 @@ actor DefaultAccountTokenProvider: AccountTokenProvider {
}
self.refreshTask = refreshTask

_ = try await refreshTask.value
return try await refreshTask.value
}

func setDelegate(delegate: AccountTokenProviderDelegate) async {
accountTokenProviderDelegate = delegate
}

// MARK: Private

/// Returns whether the access token needs to be refreshed based on the last stored access token
/// expiration date. This is used to preemptively refresh the token prior to its expiration.
///
/// - Parameter accessToken: The access token to determine whether it needs to be refreshed.
/// - Returns: Whether the access token needs to be refreshed.
///
private func shouldRefresh(accessToken: String) async -> Bool {
guard let expirationDate = try? await tokenService.getAccessTokenExpirationDate() else {
// If there's no stored expiration date, don't preemptively refresh the token.
return false
}

let refreshThreshold = timeProvider.presentTime.addingTimeInterval(Constants.tokenRefreshThreshold)
return expirationDate <= refreshThreshold
}
}

/// Delegate to be used by the `AccountTokenProvider`.
protocol AccountTokenProviderDelegate: AnyObject {
/// Callbac to be used when an error is thrown when refreshing the access token.
/// Callback to be used when an error is thrown when refreshing the access token.
/// - Parameter error: `Error` thrown.
func onRefreshTokenError(error: Error) async throws
}
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import BitwardenKitMocks
import Networking
import TestHelpers
import XCTest
Expand All @@ -9,18 +10,25 @@ class AccountTokenProviderTests: BitwardenTestCase {

var client: MockHTTPClient!
var subject: DefaultAccountTokenProvider!
var timeProvider: MockTimeProvider!
var tokenService: MockTokenService!

let expirationDateExpired = Date(year: 2025, month: 10, day: 1, hour: 23, minute: 59, second: 0)
let expirationDateExpiringSoon = Date(year: 2025, month: 10, day: 2, hour: 0, minute: 2, second: 0)
let expirationDateUnexpired = Date(year: 2025, month: 10, day: 2, hour: 0, minute: 6, second: 0)

// MARK: Setup & Teardown

override func setUp() {
super.setUp()

client = MockHTTPClient()
timeProvider = MockTimeProvider(.mockTime(Date(year: 2025, month: 10, day: 2)))
tokenService = MockTokenService()

subject = DefaultAccountTokenProvider(
httpService: HTTPService(baseURL: URL(string: "https://example.com")!, client: client),
timeProvider: timeProvider,
tokenService: tokenService,
)
}
Expand All @@ -30,13 +38,55 @@ class AccountTokenProviderTests: BitwardenTestCase {

client = nil
subject = nil
timeProvider = nil
tokenService = nil
}

// MARK: Tests

/// `getToken()` returns the current access token.
func test_getToken() async throws {
/// `getToken()` returns the current access token if fetching the expiration date returns an error.
func test_getToken_tokenError() async throws {
tokenService.accessToken = "ACCESS_TOKEN"
tokenService.accessTokenExpirationDateResult = .failure(BitwardenTestError.example)

let token = try await subject.getToken()
XCTAssertEqual(token, "ACCESS_TOKEN")
}

/// `getToken()` returns a refreshed access token if the current one is expired.
func test_getToken_tokenExpired() async throws {
client.result = .httpSuccess(testData: .identityTokenRefresh)
tokenService.accessToken = "EXPIRED"
tokenService.accessTokenExpirationDateResult = .success(expirationDateExpired)

let token = try await subject.getToken()
XCTAssertEqual(token, "ACCESS_TOKEN")
}

/// `getToken()` returns a refreshed access token if the current one is expiring soon.
func test_getToken_tokenExpiringSoon() async throws {
client.result = .httpSuccess(testData: .identityTokenRefresh)
tokenService.accessToken = "EXPIRING_SOON"
tokenService.accessTokenExpirationDateResult = .success(expirationDateExpiringSoon)

let token = try await subject.getToken()
XCTAssertEqual(token, "ACCESS_TOKEN")
}

/// `getToken()` returns the current access token if it is unexpired.
func test_getToken_tokenUnexpired() async throws {
tokenService.accessToken = "ACCESS_TOKEN"
tokenService.accessTokenExpirationDateResult = .success(expirationDateUnexpired)

let token = try await subject.getToken()
XCTAssertEqual(token, "ACCESS_TOKEN")
}

/// `getToken()` returns the current access token if the expiration date doesn't yet exist.
func test_getToken_tokenNil() async throws {
tokenService.accessToken = "ACCESS_TOKEN"
tokenService.accessTokenExpirationDateResult = .success(nil)

let token = try await subject.getToken()
XCTAssertEqual(token, "ACCESS_TOKEN")
}
Expand All @@ -58,12 +108,12 @@ class AccountTokenProviderTests: BitwardenTestCase {

client.result = .httpSuccess(testData: .identityTokenRefresh)

try await subject.refreshToken()
let newAccessToken = try await subject.refreshToken()

let newAccessToken = try await subject.getToken()
XCTAssertEqual(newAccessToken, "ACCESS_TOKEN")
XCTAssertEqual(tokenService.accessToken, "ACCESS_TOKEN")
XCTAssertEqual(tokenService.refreshToken, "REFRESH_TOKEN")
XCTAssertEqual(tokenService.expirationDate, Date(year: 2025, month: 10, day: 2, hour: 1, minute: 0, second: 0))

let refreshTask = await subject.refreshTask
XCTAssertNil(refreshTask)
Expand All @@ -76,14 +126,15 @@ class AccountTokenProviderTests: BitwardenTestCase {

client.result = .httpSuccess(testData: .identityTokenRefresh)

async let refreshTask1: Void = subject.refreshToken()
async let refreshTask2: Void = subject.refreshToken()
async let refreshTask1: String = subject.refreshToken()
async let refreshTask2: String = subject.refreshToken()

_ = try await (refreshTask1, refreshTask2)

XCTAssertEqual(client.requests.count, 1)
XCTAssertEqual(tokenService.accessToken, "ACCESS_TOKEN")
XCTAssertEqual(tokenService.refreshToken, "REFRESH_TOKEN")
XCTAssertEqual(tokenService.expirationDate, Date(year: 2025, month: 10, day: 2, hour: 1, minute: 0, second: 0))

let refreshTask = await subject.refreshTask
XCTAssertNil(refreshTask)
Expand All @@ -101,7 +152,7 @@ class AccountTokenProviderTests: BitwardenTestCase {
client.result = .failure(BitwardenTestError.example)

await assertAsyncThrows(error: BitwardenTestError.example) {
try await subject.refreshToken()
_ = try await subject.refreshToken()
}
XCTAssertTrue(delegate.onRefreshTokenErrorCalled)
}
Expand All @@ -115,7 +166,7 @@ class AccountTokenProviderTests: BitwardenTestCase {
client.result = .failure(BitwardenTestError.example)

await assertAsyncThrows(error: BitwardenTestError.example) {
try await subject.refreshToken()
_ = try await subject.refreshToken()
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,6 @@ protocol RefreshableAPIService { // sourcery: AutoMockable

extension APIService: RefreshableAPIService {
func refreshAccessToken() async throws {
try await accountTokenProvider.refreshToken()
_ = try await accountTokenProvider.refreshToken()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@ class MockAccountTokenProvider: AccountTokenProvider {
var delegate: AccountTokenProviderDelegate?
var getTokenResult: Result<String, Error> = .success("ACCESS_TOKEN")
var refreshTokenCalled = false
var refreshTokenResult: Result<Void, Error> = .success(())
var refreshTokenResult: Result<String, Error> = .success("ACCESS_TOKEN")

func getToken() async throws -> String {
try getTokenResult.get()
}

func refreshToken() async throws {
func refreshToken() async throws -> String {
refreshTokenCalled = true
try refreshTokenResult.get()
return try refreshTokenResult.get()
}

func setDelegate(delegate: AccountTokenProviderDelegate) async {
Expand Down
40 changes: 40 additions & 0 deletions BitwardenShared/Core/Platform/Services/StateService.swift
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,13 @@ protocol StateService: AnyObject {
///
func doesActiveAccountHavePremium() async -> Bool

/// Gets the access token's expiration date for an account.
///
/// - Parameter userId: The user ID associated with the access token expiration date.
/// - Returns: The user's access token expiration date.
///
func getAccessTokenExpirationDate(userId: String) async -> Date?

/// Gets the account for an id.
///
/// - Parameter userId: The id for an account. If nil, the active account will be returned.
Expand Down Expand Up @@ -429,6 +436,14 @@ protocol StateService: AnyObject {
///
func pinUnlockRequiresPasswordAfterRestart() async throws -> Bool

/// Sets the access token's expiration date for an account.
///
/// - Parameters:
/// - expirationDate: The user's access token expiration date.
/// - userId: The user ID associated with the access token expiration date.
///
func setAccessTokenExpirationDate(_ expirationDate: Date?, userId: String) async

/// Sets the account encryption keys for an account.
///
/// - Parameters:
Expand Down Expand Up @@ -855,6 +870,14 @@ extension StateService {
await setPendingAppIntentActions(actions: actions)
}

/// Gets the access token's expiration date for the active account.
///
/// - Returns: The user's access token expiration date.
///
func getAccessTokenExpirationDate() async throws -> Date? {
try await getAccessTokenExpirationDate(userId: getActiveAccountId())
}

/// Gets the account encryptions keys for the active account.
///
/// - Returns: The account encryption keys.
Expand Down Expand Up @@ -1143,6 +1166,14 @@ extension StateService {
try await pinProtectedUserKeyEnvelope(userId: nil)
}

/// Sets the access token's expiration date for the active account.
///
/// - Parameter expirationDate: The user's access token expiration date.
///
func setAccessTokenExpirationDate(_ expirationDate: Date?) async throws {
try await setAccessTokenExpirationDate(expirationDate, userId: getActiveAccountId())
}

/// Sets the account encryption keys for the active account.
///
/// - Parameter encryptionKeys: The account encryption keys.
Expand Down Expand Up @@ -1542,6 +1573,10 @@ actor DefaultStateService: StateService, ConfigStateService { // swiftlint:disab
}
}

func getAccessTokenExpirationDate(userId: String) -> Date? {
appSettingsStore.accessTokenExpirationDate(userId: userId)
}

func getAccount(userId: String?) throws -> Account {
guard let accounts = appSettingsStore.state?.accounts else {
throw StateServiceError.noAccounts
Expand Down Expand Up @@ -1844,6 +1879,7 @@ actor DefaultStateService: StateService, ConfigStateService { // swiftlint:disab
state.activeUserId = state.accounts.first?.key
}

appSettingsStore.setAccessTokenExpirationDate(nil, userId: knownUserId)
appSettingsStore.setBiometricAuthenticationEnabled(nil, for: knownUserId)
appSettingsStore.setDefaultUriMatchType(nil, userId: knownUserId)
appSettingsStore.setDisableAutoTotpCopy(nil, userId: knownUserId)
Expand Down Expand Up @@ -1876,6 +1912,10 @@ actor DefaultStateService: StateService, ConfigStateService { // swiftlint:disab
&& appSettingsStore.pinProtectedUserKey(userId: userId) == nil
}

func setAccessTokenExpirationDate(_ expirationDate: Date?, userId: String) async {
appSettingsStore.setAccessTokenExpirationDate(expirationDate, userId: userId)
}

func setAccountKdf(_ kdfConfig: KdfConfig, userId: String) async throws {
try updateAccountProfile(userId: userId) { profile in
profile.kdfType = kdfConfig.kdfType
Expand Down
Loading