diff --git a/app/src/main/kotlin/com/x8bit/bitwarden/data/auth/datasource/disk/model/AccountTokensJson.kt b/app/src/main/kotlin/com/x8bit/bitwarden/data/auth/datasource/disk/model/AccountTokensJson.kt index dacd8a6f20..2e98adb6a9 100644 --- a/app/src/main/kotlin/com/x8bit/bitwarden/data/auth/datasource/disk/model/AccountTokensJson.kt +++ b/app/src/main/kotlin/com/x8bit/bitwarden/data/auth/datasource/disk/model/AccountTokensJson.kt @@ -8,6 +8,7 @@ import kotlinx.serialization.Serializable * * @property accessToken The user's primary access token. * @property refreshToken The user's refresh token. + * @property expiresAtSec The time at which the token expires in epoch seconds. */ @Serializable data class AccountTokensJson( @@ -16,6 +17,9 @@ data class AccountTokensJson( @SerialName("refreshToken") val refreshToken: String?, + + @SerialName("expiresAtSec") + val expiresAtSec: Long = Long.MAX_VALUE, ) { /** * Returns `true` if the user is logged in, `false otherwise. diff --git a/app/src/main/kotlin/com/x8bit/bitwarden/data/auth/manager/AuthTokenManagerImpl.kt b/app/src/main/kotlin/com/x8bit/bitwarden/data/auth/manager/AuthTokenManagerImpl.kt index 6cd83275e4..37e2277732 100644 --- a/app/src/main/kotlin/com/x8bit/bitwarden/data/auth/manager/AuthTokenManagerImpl.kt +++ b/app/src/main/kotlin/com/x8bit/bitwarden/data/auth/manager/AuthTokenManagerImpl.kt @@ -1,5 +1,6 @@ package com.x8bit.bitwarden.data.auth.manager +import com.bitwarden.network.model.AuthTokenData import com.x8bit.bitwarden.data.auth.datasource.disk.AuthDiskSource /** @@ -9,9 +10,19 @@ class AuthTokenManagerImpl( private val authDiskSource: AuthDiskSource, ) : AuthTokenManager { - override fun getActiveAccessTokenOrNull(): String? = authDiskSource + override fun getAuthTokenDataOrNull(): AuthTokenData? = authDiskSource .userState ?.activeUserId - ?.let { authDiskSource.getAccountTokens(it) } - ?.accessToken + ?.let { userId -> + authDiskSource + .getAccountTokens(userId = userId) + ?.takeIf { it.accessToken != null } + ?.let { + AuthTokenData( + userId = userId, + accessToken = requireNotNull(it.accessToken), + expiresAtSec = it.expiresAtSec, + ) + } + } } diff --git a/app/src/main/kotlin/com/x8bit/bitwarden/data/auth/repository/AuthRepositoryImpl.kt b/app/src/main/kotlin/com/x8bit/bitwarden/data/auth/repository/AuthRepositoryImpl.kt index 202dd5fb7c..7423944d5e 100644 --- a/app/src/main/kotlin/com/x8bit/bitwarden/data/auth/repository/AuthRepositoryImpl.kt +++ b/app/src/main/kotlin/com/x8bit/bitwarden/data/auth/repository/AuthRepositoryImpl.kt @@ -146,6 +146,7 @@ import kotlinx.coroutines.flow.onEach import kotlinx.coroutines.flow.receiveAsFlow import kotlinx.coroutines.flow.stateIn import kotlinx.coroutines.flow.update +import java.time.Clock import javax.inject.Singleton /** @@ -154,6 +155,7 @@ import javax.inject.Singleton @Suppress("LargeClass", "LongParameterList", "TooManyFunctions") @Singleton class AuthRepositoryImpl( + private val clock: Clock, private val accountsService: AccountsService, private val devicesService: DevicesService, private val haveIBeenPwnedService: HaveIBeenPwnedService, @@ -792,6 +794,8 @@ class AuthRepositoryImpl( accountTokens = AccountTokensJson( accessToken = refreshTokenResponse.accessToken, refreshToken = refreshTokenResponse.refreshToken, + expiresAtSec = clock.instant().epochSecond + + refreshTokenResponse.expiresIn, ), ) refreshTokenResponse.accessToken.asSuccess() @@ -1778,6 +1782,7 @@ class AuthRepositoryImpl( accountTokens = AccountTokensJson( accessToken = loginResponse.accessToken, refreshToken = loginResponse.refreshToken, + expiresAtSec = clock.instant().epochSecond + loginResponse.expiresInSeconds, ), ) settingsRepository.hasUserLoggedInOrCreatedAccount = true diff --git a/app/src/main/kotlin/com/x8bit/bitwarden/data/auth/repository/di/AuthRepositoryModule.kt b/app/src/main/kotlin/com/x8bit/bitwarden/data/auth/repository/di/AuthRepositoryModule.kt index b9753a24ba..ce7935e3d2 100644 --- a/app/src/main/kotlin/com/x8bit/bitwarden/data/auth/repository/di/AuthRepositoryModule.kt +++ b/app/src/main/kotlin/com/x8bit/bitwarden/data/auth/repository/di/AuthRepositoryModule.kt @@ -27,6 +27,7 @@ import dagger.Module import dagger.Provides import dagger.hilt.InstallIn import dagger.hilt.components.SingletonComponent +import java.time.Clock import javax.inject.Singleton /** @@ -39,6 +40,7 @@ object AuthRepositoryModule { @Provides @Singleton fun providesAuthRepository( + clock: Clock, accountsService: AccountsService, devicesService: DevicesService, identityService: IdentityService, @@ -61,6 +63,7 @@ object AuthRepositoryModule { firstTimeActionManager: FirstTimeActionManager, logsManager: LogsManager, ): AuthRepository = AuthRepositoryImpl( + clock = clock, accountsService = accountsService, devicesService = devicesService, identityService = identityService, diff --git a/app/src/test/kotlin/com/x8bit/bitwarden/data/auth/manager/AuthTokenManagerTest.kt b/app/src/test/kotlin/com/x8bit/bitwarden/data/auth/manager/AuthTokenManagerTest.kt index 26d1b99694..3b9e824bf0 100644 --- a/app/src/test/kotlin/com/x8bit/bitwarden/data/auth/manager/AuthTokenManagerTest.kt +++ b/app/src/test/kotlin/com/x8bit/bitwarden/data/auth/manager/AuthTokenManagerTest.kt @@ -1,5 +1,6 @@ package com.x8bit.bitwarden.data.auth.manager +import com.bitwarden.network.model.AuthTokenData import com.bitwarden.network.model.KdfTypeJson import com.x8bit.bitwarden.data.auth.datasource.disk.model.AccountJson import com.x8bit.bitwarden.data.auth.datasource.disk.model.AccountTokensJson @@ -18,7 +19,7 @@ class AuthTokenManagerTest { @Test fun `UserState is null`() { fakeAuthDiskSource.userState = null - assertNull(authTokenManager.getActiveAccessTokenOrNull()) + assertNull(authTokenManager.getAuthTokenDataOrNull()) } @Test @@ -29,7 +30,7 @@ class AuthTokenManagerTest { USER_ID to ACCOUNT.copy(tokens = null), ), ) - assertNull(authTokenManager.getActiveAccessTokenOrNull()) + assertNull(authTokenManager.getAuthTokenDataOrNull()) } @Test @@ -45,7 +46,21 @@ class AuthTokenManagerTest { ), ), ) - assertNull(authTokenManager.getActiveAccessTokenOrNull()) + assertNull(authTokenManager.getAuthTokenDataOrNull()) + } + + @Test + fun `getActiveAccessTokenOrNull should return null if user access token is null`() { + fakeAuthDiskSource.userState = SINGLE_USER_STATE + fakeAuthDiskSource.storeAccountTokens( + userId = USER_ID, + accountTokens = AccountTokensJson( + accessToken = null, + refreshToken = REFRESH_TOKEN, + expiresAtSec = EXPIRES_AT_SEC, + ), + ) + assertNull(authTokenManager.getAuthTokenDataOrNull()) } @Test @@ -56,11 +71,16 @@ class AuthTokenManagerTest { accountTokens = AccountTokensJson( accessToken = ACCESS_TOKEN, refreshToken = REFRESH_TOKEN, + expiresAtSec = EXPIRES_AT_SEC, ), ) assertEquals( - ACCESS_TOKEN, - authTokenManager.getActiveAccessTokenOrNull(), + AuthTokenData( + userId = USER_ID, + accessToken = ACCESS_TOKEN, + expiresAtSec = EXPIRES_AT_SEC, + ), + authTokenManager.getAuthTokenDataOrNull(), ) } } @@ -69,6 +89,7 @@ private const val EMAIL: String = "test@bitwarden.com" private const val USER_ID: String = "2a135b23-e1fb-42c9-bec3-573857bc8181" private const val ACCESS_TOKEN: String = "accessToken" private const val REFRESH_TOKEN: String = "refreshToken" +private const val EXPIRES_AT_SEC: Long = 3600 private val ACCOUNT: AccountJson = AccountJson( profile = AccountJson.Profile( userId = USER_ID, @@ -91,6 +112,7 @@ private val ACCOUNT: AccountJson = AccountJson( tokens = AccountTokensJson( accessToken = ACCESS_TOKEN, refreshToken = REFRESH_TOKEN, + expiresAtSec = EXPIRES_AT_SEC, ), settings = AccountJson.Settings( environmentUrlData = null, diff --git a/app/src/test/kotlin/com/x8bit/bitwarden/data/auth/repository/AuthRepositoryTest.kt b/app/src/test/kotlin/com/x8bit/bitwarden/data/auth/repository/AuthRepositoryTest.kt index 1827c784e2..4dcf8b75ad 100644 --- a/app/src/test/kotlin/com/x8bit/bitwarden/data/auth/repository/AuthRepositoryTest.kt +++ b/app/src/test/kotlin/com/x8bit/bitwarden/data/auth/repository/AuthRepositoryTest.kt @@ -154,6 +154,9 @@ import org.junit.jupiter.api.Assertions.assertTrue import org.junit.jupiter.api.BeforeEach import org.junit.jupiter.api.Test import org.junit.jupiter.api.assertDoesNotThrow +import java.time.Clock +import java.time.Instant +import java.time.ZoneOffset import java.time.ZonedDateTime import javax.net.ssl.SSLHandshakeException @@ -259,7 +262,8 @@ class AuthRepositoryTest { every { setUserData(userId = any(), environmentType = any()) } just runs } - private val repository = AuthRepositoryImpl( + private val repository: AuthRepository = AuthRepositoryImpl( + clock = FIXED_CLOCK, accountsService = accountsService, devicesService = devicesService, identityService = identityService, @@ -899,6 +903,7 @@ class AuthRepositoryTest { val updatedAccountTokens = AccountTokensJson( accessToken = ACCESS_TOKEN_2, refreshToken = REFRESH_TOKEN_2, + expiresAtSec = FIXED_CLOCK.instant().epochSecond + ACCESS_TOKEN_2_EXPIRES_IN, ) fakeAuthDiskSource.storeAccountTokens( userId = USER_ID_1, @@ -6908,6 +6913,10 @@ class AuthRepositoryTest { } companion object { + private val FIXED_CLOCK: Clock = Clock.fixed( + Instant.parse("2023-10-27T12:00:00Z"), + ZoneOffset.UTC, + ) private const val UNIQUE_APP_ID = "testUniqueAppId" private const val NAME = "Example Name" private const val EMAIL = "test@bitwarden.com" @@ -6919,6 +6928,7 @@ class AuthRepositoryTest { private const val ACCESS_TOKEN_2 = "accessToken2" private const val REFRESH_TOKEN = "refreshToken" private const val REFRESH_TOKEN_2 = "refreshToken2" + private const val ACCESS_TOKEN_2_EXPIRES_IN = 3600 private const val CAPTCHA_KEY = "captcha" private const val TWO_FACTOR_CODE = "123456" private val TWO_FACTOR_METHOD = TwoFactorAuthMethod.EMAIL @@ -6961,7 +6971,7 @@ class AuthRepositoryTest { ) private val REFRESH_TOKEN_RESPONSE_JSON = RefreshTokenResponseJson.Success( accessToken = ACCESS_TOKEN_2, - expiresIn = 3600, + expiresIn = ACCESS_TOKEN_2_EXPIRES_IN, refreshToken = REFRESH_TOKEN_2, tokenType = "Bearer", ) diff --git a/authenticator/src/main/kotlin/com/bitwarden/authenticator/data/platform/datasource/network/di/PlatformNetworkModule.kt b/authenticator/src/main/kotlin/com/bitwarden/authenticator/data/platform/datasource/network/di/PlatformNetworkModule.kt index 01c78c91b6..4dd98c36c5 100644 --- a/authenticator/src/main/kotlin/com/bitwarden/authenticator/data/platform/datasource/network/di/PlatformNetworkModule.kt +++ b/authenticator/src/main/kotlin/com/bitwarden/authenticator/data/platform/datasource/network/di/PlatformNetworkModule.kt @@ -9,6 +9,7 @@ import com.bitwarden.network.BitwardenServiceClient import com.bitwarden.network.bitwardenServiceClient import com.bitwarden.network.interceptor.AuthTokenProvider import com.bitwarden.network.interceptor.BaseUrlsProvider +import com.bitwarden.network.model.AuthTokenData import com.bitwarden.network.model.BitwardenServiceClientConfig import com.bitwarden.network.service.ConfigService import com.bitwarden.network.ssl.CertificateProvider @@ -54,7 +55,7 @@ object PlatformNetworkModule { baseUrlsProvider = baseUrlsProvider, enableHttpBodyLogging = BuildConfig.DEBUG, authTokenProvider = object : AuthTokenProvider { - override fun getActiveAccessTokenOrNull(): String? = null + override fun getAuthTokenDataOrNull(): AuthTokenData? = null }, certificateProvider = object : CertificateProvider { override fun chooseClientAlias( diff --git a/network/src/main/kotlin/com/bitwarden/network/BitwardenServiceClientImpl.kt b/network/src/main/kotlin/com/bitwarden/network/BitwardenServiceClientImpl.kt index 5f46c2140a..90d4b3cd7b 100644 --- a/network/src/main/kotlin/com/bitwarden/network/BitwardenServiceClientImpl.kt +++ b/network/src/main/kotlin/com/bitwarden/network/BitwardenServiceClientImpl.kt @@ -53,6 +53,10 @@ internal class BitwardenServiceClientImpl( ) : BitwardenServiceClient { private val refreshAuthenticator: RefreshAuthenticator = RefreshAuthenticator() + private val authTokenInterceptor: AuthTokenInterceptor = AuthTokenInterceptor( + clock = bitwardenServiceClientConfig.clock, + authTokenProvider = bitwardenServiceClientConfig.authTokenProvider, + ) private val clientJson = Json { // If there are keys returned by the server not modeled by a serializable class, @@ -71,9 +75,7 @@ internal class BitwardenServiceClientImpl( } private val retrofits: Retrofits by lazy { RetrofitsImpl( - authTokenInterceptor = AuthTokenInterceptor( - authTokenProvider = bitwardenServiceClientConfig.authTokenProvider, - ), + authTokenInterceptor = authTokenInterceptor, baseUrlInterceptors = BaseUrlInterceptors( baseUrlsProvider = bitwardenServiceClientConfig.baseUrlsProvider, ), @@ -205,5 +207,6 @@ internal class BitwardenServiceClientImpl( override fun setRefreshTokenProvider(refreshTokenProvider: RefreshTokenProvider?) { refreshAuthenticator.refreshTokenProvider = refreshTokenProvider + authTokenInterceptor.refreshTokenProvider = refreshTokenProvider } } diff --git a/network/src/main/kotlin/com/bitwarden/network/interceptor/AuthTokenInterceptor.kt b/network/src/main/kotlin/com/bitwarden/network/interceptor/AuthTokenInterceptor.kt index 8e6818899d..7a07a1d035 100644 --- a/network/src/main/kotlin/com/bitwarden/network/interceptor/AuthTokenInterceptor.kt +++ b/network/src/main/kotlin/com/bitwarden/network/interceptor/AuthTokenInterceptor.kt @@ -1,22 +1,45 @@ package com.bitwarden.network.interceptor +import com.bitwarden.network.provider.RefreshTokenProvider import com.bitwarden.network.util.HEADER_KEY_AUTHORIZATION import com.bitwarden.network.util.HEADER_VALUE_BEARER_PREFIX import okhttp3.Interceptor import okhttp3.Response +import timber.log.Timber import java.io.IOException +import java.time.Clock +import java.time.Instant +import java.time.temporal.ChronoUnit + +private const val MISSING_TOKEN_MESSAGE: String = "Auth token is missing!" +private const val MISSING_PROVIDER_MESSAGE: String = "Refresh token provider is missing!" +private const val EXPIRATION_OFFSET_MINUTES: Long = 5L /** * Interceptor responsible for adding the auth token(Bearer) to API requests. */ internal class AuthTokenInterceptor( + private val clock: Clock, private val authTokenProvider: AuthTokenProvider, ) : Interceptor { - private val missingTokenMessage = "Auth token is missing!" + var refreshTokenProvider: RefreshTokenProvider? = null override fun intercept(chain: Interceptor.Chain): Response { - val token = authTokenProvider.getActiveAccessTokenOrNull() - ?: throw IOException(IllegalStateException(missingTokenMessage)) + val tokenData = authTokenProvider + .getAuthTokenDataOrNull() + ?: throw IOException(IllegalStateException(MISSING_TOKEN_MESSAGE)) + val expirationTime = Instant + .ofEpochSecond(tokenData.expiresAtSec) + .minus(EXPIRATION_OFFSET_MINUTES, ChronoUnit.MINUTES) + val token = if (clock.instant().isAfter(expirationTime)) { + Timber.d("Attempting to refresh token due to expiration") + refreshTokenProvider + ?.refreshAccessTokenSynchronously(userId = tokenData.userId) + ?.getOrElse { throw IOException(it) } + ?: throw IOException(IllegalStateException(MISSING_PROVIDER_MESSAGE)) + } else { + tokenData.accessToken + } val request = chain .request() .newBuilder() diff --git a/network/src/main/kotlin/com/bitwarden/network/interceptor/AuthTokenProvider.kt b/network/src/main/kotlin/com/bitwarden/network/interceptor/AuthTokenProvider.kt index 6847522c0c..9ab30058bd 100644 --- a/network/src/main/kotlin/com/bitwarden/network/interceptor/AuthTokenProvider.kt +++ b/network/src/main/kotlin/com/bitwarden/network/interceptor/AuthTokenProvider.kt @@ -1,12 +1,14 @@ package com.bitwarden.network.interceptor +import com.bitwarden.network.model.AuthTokenData + /** * A provider for all the functionality needed to properly refresh the users access token. */ interface AuthTokenProvider { /** - * The currently active user's access token. + * The currently active user's auth token data. */ - fun getActiveAccessTokenOrNull(): String? + fun getAuthTokenDataOrNull(): AuthTokenData? } diff --git a/network/src/main/kotlin/com/bitwarden/network/model/AuthTokenData.kt b/network/src/main/kotlin/com/bitwarden/network/model/AuthTokenData.kt new file mode 100644 index 0000000000..8a03b9cd24 --- /dev/null +++ b/network/src/main/kotlin/com/bitwarden/network/model/AuthTokenData.kt @@ -0,0 +1,10 @@ +package com.bitwarden.network.model + +/** + * Contains the access token and expiration data for a user. + */ +data class AuthTokenData( + val userId: String, + val accessToken: String, + val expiresAtSec: Long, +) diff --git a/network/src/test/kotlin/com/bitwarden/network/interceptor/AuthTokenInterceptorTest.kt b/network/src/test/kotlin/com/bitwarden/network/interceptor/AuthTokenInterceptorTest.kt index 7506726821..6ac846f709 100644 --- a/network/src/test/kotlin/com/bitwarden/network/interceptor/AuthTokenInterceptorTest.kt +++ b/network/src/test/kotlin/com/bitwarden/network/interceptor/AuthTokenInterceptorTest.kt @@ -1,5 +1,9 @@ package com.bitwarden.network.interceptor +import com.bitwarden.core.data.util.asFailure +import com.bitwarden.core.data.util.asSuccess +import com.bitwarden.network.model.AuthTokenData +import com.bitwarden.network.provider.RefreshTokenProvider import io.mockk.every import io.mockk.mockk import junit.framework.TestCase.assertEquals @@ -7,12 +11,16 @@ import okhttp3.Request import org.junit.Assert.assertThrows import org.junit.Test import java.io.IOException +import java.time.Clock +import java.time.Instant +import java.time.ZoneOffset class AuthTokenInterceptorTest { private val mockAuthTokenProvider = mockk { - every { getActiveAccessTokenOrNull() } returns null + every { getAuthTokenDataOrNull() } returns null } private val interceptor: AuthTokenInterceptor = AuthTokenInterceptor( + clock = FIXED_CLOCK, authTokenProvider = mockAuthTokenProvider, ) private val request: Request = Request @@ -22,7 +30,12 @@ class AuthTokenInterceptorTest { @Test fun `intercept should add the auth token when set`() { - every { mockAuthTokenProvider.getActiveAccessTokenOrNull() } returns ACCESS_TOKEN + val authTokenData = AuthTokenData( + userId = USER_ID, + accessToken = ACCESS_TOKEN, + expiresAtSec = FIXED_CLOCK.instant().epochSecond + 3600L, + ) + every { mockAuthTokenProvider.getAuthTokenDataOrNull() } returns authTokenData val response = interceptor.intercept( chain = FakeInterceptorChain(request = request), @@ -33,8 +46,78 @@ class AuthTokenInterceptorTest { ) } + @Suppress("MaxLineLength") @Test - fun `intercept should throw an exception when an auth token is missing`() { + fun `intercept should throw an exception when auth token is expired and refreshTokenProvider is missing`() { + val authTokenData = AuthTokenData( + userId = USER_ID, + accessToken = ACCESS_TOKEN, + expiresAtSec = FIXED_CLOCK.instant().epochSecond - 3600L, + ) + every { mockAuthTokenProvider.getAuthTokenDataOrNull() } returns authTokenData + + val throwable = assertThrows(IOException::class.java) { + interceptor.intercept( + chain = FakeInterceptorChain(request = request), + ) + } + assertEquals( + "Refresh token provider is missing!", + throwable.cause?.message, + ) + } + + @Suppress("MaxLineLength") + @Test + fun `intercept should throw an exception when auth token is expired and refreshAccessTokenSynchronously returns an error`() { + val errorMessage = "Fail!" + interceptor.refreshTokenProvider = object : RefreshTokenProvider { + override fun refreshAccessTokenSynchronously( + userId: String, + ): Result = Throwable(errorMessage).asFailure() + } + val authTokenData = AuthTokenData( + userId = USER_ID, + accessToken = ACCESS_TOKEN, + expiresAtSec = FIXED_CLOCK.instant().epochSecond - 3600L, + ) + every { mockAuthTokenProvider.getAuthTokenDataOrNull() } returns authTokenData + + val throwable = assertThrows(IOException::class.java) { + interceptor.intercept( + chain = FakeInterceptorChain(request = request), + ) + } + assertEquals(errorMessage, throwable.cause?.message) + } + + @Suppress("MaxLineLength") + @Test + fun `intercept should add the auth token when auth token is expired and refreshAccessTokenSynchronously returns new token`() { + val token = "token" + interceptor.refreshTokenProvider = object : RefreshTokenProvider { + override fun refreshAccessTokenSynchronously( + userId: String, + ): Result = token.asSuccess() + } + val authTokenData = AuthTokenData( + userId = USER_ID, + accessToken = ACCESS_TOKEN, + expiresAtSec = FIXED_CLOCK.instant().epochSecond - 3600L, + ) + every { mockAuthTokenProvider.getAuthTokenDataOrNull() } returns authTokenData + + val response = interceptor.intercept( + chain = FakeInterceptorChain(request = request), + ) + assertEquals( + "Bearer $token", + response.request.header("Authorization"), + ) + } + + @Test + fun `intercept should throw an exception when an auth token data is missing`() { val throwable = assertThrows(IOException::class.java) { interceptor.intercept( chain = FakeInterceptorChain(request = request), @@ -47,4 +130,10 @@ class AuthTokenInterceptorTest { } } +private val FIXED_CLOCK: Clock = Clock.fixed( + Instant.parse("2023-10-27T12:00:00Z"), + ZoneOffset.UTC, +) + +private const val USER_ID: String = "user_id" private const val ACCESS_TOKEN: String = "access_token"