[PM-24481] Update AuthTokenInterceptor to refresh token on expiration (#5647)

This commit is contained in:
David Perez
2025-08-06 13:05:07 -05:00
committed by GitHub
parent 60ee129e0b
commit 72250dce90
12 changed files with 205 additions and 22 deletions

View File

@@ -53,6 +53,10 @@ internal class BitwardenServiceClientImpl(
) : BitwardenServiceClient {
private val refreshAuthenticator: RefreshAuthenticator = RefreshAuthenticator()
private val authTokenInterceptor: AuthTokenInterceptor = AuthTokenInterceptor(
clock = bitwardenServiceClientConfig.clock,
authTokenProvider = bitwardenServiceClientConfig.authTokenProvider,
)
private val clientJson = Json {
// If there are keys returned by the server not modeled by a serializable class,
@@ -71,9 +75,7 @@ internal class BitwardenServiceClientImpl(
}
private val retrofits: Retrofits by lazy {
RetrofitsImpl(
authTokenInterceptor = AuthTokenInterceptor(
authTokenProvider = bitwardenServiceClientConfig.authTokenProvider,
),
authTokenInterceptor = authTokenInterceptor,
baseUrlInterceptors = BaseUrlInterceptors(
baseUrlsProvider = bitwardenServiceClientConfig.baseUrlsProvider,
),
@@ -205,5 +207,6 @@ internal class BitwardenServiceClientImpl(
override fun setRefreshTokenProvider(refreshTokenProvider: RefreshTokenProvider?) {
refreshAuthenticator.refreshTokenProvider = refreshTokenProvider
authTokenInterceptor.refreshTokenProvider = refreshTokenProvider
}
}

View File

@@ -1,22 +1,45 @@
package com.bitwarden.network.interceptor
import com.bitwarden.network.provider.RefreshTokenProvider
import com.bitwarden.network.util.HEADER_KEY_AUTHORIZATION
import com.bitwarden.network.util.HEADER_VALUE_BEARER_PREFIX
import okhttp3.Interceptor
import okhttp3.Response
import timber.log.Timber
import java.io.IOException
import java.time.Clock
import java.time.Instant
import java.time.temporal.ChronoUnit
private const val MISSING_TOKEN_MESSAGE: String = "Auth token is missing!"
private const val MISSING_PROVIDER_MESSAGE: String = "Refresh token provider is missing!"
private const val EXPIRATION_OFFSET_MINUTES: Long = 5L
/**
* Interceptor responsible for adding the auth token(Bearer) to API requests.
*/
internal class AuthTokenInterceptor(
private val clock: Clock,
private val authTokenProvider: AuthTokenProvider,
) : Interceptor {
private val missingTokenMessage = "Auth token is missing!"
var refreshTokenProvider: RefreshTokenProvider? = null
override fun intercept(chain: Interceptor.Chain): Response {
val token = authTokenProvider.getActiveAccessTokenOrNull()
?: throw IOException(IllegalStateException(missingTokenMessage))
val tokenData = authTokenProvider
.getAuthTokenDataOrNull()
?: throw IOException(IllegalStateException(MISSING_TOKEN_MESSAGE))
val expirationTime = Instant
.ofEpochSecond(tokenData.expiresAtSec)
.minus(EXPIRATION_OFFSET_MINUTES, ChronoUnit.MINUTES)
val token = if (clock.instant().isAfter(expirationTime)) {
Timber.d("Attempting to refresh token due to expiration")
refreshTokenProvider
?.refreshAccessTokenSynchronously(userId = tokenData.userId)
?.getOrElse { throw IOException(it) }
?: throw IOException(IllegalStateException(MISSING_PROVIDER_MESSAGE))
} else {
tokenData.accessToken
}
val request = chain
.request()
.newBuilder()

View File

@@ -1,12 +1,14 @@
package com.bitwarden.network.interceptor
import com.bitwarden.network.model.AuthTokenData
/**
* A provider for all the functionality needed to properly refresh the users access token.
*/
interface AuthTokenProvider {
/**
* The currently active user's access token.
* The currently active user's auth token data.
*/
fun getActiveAccessTokenOrNull(): String?
fun getAuthTokenDataOrNull(): AuthTokenData?
}

View File

@@ -0,0 +1,10 @@
package com.bitwarden.network.model
/**
* Contains the access token and expiration data for a user.
*/
data class AuthTokenData(
val userId: String,
val accessToken: String,
val expiresAtSec: Long,
)

View File

@@ -1,5 +1,9 @@
package com.bitwarden.network.interceptor
import com.bitwarden.core.data.util.asFailure
import com.bitwarden.core.data.util.asSuccess
import com.bitwarden.network.model.AuthTokenData
import com.bitwarden.network.provider.RefreshTokenProvider
import io.mockk.every
import io.mockk.mockk
import junit.framework.TestCase.assertEquals
@@ -7,12 +11,16 @@ import okhttp3.Request
import org.junit.Assert.assertThrows
import org.junit.Test
import java.io.IOException
import java.time.Clock
import java.time.Instant
import java.time.ZoneOffset
class AuthTokenInterceptorTest {
private val mockAuthTokenProvider = mockk<AuthTokenProvider> {
every { getActiveAccessTokenOrNull() } returns null
every { getAuthTokenDataOrNull() } returns null
}
private val interceptor: AuthTokenInterceptor = AuthTokenInterceptor(
clock = FIXED_CLOCK,
authTokenProvider = mockAuthTokenProvider,
)
private val request: Request = Request
@@ -22,7 +30,12 @@ class AuthTokenInterceptorTest {
@Test
fun `intercept should add the auth token when set`() {
every { mockAuthTokenProvider.getActiveAccessTokenOrNull() } returns ACCESS_TOKEN
val authTokenData = AuthTokenData(
userId = USER_ID,
accessToken = ACCESS_TOKEN,
expiresAtSec = FIXED_CLOCK.instant().epochSecond + 3600L,
)
every { mockAuthTokenProvider.getAuthTokenDataOrNull() } returns authTokenData
val response = interceptor.intercept(
chain = FakeInterceptorChain(request = request),
@@ -33,8 +46,78 @@ class AuthTokenInterceptorTest {
)
}
@Suppress("MaxLineLength")
@Test
fun `intercept should throw an exception when an auth token is missing`() {
fun `intercept should throw an exception when auth token is expired and refreshTokenProvider is missing`() {
val authTokenData = AuthTokenData(
userId = USER_ID,
accessToken = ACCESS_TOKEN,
expiresAtSec = FIXED_CLOCK.instant().epochSecond - 3600L,
)
every { mockAuthTokenProvider.getAuthTokenDataOrNull() } returns authTokenData
val throwable = assertThrows(IOException::class.java) {
interceptor.intercept(
chain = FakeInterceptorChain(request = request),
)
}
assertEquals(
"Refresh token provider is missing!",
throwable.cause?.message,
)
}
@Suppress("MaxLineLength")
@Test
fun `intercept should throw an exception when auth token is expired and refreshAccessTokenSynchronously returns an error`() {
val errorMessage = "Fail!"
interceptor.refreshTokenProvider = object : RefreshTokenProvider {
override fun refreshAccessTokenSynchronously(
userId: String,
): Result<String> = Throwable(errorMessage).asFailure()
}
val authTokenData = AuthTokenData(
userId = USER_ID,
accessToken = ACCESS_TOKEN,
expiresAtSec = FIXED_CLOCK.instant().epochSecond - 3600L,
)
every { mockAuthTokenProvider.getAuthTokenDataOrNull() } returns authTokenData
val throwable = assertThrows(IOException::class.java) {
interceptor.intercept(
chain = FakeInterceptorChain(request = request),
)
}
assertEquals(errorMessage, throwable.cause?.message)
}
@Suppress("MaxLineLength")
@Test
fun `intercept should add the auth token when auth token is expired and refreshAccessTokenSynchronously returns new token`() {
val token = "token"
interceptor.refreshTokenProvider = object : RefreshTokenProvider {
override fun refreshAccessTokenSynchronously(
userId: String,
): Result<String> = token.asSuccess()
}
val authTokenData = AuthTokenData(
userId = USER_ID,
accessToken = ACCESS_TOKEN,
expiresAtSec = FIXED_CLOCK.instant().epochSecond - 3600L,
)
every { mockAuthTokenProvider.getAuthTokenDataOrNull() } returns authTokenData
val response = interceptor.intercept(
chain = FakeInterceptorChain(request = request),
)
assertEquals(
"Bearer $token",
response.request.header("Authorization"),
)
}
@Test
fun `intercept should throw an exception when an auth token data is missing`() {
val throwable = assertThrows(IOException::class.java) {
interceptor.intercept(
chain = FakeInterceptorChain(request = request),
@@ -47,4 +130,10 @@ class AuthTokenInterceptorTest {
}
}
private val FIXED_CLOCK: Clock = Clock.fixed(
Instant.parse("2023-10-27T12:00:00Z"),
ZoneOffset.UTC,
)
private const val USER_ID: String = "user_id"
private const val ACCESS_TOKEN: String = "access_token"