mirror of
https://github.com/bitwarden/android.git
synced 2025-12-10 20:07:59 -06:00
PM-26358: Integrate the token auth logic with the SDK (#5967)
This commit is contained in:
parent
0c9530472f
commit
cd9c7f98e7
@ -10,19 +10,20 @@ class AuthTokenManagerImpl(
|
||||
private val authDiskSource: AuthDiskSource,
|
||||
) : AuthTokenManager {
|
||||
|
||||
override fun getAuthTokenDataOrNull(userId: String): AuthTokenData? =
|
||||
authDiskSource
|
||||
.getAccountTokens(userId = userId)
|
||||
?.takeIf { it.accessToken != null }
|
||||
?.let {
|
||||
AuthTokenData(
|
||||
userId = userId,
|
||||
accessToken = requireNotNull(it.accessToken),
|
||||
expiresAtSec = it.expiresAtSec,
|
||||
)
|
||||
}
|
||||
|
||||
override fun getAuthTokenDataOrNull(): AuthTokenData? = authDiskSource
|
||||
.userState
|
||||
?.activeUserId
|
||||
?.let { userId ->
|
||||
authDiskSource
|
||||
.getAccountTokens(userId = userId)
|
||||
?.takeIf { it.accessToken != null }
|
||||
?.let {
|
||||
AuthTokenData(
|
||||
userId = userId,
|
||||
accessToken = requireNotNull(it.accessToken),
|
||||
expiresAtSec = it.expiresAtSec,
|
||||
)
|
||||
}
|
||||
}
|
||||
?.let(::getAuthTokenDataOrNull)
|
||||
}
|
||||
|
||||
@ -1,21 +1,11 @@
|
||||
package com.x8bit.bitwarden.data.platform.manager
|
||||
|
||||
import android.os.Build
|
||||
import com.bitwarden.core.ClientManagedTokens
|
||||
import com.bitwarden.core.util.isBuildVersionAtLeast
|
||||
import com.bitwarden.data.manager.NativeLibraryManager
|
||||
import com.bitwarden.sdk.Client
|
||||
import com.x8bit.bitwarden.data.platform.manager.sdk.SdkRepositoryFactory
|
||||
|
||||
/**
|
||||
* The token provider to pass to the SDK.
|
||||
*/
|
||||
class Token : ClientManagedTokens {
|
||||
override suspend fun getAccessToken(): String? {
|
||||
return null
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Primary implementation of [SdkClientManager].
|
||||
*/
|
||||
@ -24,14 +14,18 @@ class SdkClientManagerImpl(
|
||||
sdkRepoFactory: SdkRepositoryFactory,
|
||||
private val featureFlagManager: FeatureFlagManager,
|
||||
private val clientProvider: suspend (userId: String?) -> Client = { userId ->
|
||||
Client(tokenProvider = Token(), settings = null).apply {
|
||||
platform().loadFlags(featureFlagManager.sdkFeatureFlags)
|
||||
userId?.let {
|
||||
platform().state().apply {
|
||||
registerCipherRepository(sdkRepoFactory.getCipherRepository(userId = it))
|
||||
Client(
|
||||
tokenProvider = sdkRepoFactory.getClientManagedTokens(userId = userId),
|
||||
settings = null,
|
||||
)
|
||||
.apply {
|
||||
platform().loadFlags(featureFlagManager.sdkFeatureFlags)
|
||||
userId?.let {
|
||||
platform().state().apply {
|
||||
registerCipherRepository(sdkRepoFactory.getCipherRepository(userId = it))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
) : SdkClientManager {
|
||||
private val userIdToClientMap = mutableMapOf<String?, Client>()
|
||||
|
||||
@ -389,8 +389,10 @@ object PlatformManagerModule {
|
||||
@Singleton
|
||||
fun provideSdkRepositoryFactory(
|
||||
vaultDiskSource: VaultDiskSource,
|
||||
bitwardenServiceClient: BitwardenServiceClient,
|
||||
): SdkRepositoryFactory = SdkRepositoryFactoryImpl(
|
||||
vaultDiskSource = vaultDiskSource,
|
||||
bitwardenServiceClient = bitwardenServiceClient,
|
||||
)
|
||||
|
||||
@Provides
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
package com.x8bit.bitwarden.data.platform.manager.sdk
|
||||
|
||||
import com.bitwarden.core.ClientManagedTokens
|
||||
import com.bitwarden.sdk.CipherRepository
|
||||
|
||||
/**
|
||||
@ -10,4 +11,9 @@ interface SdkRepositoryFactory {
|
||||
* Retrieves or creates a [CipherRepository] for use with the Bitwarden SDK.
|
||||
*/
|
||||
fun getCipherRepository(userId: String): CipherRepository
|
||||
|
||||
/**
|
||||
* Retrieves or creates a [ClientManagedTokens] for use with the Bitwarden SDK.
|
||||
*/
|
||||
fun getClientManagedTokens(userId: String?): ClientManagedTokens
|
||||
}
|
||||
|
||||
@ -1,7 +1,10 @@
|
||||
package com.x8bit.bitwarden.data.platform.manager.sdk
|
||||
|
||||
import com.bitwarden.core.ClientManagedTokens
|
||||
import com.bitwarden.network.BitwardenServiceClient
|
||||
import com.bitwarden.sdk.CipherRepository
|
||||
import com.x8bit.bitwarden.data.platform.manager.sdk.repository.SdkCipherRepository
|
||||
import com.x8bit.bitwarden.data.platform.manager.sdk.repository.SdkTokenRepository
|
||||
import com.x8bit.bitwarden.data.vault.datasource.disk.VaultDiskSource
|
||||
|
||||
/**
|
||||
@ -9,6 +12,7 @@ import com.x8bit.bitwarden.data.vault.datasource.disk.VaultDiskSource
|
||||
*/
|
||||
class SdkRepositoryFactoryImpl(
|
||||
private val vaultDiskSource: VaultDiskSource,
|
||||
private val bitwardenServiceClient: BitwardenServiceClient,
|
||||
) : SdkRepositoryFactory {
|
||||
override fun getCipherRepository(
|
||||
userId: String,
|
||||
@ -17,4 +21,12 @@ class SdkRepositoryFactoryImpl(
|
||||
userId = userId,
|
||||
vaultDiskSource = vaultDiskSource,
|
||||
)
|
||||
|
||||
override fun getClientManagedTokens(
|
||||
userId: String?,
|
||||
): ClientManagedTokens =
|
||||
SdkTokenRepository(
|
||||
userId = userId,
|
||||
tokenProvider = bitwardenServiceClient.tokenProvider,
|
||||
)
|
||||
}
|
||||
|
||||
@ -0,0 +1,15 @@
|
||||
package com.x8bit.bitwarden.data.platform.manager.sdk.repository
|
||||
|
||||
import com.bitwarden.core.ClientManagedTokens
|
||||
import com.bitwarden.network.provider.TokenProvider
|
||||
|
||||
/**
|
||||
* A user-scoped implementation of a Bitwarden SDK [ClientManagedTokens].
|
||||
*/
|
||||
class SdkTokenRepository(
|
||||
private val userId: String?,
|
||||
private val tokenProvider: TokenProvider,
|
||||
) : ClientManagedTokens {
|
||||
override suspend fun getAccessToken(): String? =
|
||||
userId?.let { tokenProvider.getAccessToken(userId = it) }
|
||||
}
|
||||
@ -7,6 +7,7 @@ import com.x8bit.bitwarden.data.auth.datasource.disk.model.AccountTokensJson
|
||||
import com.x8bit.bitwarden.data.auth.datasource.disk.model.UserStateJson
|
||||
import com.x8bit.bitwarden.data.auth.datasource.disk.util.FakeAuthDiskSource
|
||||
import org.junit.jupiter.api.Assertions.assertEquals
|
||||
import org.junit.jupiter.api.Nested
|
||||
import org.junit.jupiter.api.Test
|
||||
import org.junit.jupiter.api.assertNull
|
||||
import java.time.ZonedDateTime
|
||||
@ -16,27 +17,25 @@ class AuthTokenManagerTest {
|
||||
private val fakeAuthDiskSource = FakeAuthDiskSource()
|
||||
private val authTokenManager = AuthTokenManagerImpl(fakeAuthDiskSource)
|
||||
|
||||
@Test
|
||||
fun `UserState is null`() {
|
||||
fakeAuthDiskSource.userState = null
|
||||
assertNull(authTokenManager.getAuthTokenDataOrNull())
|
||||
}
|
||||
@Nested
|
||||
inner class WithUserId {
|
||||
@Test
|
||||
fun `UserState is null`() {
|
||||
fakeAuthDiskSource.userState = null
|
||||
assertNull(authTokenManager.getAuthTokenDataOrNull(userId = USER_ID))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `Account tokens are null`() {
|
||||
fakeAuthDiskSource.userState = SINGLE_USER_STATE
|
||||
.copy(
|
||||
accounts = mapOf(
|
||||
USER_ID to ACCOUNT.copy(tokens = null),
|
||||
),
|
||||
@Test
|
||||
fun `Account tokens are null`() {
|
||||
fakeAuthDiskSource.userState = SINGLE_USER_STATE.copy(
|
||||
accounts = mapOf(USER_ID to ACCOUNT.copy(tokens = null)),
|
||||
)
|
||||
assertNull(authTokenManager.getAuthTokenDataOrNull())
|
||||
}
|
||||
assertNull(authTokenManager.getAuthTokenDataOrNull(userId = USER_ID))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `Access token is null`() {
|
||||
fakeAuthDiskSource.userState = SINGLE_USER_STATE
|
||||
.copy(
|
||||
@Test
|
||||
fun `Access token is null`() {
|
||||
fakeAuthDiskSource.userState = SINGLE_USER_STATE.copy(
|
||||
accounts = mapOf(
|
||||
USER_ID to ACCOUNT.copy(
|
||||
tokens = AccountTokensJson(
|
||||
@ -46,42 +45,124 @@ class AuthTokenManagerTest {
|
||||
),
|
||||
),
|
||||
)
|
||||
assertNull(authTokenManager.getAuthTokenDataOrNull())
|
||||
}
|
||||
assertNull(authTokenManager.getAuthTokenDataOrNull(userId = USER_ID))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `getActiveAccessTokenOrNull should return null if user access token is null`() {
|
||||
fakeAuthDiskSource.userState = SINGLE_USER_STATE
|
||||
fakeAuthDiskSource.storeAccountTokens(
|
||||
userId = USER_ID,
|
||||
accountTokens = AccountTokensJson(
|
||||
accessToken = null,
|
||||
refreshToken = REFRESH_TOKEN,
|
||||
expiresAtSec = EXPIRES_AT_SEC,
|
||||
),
|
||||
)
|
||||
assertNull(authTokenManager.getAuthTokenDataOrNull())
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `getActiveAccessTokenOrNull should return active user access token`() {
|
||||
fakeAuthDiskSource.userState = SINGLE_USER_STATE
|
||||
fakeAuthDiskSource.storeAccountTokens(
|
||||
userId = USER_ID,
|
||||
accountTokens = AccountTokensJson(
|
||||
accessToken = ACCESS_TOKEN,
|
||||
refreshToken = REFRESH_TOKEN,
|
||||
expiresAtSec = EXPIRES_AT_SEC,
|
||||
),
|
||||
)
|
||||
assertEquals(
|
||||
AuthTokenData(
|
||||
@Test
|
||||
fun `getActiveAccessTokenOrNull should return null if user access token is null`() {
|
||||
fakeAuthDiskSource.userState = SINGLE_USER_STATE
|
||||
fakeAuthDiskSource.storeAccountTokens(
|
||||
userId = USER_ID,
|
||||
accessToken = ACCESS_TOKEN,
|
||||
expiresAtSec = EXPIRES_AT_SEC,
|
||||
),
|
||||
authTokenManager.getAuthTokenDataOrNull(),
|
||||
)
|
||||
accountTokens = AccountTokensJson(
|
||||
accessToken = null,
|
||||
refreshToken = REFRESH_TOKEN,
|
||||
expiresAtSec = EXPIRES_AT_SEC,
|
||||
),
|
||||
)
|
||||
assertNull(authTokenManager.getAuthTokenDataOrNull(userId = USER_ID))
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `getActiveAccessTokenOrNull should return access token`() {
|
||||
fakeAuthDiskSource.userState = SINGLE_USER_STATE
|
||||
fakeAuthDiskSource.storeAccountTokens(
|
||||
userId = USER_ID,
|
||||
accountTokens = AccountTokensJson(
|
||||
accessToken = ACCESS_TOKEN,
|
||||
refreshToken = REFRESH_TOKEN,
|
||||
expiresAtSec = EXPIRES_AT_SEC,
|
||||
),
|
||||
)
|
||||
assertEquals(
|
||||
AuthTokenData(
|
||||
userId = USER_ID,
|
||||
accessToken = ACCESS_TOKEN,
|
||||
expiresAtSec = EXPIRES_AT_SEC,
|
||||
),
|
||||
authTokenManager.getAuthTokenDataOrNull(userId = USER_ID),
|
||||
)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `getActiveAccessTokenOrNull should return null for unknown userId`() {
|
||||
fakeAuthDiskSource.userState = SINGLE_USER_STATE
|
||||
fakeAuthDiskSource.storeAccountTokens(
|
||||
userId = USER_ID,
|
||||
accountTokens = AccountTokensJson(
|
||||
accessToken = ACCESS_TOKEN,
|
||||
refreshToken = REFRESH_TOKEN,
|
||||
expiresAtSec = EXPIRES_AT_SEC,
|
||||
),
|
||||
)
|
||||
assertNull(authTokenManager.getAuthTokenDataOrNull(userId = "unknown_user_id"))
|
||||
}
|
||||
}
|
||||
|
||||
@Nested
|
||||
inner class WithoutUserId {
|
||||
@Test
|
||||
fun `UserState is null`() {
|
||||
fakeAuthDiskSource.userState = null
|
||||
assertNull(authTokenManager.getAuthTokenDataOrNull())
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `Account tokens are null`() {
|
||||
fakeAuthDiskSource.userState = SINGLE_USER_STATE.copy(
|
||||
accounts = mapOf(USER_ID to ACCOUNT.copy(tokens = null)),
|
||||
)
|
||||
assertNull(authTokenManager.getAuthTokenDataOrNull())
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `Access token is null`() {
|
||||
fakeAuthDiskSource.userState = SINGLE_USER_STATE.copy(
|
||||
accounts = mapOf(
|
||||
USER_ID to ACCOUNT.copy(
|
||||
tokens = AccountTokensJson(
|
||||
accessToken = null,
|
||||
refreshToken = null,
|
||||
),
|
||||
),
|
||||
),
|
||||
)
|
||||
assertNull(authTokenManager.getAuthTokenDataOrNull())
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `getActiveAccessTokenOrNull should return null if user access token is null`() {
|
||||
fakeAuthDiskSource.userState = SINGLE_USER_STATE
|
||||
fakeAuthDiskSource.storeAccountTokens(
|
||||
userId = USER_ID,
|
||||
accountTokens = AccountTokensJson(
|
||||
accessToken = null,
|
||||
refreshToken = REFRESH_TOKEN,
|
||||
expiresAtSec = EXPIRES_AT_SEC,
|
||||
),
|
||||
)
|
||||
assertNull(authTokenManager.getAuthTokenDataOrNull())
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `getActiveAccessTokenOrNull should return active user access token`() {
|
||||
fakeAuthDiskSource.userState = SINGLE_USER_STATE
|
||||
fakeAuthDiskSource.storeAccountTokens(
|
||||
userId = USER_ID,
|
||||
accountTokens = AccountTokensJson(
|
||||
accessToken = ACCESS_TOKEN,
|
||||
refreshToken = REFRESH_TOKEN,
|
||||
expiresAtSec = EXPIRES_AT_SEC,
|
||||
),
|
||||
)
|
||||
assertEquals(
|
||||
AuthTokenData(
|
||||
userId = USER_ID,
|
||||
accessToken = ACCESS_TOKEN,
|
||||
expiresAtSec = EXPIRES_AT_SEC,
|
||||
),
|
||||
authTokenManager.getAuthTokenDataOrNull(),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
package com.x8bit.bitwarden.data.platform.manager.sdk
|
||||
|
||||
import com.bitwarden.network.BitwardenServiceClient
|
||||
import com.x8bit.bitwarden.data.vault.datasource.disk.VaultDiskSource
|
||||
import io.mockk.every
|
||||
import io.mockk.mockk
|
||||
import org.junit.jupiter.api.Assertions.assertNotEquals
|
||||
import org.junit.jupiter.api.Test
|
||||
@ -8,9 +10,13 @@ import org.junit.jupiter.api.Test
|
||||
class SdkRepositoryFactoryTests {
|
||||
|
||||
private val vaultDiskSource: VaultDiskSource = mockk()
|
||||
private val bitwardenServiceClient: BitwardenServiceClient = mockk {
|
||||
every { tokenProvider } returns mockk()
|
||||
}
|
||||
|
||||
private val sdkRepoFactory: SdkRepositoryFactory = SdkRepositoryFactoryImpl(
|
||||
vaultDiskSource = vaultDiskSource,
|
||||
bitwardenServiceClient = bitwardenServiceClient,
|
||||
)
|
||||
|
||||
@Test
|
||||
@ -27,4 +33,19 @@ class SdkRepositoryFactoryTests {
|
||||
val thirdClient = sdkRepoFactory.getCipherRepository(userId = otherUserId)
|
||||
assertNotEquals(firstClient, thirdClient)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `getClientManagedTokens should create a new client`() {
|
||||
val userId = "userId"
|
||||
val firstClient = sdkRepoFactory.getClientManagedTokens(userId = userId)
|
||||
|
||||
// Additional calls for the same userId should create a repo
|
||||
val secondClient = sdkRepoFactory.getClientManagedTokens(userId = userId)
|
||||
assertNotEquals(firstClient, secondClient)
|
||||
|
||||
// Additional calls for different userIds should return a different repo
|
||||
val otherUserId = "otherUserId"
|
||||
val thirdClient = sdkRepoFactory.getClientManagedTokens(userId = otherUserId)
|
||||
assertNotEquals(firstClient, thirdClient)
|
||||
}
|
||||
}
|
||||
|
||||
@ -0,0 +1,57 @@
|
||||
package com.x8bit.bitwarden.data.platform.manager.sdk.repository
|
||||
|
||||
import com.bitwarden.network.provider.TokenProvider
|
||||
import io.mockk.every
|
||||
import io.mockk.mockk
|
||||
import io.mockk.verify
|
||||
import kotlinx.coroutines.test.runTest
|
||||
import org.junit.jupiter.api.Assertions.assertEquals
|
||||
import org.junit.jupiter.api.Assertions.assertNull
|
||||
import org.junit.jupiter.api.Test
|
||||
|
||||
class SdkTokenRepositoryTest {
|
||||
|
||||
private val tokenProvider: TokenProvider = mockk()
|
||||
|
||||
@Test
|
||||
fun `getAccessToken should return null when userId is null`() = runTest {
|
||||
val repository = createSdkTokenRepository(userId = null)
|
||||
assertNull(repository.getAccessToken())
|
||||
verify(exactly = 0) {
|
||||
tokenProvider.getAccessToken(userId = any())
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `getAccessToken should return null when userId is valid and tokenProvider returns null`() =
|
||||
runTest {
|
||||
every { tokenProvider.getAccessToken(userId = USER_ID) } returns null
|
||||
val repository = createSdkTokenRepository()
|
||||
assertNull(repository.getAccessToken())
|
||||
verify(exactly = 1) {
|
||||
tokenProvider.getAccessToken(userId = USER_ID)
|
||||
}
|
||||
}
|
||||
|
||||
@Suppress("MaxLineLength")
|
||||
@Test
|
||||
fun `getAccessToken should return access token when userId is valid and tokenProvider returns an access token`() =
|
||||
runTest {
|
||||
val accessToken = "access_token"
|
||||
every { tokenProvider.getAccessToken(userId = USER_ID) } returns accessToken
|
||||
val repository = createSdkTokenRepository()
|
||||
assertEquals(accessToken, repository.getAccessToken())
|
||||
verify(exactly = 1) {
|
||||
tokenProvider.getAccessToken(userId = USER_ID)
|
||||
}
|
||||
}
|
||||
|
||||
private fun createSdkTokenRepository(
|
||||
userId: String? = USER_ID,
|
||||
): SdkTokenRepository = SdkTokenRepository(
|
||||
userId = userId,
|
||||
tokenProvider = tokenProvider,
|
||||
)
|
||||
}
|
||||
|
||||
private const val USER_ID: String = "userId"
|
||||
@ -56,6 +56,7 @@ object PlatformNetworkModule {
|
||||
enableHttpBodyLogging = BuildConfig.DEBUG,
|
||||
authTokenProvider = object : AuthTokenProvider {
|
||||
override fun getAuthTokenDataOrNull(): AuthTokenData? = null
|
||||
override fun getAuthTokenDataOrNull(userId: String): AuthTokenData? = null
|
||||
},
|
||||
certificateProvider = object : CertificateProvider {
|
||||
override fun chooseClientAlias(
|
||||
|
||||
@ -9,7 +9,9 @@ import com.bitwarden.sdk.Client
|
||||
class SdkClientManagerImpl(
|
||||
private val clientProvider: suspend () -> Client = {
|
||||
Client(
|
||||
tokenProvider = Token(),
|
||||
tokenProvider = object : ClientManagedTokens {
|
||||
override suspend fun getAccessToken(): String? = null
|
||||
},
|
||||
settings = null,
|
||||
)
|
||||
},
|
||||
@ -21,13 +23,4 @@ class SdkClientManagerImpl(
|
||||
override fun destroyClient() {
|
||||
client = null
|
||||
}
|
||||
|
||||
/**
|
||||
* The token provider to pass to the SDK.
|
||||
*/
|
||||
private class Token : ClientManagedTokens {
|
||||
override suspend fun getAccessToken(): String? {
|
||||
return null
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -5,6 +5,7 @@ package com.bitwarden.network
|
||||
import com.bitwarden.annotation.OmitFromCoverage
|
||||
import com.bitwarden.network.model.BitwardenServiceClientConfig
|
||||
import com.bitwarden.network.provider.RefreshTokenProvider
|
||||
import com.bitwarden.network.provider.TokenProvider
|
||||
import com.bitwarden.network.service.AccountsService
|
||||
import com.bitwarden.network.service.AuthRequestsService
|
||||
import com.bitwarden.network.service.CiphersService
|
||||
@ -48,6 +49,10 @@ import com.bitwarden.network.service.SyncService
|
||||
* ```
|
||||
*/
|
||||
interface BitwardenServiceClient {
|
||||
/**
|
||||
* Provides access to the token provider.
|
||||
*/
|
||||
val tokenProvider: TokenProvider
|
||||
|
||||
/**
|
||||
* Provides access to the Accounts service.
|
||||
|
||||
@ -7,6 +7,7 @@ import com.bitwarden.network.interceptor.BaseUrlInterceptors
|
||||
import com.bitwarden.network.interceptor.HeadersInterceptor
|
||||
import com.bitwarden.network.model.BitwardenServiceClientConfig
|
||||
import com.bitwarden.network.provider.RefreshTokenProvider
|
||||
import com.bitwarden.network.provider.TokenProvider
|
||||
import com.bitwarden.network.retrofit.Retrofits
|
||||
import com.bitwarden.network.retrofit.RetrofitsImpl
|
||||
import com.bitwarden.network.service.AccountsServiceImpl
|
||||
@ -55,6 +56,7 @@ internal class BitwardenServiceClientImpl(
|
||||
clock = bitwardenServiceClientConfig.clock,
|
||||
authTokenProvider = bitwardenServiceClientConfig.authTokenProvider,
|
||||
)
|
||||
override val tokenProvider: TokenProvider = authTokenManager
|
||||
private val clientJson = Json {
|
||||
|
||||
// If there are keys returned by the server not modeled by a serializable class,
|
||||
|
||||
@ -1,6 +1,10 @@
|
||||
package com.bitwarden.network.interceptor
|
||||
|
||||
import com.bitwarden.core.data.util.asFailure
|
||||
import com.bitwarden.core.data.util.asSuccess
|
||||
import com.bitwarden.network.model.AuthTokenData
|
||||
import com.bitwarden.network.provider.RefreshTokenProvider
|
||||
import com.bitwarden.network.provider.TokenProvider
|
||||
import com.bitwarden.network.util.HEADER_KEY_AUTHORIZATION
|
||||
import com.bitwarden.network.util.HEADER_VALUE_BEARER_PREFIX
|
||||
import com.bitwarden.network.util.parseJwtTokenDataOrNull
|
||||
@ -25,70 +29,61 @@ private const val EXPIRATION_OFFSET_MINUTES: Long = 5L
|
||||
internal class AuthTokenManager(
|
||||
private val clock: Clock,
|
||||
private val authTokenProvider: AuthTokenProvider,
|
||||
) : Authenticator, Interceptor {
|
||||
) : TokenProvider, Authenticator, Interceptor {
|
||||
var refreshTokenProvider: RefreshTokenProvider? = null
|
||||
|
||||
@Synchronized
|
||||
override fun getAccessToken(
|
||||
userId: String,
|
||||
): String? = authTokenProvider
|
||||
.getAuthTokenDataOrNull(userId = userId)
|
||||
?.let { getAccessToken(authTokenData = it).getOrNull() }
|
||||
|
||||
@Synchronized
|
||||
override fun authenticate(
|
||||
route: Route?,
|
||||
response: Response,
|
||||
): Request? {
|
||||
synchronized(this) {
|
||||
if (response.shouldSkipAuthentication()) {
|
||||
// If the same request keeps failing, let's just let the 401 pass through.
|
||||
return null
|
||||
if (response.shouldSkipAuthentication()) {
|
||||
// If the same request keeps failing, let's just let the 401 pass through.
|
||||
return null
|
||||
}
|
||||
val accessToken = requireNotNull(
|
||||
response
|
||||
.request
|
||||
.header(name = HEADER_KEY_AUTHORIZATION)
|
||||
?.substringAfter(delimiter = HEADER_VALUE_BEARER_PREFIX),
|
||||
)
|
||||
return when (val userId = parseJwtTokenDataOrNull(accessToken)?.userId) {
|
||||
null -> {
|
||||
// We are unable to get the user ID, let's just let the 401 pass through.
|
||||
null
|
||||
}
|
||||
val accessToken = requireNotNull(
|
||||
response
|
||||
.request
|
||||
.header(name = HEADER_KEY_AUTHORIZATION)
|
||||
?.substringAfter(delimiter = HEADER_VALUE_BEARER_PREFIX),
|
||||
)
|
||||
return when (val userId = parseJwtTokenDataOrNull(accessToken)?.userId) {
|
||||
null -> {
|
||||
// We are unable to get the user ID, let's just let the 401 pass through.
|
||||
null
|
||||
}
|
||||
|
||||
else -> {
|
||||
Timber.d("Attempting to refresh token due to unauthorized")
|
||||
refreshTokenProvider
|
||||
?.refreshAccessTokenSynchronously(userId = userId)
|
||||
?.fold(
|
||||
onFailure = { null },
|
||||
onSuccess = { newAccessToken ->
|
||||
response
|
||||
.request
|
||||
.newBuilder()
|
||||
.header(
|
||||
name = HEADER_KEY_AUTHORIZATION,
|
||||
value = "$HEADER_VALUE_BEARER_PREFIX$newAccessToken",
|
||||
)
|
||||
.build()
|
||||
},
|
||||
)
|
||||
}
|
||||
else -> {
|
||||
Timber.d("Attempting to refresh token due to unauthorized")
|
||||
refreshTokenProvider
|
||||
?.refreshAccessTokenSynchronously(userId = userId)
|
||||
?.fold(
|
||||
onFailure = { null },
|
||||
onSuccess = { newAccessToken ->
|
||||
response
|
||||
.request
|
||||
.newBuilder()
|
||||
.header(
|
||||
name = HEADER_KEY_AUTHORIZATION,
|
||||
value = "$HEADER_VALUE_BEARER_PREFIX$newAccessToken",
|
||||
)
|
||||
.build()
|
||||
},
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override fun intercept(chain: Interceptor.Chain): Response {
|
||||
val token = synchronized(this) {
|
||||
val tokenData = authTokenProvider
|
||||
.getAuthTokenDataOrNull()
|
||||
?: throw IOException(IllegalStateException(MISSING_TOKEN_MESSAGE))
|
||||
val expirationTime = Instant
|
||||
.ofEpochSecond(tokenData.expiresAtSec)
|
||||
.minus(EXPIRATION_OFFSET_MINUTES, ChronoUnit.MINUTES)
|
||||
if (clock.instant().isAfter(expirationTime)) {
|
||||
Timber.d("Attempting to refresh token due to expiration")
|
||||
refreshTokenProvider
|
||||
?.refreshAccessTokenSynchronously(userId = tokenData.userId)
|
||||
?.getOrElse { throw IOException(it) }
|
||||
?: throw IOException(IllegalStateException(MISSING_PROVIDER_MESSAGE))
|
||||
} else {
|
||||
tokenData.accessToken
|
||||
}
|
||||
}
|
||||
val token = getAccessToken()
|
||||
?: throw IOException(IllegalStateException(MISSING_TOKEN_MESSAGE))
|
||||
val request = chain
|
||||
.request()
|
||||
.newBuilder()
|
||||
@ -100,5 +95,25 @@ internal class AuthTokenManager(
|
||||
return chain.proceed(request)
|
||||
}
|
||||
|
||||
@Synchronized
|
||||
private fun getAccessToken(): String? = authTokenProvider
|
||||
.getAuthTokenDataOrNull()
|
||||
?.let { getAccessToken(authTokenData = it).getOrThrow() }
|
||||
|
||||
@Synchronized
|
||||
private fun getAccessToken(authTokenData: AuthTokenData): Result<String> {
|
||||
val expirationTime = Instant
|
||||
.ofEpochSecond(authTokenData.expiresAtSec)
|
||||
.minus(EXPIRATION_OFFSET_MINUTES, ChronoUnit.MINUTES)
|
||||
return if (clock.instant().isAfter(expirationTime)) {
|
||||
Timber.d("Attempting to refresh token due to expiration")
|
||||
refreshTokenProvider
|
||||
?.refreshAccessTokenSynchronously(userId = authTokenData.userId)
|
||||
?: IOException(IllegalStateException(MISSING_PROVIDER_MESSAGE)).asFailure()
|
||||
} else {
|
||||
authTokenData.accessToken.asSuccess()
|
||||
}
|
||||
}
|
||||
|
||||
private fun Response.shouldSkipAuthentication(): Boolean = this.priorResponse != null
|
||||
}
|
||||
|
||||
@ -6,6 +6,10 @@ import com.bitwarden.network.model.AuthTokenData
|
||||
* A provider for all the functionality needed to properly refresh the users access token.
|
||||
*/
|
||||
interface AuthTokenProvider {
|
||||
/**
|
||||
* The specified user's auth token data.
|
||||
*/
|
||||
fun getAuthTokenDataOrNull(userId: String): AuthTokenData?
|
||||
|
||||
/**
|
||||
* The currently active user's auth token data.
|
||||
|
||||
@ -0,0 +1,11 @@
|
||||
package com.bitwarden.network.provider
|
||||
|
||||
/**
|
||||
* A provider for authentication tokens.
|
||||
*/
|
||||
interface TokenProvider {
|
||||
/**
|
||||
* Retrieves an up-to-date token for the specified user.
|
||||
*/
|
||||
fun getAccessToken(userId: String): String?
|
||||
}
|
||||
@ -51,6 +51,50 @@ class AuthTokenManagerTest {
|
||||
unmockkStatic(::parseJwtTokenDataOrNull)
|
||||
}
|
||||
|
||||
@Nested
|
||||
inner class TokenProvider {
|
||||
@Test
|
||||
fun `returns null if token provider has no auth data for user ID`() {
|
||||
val userId = "userId"
|
||||
every { mockAuthTokenProvider.getAuthTokenDataOrNull(userId = userId) } returns null
|
||||
val result = authTokenManager.getAccessToken(userId = userId)
|
||||
assertNull(result)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `returns null if refresh fails`() {
|
||||
val userId = "userId"
|
||||
val authData = AuthTokenData(
|
||||
userId = userId,
|
||||
accessToken = ACCESS_TOKEN,
|
||||
expiresAtSec = FIXED_CLOCK.instant().epochSecond,
|
||||
)
|
||||
every { mockAuthTokenProvider.getAuthTokenDataOrNull(userId = userId) } returns authData
|
||||
every {
|
||||
refreshTokenProvider.refreshAccessTokenSynchronously(userId = userId)
|
||||
} returns Throwable("Fail!").asFailure()
|
||||
val result = authTokenManager.getAccessToken(userId = userId)
|
||||
assertNull(result)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `returns access token if refresh is not required`() {
|
||||
val userId = "userId"
|
||||
val authData = AuthTokenData(
|
||||
userId = userId,
|
||||
accessToken = ACCESS_TOKEN,
|
||||
expiresAtSec = 0L,
|
||||
)
|
||||
val refreshedAccessToken = "refreshed_access_token"
|
||||
every { mockAuthTokenProvider.getAuthTokenDataOrNull(userId = userId) } returns authData
|
||||
every {
|
||||
refreshTokenProvider.refreshAccessTokenSynchronously(userId = userId)
|
||||
} returns refreshedAccessToken.asSuccess()
|
||||
val result = authTokenManager.getAccessToken(userId = userId)
|
||||
assertEquals(refreshedAccessToken, result)
|
||||
}
|
||||
}
|
||||
|
||||
@Nested
|
||||
inner class Authenticator {
|
||||
@Test
|
||||
@ -158,7 +202,7 @@ class AuthTokenManagerTest {
|
||||
authTokenManager.refreshTokenProvider = object : RefreshTokenProvider {
|
||||
override fun refreshAccessTokenSynchronously(
|
||||
userId: String,
|
||||
): Result<String> = Throwable(errorMessage).asFailure()
|
||||
): Result<String> = IOException(errorMessage).asFailure()
|
||||
}
|
||||
val authTokenData = AuthTokenData(
|
||||
userId = USER_ID,
|
||||
@ -172,7 +216,7 @@ class AuthTokenManagerTest {
|
||||
chain = FakeInterceptorChain(request = request),
|
||||
)
|
||||
}
|
||||
assertEquals(errorMessage, throwable.cause?.message)
|
||||
assertEquals(errorMessage, throwable?.message)
|
||||
}
|
||||
|
||||
@Suppress("MaxLineLength")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user