mirror of
https://github.com/bitwarden/android.git
synced 2026-04-28 03:48:14 -05:00
[PM-24481] Update AuthTokenInterceptor to refresh token on expiration (#5647)
This commit is contained in:
@@ -53,6 +53,10 @@ internal class BitwardenServiceClientImpl(
|
||||
) : BitwardenServiceClient {
|
||||
|
||||
private val refreshAuthenticator: RefreshAuthenticator = RefreshAuthenticator()
|
||||
private val authTokenInterceptor: AuthTokenInterceptor = AuthTokenInterceptor(
|
||||
clock = bitwardenServiceClientConfig.clock,
|
||||
authTokenProvider = bitwardenServiceClientConfig.authTokenProvider,
|
||||
)
|
||||
private val clientJson = Json {
|
||||
|
||||
// If there are keys returned by the server not modeled by a serializable class,
|
||||
@@ -71,9 +75,7 @@ internal class BitwardenServiceClientImpl(
|
||||
}
|
||||
private val retrofits: Retrofits by lazy {
|
||||
RetrofitsImpl(
|
||||
authTokenInterceptor = AuthTokenInterceptor(
|
||||
authTokenProvider = bitwardenServiceClientConfig.authTokenProvider,
|
||||
),
|
||||
authTokenInterceptor = authTokenInterceptor,
|
||||
baseUrlInterceptors = BaseUrlInterceptors(
|
||||
baseUrlsProvider = bitwardenServiceClientConfig.baseUrlsProvider,
|
||||
),
|
||||
@@ -205,5 +207,6 @@ internal class BitwardenServiceClientImpl(
|
||||
|
||||
override fun setRefreshTokenProvider(refreshTokenProvider: RefreshTokenProvider?) {
|
||||
refreshAuthenticator.refreshTokenProvider = refreshTokenProvider
|
||||
authTokenInterceptor.refreshTokenProvider = refreshTokenProvider
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,22 +1,45 @@
|
||||
package com.bitwarden.network.interceptor
|
||||
|
||||
import com.bitwarden.network.provider.RefreshTokenProvider
|
||||
import com.bitwarden.network.util.HEADER_KEY_AUTHORIZATION
|
||||
import com.bitwarden.network.util.HEADER_VALUE_BEARER_PREFIX
|
||||
import okhttp3.Interceptor
|
||||
import okhttp3.Response
|
||||
import timber.log.Timber
|
||||
import java.io.IOException
|
||||
import java.time.Clock
|
||||
import java.time.Instant
|
||||
import java.time.temporal.ChronoUnit
|
||||
|
||||
private const val MISSING_TOKEN_MESSAGE: String = "Auth token is missing!"
|
||||
private const val MISSING_PROVIDER_MESSAGE: String = "Refresh token provider is missing!"
|
||||
private const val EXPIRATION_OFFSET_MINUTES: Long = 5L
|
||||
|
||||
/**
|
||||
* Interceptor responsible for adding the auth token(Bearer) to API requests.
|
||||
*/
|
||||
internal class AuthTokenInterceptor(
|
||||
private val clock: Clock,
|
||||
private val authTokenProvider: AuthTokenProvider,
|
||||
) : Interceptor {
|
||||
private val missingTokenMessage = "Auth token is missing!"
|
||||
var refreshTokenProvider: RefreshTokenProvider? = null
|
||||
|
||||
override fun intercept(chain: Interceptor.Chain): Response {
|
||||
val token = authTokenProvider.getActiveAccessTokenOrNull()
|
||||
?: throw IOException(IllegalStateException(missingTokenMessage))
|
||||
val tokenData = authTokenProvider
|
||||
.getAuthTokenDataOrNull()
|
||||
?: throw IOException(IllegalStateException(MISSING_TOKEN_MESSAGE))
|
||||
val expirationTime = Instant
|
||||
.ofEpochSecond(tokenData.expiresAtSec)
|
||||
.minus(EXPIRATION_OFFSET_MINUTES, ChronoUnit.MINUTES)
|
||||
val token = if (clock.instant().isAfter(expirationTime)) {
|
||||
Timber.d("Attempting to refresh token due to expiration")
|
||||
refreshTokenProvider
|
||||
?.refreshAccessTokenSynchronously(userId = tokenData.userId)
|
||||
?.getOrElse { throw IOException(it) }
|
||||
?: throw IOException(IllegalStateException(MISSING_PROVIDER_MESSAGE))
|
||||
} else {
|
||||
tokenData.accessToken
|
||||
}
|
||||
val request = chain
|
||||
.request()
|
||||
.newBuilder()
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
package com.bitwarden.network.interceptor
|
||||
|
||||
import com.bitwarden.network.model.AuthTokenData
|
||||
|
||||
/**
|
||||
* A provider for all the functionality needed to properly refresh the users access token.
|
||||
*/
|
||||
interface AuthTokenProvider {
|
||||
|
||||
/**
|
||||
* The currently active user's access token.
|
||||
* The currently active user's auth token data.
|
||||
*/
|
||||
fun getActiveAccessTokenOrNull(): String?
|
||||
fun getAuthTokenDataOrNull(): AuthTokenData?
|
||||
}
|
||||
|
||||
@@ -0,0 +1,10 @@
|
||||
package com.bitwarden.network.model
|
||||
|
||||
/**
|
||||
* Contains the access token and expiration data for a user.
|
||||
*/
|
||||
data class AuthTokenData(
|
||||
val userId: String,
|
||||
val accessToken: String,
|
||||
val expiresAtSec: Long,
|
||||
)
|
||||
@@ -1,5 +1,9 @@
|
||||
package com.bitwarden.network.interceptor
|
||||
|
||||
import com.bitwarden.core.data.util.asFailure
|
||||
import com.bitwarden.core.data.util.asSuccess
|
||||
import com.bitwarden.network.model.AuthTokenData
|
||||
import com.bitwarden.network.provider.RefreshTokenProvider
|
||||
import io.mockk.every
|
||||
import io.mockk.mockk
|
||||
import junit.framework.TestCase.assertEquals
|
||||
@@ -7,12 +11,16 @@ import okhttp3.Request
|
||||
import org.junit.Assert.assertThrows
|
||||
import org.junit.Test
|
||||
import java.io.IOException
|
||||
import java.time.Clock
|
||||
import java.time.Instant
|
||||
import java.time.ZoneOffset
|
||||
|
||||
class AuthTokenInterceptorTest {
|
||||
private val mockAuthTokenProvider = mockk<AuthTokenProvider> {
|
||||
every { getActiveAccessTokenOrNull() } returns null
|
||||
every { getAuthTokenDataOrNull() } returns null
|
||||
}
|
||||
private val interceptor: AuthTokenInterceptor = AuthTokenInterceptor(
|
||||
clock = FIXED_CLOCK,
|
||||
authTokenProvider = mockAuthTokenProvider,
|
||||
)
|
||||
private val request: Request = Request
|
||||
@@ -22,7 +30,12 @@ class AuthTokenInterceptorTest {
|
||||
|
||||
@Test
|
||||
fun `intercept should add the auth token when set`() {
|
||||
every { mockAuthTokenProvider.getActiveAccessTokenOrNull() } returns ACCESS_TOKEN
|
||||
val authTokenData = AuthTokenData(
|
||||
userId = USER_ID,
|
||||
accessToken = ACCESS_TOKEN,
|
||||
expiresAtSec = FIXED_CLOCK.instant().epochSecond + 3600L,
|
||||
)
|
||||
every { mockAuthTokenProvider.getAuthTokenDataOrNull() } returns authTokenData
|
||||
|
||||
val response = interceptor.intercept(
|
||||
chain = FakeInterceptorChain(request = request),
|
||||
@@ -33,8 +46,78 @@ class AuthTokenInterceptorTest {
|
||||
)
|
||||
}
|
||||
|
||||
@Suppress("MaxLineLength")
|
||||
@Test
|
||||
fun `intercept should throw an exception when an auth token is missing`() {
|
||||
fun `intercept should throw an exception when auth token is expired and refreshTokenProvider is missing`() {
|
||||
val authTokenData = AuthTokenData(
|
||||
userId = USER_ID,
|
||||
accessToken = ACCESS_TOKEN,
|
||||
expiresAtSec = FIXED_CLOCK.instant().epochSecond - 3600L,
|
||||
)
|
||||
every { mockAuthTokenProvider.getAuthTokenDataOrNull() } returns authTokenData
|
||||
|
||||
val throwable = assertThrows(IOException::class.java) {
|
||||
interceptor.intercept(
|
||||
chain = FakeInterceptorChain(request = request),
|
||||
)
|
||||
}
|
||||
assertEquals(
|
||||
"Refresh token provider is missing!",
|
||||
throwable.cause?.message,
|
||||
)
|
||||
}
|
||||
|
||||
@Suppress("MaxLineLength")
|
||||
@Test
|
||||
fun `intercept should throw an exception when auth token is expired and refreshAccessTokenSynchronously returns an error`() {
|
||||
val errorMessage = "Fail!"
|
||||
interceptor.refreshTokenProvider = object : RefreshTokenProvider {
|
||||
override fun refreshAccessTokenSynchronously(
|
||||
userId: String,
|
||||
): Result<String> = Throwable(errorMessage).asFailure()
|
||||
}
|
||||
val authTokenData = AuthTokenData(
|
||||
userId = USER_ID,
|
||||
accessToken = ACCESS_TOKEN,
|
||||
expiresAtSec = FIXED_CLOCK.instant().epochSecond - 3600L,
|
||||
)
|
||||
every { mockAuthTokenProvider.getAuthTokenDataOrNull() } returns authTokenData
|
||||
|
||||
val throwable = assertThrows(IOException::class.java) {
|
||||
interceptor.intercept(
|
||||
chain = FakeInterceptorChain(request = request),
|
||||
)
|
||||
}
|
||||
assertEquals(errorMessage, throwable.cause?.message)
|
||||
}
|
||||
|
||||
@Suppress("MaxLineLength")
|
||||
@Test
|
||||
fun `intercept should add the auth token when auth token is expired and refreshAccessTokenSynchronously returns new token`() {
|
||||
val token = "token"
|
||||
interceptor.refreshTokenProvider = object : RefreshTokenProvider {
|
||||
override fun refreshAccessTokenSynchronously(
|
||||
userId: String,
|
||||
): Result<String> = token.asSuccess()
|
||||
}
|
||||
val authTokenData = AuthTokenData(
|
||||
userId = USER_ID,
|
||||
accessToken = ACCESS_TOKEN,
|
||||
expiresAtSec = FIXED_CLOCK.instant().epochSecond - 3600L,
|
||||
)
|
||||
every { mockAuthTokenProvider.getAuthTokenDataOrNull() } returns authTokenData
|
||||
|
||||
val response = interceptor.intercept(
|
||||
chain = FakeInterceptorChain(request = request),
|
||||
)
|
||||
assertEquals(
|
||||
"Bearer $token",
|
||||
response.request.header("Authorization"),
|
||||
)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `intercept should throw an exception when an auth token data is missing`() {
|
||||
val throwable = assertThrows(IOException::class.java) {
|
||||
interceptor.intercept(
|
||||
chain = FakeInterceptorChain(request = request),
|
||||
@@ -47,4 +130,10 @@ class AuthTokenInterceptorTest {
|
||||
}
|
||||
}
|
||||
|
||||
private val FIXED_CLOCK: Clock = Clock.fixed(
|
||||
Instant.parse("2023-10-27T12:00:00Z"),
|
||||
ZoneOffset.UTC,
|
||||
)
|
||||
|
||||
private const val USER_ID: String = "user_id"
|
||||
private const val ACCESS_TOKEN: String = "access_token"
|
||||
|
||||
Reference in New Issue
Block a user