[PM-24481] Update AuthTokenInterceptor to refresh token on expiration (#5647)

This commit is contained in:
David Perez
2025-08-06 13:05:07 -05:00
committed by GitHub
parent 60ee129e0b
commit 72250dce90
12 changed files with 205 additions and 22 deletions

View File

@@ -8,6 +8,7 @@ import kotlinx.serialization.Serializable
*
* @property accessToken The user's primary access token.
* @property refreshToken The user's refresh token.
* @property expiresAtSec The time at which the token expires in epoch seconds.
*/
@Serializable
data class AccountTokensJson(
@@ -16,6 +17,9 @@ data class AccountTokensJson(
@SerialName("refreshToken")
val refreshToken: String?,
@SerialName("expiresAtSec")
val expiresAtSec: Long = Long.MAX_VALUE,
) {
/**
* Returns `true` if the user is logged in, `false otherwise.

View File

@@ -1,5 +1,6 @@
package com.x8bit.bitwarden.data.auth.manager
import com.bitwarden.network.model.AuthTokenData
import com.x8bit.bitwarden.data.auth.datasource.disk.AuthDiskSource
/**
@@ -9,9 +10,19 @@ class AuthTokenManagerImpl(
private val authDiskSource: AuthDiskSource,
) : AuthTokenManager {
override fun getActiveAccessTokenOrNull(): String? = authDiskSource
override fun getAuthTokenDataOrNull(): AuthTokenData? = authDiskSource
.userState
?.activeUserId
?.let { authDiskSource.getAccountTokens(it) }
?.accessToken
?.let { userId ->
authDiskSource
.getAccountTokens(userId = userId)
?.takeIf { it.accessToken != null }
?.let {
AuthTokenData(
userId = userId,
accessToken = requireNotNull(it.accessToken),
expiresAtSec = it.expiresAtSec,
)
}
}
}

View File

@@ -146,6 +146,7 @@ import kotlinx.coroutines.flow.onEach
import kotlinx.coroutines.flow.receiveAsFlow
import kotlinx.coroutines.flow.stateIn
import kotlinx.coroutines.flow.update
import java.time.Clock
import javax.inject.Singleton
/**
@@ -154,6 +155,7 @@ import javax.inject.Singleton
@Suppress("LargeClass", "LongParameterList", "TooManyFunctions")
@Singleton
class AuthRepositoryImpl(
private val clock: Clock,
private val accountsService: AccountsService,
private val devicesService: DevicesService,
private val haveIBeenPwnedService: HaveIBeenPwnedService,
@@ -792,6 +794,8 @@ class AuthRepositoryImpl(
accountTokens = AccountTokensJson(
accessToken = refreshTokenResponse.accessToken,
refreshToken = refreshTokenResponse.refreshToken,
expiresAtSec = clock.instant().epochSecond +
refreshTokenResponse.expiresIn,
),
)
refreshTokenResponse.accessToken.asSuccess()
@@ -1778,6 +1782,7 @@ class AuthRepositoryImpl(
accountTokens = AccountTokensJson(
accessToken = loginResponse.accessToken,
refreshToken = loginResponse.refreshToken,
expiresAtSec = clock.instant().epochSecond + loginResponse.expiresInSeconds,
),
)
settingsRepository.hasUserLoggedInOrCreatedAccount = true

View File

@@ -27,6 +27,7 @@ import dagger.Module
import dagger.Provides
import dagger.hilt.InstallIn
import dagger.hilt.components.SingletonComponent
import java.time.Clock
import javax.inject.Singleton
/**
@@ -39,6 +40,7 @@ object AuthRepositoryModule {
@Provides
@Singleton
fun providesAuthRepository(
clock: Clock,
accountsService: AccountsService,
devicesService: DevicesService,
identityService: IdentityService,
@@ -61,6 +63,7 @@ object AuthRepositoryModule {
firstTimeActionManager: FirstTimeActionManager,
logsManager: LogsManager,
): AuthRepository = AuthRepositoryImpl(
clock = clock,
accountsService = accountsService,
devicesService = devicesService,
identityService = identityService,

View File

@@ -1,5 +1,6 @@
package com.x8bit.bitwarden.data.auth.manager
import com.bitwarden.network.model.AuthTokenData
import com.bitwarden.network.model.KdfTypeJson
import com.x8bit.bitwarden.data.auth.datasource.disk.model.AccountJson
import com.x8bit.bitwarden.data.auth.datasource.disk.model.AccountTokensJson
@@ -18,7 +19,7 @@ class AuthTokenManagerTest {
@Test
fun `UserState is null`() {
fakeAuthDiskSource.userState = null
assertNull(authTokenManager.getActiveAccessTokenOrNull())
assertNull(authTokenManager.getAuthTokenDataOrNull())
}
@Test
@@ -29,7 +30,7 @@ class AuthTokenManagerTest {
USER_ID to ACCOUNT.copy(tokens = null),
),
)
assertNull(authTokenManager.getActiveAccessTokenOrNull())
assertNull(authTokenManager.getAuthTokenDataOrNull())
}
@Test
@@ -45,7 +46,21 @@ class AuthTokenManagerTest {
),
),
)
assertNull(authTokenManager.getActiveAccessTokenOrNull())
assertNull(authTokenManager.getAuthTokenDataOrNull())
}
@Test
fun `getActiveAccessTokenOrNull should return null if user access token is null`() {
fakeAuthDiskSource.userState = SINGLE_USER_STATE
fakeAuthDiskSource.storeAccountTokens(
userId = USER_ID,
accountTokens = AccountTokensJson(
accessToken = null,
refreshToken = REFRESH_TOKEN,
expiresAtSec = EXPIRES_AT_SEC,
),
)
assertNull(authTokenManager.getAuthTokenDataOrNull())
}
@Test
@@ -56,11 +71,16 @@ class AuthTokenManagerTest {
accountTokens = AccountTokensJson(
accessToken = ACCESS_TOKEN,
refreshToken = REFRESH_TOKEN,
expiresAtSec = EXPIRES_AT_SEC,
),
)
assertEquals(
ACCESS_TOKEN,
authTokenManager.getActiveAccessTokenOrNull(),
AuthTokenData(
userId = USER_ID,
accessToken = ACCESS_TOKEN,
expiresAtSec = EXPIRES_AT_SEC,
),
authTokenManager.getAuthTokenDataOrNull(),
)
}
}
@@ -69,6 +89,7 @@ private const val EMAIL: String = "test@bitwarden.com"
private const val USER_ID: String = "2a135b23-e1fb-42c9-bec3-573857bc8181"
private const val ACCESS_TOKEN: String = "accessToken"
private const val REFRESH_TOKEN: String = "refreshToken"
private const val EXPIRES_AT_SEC: Long = 3600
private val ACCOUNT: AccountJson = AccountJson(
profile = AccountJson.Profile(
userId = USER_ID,
@@ -91,6 +112,7 @@ private val ACCOUNT: AccountJson = AccountJson(
tokens = AccountTokensJson(
accessToken = ACCESS_TOKEN,
refreshToken = REFRESH_TOKEN,
expiresAtSec = EXPIRES_AT_SEC,
),
settings = AccountJson.Settings(
environmentUrlData = null,

View File

@@ -154,6 +154,9 @@ import org.junit.jupiter.api.Assertions.assertTrue
import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Test
import org.junit.jupiter.api.assertDoesNotThrow
import java.time.Clock
import java.time.Instant
import java.time.ZoneOffset
import java.time.ZonedDateTime
import javax.net.ssl.SSLHandshakeException
@@ -259,7 +262,8 @@ class AuthRepositoryTest {
every { setUserData(userId = any(), environmentType = any()) } just runs
}
private val repository = AuthRepositoryImpl(
private val repository: AuthRepository = AuthRepositoryImpl(
clock = FIXED_CLOCK,
accountsService = accountsService,
devicesService = devicesService,
identityService = identityService,
@@ -899,6 +903,7 @@ class AuthRepositoryTest {
val updatedAccountTokens = AccountTokensJson(
accessToken = ACCESS_TOKEN_2,
refreshToken = REFRESH_TOKEN_2,
expiresAtSec = FIXED_CLOCK.instant().epochSecond + ACCESS_TOKEN_2_EXPIRES_IN,
)
fakeAuthDiskSource.storeAccountTokens(
userId = USER_ID_1,
@@ -6908,6 +6913,10 @@ class AuthRepositoryTest {
}
companion object {
private val FIXED_CLOCK: Clock = Clock.fixed(
Instant.parse("2023-10-27T12:00:00Z"),
ZoneOffset.UTC,
)
private const val UNIQUE_APP_ID = "testUniqueAppId"
private const val NAME = "Example Name"
private const val EMAIL = "test@bitwarden.com"
@@ -6919,6 +6928,7 @@ class AuthRepositoryTest {
private const val ACCESS_TOKEN_2 = "accessToken2"
private const val REFRESH_TOKEN = "refreshToken"
private const val REFRESH_TOKEN_2 = "refreshToken2"
private const val ACCESS_TOKEN_2_EXPIRES_IN = 3600
private const val CAPTCHA_KEY = "captcha"
private const val TWO_FACTOR_CODE = "123456"
private val TWO_FACTOR_METHOD = TwoFactorAuthMethod.EMAIL
@@ -6961,7 +6971,7 @@ class AuthRepositoryTest {
)
private val REFRESH_TOKEN_RESPONSE_JSON = RefreshTokenResponseJson.Success(
accessToken = ACCESS_TOKEN_2,
expiresIn = 3600,
expiresIn = ACCESS_TOKEN_2_EXPIRES_IN,
refreshToken = REFRESH_TOKEN_2,
tokenType = "Bearer",
)

View File

@@ -9,6 +9,7 @@ import com.bitwarden.network.BitwardenServiceClient
import com.bitwarden.network.bitwardenServiceClient
import com.bitwarden.network.interceptor.AuthTokenProvider
import com.bitwarden.network.interceptor.BaseUrlsProvider
import com.bitwarden.network.model.AuthTokenData
import com.bitwarden.network.model.BitwardenServiceClientConfig
import com.bitwarden.network.service.ConfigService
import com.bitwarden.network.ssl.CertificateProvider
@@ -54,7 +55,7 @@ object PlatformNetworkModule {
baseUrlsProvider = baseUrlsProvider,
enableHttpBodyLogging = BuildConfig.DEBUG,
authTokenProvider = object : AuthTokenProvider {
override fun getActiveAccessTokenOrNull(): String? = null
override fun getAuthTokenDataOrNull(): AuthTokenData? = null
},
certificateProvider = object : CertificateProvider {
override fun chooseClientAlias(

View File

@@ -53,6 +53,10 @@ internal class BitwardenServiceClientImpl(
) : BitwardenServiceClient {
private val refreshAuthenticator: RefreshAuthenticator = RefreshAuthenticator()
private val authTokenInterceptor: AuthTokenInterceptor = AuthTokenInterceptor(
clock = bitwardenServiceClientConfig.clock,
authTokenProvider = bitwardenServiceClientConfig.authTokenProvider,
)
private val clientJson = Json {
// If there are keys returned by the server not modeled by a serializable class,
@@ -71,9 +75,7 @@ internal class BitwardenServiceClientImpl(
}
private val retrofits: Retrofits by lazy {
RetrofitsImpl(
authTokenInterceptor = AuthTokenInterceptor(
authTokenProvider = bitwardenServiceClientConfig.authTokenProvider,
),
authTokenInterceptor = authTokenInterceptor,
baseUrlInterceptors = BaseUrlInterceptors(
baseUrlsProvider = bitwardenServiceClientConfig.baseUrlsProvider,
),
@@ -205,5 +207,6 @@ internal class BitwardenServiceClientImpl(
override fun setRefreshTokenProvider(refreshTokenProvider: RefreshTokenProvider?) {
refreshAuthenticator.refreshTokenProvider = refreshTokenProvider
authTokenInterceptor.refreshTokenProvider = refreshTokenProvider
}
}

View File

@@ -1,22 +1,45 @@
package com.bitwarden.network.interceptor
import com.bitwarden.network.provider.RefreshTokenProvider
import com.bitwarden.network.util.HEADER_KEY_AUTHORIZATION
import com.bitwarden.network.util.HEADER_VALUE_BEARER_PREFIX
import okhttp3.Interceptor
import okhttp3.Response
import timber.log.Timber
import java.io.IOException
import java.time.Clock
import java.time.Instant
import java.time.temporal.ChronoUnit
private const val MISSING_TOKEN_MESSAGE: String = "Auth token is missing!"
private const val MISSING_PROVIDER_MESSAGE: String = "Refresh token provider is missing!"
private const val EXPIRATION_OFFSET_MINUTES: Long = 5L
/**
* Interceptor responsible for adding the auth token(Bearer) to API requests.
*/
internal class AuthTokenInterceptor(
private val clock: Clock,
private val authTokenProvider: AuthTokenProvider,
) : Interceptor {
private val missingTokenMessage = "Auth token is missing!"
var refreshTokenProvider: RefreshTokenProvider? = null
override fun intercept(chain: Interceptor.Chain): Response {
val token = authTokenProvider.getActiveAccessTokenOrNull()
?: throw IOException(IllegalStateException(missingTokenMessage))
val tokenData = authTokenProvider
.getAuthTokenDataOrNull()
?: throw IOException(IllegalStateException(MISSING_TOKEN_MESSAGE))
val expirationTime = Instant
.ofEpochSecond(tokenData.expiresAtSec)
.minus(EXPIRATION_OFFSET_MINUTES, ChronoUnit.MINUTES)
val token = if (clock.instant().isAfter(expirationTime)) {
Timber.d("Attempting to refresh token due to expiration")
refreshTokenProvider
?.refreshAccessTokenSynchronously(userId = tokenData.userId)
?.getOrElse { throw IOException(it) }
?: throw IOException(IllegalStateException(MISSING_PROVIDER_MESSAGE))
} else {
tokenData.accessToken
}
val request = chain
.request()
.newBuilder()

View File

@@ -1,12 +1,14 @@
package com.bitwarden.network.interceptor
import com.bitwarden.network.model.AuthTokenData
/**
* A provider for all the functionality needed to properly refresh the users access token.
*/
interface AuthTokenProvider {
/**
* The currently active user's access token.
* The currently active user's auth token data.
*/
fun getActiveAccessTokenOrNull(): String?
fun getAuthTokenDataOrNull(): AuthTokenData?
}

View File

@@ -0,0 +1,10 @@
package com.bitwarden.network.model
/**
* Contains the access token and expiration data for a user.
*/
data class AuthTokenData(
val userId: String,
val accessToken: String,
val expiresAtSec: Long,
)

View File

@@ -1,5 +1,9 @@
package com.bitwarden.network.interceptor
import com.bitwarden.core.data.util.asFailure
import com.bitwarden.core.data.util.asSuccess
import com.bitwarden.network.model.AuthTokenData
import com.bitwarden.network.provider.RefreshTokenProvider
import io.mockk.every
import io.mockk.mockk
import junit.framework.TestCase.assertEquals
@@ -7,12 +11,16 @@ import okhttp3.Request
import org.junit.Assert.assertThrows
import org.junit.Test
import java.io.IOException
import java.time.Clock
import java.time.Instant
import java.time.ZoneOffset
class AuthTokenInterceptorTest {
private val mockAuthTokenProvider = mockk<AuthTokenProvider> {
every { getActiveAccessTokenOrNull() } returns null
every { getAuthTokenDataOrNull() } returns null
}
private val interceptor: AuthTokenInterceptor = AuthTokenInterceptor(
clock = FIXED_CLOCK,
authTokenProvider = mockAuthTokenProvider,
)
private val request: Request = Request
@@ -22,7 +30,12 @@ class AuthTokenInterceptorTest {
@Test
fun `intercept should add the auth token when set`() {
every { mockAuthTokenProvider.getActiveAccessTokenOrNull() } returns ACCESS_TOKEN
val authTokenData = AuthTokenData(
userId = USER_ID,
accessToken = ACCESS_TOKEN,
expiresAtSec = FIXED_CLOCK.instant().epochSecond + 3600L,
)
every { mockAuthTokenProvider.getAuthTokenDataOrNull() } returns authTokenData
val response = interceptor.intercept(
chain = FakeInterceptorChain(request = request),
@@ -33,8 +46,78 @@ class AuthTokenInterceptorTest {
)
}
@Suppress("MaxLineLength")
@Test
fun `intercept should throw an exception when an auth token is missing`() {
fun `intercept should throw an exception when auth token is expired and refreshTokenProvider is missing`() {
val authTokenData = AuthTokenData(
userId = USER_ID,
accessToken = ACCESS_TOKEN,
expiresAtSec = FIXED_CLOCK.instant().epochSecond - 3600L,
)
every { mockAuthTokenProvider.getAuthTokenDataOrNull() } returns authTokenData
val throwable = assertThrows(IOException::class.java) {
interceptor.intercept(
chain = FakeInterceptorChain(request = request),
)
}
assertEquals(
"Refresh token provider is missing!",
throwable.cause?.message,
)
}
@Suppress("MaxLineLength")
@Test
fun `intercept should throw an exception when auth token is expired and refreshAccessTokenSynchronously returns an error`() {
val errorMessage = "Fail!"
interceptor.refreshTokenProvider = object : RefreshTokenProvider {
override fun refreshAccessTokenSynchronously(
userId: String,
): Result<String> = Throwable(errorMessage).asFailure()
}
val authTokenData = AuthTokenData(
userId = USER_ID,
accessToken = ACCESS_TOKEN,
expiresAtSec = FIXED_CLOCK.instant().epochSecond - 3600L,
)
every { mockAuthTokenProvider.getAuthTokenDataOrNull() } returns authTokenData
val throwable = assertThrows(IOException::class.java) {
interceptor.intercept(
chain = FakeInterceptorChain(request = request),
)
}
assertEquals(errorMessage, throwable.cause?.message)
}
@Suppress("MaxLineLength")
@Test
fun `intercept should add the auth token when auth token is expired and refreshAccessTokenSynchronously returns new token`() {
val token = "token"
interceptor.refreshTokenProvider = object : RefreshTokenProvider {
override fun refreshAccessTokenSynchronously(
userId: String,
): Result<String> = token.asSuccess()
}
val authTokenData = AuthTokenData(
userId = USER_ID,
accessToken = ACCESS_TOKEN,
expiresAtSec = FIXED_CLOCK.instant().epochSecond - 3600L,
)
every { mockAuthTokenProvider.getAuthTokenDataOrNull() } returns authTokenData
val response = interceptor.intercept(
chain = FakeInterceptorChain(request = request),
)
assertEquals(
"Bearer $token",
response.request.header("Authorization"),
)
}
@Test
fun `intercept should throw an exception when an auth token data is missing`() {
val throwable = assertThrows(IOException::class.java) {
interceptor.intercept(
chain = FakeInterceptorChain(request = request),
@@ -47,4 +130,10 @@ class AuthTokenInterceptorTest {
}
}
private val FIXED_CLOCK: Clock = Clock.fixed(
Instant.parse("2023-10-27T12:00:00Z"),
ZoneOffset.UTC,
)
private const val USER_ID: String = "user_id"
private const val ACCESS_TOKEN: String = "access_token"