PM-24440: Log user out for 'invalid_grant' (#5641)

This commit is contained in:
David Perez
2025-08-04 14:13:44 -05:00
committed by GitHub
parent e381d72d5c
commit 38b92133ff
10 changed files with 121 additions and 90 deletions

View File

@@ -404,10 +404,7 @@ class AuthRepositoryImpl(
.onEach {
val userId = activeUserId ?: return@onEach
// TODO: [PM-20593] Investigate why tokens are explicitly refreshed.
refreshAccessTokenSynchronouslyInternal(
userId = userId,
logOutOnFailure = false,
)
refreshAccessTokenSynchronously(userId = userId)
vaultRepository.sync(forced = true)
}
// This requires the ioScope to ensure that refreshAccessTokenSynchronously
@@ -760,11 +757,48 @@ class AuthRepositoryImpl(
orgIdentifier = organizationIdentifier,
)
override fun refreshAccessTokenSynchronously(userId: String): Result<RefreshTokenResponseJson> =
refreshAccessTokenSynchronouslyInternal(
userId = userId,
logOutOnFailure = true,
)
override fun refreshAccessTokenSynchronously(
userId: String,
): Result<String> {
val refreshToken = authDiskSource
.getAccountTokens(userId = userId)
?.refreshToken
?: return IllegalStateException("Must be logged in.").asFailure()
return identityService
.refreshTokenSynchronously(refreshToken)
.flatMap { refreshTokenResponse ->
// Check to make sure the user is still logged in after making the request
authDiskSource
.userState
?.accounts
?.get(userId)
?.let { refreshTokenResponse.asSuccess() }
?: IllegalStateException("Must be logged in.").asFailure()
}
.flatMap { refreshTokenResponse ->
when (refreshTokenResponse) {
is RefreshTokenResponseJson.Error -> {
if (refreshTokenResponse.isInvalidGrant) {
// We only logout for an invalid grant
logout(userId = userId, reason = LogoutReason.InvalidGrant)
}
IllegalStateException(refreshTokenResponse.error).asFailure()
}
is RefreshTokenResponseJson.Success -> {
// Store the new token information
authDiskSource.storeAccountTokens(
userId = userId,
accountTokens = AccountTokensJson(
accessToken = refreshTokenResponse.accessToken,
refreshToken = refreshTokenResponse.refreshToken,
),
)
refreshTokenResponse.accessToken.asSuccess()
}
}
}
}
override fun logout(reason: LogoutReason) {
activeUserId?.let { userId -> logout(userId = userId, reason = reason) }
@@ -1422,42 +1456,6 @@ class AuthRepositoryImpl(
onFailure = { LeaveOrganizationResult.Error(error = it) },
)
private fun refreshAccessTokenSynchronouslyInternal(
userId: String,
logOutOnFailure: Boolean,
): Result<RefreshTokenResponseJson> {
val refreshToken = authDiskSource
.getAccountTokens(userId = userId)
?.refreshToken
?: return IllegalStateException("Must be logged in.").asFailure()
return identityService
.refreshTokenSynchronously(refreshToken)
.flatMap { refreshTokenResponse ->
// Check to make sure the user is still logged in after making the request
authDiskSource
.userState
?.accounts
?.get(userId)
?.let { refreshTokenResponse.asSuccess() }
?: IllegalStateException("Must be logged in.").asFailure()
}
.onFailure {
if (logOutOnFailure) {
logout(userId = userId, reason = LogoutReason.TokenRefreshFail)
}
}
.onSuccess { refreshTokenResponse ->
// Update the existing UserState with updated token information
authDiskSource.storeAccountTokens(
userId = userId,
accountTokens = AccountTokensJson(
accessToken = refreshTokenResponse.accessToken,
refreshToken = refreshTokenResponse.refreshToken,
),
)
}
}
@Suppress("CyclomaticComplexMethod")
private suspend fun validatePasswordAgainstPolicy(
password: String,

View File

@@ -29,6 +29,12 @@ sealed class LogoutReason {
data object NoLongerSupported : Biometrics()
}
/**
* Indicates that the logout is happening because the there was an "invalid_grant" response
* from the network.
*/
data object InvalidGrant : LogoutReason()
/**
* Indicates that the logout is happening because of an invalid state.
*/
@@ -58,11 +64,6 @@ sealed class LogoutReason {
*/
data object Timeout : LogoutReason()
/**
* Indicates that the logout is happening because the access token could not be refreshed.
*/
data object TokenRefreshFail : LogoutReason()
/**
* Indicates that the logout is happening because the user tried to unlock the vault
* unsuccessfully too many times.

View File

@@ -911,7 +911,7 @@ class AuthRepositoryTest {
val result = repository.refreshAccessTokenSynchronously(USER_ID_1)
assertEquals(REFRESH_TOKEN_RESPONSE_JSON.asSuccess(), result)
assertEquals(REFRESH_TOKEN_RESPONSE_JSON.accessToken.asSuccess(), result)
fakeAuthDiskSource.assertAccountTokens(
userId = USER_ID_1,
accountTokens = updatedAccountTokens,
@@ -6959,7 +6959,7 @@ class AuthRepositoryTest {
accessCode = "accessCode",
fingerprint = "fingerprint",
)
private val REFRESH_TOKEN_RESPONSE_JSON = RefreshTokenResponseJson(
private val REFRESH_TOKEN_RESPONSE_JSON = RefreshTokenResponseJson.Success(
accessToken = ACCESS_TOKEN_2,
expiresIn = 3600,
refreshToken = REFRESH_TOKEN_2,

View File

@@ -65,7 +65,7 @@ internal interface UnauthenticatedIdentityApi {
@Field(value = "client_id") clientId: String,
@Field(value = "refresh_token") refreshToken: String,
@Field(value = "grant_type") grantType: String,
): Call<RefreshTokenResponseJson>
): Call<RefreshTokenResponseJson.Success>
@POST("/accounts/prelogin")
suspend fun preLogin(@Body body: PreLoginRequestJson): NetworkResult<PreLoginResponseJson>

View File

@@ -39,11 +39,12 @@ internal class RefreshAuthenticator : Authenticator {
?.fold(
onFailure = { null },
onSuccess = { newAccessToken ->
response.request
response
.request
.newBuilder()
.header(
name = HEADER_KEY_AUTHORIZATION,
value = HEADER_VALUE_BEARER_PREFIX + newAccessToken.accessToken,
value = "$HEADER_VALUE_BEARER_PREFIX$newAccessToken",
)
.build()
},

View File

@@ -4,24 +4,40 @@ import kotlinx.serialization.SerialName
import kotlinx.serialization.Serializable
/**
* Models the response body from the refresh token request.
*
* @property accessToken The new access token.
* @property expiresIn When the new [accessToken] expires.
* @property refreshToken The new refresh token.
* @property tokenType The type of token the new [accessToken] is.
* Represents the JSON response from refreshing the access token.
*/
@Serializable
data class RefreshTokenResponseJson(
@SerialName("access_token")
val accessToken: String,
sealed class RefreshTokenResponseJson {
/**
* Models a successful response body from the refresh token request.
*
* @property accessToken The new access token.
* @property expiresIn When the new [accessToken] expires.
* @property refreshToken The new refresh token.
* @property tokenType The type of token the new [accessToken] is.
*/
@Serializable
data class Success(
@SerialName("access_token")
val accessToken: String,
@SerialName("expires_in")
val expiresIn: Int,
@SerialName("expires_in")
val expiresIn: Int,
@SerialName("refresh_token")
val refreshToken: String,
@SerialName("refresh_token")
val refreshToken: String,
@SerialName("token_type")
val tokenType: String,
)
@SerialName("token_type")
val tokenType: String,
) : RefreshTokenResponseJson()
/**
* Models a failure response body from the refresh token request.
*/
@Serializable
data class Error(
@SerialName("error")
val error: String,
) : RefreshTokenResponseJson() {
val isInvalidGrant: Boolean get() = error == "invalid_grant"
}
}

View File

@@ -1,7 +1,5 @@
package com.bitwarden.network.provider
import com.bitwarden.network.model.RefreshTokenResponseJson
/**
* A provider for all the functionality needed to refresh a user's access token.
*/
@@ -12,5 +10,5 @@ interface RefreshTokenProvider {
* This call is both synchronous and performs a network request. Make sure that you are calling
* from an appropriate thread.
*/
fun refreshAccessTokenSynchronously(userId: String): Result<RefreshTokenResponseJson>
fun refreshAccessTokenSynchronously(userId: String): Result<String>
}

View File

@@ -130,6 +130,15 @@ internal class IdentityServiceImpl(
)
.executeForNetworkResult()
.toResult()
.recoverCatching { throwable ->
throwable
.toBitwardenError()
.parseErrorBodyOrNull<RefreshTokenResponseJson.Error>(
code = NetworkErrorCode.BAD_REQUEST,
json = json,
)
?: throw throwable
}
override suspend fun registerFinish(
body: RegisterFinishRequestJson,

View File

@@ -3,7 +3,6 @@ package com.bitwarden.network.authenticator
import com.bitwarden.core.data.util.asFailure
import com.bitwarden.core.data.util.asSuccess
import com.bitwarden.network.model.JwtTokenDataJson
import com.bitwarden.network.model.RefreshTokenResponseJson
import com.bitwarden.network.provider.RefreshTokenProvider
import com.bitwarden.network.util.parseJwtTokenDataOrNull
import io.mockk.every
@@ -72,20 +71,13 @@ class RefreshAuthenticatorTests {
}
}
@Suppress("MaxLineLength")
@Test
fun `RefreshAuthenticator returns updated request when refresh is success`() {
val newAccessToken = "newAccessToken"
val refreshResponse = RefreshTokenResponseJson(
accessToken = newAccessToken,
expiresIn = 3600,
refreshToken = "refreshToken",
tokenType = "Bearer",
)
every { parseJwtTokenDataOrNull(JWT_ACCESS_TOKEN) } returns JTW_TOKEN
every {
refreshTokenProvider.refreshAccessTokenSynchronously(USER_ID)
} returns refreshResponse.asSuccess()
} returns newAccessToken.asSuccess()
val authenticatedRequest = authenticator.authenticate(null, RESPONSE_401)

View File

@@ -312,12 +312,18 @@ class IdentityServiceTest : BaseServiceTest() {
assertEquals(PREVALIDATE_SSO_ERROR_BODY.asSuccess(), result)
}
@Suppress("MaxLineLength")
@Test
fun `refreshTokenSynchronously when response is success should return RefreshTokenResponseJson`() {
server.enqueue(MockResponse().setResponseCode(200).setBody(REFRESH_TOKEN_JSON))
fun `refreshTokenSynchronously when response is success should return Success`() {
server.enqueue(MockResponse().setResponseCode(200).setBody(REFRESH_TOKEN_SUCCESS_JSON))
val result = identityService.refreshTokenSynchronously(refreshToken = REFRESH_TOKEN)
assertEquals(REFRESH_TOKEN_BODY.asSuccess(), result)
assertEquals(REFRESH_TOKEN_SUCCESS_BODY.asSuccess(), result)
}
@Test
fun `refreshTokenSynchronously when response is error should return Error`() {
server.enqueue(MockResponse().setResponseCode(400).setBody(REFRESH_TOKEN_ERROR_JSON))
val result = identityService.refreshTokenSynchronously(refreshToken = REFRESH_TOKEN)
assertEquals(REFRESH_TOKEN_ERROR_BODY.asSuccess(), result)
}
@Test
@@ -520,7 +526,17 @@ private val PREVALIDATE_SSO_ERROR_BODY = PrevalidateSsoResponseJson.Error(
message = "Organization not found from identifier.",
)
private const val REFRESH_TOKEN_JSON = """
private const val REFRESH_TOKEN_ERROR_JSON = """
{
"error": "invalid_grant"
}
"""
private val REFRESH_TOKEN_ERROR_BODY = RefreshTokenResponseJson.Error(
error = "invalid_grant",
)
private const val REFRESH_TOKEN_SUCCESS_JSON = """
{
"access_token": "accessToken",
"expires_in": 3600,
@@ -529,7 +545,7 @@ private const val REFRESH_TOKEN_JSON = """
}
"""
private val REFRESH_TOKEN_BODY = RefreshTokenResponseJson(
private val REFRESH_TOKEN_SUCCESS_BODY = RefreshTokenResponseJson.Success(
accessToken = "accessToken",
expiresIn = 3600,
refreshToken = "refreshToken",