[PM-24482] Refresh access token preemptively and log out on 401/403 refresh errors (#2024)

This commit is contained in:
Matt Czech 2025-10-08 12:07:21 -05:00 committed by GitHub
parent 039495e7e9
commit 4376077ab1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
19 changed files with 340 additions and 30 deletions

View File

@ -1,3 +1,5 @@
import BitwardenKit
import Foundation
import Networking
// MARK: - AccountTokenProvider
@ -21,13 +23,16 @@ actor DefaultAccountTokenProvider: AccountTokenProvider {
private weak var accountTokenProviderDelegate: AccountTokenProviderDelegate?
/// The `HTTPService` used to make the API call to refresh the access token.
let httpService: HTTPService
private let httpService: HTTPService
/// The task associated with refreshing the token, if one is in progress.
private(set) var refreshTask: Task<String, Error>?
/// The service used to get the present time.
private let timeProvider: TimeProvider
/// The `TokenService` used to get the current tokens from.
let tokenService: TokenService
private let tokenService: TokenService
// MARK: Initialization
@ -35,14 +40,17 @@ actor DefaultAccountTokenProvider: AccountTokenProvider {
///
/// - Parameters:
/// - httpService: The service used to make the API call to refresh the access token.
/// - timeProvider: The service used to get the present time.
/// - tokenService: The service used to get the current tokens from.
///
init(
httpService: HTTPService,
timeProvider: TimeProvider = CurrentTime(),
tokenService: TokenService,
) {
self.tokenService = tokenService
self.httpService = httpService
self.timeProvider = timeProvider
self.tokenService = tokenService
}
// MARK: Methods
@ -54,15 +62,19 @@ actor DefaultAccountTokenProvider: AccountTokenProvider {
return try await refreshTask.value
}
return try await tokenService.getAccessToken()
let token = try await tokenService.getAccessToken()
if await shouldRefresh(accessToken: token) {
return try await refreshToken()
} else {
return token
}
}
func refreshToken() async throws {
func refreshToken() async throws -> String {
if let refreshTask {
// If there's a refresh in progress, wait for it to complete rather than triggering
// another refresh.
_ = try await refreshTask.value
return
return try await refreshTask.value
}
let refreshTask = Task {
@ -73,9 +85,12 @@ actor DefaultAccountTokenProvider: AccountTokenProvider {
let response = try await httpService.send(
IdentityTokenRefreshRequest(refreshToken: refreshToken),
)
let expirationDate = timeProvider.presentTime.addingTimeInterval(TimeInterval(response.expiresIn))
try await tokenService.setTokens(
accessToken: response.accessToken,
refreshToken: response.refreshToken,
expirationDate: expirationDate,
)
return response.accessToken
@ -88,17 +103,35 @@ actor DefaultAccountTokenProvider: AccountTokenProvider {
}
self.refreshTask = refreshTask
_ = try await refreshTask.value
return try await refreshTask.value
}
func setDelegate(delegate: AccountTokenProviderDelegate) async {
accountTokenProviderDelegate = delegate
}
// MARK: Private
/// Returns whether the access token needs to be refreshed based on the last stored access token
/// expiration date. This is used to preemptively refresh the token prior to its expiration.
///
/// - Parameter accessToken: The access token to determine whether it needs to be refreshed.
/// - Returns: Whether the access token needs to be refreshed.
///
private func shouldRefresh(accessToken: String) async -> Bool {
guard let expirationDate = try? await tokenService.getAccessTokenExpirationDate() else {
// If there's no stored expiration date, don't preemptively refresh the token.
return false
}
let refreshThreshold = timeProvider.presentTime.addingTimeInterval(Constants.tokenRefreshThreshold)
return expirationDate <= refreshThreshold
}
}
/// Delegate to be used by the `AccountTokenProvider`.
protocol AccountTokenProviderDelegate: AnyObject {
/// Callbac to be used when an error is thrown when refreshing the access token.
/// Callback to be used when an error is thrown when refreshing the access token.
/// - Parameter error: `Error` thrown.
func onRefreshTokenError(error: Error) async throws
}

View File

@ -1,3 +1,4 @@
import BitwardenKitMocks
import Networking
import TestHelpers
import XCTest
@ -9,18 +10,25 @@ class AccountTokenProviderTests: BitwardenTestCase {
var client: MockHTTPClient!
var subject: DefaultAccountTokenProvider!
var timeProvider: MockTimeProvider!
var tokenService: MockTokenService!
let expirationDateExpired = Date(year: 2025, month: 10, day: 1, hour: 23, minute: 59, second: 0)
let expirationDateExpiringSoon = Date(year: 2025, month: 10, day: 2, hour: 0, minute: 2, second: 0)
let expirationDateUnexpired = Date(year: 2025, month: 10, day: 2, hour: 0, minute: 6, second: 0)
// MARK: Setup & Teardown
override func setUp() {
super.setUp()
client = MockHTTPClient()
timeProvider = MockTimeProvider(.mockTime(Date(year: 2025, month: 10, day: 2)))
tokenService = MockTokenService()
subject = DefaultAccountTokenProvider(
httpService: HTTPService(baseURL: URL(string: "https://example.com")!, client: client),
timeProvider: timeProvider,
tokenService: tokenService,
)
}
@ -30,13 +38,55 @@ class AccountTokenProviderTests: BitwardenTestCase {
client = nil
subject = nil
timeProvider = nil
tokenService = nil
}
// MARK: Tests
/// `getToken()` returns the current access token.
func test_getToken() async throws {
/// `getToken()` returns the current access token if fetching the expiration date returns an error.
func test_getToken_tokenError() async throws {
tokenService.accessToken = "ACCESS_TOKEN"
tokenService.accessTokenExpirationDateResult = .failure(BitwardenTestError.example)
let token = try await subject.getToken()
XCTAssertEqual(token, "ACCESS_TOKEN")
}
/// `getToken()` returns a refreshed access token if the current one is expired.
func test_getToken_tokenExpired() async throws {
client.result = .httpSuccess(testData: .identityTokenRefresh)
tokenService.accessToken = "EXPIRED"
tokenService.accessTokenExpirationDateResult = .success(expirationDateExpired)
let token = try await subject.getToken()
XCTAssertEqual(token, "ACCESS_TOKEN")
}
/// `getToken()` returns a refreshed access token if the current one is expiring soon.
func test_getToken_tokenExpiringSoon() async throws {
client.result = .httpSuccess(testData: .identityTokenRefresh)
tokenService.accessToken = "EXPIRING_SOON"
tokenService.accessTokenExpirationDateResult = .success(expirationDateExpiringSoon)
let token = try await subject.getToken()
XCTAssertEqual(token, "ACCESS_TOKEN")
}
/// `getToken()` returns the current access token if it is unexpired.
func test_getToken_tokenUnexpired() async throws {
tokenService.accessToken = "ACCESS_TOKEN"
tokenService.accessTokenExpirationDateResult = .success(expirationDateUnexpired)
let token = try await subject.getToken()
XCTAssertEqual(token, "ACCESS_TOKEN")
}
/// `getToken()` returns the current access token if the expiration date doesn't yet exist.
func test_getToken_tokenNil() async throws {
tokenService.accessToken = "ACCESS_TOKEN"
tokenService.accessTokenExpirationDateResult = .success(nil)
let token = try await subject.getToken()
XCTAssertEqual(token, "ACCESS_TOKEN")
}
@ -58,12 +108,12 @@ class AccountTokenProviderTests: BitwardenTestCase {
client.result = .httpSuccess(testData: .identityTokenRefresh)
try await subject.refreshToken()
let newAccessToken = try await subject.refreshToken()
let newAccessToken = try await subject.getToken()
XCTAssertEqual(newAccessToken, "ACCESS_TOKEN")
XCTAssertEqual(tokenService.accessToken, "ACCESS_TOKEN")
XCTAssertEqual(tokenService.refreshToken, "REFRESH_TOKEN")
XCTAssertEqual(tokenService.expirationDate, Date(year: 2025, month: 10, day: 2, hour: 1, minute: 0, second: 0))
let refreshTask = await subject.refreshTask
XCTAssertNil(refreshTask)
@ -76,14 +126,15 @@ class AccountTokenProviderTests: BitwardenTestCase {
client.result = .httpSuccess(testData: .identityTokenRefresh)
async let refreshTask1: Void = subject.refreshToken()
async let refreshTask2: Void = subject.refreshToken()
async let refreshTask1: String = subject.refreshToken()
async let refreshTask2: String = subject.refreshToken()
_ = try await (refreshTask1, refreshTask2)
XCTAssertEqual(client.requests.count, 1)
XCTAssertEqual(tokenService.accessToken, "ACCESS_TOKEN")
XCTAssertEqual(tokenService.refreshToken, "REFRESH_TOKEN")
XCTAssertEqual(tokenService.expirationDate, Date(year: 2025, month: 10, day: 2, hour: 1, minute: 0, second: 0))
let refreshTask = await subject.refreshTask
XCTAssertNil(refreshTask)
@ -101,7 +152,7 @@ class AccountTokenProviderTests: BitwardenTestCase {
client.result = .failure(BitwardenTestError.example)
await assertAsyncThrows(error: BitwardenTestError.example) {
try await subject.refreshToken()
_ = try await subject.refreshToken()
}
XCTAssertTrue(delegate.onRefreshTokenErrorCalled)
}
@ -115,7 +166,7 @@ class AccountTokenProviderTests: BitwardenTestCase {
client.result = .failure(BitwardenTestError.example)
await assertAsyncThrows(error: BitwardenTestError.example) {
try await subject.refreshToken()
_ = try await subject.refreshToken()
}
}
}

View File

@ -9,6 +9,6 @@ protocol RefreshableAPIService { // sourcery: AutoMockable
extension APIService: RefreshableAPIService {
func refreshAccessToken() async throws {
try await accountTokenProvider.refreshToken()
_ = try await accountTokenProvider.refreshToken()
}
}

View File

@ -9,15 +9,15 @@ class MockAccountTokenProvider: AccountTokenProvider {
var delegate: AccountTokenProviderDelegate?
var getTokenResult: Result<String, Error> = .success("ACCESS_TOKEN")
var refreshTokenCalled = false
var refreshTokenResult: Result<Void, Error> = .success(())
var refreshTokenResult: Result<String, Error> = .success("ACCESS_TOKEN")
func getToken() async throws -> String {
try getTokenResult.get()
}
func refreshToken() async throws {
func refreshToken() async throws -> String {
refreshTokenCalled = true
try refreshTokenResult.get()
return try refreshTokenResult.get()
}
func setDelegate(delegate: AccountTokenProviderDelegate) async {

View File

@ -44,6 +44,13 @@ protocol StateService: AnyObject {
///
func doesActiveAccountHavePremium() async -> Bool
/// Gets the access token's expiration date for an account.
///
/// - Parameter userId: The user ID associated with the access token expiration date.
/// - Returns: The user's access token expiration date.
///
func getAccessTokenExpirationDate(userId: String) async -> Date?
/// Gets the account for an id.
///
/// - Parameter userId: The id for an account. If nil, the active account will be returned.
@ -429,6 +436,14 @@ protocol StateService: AnyObject {
///
func pinUnlockRequiresPasswordAfterRestart() async throws -> Bool
/// Sets the access token's expiration date for an account.
///
/// - Parameters:
/// - expirationDate: The user's access token expiration date.
/// - userId: The user ID associated with the access token expiration date.
///
func setAccessTokenExpirationDate(_ expirationDate: Date?, userId: String) async
/// Sets the account encryption keys for an account.
///
/// - Parameters:
@ -855,6 +870,14 @@ extension StateService {
await setPendingAppIntentActions(actions: actions)
}
/// Gets the access token's expiration date for the active account.
///
/// - Returns: The user's access token expiration date.
///
func getAccessTokenExpirationDate() async throws -> Date? {
try await getAccessTokenExpirationDate(userId: getActiveAccountId())
}
/// Gets the account encryptions keys for the active account.
///
/// - Returns: The account encryption keys.
@ -1143,6 +1166,14 @@ extension StateService {
try await pinProtectedUserKeyEnvelope(userId: nil)
}
/// Sets the access token's expiration date for the active account.
///
/// - Parameter expirationDate: The user's access token expiration date.
///
func setAccessTokenExpirationDate(_ expirationDate: Date?) async throws {
try await setAccessTokenExpirationDate(expirationDate, userId: getActiveAccountId())
}
/// Sets the account encryption keys for the active account.
///
/// - Parameter encryptionKeys: The account encryption keys.
@ -1542,6 +1573,10 @@ actor DefaultStateService: StateService, ConfigStateService { // swiftlint:disab
}
}
func getAccessTokenExpirationDate(userId: String) -> Date? {
appSettingsStore.accessTokenExpirationDate(userId: userId)
}
func getAccount(userId: String?) throws -> Account {
guard let accounts = appSettingsStore.state?.accounts else {
throw StateServiceError.noAccounts
@ -1844,6 +1879,7 @@ actor DefaultStateService: StateService, ConfigStateService { // swiftlint:disab
state.activeUserId = state.accounts.first?.key
}
appSettingsStore.setAccessTokenExpirationDate(nil, userId: knownUserId)
appSettingsStore.setBiometricAuthenticationEnabled(nil, for: knownUserId)
appSettingsStore.setDefaultUriMatchType(nil, userId: knownUserId)
appSettingsStore.setDisableAutoTotpCopy(nil, userId: knownUserId)
@ -1876,6 +1912,10 @@ actor DefaultStateService: StateService, ConfigStateService { // swiftlint:disab
&& appSettingsStore.pinProtectedUserKey(userId: userId) == nil
}
func setAccessTokenExpirationDate(_ expirationDate: Date?, userId: String) async {
appSettingsStore.setAccessTokenExpirationDate(expirationDate, userId: userId)
}
func setAccountKdf(_ kdfConfig: KdfConfig, userId: String) async throws {
try updateAccountProfile(userId: userId) { profile in
profile.kdfType = kdfConfig.kdfType

View File

@ -290,6 +290,29 @@ class StateServiceTests: BitwardenTestCase { // swiftlint:disable:this type_body
XCTAssertEqual(errorReporter.errors as? [StateServiceError], [.noActiveAccount])
}
/// `getAccessTokenExpirationDate(userId:)` gets the user's access token expiration date.
func test_getAccessTokenExpirationDate() async throws {
let date1 = Date(year: 2025, month: 1, day: 1)
let date2 = Date(year: 2026, month: 6, day: 1)
appSettingsStore.accessTokenExpirationDateByUserId["1"] = date1
appSettingsStore.accessTokenExpirationDateByUserId["2"] = date2
await subject.addAccount(.fixture(profile: .fixture(userId: "1")))
await subject.addAccount(.fixture(profile: .fixture(userId: "2")))
let expirationDate1 = await subject.getAccessTokenExpirationDate(userId: "1")
XCTAssertEqual(expirationDate1, date1)
let expirationDate2 = try await subject.getAccessTokenExpirationDate()
XCTAssertEqual(expirationDate2, date2)
}
/// `getAccessTokenExpirationDate(userId:)` throws an error if there's no accounts.
func test_getAccessTokenExpirationDate_noAccount() async throws {
await assertAsyncThrows(error: StateServiceError.noActiveAccount) {
_ = try await subject.getAccessTokenExpirationDate()
}
}
/// `getAccountEncryptionKeys(_:)` returns the encryption keys for the user account.
func test_getAccountEncryptionKeys() async throws {
appSettingsStore.accountKeys["1"] = .fixture(
@ -1681,6 +1704,29 @@ class StateServiceTests: BitwardenTestCase { // swiftlint:disable:this type_body
XCTAssertTrue(result == true)
}
/// `setAccessTokenExpirationDate(_:userId:)` sets the access token expiration date for the account.
func test_setAccessTokenExpirationDate() async throws {
let date1 = Date(year: 2025, month: 1, day: 1)
let date2 = Date(year: 2026, month: 6, day: 1)
await subject.addAccount(.fixture(profile: .fixture(userId: "1")))
await subject.addAccount(.fixture(profile: .fixture(userId: "2")))
await subject.setAccessTokenExpirationDate(date1, userId: "1")
try await subject.setAccessTokenExpirationDate(date2)
XCTAssertEqual(
appSettingsStore.accessTokenExpirationDateByUserId,
["1": date1, "2": date2],
)
}
/// `setAccessTokenExpirationDate(_:userId:)` throws an error if there's no accounts.
func test_setAccessTokenExpirationDate_noAccounts() async throws {
await assertAsyncThrows(error: StateServiceError.noActiveAccount) {
_ = try await subject.setAccessTokenExpirationDate(.now)
}
}
/// `setAccountEncryptionKeys(_:userId:)` sets the encryption keys for the user account.
func test_setAccountEncryptionKeys() async throws {
await subject.addAccount(.fixture(profile: .fixture(userId: "1")))

View File

@ -73,6 +73,13 @@ protocol AppSettingsStore: AnyObject {
/// The app's account state.
var state: State? { get set }
/// The user's access token expiration date.
///
/// - Parameter userId: The user ID associated with the access token expiration date.
/// - Returns: The user's access token expiration date.
///
func accessTokenExpirationDate(userId: String) -> Date?
/// The user's v2 account keys.
///
/// - Parameter userId: The user ID associated with the stored account keys.
@ -265,6 +272,14 @@ protocol AppSettingsStore: AnyObject {
/// - Returns: The server config for that user ID.
func serverConfig(userId: String) -> ServerConfig?
/// Sets the user's access token expiration date
///
/// - Parameters:
/// - expirationDate: The user's access token expiration date
/// - userId: The user ID associated with the access token expiration date.
///
func setAccessTokenExpirationDate(_ expirationDate: Date?, userId: String)
/// Sets the account v2 keys for a user ID.
///
/// - Parameters:
@ -740,6 +755,7 @@ extension DefaultAppSettingsStore: AppSettingsStore, ConfigSettingsStore {
/// The keys used to store their associated values.
///
enum Keys {
case accessTokenExpirationDate(userId: String)
case accountKeys(userId: String)
case accountSetupAutofill(userId: String)
case accountSetupImportLogins(userId: String)
@ -800,6 +816,8 @@ extension DefaultAppSettingsStore: AppSettingsStore, ConfigSettingsStore {
/// Returns the key used to store the data under for retrieving it later.
var storageKey: String {
let key = switch self {
case let .accessTokenExpirationDate(userId):
"accessTokenExpirationDate_\(userId)"
case let .accountKeys(userId):
"accountKeys_\(userId)"
case let .accountSetupAutofill(userId):
@ -1019,6 +1037,10 @@ extension DefaultAppSettingsStore: AppSettingsStore, ConfigSettingsStore {
}
}
func accessTokenExpirationDate(userId: String) -> Date? {
fetch(for: .accessTokenExpirationDate(userId: userId))
}
func accountKeys(userId: String) -> PrivateKeysResponseModel? {
fetch(for: .accountKeys(userId: userId))
}
@ -1141,6 +1163,10 @@ extension DefaultAppSettingsStore: AppSettingsStore, ConfigSettingsStore {
fetch(for: .serverConfig(userId: userId))
}
func setAccessTokenExpirationDate(_ expirationDate: Date?, userId: String) {
store(expirationDate, for: .accessTokenExpirationDate(userId: userId))
}
func setAccountKeys(_ keys: PrivateKeysResponseModel?, userId: String) {
store(keys, for: .accountKeys(userId: userId))
}

View File

@ -39,6 +39,24 @@ class AppSettingsStoreTests: BitwardenTestCase { // swiftlint:disable:this type_
// MARK: Tests
/// `accessTokenExpirationDate(userId:)` returns `nil` if there isn't a previously stored value.
func test_accessTokenExpirationDate_isInitiallyNil() {
XCTAssertNil(subject.accessTokenExpirationDate(userId: "-1"))
}
/// `accessTokenExpirationDate(userId:)` can be used to get the user's access token expiration date.
func test_accessTokenExpirationDate_withValue() {
let date1 = Date(year: 2025, month: 10, day: 1)
let date2 = Date(year: 2026, month: 1, day: 2)
subject.setAccessTokenExpirationDate(date1, userId: "1")
subject.setAccessTokenExpirationDate(date2, userId: "2")
XCTAssertEqual(subject.accessTokenExpirationDate(userId: "1"), date1)
XCTAssertEqual(subject.accessTokenExpirationDate(userId: "2"), date2)
XCTAssertEqual(userDefaults.integer(forKey: "bwPreferencesStorage:accessTokenExpirationDate_1"), 780_969_600)
XCTAssertEqual(userDefaults.integer(forKey: "bwPreferencesStorage:accessTokenExpirationDate_2"), 789_004_800)
}
/// `accountKeys(userId:)` returns `nil` if there isn't a previously stored value.
func test_accountKeys_isInitiallyNil() {
XCTAssertNil(subject.accountKeys(userId: "-1"))

View File

@ -7,6 +7,7 @@ import Foundation
// swiftlint:disable file_length
class MockAppSettingsStore: AppSettingsStore { // swiftlint:disable:this type_body_length
var accessTokenExpirationDateByUserId = [String: Date]()
var accountKeys = [String: PrivateKeysResponseModel]()
var accountSetupAutofill = [String: AccountSetupProgress]()
var accountSetupImportLogins = [String: AccountSetupProgress]()
@ -74,6 +75,10 @@ class MockAppSettingsStore: AppSettingsStore { // swiftlint:disable:this type_bo
var activeIdSubject = CurrentValueSubject<String?, Never>(nil)
func accessTokenExpirationDate(userId: String) -> Date? {
accessTokenExpirationDateByUserId[userId]
}
func accountKeys(userId: String) -> PrivateKeysResponseModel? {
accountKeys[userId]
}
@ -191,6 +196,10 @@ class MockAppSettingsStore: AppSettingsStore { // swiftlint:disable:this type_bo
serverConfig[userId]
}
func setAccessTokenExpirationDate(_ expirationDate: Date?, userId: String) {
accessTokenExpirationDateByUserId[userId] = expirationDate
}
func setAccountKeys(_ keys: BitwardenShared.PrivateKeysResponseModel?, userId: String) {
accountKeys[userId] = keys
}

View File

@ -7,6 +7,7 @@ import Foundation
@testable import BitwardenShared
class MockStateService: StateService { // swiftlint:disable:this type_body_length
var accessTokenExpirationDateByUserId = [String: Date]()
var accountEncryptionKeys = [String: AccountEncryptionKeys]()
var accountSetupAutofill = [String: AccountSetupProgress]()
var accountSetupAutofillError: Error?
@ -138,6 +139,10 @@ class MockStateService: StateService { // swiftlint:disable:this type_body_lengt
})
}
func getAccessTokenExpirationDate(userId: String) async -> Date? {
accessTokenExpirationDateByUserId[userId]
}
func didAccountSwitchInExtension() async throws -> Bool {
try didAccountSwitchInExtensionResult.get()
}
@ -445,6 +450,10 @@ class MockStateService: StateService { // swiftlint:disable:this type_body_lengt
pinUnlockRequiresPasswordAfterRestartValue
}
func setAccessTokenExpirationDate(_ expirationDate: Date?, userId: String) async {
accessTokenExpirationDateByUserId[userId] = expirationDate
}
func setAccountEncryptionKeys(_ encryptionKeys: AccountEncryptionKeys, userId: String?) async throws {
let userId = try unwrapUserId(userId)
accountEncryptionKeys[userId] = encryptionKeys

View File

@ -1,9 +1,12 @@
import Foundation
import Networking
@testable import BitwardenShared
class MockTokenService: TokenService {
var accessToken: String? = "ACCESS_TOKEN"
var accessTokenExpirationDateResult: Result<Date?, Error> = .success(nil)
var expirationDate: Date?
var getIsExternalResult: Result<Bool, Error> = .success(false)
var refreshToken: String? = "REFRESH_TOKEN"
@ -12,6 +15,10 @@ class MockTokenService: TokenService {
return accessToken
}
func getAccessTokenExpirationDate() async throws -> Date? {
try accessTokenExpirationDateResult.get()
}
func getIsExternal() async throws -> Bool {
try getIsExternalResult.get()
}
@ -21,8 +28,9 @@ class MockTokenService: TokenService {
return refreshToken
}
func setTokens(accessToken: String, refreshToken: String) async {
func setTokens(accessToken: String, refreshToken: String, expirationDate: Date) async {
self.accessToken = accessToken
self.refreshToken = refreshToken
self.expirationDate = expirationDate
}
}

View File

@ -1,5 +1,6 @@
import BitwardenKit
import BitwardenSdk
import Foundation
/// A protocol for a `TokenService` which manages accessing and updating the active account's tokens.
///
@ -10,6 +11,12 @@ protocol TokenService: AnyObject {
///
func getAccessToken() async throws -> String
/// Returns the access token's expiration date for the current account.
///
/// - Returns: The access token's expiration date for the current account.
///
func getAccessTokenExpirationDate() async throws -> Date?
/// Returns whether the user is an external user.
///
/// - Returns: Whether the user is an external user.
@ -27,8 +34,9 @@ protocol TokenService: AnyObject {
/// - Parameters:
/// - accessToken: The account's updated access token.
/// - refreshToken: The account's updated refresh token.
/// - expirationDate: The access token's expiration date.
///
func setTokens(accessToken: String, refreshToken: String) async throws
func setTokens(accessToken: String, refreshToken: String, expirationDate: Date) async throws
}
// MARK: - DefaultTokenService
@ -73,6 +81,10 @@ actor DefaultTokenService: TokenService {
return try await keychainRepository.getAccessToken(userId: userId)
}
func getAccessTokenExpirationDate() async throws -> Date? {
try await stateService.getAccessTokenExpirationDate()
}
func getIsExternal() async throws -> Bool {
let accessToken: String = try await getAccessToken()
let tokenPayload = try TokenParser.parseToken(accessToken)
@ -84,10 +96,11 @@ actor DefaultTokenService: TokenService {
return try await keychainRepository.getRefreshToken(userId: userId)
}
func setTokens(accessToken: String, refreshToken: String) async throws {
func setTokens(accessToken: String, refreshToken: String, expirationDate: Date) async throws {
let userId = try await stateService.getActiveAccountId()
try await keychainRepository.setAccessToken(accessToken, userId: userId)
try await keychainRepository.setRefreshToken(refreshToken, userId: userId)
await stateService.setAccessTokenExpirationDate(expirationDate, userId: userId)
}
}

View File

@ -69,6 +69,22 @@ class TokenServiceTests: BitwardenTestCase {
XCTAssertNil(accessToken)
}
/// `getAccessTokenExpirationDate()` returns the access token's expiration date.
func test_getAccessTokenExpirationDate() async throws {
stateService.accessTokenExpirationDateByUserId["1"] = Date(year: 2025, month: 10, day: 2)
stateService.activeAccount = .fixture()
let expirationDate = try await subject.getAccessTokenExpirationDate()
XCTAssertEqual(expirationDate, Date(year: 2025, month: 10, day: 2))
}
/// `getAccessTokenExpirationDate()` throws an error if there isn't an active account.
func test_getAccessTokenExpirationDate_error() async throws {
await assertAsyncThrows(error: StateServiceError.noActiveAccount) {
_ = try await subject.getAccessTokenExpirationDate()
}
}
/// `getIsExternal()` returns false if the user isn't an external user.
func test_getIsExternal_false() async throws {
// swiftlint:disable:next line_length
@ -144,7 +160,8 @@ class TokenServiceTests: BitwardenTestCase {
func test_setTokens() async throws {
stateService.activeAccount = .fixture()
try await subject.setTokens(accessToken: "🔑", refreshToken: "🔒")
let expirationDate = Date(year: 2025, month: 10, day: 1)
try await subject.setTokens(accessToken: "🔑", refreshToken: "🔒", expirationDate: expirationDate)
XCTAssertEqual(
keychainRepository.mockStorage[keychainRepository.formattedKey(for: .accessToken(userId: "1"))],
@ -154,5 +171,6 @@ class TokenServiceTests: BitwardenTestCase {
keychainRepository.mockStorage[keychainRepository.formattedKey(for: .refreshToken(userId: "1"))],
"🔒",
)
XCTAssertEqual(stateService.accessTokenExpirationDateByUserId["1"], expirationDate)
}
}

View File

@ -87,6 +87,10 @@ extension Constants {
/// The default number of KDF iterations to perform.
static let pbkdf2Iterations = 600_000
/// The number of seconds before an access token's expiration time at which the app will
/// preemptively refresh the token.
static let tokenRefreshThreshold: TimeInterval = 5 * 60 // 5 minutes
}
// MARK: Extension Constants

View File

@ -682,6 +682,8 @@ extension AppProcessor: AccountTokenProviderDelegate {
func onRefreshTokenError(error: any Error) async throws {
if case IdentityTokenRefreshRequestError.invalidGrant = error {
await logOutAutomatically()
} else if let error = error as? ResponseValidationError, [401, 403].contains(error.response.statusCode) {
await logOutAutomatically()
}
}
}

View File

@ -1,5 +1,6 @@
import AuthenticationServices
import AuthenticatorBridgeKit
import BitwardenKit
import BitwardenKitMocks
import BitwardenResources
import Foundation
@ -1402,6 +1403,36 @@ class AppProcessorTests: BitwardenTestCase { // swiftlint:disable:this type_body
XCTAssertEqual(coordinator.events, [.didLogout(userId: "1", userInitiated: false)])
}
/// `onRefreshTokenError(error:)` logs the user out and notifies the coordinator when a 401 is
/// received while refreshing the token.
@MainActor
func test_onRefreshTokenError_logOut401() async throws {
coordinator.isLoadingOverlayShowing = true
try await subject.onRefreshTokenError(error: ResponseValidationError(response: .failure(statusCode: 401)))
XCTAssertTrue(authRepository.logoutCalled)
XCTAssertEqual(authRepository.logoutUserId, nil)
XCTAssertFalse(authRepository.logoutUserInitiated)
XCTAssertFalse(coordinator.isLoadingOverlayShowing)
XCTAssertEqual(coordinator.events, [.didLogout(userId: nil, userInitiated: false)])
}
/// `onRefreshTokenError(error:)` logs the user out and notifies the coordinator a 403 is
/// received while refreshing the token.
@MainActor
func test_onRefreshTokenError_logOut403() async throws {
coordinator.isLoadingOverlayShowing = true
try await subject.onRefreshTokenError(error: ResponseValidationError(response: .failure(statusCode: 403)))
XCTAssertTrue(authRepository.logoutCalled)
XCTAssertEqual(authRepository.logoutUserId, nil)
XCTAssertFalse(authRepository.logoutUserInitiated)
XCTAssertFalse(coordinator.isLoadingOverlayShowing)
XCTAssertEqual(coordinator.events, [.didLogout(userId: nil, userInitiated: false)])
}
/// `onRefreshTokenError(error:)` logs the user out and notifies the coordinator when error is `.invalidGrant`.
@MainActor
func test_onRefreshTokenError_logOutInvalidGrant() async throws {

View File

@ -145,7 +145,7 @@ public final class HTTPService: Sendable {
}
if let tokenProvider, httpResponse.statusCode == 401, shouldRetryIfUnauthorized {
try await tokenProvider.refreshToken()
_ = try await tokenProvider.refreshToken()
// Send the request again, but don't retry if still unauthorized to prevent a retry loop.
return try await send(httpRequest, validate: validate, shouldRetryIfUnauthorized: false)

View File

@ -9,5 +9,7 @@ public protocol TokenProvider: Sendable {
/// Refreshes the access token by using the refresh token to acquire a new access token.
///
func refreshToken() async throws
/// - Returns: A new access token.
///
func refreshToken() async throws -> String
}

View File

@ -8,7 +8,7 @@ class MockTokenProvider: TokenProvider {
var getTokenCallCount = 0
var tokenResults: [Result<String, Error>] = [.success("ACCESS_TOKEN")]
var refreshTokenResult: Result<Void, Error> = .success(())
var refreshTokenResult: Result<String, Error> = .success("ACCESS_TOKEN")
var refreshTokenCallCount = 0
func getToken() async throws -> String {
@ -17,8 +17,8 @@ class MockTokenProvider: TokenProvider {
return try tokenResults.removeFirst().get()
}
func refreshToken() async throws {
func refreshToken() async throws -> String {
refreshTokenCallCount += 1
try refreshTokenResult.get()
return try refreshTokenResult.get()
}
}