mirror of
https://github.com/bitwarden/ios.git
synced 2025-12-10 00:42:29 -06:00
[PM-24482] Refresh access token preemptively and log out on 401/403 refresh errors (#2024)
This commit is contained in:
parent
039495e7e9
commit
4376077ab1
@ -1,3 +1,5 @@
|
||||
import BitwardenKit
|
||||
import Foundation
|
||||
import Networking
|
||||
|
||||
// MARK: - AccountTokenProvider
|
||||
@ -21,13 +23,16 @@ actor DefaultAccountTokenProvider: AccountTokenProvider {
|
||||
private weak var accountTokenProviderDelegate: AccountTokenProviderDelegate?
|
||||
|
||||
/// The `HTTPService` used to make the API call to refresh the access token.
|
||||
let httpService: HTTPService
|
||||
private let httpService: HTTPService
|
||||
|
||||
/// The task associated with refreshing the token, if one is in progress.
|
||||
private(set) var refreshTask: Task<String, Error>?
|
||||
|
||||
/// The service used to get the present time.
|
||||
private let timeProvider: TimeProvider
|
||||
|
||||
/// The `TokenService` used to get the current tokens from.
|
||||
let tokenService: TokenService
|
||||
private let tokenService: TokenService
|
||||
|
||||
// MARK: Initialization
|
||||
|
||||
@ -35,14 +40,17 @@ actor DefaultAccountTokenProvider: AccountTokenProvider {
|
||||
///
|
||||
/// - Parameters:
|
||||
/// - httpService: The service used to make the API call to refresh the access token.
|
||||
/// - timeProvider: The service used to get the present time.
|
||||
/// - tokenService: The service used to get the current tokens from.
|
||||
///
|
||||
init(
|
||||
httpService: HTTPService,
|
||||
timeProvider: TimeProvider = CurrentTime(),
|
||||
tokenService: TokenService,
|
||||
) {
|
||||
self.tokenService = tokenService
|
||||
self.httpService = httpService
|
||||
self.timeProvider = timeProvider
|
||||
self.tokenService = tokenService
|
||||
}
|
||||
|
||||
// MARK: Methods
|
||||
@ -54,15 +62,19 @@ actor DefaultAccountTokenProvider: AccountTokenProvider {
|
||||
return try await refreshTask.value
|
||||
}
|
||||
|
||||
return try await tokenService.getAccessToken()
|
||||
let token = try await tokenService.getAccessToken()
|
||||
if await shouldRefresh(accessToken: token) {
|
||||
return try await refreshToken()
|
||||
} else {
|
||||
return token
|
||||
}
|
||||
}
|
||||
|
||||
func refreshToken() async throws {
|
||||
func refreshToken() async throws -> String {
|
||||
if let refreshTask {
|
||||
// If there's a refresh in progress, wait for it to complete rather than triggering
|
||||
// another refresh.
|
||||
_ = try await refreshTask.value
|
||||
return
|
||||
return try await refreshTask.value
|
||||
}
|
||||
|
||||
let refreshTask = Task {
|
||||
@ -73,9 +85,12 @@ actor DefaultAccountTokenProvider: AccountTokenProvider {
|
||||
let response = try await httpService.send(
|
||||
IdentityTokenRefreshRequest(refreshToken: refreshToken),
|
||||
)
|
||||
let expirationDate = timeProvider.presentTime.addingTimeInterval(TimeInterval(response.expiresIn))
|
||||
|
||||
try await tokenService.setTokens(
|
||||
accessToken: response.accessToken,
|
||||
refreshToken: response.refreshToken,
|
||||
expirationDate: expirationDate,
|
||||
)
|
||||
|
||||
return response.accessToken
|
||||
@ -88,17 +103,35 @@ actor DefaultAccountTokenProvider: AccountTokenProvider {
|
||||
}
|
||||
self.refreshTask = refreshTask
|
||||
|
||||
_ = try await refreshTask.value
|
||||
return try await refreshTask.value
|
||||
}
|
||||
|
||||
func setDelegate(delegate: AccountTokenProviderDelegate) async {
|
||||
accountTokenProviderDelegate = delegate
|
||||
}
|
||||
|
||||
// MARK: Private
|
||||
|
||||
/// Returns whether the access token needs to be refreshed based on the last stored access token
|
||||
/// expiration date. This is used to preemptively refresh the token prior to its expiration.
|
||||
///
|
||||
/// - Parameter accessToken: The access token to determine whether it needs to be refreshed.
|
||||
/// - Returns: Whether the access token needs to be refreshed.
|
||||
///
|
||||
private func shouldRefresh(accessToken: String) async -> Bool {
|
||||
guard let expirationDate = try? await tokenService.getAccessTokenExpirationDate() else {
|
||||
// If there's no stored expiration date, don't preemptively refresh the token.
|
||||
return false
|
||||
}
|
||||
|
||||
let refreshThreshold = timeProvider.presentTime.addingTimeInterval(Constants.tokenRefreshThreshold)
|
||||
return expirationDate <= refreshThreshold
|
||||
}
|
||||
}
|
||||
|
||||
/// Delegate to be used by the `AccountTokenProvider`.
|
||||
protocol AccountTokenProviderDelegate: AnyObject {
|
||||
/// Callbac to be used when an error is thrown when refreshing the access token.
|
||||
/// Callback to be used when an error is thrown when refreshing the access token.
|
||||
/// - Parameter error: `Error` thrown.
|
||||
func onRefreshTokenError(error: Error) async throws
|
||||
}
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import BitwardenKitMocks
|
||||
import Networking
|
||||
import TestHelpers
|
||||
import XCTest
|
||||
@ -9,18 +10,25 @@ class AccountTokenProviderTests: BitwardenTestCase {
|
||||
|
||||
var client: MockHTTPClient!
|
||||
var subject: DefaultAccountTokenProvider!
|
||||
var timeProvider: MockTimeProvider!
|
||||
var tokenService: MockTokenService!
|
||||
|
||||
let expirationDateExpired = Date(year: 2025, month: 10, day: 1, hour: 23, minute: 59, second: 0)
|
||||
let expirationDateExpiringSoon = Date(year: 2025, month: 10, day: 2, hour: 0, minute: 2, second: 0)
|
||||
let expirationDateUnexpired = Date(year: 2025, month: 10, day: 2, hour: 0, minute: 6, second: 0)
|
||||
|
||||
// MARK: Setup & Teardown
|
||||
|
||||
override func setUp() {
|
||||
super.setUp()
|
||||
|
||||
client = MockHTTPClient()
|
||||
timeProvider = MockTimeProvider(.mockTime(Date(year: 2025, month: 10, day: 2)))
|
||||
tokenService = MockTokenService()
|
||||
|
||||
subject = DefaultAccountTokenProvider(
|
||||
httpService: HTTPService(baseURL: URL(string: "https://example.com")!, client: client),
|
||||
timeProvider: timeProvider,
|
||||
tokenService: tokenService,
|
||||
)
|
||||
}
|
||||
@ -30,13 +38,55 @@ class AccountTokenProviderTests: BitwardenTestCase {
|
||||
|
||||
client = nil
|
||||
subject = nil
|
||||
timeProvider = nil
|
||||
tokenService = nil
|
||||
}
|
||||
|
||||
// MARK: Tests
|
||||
|
||||
/// `getToken()` returns the current access token.
|
||||
func test_getToken() async throws {
|
||||
/// `getToken()` returns the current access token if fetching the expiration date returns an error.
|
||||
func test_getToken_tokenError() async throws {
|
||||
tokenService.accessToken = "ACCESS_TOKEN"
|
||||
tokenService.accessTokenExpirationDateResult = .failure(BitwardenTestError.example)
|
||||
|
||||
let token = try await subject.getToken()
|
||||
XCTAssertEqual(token, "ACCESS_TOKEN")
|
||||
}
|
||||
|
||||
/// `getToken()` returns a refreshed access token if the current one is expired.
|
||||
func test_getToken_tokenExpired() async throws {
|
||||
client.result = .httpSuccess(testData: .identityTokenRefresh)
|
||||
tokenService.accessToken = "EXPIRED"
|
||||
tokenService.accessTokenExpirationDateResult = .success(expirationDateExpired)
|
||||
|
||||
let token = try await subject.getToken()
|
||||
XCTAssertEqual(token, "ACCESS_TOKEN")
|
||||
}
|
||||
|
||||
/// `getToken()` returns a refreshed access token if the current one is expiring soon.
|
||||
func test_getToken_tokenExpiringSoon() async throws {
|
||||
client.result = .httpSuccess(testData: .identityTokenRefresh)
|
||||
tokenService.accessToken = "EXPIRING_SOON"
|
||||
tokenService.accessTokenExpirationDateResult = .success(expirationDateExpiringSoon)
|
||||
|
||||
let token = try await subject.getToken()
|
||||
XCTAssertEqual(token, "ACCESS_TOKEN")
|
||||
}
|
||||
|
||||
/// `getToken()` returns the current access token if it is unexpired.
|
||||
func test_getToken_tokenUnexpired() async throws {
|
||||
tokenService.accessToken = "ACCESS_TOKEN"
|
||||
tokenService.accessTokenExpirationDateResult = .success(expirationDateUnexpired)
|
||||
|
||||
let token = try await subject.getToken()
|
||||
XCTAssertEqual(token, "ACCESS_TOKEN")
|
||||
}
|
||||
|
||||
/// `getToken()` returns the current access token if the expiration date doesn't yet exist.
|
||||
func test_getToken_tokenNil() async throws {
|
||||
tokenService.accessToken = "ACCESS_TOKEN"
|
||||
tokenService.accessTokenExpirationDateResult = .success(nil)
|
||||
|
||||
let token = try await subject.getToken()
|
||||
XCTAssertEqual(token, "ACCESS_TOKEN")
|
||||
}
|
||||
@ -58,12 +108,12 @@ class AccountTokenProviderTests: BitwardenTestCase {
|
||||
|
||||
client.result = .httpSuccess(testData: .identityTokenRefresh)
|
||||
|
||||
try await subject.refreshToken()
|
||||
let newAccessToken = try await subject.refreshToken()
|
||||
|
||||
let newAccessToken = try await subject.getToken()
|
||||
XCTAssertEqual(newAccessToken, "ACCESS_TOKEN")
|
||||
XCTAssertEqual(tokenService.accessToken, "ACCESS_TOKEN")
|
||||
XCTAssertEqual(tokenService.refreshToken, "REFRESH_TOKEN")
|
||||
XCTAssertEqual(tokenService.expirationDate, Date(year: 2025, month: 10, day: 2, hour: 1, minute: 0, second: 0))
|
||||
|
||||
let refreshTask = await subject.refreshTask
|
||||
XCTAssertNil(refreshTask)
|
||||
@ -76,14 +126,15 @@ class AccountTokenProviderTests: BitwardenTestCase {
|
||||
|
||||
client.result = .httpSuccess(testData: .identityTokenRefresh)
|
||||
|
||||
async let refreshTask1: Void = subject.refreshToken()
|
||||
async let refreshTask2: Void = subject.refreshToken()
|
||||
async let refreshTask1: String = subject.refreshToken()
|
||||
async let refreshTask2: String = subject.refreshToken()
|
||||
|
||||
_ = try await (refreshTask1, refreshTask2)
|
||||
|
||||
XCTAssertEqual(client.requests.count, 1)
|
||||
XCTAssertEqual(tokenService.accessToken, "ACCESS_TOKEN")
|
||||
XCTAssertEqual(tokenService.refreshToken, "REFRESH_TOKEN")
|
||||
XCTAssertEqual(tokenService.expirationDate, Date(year: 2025, month: 10, day: 2, hour: 1, minute: 0, second: 0))
|
||||
|
||||
let refreshTask = await subject.refreshTask
|
||||
XCTAssertNil(refreshTask)
|
||||
@ -101,7 +152,7 @@ class AccountTokenProviderTests: BitwardenTestCase {
|
||||
client.result = .failure(BitwardenTestError.example)
|
||||
|
||||
await assertAsyncThrows(error: BitwardenTestError.example) {
|
||||
try await subject.refreshToken()
|
||||
_ = try await subject.refreshToken()
|
||||
}
|
||||
XCTAssertTrue(delegate.onRefreshTokenErrorCalled)
|
||||
}
|
||||
@ -115,7 +166,7 @@ class AccountTokenProviderTests: BitwardenTestCase {
|
||||
client.result = .failure(BitwardenTestError.example)
|
||||
|
||||
await assertAsyncThrows(error: BitwardenTestError.example) {
|
||||
try await subject.refreshToken()
|
||||
_ = try await subject.refreshToken()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -9,6 +9,6 @@ protocol RefreshableAPIService { // sourcery: AutoMockable
|
||||
|
||||
extension APIService: RefreshableAPIService {
|
||||
func refreshAccessToken() async throws {
|
||||
try await accountTokenProvider.refreshToken()
|
||||
_ = try await accountTokenProvider.refreshToken()
|
||||
}
|
||||
}
|
||||
|
||||
@ -9,15 +9,15 @@ class MockAccountTokenProvider: AccountTokenProvider {
|
||||
var delegate: AccountTokenProviderDelegate?
|
||||
var getTokenResult: Result<String, Error> = .success("ACCESS_TOKEN")
|
||||
var refreshTokenCalled = false
|
||||
var refreshTokenResult: Result<Void, Error> = .success(())
|
||||
var refreshTokenResult: Result<String, Error> = .success("ACCESS_TOKEN")
|
||||
|
||||
func getToken() async throws -> String {
|
||||
try getTokenResult.get()
|
||||
}
|
||||
|
||||
func refreshToken() async throws {
|
||||
func refreshToken() async throws -> String {
|
||||
refreshTokenCalled = true
|
||||
try refreshTokenResult.get()
|
||||
return try refreshTokenResult.get()
|
||||
}
|
||||
|
||||
func setDelegate(delegate: AccountTokenProviderDelegate) async {
|
||||
|
||||
@ -44,6 +44,13 @@ protocol StateService: AnyObject {
|
||||
///
|
||||
func doesActiveAccountHavePremium() async -> Bool
|
||||
|
||||
/// Gets the access token's expiration date for an account.
|
||||
///
|
||||
/// - Parameter userId: The user ID associated with the access token expiration date.
|
||||
/// - Returns: The user's access token expiration date.
|
||||
///
|
||||
func getAccessTokenExpirationDate(userId: String) async -> Date?
|
||||
|
||||
/// Gets the account for an id.
|
||||
///
|
||||
/// - Parameter userId: The id for an account. If nil, the active account will be returned.
|
||||
@ -429,6 +436,14 @@ protocol StateService: AnyObject {
|
||||
///
|
||||
func pinUnlockRequiresPasswordAfterRestart() async throws -> Bool
|
||||
|
||||
/// Sets the access token's expiration date for an account.
|
||||
///
|
||||
/// - Parameters:
|
||||
/// - expirationDate: The user's access token expiration date.
|
||||
/// - userId: The user ID associated with the access token expiration date.
|
||||
///
|
||||
func setAccessTokenExpirationDate(_ expirationDate: Date?, userId: String) async
|
||||
|
||||
/// Sets the account encryption keys for an account.
|
||||
///
|
||||
/// - Parameters:
|
||||
@ -855,6 +870,14 @@ extension StateService {
|
||||
await setPendingAppIntentActions(actions: actions)
|
||||
}
|
||||
|
||||
/// Gets the access token's expiration date for the active account.
|
||||
///
|
||||
/// - Returns: The user's access token expiration date.
|
||||
///
|
||||
func getAccessTokenExpirationDate() async throws -> Date? {
|
||||
try await getAccessTokenExpirationDate(userId: getActiveAccountId())
|
||||
}
|
||||
|
||||
/// Gets the account encryptions keys for the active account.
|
||||
///
|
||||
/// - Returns: The account encryption keys.
|
||||
@ -1143,6 +1166,14 @@ extension StateService {
|
||||
try await pinProtectedUserKeyEnvelope(userId: nil)
|
||||
}
|
||||
|
||||
/// Sets the access token's expiration date for the active account.
|
||||
///
|
||||
/// - Parameter expirationDate: The user's access token expiration date.
|
||||
///
|
||||
func setAccessTokenExpirationDate(_ expirationDate: Date?) async throws {
|
||||
try await setAccessTokenExpirationDate(expirationDate, userId: getActiveAccountId())
|
||||
}
|
||||
|
||||
/// Sets the account encryption keys for the active account.
|
||||
///
|
||||
/// - Parameter encryptionKeys: The account encryption keys.
|
||||
@ -1542,6 +1573,10 @@ actor DefaultStateService: StateService, ConfigStateService { // swiftlint:disab
|
||||
}
|
||||
}
|
||||
|
||||
func getAccessTokenExpirationDate(userId: String) -> Date? {
|
||||
appSettingsStore.accessTokenExpirationDate(userId: userId)
|
||||
}
|
||||
|
||||
func getAccount(userId: String?) throws -> Account {
|
||||
guard let accounts = appSettingsStore.state?.accounts else {
|
||||
throw StateServiceError.noAccounts
|
||||
@ -1844,6 +1879,7 @@ actor DefaultStateService: StateService, ConfigStateService { // swiftlint:disab
|
||||
state.activeUserId = state.accounts.first?.key
|
||||
}
|
||||
|
||||
appSettingsStore.setAccessTokenExpirationDate(nil, userId: knownUserId)
|
||||
appSettingsStore.setBiometricAuthenticationEnabled(nil, for: knownUserId)
|
||||
appSettingsStore.setDefaultUriMatchType(nil, userId: knownUserId)
|
||||
appSettingsStore.setDisableAutoTotpCopy(nil, userId: knownUserId)
|
||||
@ -1876,6 +1912,10 @@ actor DefaultStateService: StateService, ConfigStateService { // swiftlint:disab
|
||||
&& appSettingsStore.pinProtectedUserKey(userId: userId) == nil
|
||||
}
|
||||
|
||||
func setAccessTokenExpirationDate(_ expirationDate: Date?, userId: String) async {
|
||||
appSettingsStore.setAccessTokenExpirationDate(expirationDate, userId: userId)
|
||||
}
|
||||
|
||||
func setAccountKdf(_ kdfConfig: KdfConfig, userId: String) async throws {
|
||||
try updateAccountProfile(userId: userId) { profile in
|
||||
profile.kdfType = kdfConfig.kdfType
|
||||
|
||||
@ -290,6 +290,29 @@ class StateServiceTests: BitwardenTestCase { // swiftlint:disable:this type_body
|
||||
XCTAssertEqual(errorReporter.errors as? [StateServiceError], [.noActiveAccount])
|
||||
}
|
||||
|
||||
/// `getAccessTokenExpirationDate(userId:)` gets the user's access token expiration date.
|
||||
func test_getAccessTokenExpirationDate() async throws {
|
||||
let date1 = Date(year: 2025, month: 1, day: 1)
|
||||
let date2 = Date(year: 2026, month: 6, day: 1)
|
||||
appSettingsStore.accessTokenExpirationDateByUserId["1"] = date1
|
||||
appSettingsStore.accessTokenExpirationDateByUserId["2"] = date2
|
||||
|
||||
await subject.addAccount(.fixture(profile: .fixture(userId: "1")))
|
||||
await subject.addAccount(.fixture(profile: .fixture(userId: "2")))
|
||||
|
||||
let expirationDate1 = await subject.getAccessTokenExpirationDate(userId: "1")
|
||||
XCTAssertEqual(expirationDate1, date1)
|
||||
let expirationDate2 = try await subject.getAccessTokenExpirationDate()
|
||||
XCTAssertEqual(expirationDate2, date2)
|
||||
}
|
||||
|
||||
/// `getAccessTokenExpirationDate(userId:)` throws an error if there's no accounts.
|
||||
func test_getAccessTokenExpirationDate_noAccount() async throws {
|
||||
await assertAsyncThrows(error: StateServiceError.noActiveAccount) {
|
||||
_ = try await subject.getAccessTokenExpirationDate()
|
||||
}
|
||||
}
|
||||
|
||||
/// `getAccountEncryptionKeys(_:)` returns the encryption keys for the user account.
|
||||
func test_getAccountEncryptionKeys() async throws {
|
||||
appSettingsStore.accountKeys["1"] = .fixture(
|
||||
@ -1681,6 +1704,29 @@ class StateServiceTests: BitwardenTestCase { // swiftlint:disable:this type_body
|
||||
XCTAssertTrue(result == true)
|
||||
}
|
||||
|
||||
/// `setAccessTokenExpirationDate(_:userId:)` sets the access token expiration date for the account.
|
||||
func test_setAccessTokenExpirationDate() async throws {
|
||||
let date1 = Date(year: 2025, month: 1, day: 1)
|
||||
let date2 = Date(year: 2026, month: 6, day: 1)
|
||||
await subject.addAccount(.fixture(profile: .fixture(userId: "1")))
|
||||
await subject.addAccount(.fixture(profile: .fixture(userId: "2")))
|
||||
|
||||
await subject.setAccessTokenExpirationDate(date1, userId: "1")
|
||||
try await subject.setAccessTokenExpirationDate(date2)
|
||||
|
||||
XCTAssertEqual(
|
||||
appSettingsStore.accessTokenExpirationDateByUserId,
|
||||
["1": date1, "2": date2],
|
||||
)
|
||||
}
|
||||
|
||||
/// `setAccessTokenExpirationDate(_:userId:)` throws an error if there's no accounts.
|
||||
func test_setAccessTokenExpirationDate_noAccounts() async throws {
|
||||
await assertAsyncThrows(error: StateServiceError.noActiveAccount) {
|
||||
_ = try await subject.setAccessTokenExpirationDate(.now)
|
||||
}
|
||||
}
|
||||
|
||||
/// `setAccountEncryptionKeys(_:userId:)` sets the encryption keys for the user account.
|
||||
func test_setAccountEncryptionKeys() async throws {
|
||||
await subject.addAccount(.fixture(profile: .fixture(userId: "1")))
|
||||
|
||||
@ -73,6 +73,13 @@ protocol AppSettingsStore: AnyObject {
|
||||
/// The app's account state.
|
||||
var state: State? { get set }
|
||||
|
||||
/// The user's access token expiration date.
|
||||
///
|
||||
/// - Parameter userId: The user ID associated with the access token expiration date.
|
||||
/// - Returns: The user's access token expiration date.
|
||||
///
|
||||
func accessTokenExpirationDate(userId: String) -> Date?
|
||||
|
||||
/// The user's v2 account keys.
|
||||
///
|
||||
/// - Parameter userId: The user ID associated with the stored account keys.
|
||||
@ -265,6 +272,14 @@ protocol AppSettingsStore: AnyObject {
|
||||
/// - Returns: The server config for that user ID.
|
||||
func serverConfig(userId: String) -> ServerConfig?
|
||||
|
||||
/// Sets the user's access token expiration date
|
||||
///
|
||||
/// - Parameters:
|
||||
/// - expirationDate: The user's access token expiration date
|
||||
/// - userId: The user ID associated with the access token expiration date.
|
||||
///
|
||||
func setAccessTokenExpirationDate(_ expirationDate: Date?, userId: String)
|
||||
|
||||
/// Sets the account v2 keys for a user ID.
|
||||
///
|
||||
/// - Parameters:
|
||||
@ -740,6 +755,7 @@ extension DefaultAppSettingsStore: AppSettingsStore, ConfigSettingsStore {
|
||||
/// The keys used to store their associated values.
|
||||
///
|
||||
enum Keys {
|
||||
case accessTokenExpirationDate(userId: String)
|
||||
case accountKeys(userId: String)
|
||||
case accountSetupAutofill(userId: String)
|
||||
case accountSetupImportLogins(userId: String)
|
||||
@ -800,6 +816,8 @@ extension DefaultAppSettingsStore: AppSettingsStore, ConfigSettingsStore {
|
||||
/// Returns the key used to store the data under for retrieving it later.
|
||||
var storageKey: String {
|
||||
let key = switch self {
|
||||
case let .accessTokenExpirationDate(userId):
|
||||
"accessTokenExpirationDate_\(userId)"
|
||||
case let .accountKeys(userId):
|
||||
"accountKeys_\(userId)"
|
||||
case let .accountSetupAutofill(userId):
|
||||
@ -1019,6 +1037,10 @@ extension DefaultAppSettingsStore: AppSettingsStore, ConfigSettingsStore {
|
||||
}
|
||||
}
|
||||
|
||||
func accessTokenExpirationDate(userId: String) -> Date? {
|
||||
fetch(for: .accessTokenExpirationDate(userId: userId))
|
||||
}
|
||||
|
||||
func accountKeys(userId: String) -> PrivateKeysResponseModel? {
|
||||
fetch(for: .accountKeys(userId: userId))
|
||||
}
|
||||
@ -1141,6 +1163,10 @@ extension DefaultAppSettingsStore: AppSettingsStore, ConfigSettingsStore {
|
||||
fetch(for: .serverConfig(userId: userId))
|
||||
}
|
||||
|
||||
func setAccessTokenExpirationDate(_ expirationDate: Date?, userId: String) {
|
||||
store(expirationDate, for: .accessTokenExpirationDate(userId: userId))
|
||||
}
|
||||
|
||||
func setAccountKeys(_ keys: PrivateKeysResponseModel?, userId: String) {
|
||||
store(keys, for: .accountKeys(userId: userId))
|
||||
}
|
||||
|
||||
@ -39,6 +39,24 @@ class AppSettingsStoreTests: BitwardenTestCase { // swiftlint:disable:this type_
|
||||
|
||||
// MARK: Tests
|
||||
|
||||
/// `accessTokenExpirationDate(userId:)` returns `nil` if there isn't a previously stored value.
|
||||
func test_accessTokenExpirationDate_isInitiallyNil() {
|
||||
XCTAssertNil(subject.accessTokenExpirationDate(userId: "-1"))
|
||||
}
|
||||
|
||||
/// `accessTokenExpirationDate(userId:)` can be used to get the user's access token expiration date.
|
||||
func test_accessTokenExpirationDate_withValue() {
|
||||
let date1 = Date(year: 2025, month: 10, day: 1)
|
||||
let date2 = Date(year: 2026, month: 1, day: 2)
|
||||
subject.setAccessTokenExpirationDate(date1, userId: "1")
|
||||
subject.setAccessTokenExpirationDate(date2, userId: "2")
|
||||
|
||||
XCTAssertEqual(subject.accessTokenExpirationDate(userId: "1"), date1)
|
||||
XCTAssertEqual(subject.accessTokenExpirationDate(userId: "2"), date2)
|
||||
XCTAssertEqual(userDefaults.integer(forKey: "bwPreferencesStorage:accessTokenExpirationDate_1"), 780_969_600)
|
||||
XCTAssertEqual(userDefaults.integer(forKey: "bwPreferencesStorage:accessTokenExpirationDate_2"), 789_004_800)
|
||||
}
|
||||
|
||||
/// `accountKeys(userId:)` returns `nil` if there isn't a previously stored value.
|
||||
func test_accountKeys_isInitiallyNil() {
|
||||
XCTAssertNil(subject.accountKeys(userId: "-1"))
|
||||
|
||||
@ -7,6 +7,7 @@ import Foundation
|
||||
// swiftlint:disable file_length
|
||||
|
||||
class MockAppSettingsStore: AppSettingsStore { // swiftlint:disable:this type_body_length
|
||||
var accessTokenExpirationDateByUserId = [String: Date]()
|
||||
var accountKeys = [String: PrivateKeysResponseModel]()
|
||||
var accountSetupAutofill = [String: AccountSetupProgress]()
|
||||
var accountSetupImportLogins = [String: AccountSetupProgress]()
|
||||
@ -74,6 +75,10 @@ class MockAppSettingsStore: AppSettingsStore { // swiftlint:disable:this type_bo
|
||||
|
||||
var activeIdSubject = CurrentValueSubject<String?, Never>(nil)
|
||||
|
||||
func accessTokenExpirationDate(userId: String) -> Date? {
|
||||
accessTokenExpirationDateByUserId[userId]
|
||||
}
|
||||
|
||||
func accountKeys(userId: String) -> PrivateKeysResponseModel? {
|
||||
accountKeys[userId]
|
||||
}
|
||||
@ -191,6 +196,10 @@ class MockAppSettingsStore: AppSettingsStore { // swiftlint:disable:this type_bo
|
||||
serverConfig[userId]
|
||||
}
|
||||
|
||||
func setAccessTokenExpirationDate(_ expirationDate: Date?, userId: String) {
|
||||
accessTokenExpirationDateByUserId[userId] = expirationDate
|
||||
}
|
||||
|
||||
func setAccountKeys(_ keys: BitwardenShared.PrivateKeysResponseModel?, userId: String) {
|
||||
accountKeys[userId] = keys
|
||||
}
|
||||
|
||||
@ -7,6 +7,7 @@ import Foundation
|
||||
@testable import BitwardenShared
|
||||
|
||||
class MockStateService: StateService { // swiftlint:disable:this type_body_length
|
||||
var accessTokenExpirationDateByUserId = [String: Date]()
|
||||
var accountEncryptionKeys = [String: AccountEncryptionKeys]()
|
||||
var accountSetupAutofill = [String: AccountSetupProgress]()
|
||||
var accountSetupAutofillError: Error?
|
||||
@ -138,6 +139,10 @@ class MockStateService: StateService { // swiftlint:disable:this type_body_lengt
|
||||
})
|
||||
}
|
||||
|
||||
func getAccessTokenExpirationDate(userId: String) async -> Date? {
|
||||
accessTokenExpirationDateByUserId[userId]
|
||||
}
|
||||
|
||||
func didAccountSwitchInExtension() async throws -> Bool {
|
||||
try didAccountSwitchInExtensionResult.get()
|
||||
}
|
||||
@ -445,6 +450,10 @@ class MockStateService: StateService { // swiftlint:disable:this type_body_lengt
|
||||
pinUnlockRequiresPasswordAfterRestartValue
|
||||
}
|
||||
|
||||
func setAccessTokenExpirationDate(_ expirationDate: Date?, userId: String) async {
|
||||
accessTokenExpirationDateByUserId[userId] = expirationDate
|
||||
}
|
||||
|
||||
func setAccountEncryptionKeys(_ encryptionKeys: AccountEncryptionKeys, userId: String?) async throws {
|
||||
let userId = try unwrapUserId(userId)
|
||||
accountEncryptionKeys[userId] = encryptionKeys
|
||||
|
||||
@ -1,9 +1,12 @@
|
||||
import Foundation
|
||||
import Networking
|
||||
|
||||
@testable import BitwardenShared
|
||||
|
||||
class MockTokenService: TokenService {
|
||||
var accessToken: String? = "ACCESS_TOKEN"
|
||||
var accessTokenExpirationDateResult: Result<Date?, Error> = .success(nil)
|
||||
var expirationDate: Date?
|
||||
var getIsExternalResult: Result<Bool, Error> = .success(false)
|
||||
var refreshToken: String? = "REFRESH_TOKEN"
|
||||
|
||||
@ -12,6 +15,10 @@ class MockTokenService: TokenService {
|
||||
return accessToken
|
||||
}
|
||||
|
||||
func getAccessTokenExpirationDate() async throws -> Date? {
|
||||
try accessTokenExpirationDateResult.get()
|
||||
}
|
||||
|
||||
func getIsExternal() async throws -> Bool {
|
||||
try getIsExternalResult.get()
|
||||
}
|
||||
@ -21,8 +28,9 @@ class MockTokenService: TokenService {
|
||||
return refreshToken
|
||||
}
|
||||
|
||||
func setTokens(accessToken: String, refreshToken: String) async {
|
||||
func setTokens(accessToken: String, refreshToken: String, expirationDate: Date) async {
|
||||
self.accessToken = accessToken
|
||||
self.refreshToken = refreshToken
|
||||
self.expirationDate = expirationDate
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import BitwardenKit
|
||||
import BitwardenSdk
|
||||
import Foundation
|
||||
|
||||
/// A protocol for a `TokenService` which manages accessing and updating the active account's tokens.
|
||||
///
|
||||
@ -10,6 +11,12 @@ protocol TokenService: AnyObject {
|
||||
///
|
||||
func getAccessToken() async throws -> String
|
||||
|
||||
/// Returns the access token's expiration date for the current account.
|
||||
///
|
||||
/// - Returns: The access token's expiration date for the current account.
|
||||
///
|
||||
func getAccessTokenExpirationDate() async throws -> Date?
|
||||
|
||||
/// Returns whether the user is an external user.
|
||||
///
|
||||
/// - Returns: Whether the user is an external user.
|
||||
@ -27,8 +34,9 @@ protocol TokenService: AnyObject {
|
||||
/// - Parameters:
|
||||
/// - accessToken: The account's updated access token.
|
||||
/// - refreshToken: The account's updated refresh token.
|
||||
/// - expirationDate: The access token's expiration date.
|
||||
///
|
||||
func setTokens(accessToken: String, refreshToken: String) async throws
|
||||
func setTokens(accessToken: String, refreshToken: String, expirationDate: Date) async throws
|
||||
}
|
||||
|
||||
// MARK: - DefaultTokenService
|
||||
@ -73,6 +81,10 @@ actor DefaultTokenService: TokenService {
|
||||
return try await keychainRepository.getAccessToken(userId: userId)
|
||||
}
|
||||
|
||||
func getAccessTokenExpirationDate() async throws -> Date? {
|
||||
try await stateService.getAccessTokenExpirationDate()
|
||||
}
|
||||
|
||||
func getIsExternal() async throws -> Bool {
|
||||
let accessToken: String = try await getAccessToken()
|
||||
let tokenPayload = try TokenParser.parseToken(accessToken)
|
||||
@ -84,10 +96,11 @@ actor DefaultTokenService: TokenService {
|
||||
return try await keychainRepository.getRefreshToken(userId: userId)
|
||||
}
|
||||
|
||||
func setTokens(accessToken: String, refreshToken: String) async throws {
|
||||
func setTokens(accessToken: String, refreshToken: String, expirationDate: Date) async throws {
|
||||
let userId = try await stateService.getActiveAccountId()
|
||||
try await keychainRepository.setAccessToken(accessToken, userId: userId)
|
||||
try await keychainRepository.setRefreshToken(refreshToken, userId: userId)
|
||||
await stateService.setAccessTokenExpirationDate(expirationDate, userId: userId)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -69,6 +69,22 @@ class TokenServiceTests: BitwardenTestCase {
|
||||
XCTAssertNil(accessToken)
|
||||
}
|
||||
|
||||
/// `getAccessTokenExpirationDate()` returns the access token's expiration date.
|
||||
func test_getAccessTokenExpirationDate() async throws {
|
||||
stateService.accessTokenExpirationDateByUserId["1"] = Date(year: 2025, month: 10, day: 2)
|
||||
stateService.activeAccount = .fixture()
|
||||
|
||||
let expirationDate = try await subject.getAccessTokenExpirationDate()
|
||||
XCTAssertEqual(expirationDate, Date(year: 2025, month: 10, day: 2))
|
||||
}
|
||||
|
||||
/// `getAccessTokenExpirationDate()` throws an error if there isn't an active account.
|
||||
func test_getAccessTokenExpirationDate_error() async throws {
|
||||
await assertAsyncThrows(error: StateServiceError.noActiveAccount) {
|
||||
_ = try await subject.getAccessTokenExpirationDate()
|
||||
}
|
||||
}
|
||||
|
||||
/// `getIsExternal()` returns false if the user isn't an external user.
|
||||
func test_getIsExternal_false() async throws {
|
||||
// swiftlint:disable:next line_length
|
||||
@ -144,7 +160,8 @@ class TokenServiceTests: BitwardenTestCase {
|
||||
func test_setTokens() async throws {
|
||||
stateService.activeAccount = .fixture()
|
||||
|
||||
try await subject.setTokens(accessToken: "🔑", refreshToken: "🔒")
|
||||
let expirationDate = Date(year: 2025, month: 10, day: 1)
|
||||
try await subject.setTokens(accessToken: "🔑", refreshToken: "🔒", expirationDate: expirationDate)
|
||||
|
||||
XCTAssertEqual(
|
||||
keychainRepository.mockStorage[keychainRepository.formattedKey(for: .accessToken(userId: "1"))],
|
||||
@ -154,5 +171,6 @@ class TokenServiceTests: BitwardenTestCase {
|
||||
keychainRepository.mockStorage[keychainRepository.formattedKey(for: .refreshToken(userId: "1"))],
|
||||
"🔒",
|
||||
)
|
||||
XCTAssertEqual(stateService.accessTokenExpirationDateByUserId["1"], expirationDate)
|
||||
}
|
||||
}
|
||||
|
||||
@ -87,6 +87,10 @@ extension Constants {
|
||||
|
||||
/// The default number of KDF iterations to perform.
|
||||
static let pbkdf2Iterations = 600_000
|
||||
|
||||
/// The number of seconds before an access token's expiration time at which the app will
|
||||
/// preemptively refresh the token.
|
||||
static let tokenRefreshThreshold: TimeInterval = 5 * 60 // 5 minutes
|
||||
}
|
||||
|
||||
// MARK: Extension Constants
|
||||
|
||||
@ -682,6 +682,8 @@ extension AppProcessor: AccountTokenProviderDelegate {
|
||||
func onRefreshTokenError(error: any Error) async throws {
|
||||
if case IdentityTokenRefreshRequestError.invalidGrant = error {
|
||||
await logOutAutomatically()
|
||||
} else if let error = error as? ResponseValidationError, [401, 403].contains(error.response.statusCode) {
|
||||
await logOutAutomatically()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import AuthenticationServices
|
||||
import AuthenticatorBridgeKit
|
||||
import BitwardenKit
|
||||
import BitwardenKitMocks
|
||||
import BitwardenResources
|
||||
import Foundation
|
||||
@ -1402,6 +1403,36 @@ class AppProcessorTests: BitwardenTestCase { // swiftlint:disable:this type_body
|
||||
XCTAssertEqual(coordinator.events, [.didLogout(userId: "1", userInitiated: false)])
|
||||
}
|
||||
|
||||
/// `onRefreshTokenError(error:)` logs the user out and notifies the coordinator when a 401 is
|
||||
/// received while refreshing the token.
|
||||
@MainActor
|
||||
func test_onRefreshTokenError_logOut401() async throws {
|
||||
coordinator.isLoadingOverlayShowing = true
|
||||
|
||||
try await subject.onRefreshTokenError(error: ResponseValidationError(response: .failure(statusCode: 401)))
|
||||
|
||||
XCTAssertTrue(authRepository.logoutCalled)
|
||||
XCTAssertEqual(authRepository.logoutUserId, nil)
|
||||
XCTAssertFalse(authRepository.logoutUserInitiated)
|
||||
XCTAssertFalse(coordinator.isLoadingOverlayShowing)
|
||||
XCTAssertEqual(coordinator.events, [.didLogout(userId: nil, userInitiated: false)])
|
||||
}
|
||||
|
||||
/// `onRefreshTokenError(error:)` logs the user out and notifies the coordinator a 403 is
|
||||
/// received while refreshing the token.
|
||||
@MainActor
|
||||
func test_onRefreshTokenError_logOut403() async throws {
|
||||
coordinator.isLoadingOverlayShowing = true
|
||||
|
||||
try await subject.onRefreshTokenError(error: ResponseValidationError(response: .failure(statusCode: 403)))
|
||||
|
||||
XCTAssertTrue(authRepository.logoutCalled)
|
||||
XCTAssertEqual(authRepository.logoutUserId, nil)
|
||||
XCTAssertFalse(authRepository.logoutUserInitiated)
|
||||
XCTAssertFalse(coordinator.isLoadingOverlayShowing)
|
||||
XCTAssertEqual(coordinator.events, [.didLogout(userId: nil, userInitiated: false)])
|
||||
}
|
||||
|
||||
/// `onRefreshTokenError(error:)` logs the user out and notifies the coordinator when error is `.invalidGrant`.
|
||||
@MainActor
|
||||
func test_onRefreshTokenError_logOutInvalidGrant() async throws {
|
||||
|
||||
@ -145,7 +145,7 @@ public final class HTTPService: Sendable {
|
||||
}
|
||||
|
||||
if let tokenProvider, httpResponse.statusCode == 401, shouldRetryIfUnauthorized {
|
||||
try await tokenProvider.refreshToken()
|
||||
_ = try await tokenProvider.refreshToken()
|
||||
|
||||
// Send the request again, but don't retry if still unauthorized to prevent a retry loop.
|
||||
return try await send(httpRequest, validate: validate, shouldRetryIfUnauthorized: false)
|
||||
|
||||
@ -9,5 +9,7 @@ public protocol TokenProvider: Sendable {
|
||||
|
||||
/// Refreshes the access token by using the refresh token to acquire a new access token.
|
||||
///
|
||||
func refreshToken() async throws
|
||||
/// - Returns: A new access token.
|
||||
///
|
||||
func refreshToken() async throws -> String
|
||||
}
|
||||
|
||||
@ -8,7 +8,7 @@ class MockTokenProvider: TokenProvider {
|
||||
|
||||
var getTokenCallCount = 0
|
||||
var tokenResults: [Result<String, Error>] = [.success("ACCESS_TOKEN")]
|
||||
var refreshTokenResult: Result<Void, Error> = .success(())
|
||||
var refreshTokenResult: Result<String, Error> = .success("ACCESS_TOKEN")
|
||||
var refreshTokenCallCount = 0
|
||||
|
||||
func getToken() async throws -> String {
|
||||
@ -17,8 +17,8 @@ class MockTokenProvider: TokenProvider {
|
||||
return try tokenResults.removeFirst().get()
|
||||
}
|
||||
|
||||
func refreshToken() async throws {
|
||||
func refreshToken() async throws -> String {
|
||||
refreshTokenCallCount += 1
|
||||
try refreshTokenResult.get()
|
||||
return try refreshTokenResult.get()
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user