mirror of
https://github.com/bitwarden/android.git
synced 2026-03-09 03:33:36 -05:00
PM-24440: Log user out for 'invalid_grant' (#5641)
This commit is contained in:
@@ -404,10 +404,7 @@ class AuthRepositoryImpl(
|
||||
.onEach {
|
||||
val userId = activeUserId ?: return@onEach
|
||||
// TODO: [PM-20593] Investigate why tokens are explicitly refreshed.
|
||||
refreshAccessTokenSynchronouslyInternal(
|
||||
userId = userId,
|
||||
logOutOnFailure = false,
|
||||
)
|
||||
refreshAccessTokenSynchronously(userId = userId)
|
||||
vaultRepository.sync(forced = true)
|
||||
}
|
||||
// This requires the ioScope to ensure that refreshAccessTokenSynchronously
|
||||
@@ -760,11 +757,48 @@ class AuthRepositoryImpl(
|
||||
orgIdentifier = organizationIdentifier,
|
||||
)
|
||||
|
||||
override fun refreshAccessTokenSynchronously(userId: String): Result<RefreshTokenResponseJson> =
|
||||
refreshAccessTokenSynchronouslyInternal(
|
||||
userId = userId,
|
||||
logOutOnFailure = true,
|
||||
)
|
||||
override fun refreshAccessTokenSynchronously(
|
||||
userId: String,
|
||||
): Result<String> {
|
||||
val refreshToken = authDiskSource
|
||||
.getAccountTokens(userId = userId)
|
||||
?.refreshToken
|
||||
?: return IllegalStateException("Must be logged in.").asFailure()
|
||||
return identityService
|
||||
.refreshTokenSynchronously(refreshToken)
|
||||
.flatMap { refreshTokenResponse ->
|
||||
// Check to make sure the user is still logged in after making the request
|
||||
authDiskSource
|
||||
.userState
|
||||
?.accounts
|
||||
?.get(userId)
|
||||
?.let { refreshTokenResponse.asSuccess() }
|
||||
?: IllegalStateException("Must be logged in.").asFailure()
|
||||
}
|
||||
.flatMap { refreshTokenResponse ->
|
||||
when (refreshTokenResponse) {
|
||||
is RefreshTokenResponseJson.Error -> {
|
||||
if (refreshTokenResponse.isInvalidGrant) {
|
||||
// We only logout for an invalid grant
|
||||
logout(userId = userId, reason = LogoutReason.InvalidGrant)
|
||||
}
|
||||
IllegalStateException(refreshTokenResponse.error).asFailure()
|
||||
}
|
||||
|
||||
is RefreshTokenResponseJson.Success -> {
|
||||
// Store the new token information
|
||||
authDiskSource.storeAccountTokens(
|
||||
userId = userId,
|
||||
accountTokens = AccountTokensJson(
|
||||
accessToken = refreshTokenResponse.accessToken,
|
||||
refreshToken = refreshTokenResponse.refreshToken,
|
||||
),
|
||||
)
|
||||
refreshTokenResponse.accessToken.asSuccess()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override fun logout(reason: LogoutReason) {
|
||||
activeUserId?.let { userId -> logout(userId = userId, reason = reason) }
|
||||
@@ -1422,42 +1456,6 @@ class AuthRepositoryImpl(
|
||||
onFailure = { LeaveOrganizationResult.Error(error = it) },
|
||||
)
|
||||
|
||||
private fun refreshAccessTokenSynchronouslyInternal(
|
||||
userId: String,
|
||||
logOutOnFailure: Boolean,
|
||||
): Result<RefreshTokenResponseJson> {
|
||||
val refreshToken = authDiskSource
|
||||
.getAccountTokens(userId = userId)
|
||||
?.refreshToken
|
||||
?: return IllegalStateException("Must be logged in.").asFailure()
|
||||
return identityService
|
||||
.refreshTokenSynchronously(refreshToken)
|
||||
.flatMap { refreshTokenResponse ->
|
||||
// Check to make sure the user is still logged in after making the request
|
||||
authDiskSource
|
||||
.userState
|
||||
?.accounts
|
||||
?.get(userId)
|
||||
?.let { refreshTokenResponse.asSuccess() }
|
||||
?: IllegalStateException("Must be logged in.").asFailure()
|
||||
}
|
||||
.onFailure {
|
||||
if (logOutOnFailure) {
|
||||
logout(userId = userId, reason = LogoutReason.TokenRefreshFail)
|
||||
}
|
||||
}
|
||||
.onSuccess { refreshTokenResponse ->
|
||||
// Update the existing UserState with updated token information
|
||||
authDiskSource.storeAccountTokens(
|
||||
userId = userId,
|
||||
accountTokens = AccountTokensJson(
|
||||
accessToken = refreshTokenResponse.accessToken,
|
||||
refreshToken = refreshTokenResponse.refreshToken,
|
||||
),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@Suppress("CyclomaticComplexMethod")
|
||||
private suspend fun validatePasswordAgainstPolicy(
|
||||
password: String,
|
||||
|
||||
@@ -29,6 +29,12 @@ sealed class LogoutReason {
|
||||
data object NoLongerSupported : Biometrics()
|
||||
}
|
||||
|
||||
/**
|
||||
* Indicates that the logout is happening because the there was an "invalid_grant" response
|
||||
* from the network.
|
||||
*/
|
||||
data object InvalidGrant : LogoutReason()
|
||||
|
||||
/**
|
||||
* Indicates that the logout is happening because of an invalid state.
|
||||
*/
|
||||
@@ -58,11 +64,6 @@ sealed class LogoutReason {
|
||||
*/
|
||||
data object Timeout : LogoutReason()
|
||||
|
||||
/**
|
||||
* Indicates that the logout is happening because the access token could not be refreshed.
|
||||
*/
|
||||
data object TokenRefreshFail : LogoutReason()
|
||||
|
||||
/**
|
||||
* Indicates that the logout is happening because the user tried to unlock the vault
|
||||
* unsuccessfully too many times.
|
||||
|
||||
@@ -911,7 +911,7 @@ class AuthRepositoryTest {
|
||||
|
||||
val result = repository.refreshAccessTokenSynchronously(USER_ID_1)
|
||||
|
||||
assertEquals(REFRESH_TOKEN_RESPONSE_JSON.asSuccess(), result)
|
||||
assertEquals(REFRESH_TOKEN_RESPONSE_JSON.accessToken.asSuccess(), result)
|
||||
fakeAuthDiskSource.assertAccountTokens(
|
||||
userId = USER_ID_1,
|
||||
accountTokens = updatedAccountTokens,
|
||||
@@ -6959,7 +6959,7 @@ class AuthRepositoryTest {
|
||||
accessCode = "accessCode",
|
||||
fingerprint = "fingerprint",
|
||||
)
|
||||
private val REFRESH_TOKEN_RESPONSE_JSON = RefreshTokenResponseJson(
|
||||
private val REFRESH_TOKEN_RESPONSE_JSON = RefreshTokenResponseJson.Success(
|
||||
accessToken = ACCESS_TOKEN_2,
|
||||
expiresIn = 3600,
|
||||
refreshToken = REFRESH_TOKEN_2,
|
||||
|
||||
@@ -65,7 +65,7 @@ internal interface UnauthenticatedIdentityApi {
|
||||
@Field(value = "client_id") clientId: String,
|
||||
@Field(value = "refresh_token") refreshToken: String,
|
||||
@Field(value = "grant_type") grantType: String,
|
||||
): Call<RefreshTokenResponseJson>
|
||||
): Call<RefreshTokenResponseJson.Success>
|
||||
|
||||
@POST("/accounts/prelogin")
|
||||
suspend fun preLogin(@Body body: PreLoginRequestJson): NetworkResult<PreLoginResponseJson>
|
||||
|
||||
@@ -39,11 +39,12 @@ internal class RefreshAuthenticator : Authenticator {
|
||||
?.fold(
|
||||
onFailure = { null },
|
||||
onSuccess = { newAccessToken ->
|
||||
response.request
|
||||
response
|
||||
.request
|
||||
.newBuilder()
|
||||
.header(
|
||||
name = HEADER_KEY_AUTHORIZATION,
|
||||
value = HEADER_VALUE_BEARER_PREFIX + newAccessToken.accessToken,
|
||||
value = "$HEADER_VALUE_BEARER_PREFIX$newAccessToken",
|
||||
)
|
||||
.build()
|
||||
},
|
||||
|
||||
@@ -4,24 +4,40 @@ import kotlinx.serialization.SerialName
|
||||
import kotlinx.serialization.Serializable
|
||||
|
||||
/**
|
||||
* Models the response body from the refresh token request.
|
||||
*
|
||||
* @property accessToken The new access token.
|
||||
* @property expiresIn When the new [accessToken] expires.
|
||||
* @property refreshToken The new refresh token.
|
||||
* @property tokenType The type of token the new [accessToken] is.
|
||||
* Represents the JSON response from refreshing the access token.
|
||||
*/
|
||||
@Serializable
|
||||
data class RefreshTokenResponseJson(
|
||||
@SerialName("access_token")
|
||||
val accessToken: String,
|
||||
sealed class RefreshTokenResponseJson {
|
||||
/**
|
||||
* Models a successful response body from the refresh token request.
|
||||
*
|
||||
* @property accessToken The new access token.
|
||||
* @property expiresIn When the new [accessToken] expires.
|
||||
* @property refreshToken The new refresh token.
|
||||
* @property tokenType The type of token the new [accessToken] is.
|
||||
*/
|
||||
@Serializable
|
||||
data class Success(
|
||||
@SerialName("access_token")
|
||||
val accessToken: String,
|
||||
|
||||
@SerialName("expires_in")
|
||||
val expiresIn: Int,
|
||||
@SerialName("expires_in")
|
||||
val expiresIn: Int,
|
||||
|
||||
@SerialName("refresh_token")
|
||||
val refreshToken: String,
|
||||
@SerialName("refresh_token")
|
||||
val refreshToken: String,
|
||||
|
||||
@SerialName("token_type")
|
||||
val tokenType: String,
|
||||
)
|
||||
@SerialName("token_type")
|
||||
val tokenType: String,
|
||||
) : RefreshTokenResponseJson()
|
||||
|
||||
/**
|
||||
* Models a failure response body from the refresh token request.
|
||||
*/
|
||||
@Serializable
|
||||
data class Error(
|
||||
@SerialName("error")
|
||||
val error: String,
|
||||
) : RefreshTokenResponseJson() {
|
||||
val isInvalidGrant: Boolean get() = error == "invalid_grant"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
package com.bitwarden.network.provider
|
||||
|
||||
import com.bitwarden.network.model.RefreshTokenResponseJson
|
||||
|
||||
/**
|
||||
* A provider for all the functionality needed to refresh a user's access token.
|
||||
*/
|
||||
@@ -12,5 +10,5 @@ interface RefreshTokenProvider {
|
||||
* This call is both synchronous and performs a network request. Make sure that you are calling
|
||||
* from an appropriate thread.
|
||||
*/
|
||||
fun refreshAccessTokenSynchronously(userId: String): Result<RefreshTokenResponseJson>
|
||||
fun refreshAccessTokenSynchronously(userId: String): Result<String>
|
||||
}
|
||||
|
||||
@@ -130,6 +130,15 @@ internal class IdentityServiceImpl(
|
||||
)
|
||||
.executeForNetworkResult()
|
||||
.toResult()
|
||||
.recoverCatching { throwable ->
|
||||
throwable
|
||||
.toBitwardenError()
|
||||
.parseErrorBodyOrNull<RefreshTokenResponseJson.Error>(
|
||||
code = NetworkErrorCode.BAD_REQUEST,
|
||||
json = json,
|
||||
)
|
||||
?: throw throwable
|
||||
}
|
||||
|
||||
override suspend fun registerFinish(
|
||||
body: RegisterFinishRequestJson,
|
||||
|
||||
@@ -3,7 +3,6 @@ package com.bitwarden.network.authenticator
|
||||
import com.bitwarden.core.data.util.asFailure
|
||||
import com.bitwarden.core.data.util.asSuccess
|
||||
import com.bitwarden.network.model.JwtTokenDataJson
|
||||
import com.bitwarden.network.model.RefreshTokenResponseJson
|
||||
import com.bitwarden.network.provider.RefreshTokenProvider
|
||||
import com.bitwarden.network.util.parseJwtTokenDataOrNull
|
||||
import io.mockk.every
|
||||
@@ -72,20 +71,13 @@ class RefreshAuthenticatorTests {
|
||||
}
|
||||
}
|
||||
|
||||
@Suppress("MaxLineLength")
|
||||
@Test
|
||||
fun `RefreshAuthenticator returns updated request when refresh is success`() {
|
||||
val newAccessToken = "newAccessToken"
|
||||
val refreshResponse = RefreshTokenResponseJson(
|
||||
accessToken = newAccessToken,
|
||||
expiresIn = 3600,
|
||||
refreshToken = "refreshToken",
|
||||
tokenType = "Bearer",
|
||||
)
|
||||
every { parseJwtTokenDataOrNull(JWT_ACCESS_TOKEN) } returns JTW_TOKEN
|
||||
every {
|
||||
refreshTokenProvider.refreshAccessTokenSynchronously(USER_ID)
|
||||
} returns refreshResponse.asSuccess()
|
||||
} returns newAccessToken.asSuccess()
|
||||
|
||||
val authenticatedRequest = authenticator.authenticate(null, RESPONSE_401)
|
||||
|
||||
|
||||
@@ -312,12 +312,18 @@ class IdentityServiceTest : BaseServiceTest() {
|
||||
assertEquals(PREVALIDATE_SSO_ERROR_BODY.asSuccess(), result)
|
||||
}
|
||||
|
||||
@Suppress("MaxLineLength")
|
||||
@Test
|
||||
fun `refreshTokenSynchronously when response is success should return RefreshTokenResponseJson`() {
|
||||
server.enqueue(MockResponse().setResponseCode(200).setBody(REFRESH_TOKEN_JSON))
|
||||
fun `refreshTokenSynchronously when response is success should return Success`() {
|
||||
server.enqueue(MockResponse().setResponseCode(200).setBody(REFRESH_TOKEN_SUCCESS_JSON))
|
||||
val result = identityService.refreshTokenSynchronously(refreshToken = REFRESH_TOKEN)
|
||||
assertEquals(REFRESH_TOKEN_BODY.asSuccess(), result)
|
||||
assertEquals(REFRESH_TOKEN_SUCCESS_BODY.asSuccess(), result)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `refreshTokenSynchronously when response is error should return Error`() {
|
||||
server.enqueue(MockResponse().setResponseCode(400).setBody(REFRESH_TOKEN_ERROR_JSON))
|
||||
val result = identityService.refreshTokenSynchronously(refreshToken = REFRESH_TOKEN)
|
||||
assertEquals(REFRESH_TOKEN_ERROR_BODY.asSuccess(), result)
|
||||
}
|
||||
|
||||
@Test
|
||||
@@ -520,7 +526,17 @@ private val PREVALIDATE_SSO_ERROR_BODY = PrevalidateSsoResponseJson.Error(
|
||||
message = "Organization not found from identifier.",
|
||||
)
|
||||
|
||||
private const val REFRESH_TOKEN_JSON = """
|
||||
private const val REFRESH_TOKEN_ERROR_JSON = """
|
||||
{
|
||||
"error": "invalid_grant"
|
||||
}
|
||||
"""
|
||||
|
||||
private val REFRESH_TOKEN_ERROR_BODY = RefreshTokenResponseJson.Error(
|
||||
error = "invalid_grant",
|
||||
)
|
||||
|
||||
private const val REFRESH_TOKEN_SUCCESS_JSON = """
|
||||
{
|
||||
"access_token": "accessToken",
|
||||
"expires_in": 3600,
|
||||
@@ -529,7 +545,7 @@ private const val REFRESH_TOKEN_JSON = """
|
||||
}
|
||||
"""
|
||||
|
||||
private val REFRESH_TOKEN_BODY = RefreshTokenResponseJson(
|
||||
private val REFRESH_TOKEN_SUCCESS_BODY = RefreshTokenResponseJson.Success(
|
||||
accessToken = "accessToken",
|
||||
expiresIn = 3600,
|
||||
refreshToken = "refreshToken",
|
||||
|
||||
Reference in New Issue
Block a user