PM-26358: Integrate the token auth logic with the SDK (#5967)

This commit is contained in:
David Perez 2025-10-07 11:49:57 -05:00 committed by GitHub
parent 0c9530472f
commit cd9c7f98e7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 407 additions and 143 deletions

View File

@ -10,19 +10,20 @@ class AuthTokenManagerImpl(
private val authDiskSource: AuthDiskSource,
) : AuthTokenManager {
override fun getAuthTokenDataOrNull(userId: String): AuthTokenData? =
authDiskSource
.getAccountTokens(userId = userId)
?.takeIf { it.accessToken != null }
?.let {
AuthTokenData(
userId = userId,
accessToken = requireNotNull(it.accessToken),
expiresAtSec = it.expiresAtSec,
)
}
override fun getAuthTokenDataOrNull(): AuthTokenData? = authDiskSource
.userState
?.activeUserId
?.let { userId ->
authDiskSource
.getAccountTokens(userId = userId)
?.takeIf { it.accessToken != null }
?.let {
AuthTokenData(
userId = userId,
accessToken = requireNotNull(it.accessToken),
expiresAtSec = it.expiresAtSec,
)
}
}
?.let(::getAuthTokenDataOrNull)
}

View File

@ -1,21 +1,11 @@
package com.x8bit.bitwarden.data.platform.manager
import android.os.Build
import com.bitwarden.core.ClientManagedTokens
import com.bitwarden.core.util.isBuildVersionAtLeast
import com.bitwarden.data.manager.NativeLibraryManager
import com.bitwarden.sdk.Client
import com.x8bit.bitwarden.data.platform.manager.sdk.SdkRepositoryFactory
/**
* The token provider to pass to the SDK.
*/
class Token : ClientManagedTokens {
override suspend fun getAccessToken(): String? {
return null
}
}
/**
* Primary implementation of [SdkClientManager].
*/
@ -24,14 +14,18 @@ class SdkClientManagerImpl(
sdkRepoFactory: SdkRepositoryFactory,
private val featureFlagManager: FeatureFlagManager,
private val clientProvider: suspend (userId: String?) -> Client = { userId ->
Client(tokenProvider = Token(), settings = null).apply {
platform().loadFlags(featureFlagManager.sdkFeatureFlags)
userId?.let {
platform().state().apply {
registerCipherRepository(sdkRepoFactory.getCipherRepository(userId = it))
Client(
tokenProvider = sdkRepoFactory.getClientManagedTokens(userId = userId),
settings = null,
)
.apply {
platform().loadFlags(featureFlagManager.sdkFeatureFlags)
userId?.let {
platform().state().apply {
registerCipherRepository(sdkRepoFactory.getCipherRepository(userId = it))
}
}
}
}
},
) : SdkClientManager {
private val userIdToClientMap = mutableMapOf<String?, Client>()

View File

@ -389,8 +389,10 @@ object PlatformManagerModule {
@Singleton
fun provideSdkRepositoryFactory(
vaultDiskSource: VaultDiskSource,
bitwardenServiceClient: BitwardenServiceClient,
): SdkRepositoryFactory = SdkRepositoryFactoryImpl(
vaultDiskSource = vaultDiskSource,
bitwardenServiceClient = bitwardenServiceClient,
)
@Provides

View File

@ -1,5 +1,6 @@
package com.x8bit.bitwarden.data.platform.manager.sdk
import com.bitwarden.core.ClientManagedTokens
import com.bitwarden.sdk.CipherRepository
/**
@ -10,4 +11,9 @@ interface SdkRepositoryFactory {
* Retrieves or creates a [CipherRepository] for use with the Bitwarden SDK.
*/
fun getCipherRepository(userId: String): CipherRepository
/**
* Retrieves or creates a [ClientManagedTokens] for use with the Bitwarden SDK.
*/
fun getClientManagedTokens(userId: String?): ClientManagedTokens
}

View File

@ -1,7 +1,10 @@
package com.x8bit.bitwarden.data.platform.manager.sdk
import com.bitwarden.core.ClientManagedTokens
import com.bitwarden.network.BitwardenServiceClient
import com.bitwarden.sdk.CipherRepository
import com.x8bit.bitwarden.data.platform.manager.sdk.repository.SdkCipherRepository
import com.x8bit.bitwarden.data.platform.manager.sdk.repository.SdkTokenRepository
import com.x8bit.bitwarden.data.vault.datasource.disk.VaultDiskSource
/**
@ -9,6 +12,7 @@ import com.x8bit.bitwarden.data.vault.datasource.disk.VaultDiskSource
*/
class SdkRepositoryFactoryImpl(
private val vaultDiskSource: VaultDiskSource,
private val bitwardenServiceClient: BitwardenServiceClient,
) : SdkRepositoryFactory {
override fun getCipherRepository(
userId: String,
@ -17,4 +21,12 @@ class SdkRepositoryFactoryImpl(
userId = userId,
vaultDiskSource = vaultDiskSource,
)
override fun getClientManagedTokens(
userId: String?,
): ClientManagedTokens =
SdkTokenRepository(
userId = userId,
tokenProvider = bitwardenServiceClient.tokenProvider,
)
}

View File

@ -0,0 +1,15 @@
package com.x8bit.bitwarden.data.platform.manager.sdk.repository
import com.bitwarden.core.ClientManagedTokens
import com.bitwarden.network.provider.TokenProvider
/**
* A user-scoped implementation of a Bitwarden SDK [ClientManagedTokens].
*/
class SdkTokenRepository(
private val userId: String?,
private val tokenProvider: TokenProvider,
) : ClientManagedTokens {
override suspend fun getAccessToken(): String? =
userId?.let { tokenProvider.getAccessToken(userId = it) }
}

View File

@ -7,6 +7,7 @@ import com.x8bit.bitwarden.data.auth.datasource.disk.model.AccountTokensJson
import com.x8bit.bitwarden.data.auth.datasource.disk.model.UserStateJson
import com.x8bit.bitwarden.data.auth.datasource.disk.util.FakeAuthDiskSource
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.Nested
import org.junit.jupiter.api.Test
import org.junit.jupiter.api.assertNull
import java.time.ZonedDateTime
@ -16,27 +17,25 @@ class AuthTokenManagerTest {
private val fakeAuthDiskSource = FakeAuthDiskSource()
private val authTokenManager = AuthTokenManagerImpl(fakeAuthDiskSource)
@Test
fun `UserState is null`() {
fakeAuthDiskSource.userState = null
assertNull(authTokenManager.getAuthTokenDataOrNull())
}
@Nested
inner class WithUserId {
@Test
fun `UserState is null`() {
fakeAuthDiskSource.userState = null
assertNull(authTokenManager.getAuthTokenDataOrNull(userId = USER_ID))
}
@Test
fun `Account tokens are null`() {
fakeAuthDiskSource.userState = SINGLE_USER_STATE
.copy(
accounts = mapOf(
USER_ID to ACCOUNT.copy(tokens = null),
),
@Test
fun `Account tokens are null`() {
fakeAuthDiskSource.userState = SINGLE_USER_STATE.copy(
accounts = mapOf(USER_ID to ACCOUNT.copy(tokens = null)),
)
assertNull(authTokenManager.getAuthTokenDataOrNull())
}
assertNull(authTokenManager.getAuthTokenDataOrNull(userId = USER_ID))
}
@Test
fun `Access token is null`() {
fakeAuthDiskSource.userState = SINGLE_USER_STATE
.copy(
@Test
fun `Access token is null`() {
fakeAuthDiskSource.userState = SINGLE_USER_STATE.copy(
accounts = mapOf(
USER_ID to ACCOUNT.copy(
tokens = AccountTokensJson(
@ -46,42 +45,124 @@ class AuthTokenManagerTest {
),
),
)
assertNull(authTokenManager.getAuthTokenDataOrNull())
}
assertNull(authTokenManager.getAuthTokenDataOrNull(userId = USER_ID))
}
@Test
fun `getActiveAccessTokenOrNull should return null if user access token is null`() {
fakeAuthDiskSource.userState = SINGLE_USER_STATE
fakeAuthDiskSource.storeAccountTokens(
userId = USER_ID,
accountTokens = AccountTokensJson(
accessToken = null,
refreshToken = REFRESH_TOKEN,
expiresAtSec = EXPIRES_AT_SEC,
),
)
assertNull(authTokenManager.getAuthTokenDataOrNull())
}
@Test
fun `getActiveAccessTokenOrNull should return active user access token`() {
fakeAuthDiskSource.userState = SINGLE_USER_STATE
fakeAuthDiskSource.storeAccountTokens(
userId = USER_ID,
accountTokens = AccountTokensJson(
accessToken = ACCESS_TOKEN,
refreshToken = REFRESH_TOKEN,
expiresAtSec = EXPIRES_AT_SEC,
),
)
assertEquals(
AuthTokenData(
@Test
fun `getActiveAccessTokenOrNull should return null if user access token is null`() {
fakeAuthDiskSource.userState = SINGLE_USER_STATE
fakeAuthDiskSource.storeAccountTokens(
userId = USER_ID,
accessToken = ACCESS_TOKEN,
expiresAtSec = EXPIRES_AT_SEC,
),
authTokenManager.getAuthTokenDataOrNull(),
)
accountTokens = AccountTokensJson(
accessToken = null,
refreshToken = REFRESH_TOKEN,
expiresAtSec = EXPIRES_AT_SEC,
),
)
assertNull(authTokenManager.getAuthTokenDataOrNull(userId = USER_ID))
}
@Test
fun `getActiveAccessTokenOrNull should return access token`() {
fakeAuthDiskSource.userState = SINGLE_USER_STATE
fakeAuthDiskSource.storeAccountTokens(
userId = USER_ID,
accountTokens = AccountTokensJson(
accessToken = ACCESS_TOKEN,
refreshToken = REFRESH_TOKEN,
expiresAtSec = EXPIRES_AT_SEC,
),
)
assertEquals(
AuthTokenData(
userId = USER_ID,
accessToken = ACCESS_TOKEN,
expiresAtSec = EXPIRES_AT_SEC,
),
authTokenManager.getAuthTokenDataOrNull(userId = USER_ID),
)
}
@Test
fun `getActiveAccessTokenOrNull should return null for unknown userId`() {
fakeAuthDiskSource.userState = SINGLE_USER_STATE
fakeAuthDiskSource.storeAccountTokens(
userId = USER_ID,
accountTokens = AccountTokensJson(
accessToken = ACCESS_TOKEN,
refreshToken = REFRESH_TOKEN,
expiresAtSec = EXPIRES_AT_SEC,
),
)
assertNull(authTokenManager.getAuthTokenDataOrNull(userId = "unknown_user_id"))
}
}
@Nested
inner class WithoutUserId {
@Test
fun `UserState is null`() {
fakeAuthDiskSource.userState = null
assertNull(authTokenManager.getAuthTokenDataOrNull())
}
@Test
fun `Account tokens are null`() {
fakeAuthDiskSource.userState = SINGLE_USER_STATE.copy(
accounts = mapOf(USER_ID to ACCOUNT.copy(tokens = null)),
)
assertNull(authTokenManager.getAuthTokenDataOrNull())
}
@Test
fun `Access token is null`() {
fakeAuthDiskSource.userState = SINGLE_USER_STATE.copy(
accounts = mapOf(
USER_ID to ACCOUNT.copy(
tokens = AccountTokensJson(
accessToken = null,
refreshToken = null,
),
),
),
)
assertNull(authTokenManager.getAuthTokenDataOrNull())
}
@Test
fun `getActiveAccessTokenOrNull should return null if user access token is null`() {
fakeAuthDiskSource.userState = SINGLE_USER_STATE
fakeAuthDiskSource.storeAccountTokens(
userId = USER_ID,
accountTokens = AccountTokensJson(
accessToken = null,
refreshToken = REFRESH_TOKEN,
expiresAtSec = EXPIRES_AT_SEC,
),
)
assertNull(authTokenManager.getAuthTokenDataOrNull())
}
@Test
fun `getActiveAccessTokenOrNull should return active user access token`() {
fakeAuthDiskSource.userState = SINGLE_USER_STATE
fakeAuthDiskSource.storeAccountTokens(
userId = USER_ID,
accountTokens = AccountTokensJson(
accessToken = ACCESS_TOKEN,
refreshToken = REFRESH_TOKEN,
expiresAtSec = EXPIRES_AT_SEC,
),
)
assertEquals(
AuthTokenData(
userId = USER_ID,
accessToken = ACCESS_TOKEN,
expiresAtSec = EXPIRES_AT_SEC,
),
authTokenManager.getAuthTokenDataOrNull(),
)
}
}
}

View File

@ -1,6 +1,8 @@
package com.x8bit.bitwarden.data.platform.manager.sdk
import com.bitwarden.network.BitwardenServiceClient
import com.x8bit.bitwarden.data.vault.datasource.disk.VaultDiskSource
import io.mockk.every
import io.mockk.mockk
import org.junit.jupiter.api.Assertions.assertNotEquals
import org.junit.jupiter.api.Test
@ -8,9 +10,13 @@ import org.junit.jupiter.api.Test
class SdkRepositoryFactoryTests {
private val vaultDiskSource: VaultDiskSource = mockk()
private val bitwardenServiceClient: BitwardenServiceClient = mockk {
every { tokenProvider } returns mockk()
}
private val sdkRepoFactory: SdkRepositoryFactory = SdkRepositoryFactoryImpl(
vaultDiskSource = vaultDiskSource,
bitwardenServiceClient = bitwardenServiceClient,
)
@Test
@ -27,4 +33,19 @@ class SdkRepositoryFactoryTests {
val thirdClient = sdkRepoFactory.getCipherRepository(userId = otherUserId)
assertNotEquals(firstClient, thirdClient)
}
@Test
fun `getClientManagedTokens should create a new client`() {
val userId = "userId"
val firstClient = sdkRepoFactory.getClientManagedTokens(userId = userId)
// Additional calls for the same userId should create a repo
val secondClient = sdkRepoFactory.getClientManagedTokens(userId = userId)
assertNotEquals(firstClient, secondClient)
// Additional calls for different userIds should return a different repo
val otherUserId = "otherUserId"
val thirdClient = sdkRepoFactory.getClientManagedTokens(userId = otherUserId)
assertNotEquals(firstClient, thirdClient)
}
}

View File

@ -0,0 +1,57 @@
package com.x8bit.bitwarden.data.platform.manager.sdk.repository
import com.bitwarden.network.provider.TokenProvider
import io.mockk.every
import io.mockk.mockk
import io.mockk.verify
import kotlinx.coroutines.test.runTest
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.Assertions.assertNull
import org.junit.jupiter.api.Test
class SdkTokenRepositoryTest {
private val tokenProvider: TokenProvider = mockk()
@Test
fun `getAccessToken should return null when userId is null`() = runTest {
val repository = createSdkTokenRepository(userId = null)
assertNull(repository.getAccessToken())
verify(exactly = 0) {
tokenProvider.getAccessToken(userId = any())
}
}
@Test
fun `getAccessToken should return null when userId is valid and tokenProvider returns null`() =
runTest {
every { tokenProvider.getAccessToken(userId = USER_ID) } returns null
val repository = createSdkTokenRepository()
assertNull(repository.getAccessToken())
verify(exactly = 1) {
tokenProvider.getAccessToken(userId = USER_ID)
}
}
@Suppress("MaxLineLength")
@Test
fun `getAccessToken should return access token when userId is valid and tokenProvider returns an access token`() =
runTest {
val accessToken = "access_token"
every { tokenProvider.getAccessToken(userId = USER_ID) } returns accessToken
val repository = createSdkTokenRepository()
assertEquals(accessToken, repository.getAccessToken())
verify(exactly = 1) {
tokenProvider.getAccessToken(userId = USER_ID)
}
}
private fun createSdkTokenRepository(
userId: String? = USER_ID,
): SdkTokenRepository = SdkTokenRepository(
userId = userId,
tokenProvider = tokenProvider,
)
}
private const val USER_ID: String = "userId"

View File

@ -56,6 +56,7 @@ object PlatformNetworkModule {
enableHttpBodyLogging = BuildConfig.DEBUG,
authTokenProvider = object : AuthTokenProvider {
override fun getAuthTokenDataOrNull(): AuthTokenData? = null
override fun getAuthTokenDataOrNull(userId: String): AuthTokenData? = null
},
certificateProvider = object : CertificateProvider {
override fun chooseClientAlias(

View File

@ -9,7 +9,9 @@ import com.bitwarden.sdk.Client
class SdkClientManagerImpl(
private val clientProvider: suspend () -> Client = {
Client(
tokenProvider = Token(),
tokenProvider = object : ClientManagedTokens {
override suspend fun getAccessToken(): String? = null
},
settings = null,
)
},
@ -21,13 +23,4 @@ class SdkClientManagerImpl(
override fun destroyClient() {
client = null
}
/**
* The token provider to pass to the SDK.
*/
private class Token : ClientManagedTokens {
override suspend fun getAccessToken(): String? {
return null
}
}
}

View File

@ -5,6 +5,7 @@ package com.bitwarden.network
import com.bitwarden.annotation.OmitFromCoverage
import com.bitwarden.network.model.BitwardenServiceClientConfig
import com.bitwarden.network.provider.RefreshTokenProvider
import com.bitwarden.network.provider.TokenProvider
import com.bitwarden.network.service.AccountsService
import com.bitwarden.network.service.AuthRequestsService
import com.bitwarden.network.service.CiphersService
@ -48,6 +49,10 @@ import com.bitwarden.network.service.SyncService
* ```
*/
interface BitwardenServiceClient {
/**
* Provides access to the token provider.
*/
val tokenProvider: TokenProvider
/**
* Provides access to the Accounts service.

View File

@ -7,6 +7,7 @@ import com.bitwarden.network.interceptor.BaseUrlInterceptors
import com.bitwarden.network.interceptor.HeadersInterceptor
import com.bitwarden.network.model.BitwardenServiceClientConfig
import com.bitwarden.network.provider.RefreshTokenProvider
import com.bitwarden.network.provider.TokenProvider
import com.bitwarden.network.retrofit.Retrofits
import com.bitwarden.network.retrofit.RetrofitsImpl
import com.bitwarden.network.service.AccountsServiceImpl
@ -55,6 +56,7 @@ internal class BitwardenServiceClientImpl(
clock = bitwardenServiceClientConfig.clock,
authTokenProvider = bitwardenServiceClientConfig.authTokenProvider,
)
override val tokenProvider: TokenProvider = authTokenManager
private val clientJson = Json {
// If there are keys returned by the server not modeled by a serializable class,

View File

@ -1,6 +1,10 @@
package com.bitwarden.network.interceptor
import com.bitwarden.core.data.util.asFailure
import com.bitwarden.core.data.util.asSuccess
import com.bitwarden.network.model.AuthTokenData
import com.bitwarden.network.provider.RefreshTokenProvider
import com.bitwarden.network.provider.TokenProvider
import com.bitwarden.network.util.HEADER_KEY_AUTHORIZATION
import com.bitwarden.network.util.HEADER_VALUE_BEARER_PREFIX
import com.bitwarden.network.util.parseJwtTokenDataOrNull
@ -25,70 +29,61 @@ private const val EXPIRATION_OFFSET_MINUTES: Long = 5L
internal class AuthTokenManager(
private val clock: Clock,
private val authTokenProvider: AuthTokenProvider,
) : Authenticator, Interceptor {
) : TokenProvider, Authenticator, Interceptor {
var refreshTokenProvider: RefreshTokenProvider? = null
@Synchronized
override fun getAccessToken(
userId: String,
): String? = authTokenProvider
.getAuthTokenDataOrNull(userId = userId)
?.let { getAccessToken(authTokenData = it).getOrNull() }
@Synchronized
override fun authenticate(
route: Route?,
response: Response,
): Request? {
synchronized(this) {
if (response.shouldSkipAuthentication()) {
// If the same request keeps failing, let's just let the 401 pass through.
return null
if (response.shouldSkipAuthentication()) {
// If the same request keeps failing, let's just let the 401 pass through.
return null
}
val accessToken = requireNotNull(
response
.request
.header(name = HEADER_KEY_AUTHORIZATION)
?.substringAfter(delimiter = HEADER_VALUE_BEARER_PREFIX),
)
return when (val userId = parseJwtTokenDataOrNull(accessToken)?.userId) {
null -> {
// We are unable to get the user ID, let's just let the 401 pass through.
null
}
val accessToken = requireNotNull(
response
.request
.header(name = HEADER_KEY_AUTHORIZATION)
?.substringAfter(delimiter = HEADER_VALUE_BEARER_PREFIX),
)
return when (val userId = parseJwtTokenDataOrNull(accessToken)?.userId) {
null -> {
// We are unable to get the user ID, let's just let the 401 pass through.
null
}
else -> {
Timber.d("Attempting to refresh token due to unauthorized")
refreshTokenProvider
?.refreshAccessTokenSynchronously(userId = userId)
?.fold(
onFailure = { null },
onSuccess = { newAccessToken ->
response
.request
.newBuilder()
.header(
name = HEADER_KEY_AUTHORIZATION,
value = "$HEADER_VALUE_BEARER_PREFIX$newAccessToken",
)
.build()
},
)
}
else -> {
Timber.d("Attempting to refresh token due to unauthorized")
refreshTokenProvider
?.refreshAccessTokenSynchronously(userId = userId)
?.fold(
onFailure = { null },
onSuccess = { newAccessToken ->
response
.request
.newBuilder()
.header(
name = HEADER_KEY_AUTHORIZATION,
value = "$HEADER_VALUE_BEARER_PREFIX$newAccessToken",
)
.build()
},
)
}
}
}
override fun intercept(chain: Interceptor.Chain): Response {
val token = synchronized(this) {
val tokenData = authTokenProvider
.getAuthTokenDataOrNull()
?: throw IOException(IllegalStateException(MISSING_TOKEN_MESSAGE))
val expirationTime = Instant
.ofEpochSecond(tokenData.expiresAtSec)
.minus(EXPIRATION_OFFSET_MINUTES, ChronoUnit.MINUTES)
if (clock.instant().isAfter(expirationTime)) {
Timber.d("Attempting to refresh token due to expiration")
refreshTokenProvider
?.refreshAccessTokenSynchronously(userId = tokenData.userId)
?.getOrElse { throw IOException(it) }
?: throw IOException(IllegalStateException(MISSING_PROVIDER_MESSAGE))
} else {
tokenData.accessToken
}
}
val token = getAccessToken()
?: throw IOException(IllegalStateException(MISSING_TOKEN_MESSAGE))
val request = chain
.request()
.newBuilder()
@ -100,5 +95,25 @@ internal class AuthTokenManager(
return chain.proceed(request)
}
@Synchronized
private fun getAccessToken(): String? = authTokenProvider
.getAuthTokenDataOrNull()
?.let { getAccessToken(authTokenData = it).getOrThrow() }
@Synchronized
private fun getAccessToken(authTokenData: AuthTokenData): Result<String> {
val expirationTime = Instant
.ofEpochSecond(authTokenData.expiresAtSec)
.minus(EXPIRATION_OFFSET_MINUTES, ChronoUnit.MINUTES)
return if (clock.instant().isAfter(expirationTime)) {
Timber.d("Attempting to refresh token due to expiration")
refreshTokenProvider
?.refreshAccessTokenSynchronously(userId = authTokenData.userId)
?: IOException(IllegalStateException(MISSING_PROVIDER_MESSAGE)).asFailure()
} else {
authTokenData.accessToken.asSuccess()
}
}
private fun Response.shouldSkipAuthentication(): Boolean = this.priorResponse != null
}

View File

@ -6,6 +6,10 @@ import com.bitwarden.network.model.AuthTokenData
* A provider for all the functionality needed to properly refresh the users access token.
*/
interface AuthTokenProvider {
/**
* The specified user's auth token data.
*/
fun getAuthTokenDataOrNull(userId: String): AuthTokenData?
/**
* The currently active user's auth token data.

View File

@ -0,0 +1,11 @@
package com.bitwarden.network.provider
/**
* A provider for authentication tokens.
*/
interface TokenProvider {
/**
* Retrieves an up-to-date token for the specified user.
*/
fun getAccessToken(userId: String): String?
}

View File

@ -51,6 +51,50 @@ class AuthTokenManagerTest {
unmockkStatic(::parseJwtTokenDataOrNull)
}
@Nested
inner class TokenProvider {
@Test
fun `returns null if token provider has no auth data for user ID`() {
val userId = "userId"
every { mockAuthTokenProvider.getAuthTokenDataOrNull(userId = userId) } returns null
val result = authTokenManager.getAccessToken(userId = userId)
assertNull(result)
}
@Test
fun `returns null if refresh fails`() {
val userId = "userId"
val authData = AuthTokenData(
userId = userId,
accessToken = ACCESS_TOKEN,
expiresAtSec = FIXED_CLOCK.instant().epochSecond,
)
every { mockAuthTokenProvider.getAuthTokenDataOrNull(userId = userId) } returns authData
every {
refreshTokenProvider.refreshAccessTokenSynchronously(userId = userId)
} returns Throwable("Fail!").asFailure()
val result = authTokenManager.getAccessToken(userId = userId)
assertNull(result)
}
@Test
fun `returns access token if refresh is not required`() {
val userId = "userId"
val authData = AuthTokenData(
userId = userId,
accessToken = ACCESS_TOKEN,
expiresAtSec = 0L,
)
val refreshedAccessToken = "refreshed_access_token"
every { mockAuthTokenProvider.getAuthTokenDataOrNull(userId = userId) } returns authData
every {
refreshTokenProvider.refreshAccessTokenSynchronously(userId = userId)
} returns refreshedAccessToken.asSuccess()
val result = authTokenManager.getAccessToken(userId = userId)
assertEquals(refreshedAccessToken, result)
}
}
@Nested
inner class Authenticator {
@Test
@ -158,7 +202,7 @@ class AuthTokenManagerTest {
authTokenManager.refreshTokenProvider = object : RefreshTokenProvider {
override fun refreshAccessTokenSynchronously(
userId: String,
): Result<String> = Throwable(errorMessage).asFailure()
): Result<String> = IOException(errorMessage).asFailure()
}
val authTokenData = AuthTokenData(
userId = USER_ID,
@ -172,7 +216,7 @@ class AuthTokenManagerTest {
chain = FakeInterceptorChain(request = request),
)
}
assertEquals(errorMessage, throwable.cause?.message)
assertEquals(errorMessage, throwable?.message)
}
@Suppress("MaxLineLength")