[PM-24482] Refresh access token preemptively and log out on 401/403 refresh errors (#2024)

This commit is contained in:
Matt Czech 2025-10-08 12:07:21 -05:00 committed by GitHub
parent 039495e7e9
commit 4376077ab1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
19 changed files with 340 additions and 30 deletions

View File

@ -1,3 +1,5 @@
import BitwardenKit
import Foundation
import Networking import Networking
// MARK: - AccountTokenProvider // MARK: - AccountTokenProvider
@ -21,13 +23,16 @@ actor DefaultAccountTokenProvider: AccountTokenProvider {
private weak var accountTokenProviderDelegate: AccountTokenProviderDelegate? private weak var accountTokenProviderDelegate: AccountTokenProviderDelegate?
/// The `HTTPService` used to make the API call to refresh the access token. /// The `HTTPService` used to make the API call to refresh the access token.
let httpService: HTTPService private let httpService: HTTPService
/// The task associated with refreshing the token, if one is in progress. /// The task associated with refreshing the token, if one is in progress.
private(set) var refreshTask: Task<String, Error>? private(set) var refreshTask: Task<String, Error>?
/// The service used to get the present time.
private let timeProvider: TimeProvider
/// The `TokenService` used to get the current tokens from. /// The `TokenService` used to get the current tokens from.
let tokenService: TokenService private let tokenService: TokenService
// MARK: Initialization // MARK: Initialization
@ -35,14 +40,17 @@ actor DefaultAccountTokenProvider: AccountTokenProvider {
/// ///
/// - Parameters: /// - Parameters:
/// - httpService: The service used to make the API call to refresh the access token. /// - httpService: The service used to make the API call to refresh the access token.
/// - timeProvider: The service used to get the present time.
/// - tokenService: The service used to get the current tokens from. /// - tokenService: The service used to get the current tokens from.
/// ///
init( init(
httpService: HTTPService, httpService: HTTPService,
timeProvider: TimeProvider = CurrentTime(),
tokenService: TokenService, tokenService: TokenService,
) { ) {
self.tokenService = tokenService
self.httpService = httpService self.httpService = httpService
self.timeProvider = timeProvider
self.tokenService = tokenService
} }
// MARK: Methods // MARK: Methods
@ -54,15 +62,19 @@ actor DefaultAccountTokenProvider: AccountTokenProvider {
return try await refreshTask.value return try await refreshTask.value
} }
return try await tokenService.getAccessToken() let token = try await tokenService.getAccessToken()
if await shouldRefresh(accessToken: token) {
return try await refreshToken()
} else {
return token
}
} }
func refreshToken() async throws { func refreshToken() async throws -> String {
if let refreshTask { if let refreshTask {
// If there's a refresh in progress, wait for it to complete rather than triggering // If there's a refresh in progress, wait for it to complete rather than triggering
// another refresh. // another refresh.
_ = try await refreshTask.value return try await refreshTask.value
return
} }
let refreshTask = Task { let refreshTask = Task {
@ -73,9 +85,12 @@ actor DefaultAccountTokenProvider: AccountTokenProvider {
let response = try await httpService.send( let response = try await httpService.send(
IdentityTokenRefreshRequest(refreshToken: refreshToken), IdentityTokenRefreshRequest(refreshToken: refreshToken),
) )
let expirationDate = timeProvider.presentTime.addingTimeInterval(TimeInterval(response.expiresIn))
try await tokenService.setTokens( try await tokenService.setTokens(
accessToken: response.accessToken, accessToken: response.accessToken,
refreshToken: response.refreshToken, refreshToken: response.refreshToken,
expirationDate: expirationDate,
) )
return response.accessToken return response.accessToken
@ -88,17 +103,35 @@ actor DefaultAccountTokenProvider: AccountTokenProvider {
} }
self.refreshTask = refreshTask self.refreshTask = refreshTask
_ = try await refreshTask.value return try await refreshTask.value
} }
func setDelegate(delegate: AccountTokenProviderDelegate) async { func setDelegate(delegate: AccountTokenProviderDelegate) async {
accountTokenProviderDelegate = delegate accountTokenProviderDelegate = delegate
} }
// MARK: Private
/// Returns whether the access token needs to be refreshed based on the last stored access token
/// expiration date. This is used to preemptively refresh the token prior to its expiration.
///
/// - Parameter accessToken: The access token to determine whether it needs to be refreshed.
/// - Returns: Whether the access token needs to be refreshed.
///
private func shouldRefresh(accessToken: String) async -> Bool {
guard let expirationDate = try? await tokenService.getAccessTokenExpirationDate() else {
// If there's no stored expiration date, don't preemptively refresh the token.
return false
}
let refreshThreshold = timeProvider.presentTime.addingTimeInterval(Constants.tokenRefreshThreshold)
return expirationDate <= refreshThreshold
}
} }
/// Delegate to be used by the `AccountTokenProvider`. /// Delegate to be used by the `AccountTokenProvider`.
protocol AccountTokenProviderDelegate: AnyObject { protocol AccountTokenProviderDelegate: AnyObject {
/// Callbac to be used when an error is thrown when refreshing the access token. /// Callback to be used when an error is thrown when refreshing the access token.
/// - Parameter error: `Error` thrown. /// - Parameter error: `Error` thrown.
func onRefreshTokenError(error: Error) async throws func onRefreshTokenError(error: Error) async throws
} }

View File

@ -1,3 +1,4 @@
import BitwardenKitMocks
import Networking import Networking
import TestHelpers import TestHelpers
import XCTest import XCTest
@ -9,18 +10,25 @@ class AccountTokenProviderTests: BitwardenTestCase {
var client: MockHTTPClient! var client: MockHTTPClient!
var subject: DefaultAccountTokenProvider! var subject: DefaultAccountTokenProvider!
var timeProvider: MockTimeProvider!
var tokenService: MockTokenService! var tokenService: MockTokenService!
let expirationDateExpired = Date(year: 2025, month: 10, day: 1, hour: 23, minute: 59, second: 0)
let expirationDateExpiringSoon = Date(year: 2025, month: 10, day: 2, hour: 0, minute: 2, second: 0)
let expirationDateUnexpired = Date(year: 2025, month: 10, day: 2, hour: 0, minute: 6, second: 0)
// MARK: Setup & Teardown // MARK: Setup & Teardown
override func setUp() { override func setUp() {
super.setUp() super.setUp()
client = MockHTTPClient() client = MockHTTPClient()
timeProvider = MockTimeProvider(.mockTime(Date(year: 2025, month: 10, day: 2)))
tokenService = MockTokenService() tokenService = MockTokenService()
subject = DefaultAccountTokenProvider( subject = DefaultAccountTokenProvider(
httpService: HTTPService(baseURL: URL(string: "https://example.com")!, client: client), httpService: HTTPService(baseURL: URL(string: "https://example.com")!, client: client),
timeProvider: timeProvider,
tokenService: tokenService, tokenService: tokenService,
) )
} }
@ -30,13 +38,55 @@ class AccountTokenProviderTests: BitwardenTestCase {
client = nil client = nil
subject = nil subject = nil
timeProvider = nil
tokenService = nil tokenService = nil
} }
// MARK: Tests // MARK: Tests
/// `getToken()` returns the current access token. /// `getToken()` returns the current access token if fetching the expiration date returns an error.
func test_getToken() async throws { func test_getToken_tokenError() async throws {
tokenService.accessToken = "ACCESS_TOKEN"
tokenService.accessTokenExpirationDateResult = .failure(BitwardenTestError.example)
let token = try await subject.getToken()
XCTAssertEqual(token, "ACCESS_TOKEN")
}
/// `getToken()` returns a refreshed access token if the current one is expired.
func test_getToken_tokenExpired() async throws {
client.result = .httpSuccess(testData: .identityTokenRefresh)
tokenService.accessToken = "EXPIRED"
tokenService.accessTokenExpirationDateResult = .success(expirationDateExpired)
let token = try await subject.getToken()
XCTAssertEqual(token, "ACCESS_TOKEN")
}
/// `getToken()` returns a refreshed access token if the current one is expiring soon.
func test_getToken_tokenExpiringSoon() async throws {
client.result = .httpSuccess(testData: .identityTokenRefresh)
tokenService.accessToken = "EXPIRING_SOON"
tokenService.accessTokenExpirationDateResult = .success(expirationDateExpiringSoon)
let token = try await subject.getToken()
XCTAssertEqual(token, "ACCESS_TOKEN")
}
/// `getToken()` returns the current access token if it is unexpired.
func test_getToken_tokenUnexpired() async throws {
tokenService.accessToken = "ACCESS_TOKEN"
tokenService.accessTokenExpirationDateResult = .success(expirationDateUnexpired)
let token = try await subject.getToken()
XCTAssertEqual(token, "ACCESS_TOKEN")
}
/// `getToken()` returns the current access token if the expiration date doesn't yet exist.
func test_getToken_tokenNil() async throws {
tokenService.accessToken = "ACCESS_TOKEN"
tokenService.accessTokenExpirationDateResult = .success(nil)
let token = try await subject.getToken() let token = try await subject.getToken()
XCTAssertEqual(token, "ACCESS_TOKEN") XCTAssertEqual(token, "ACCESS_TOKEN")
} }
@ -58,12 +108,12 @@ class AccountTokenProviderTests: BitwardenTestCase {
client.result = .httpSuccess(testData: .identityTokenRefresh) client.result = .httpSuccess(testData: .identityTokenRefresh)
try await subject.refreshToken() let newAccessToken = try await subject.refreshToken()
let newAccessToken = try await subject.getToken()
XCTAssertEqual(newAccessToken, "ACCESS_TOKEN") XCTAssertEqual(newAccessToken, "ACCESS_TOKEN")
XCTAssertEqual(tokenService.accessToken, "ACCESS_TOKEN") XCTAssertEqual(tokenService.accessToken, "ACCESS_TOKEN")
XCTAssertEqual(tokenService.refreshToken, "REFRESH_TOKEN") XCTAssertEqual(tokenService.refreshToken, "REFRESH_TOKEN")
XCTAssertEqual(tokenService.expirationDate, Date(year: 2025, month: 10, day: 2, hour: 1, minute: 0, second: 0))
let refreshTask = await subject.refreshTask let refreshTask = await subject.refreshTask
XCTAssertNil(refreshTask) XCTAssertNil(refreshTask)
@ -76,14 +126,15 @@ class AccountTokenProviderTests: BitwardenTestCase {
client.result = .httpSuccess(testData: .identityTokenRefresh) client.result = .httpSuccess(testData: .identityTokenRefresh)
async let refreshTask1: Void = subject.refreshToken() async let refreshTask1: String = subject.refreshToken()
async let refreshTask2: Void = subject.refreshToken() async let refreshTask2: String = subject.refreshToken()
_ = try await (refreshTask1, refreshTask2) _ = try await (refreshTask1, refreshTask2)
XCTAssertEqual(client.requests.count, 1) XCTAssertEqual(client.requests.count, 1)
XCTAssertEqual(tokenService.accessToken, "ACCESS_TOKEN") XCTAssertEqual(tokenService.accessToken, "ACCESS_TOKEN")
XCTAssertEqual(tokenService.refreshToken, "REFRESH_TOKEN") XCTAssertEqual(tokenService.refreshToken, "REFRESH_TOKEN")
XCTAssertEqual(tokenService.expirationDate, Date(year: 2025, month: 10, day: 2, hour: 1, minute: 0, second: 0))
let refreshTask = await subject.refreshTask let refreshTask = await subject.refreshTask
XCTAssertNil(refreshTask) XCTAssertNil(refreshTask)
@ -101,7 +152,7 @@ class AccountTokenProviderTests: BitwardenTestCase {
client.result = .failure(BitwardenTestError.example) client.result = .failure(BitwardenTestError.example)
await assertAsyncThrows(error: BitwardenTestError.example) { await assertAsyncThrows(error: BitwardenTestError.example) {
try await subject.refreshToken() _ = try await subject.refreshToken()
} }
XCTAssertTrue(delegate.onRefreshTokenErrorCalled) XCTAssertTrue(delegate.onRefreshTokenErrorCalled)
} }
@ -115,7 +166,7 @@ class AccountTokenProviderTests: BitwardenTestCase {
client.result = .failure(BitwardenTestError.example) client.result = .failure(BitwardenTestError.example)
await assertAsyncThrows(error: BitwardenTestError.example) { await assertAsyncThrows(error: BitwardenTestError.example) {
try await subject.refreshToken() _ = try await subject.refreshToken()
} }
} }
} }

View File

@ -9,6 +9,6 @@ protocol RefreshableAPIService { // sourcery: AutoMockable
extension APIService: RefreshableAPIService { extension APIService: RefreshableAPIService {
func refreshAccessToken() async throws { func refreshAccessToken() async throws {
try await accountTokenProvider.refreshToken() _ = try await accountTokenProvider.refreshToken()
} }
} }

View File

@ -9,15 +9,15 @@ class MockAccountTokenProvider: AccountTokenProvider {
var delegate: AccountTokenProviderDelegate? var delegate: AccountTokenProviderDelegate?
var getTokenResult: Result<String, Error> = .success("ACCESS_TOKEN") var getTokenResult: Result<String, Error> = .success("ACCESS_TOKEN")
var refreshTokenCalled = false var refreshTokenCalled = false
var refreshTokenResult: Result<Void, Error> = .success(()) var refreshTokenResult: Result<String, Error> = .success("ACCESS_TOKEN")
func getToken() async throws -> String { func getToken() async throws -> String {
try getTokenResult.get() try getTokenResult.get()
} }
func refreshToken() async throws { func refreshToken() async throws -> String {
refreshTokenCalled = true refreshTokenCalled = true
try refreshTokenResult.get() return try refreshTokenResult.get()
} }
func setDelegate(delegate: AccountTokenProviderDelegate) async { func setDelegate(delegate: AccountTokenProviderDelegate) async {

View File

@ -44,6 +44,13 @@ protocol StateService: AnyObject {
/// ///
func doesActiveAccountHavePremium() async -> Bool func doesActiveAccountHavePremium() async -> Bool
/// Gets the access token's expiration date for an account.
///
/// - Parameter userId: The user ID associated with the access token expiration date.
/// - Returns: The user's access token expiration date.
///
func getAccessTokenExpirationDate(userId: String) async -> Date?
/// Gets the account for an id. /// Gets the account for an id.
/// ///
/// - Parameter userId: The id for an account. If nil, the active account will be returned. /// - Parameter userId: The id for an account. If nil, the active account will be returned.
@ -429,6 +436,14 @@ protocol StateService: AnyObject {
/// ///
func pinUnlockRequiresPasswordAfterRestart() async throws -> Bool func pinUnlockRequiresPasswordAfterRestart() async throws -> Bool
/// Sets the access token's expiration date for an account.
///
/// - Parameters:
/// - expirationDate: The user's access token expiration date.
/// - userId: The user ID associated with the access token expiration date.
///
func setAccessTokenExpirationDate(_ expirationDate: Date?, userId: String) async
/// Sets the account encryption keys for an account. /// Sets the account encryption keys for an account.
/// ///
/// - Parameters: /// - Parameters:
@ -855,6 +870,14 @@ extension StateService {
await setPendingAppIntentActions(actions: actions) await setPendingAppIntentActions(actions: actions)
} }
/// Gets the access token's expiration date for the active account.
///
/// - Returns: The user's access token expiration date.
///
func getAccessTokenExpirationDate() async throws -> Date? {
try await getAccessTokenExpirationDate(userId: getActiveAccountId())
}
/// Gets the account encryptions keys for the active account. /// Gets the account encryptions keys for the active account.
/// ///
/// - Returns: The account encryption keys. /// - Returns: The account encryption keys.
@ -1143,6 +1166,14 @@ extension StateService {
try await pinProtectedUserKeyEnvelope(userId: nil) try await pinProtectedUserKeyEnvelope(userId: nil)
} }
/// Sets the access token's expiration date for the active account.
///
/// - Parameter expirationDate: The user's access token expiration date.
///
func setAccessTokenExpirationDate(_ expirationDate: Date?) async throws {
try await setAccessTokenExpirationDate(expirationDate, userId: getActiveAccountId())
}
/// Sets the account encryption keys for the active account. /// Sets the account encryption keys for the active account.
/// ///
/// - Parameter encryptionKeys: The account encryption keys. /// - Parameter encryptionKeys: The account encryption keys.
@ -1542,6 +1573,10 @@ actor DefaultStateService: StateService, ConfigStateService { // swiftlint:disab
} }
} }
func getAccessTokenExpirationDate(userId: String) -> Date? {
appSettingsStore.accessTokenExpirationDate(userId: userId)
}
func getAccount(userId: String?) throws -> Account { func getAccount(userId: String?) throws -> Account {
guard let accounts = appSettingsStore.state?.accounts else { guard let accounts = appSettingsStore.state?.accounts else {
throw StateServiceError.noAccounts throw StateServiceError.noAccounts
@ -1844,6 +1879,7 @@ actor DefaultStateService: StateService, ConfigStateService { // swiftlint:disab
state.activeUserId = state.accounts.first?.key state.activeUserId = state.accounts.first?.key
} }
appSettingsStore.setAccessTokenExpirationDate(nil, userId: knownUserId)
appSettingsStore.setBiometricAuthenticationEnabled(nil, for: knownUserId) appSettingsStore.setBiometricAuthenticationEnabled(nil, for: knownUserId)
appSettingsStore.setDefaultUriMatchType(nil, userId: knownUserId) appSettingsStore.setDefaultUriMatchType(nil, userId: knownUserId)
appSettingsStore.setDisableAutoTotpCopy(nil, userId: knownUserId) appSettingsStore.setDisableAutoTotpCopy(nil, userId: knownUserId)
@ -1876,6 +1912,10 @@ actor DefaultStateService: StateService, ConfigStateService { // swiftlint:disab
&& appSettingsStore.pinProtectedUserKey(userId: userId) == nil && appSettingsStore.pinProtectedUserKey(userId: userId) == nil
} }
func setAccessTokenExpirationDate(_ expirationDate: Date?, userId: String) async {
appSettingsStore.setAccessTokenExpirationDate(expirationDate, userId: userId)
}
func setAccountKdf(_ kdfConfig: KdfConfig, userId: String) async throws { func setAccountKdf(_ kdfConfig: KdfConfig, userId: String) async throws {
try updateAccountProfile(userId: userId) { profile in try updateAccountProfile(userId: userId) { profile in
profile.kdfType = kdfConfig.kdfType profile.kdfType = kdfConfig.kdfType

View File

@ -290,6 +290,29 @@ class StateServiceTests: BitwardenTestCase { // swiftlint:disable:this type_body
XCTAssertEqual(errorReporter.errors as? [StateServiceError], [.noActiveAccount]) XCTAssertEqual(errorReporter.errors as? [StateServiceError], [.noActiveAccount])
} }
/// `getAccessTokenExpirationDate(userId:)` gets the user's access token expiration date.
func test_getAccessTokenExpirationDate() async throws {
let date1 = Date(year: 2025, month: 1, day: 1)
let date2 = Date(year: 2026, month: 6, day: 1)
appSettingsStore.accessTokenExpirationDateByUserId["1"] = date1
appSettingsStore.accessTokenExpirationDateByUserId["2"] = date2
await subject.addAccount(.fixture(profile: .fixture(userId: "1")))
await subject.addAccount(.fixture(profile: .fixture(userId: "2")))
let expirationDate1 = await subject.getAccessTokenExpirationDate(userId: "1")
XCTAssertEqual(expirationDate1, date1)
let expirationDate2 = try await subject.getAccessTokenExpirationDate()
XCTAssertEqual(expirationDate2, date2)
}
/// `getAccessTokenExpirationDate(userId:)` throws an error if there's no accounts.
func test_getAccessTokenExpirationDate_noAccount() async throws {
await assertAsyncThrows(error: StateServiceError.noActiveAccount) {
_ = try await subject.getAccessTokenExpirationDate()
}
}
/// `getAccountEncryptionKeys(_:)` returns the encryption keys for the user account. /// `getAccountEncryptionKeys(_:)` returns the encryption keys for the user account.
func test_getAccountEncryptionKeys() async throws { func test_getAccountEncryptionKeys() async throws {
appSettingsStore.accountKeys["1"] = .fixture( appSettingsStore.accountKeys["1"] = .fixture(
@ -1681,6 +1704,29 @@ class StateServiceTests: BitwardenTestCase { // swiftlint:disable:this type_body
XCTAssertTrue(result == true) XCTAssertTrue(result == true)
} }
/// `setAccessTokenExpirationDate(_:userId:)` sets the access token expiration date for the account.
func test_setAccessTokenExpirationDate() async throws {
let date1 = Date(year: 2025, month: 1, day: 1)
let date2 = Date(year: 2026, month: 6, day: 1)
await subject.addAccount(.fixture(profile: .fixture(userId: "1")))
await subject.addAccount(.fixture(profile: .fixture(userId: "2")))
await subject.setAccessTokenExpirationDate(date1, userId: "1")
try await subject.setAccessTokenExpirationDate(date2)
XCTAssertEqual(
appSettingsStore.accessTokenExpirationDateByUserId,
["1": date1, "2": date2],
)
}
/// `setAccessTokenExpirationDate(_:userId:)` throws an error if there's no accounts.
func test_setAccessTokenExpirationDate_noAccounts() async throws {
await assertAsyncThrows(error: StateServiceError.noActiveAccount) {
_ = try await subject.setAccessTokenExpirationDate(.now)
}
}
/// `setAccountEncryptionKeys(_:userId:)` sets the encryption keys for the user account. /// `setAccountEncryptionKeys(_:userId:)` sets the encryption keys for the user account.
func test_setAccountEncryptionKeys() async throws { func test_setAccountEncryptionKeys() async throws {
await subject.addAccount(.fixture(profile: .fixture(userId: "1"))) await subject.addAccount(.fixture(profile: .fixture(userId: "1")))

View File

@ -73,6 +73,13 @@ protocol AppSettingsStore: AnyObject {
/// The app's account state. /// The app's account state.
var state: State? { get set } var state: State? { get set }
/// The user's access token expiration date.
///
/// - Parameter userId: The user ID associated with the access token expiration date.
/// - Returns: The user's access token expiration date.
///
func accessTokenExpirationDate(userId: String) -> Date?
/// The user's v2 account keys. /// The user's v2 account keys.
/// ///
/// - Parameter userId: The user ID associated with the stored account keys. /// - Parameter userId: The user ID associated with the stored account keys.
@ -265,6 +272,14 @@ protocol AppSettingsStore: AnyObject {
/// - Returns: The server config for that user ID. /// - Returns: The server config for that user ID.
func serverConfig(userId: String) -> ServerConfig? func serverConfig(userId: String) -> ServerConfig?
/// Sets the user's access token expiration date
///
/// - Parameters:
/// - expirationDate: The user's access token expiration date
/// - userId: The user ID associated with the access token expiration date.
///
func setAccessTokenExpirationDate(_ expirationDate: Date?, userId: String)
/// Sets the account v2 keys for a user ID. /// Sets the account v2 keys for a user ID.
/// ///
/// - Parameters: /// - Parameters:
@ -740,6 +755,7 @@ extension DefaultAppSettingsStore: AppSettingsStore, ConfigSettingsStore {
/// The keys used to store their associated values. /// The keys used to store their associated values.
/// ///
enum Keys { enum Keys {
case accessTokenExpirationDate(userId: String)
case accountKeys(userId: String) case accountKeys(userId: String)
case accountSetupAutofill(userId: String) case accountSetupAutofill(userId: String)
case accountSetupImportLogins(userId: String) case accountSetupImportLogins(userId: String)
@ -800,6 +816,8 @@ extension DefaultAppSettingsStore: AppSettingsStore, ConfigSettingsStore {
/// Returns the key used to store the data under for retrieving it later. /// Returns the key used to store the data under for retrieving it later.
var storageKey: String { var storageKey: String {
let key = switch self { let key = switch self {
case let .accessTokenExpirationDate(userId):
"accessTokenExpirationDate_\(userId)"
case let .accountKeys(userId): case let .accountKeys(userId):
"accountKeys_\(userId)" "accountKeys_\(userId)"
case let .accountSetupAutofill(userId): case let .accountSetupAutofill(userId):
@ -1019,6 +1037,10 @@ extension DefaultAppSettingsStore: AppSettingsStore, ConfigSettingsStore {
} }
} }
func accessTokenExpirationDate(userId: String) -> Date? {
fetch(for: .accessTokenExpirationDate(userId: userId))
}
func accountKeys(userId: String) -> PrivateKeysResponseModel? { func accountKeys(userId: String) -> PrivateKeysResponseModel? {
fetch(for: .accountKeys(userId: userId)) fetch(for: .accountKeys(userId: userId))
} }
@ -1141,6 +1163,10 @@ extension DefaultAppSettingsStore: AppSettingsStore, ConfigSettingsStore {
fetch(for: .serverConfig(userId: userId)) fetch(for: .serverConfig(userId: userId))
} }
func setAccessTokenExpirationDate(_ expirationDate: Date?, userId: String) {
store(expirationDate, for: .accessTokenExpirationDate(userId: userId))
}
func setAccountKeys(_ keys: PrivateKeysResponseModel?, userId: String) { func setAccountKeys(_ keys: PrivateKeysResponseModel?, userId: String) {
store(keys, for: .accountKeys(userId: userId)) store(keys, for: .accountKeys(userId: userId))
} }

View File

@ -39,6 +39,24 @@ class AppSettingsStoreTests: BitwardenTestCase { // swiftlint:disable:this type_
// MARK: Tests // MARK: Tests
/// `accessTokenExpirationDate(userId:)` returns `nil` if there isn't a previously stored value.
func test_accessTokenExpirationDate_isInitiallyNil() {
XCTAssertNil(subject.accessTokenExpirationDate(userId: "-1"))
}
/// `accessTokenExpirationDate(userId:)` can be used to get the user's access token expiration date.
func test_accessTokenExpirationDate_withValue() {
let date1 = Date(year: 2025, month: 10, day: 1)
let date2 = Date(year: 2026, month: 1, day: 2)
subject.setAccessTokenExpirationDate(date1, userId: "1")
subject.setAccessTokenExpirationDate(date2, userId: "2")
XCTAssertEqual(subject.accessTokenExpirationDate(userId: "1"), date1)
XCTAssertEqual(subject.accessTokenExpirationDate(userId: "2"), date2)
XCTAssertEqual(userDefaults.integer(forKey: "bwPreferencesStorage:accessTokenExpirationDate_1"), 780_969_600)
XCTAssertEqual(userDefaults.integer(forKey: "bwPreferencesStorage:accessTokenExpirationDate_2"), 789_004_800)
}
/// `accountKeys(userId:)` returns `nil` if there isn't a previously stored value. /// `accountKeys(userId:)` returns `nil` if there isn't a previously stored value.
func test_accountKeys_isInitiallyNil() { func test_accountKeys_isInitiallyNil() {
XCTAssertNil(subject.accountKeys(userId: "-1")) XCTAssertNil(subject.accountKeys(userId: "-1"))

View File

@ -7,6 +7,7 @@ import Foundation
// swiftlint:disable file_length // swiftlint:disable file_length
class MockAppSettingsStore: AppSettingsStore { // swiftlint:disable:this type_body_length class MockAppSettingsStore: AppSettingsStore { // swiftlint:disable:this type_body_length
var accessTokenExpirationDateByUserId = [String: Date]()
var accountKeys = [String: PrivateKeysResponseModel]() var accountKeys = [String: PrivateKeysResponseModel]()
var accountSetupAutofill = [String: AccountSetupProgress]() var accountSetupAutofill = [String: AccountSetupProgress]()
var accountSetupImportLogins = [String: AccountSetupProgress]() var accountSetupImportLogins = [String: AccountSetupProgress]()
@ -74,6 +75,10 @@ class MockAppSettingsStore: AppSettingsStore { // swiftlint:disable:this type_bo
var activeIdSubject = CurrentValueSubject<String?, Never>(nil) var activeIdSubject = CurrentValueSubject<String?, Never>(nil)
func accessTokenExpirationDate(userId: String) -> Date? {
accessTokenExpirationDateByUserId[userId]
}
func accountKeys(userId: String) -> PrivateKeysResponseModel? { func accountKeys(userId: String) -> PrivateKeysResponseModel? {
accountKeys[userId] accountKeys[userId]
} }
@ -191,6 +196,10 @@ class MockAppSettingsStore: AppSettingsStore { // swiftlint:disable:this type_bo
serverConfig[userId] serverConfig[userId]
} }
func setAccessTokenExpirationDate(_ expirationDate: Date?, userId: String) {
accessTokenExpirationDateByUserId[userId] = expirationDate
}
func setAccountKeys(_ keys: BitwardenShared.PrivateKeysResponseModel?, userId: String) { func setAccountKeys(_ keys: BitwardenShared.PrivateKeysResponseModel?, userId: String) {
accountKeys[userId] = keys accountKeys[userId] = keys
} }

View File

@ -7,6 +7,7 @@ import Foundation
@testable import BitwardenShared @testable import BitwardenShared
class MockStateService: StateService { // swiftlint:disable:this type_body_length class MockStateService: StateService { // swiftlint:disable:this type_body_length
var accessTokenExpirationDateByUserId = [String: Date]()
var accountEncryptionKeys = [String: AccountEncryptionKeys]() var accountEncryptionKeys = [String: AccountEncryptionKeys]()
var accountSetupAutofill = [String: AccountSetupProgress]() var accountSetupAutofill = [String: AccountSetupProgress]()
var accountSetupAutofillError: Error? var accountSetupAutofillError: Error?
@ -138,6 +139,10 @@ class MockStateService: StateService { // swiftlint:disable:this type_body_lengt
}) })
} }
func getAccessTokenExpirationDate(userId: String) async -> Date? {
accessTokenExpirationDateByUserId[userId]
}
func didAccountSwitchInExtension() async throws -> Bool { func didAccountSwitchInExtension() async throws -> Bool {
try didAccountSwitchInExtensionResult.get() try didAccountSwitchInExtensionResult.get()
} }
@ -445,6 +450,10 @@ class MockStateService: StateService { // swiftlint:disable:this type_body_lengt
pinUnlockRequiresPasswordAfterRestartValue pinUnlockRequiresPasswordAfterRestartValue
} }
func setAccessTokenExpirationDate(_ expirationDate: Date?, userId: String) async {
accessTokenExpirationDateByUserId[userId] = expirationDate
}
func setAccountEncryptionKeys(_ encryptionKeys: AccountEncryptionKeys, userId: String?) async throws { func setAccountEncryptionKeys(_ encryptionKeys: AccountEncryptionKeys, userId: String?) async throws {
let userId = try unwrapUserId(userId) let userId = try unwrapUserId(userId)
accountEncryptionKeys[userId] = encryptionKeys accountEncryptionKeys[userId] = encryptionKeys

View File

@ -1,9 +1,12 @@
import Foundation
import Networking import Networking
@testable import BitwardenShared @testable import BitwardenShared
class MockTokenService: TokenService { class MockTokenService: TokenService {
var accessToken: String? = "ACCESS_TOKEN" var accessToken: String? = "ACCESS_TOKEN"
var accessTokenExpirationDateResult: Result<Date?, Error> = .success(nil)
var expirationDate: Date?
var getIsExternalResult: Result<Bool, Error> = .success(false) var getIsExternalResult: Result<Bool, Error> = .success(false)
var refreshToken: String? = "REFRESH_TOKEN" var refreshToken: String? = "REFRESH_TOKEN"
@ -12,6 +15,10 @@ class MockTokenService: TokenService {
return accessToken return accessToken
} }
func getAccessTokenExpirationDate() async throws -> Date? {
try accessTokenExpirationDateResult.get()
}
func getIsExternal() async throws -> Bool { func getIsExternal() async throws -> Bool {
try getIsExternalResult.get() try getIsExternalResult.get()
} }
@ -21,8 +28,9 @@ class MockTokenService: TokenService {
return refreshToken return refreshToken
} }
func setTokens(accessToken: String, refreshToken: String) async { func setTokens(accessToken: String, refreshToken: String, expirationDate: Date) async {
self.accessToken = accessToken self.accessToken = accessToken
self.refreshToken = refreshToken self.refreshToken = refreshToken
self.expirationDate = expirationDate
} }
} }

View File

@ -1,5 +1,6 @@
import BitwardenKit import BitwardenKit
import BitwardenSdk import BitwardenSdk
import Foundation
/// A protocol for a `TokenService` which manages accessing and updating the active account's tokens. /// A protocol for a `TokenService` which manages accessing and updating the active account's tokens.
/// ///
@ -10,6 +11,12 @@ protocol TokenService: AnyObject {
/// ///
func getAccessToken() async throws -> String func getAccessToken() async throws -> String
/// Returns the access token's expiration date for the current account.
///
/// - Returns: The access token's expiration date for the current account.
///
func getAccessTokenExpirationDate() async throws -> Date?
/// Returns whether the user is an external user. /// Returns whether the user is an external user.
/// ///
/// - Returns: Whether the user is an external user. /// - Returns: Whether the user is an external user.
@ -27,8 +34,9 @@ protocol TokenService: AnyObject {
/// - Parameters: /// - Parameters:
/// - accessToken: The account's updated access token. /// - accessToken: The account's updated access token.
/// - refreshToken: The account's updated refresh token. /// - refreshToken: The account's updated refresh token.
/// - expirationDate: The access token's expiration date.
/// ///
func setTokens(accessToken: String, refreshToken: String) async throws func setTokens(accessToken: String, refreshToken: String, expirationDate: Date) async throws
} }
// MARK: - DefaultTokenService // MARK: - DefaultTokenService
@ -73,6 +81,10 @@ actor DefaultTokenService: TokenService {
return try await keychainRepository.getAccessToken(userId: userId) return try await keychainRepository.getAccessToken(userId: userId)
} }
func getAccessTokenExpirationDate() async throws -> Date? {
try await stateService.getAccessTokenExpirationDate()
}
func getIsExternal() async throws -> Bool { func getIsExternal() async throws -> Bool {
let accessToken: String = try await getAccessToken() let accessToken: String = try await getAccessToken()
let tokenPayload = try TokenParser.parseToken(accessToken) let tokenPayload = try TokenParser.parseToken(accessToken)
@ -84,10 +96,11 @@ actor DefaultTokenService: TokenService {
return try await keychainRepository.getRefreshToken(userId: userId) return try await keychainRepository.getRefreshToken(userId: userId)
} }
func setTokens(accessToken: String, refreshToken: String) async throws { func setTokens(accessToken: String, refreshToken: String, expirationDate: Date) async throws {
let userId = try await stateService.getActiveAccountId() let userId = try await stateService.getActiveAccountId()
try await keychainRepository.setAccessToken(accessToken, userId: userId) try await keychainRepository.setAccessToken(accessToken, userId: userId)
try await keychainRepository.setRefreshToken(refreshToken, userId: userId) try await keychainRepository.setRefreshToken(refreshToken, userId: userId)
await stateService.setAccessTokenExpirationDate(expirationDate, userId: userId)
} }
} }

View File

@ -69,6 +69,22 @@ class TokenServiceTests: BitwardenTestCase {
XCTAssertNil(accessToken) XCTAssertNil(accessToken)
} }
/// `getAccessTokenExpirationDate()` returns the access token's expiration date.
func test_getAccessTokenExpirationDate() async throws {
stateService.accessTokenExpirationDateByUserId["1"] = Date(year: 2025, month: 10, day: 2)
stateService.activeAccount = .fixture()
let expirationDate = try await subject.getAccessTokenExpirationDate()
XCTAssertEqual(expirationDate, Date(year: 2025, month: 10, day: 2))
}
/// `getAccessTokenExpirationDate()` throws an error if there isn't an active account.
func test_getAccessTokenExpirationDate_error() async throws {
await assertAsyncThrows(error: StateServiceError.noActiveAccount) {
_ = try await subject.getAccessTokenExpirationDate()
}
}
/// `getIsExternal()` returns false if the user isn't an external user. /// `getIsExternal()` returns false if the user isn't an external user.
func test_getIsExternal_false() async throws { func test_getIsExternal_false() async throws {
// swiftlint:disable:next line_length // swiftlint:disable:next line_length
@ -144,7 +160,8 @@ class TokenServiceTests: BitwardenTestCase {
func test_setTokens() async throws { func test_setTokens() async throws {
stateService.activeAccount = .fixture() stateService.activeAccount = .fixture()
try await subject.setTokens(accessToken: "🔑", refreshToken: "🔒") let expirationDate = Date(year: 2025, month: 10, day: 1)
try await subject.setTokens(accessToken: "🔑", refreshToken: "🔒", expirationDate: expirationDate)
XCTAssertEqual( XCTAssertEqual(
keychainRepository.mockStorage[keychainRepository.formattedKey(for: .accessToken(userId: "1"))], keychainRepository.mockStorage[keychainRepository.formattedKey(for: .accessToken(userId: "1"))],
@ -154,5 +171,6 @@ class TokenServiceTests: BitwardenTestCase {
keychainRepository.mockStorage[keychainRepository.formattedKey(for: .refreshToken(userId: "1"))], keychainRepository.mockStorage[keychainRepository.formattedKey(for: .refreshToken(userId: "1"))],
"🔒", "🔒",
) )
XCTAssertEqual(stateService.accessTokenExpirationDateByUserId["1"], expirationDate)
} }
} }

View File

@ -87,6 +87,10 @@ extension Constants {
/// The default number of KDF iterations to perform. /// The default number of KDF iterations to perform.
static let pbkdf2Iterations = 600_000 static let pbkdf2Iterations = 600_000
/// The number of seconds before an access token's expiration time at which the app will
/// preemptively refresh the token.
static let tokenRefreshThreshold: TimeInterval = 5 * 60 // 5 minutes
} }
// MARK: Extension Constants // MARK: Extension Constants

View File

@ -682,6 +682,8 @@ extension AppProcessor: AccountTokenProviderDelegate {
func onRefreshTokenError(error: any Error) async throws { func onRefreshTokenError(error: any Error) async throws {
if case IdentityTokenRefreshRequestError.invalidGrant = error { if case IdentityTokenRefreshRequestError.invalidGrant = error {
await logOutAutomatically() await logOutAutomatically()
} else if let error = error as? ResponseValidationError, [401, 403].contains(error.response.statusCode) {
await logOutAutomatically()
} }
} }
} }

View File

@ -1,5 +1,6 @@
import AuthenticationServices import AuthenticationServices
import AuthenticatorBridgeKit import AuthenticatorBridgeKit
import BitwardenKit
import BitwardenKitMocks import BitwardenKitMocks
import BitwardenResources import BitwardenResources
import Foundation import Foundation
@ -1402,6 +1403,36 @@ class AppProcessorTests: BitwardenTestCase { // swiftlint:disable:this type_body
XCTAssertEqual(coordinator.events, [.didLogout(userId: "1", userInitiated: false)]) XCTAssertEqual(coordinator.events, [.didLogout(userId: "1", userInitiated: false)])
} }
/// `onRefreshTokenError(error:)` logs the user out and notifies the coordinator when a 401 is
/// received while refreshing the token.
@MainActor
func test_onRefreshTokenError_logOut401() async throws {
coordinator.isLoadingOverlayShowing = true
try await subject.onRefreshTokenError(error: ResponseValidationError(response: .failure(statusCode: 401)))
XCTAssertTrue(authRepository.logoutCalled)
XCTAssertEqual(authRepository.logoutUserId, nil)
XCTAssertFalse(authRepository.logoutUserInitiated)
XCTAssertFalse(coordinator.isLoadingOverlayShowing)
XCTAssertEqual(coordinator.events, [.didLogout(userId: nil, userInitiated: false)])
}
/// `onRefreshTokenError(error:)` logs the user out and notifies the coordinator a 403 is
/// received while refreshing the token.
@MainActor
func test_onRefreshTokenError_logOut403() async throws {
coordinator.isLoadingOverlayShowing = true
try await subject.onRefreshTokenError(error: ResponseValidationError(response: .failure(statusCode: 403)))
XCTAssertTrue(authRepository.logoutCalled)
XCTAssertEqual(authRepository.logoutUserId, nil)
XCTAssertFalse(authRepository.logoutUserInitiated)
XCTAssertFalse(coordinator.isLoadingOverlayShowing)
XCTAssertEqual(coordinator.events, [.didLogout(userId: nil, userInitiated: false)])
}
/// `onRefreshTokenError(error:)` logs the user out and notifies the coordinator when error is `.invalidGrant`. /// `onRefreshTokenError(error:)` logs the user out and notifies the coordinator when error is `.invalidGrant`.
@MainActor @MainActor
func test_onRefreshTokenError_logOutInvalidGrant() async throws { func test_onRefreshTokenError_logOutInvalidGrant() async throws {

View File

@ -145,7 +145,7 @@ public final class HTTPService: Sendable {
} }
if let tokenProvider, httpResponse.statusCode == 401, shouldRetryIfUnauthorized { if let tokenProvider, httpResponse.statusCode == 401, shouldRetryIfUnauthorized {
try await tokenProvider.refreshToken() _ = try await tokenProvider.refreshToken()
// Send the request again, but don't retry if still unauthorized to prevent a retry loop. // Send the request again, but don't retry if still unauthorized to prevent a retry loop.
return try await send(httpRequest, validate: validate, shouldRetryIfUnauthorized: false) return try await send(httpRequest, validate: validate, shouldRetryIfUnauthorized: false)

View File

@ -9,5 +9,7 @@ public protocol TokenProvider: Sendable {
/// Refreshes the access token by using the refresh token to acquire a new access token. /// Refreshes the access token by using the refresh token to acquire a new access token.
/// ///
func refreshToken() async throws /// - Returns: A new access token.
///
func refreshToken() async throws -> String
} }

View File

@ -8,7 +8,7 @@ class MockTokenProvider: TokenProvider {
var getTokenCallCount = 0 var getTokenCallCount = 0
var tokenResults: [Result<String, Error>] = [.success("ACCESS_TOKEN")] var tokenResults: [Result<String, Error>] = [.success("ACCESS_TOKEN")]
var refreshTokenResult: Result<Void, Error> = .success(()) var refreshTokenResult: Result<String, Error> = .success("ACCESS_TOKEN")
var refreshTokenCallCount = 0 var refreshTokenCallCount = 0
func getToken() async throws -> String { func getToken() async throws -> String {
@ -17,8 +17,8 @@ class MockTokenProvider: TokenProvider {
return try tokenResults.removeFirst().get() return try tokenResults.removeFirst().get()
} }
func refreshToken() async throws { func refreshToken() async throws -> String {
refreshTokenCallCount += 1 refreshTokenCallCount += 1
try refreshTokenResult.get() return try refreshTokenResult.get()
} }
} }