From cd9c7f98e7aa492e4ce73b9d3a7e18041eba544a Mon Sep 17 00:00:00 2001 From: David Perez Date: Tue, 7 Oct 2025 11:49:57 -0500 Subject: [PATCH] PM-26358: Integrate the token auth logic with the SDK (#5967) --- .../data/auth/manager/AuthTokenManagerImpl.kt | 25 +-- .../platform/manager/SdkClientManagerImpl.kt | 26 +-- .../manager/di/PlatformManagerModule.kt | 2 + .../manager/sdk/SdkRepositoryFactory.kt | 6 + .../manager/sdk/SdkRepositoryFactoryImpl.kt | 12 ++ .../sdk/repository/SdkTokenRepository.kt | 15 ++ .../data/auth/manager/AuthTokenManagerTest.kt | 185 +++++++++++++----- .../manager/sdk/SdkRepositoryFactoryTests.kt | 21 ++ .../sdk/repository/SdkTokenRepositoryTest.kt | 57 ++++++ .../network/di/PlatformNetworkModule.kt | 1 + .../platform/manager/SdkClientManagerImpl.kt | 13 +- .../network/BitwardenServiceClient.kt | 5 + .../network/BitwardenServiceClientImpl.kt | 2 + .../network/interceptor/AuthTokenManager.kt | 117 ++++++----- .../network/interceptor/AuthTokenProvider.kt | 4 + .../network/provider/TokenProvider.kt | 11 ++ .../interceptor/AuthTokenManagerTest.kt | 48 ++++- 17 files changed, 407 insertions(+), 143 deletions(-) create mode 100644 app/src/main/kotlin/com/x8bit/bitwarden/data/platform/manager/sdk/repository/SdkTokenRepository.kt create mode 100644 app/src/test/kotlin/com/x8bit/bitwarden/data/platform/manager/sdk/repository/SdkTokenRepositoryTest.kt create mode 100644 network/src/main/kotlin/com/bitwarden/network/provider/TokenProvider.kt diff --git a/app/src/main/kotlin/com/x8bit/bitwarden/data/auth/manager/AuthTokenManagerImpl.kt b/app/src/main/kotlin/com/x8bit/bitwarden/data/auth/manager/AuthTokenManagerImpl.kt index 37e2277732..73520d05af 100644 --- a/app/src/main/kotlin/com/x8bit/bitwarden/data/auth/manager/AuthTokenManagerImpl.kt +++ b/app/src/main/kotlin/com/x8bit/bitwarden/data/auth/manager/AuthTokenManagerImpl.kt @@ -10,19 +10,20 @@ class AuthTokenManagerImpl( private val authDiskSource: AuthDiskSource, ) : AuthTokenManager { + override fun getAuthTokenDataOrNull(userId: String): AuthTokenData? = + authDiskSource + .getAccountTokens(userId = userId) + ?.takeIf { it.accessToken != null } + ?.let { + AuthTokenData( + userId = userId, + accessToken = requireNotNull(it.accessToken), + expiresAtSec = it.expiresAtSec, + ) + } + override fun getAuthTokenDataOrNull(): AuthTokenData? = authDiskSource .userState ?.activeUserId - ?.let { userId -> - authDiskSource - .getAccountTokens(userId = userId) - ?.takeIf { it.accessToken != null } - ?.let { - AuthTokenData( - userId = userId, - accessToken = requireNotNull(it.accessToken), - expiresAtSec = it.expiresAtSec, - ) - } - } + ?.let(::getAuthTokenDataOrNull) } diff --git a/app/src/main/kotlin/com/x8bit/bitwarden/data/platform/manager/SdkClientManagerImpl.kt b/app/src/main/kotlin/com/x8bit/bitwarden/data/platform/manager/SdkClientManagerImpl.kt index c81f44a7a4..4ac7e686d3 100644 --- a/app/src/main/kotlin/com/x8bit/bitwarden/data/platform/manager/SdkClientManagerImpl.kt +++ b/app/src/main/kotlin/com/x8bit/bitwarden/data/platform/manager/SdkClientManagerImpl.kt @@ -1,21 +1,11 @@ package com.x8bit.bitwarden.data.platform.manager import android.os.Build -import com.bitwarden.core.ClientManagedTokens import com.bitwarden.core.util.isBuildVersionAtLeast import com.bitwarden.data.manager.NativeLibraryManager import com.bitwarden.sdk.Client import com.x8bit.bitwarden.data.platform.manager.sdk.SdkRepositoryFactory -/** - * The token provider to pass to the SDK. - */ -class Token : ClientManagedTokens { - override suspend fun getAccessToken(): String? { - return null - } -} - /** * Primary implementation of [SdkClientManager]. */ @@ -24,14 +14,18 @@ class SdkClientManagerImpl( sdkRepoFactory: SdkRepositoryFactory, private val featureFlagManager: FeatureFlagManager, private val clientProvider: suspend (userId: String?) -> Client = { userId -> - Client(tokenProvider = Token(), settings = null).apply { - platform().loadFlags(featureFlagManager.sdkFeatureFlags) - userId?.let { - platform().state().apply { - registerCipherRepository(sdkRepoFactory.getCipherRepository(userId = it)) + Client( + tokenProvider = sdkRepoFactory.getClientManagedTokens(userId = userId), + settings = null, + ) + .apply { + platform().loadFlags(featureFlagManager.sdkFeatureFlags) + userId?.let { + platform().state().apply { + registerCipherRepository(sdkRepoFactory.getCipherRepository(userId = it)) + } } } - } }, ) : SdkClientManager { private val userIdToClientMap = mutableMapOf() diff --git a/app/src/main/kotlin/com/x8bit/bitwarden/data/platform/manager/di/PlatformManagerModule.kt b/app/src/main/kotlin/com/x8bit/bitwarden/data/platform/manager/di/PlatformManagerModule.kt index 57b7db9f4e..a263fbc7ec 100644 --- a/app/src/main/kotlin/com/x8bit/bitwarden/data/platform/manager/di/PlatformManagerModule.kt +++ b/app/src/main/kotlin/com/x8bit/bitwarden/data/platform/manager/di/PlatformManagerModule.kt @@ -389,8 +389,10 @@ object PlatformManagerModule { @Singleton fun provideSdkRepositoryFactory( vaultDiskSource: VaultDiskSource, + bitwardenServiceClient: BitwardenServiceClient, ): SdkRepositoryFactory = SdkRepositoryFactoryImpl( vaultDiskSource = vaultDiskSource, + bitwardenServiceClient = bitwardenServiceClient, ) @Provides diff --git a/app/src/main/kotlin/com/x8bit/bitwarden/data/platform/manager/sdk/SdkRepositoryFactory.kt b/app/src/main/kotlin/com/x8bit/bitwarden/data/platform/manager/sdk/SdkRepositoryFactory.kt index 1344a5b401..2f19f0165c 100644 --- a/app/src/main/kotlin/com/x8bit/bitwarden/data/platform/manager/sdk/SdkRepositoryFactory.kt +++ b/app/src/main/kotlin/com/x8bit/bitwarden/data/platform/manager/sdk/SdkRepositoryFactory.kt @@ -1,5 +1,6 @@ package com.x8bit.bitwarden.data.platform.manager.sdk +import com.bitwarden.core.ClientManagedTokens import com.bitwarden.sdk.CipherRepository /** @@ -10,4 +11,9 @@ interface SdkRepositoryFactory { * Retrieves or creates a [CipherRepository] for use with the Bitwarden SDK. */ fun getCipherRepository(userId: String): CipherRepository + + /** + * Retrieves or creates a [ClientManagedTokens] for use with the Bitwarden SDK. + */ + fun getClientManagedTokens(userId: String?): ClientManagedTokens } diff --git a/app/src/main/kotlin/com/x8bit/bitwarden/data/platform/manager/sdk/SdkRepositoryFactoryImpl.kt b/app/src/main/kotlin/com/x8bit/bitwarden/data/platform/manager/sdk/SdkRepositoryFactoryImpl.kt index 7f7546c807..a054a765ad 100644 --- a/app/src/main/kotlin/com/x8bit/bitwarden/data/platform/manager/sdk/SdkRepositoryFactoryImpl.kt +++ b/app/src/main/kotlin/com/x8bit/bitwarden/data/platform/manager/sdk/SdkRepositoryFactoryImpl.kt @@ -1,7 +1,10 @@ package com.x8bit.bitwarden.data.platform.manager.sdk +import com.bitwarden.core.ClientManagedTokens +import com.bitwarden.network.BitwardenServiceClient import com.bitwarden.sdk.CipherRepository import com.x8bit.bitwarden.data.platform.manager.sdk.repository.SdkCipherRepository +import com.x8bit.bitwarden.data.platform.manager.sdk.repository.SdkTokenRepository import com.x8bit.bitwarden.data.vault.datasource.disk.VaultDiskSource /** @@ -9,6 +12,7 @@ import com.x8bit.bitwarden.data.vault.datasource.disk.VaultDiskSource */ class SdkRepositoryFactoryImpl( private val vaultDiskSource: VaultDiskSource, + private val bitwardenServiceClient: BitwardenServiceClient, ) : SdkRepositoryFactory { override fun getCipherRepository( userId: String, @@ -17,4 +21,12 @@ class SdkRepositoryFactoryImpl( userId = userId, vaultDiskSource = vaultDiskSource, ) + + override fun getClientManagedTokens( + userId: String?, + ): ClientManagedTokens = + SdkTokenRepository( + userId = userId, + tokenProvider = bitwardenServiceClient.tokenProvider, + ) } diff --git a/app/src/main/kotlin/com/x8bit/bitwarden/data/platform/manager/sdk/repository/SdkTokenRepository.kt b/app/src/main/kotlin/com/x8bit/bitwarden/data/platform/manager/sdk/repository/SdkTokenRepository.kt new file mode 100644 index 0000000000..4d1e6b0226 --- /dev/null +++ b/app/src/main/kotlin/com/x8bit/bitwarden/data/platform/manager/sdk/repository/SdkTokenRepository.kt @@ -0,0 +1,15 @@ +package com.x8bit.bitwarden.data.platform.manager.sdk.repository + +import com.bitwarden.core.ClientManagedTokens +import com.bitwarden.network.provider.TokenProvider + +/** + * A user-scoped implementation of a Bitwarden SDK [ClientManagedTokens]. + */ +class SdkTokenRepository( + private val userId: String?, + private val tokenProvider: TokenProvider, +) : ClientManagedTokens { + override suspend fun getAccessToken(): String? = + userId?.let { tokenProvider.getAccessToken(userId = it) } +} diff --git a/app/src/test/kotlin/com/x8bit/bitwarden/data/auth/manager/AuthTokenManagerTest.kt b/app/src/test/kotlin/com/x8bit/bitwarden/data/auth/manager/AuthTokenManagerTest.kt index 3b9e824bf0..29f146e556 100644 --- a/app/src/test/kotlin/com/x8bit/bitwarden/data/auth/manager/AuthTokenManagerTest.kt +++ b/app/src/test/kotlin/com/x8bit/bitwarden/data/auth/manager/AuthTokenManagerTest.kt @@ -7,6 +7,7 @@ import com.x8bit.bitwarden.data.auth.datasource.disk.model.AccountTokensJson import com.x8bit.bitwarden.data.auth.datasource.disk.model.UserStateJson import com.x8bit.bitwarden.data.auth.datasource.disk.util.FakeAuthDiskSource import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.Nested import org.junit.jupiter.api.Test import org.junit.jupiter.api.assertNull import java.time.ZonedDateTime @@ -16,27 +17,25 @@ class AuthTokenManagerTest { private val fakeAuthDiskSource = FakeAuthDiskSource() private val authTokenManager = AuthTokenManagerImpl(fakeAuthDiskSource) - @Test - fun `UserState is null`() { - fakeAuthDiskSource.userState = null - assertNull(authTokenManager.getAuthTokenDataOrNull()) - } + @Nested + inner class WithUserId { + @Test + fun `UserState is null`() { + fakeAuthDiskSource.userState = null + assertNull(authTokenManager.getAuthTokenDataOrNull(userId = USER_ID)) + } - @Test - fun `Account tokens are null`() { - fakeAuthDiskSource.userState = SINGLE_USER_STATE - .copy( - accounts = mapOf( - USER_ID to ACCOUNT.copy(tokens = null), - ), + @Test + fun `Account tokens are null`() { + fakeAuthDiskSource.userState = SINGLE_USER_STATE.copy( + accounts = mapOf(USER_ID to ACCOUNT.copy(tokens = null)), ) - assertNull(authTokenManager.getAuthTokenDataOrNull()) - } + assertNull(authTokenManager.getAuthTokenDataOrNull(userId = USER_ID)) + } - @Test - fun `Access token is null`() { - fakeAuthDiskSource.userState = SINGLE_USER_STATE - .copy( + @Test + fun `Access token is null`() { + fakeAuthDiskSource.userState = SINGLE_USER_STATE.copy( accounts = mapOf( USER_ID to ACCOUNT.copy( tokens = AccountTokensJson( @@ -46,42 +45,124 @@ class AuthTokenManagerTest { ), ), ) - assertNull(authTokenManager.getAuthTokenDataOrNull()) - } + assertNull(authTokenManager.getAuthTokenDataOrNull(userId = USER_ID)) + } - @Test - fun `getActiveAccessTokenOrNull should return null if user access token is null`() { - fakeAuthDiskSource.userState = SINGLE_USER_STATE - fakeAuthDiskSource.storeAccountTokens( - userId = USER_ID, - accountTokens = AccountTokensJson( - accessToken = null, - refreshToken = REFRESH_TOKEN, - expiresAtSec = EXPIRES_AT_SEC, - ), - ) - assertNull(authTokenManager.getAuthTokenDataOrNull()) - } - - @Test - fun `getActiveAccessTokenOrNull should return active user access token`() { - fakeAuthDiskSource.userState = SINGLE_USER_STATE - fakeAuthDiskSource.storeAccountTokens( - userId = USER_ID, - accountTokens = AccountTokensJson( - accessToken = ACCESS_TOKEN, - refreshToken = REFRESH_TOKEN, - expiresAtSec = EXPIRES_AT_SEC, - ), - ) - assertEquals( - AuthTokenData( + @Test + fun `getActiveAccessTokenOrNull should return null if user access token is null`() { + fakeAuthDiskSource.userState = SINGLE_USER_STATE + fakeAuthDiskSource.storeAccountTokens( userId = USER_ID, - accessToken = ACCESS_TOKEN, - expiresAtSec = EXPIRES_AT_SEC, - ), - authTokenManager.getAuthTokenDataOrNull(), - ) + accountTokens = AccountTokensJson( + accessToken = null, + refreshToken = REFRESH_TOKEN, + expiresAtSec = EXPIRES_AT_SEC, + ), + ) + assertNull(authTokenManager.getAuthTokenDataOrNull(userId = USER_ID)) + } + + @Test + fun `getActiveAccessTokenOrNull should return access token`() { + fakeAuthDiskSource.userState = SINGLE_USER_STATE + fakeAuthDiskSource.storeAccountTokens( + userId = USER_ID, + accountTokens = AccountTokensJson( + accessToken = ACCESS_TOKEN, + refreshToken = REFRESH_TOKEN, + expiresAtSec = EXPIRES_AT_SEC, + ), + ) + assertEquals( + AuthTokenData( + userId = USER_ID, + accessToken = ACCESS_TOKEN, + expiresAtSec = EXPIRES_AT_SEC, + ), + authTokenManager.getAuthTokenDataOrNull(userId = USER_ID), + ) + } + + @Test + fun `getActiveAccessTokenOrNull should return null for unknown userId`() { + fakeAuthDiskSource.userState = SINGLE_USER_STATE + fakeAuthDiskSource.storeAccountTokens( + userId = USER_ID, + accountTokens = AccountTokensJson( + accessToken = ACCESS_TOKEN, + refreshToken = REFRESH_TOKEN, + expiresAtSec = EXPIRES_AT_SEC, + ), + ) + assertNull(authTokenManager.getAuthTokenDataOrNull(userId = "unknown_user_id")) + } + } + + @Nested + inner class WithoutUserId { + @Test + fun `UserState is null`() { + fakeAuthDiskSource.userState = null + assertNull(authTokenManager.getAuthTokenDataOrNull()) + } + + @Test + fun `Account tokens are null`() { + fakeAuthDiskSource.userState = SINGLE_USER_STATE.copy( + accounts = mapOf(USER_ID to ACCOUNT.copy(tokens = null)), + ) + assertNull(authTokenManager.getAuthTokenDataOrNull()) + } + + @Test + fun `Access token is null`() { + fakeAuthDiskSource.userState = SINGLE_USER_STATE.copy( + accounts = mapOf( + USER_ID to ACCOUNT.copy( + tokens = AccountTokensJson( + accessToken = null, + refreshToken = null, + ), + ), + ), + ) + assertNull(authTokenManager.getAuthTokenDataOrNull()) + } + + @Test + fun `getActiveAccessTokenOrNull should return null if user access token is null`() { + fakeAuthDiskSource.userState = SINGLE_USER_STATE + fakeAuthDiskSource.storeAccountTokens( + userId = USER_ID, + accountTokens = AccountTokensJson( + accessToken = null, + refreshToken = REFRESH_TOKEN, + expiresAtSec = EXPIRES_AT_SEC, + ), + ) + assertNull(authTokenManager.getAuthTokenDataOrNull()) + } + + @Test + fun `getActiveAccessTokenOrNull should return active user access token`() { + fakeAuthDiskSource.userState = SINGLE_USER_STATE + fakeAuthDiskSource.storeAccountTokens( + userId = USER_ID, + accountTokens = AccountTokensJson( + accessToken = ACCESS_TOKEN, + refreshToken = REFRESH_TOKEN, + expiresAtSec = EXPIRES_AT_SEC, + ), + ) + assertEquals( + AuthTokenData( + userId = USER_ID, + accessToken = ACCESS_TOKEN, + expiresAtSec = EXPIRES_AT_SEC, + ), + authTokenManager.getAuthTokenDataOrNull(), + ) + } } } diff --git a/app/src/test/kotlin/com/x8bit/bitwarden/data/platform/manager/sdk/SdkRepositoryFactoryTests.kt b/app/src/test/kotlin/com/x8bit/bitwarden/data/platform/manager/sdk/SdkRepositoryFactoryTests.kt index 311fb091be..a2d654feec 100644 --- a/app/src/test/kotlin/com/x8bit/bitwarden/data/platform/manager/sdk/SdkRepositoryFactoryTests.kt +++ b/app/src/test/kotlin/com/x8bit/bitwarden/data/platform/manager/sdk/SdkRepositoryFactoryTests.kt @@ -1,6 +1,8 @@ package com.x8bit.bitwarden.data.platform.manager.sdk +import com.bitwarden.network.BitwardenServiceClient import com.x8bit.bitwarden.data.vault.datasource.disk.VaultDiskSource +import io.mockk.every import io.mockk.mockk import org.junit.jupiter.api.Assertions.assertNotEquals import org.junit.jupiter.api.Test @@ -8,9 +10,13 @@ import org.junit.jupiter.api.Test class SdkRepositoryFactoryTests { private val vaultDiskSource: VaultDiskSource = mockk() + private val bitwardenServiceClient: BitwardenServiceClient = mockk { + every { tokenProvider } returns mockk() + } private val sdkRepoFactory: SdkRepositoryFactory = SdkRepositoryFactoryImpl( vaultDiskSource = vaultDiskSource, + bitwardenServiceClient = bitwardenServiceClient, ) @Test @@ -27,4 +33,19 @@ class SdkRepositoryFactoryTests { val thirdClient = sdkRepoFactory.getCipherRepository(userId = otherUserId) assertNotEquals(firstClient, thirdClient) } + + @Test + fun `getClientManagedTokens should create a new client`() { + val userId = "userId" + val firstClient = sdkRepoFactory.getClientManagedTokens(userId = userId) + + // Additional calls for the same userId should create a repo + val secondClient = sdkRepoFactory.getClientManagedTokens(userId = userId) + assertNotEquals(firstClient, secondClient) + + // Additional calls for different userIds should return a different repo + val otherUserId = "otherUserId" + val thirdClient = sdkRepoFactory.getClientManagedTokens(userId = otherUserId) + assertNotEquals(firstClient, thirdClient) + } } diff --git a/app/src/test/kotlin/com/x8bit/bitwarden/data/platform/manager/sdk/repository/SdkTokenRepositoryTest.kt b/app/src/test/kotlin/com/x8bit/bitwarden/data/platform/manager/sdk/repository/SdkTokenRepositoryTest.kt new file mode 100644 index 0000000000..47425e9120 --- /dev/null +++ b/app/src/test/kotlin/com/x8bit/bitwarden/data/platform/manager/sdk/repository/SdkTokenRepositoryTest.kt @@ -0,0 +1,57 @@ +package com.x8bit.bitwarden.data.platform.manager.sdk.repository + +import com.bitwarden.network.provider.TokenProvider +import io.mockk.every +import io.mockk.mockk +import io.mockk.verify +import kotlinx.coroutines.test.runTest +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.Assertions.assertNull +import org.junit.jupiter.api.Test + +class SdkTokenRepositoryTest { + + private val tokenProvider: TokenProvider = mockk() + + @Test + fun `getAccessToken should return null when userId is null`() = runTest { + val repository = createSdkTokenRepository(userId = null) + assertNull(repository.getAccessToken()) + verify(exactly = 0) { + tokenProvider.getAccessToken(userId = any()) + } + } + + @Test + fun `getAccessToken should return null when userId is valid and tokenProvider returns null`() = + runTest { + every { tokenProvider.getAccessToken(userId = USER_ID) } returns null + val repository = createSdkTokenRepository() + assertNull(repository.getAccessToken()) + verify(exactly = 1) { + tokenProvider.getAccessToken(userId = USER_ID) + } + } + + @Suppress("MaxLineLength") + @Test + fun `getAccessToken should return access token when userId is valid and tokenProvider returns an access token`() = + runTest { + val accessToken = "access_token" + every { tokenProvider.getAccessToken(userId = USER_ID) } returns accessToken + val repository = createSdkTokenRepository() + assertEquals(accessToken, repository.getAccessToken()) + verify(exactly = 1) { + tokenProvider.getAccessToken(userId = USER_ID) + } + } + + private fun createSdkTokenRepository( + userId: String? = USER_ID, + ): SdkTokenRepository = SdkTokenRepository( + userId = userId, + tokenProvider = tokenProvider, + ) +} + +private const val USER_ID: String = "userId" diff --git a/authenticator/src/main/kotlin/com/bitwarden/authenticator/data/platform/datasource/network/di/PlatformNetworkModule.kt b/authenticator/src/main/kotlin/com/bitwarden/authenticator/data/platform/datasource/network/di/PlatformNetworkModule.kt index 4dd98c36c5..b1ea2c4d12 100644 --- a/authenticator/src/main/kotlin/com/bitwarden/authenticator/data/platform/datasource/network/di/PlatformNetworkModule.kt +++ b/authenticator/src/main/kotlin/com/bitwarden/authenticator/data/platform/datasource/network/di/PlatformNetworkModule.kt @@ -56,6 +56,7 @@ object PlatformNetworkModule { enableHttpBodyLogging = BuildConfig.DEBUG, authTokenProvider = object : AuthTokenProvider { override fun getAuthTokenDataOrNull(): AuthTokenData? = null + override fun getAuthTokenDataOrNull(userId: String): AuthTokenData? = null }, certificateProvider = object : CertificateProvider { override fun chooseClientAlias( diff --git a/authenticator/src/main/kotlin/com/bitwarden/authenticator/data/platform/manager/SdkClientManagerImpl.kt b/authenticator/src/main/kotlin/com/bitwarden/authenticator/data/platform/manager/SdkClientManagerImpl.kt index 25123b08a9..b8add7c41b 100644 --- a/authenticator/src/main/kotlin/com/bitwarden/authenticator/data/platform/manager/SdkClientManagerImpl.kt +++ b/authenticator/src/main/kotlin/com/bitwarden/authenticator/data/platform/manager/SdkClientManagerImpl.kt @@ -9,7 +9,9 @@ import com.bitwarden.sdk.Client class SdkClientManagerImpl( private val clientProvider: suspend () -> Client = { Client( - tokenProvider = Token(), + tokenProvider = object : ClientManagedTokens { + override suspend fun getAccessToken(): String? = null + }, settings = null, ) }, @@ -21,13 +23,4 @@ class SdkClientManagerImpl( override fun destroyClient() { client = null } - - /** - * The token provider to pass to the SDK. - */ - private class Token : ClientManagedTokens { - override suspend fun getAccessToken(): String? { - return null - } - } } diff --git a/network/src/main/kotlin/com/bitwarden/network/BitwardenServiceClient.kt b/network/src/main/kotlin/com/bitwarden/network/BitwardenServiceClient.kt index b407db049e..08fe218473 100644 --- a/network/src/main/kotlin/com/bitwarden/network/BitwardenServiceClient.kt +++ b/network/src/main/kotlin/com/bitwarden/network/BitwardenServiceClient.kt @@ -5,6 +5,7 @@ package com.bitwarden.network import com.bitwarden.annotation.OmitFromCoverage import com.bitwarden.network.model.BitwardenServiceClientConfig import com.bitwarden.network.provider.RefreshTokenProvider +import com.bitwarden.network.provider.TokenProvider import com.bitwarden.network.service.AccountsService import com.bitwarden.network.service.AuthRequestsService import com.bitwarden.network.service.CiphersService @@ -48,6 +49,10 @@ import com.bitwarden.network.service.SyncService * ``` */ interface BitwardenServiceClient { + /** + * Provides access to the token provider. + */ + val tokenProvider: TokenProvider /** * Provides access to the Accounts service. diff --git a/network/src/main/kotlin/com/bitwarden/network/BitwardenServiceClientImpl.kt b/network/src/main/kotlin/com/bitwarden/network/BitwardenServiceClientImpl.kt index f039e4f487..9b4bec6c9e 100644 --- a/network/src/main/kotlin/com/bitwarden/network/BitwardenServiceClientImpl.kt +++ b/network/src/main/kotlin/com/bitwarden/network/BitwardenServiceClientImpl.kt @@ -7,6 +7,7 @@ import com.bitwarden.network.interceptor.BaseUrlInterceptors import com.bitwarden.network.interceptor.HeadersInterceptor import com.bitwarden.network.model.BitwardenServiceClientConfig import com.bitwarden.network.provider.RefreshTokenProvider +import com.bitwarden.network.provider.TokenProvider import com.bitwarden.network.retrofit.Retrofits import com.bitwarden.network.retrofit.RetrofitsImpl import com.bitwarden.network.service.AccountsServiceImpl @@ -55,6 +56,7 @@ internal class BitwardenServiceClientImpl( clock = bitwardenServiceClientConfig.clock, authTokenProvider = bitwardenServiceClientConfig.authTokenProvider, ) + override val tokenProvider: TokenProvider = authTokenManager private val clientJson = Json { // If there are keys returned by the server not modeled by a serializable class, diff --git a/network/src/main/kotlin/com/bitwarden/network/interceptor/AuthTokenManager.kt b/network/src/main/kotlin/com/bitwarden/network/interceptor/AuthTokenManager.kt index 31c337ded9..990b4ea229 100644 --- a/network/src/main/kotlin/com/bitwarden/network/interceptor/AuthTokenManager.kt +++ b/network/src/main/kotlin/com/bitwarden/network/interceptor/AuthTokenManager.kt @@ -1,6 +1,10 @@ package com.bitwarden.network.interceptor +import com.bitwarden.core.data.util.asFailure +import com.bitwarden.core.data.util.asSuccess +import com.bitwarden.network.model.AuthTokenData import com.bitwarden.network.provider.RefreshTokenProvider +import com.bitwarden.network.provider.TokenProvider import com.bitwarden.network.util.HEADER_KEY_AUTHORIZATION import com.bitwarden.network.util.HEADER_VALUE_BEARER_PREFIX import com.bitwarden.network.util.parseJwtTokenDataOrNull @@ -25,70 +29,61 @@ private const val EXPIRATION_OFFSET_MINUTES: Long = 5L internal class AuthTokenManager( private val clock: Clock, private val authTokenProvider: AuthTokenProvider, -) : Authenticator, Interceptor { +) : TokenProvider, Authenticator, Interceptor { var refreshTokenProvider: RefreshTokenProvider? = null + @Synchronized + override fun getAccessToken( + userId: String, + ): String? = authTokenProvider + .getAuthTokenDataOrNull(userId = userId) + ?.let { getAccessToken(authTokenData = it).getOrNull() } + + @Synchronized override fun authenticate( route: Route?, response: Response, ): Request? { - synchronized(this) { - if (response.shouldSkipAuthentication()) { - // If the same request keeps failing, let's just let the 401 pass through. - return null + if (response.shouldSkipAuthentication()) { + // If the same request keeps failing, let's just let the 401 pass through. + return null + } + val accessToken = requireNotNull( + response + .request + .header(name = HEADER_KEY_AUTHORIZATION) + ?.substringAfter(delimiter = HEADER_VALUE_BEARER_PREFIX), + ) + return when (val userId = parseJwtTokenDataOrNull(accessToken)?.userId) { + null -> { + // We are unable to get the user ID, let's just let the 401 pass through. + null } - val accessToken = requireNotNull( - response - .request - .header(name = HEADER_KEY_AUTHORIZATION) - ?.substringAfter(delimiter = HEADER_VALUE_BEARER_PREFIX), - ) - return when (val userId = parseJwtTokenDataOrNull(accessToken)?.userId) { - null -> { - // We are unable to get the user ID, let's just let the 401 pass through. - null - } - else -> { - Timber.d("Attempting to refresh token due to unauthorized") - refreshTokenProvider - ?.refreshAccessTokenSynchronously(userId = userId) - ?.fold( - onFailure = { null }, - onSuccess = { newAccessToken -> - response - .request - .newBuilder() - .header( - name = HEADER_KEY_AUTHORIZATION, - value = "$HEADER_VALUE_BEARER_PREFIX$newAccessToken", - ) - .build() - }, - ) - } + else -> { + Timber.d("Attempting to refresh token due to unauthorized") + refreshTokenProvider + ?.refreshAccessTokenSynchronously(userId = userId) + ?.fold( + onFailure = { null }, + onSuccess = { newAccessToken -> + response + .request + .newBuilder() + .header( + name = HEADER_KEY_AUTHORIZATION, + value = "$HEADER_VALUE_BEARER_PREFIX$newAccessToken", + ) + .build() + }, + ) } } } override fun intercept(chain: Interceptor.Chain): Response { - val token = synchronized(this) { - val tokenData = authTokenProvider - .getAuthTokenDataOrNull() - ?: throw IOException(IllegalStateException(MISSING_TOKEN_MESSAGE)) - val expirationTime = Instant - .ofEpochSecond(tokenData.expiresAtSec) - .minus(EXPIRATION_OFFSET_MINUTES, ChronoUnit.MINUTES) - if (clock.instant().isAfter(expirationTime)) { - Timber.d("Attempting to refresh token due to expiration") - refreshTokenProvider - ?.refreshAccessTokenSynchronously(userId = tokenData.userId) - ?.getOrElse { throw IOException(it) } - ?: throw IOException(IllegalStateException(MISSING_PROVIDER_MESSAGE)) - } else { - tokenData.accessToken - } - } + val token = getAccessToken() + ?: throw IOException(IllegalStateException(MISSING_TOKEN_MESSAGE)) val request = chain .request() .newBuilder() @@ -100,5 +95,25 @@ internal class AuthTokenManager( return chain.proceed(request) } + @Synchronized + private fun getAccessToken(): String? = authTokenProvider + .getAuthTokenDataOrNull() + ?.let { getAccessToken(authTokenData = it).getOrThrow() } + + @Synchronized + private fun getAccessToken(authTokenData: AuthTokenData): Result { + val expirationTime = Instant + .ofEpochSecond(authTokenData.expiresAtSec) + .minus(EXPIRATION_OFFSET_MINUTES, ChronoUnit.MINUTES) + return if (clock.instant().isAfter(expirationTime)) { + Timber.d("Attempting to refresh token due to expiration") + refreshTokenProvider + ?.refreshAccessTokenSynchronously(userId = authTokenData.userId) + ?: IOException(IllegalStateException(MISSING_PROVIDER_MESSAGE)).asFailure() + } else { + authTokenData.accessToken.asSuccess() + } + } + private fun Response.shouldSkipAuthentication(): Boolean = this.priorResponse != null } diff --git a/network/src/main/kotlin/com/bitwarden/network/interceptor/AuthTokenProvider.kt b/network/src/main/kotlin/com/bitwarden/network/interceptor/AuthTokenProvider.kt index 9ab30058bd..6d07e499c2 100644 --- a/network/src/main/kotlin/com/bitwarden/network/interceptor/AuthTokenProvider.kt +++ b/network/src/main/kotlin/com/bitwarden/network/interceptor/AuthTokenProvider.kt @@ -6,6 +6,10 @@ import com.bitwarden.network.model.AuthTokenData * A provider for all the functionality needed to properly refresh the users access token. */ interface AuthTokenProvider { + /** + * The specified user's auth token data. + */ + fun getAuthTokenDataOrNull(userId: String): AuthTokenData? /** * The currently active user's auth token data. diff --git a/network/src/main/kotlin/com/bitwarden/network/provider/TokenProvider.kt b/network/src/main/kotlin/com/bitwarden/network/provider/TokenProvider.kt new file mode 100644 index 0000000000..0ffe532dbe --- /dev/null +++ b/network/src/main/kotlin/com/bitwarden/network/provider/TokenProvider.kt @@ -0,0 +1,11 @@ +package com.bitwarden.network.provider + +/** + * A provider for authentication tokens. + */ +interface TokenProvider { + /** + * Retrieves an up-to-date token for the specified user. + */ + fun getAccessToken(userId: String): String? +} diff --git a/network/src/test/kotlin/com/bitwarden/network/interceptor/AuthTokenManagerTest.kt b/network/src/test/kotlin/com/bitwarden/network/interceptor/AuthTokenManagerTest.kt index 2a621ac140..c2acf07a67 100644 --- a/network/src/test/kotlin/com/bitwarden/network/interceptor/AuthTokenManagerTest.kt +++ b/network/src/test/kotlin/com/bitwarden/network/interceptor/AuthTokenManagerTest.kt @@ -51,6 +51,50 @@ class AuthTokenManagerTest { unmockkStatic(::parseJwtTokenDataOrNull) } + @Nested + inner class TokenProvider { + @Test + fun `returns null if token provider has no auth data for user ID`() { + val userId = "userId" + every { mockAuthTokenProvider.getAuthTokenDataOrNull(userId = userId) } returns null + val result = authTokenManager.getAccessToken(userId = userId) + assertNull(result) + } + + @Test + fun `returns null if refresh fails`() { + val userId = "userId" + val authData = AuthTokenData( + userId = userId, + accessToken = ACCESS_TOKEN, + expiresAtSec = FIXED_CLOCK.instant().epochSecond, + ) + every { mockAuthTokenProvider.getAuthTokenDataOrNull(userId = userId) } returns authData + every { + refreshTokenProvider.refreshAccessTokenSynchronously(userId = userId) + } returns Throwable("Fail!").asFailure() + val result = authTokenManager.getAccessToken(userId = userId) + assertNull(result) + } + + @Test + fun `returns access token if refresh is not required`() { + val userId = "userId" + val authData = AuthTokenData( + userId = userId, + accessToken = ACCESS_TOKEN, + expiresAtSec = 0L, + ) + val refreshedAccessToken = "refreshed_access_token" + every { mockAuthTokenProvider.getAuthTokenDataOrNull(userId = userId) } returns authData + every { + refreshTokenProvider.refreshAccessTokenSynchronously(userId = userId) + } returns refreshedAccessToken.asSuccess() + val result = authTokenManager.getAccessToken(userId = userId) + assertEquals(refreshedAccessToken, result) + } + } + @Nested inner class Authenticator { @Test @@ -158,7 +202,7 @@ class AuthTokenManagerTest { authTokenManager.refreshTokenProvider = object : RefreshTokenProvider { override fun refreshAccessTokenSynchronously( userId: String, - ): Result = Throwable(errorMessage).asFailure() + ): Result = IOException(errorMessage).asFailure() } val authTokenData = AuthTokenData( userId = USER_ID, @@ -172,7 +216,7 @@ class AuthTokenManagerTest { chain = FakeInterceptorChain(request = request), ) } - assertEquals(errorMessage, throwable.cause?.message) + assertEquals(errorMessage, throwable?.message) } @Suppress("MaxLineLength")