PM-24481: Logout when token refresh API returns 401 or 403 (#5651)

This commit is contained in:
David Perez 2025-08-06 15:38:01 -05:00 committed by GitHub
parent 59c2261e7c
commit 3c033d4aa2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 158 additions and 16 deletions

View File

@ -781,12 +781,21 @@ class AuthRepositoryImpl(
when (refreshTokenResponse) { when (refreshTokenResponse) {
is RefreshTokenResponseJson.Error -> { is RefreshTokenResponseJson.Error -> {
if (refreshTokenResponse.isInvalidGrant) { if (refreshTokenResponse.isInvalidGrant) {
// We only logout for an invalid grant
logout(userId = userId, reason = LogoutReason.InvalidGrant) logout(userId = userId, reason = LogoutReason.InvalidGrant)
} }
IllegalStateException(refreshTokenResponse.error).asFailure() IllegalStateException(refreshTokenResponse.error).asFailure()
} }
is RefreshTokenResponseJson.Forbidden -> {
logout(userId = userId, reason = LogoutReason.RefreshForbidden)
refreshTokenResponse.error.asFailure()
}
is RefreshTokenResponseJson.Unauthorized -> {
logout(userId = userId, reason = LogoutReason.RefreshUnauthorized)
refreshTokenResponse.error.asFailure()
}
is RefreshTokenResponseJson.Success -> { is RefreshTokenResponseJson.Success -> {
// Store the new token information // Store the new token information
authDiskSource.storeAccountTokens( authDiskSource.storeAccountTokens(

View File

@ -35,6 +35,18 @@ sealed class LogoutReason {
*/ */
data object InvalidGrant : LogoutReason() data object InvalidGrant : LogoutReason()
/**
* Indicates that the logout is happening because the there was a "Forbidden" response from
* token refresh API.
*/
data object RefreshForbidden : LogoutReason()
/**
* Indicates that the logout is happening because the there was a "Unauthorized" response from
* token refresh API.
*/
data object RefreshUnauthorized : LogoutReason()
/** /**
* Indicates that the logout is happening because of an invalid state. * Indicates that the logout is happening because of an invalid state.
*/ */

View File

@ -882,21 +882,88 @@ class AuthRepositoryTest {
} }
@Test @Test
fun `refreshAccessTokenSynchronously returns failure and logs out on failure`() = runTest { fun `refreshAccessTokenSynchronously returns failure if refreshTokenSynchronously fails`() =
fakeAuthDiskSource.storeAccountTokens( runTest {
userId = USER_ID_1, fakeAuthDiskSource.storeAccountTokens(
accountTokens = ACCOUNT_TOKENS_1, userId = USER_ID_1,
) accountTokens = ACCOUNT_TOKENS_1,
coEvery { )
identityService.refreshTokenSynchronously(REFRESH_TOKEN) coEvery {
} returns Throwable("Fail").asFailure() identityService.refreshTokenSynchronously(REFRESH_TOKEN)
} returns Throwable("Fail").asFailure()
assertTrue(repository.refreshAccessTokenSynchronously(USER_ID_1).isFailure) assertTrue(repository.refreshAccessTokenSynchronously(USER_ID_1).isFailure)
coVerify(exactly = 1) { coVerify(exactly = 1) {
identityService.refreshTokenSynchronously(REFRESH_TOKEN) identityService.refreshTokenSynchronously(REFRESH_TOKEN)
}
}
@Suppress("MaxLineLength")
@Test
fun `refreshAccessTokenSynchronously returns logs out and returns failure if refreshTokenSynchronously returns invalid_grant`() =
runTest {
fakeAuthDiskSource.userState = SINGLE_USER_STATE_1
fakeAuthDiskSource.storeAccountTokens(
userId = USER_ID_1,
accountTokens = ACCOUNT_TOKENS_1,
)
coEvery {
identityService.refreshTokenSynchronously(REFRESH_TOKEN)
} returns RefreshTokenResponseJson.Error(error = "invalid_grant").asSuccess()
assertTrue(repository.refreshAccessTokenSynchronously(USER_ID_1).isFailure)
coVerify(exactly = 1) {
identityService.refreshTokenSynchronously(REFRESH_TOKEN)
userLogoutManager.logout(userId = USER_ID_1, reason = LogoutReason.InvalidGrant)
}
}
@Suppress("MaxLineLength")
@Test
fun `refreshAccessTokenSynchronously returns logs out and returns failure if refreshTokenSynchronously returns Forbidden`() =
runTest {
fakeAuthDiskSource.userState = SINGLE_USER_STATE_1
fakeAuthDiskSource.storeAccountTokens(
userId = USER_ID_1,
accountTokens = ACCOUNT_TOKENS_1,
)
coEvery {
identityService.refreshTokenSynchronously(REFRESH_TOKEN)
} returns RefreshTokenResponseJson.Forbidden(error = Throwable("Fail!")).asSuccess()
assertTrue(repository.refreshAccessTokenSynchronously(USER_ID_1).isFailure)
coVerify(exactly = 1) {
identityService.refreshTokenSynchronously(REFRESH_TOKEN)
userLogoutManager.logout(userId = USER_ID_1, reason = LogoutReason.RefreshForbidden)
}
}
@Suppress("MaxLineLength")
@Test
fun `refreshAccessTokenSynchronously returns logs out and returns failure if refreshTokenSynchronously returns Unauthorized`() =
runTest {
fakeAuthDiskSource.userState = SINGLE_USER_STATE_1
fakeAuthDiskSource.storeAccountTokens(
userId = USER_ID_1,
accountTokens = ACCOUNT_TOKENS_1,
)
coEvery {
identityService.refreshTokenSynchronously(REFRESH_TOKEN)
} returns RefreshTokenResponseJson.Unauthorized(error = Throwable("Fail!")).asSuccess()
assertTrue(repository.refreshAccessTokenSynchronously(USER_ID_1).isFailure)
coVerify(exactly = 1) {
identityService.refreshTokenSynchronously(REFRESH_TOKEN)
userLogoutManager.logout(
userId = USER_ID_1,
reason = LogoutReason.RefreshUnauthorized,
)
}
} }
}
@Test @Test
fun `refreshAccessTokenSynchronously returns success and sets account tokens`() = runTest { fun `refreshAccessTokenSynchronously returns success and sets account tokens`() = runTest {

View File

@ -40,4 +40,18 @@ sealed class RefreshTokenResponseJson {
) : RefreshTokenResponseJson() { ) : RefreshTokenResponseJson() {
val isInvalidGrant: Boolean get() = error == "invalid_grant" val isInvalidGrant: Boolean get() = error == "invalid_grant"
} }
/**
* Models a failure response with a 403 "Forbidden" response code.
*/
data class Forbidden(
val error: Throwable,
) : RefreshTokenResponseJson()
/**
* Models a failure response with a 401 "Unauthorized" response code.
*/
data class Unauthorized(
val error: Throwable,
) : RefreshTokenResponseJson()
} }

View File

@ -20,6 +20,7 @@ import com.bitwarden.network.util.DeviceModelProvider
import com.bitwarden.network.util.NetworkErrorCode import com.bitwarden.network.util.NetworkErrorCode
import com.bitwarden.network.util.base64UrlEncode import com.bitwarden.network.util.base64UrlEncode
import com.bitwarden.network.util.executeForNetworkResult import com.bitwarden.network.util.executeForNetworkResult
import com.bitwarden.network.util.getNetworkErrorCodeOrNull
import com.bitwarden.network.util.parseErrorBodyOrNull import com.bitwarden.network.util.parseErrorBodyOrNull
import com.bitwarden.network.util.toResult import com.bitwarden.network.util.toResult
import kotlinx.serialization.json.Json import kotlinx.serialization.json.Json
@ -131,13 +132,28 @@ internal class IdentityServiceImpl(
.executeForNetworkResult() .executeForNetworkResult()
.toResult() .toResult()
.recoverCatching { throwable -> .recoverCatching { throwable ->
throwable val bitwardenError = throwable.toBitwardenError()
.toBitwardenError() bitwardenError
.parseErrorBodyOrNull<RefreshTokenResponseJson.Error>( .parseErrorBodyOrNull<RefreshTokenResponseJson.Error>(
code = NetworkErrorCode.BAD_REQUEST, code = NetworkErrorCode.BAD_REQUEST,
json = json, json = json,
) )
?: throw throwable ?: run {
when (bitwardenError.getNetworkErrorCodeOrNull()) {
NetworkErrorCode.UNAUTHORIZED -> {
RefreshTokenResponseJson.Unauthorized(throwable)
}
NetworkErrorCode.FORBIDDEN -> {
RefreshTokenResponseJson.Forbidden(throwable)
}
NetworkErrorCode.BAD_REQUEST,
NetworkErrorCode.TOO_MANY_REQUESTS,
null,
-> throw throwable
}
}
} }
override suspend fun registerFinish( override suspend fun registerFinish(

View File

@ -5,6 +5,14 @@ import com.bitwarden.network.model.BitwardenError
import kotlinx.serialization.json.Json import kotlinx.serialization.json.Json
import retrofit2.HttpException import retrofit2.HttpException
/**
* Returns the [NetworkErrorCode] for the given error if it is available.
*/
internal fun BitwardenError.getNetworkErrorCodeOrNull(): NetworkErrorCode? =
(this as? BitwardenError.Http)?.let { httpError ->
NetworkErrorCode.entries.firstOrNull { httpError.code == it.code }
}
/** /**
* Attempt to parse the error body to serializable type [T]. * Attempt to parse the error body to serializable type [T].
* *

View File

@ -7,5 +7,7 @@ internal enum class NetworkErrorCode(
val code: Int, val code: Int,
) { ) {
BAD_REQUEST(code = 400), BAD_REQUEST(code = 400),
UNAUTHORIZED(code = 401),
FORBIDDEN(code = 403),
TOO_MANY_REQUESTS(code = 429), TOO_MANY_REQUESTS(code = 429),
} }

View File

@ -333,6 +333,20 @@ class IdentityServiceTest : BaseServiceTest() {
assertTrue(result.isFailure) assertTrue(result.isFailure)
} }
@Test
fun `refreshTokenSynchronously when response is a 403 error should return an Forbidden`() {
server.enqueue(MockResponse().setResponseCode(403))
val result = identityService.refreshTokenSynchronously(refreshToken = REFRESH_TOKEN)
assertTrue(result.getOrThrow() is RefreshTokenResponseJson.Forbidden)
}
@Test
fun `refreshTokenSynchronously when response is a 401 error should return an Unauthorized`() {
server.enqueue(MockResponse().setResponseCode(401))
val result = identityService.refreshTokenSynchronously(refreshToken = REFRESH_TOKEN)
assertTrue(result.getOrThrow() is RefreshTokenResponseJson.Unauthorized)
}
@Test @Test
fun `registerFinish success json should be Success`() = runTest { fun `registerFinish success json should be Success`() = runTest {
val expectedResponse = RegisterResponseJson.Success( val expectedResponse = RegisterResponseJson.Success(