diff --git a/BitwardenShared/Core/Platform/Services/API/AccountTokenProvider.swift b/BitwardenShared/Core/Platform/Services/API/AccountTokenProvider.swift index badbd05810..f76c45baa7 100644 --- a/BitwardenShared/Core/Platform/Services/API/AccountTokenProvider.swift +++ b/BitwardenShared/Core/Platform/Services/API/AccountTokenProvider.swift @@ -1,3 +1,5 @@ +import BitwardenKit +import Foundation import Networking // MARK: - AccountTokenProvider @@ -21,13 +23,16 @@ actor DefaultAccountTokenProvider: AccountTokenProvider { private weak var accountTokenProviderDelegate: AccountTokenProviderDelegate? /// The `HTTPService` used to make the API call to refresh the access token. - let httpService: HTTPService + private let httpService: HTTPService /// The task associated with refreshing the token, if one is in progress. private(set) var refreshTask: Task? + /// The service used to get the present time. + private let timeProvider: TimeProvider + /// The `TokenService` used to get the current tokens from. - let tokenService: TokenService + private let tokenService: TokenService // MARK: Initialization @@ -35,14 +40,17 @@ actor DefaultAccountTokenProvider: AccountTokenProvider { /// /// - Parameters: /// - httpService: The service used to make the API call to refresh the access token. + /// - timeProvider: The service used to get the present time. /// - tokenService: The service used to get the current tokens from. /// init( httpService: HTTPService, + timeProvider: TimeProvider = CurrentTime(), tokenService: TokenService, ) { - self.tokenService = tokenService self.httpService = httpService + self.timeProvider = timeProvider + self.tokenService = tokenService } // MARK: Methods @@ -54,15 +62,19 @@ actor DefaultAccountTokenProvider: AccountTokenProvider { return try await refreshTask.value } - return try await tokenService.getAccessToken() + let token = try await tokenService.getAccessToken() + if await shouldRefresh(accessToken: token) { + return try await refreshToken() + } else { + return token + } } - func refreshToken() async throws { + func refreshToken() async throws -> String { if let refreshTask { // If there's a refresh in progress, wait for it to complete rather than triggering // another refresh. - _ = try await refreshTask.value - return + return try await refreshTask.value } let refreshTask = Task { @@ -73,9 +85,12 @@ actor DefaultAccountTokenProvider: AccountTokenProvider { let response = try await httpService.send( IdentityTokenRefreshRequest(refreshToken: refreshToken), ) + let expirationDate = timeProvider.presentTime.addingTimeInterval(TimeInterval(response.expiresIn)) + try await tokenService.setTokens( accessToken: response.accessToken, refreshToken: response.refreshToken, + expirationDate: expirationDate, ) return response.accessToken @@ -88,17 +103,35 @@ actor DefaultAccountTokenProvider: AccountTokenProvider { } self.refreshTask = refreshTask - _ = try await refreshTask.value + return try await refreshTask.value } func setDelegate(delegate: AccountTokenProviderDelegate) async { accountTokenProviderDelegate = delegate } + + // MARK: Private + + /// Returns whether the access token needs to be refreshed based on the last stored access token + /// expiration date. This is used to preemptively refresh the token prior to its expiration. + /// + /// - Parameter accessToken: The access token to determine whether it needs to be refreshed. + /// - Returns: Whether the access token needs to be refreshed. + /// + private func shouldRefresh(accessToken: String) async -> Bool { + guard let expirationDate = try? await tokenService.getAccessTokenExpirationDate() else { + // If there's no stored expiration date, don't preemptively refresh the token. + return false + } + + let refreshThreshold = timeProvider.presentTime.addingTimeInterval(Constants.tokenRefreshThreshold) + return expirationDate <= refreshThreshold + } } /// Delegate to be used by the `AccountTokenProvider`. protocol AccountTokenProviderDelegate: AnyObject { - /// Callbac to be used when an error is thrown when refreshing the access token. + /// Callback to be used when an error is thrown when refreshing the access token. /// - Parameter error: `Error` thrown. func onRefreshTokenError(error: Error) async throws } diff --git a/BitwardenShared/Core/Platform/Services/API/AccountTokenProviderTests.swift b/BitwardenShared/Core/Platform/Services/API/AccountTokenProviderTests.swift index a8f188933a..7120b8ac68 100644 --- a/BitwardenShared/Core/Platform/Services/API/AccountTokenProviderTests.swift +++ b/BitwardenShared/Core/Platform/Services/API/AccountTokenProviderTests.swift @@ -1,3 +1,4 @@ +import BitwardenKitMocks import Networking import TestHelpers import XCTest @@ -9,18 +10,25 @@ class AccountTokenProviderTests: BitwardenTestCase { var client: MockHTTPClient! var subject: DefaultAccountTokenProvider! + var timeProvider: MockTimeProvider! var tokenService: MockTokenService! + let expirationDateExpired = Date(year: 2025, month: 10, day: 1, hour: 23, minute: 59, second: 0) + let expirationDateExpiringSoon = Date(year: 2025, month: 10, day: 2, hour: 0, minute: 2, second: 0) + let expirationDateUnexpired = Date(year: 2025, month: 10, day: 2, hour: 0, minute: 6, second: 0) + // MARK: Setup & Teardown override func setUp() { super.setUp() client = MockHTTPClient() + timeProvider = MockTimeProvider(.mockTime(Date(year: 2025, month: 10, day: 2))) tokenService = MockTokenService() subject = DefaultAccountTokenProvider( httpService: HTTPService(baseURL: URL(string: "https://example.com")!, client: client), + timeProvider: timeProvider, tokenService: tokenService, ) } @@ -30,13 +38,55 @@ class AccountTokenProviderTests: BitwardenTestCase { client = nil subject = nil + timeProvider = nil tokenService = nil } // MARK: Tests - /// `getToken()` returns the current access token. - func test_getToken() async throws { + /// `getToken()` returns the current access token if fetching the expiration date returns an error. + func test_getToken_tokenError() async throws { + tokenService.accessToken = "ACCESS_TOKEN" + tokenService.accessTokenExpirationDateResult = .failure(BitwardenTestError.example) + + let token = try await subject.getToken() + XCTAssertEqual(token, "ACCESS_TOKEN") + } + + /// `getToken()` returns a refreshed access token if the current one is expired. + func test_getToken_tokenExpired() async throws { + client.result = .httpSuccess(testData: .identityTokenRefresh) + tokenService.accessToken = "EXPIRED" + tokenService.accessTokenExpirationDateResult = .success(expirationDateExpired) + + let token = try await subject.getToken() + XCTAssertEqual(token, "ACCESS_TOKEN") + } + + /// `getToken()` returns a refreshed access token if the current one is expiring soon. + func test_getToken_tokenExpiringSoon() async throws { + client.result = .httpSuccess(testData: .identityTokenRefresh) + tokenService.accessToken = "EXPIRING_SOON" + tokenService.accessTokenExpirationDateResult = .success(expirationDateExpiringSoon) + + let token = try await subject.getToken() + XCTAssertEqual(token, "ACCESS_TOKEN") + } + + /// `getToken()` returns the current access token if it is unexpired. + func test_getToken_tokenUnexpired() async throws { + tokenService.accessToken = "ACCESS_TOKEN" + tokenService.accessTokenExpirationDateResult = .success(expirationDateUnexpired) + + let token = try await subject.getToken() + XCTAssertEqual(token, "ACCESS_TOKEN") + } + + /// `getToken()` returns the current access token if the expiration date doesn't yet exist. + func test_getToken_tokenNil() async throws { + tokenService.accessToken = "ACCESS_TOKEN" + tokenService.accessTokenExpirationDateResult = .success(nil) + let token = try await subject.getToken() XCTAssertEqual(token, "ACCESS_TOKEN") } @@ -58,12 +108,12 @@ class AccountTokenProviderTests: BitwardenTestCase { client.result = .httpSuccess(testData: .identityTokenRefresh) - try await subject.refreshToken() + let newAccessToken = try await subject.refreshToken() - let newAccessToken = try await subject.getToken() XCTAssertEqual(newAccessToken, "ACCESS_TOKEN") XCTAssertEqual(tokenService.accessToken, "ACCESS_TOKEN") XCTAssertEqual(tokenService.refreshToken, "REFRESH_TOKEN") + XCTAssertEqual(tokenService.expirationDate, Date(year: 2025, month: 10, day: 2, hour: 1, minute: 0, second: 0)) let refreshTask = await subject.refreshTask XCTAssertNil(refreshTask) @@ -76,14 +126,15 @@ class AccountTokenProviderTests: BitwardenTestCase { client.result = .httpSuccess(testData: .identityTokenRefresh) - async let refreshTask1: Void = subject.refreshToken() - async let refreshTask2: Void = subject.refreshToken() + async let refreshTask1: String = subject.refreshToken() + async let refreshTask2: String = subject.refreshToken() _ = try await (refreshTask1, refreshTask2) XCTAssertEqual(client.requests.count, 1) XCTAssertEqual(tokenService.accessToken, "ACCESS_TOKEN") XCTAssertEqual(tokenService.refreshToken, "REFRESH_TOKEN") + XCTAssertEqual(tokenService.expirationDate, Date(year: 2025, month: 10, day: 2, hour: 1, minute: 0, second: 0)) let refreshTask = await subject.refreshTask XCTAssertNil(refreshTask) @@ -101,7 +152,7 @@ class AccountTokenProviderTests: BitwardenTestCase { client.result = .failure(BitwardenTestError.example) await assertAsyncThrows(error: BitwardenTestError.example) { - try await subject.refreshToken() + _ = try await subject.refreshToken() } XCTAssertTrue(delegate.onRefreshTokenErrorCalled) } @@ -115,7 +166,7 @@ class AccountTokenProviderTests: BitwardenTestCase { client.result = .failure(BitwardenTestError.example) await assertAsyncThrows(error: BitwardenTestError.example) { - try await subject.refreshToken() + _ = try await subject.refreshToken() } } } diff --git a/BitwardenShared/Core/Platform/Services/API/RefreshableAPIService.swift b/BitwardenShared/Core/Platform/Services/API/RefreshableAPIService.swift index d09b50abaa..27db0a594c 100644 --- a/BitwardenShared/Core/Platform/Services/API/RefreshableAPIService.swift +++ b/BitwardenShared/Core/Platform/Services/API/RefreshableAPIService.swift @@ -9,6 +9,6 @@ protocol RefreshableAPIService { // sourcery: AutoMockable extension APIService: RefreshableAPIService { func refreshAccessToken() async throws { - try await accountTokenProvider.refreshToken() + _ = try await accountTokenProvider.refreshToken() } } diff --git a/BitwardenShared/Core/Platform/Services/API/TestHelpers/MockAccountTokenProvider.swift b/BitwardenShared/Core/Platform/Services/API/TestHelpers/MockAccountTokenProvider.swift index 71b3ecae63..69f24247b5 100644 --- a/BitwardenShared/Core/Platform/Services/API/TestHelpers/MockAccountTokenProvider.swift +++ b/BitwardenShared/Core/Platform/Services/API/TestHelpers/MockAccountTokenProvider.swift @@ -9,15 +9,15 @@ class MockAccountTokenProvider: AccountTokenProvider { var delegate: AccountTokenProviderDelegate? var getTokenResult: Result = .success("ACCESS_TOKEN") var refreshTokenCalled = false - var refreshTokenResult: Result = .success(()) + var refreshTokenResult: Result = .success("ACCESS_TOKEN") func getToken() async throws -> String { try getTokenResult.get() } - func refreshToken() async throws { + func refreshToken() async throws -> String { refreshTokenCalled = true - try refreshTokenResult.get() + return try refreshTokenResult.get() } func setDelegate(delegate: AccountTokenProviderDelegate) async { diff --git a/BitwardenShared/Core/Platform/Services/StateService.swift b/BitwardenShared/Core/Platform/Services/StateService.swift index 4ef1f05733..cd43b6a3d3 100644 --- a/BitwardenShared/Core/Platform/Services/StateService.swift +++ b/BitwardenShared/Core/Platform/Services/StateService.swift @@ -44,6 +44,13 @@ protocol StateService: AnyObject { /// func doesActiveAccountHavePremium() async -> Bool + /// Gets the access token's expiration date for an account. + /// + /// - Parameter userId: The user ID associated with the access token expiration date. + /// - Returns: The user's access token expiration date. + /// + func getAccessTokenExpirationDate(userId: String) async -> Date? + /// Gets the account for an id. /// /// - Parameter userId: The id for an account. If nil, the active account will be returned. @@ -429,6 +436,14 @@ protocol StateService: AnyObject { /// func pinUnlockRequiresPasswordAfterRestart() async throws -> Bool + /// Sets the access token's expiration date for an account. + /// + /// - Parameters: + /// - expirationDate: The user's access token expiration date. + /// - userId: The user ID associated with the access token expiration date. + /// + func setAccessTokenExpirationDate(_ expirationDate: Date?, userId: String) async + /// Sets the account encryption keys for an account. /// /// - Parameters: @@ -855,6 +870,14 @@ extension StateService { await setPendingAppIntentActions(actions: actions) } + /// Gets the access token's expiration date for the active account. + /// + /// - Returns: The user's access token expiration date. + /// + func getAccessTokenExpirationDate() async throws -> Date? { + try await getAccessTokenExpirationDate(userId: getActiveAccountId()) + } + /// Gets the account encryptions keys for the active account. /// /// - Returns: The account encryption keys. @@ -1143,6 +1166,14 @@ extension StateService { try await pinProtectedUserKeyEnvelope(userId: nil) } + /// Sets the access token's expiration date for the active account. + /// + /// - Parameter expirationDate: The user's access token expiration date. + /// + func setAccessTokenExpirationDate(_ expirationDate: Date?) async throws { + try await setAccessTokenExpirationDate(expirationDate, userId: getActiveAccountId()) + } + /// Sets the account encryption keys for the active account. /// /// - Parameter encryptionKeys: The account encryption keys. @@ -1542,6 +1573,10 @@ actor DefaultStateService: StateService, ConfigStateService { // swiftlint:disab } } + func getAccessTokenExpirationDate(userId: String) -> Date? { + appSettingsStore.accessTokenExpirationDate(userId: userId) + } + func getAccount(userId: String?) throws -> Account { guard let accounts = appSettingsStore.state?.accounts else { throw StateServiceError.noAccounts @@ -1844,6 +1879,7 @@ actor DefaultStateService: StateService, ConfigStateService { // swiftlint:disab state.activeUserId = state.accounts.first?.key } + appSettingsStore.setAccessTokenExpirationDate(nil, userId: knownUserId) appSettingsStore.setBiometricAuthenticationEnabled(nil, for: knownUserId) appSettingsStore.setDefaultUriMatchType(nil, userId: knownUserId) appSettingsStore.setDisableAutoTotpCopy(nil, userId: knownUserId) @@ -1876,6 +1912,10 @@ actor DefaultStateService: StateService, ConfigStateService { // swiftlint:disab && appSettingsStore.pinProtectedUserKey(userId: userId) == nil } + func setAccessTokenExpirationDate(_ expirationDate: Date?, userId: String) async { + appSettingsStore.setAccessTokenExpirationDate(expirationDate, userId: userId) + } + func setAccountKdf(_ kdfConfig: KdfConfig, userId: String) async throws { try updateAccountProfile(userId: userId) { profile in profile.kdfType = kdfConfig.kdfType diff --git a/BitwardenShared/Core/Platform/Services/StateServiceTests.swift b/BitwardenShared/Core/Platform/Services/StateServiceTests.swift index 8263e9d656..ef50f26e69 100644 --- a/BitwardenShared/Core/Platform/Services/StateServiceTests.swift +++ b/BitwardenShared/Core/Platform/Services/StateServiceTests.swift @@ -290,6 +290,29 @@ class StateServiceTests: BitwardenTestCase { // swiftlint:disable:this type_body XCTAssertEqual(errorReporter.errors as? [StateServiceError], [.noActiveAccount]) } + /// `getAccessTokenExpirationDate(userId:)` gets the user's access token expiration date. + func test_getAccessTokenExpirationDate() async throws { + let date1 = Date(year: 2025, month: 1, day: 1) + let date2 = Date(year: 2026, month: 6, day: 1) + appSettingsStore.accessTokenExpirationDateByUserId["1"] = date1 + appSettingsStore.accessTokenExpirationDateByUserId["2"] = date2 + + await subject.addAccount(.fixture(profile: .fixture(userId: "1"))) + await subject.addAccount(.fixture(profile: .fixture(userId: "2"))) + + let expirationDate1 = await subject.getAccessTokenExpirationDate(userId: "1") + XCTAssertEqual(expirationDate1, date1) + let expirationDate2 = try await subject.getAccessTokenExpirationDate() + XCTAssertEqual(expirationDate2, date2) + } + + /// `getAccessTokenExpirationDate(userId:)` throws an error if there's no accounts. + func test_getAccessTokenExpirationDate_noAccount() async throws { + await assertAsyncThrows(error: StateServiceError.noActiveAccount) { + _ = try await subject.getAccessTokenExpirationDate() + } + } + /// `getAccountEncryptionKeys(_:)` returns the encryption keys for the user account. func test_getAccountEncryptionKeys() async throws { appSettingsStore.accountKeys["1"] = .fixture( @@ -1681,6 +1704,29 @@ class StateServiceTests: BitwardenTestCase { // swiftlint:disable:this type_body XCTAssertTrue(result == true) } + /// `setAccessTokenExpirationDate(_:userId:)` sets the access token expiration date for the account. + func test_setAccessTokenExpirationDate() async throws { + let date1 = Date(year: 2025, month: 1, day: 1) + let date2 = Date(year: 2026, month: 6, day: 1) + await subject.addAccount(.fixture(profile: .fixture(userId: "1"))) + await subject.addAccount(.fixture(profile: .fixture(userId: "2"))) + + await subject.setAccessTokenExpirationDate(date1, userId: "1") + try await subject.setAccessTokenExpirationDate(date2) + + XCTAssertEqual( + appSettingsStore.accessTokenExpirationDateByUserId, + ["1": date1, "2": date2], + ) + } + + /// `setAccessTokenExpirationDate(_:userId:)` throws an error if there's no accounts. + func test_setAccessTokenExpirationDate_noAccounts() async throws { + await assertAsyncThrows(error: StateServiceError.noActiveAccount) { + _ = try await subject.setAccessTokenExpirationDate(.now) + } + } + /// `setAccountEncryptionKeys(_:userId:)` sets the encryption keys for the user account. func test_setAccountEncryptionKeys() async throws { await subject.addAccount(.fixture(profile: .fixture(userId: "1"))) diff --git a/BitwardenShared/Core/Platform/Services/Stores/AppSettingsStore.swift b/BitwardenShared/Core/Platform/Services/Stores/AppSettingsStore.swift index a01f846044..0d64aed3c6 100644 --- a/BitwardenShared/Core/Platform/Services/Stores/AppSettingsStore.swift +++ b/BitwardenShared/Core/Platform/Services/Stores/AppSettingsStore.swift @@ -73,6 +73,13 @@ protocol AppSettingsStore: AnyObject { /// The app's account state. var state: State? { get set } + /// The user's access token expiration date. + /// + /// - Parameter userId: The user ID associated with the access token expiration date. + /// - Returns: The user's access token expiration date. + /// + func accessTokenExpirationDate(userId: String) -> Date? + /// The user's v2 account keys. /// /// - Parameter userId: The user ID associated with the stored account keys. @@ -265,6 +272,14 @@ protocol AppSettingsStore: AnyObject { /// - Returns: The server config for that user ID. func serverConfig(userId: String) -> ServerConfig? + /// Sets the user's access token expiration date + /// + /// - Parameters: + /// - expirationDate: The user's access token expiration date + /// - userId: The user ID associated with the access token expiration date. + /// + func setAccessTokenExpirationDate(_ expirationDate: Date?, userId: String) + /// Sets the account v2 keys for a user ID. /// /// - Parameters: @@ -740,6 +755,7 @@ extension DefaultAppSettingsStore: AppSettingsStore, ConfigSettingsStore { /// The keys used to store their associated values. /// enum Keys { + case accessTokenExpirationDate(userId: String) case accountKeys(userId: String) case accountSetupAutofill(userId: String) case accountSetupImportLogins(userId: String) @@ -800,6 +816,8 @@ extension DefaultAppSettingsStore: AppSettingsStore, ConfigSettingsStore { /// Returns the key used to store the data under for retrieving it later. var storageKey: String { let key = switch self { + case let .accessTokenExpirationDate(userId): + "accessTokenExpirationDate_\(userId)" case let .accountKeys(userId): "accountKeys_\(userId)" case let .accountSetupAutofill(userId): @@ -1019,6 +1037,10 @@ extension DefaultAppSettingsStore: AppSettingsStore, ConfigSettingsStore { } } + func accessTokenExpirationDate(userId: String) -> Date? { + fetch(for: .accessTokenExpirationDate(userId: userId)) + } + func accountKeys(userId: String) -> PrivateKeysResponseModel? { fetch(for: .accountKeys(userId: userId)) } @@ -1141,6 +1163,10 @@ extension DefaultAppSettingsStore: AppSettingsStore, ConfigSettingsStore { fetch(for: .serverConfig(userId: userId)) } + func setAccessTokenExpirationDate(_ expirationDate: Date?, userId: String) { + store(expirationDate, for: .accessTokenExpirationDate(userId: userId)) + } + func setAccountKeys(_ keys: PrivateKeysResponseModel?, userId: String) { store(keys, for: .accountKeys(userId: userId)) } diff --git a/BitwardenShared/Core/Platform/Services/Stores/AppSettingsStoreTests.swift b/BitwardenShared/Core/Platform/Services/Stores/AppSettingsStoreTests.swift index 2a5b3ab548..b36dc4f536 100644 --- a/BitwardenShared/Core/Platform/Services/Stores/AppSettingsStoreTests.swift +++ b/BitwardenShared/Core/Platform/Services/Stores/AppSettingsStoreTests.swift @@ -39,6 +39,24 @@ class AppSettingsStoreTests: BitwardenTestCase { // swiftlint:disable:this type_ // MARK: Tests + /// `accessTokenExpirationDate(userId:)` returns `nil` if there isn't a previously stored value. + func test_accessTokenExpirationDate_isInitiallyNil() { + XCTAssertNil(subject.accessTokenExpirationDate(userId: "-1")) + } + + /// `accessTokenExpirationDate(userId:)` can be used to get the user's access token expiration date. + func test_accessTokenExpirationDate_withValue() { + let date1 = Date(year: 2025, month: 10, day: 1) + let date2 = Date(year: 2026, month: 1, day: 2) + subject.setAccessTokenExpirationDate(date1, userId: "1") + subject.setAccessTokenExpirationDate(date2, userId: "2") + + XCTAssertEqual(subject.accessTokenExpirationDate(userId: "1"), date1) + XCTAssertEqual(subject.accessTokenExpirationDate(userId: "2"), date2) + XCTAssertEqual(userDefaults.integer(forKey: "bwPreferencesStorage:accessTokenExpirationDate_1"), 780_969_600) + XCTAssertEqual(userDefaults.integer(forKey: "bwPreferencesStorage:accessTokenExpirationDate_2"), 789_004_800) + } + /// `accountKeys(userId:)` returns `nil` if there isn't a previously stored value. func test_accountKeys_isInitiallyNil() { XCTAssertNil(subject.accountKeys(userId: "-1")) diff --git a/BitwardenShared/Core/Platform/Services/Stores/TestHelpers/MockAppSettingsStore.swift b/BitwardenShared/Core/Platform/Services/Stores/TestHelpers/MockAppSettingsStore.swift index ebe81316b9..6a345c08d2 100644 --- a/BitwardenShared/Core/Platform/Services/Stores/TestHelpers/MockAppSettingsStore.swift +++ b/BitwardenShared/Core/Platform/Services/Stores/TestHelpers/MockAppSettingsStore.swift @@ -7,6 +7,7 @@ import Foundation // swiftlint:disable file_length class MockAppSettingsStore: AppSettingsStore { // swiftlint:disable:this type_body_length + var accessTokenExpirationDateByUserId = [String: Date]() var accountKeys = [String: PrivateKeysResponseModel]() var accountSetupAutofill = [String: AccountSetupProgress]() var accountSetupImportLogins = [String: AccountSetupProgress]() @@ -74,6 +75,10 @@ class MockAppSettingsStore: AppSettingsStore { // swiftlint:disable:this type_bo var activeIdSubject = CurrentValueSubject(nil) + func accessTokenExpirationDate(userId: String) -> Date? { + accessTokenExpirationDateByUserId[userId] + } + func accountKeys(userId: String) -> PrivateKeysResponseModel? { accountKeys[userId] } @@ -191,6 +196,10 @@ class MockAppSettingsStore: AppSettingsStore { // swiftlint:disable:this type_bo serverConfig[userId] } + func setAccessTokenExpirationDate(_ expirationDate: Date?, userId: String) { + accessTokenExpirationDateByUserId[userId] = expirationDate + } + func setAccountKeys(_ keys: BitwardenShared.PrivateKeysResponseModel?, userId: String) { accountKeys[userId] = keys } diff --git a/BitwardenShared/Core/Platform/Services/TestHelpers/MockStateService.swift b/BitwardenShared/Core/Platform/Services/TestHelpers/MockStateService.swift index ca70fa6f54..5a9d2b8f9d 100644 --- a/BitwardenShared/Core/Platform/Services/TestHelpers/MockStateService.swift +++ b/BitwardenShared/Core/Platform/Services/TestHelpers/MockStateService.swift @@ -7,6 +7,7 @@ import Foundation @testable import BitwardenShared class MockStateService: StateService { // swiftlint:disable:this type_body_length + var accessTokenExpirationDateByUserId = [String: Date]() var accountEncryptionKeys = [String: AccountEncryptionKeys]() var accountSetupAutofill = [String: AccountSetupProgress]() var accountSetupAutofillError: Error? @@ -138,6 +139,10 @@ class MockStateService: StateService { // swiftlint:disable:this type_body_lengt }) } + func getAccessTokenExpirationDate(userId: String) async -> Date? { + accessTokenExpirationDateByUserId[userId] + } + func didAccountSwitchInExtension() async throws -> Bool { try didAccountSwitchInExtensionResult.get() } @@ -445,6 +450,10 @@ class MockStateService: StateService { // swiftlint:disable:this type_body_lengt pinUnlockRequiresPasswordAfterRestartValue } + func setAccessTokenExpirationDate(_ expirationDate: Date?, userId: String) async { + accessTokenExpirationDateByUserId[userId] = expirationDate + } + func setAccountEncryptionKeys(_ encryptionKeys: AccountEncryptionKeys, userId: String?) async throws { let userId = try unwrapUserId(userId) accountEncryptionKeys[userId] = encryptionKeys diff --git a/BitwardenShared/Core/Platform/Services/TestHelpers/MockTokenService.swift b/BitwardenShared/Core/Platform/Services/TestHelpers/MockTokenService.swift index 852917071e..f739f53ee0 100644 --- a/BitwardenShared/Core/Platform/Services/TestHelpers/MockTokenService.swift +++ b/BitwardenShared/Core/Platform/Services/TestHelpers/MockTokenService.swift @@ -1,9 +1,12 @@ +import Foundation import Networking @testable import BitwardenShared class MockTokenService: TokenService { var accessToken: String? = "ACCESS_TOKEN" + var accessTokenExpirationDateResult: Result = .success(nil) + var expirationDate: Date? var getIsExternalResult: Result = .success(false) var refreshToken: String? = "REFRESH_TOKEN" @@ -12,6 +15,10 @@ class MockTokenService: TokenService { return accessToken } + func getAccessTokenExpirationDate() async throws -> Date? { + try accessTokenExpirationDateResult.get() + } + func getIsExternal() async throws -> Bool { try getIsExternalResult.get() } @@ -21,8 +28,9 @@ class MockTokenService: TokenService { return refreshToken } - func setTokens(accessToken: String, refreshToken: String) async { + func setTokens(accessToken: String, refreshToken: String, expirationDate: Date) async { self.accessToken = accessToken self.refreshToken = refreshToken + self.expirationDate = expirationDate } } diff --git a/BitwardenShared/Core/Platform/Services/TokenService.swift b/BitwardenShared/Core/Platform/Services/TokenService.swift index 2ef9a9f9db..306e0bdabb 100644 --- a/BitwardenShared/Core/Platform/Services/TokenService.swift +++ b/BitwardenShared/Core/Platform/Services/TokenService.swift @@ -1,5 +1,6 @@ import BitwardenKit import BitwardenSdk +import Foundation /// A protocol for a `TokenService` which manages accessing and updating the active account's tokens. /// @@ -10,6 +11,12 @@ protocol TokenService: AnyObject { /// func getAccessToken() async throws -> String + /// Returns the access token's expiration date for the current account. + /// + /// - Returns: The access token's expiration date for the current account. + /// + func getAccessTokenExpirationDate() async throws -> Date? + /// Returns whether the user is an external user. /// /// - Returns: Whether the user is an external user. @@ -27,8 +34,9 @@ protocol TokenService: AnyObject { /// - Parameters: /// - accessToken: The account's updated access token. /// - refreshToken: The account's updated refresh token. + /// - expirationDate: The access token's expiration date. /// - func setTokens(accessToken: String, refreshToken: String) async throws + func setTokens(accessToken: String, refreshToken: String, expirationDate: Date) async throws } // MARK: - DefaultTokenService @@ -73,6 +81,10 @@ actor DefaultTokenService: TokenService { return try await keychainRepository.getAccessToken(userId: userId) } + func getAccessTokenExpirationDate() async throws -> Date? { + try await stateService.getAccessTokenExpirationDate() + } + func getIsExternal() async throws -> Bool { let accessToken: String = try await getAccessToken() let tokenPayload = try TokenParser.parseToken(accessToken) @@ -84,10 +96,11 @@ actor DefaultTokenService: TokenService { return try await keychainRepository.getRefreshToken(userId: userId) } - func setTokens(accessToken: String, refreshToken: String) async throws { + func setTokens(accessToken: String, refreshToken: String, expirationDate: Date) async throws { let userId = try await stateService.getActiveAccountId() try await keychainRepository.setAccessToken(accessToken, userId: userId) try await keychainRepository.setRefreshToken(refreshToken, userId: userId) + await stateService.setAccessTokenExpirationDate(expirationDate, userId: userId) } } diff --git a/BitwardenShared/Core/Platform/Services/TokenServiceTests.swift b/BitwardenShared/Core/Platform/Services/TokenServiceTests.swift index 45a53a5e67..1141b53025 100644 --- a/BitwardenShared/Core/Platform/Services/TokenServiceTests.swift +++ b/BitwardenShared/Core/Platform/Services/TokenServiceTests.swift @@ -69,6 +69,22 @@ class TokenServiceTests: BitwardenTestCase { XCTAssertNil(accessToken) } + /// `getAccessTokenExpirationDate()` returns the access token's expiration date. + func test_getAccessTokenExpirationDate() async throws { + stateService.accessTokenExpirationDateByUserId["1"] = Date(year: 2025, month: 10, day: 2) + stateService.activeAccount = .fixture() + + let expirationDate = try await subject.getAccessTokenExpirationDate() + XCTAssertEqual(expirationDate, Date(year: 2025, month: 10, day: 2)) + } + + /// `getAccessTokenExpirationDate()` throws an error if there isn't an active account. + func test_getAccessTokenExpirationDate_error() async throws { + await assertAsyncThrows(error: StateServiceError.noActiveAccount) { + _ = try await subject.getAccessTokenExpirationDate() + } + } + /// `getIsExternal()` returns false if the user isn't an external user. func test_getIsExternal_false() async throws { // swiftlint:disable:next line_length @@ -144,7 +160,8 @@ class TokenServiceTests: BitwardenTestCase { func test_setTokens() async throws { stateService.activeAccount = .fixture() - try await subject.setTokens(accessToken: "🔑", refreshToken: "🔒") + let expirationDate = Date(year: 2025, month: 10, day: 1) + try await subject.setTokens(accessToken: "🔑", refreshToken: "🔒", expirationDate: expirationDate) XCTAssertEqual( keychainRepository.mockStorage[keychainRepository.formattedKey(for: .accessToken(userId: "1"))], @@ -154,5 +171,6 @@ class TokenServiceTests: BitwardenTestCase { keychainRepository.mockStorage[keychainRepository.formattedKey(for: .refreshToken(userId: "1"))], "🔒", ) + XCTAssertEqual(stateService.accessTokenExpirationDateByUserId["1"], expirationDate) } } diff --git a/BitwardenShared/Core/Platform/Utilities/Constants.swift b/BitwardenShared/Core/Platform/Utilities/Constants.swift index 903897b0c9..1e49c2e268 100644 --- a/BitwardenShared/Core/Platform/Utilities/Constants.swift +++ b/BitwardenShared/Core/Platform/Utilities/Constants.swift @@ -87,6 +87,10 @@ extension Constants { /// The default number of KDF iterations to perform. static let pbkdf2Iterations = 600_000 + + /// The number of seconds before an access token's expiration time at which the app will + /// preemptively refresh the token. + static let tokenRefreshThreshold: TimeInterval = 5 * 60 // 5 minutes } // MARK: Extension Constants diff --git a/BitwardenShared/UI/Platform/Application/AppProcessor.swift b/BitwardenShared/UI/Platform/Application/AppProcessor.swift index 41444e433e..a8f27dc7e3 100644 --- a/BitwardenShared/UI/Platform/Application/AppProcessor.swift +++ b/BitwardenShared/UI/Platform/Application/AppProcessor.swift @@ -682,6 +682,8 @@ extension AppProcessor: AccountTokenProviderDelegate { func onRefreshTokenError(error: any Error) async throws { if case IdentityTokenRefreshRequestError.invalidGrant = error { await logOutAutomatically() + } else if let error = error as? ResponseValidationError, [401, 403].contains(error.response.statusCode) { + await logOutAutomatically() } } } diff --git a/BitwardenShared/UI/Platform/Application/AppProcessorTests.swift b/BitwardenShared/UI/Platform/Application/AppProcessorTests.swift index 42969a2346..8c06d4937e 100644 --- a/BitwardenShared/UI/Platform/Application/AppProcessorTests.swift +++ b/BitwardenShared/UI/Platform/Application/AppProcessorTests.swift @@ -1,5 +1,6 @@ import AuthenticationServices import AuthenticatorBridgeKit +import BitwardenKit import BitwardenKitMocks import BitwardenResources import Foundation @@ -1402,6 +1403,36 @@ class AppProcessorTests: BitwardenTestCase { // swiftlint:disable:this type_body XCTAssertEqual(coordinator.events, [.didLogout(userId: "1", userInitiated: false)]) } + /// `onRefreshTokenError(error:)` logs the user out and notifies the coordinator when a 401 is + /// received while refreshing the token. + @MainActor + func test_onRefreshTokenError_logOut401() async throws { + coordinator.isLoadingOverlayShowing = true + + try await subject.onRefreshTokenError(error: ResponseValidationError(response: .failure(statusCode: 401))) + + XCTAssertTrue(authRepository.logoutCalled) + XCTAssertEqual(authRepository.logoutUserId, nil) + XCTAssertFalse(authRepository.logoutUserInitiated) + XCTAssertFalse(coordinator.isLoadingOverlayShowing) + XCTAssertEqual(coordinator.events, [.didLogout(userId: nil, userInitiated: false)]) + } + + /// `onRefreshTokenError(error:)` logs the user out and notifies the coordinator a 403 is + /// received while refreshing the token. + @MainActor + func test_onRefreshTokenError_logOut403() async throws { + coordinator.isLoadingOverlayShowing = true + + try await subject.onRefreshTokenError(error: ResponseValidationError(response: .failure(statusCode: 403))) + + XCTAssertTrue(authRepository.logoutCalled) + XCTAssertEqual(authRepository.logoutUserId, nil) + XCTAssertFalse(authRepository.logoutUserInitiated) + XCTAssertFalse(coordinator.isLoadingOverlayShowing) + XCTAssertEqual(coordinator.events, [.didLogout(userId: nil, userInitiated: false)]) + } + /// `onRefreshTokenError(error:)` logs the user out and notifies the coordinator when error is `.invalidGrant`. @MainActor func test_onRefreshTokenError_logOutInvalidGrant() async throws { diff --git a/Networking/Sources/Networking/HTTPService.swift b/Networking/Sources/Networking/HTTPService.swift index f8f3f8f28b..66b778923f 100644 --- a/Networking/Sources/Networking/HTTPService.swift +++ b/Networking/Sources/Networking/HTTPService.swift @@ -145,7 +145,7 @@ public final class HTTPService: Sendable { } if let tokenProvider, httpResponse.statusCode == 401, shouldRetryIfUnauthorized { - try await tokenProvider.refreshToken() + _ = try await tokenProvider.refreshToken() // Send the request again, but don't retry if still unauthorized to prevent a retry loop. return try await send(httpRequest, validate: validate, shouldRetryIfUnauthorized: false) diff --git a/Networking/Sources/Networking/TokenProvider.swift b/Networking/Sources/Networking/TokenProvider.swift index 8aa87b421e..b176e90671 100644 --- a/Networking/Sources/Networking/TokenProvider.swift +++ b/Networking/Sources/Networking/TokenProvider.swift @@ -9,5 +9,7 @@ public protocol TokenProvider: Sendable { /// Refreshes the access token by using the refresh token to acquire a new access token. /// - func refreshToken() async throws + /// - Returns: A new access token. + /// + func refreshToken() async throws -> String } diff --git a/Networking/Tests/NetworkingTests/Support/MockTokenProvider.swift b/Networking/Tests/NetworkingTests/Support/MockTokenProvider.swift index 10b6f43c9b..63c15a84a3 100644 --- a/Networking/Tests/NetworkingTests/Support/MockTokenProvider.swift +++ b/Networking/Tests/NetworkingTests/Support/MockTokenProvider.swift @@ -8,7 +8,7 @@ class MockTokenProvider: TokenProvider { var getTokenCallCount = 0 var tokenResults: [Result] = [.success("ACCESS_TOKEN")] - var refreshTokenResult: Result = .success(()) + var refreshTokenResult: Result = .success("ACCESS_TOKEN") var refreshTokenCallCount = 0 func getToken() async throws -> String { @@ -17,8 +17,8 @@ class MockTokenProvider: TokenProvider { return try tokenResults.removeFirst().get() } - func refreshToken() async throws { + func refreshToken() async throws -> String { refreshTokenCallCount += 1 - try refreshTokenResult.get() + return try refreshTokenResult.get() } }