diff --git a/app/src/main/java/com/x8bit/bitwarden/data/auth/repository/AuthRepository.kt b/app/src/main/java/com/x8bit/bitwarden/data/auth/repository/AuthRepository.kt index d46170d100..90c3ab608a 100644 --- a/app/src/main/java/com/x8bit/bitwarden/data/auth/repository/AuthRepository.kt +++ b/app/src/main/java/com/x8bit/bitwarden/data/auth/repository/AuthRepository.kt @@ -1,18 +1,19 @@ package com.x8bit.bitwarden.data.auth.repository import com.x8bit.bitwarden.data.auth.datasource.sdk.model.PasswordStrength -import com.x8bit.bitwarden.data.auth.repository.model.UserState import com.x8bit.bitwarden.data.auth.repository.model.AuthState import com.x8bit.bitwarden.data.auth.repository.model.LoginResult import com.x8bit.bitwarden.data.auth.repository.model.RegisterResult +import com.x8bit.bitwarden.data.auth.repository.model.UserState import com.x8bit.bitwarden.data.auth.repository.util.CaptchaCallbackTokenResult +import com.x8bit.bitwarden.data.platform.datasource.network.authenticator.AuthenticatorProvider import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.StateFlow /** * Provides an API for observing an modifying authentication state. */ -interface AuthRepository { +interface AuthRepository : AuthenticatorProvider { /** * Models the current auth state. */ diff --git a/app/src/main/java/com/x8bit/bitwarden/data/auth/repository/AuthRepositoryImpl.kt b/app/src/main/java/com/x8bit/bitwarden/data/auth/repository/AuthRepositoryImpl.kt index be5858012c..9e10786e5c 100644 --- a/app/src/main/java/com/x8bit/bitwarden/data/auth/repository/AuthRepositoryImpl.kt +++ b/app/src/main/java/com/x8bit/bitwarden/data/auth/repository/AuthRepositoryImpl.kt @@ -5,6 +5,7 @@ import com.x8bit.bitwarden.data.auth.datasource.disk.AuthDiskSource import com.x8bit.bitwarden.data.auth.datasource.network.model.GetTokenResponseJson import com.x8bit.bitwarden.data.auth.datasource.network.model.GetTokenResponseJson.CaptchaRequired import com.x8bit.bitwarden.data.auth.datasource.network.model.GetTokenResponseJson.Success +import com.x8bit.bitwarden.data.auth.datasource.network.model.RefreshTokenResponseJson import com.x8bit.bitwarden.data.auth.datasource.network.model.RegisterRequestJson import com.x8bit.bitwarden.data.auth.datasource.network.model.RegisterResponseJson import com.x8bit.bitwarden.data.auth.datasource.network.service.AccountsService @@ -20,6 +21,7 @@ import com.x8bit.bitwarden.data.auth.repository.model.UserState import com.x8bit.bitwarden.data.auth.repository.util.CaptchaCallbackTokenResult import com.x8bit.bitwarden.data.auth.repository.util.toSdkParams import com.x8bit.bitwarden.data.auth.repository.util.toUserState +import com.x8bit.bitwarden.data.auth.repository.util.toUserStateJson import com.x8bit.bitwarden.data.auth.util.KdfParamsConstants.DEFAULT_PBKDF2_ITERATIONS import com.x8bit.bitwarden.data.auth.util.toSdkParams import com.x8bit.bitwarden.data.platform.manager.dispatcher.DispatcherManager @@ -56,12 +58,13 @@ class AuthRepositoryImpl constructor( ) : AuthRepository { private val scope = CoroutineScope(dispatcherManager.io) + override val activeUserId: String? get() = authDiskSource.userState?.activeUserId + override val authStateFlow: StateFlow = authDiskSource .userStateFlow .map { userState -> userState ?.let { - @Suppress("UnsafeCallOnNullableType") AuthState.Authenticated( userState .activeAccount @@ -179,21 +182,42 @@ class AuthRepositoryImpl constructor( }, ) - override fun logout() { - val currentUserState = authDiskSource.userState ?: return + override fun refreshAccessTokenSynchronously(userId: String): Result { + val refreshAccount = authDiskSource.userState?.accounts?.get(userId) + ?: return IllegalStateException("Must be logged in.").asFailure() + return identityService + .refreshTokenSynchronously(refreshAccount.tokens.refreshToken) + .onSuccess { + // Update the existing UserState with updated token information + authDiskSource.userState = it.toUserStateJson( + userId = userId, + previousUserState = requireNotNull(authDiskSource.userState), + ) + } + } - val activeUserId = currentUserState.activeUserId + override fun logout() { + activeUserId?.let { userId -> logout(userId) } + } + + override fun logout(userId: String) { + val currentUserState = authDiskSource.userState ?: return // Remove the active user from the accounts map val updatedAccounts = currentUserState .accounts - .filterKeys { it != activeUserId } - authDiskSource.storeUserKey(userId = activeUserId, userKey = null) - authDiskSource.storePrivateKey(userId = activeUserId, privateKey = null) + .filterKeys { it != userId } + authDiskSource.storeUserKey(userId = userId, userKey = null) + authDiskSource.storePrivateKey(userId = userId, privateKey = null) // Check if there is a new active user if (updatedAccounts.isNotEmpty()) { - val (updatedActiveUserId, updatedActiveAccount) = - updatedAccounts.entries.first() + // If we logged out a non-active user, we want to leave the active user unchanged. + // If we logged out the active user, we want to set the active user to the first one + // in the list. + val updatedActiveUserId = currentUserState + .activeUserId + .takeUnless { it == userId } + ?: updatedAccounts.entries.first().key // Update the user information and emit an updated token authDiskSource.userState = currentUserState.copy( diff --git a/app/src/main/java/com/x8bit/bitwarden/data/auth/repository/util/RefreshTokenResponseExtensions.kt b/app/src/main/java/com/x8bit/bitwarden/data/auth/repository/util/RefreshTokenResponseExtensions.kt new file mode 100644 index 0000000000..d73efbd1ae --- /dev/null +++ b/app/src/main/java/com/x8bit/bitwarden/data/auth/repository/util/RefreshTokenResponseExtensions.kt @@ -0,0 +1,45 @@ +package com.x8bit.bitwarden.data.auth.repository.util + +import com.x8bit.bitwarden.data.auth.datasource.disk.model.AccountJson +import com.x8bit.bitwarden.data.auth.datasource.disk.model.UserStateJson +import com.x8bit.bitwarden.data.auth.datasource.network.model.RefreshTokenResponseJson + +/** + * Converts the given [RefreshTokenResponseJson] to a [UserStateJson], given the following + * additional information: + * + * - the [userId] + * - the [previousUserState] + */ +fun RefreshTokenResponseJson.toUserStateJson( + userId: String, + previousUserState: UserStateJson, +): UserStateJson { + val refreshedAccount = requireNotNull(previousUserState.accounts[userId]) + val accessToken = this.accessToken + val jwtTokenData = requireNotNull(parseJwtTokenDataOrNull(jwtToken = accessToken)) + + val account = refreshedAccount.copy( + profile = refreshedAccount.profile.copy( + userId = jwtTokenData.userId, + email = jwtTokenData.email, + isEmailVerified = jwtTokenData.isEmailVerified, + name = jwtTokenData.name, + hasPremium = jwtTokenData.hasPremium, + ), + tokens = AccountJson.Tokens( + accessToken = accessToken, + refreshToken = this.refreshToken, + ), + ) + + // Update the existing UserState. + return previousUserState.copy( + accounts = previousUserState + .accounts + .toMutableMap() + .apply { + put(userId, account) + }, + ) +} diff --git a/app/src/main/java/com/x8bit/bitwarden/data/platform/datasource/network/authenticator/AuthenticatorProvider.kt b/app/src/main/java/com/x8bit/bitwarden/data/platform/datasource/network/authenticator/AuthenticatorProvider.kt new file mode 100644 index 0000000000..a0473dddf0 --- /dev/null +++ b/app/src/main/java/com/x8bit/bitwarden/data/platform/datasource/network/authenticator/AuthenticatorProvider.kt @@ -0,0 +1,27 @@ +package com.x8bit.bitwarden.data.platform.datasource.network.authenticator + +import com.x8bit.bitwarden.data.auth.datasource.network.model.RefreshTokenResponseJson + +/** + * A provider for all the functionality needed to properly refresh the users access token. + */ +interface AuthenticatorProvider { + + /** + * The currently active user's ID. + */ + val activeUserId: String? + + /** + * Attempts to logout the user based on the [userId]. + */ + fun logout(userId: String) + + /** + * Attempt to refresh the user's access token based on the [userId]. + * + * This call is both synchronous and performs a network request. Make sure that you are calling + * from an appropriate thread. + */ + fun refreshAccessTokenSynchronously(userId: String): Result +} diff --git a/app/src/main/java/com/x8bit/bitwarden/data/platform/datasource/network/authenticator/RefreshAuthenticator.kt b/app/src/main/java/com/x8bit/bitwarden/data/platform/datasource/network/authenticator/RefreshAuthenticator.kt new file mode 100644 index 0000000000..e2722e766f --- /dev/null +++ b/app/src/main/java/com/x8bit/bitwarden/data/platform/datasource/network/authenticator/RefreshAuthenticator.kt @@ -0,0 +1,69 @@ +package com.x8bit.bitwarden.data.platform.datasource.network.authenticator + +import com.x8bit.bitwarden.data.auth.repository.util.parseJwtTokenDataOrNull +import com.x8bit.bitwarden.data.platform.datasource.network.util.HEADER_KEY_AUTHORIZATION +import com.x8bit.bitwarden.data.platform.datasource.network.util.HEADER_VALUE_BEARER_PREFIX +import okhttp3.Authenticator +import okhttp3.Request +import okhttp3.Response +import okhttp3.Route +import javax.inject.Singleton + +/** + * An authenticator used to refresh the access token when a 401 is returned from an API. Upon + * successfully getting a new access token, the original request is retried. + */ +@Singleton +class RefreshAuthenticator : Authenticator { + + /** + * A provider required to update tokens. + */ + var authenticatorProvider: AuthenticatorProvider? = null + + override fun authenticate( + route: Route?, + response: Response, + ): Request? { + val accessToken = requireNotNull( + response + .request + .header(name = HEADER_KEY_AUTHORIZATION) + ?.substringAfter(delimiter = HEADER_VALUE_BEARER_PREFIX), + ) + return when (val userId = parseJwtTokenDataOrNull(accessToken)?.userId) { + null -> { + // We unable to get the user ID, let's just let the 401 pass through. + null + } + + authenticatorProvider?.activeUserId -> { + // In order to prevent potential deadlocks or thread starvation we want the call + // to refresh the access token to be strictly synchronous with no internal thread + // hopping. + authenticatorProvider + ?.refreshAccessTokenSynchronously(userId) + ?.fold( + onFailure = { + authenticatorProvider?.logout(userId) + null + }, + onSuccess = { + response.request + .newBuilder() + .header( + name = HEADER_KEY_AUTHORIZATION, + value = "$HEADER_VALUE_BEARER_PREFIX${it.accessToken}", + ) + .build() + }, + ) + } + + else -> { + // We are no longer the active user, let's just cancel. + null + } + } + } +} diff --git a/app/src/main/java/com/x8bit/bitwarden/data/platform/datasource/network/di/PlatformNetworkModule.kt b/app/src/main/java/com/x8bit/bitwarden/data/platform/datasource/network/di/PlatformNetworkModule.kt index cd87085b3a..cef1e287af 100644 --- a/app/src/main/java/com/x8bit/bitwarden/data/platform/datasource/network/di/PlatformNetworkModule.kt +++ b/app/src/main/java/com/x8bit/bitwarden/data/platform/datasource/network/di/PlatformNetworkModule.kt @@ -1,5 +1,6 @@ package com.x8bit.bitwarden.data.platform.datasource.network.di +import com.x8bit.bitwarden.data.platform.datasource.network.authenticator.RefreshAuthenticator import com.x8bit.bitwarden.data.platform.datasource.network.interceptor.AuthTokenInterceptor import com.x8bit.bitwarden.data.platform.datasource.network.interceptor.BaseUrlInterceptors import com.x8bit.bitwarden.data.platform.datasource.network.retrofit.Retrofits @@ -36,16 +37,22 @@ object PlatformNetworkModule { @Singleton fun providesAuthTokenInterceptor(): AuthTokenInterceptor = AuthTokenInterceptor() + @Provides + @Singleton + fun providesRefreshAuthenticator(): RefreshAuthenticator = RefreshAuthenticator() + @Provides @Singleton fun provideRetrofits( authTokenInterceptor: AuthTokenInterceptor, baseUrlInterceptors: BaseUrlInterceptors, + refreshAuthenticator: RefreshAuthenticator, json: Json, ): Retrofits = RetrofitsImpl( authTokenInterceptor = authTokenInterceptor, baseUrlInterceptors = baseUrlInterceptors, + refreshAuthenticator = refreshAuthenticator, json = json, ) diff --git a/app/src/main/java/com/x8bit/bitwarden/data/platform/datasource/network/interceptor/AuthTokenInterceptor.kt b/app/src/main/java/com/x8bit/bitwarden/data/platform/datasource/network/interceptor/AuthTokenInterceptor.kt index c0e0630106..2bb2805f98 100644 --- a/app/src/main/java/com/x8bit/bitwarden/data/platform/datasource/network/interceptor/AuthTokenInterceptor.kt +++ b/app/src/main/java/com/x8bit/bitwarden/data/platform/datasource/network/interceptor/AuthTokenInterceptor.kt @@ -1,5 +1,7 @@ package com.x8bit.bitwarden.data.platform.datasource.network.interceptor +import com.x8bit.bitwarden.data.platform.datasource.network.util.HEADER_KEY_AUTHORIZATION +import com.x8bit.bitwarden.data.platform.datasource.network.util.HEADER_VALUE_BEARER_PREFIX import okhttp3.Interceptor import okhttp3.Response import java.io.IOException @@ -22,7 +24,10 @@ class AuthTokenInterceptor : Interceptor { val request = chain .request() .newBuilder() - .addHeader("Authorization", "Bearer $token") + .addHeader( + name = HEADER_KEY_AUTHORIZATION, + value = "$HEADER_VALUE_BEARER_PREFIX$token", + ) .build() return chain .proceed(request) diff --git a/app/src/main/java/com/x8bit/bitwarden/data/platform/datasource/network/retrofit/RetrofitsImpl.kt b/app/src/main/java/com/x8bit/bitwarden/data/platform/datasource/network/retrofit/RetrofitsImpl.kt index 2926b581a1..93e8026e34 100644 --- a/app/src/main/java/com/x8bit/bitwarden/data/platform/datasource/network/retrofit/RetrofitsImpl.kt +++ b/app/src/main/java/com/x8bit/bitwarden/data/platform/datasource/network/retrofit/RetrofitsImpl.kt @@ -1,6 +1,7 @@ package com.x8bit.bitwarden.data.platform.datasource.network.retrofit import com.jakewharton.retrofit2.converter.kotlinx.serialization.asConverterFactory +import com.x8bit.bitwarden.data.platform.datasource.network.authenticator.RefreshAuthenticator import com.x8bit.bitwarden.data.platform.datasource.network.core.ResultCallAdapterFactory import com.x8bit.bitwarden.data.platform.datasource.network.interceptor.AuthTokenInterceptor import com.x8bit.bitwarden.data.platform.datasource.network.interceptor.BaseUrlInterceptor @@ -17,6 +18,7 @@ import retrofit2.Retrofit class RetrofitsImpl( authTokenInterceptor: AuthTokenInterceptor, baseUrlInterceptors: BaseUrlInterceptors, + refreshAuthenticator: RefreshAuthenticator, json: Json, ) : Retrofits { //region Authenticated Retrofits @@ -73,6 +75,7 @@ class RetrofitsImpl( private val authenticatedOkHttpClient: OkHttpClient by lazy { baseOkHttpClient .newBuilder() + .authenticator(refreshAuthenticator) .addInterceptor(authTokenInterceptor) .build() } diff --git a/app/src/main/java/com/x8bit/bitwarden/data/platform/datasource/network/util/HeaderUtils.kt b/app/src/main/java/com/x8bit/bitwarden/data/platform/datasource/network/util/HeaderUtils.kt new file mode 100644 index 0000000000..92816d65f6 --- /dev/null +++ b/app/src/main/java/com/x8bit/bitwarden/data/platform/datasource/network/util/HeaderUtils.kt @@ -0,0 +1,11 @@ +package com.x8bit.bitwarden.data.platform.datasource.network.util + +/** + * The bearer prefix used for the 'authorization' headers value. + */ +const val HEADER_VALUE_BEARER_PREFIX: String = "Bearer " + +/** + * The key used for the 'authorization' headers. + */ +const val HEADER_KEY_AUTHORIZATION: String = "Authorization" diff --git a/app/src/main/java/com/x8bit/bitwarden/data/platform/manager/NetworkConfigManagerImpl.kt b/app/src/main/java/com/x8bit/bitwarden/data/platform/manager/NetworkConfigManagerImpl.kt index a6feb56ef1..868f205a4d 100644 --- a/app/src/main/java/com/x8bit/bitwarden/data/platform/manager/NetworkConfigManagerImpl.kt +++ b/app/src/main/java/com/x8bit/bitwarden/data/platform/manager/NetworkConfigManagerImpl.kt @@ -2,6 +2,7 @@ package com.x8bit.bitwarden.data.platform.manager import com.x8bit.bitwarden.data.auth.repository.AuthRepository import com.x8bit.bitwarden.data.auth.repository.model.AuthState +import com.x8bit.bitwarden.data.platform.datasource.network.authenticator.RefreshAuthenticator import com.x8bit.bitwarden.data.platform.datasource.network.interceptor.AuthTokenInterceptor import com.x8bit.bitwarden.data.platform.datasource.network.interceptor.BaseUrlInterceptors import com.x8bit.bitwarden.data.platform.manager.dispatcher.DispatcherManager @@ -18,6 +19,7 @@ class NetworkConfigManagerImpl( private val authTokenInterceptor: AuthTokenInterceptor, private val environmentRepository: EnvironmentRepository, private val baseUrlInterceptors: BaseUrlInterceptors, + refreshAuthenticator: RefreshAuthenticator, dispatcherManager: DispatcherManager, ) : NetworkConfigManager { @@ -41,5 +43,7 @@ class NetworkConfigManagerImpl( baseUrlInterceptors.environment = environment } .launchIn(scope) + + refreshAuthenticator.authenticatorProvider = authRepository } } diff --git a/app/src/main/java/com/x8bit/bitwarden/data/platform/manager/di/PlatformManagerModule.kt b/app/src/main/java/com/x8bit/bitwarden/data/platform/manager/di/PlatformManagerModule.kt index 0bc9f4949d..7568b21af7 100644 --- a/app/src/main/java/com/x8bit/bitwarden/data/platform/manager/di/PlatformManagerModule.kt +++ b/app/src/main/java/com/x8bit/bitwarden/data/platform/manager/di/PlatformManagerModule.kt @@ -1,6 +1,7 @@ package com.x8bit.bitwarden.data.platform.manager.di import com.x8bit.bitwarden.data.auth.repository.AuthRepository +import com.x8bit.bitwarden.data.platform.datasource.network.authenticator.RefreshAuthenticator import com.x8bit.bitwarden.data.platform.datasource.network.interceptor.AuthTokenInterceptor import com.x8bit.bitwarden.data.platform.datasource.network.interceptor.BaseUrlInterceptors import com.x8bit.bitwarden.data.platform.manager.NetworkConfigManager @@ -25,6 +26,7 @@ object PlatformManagerModule { @Singleton fun provideBitwardenDispatchers(): DispatcherManager = DispatcherManagerImpl() + @Suppress("LongParameterList") @Provides @Singleton fun provideNetworkConfigManager( @@ -32,6 +34,7 @@ object PlatformManagerModule { authTokenInterceptor: AuthTokenInterceptor, environmentRepository: EnvironmentRepository, baseUrlInterceptors: BaseUrlInterceptors, + refreshAuthenticator: RefreshAuthenticator, dispatcherManager: DispatcherManager, ): NetworkConfigManager = NetworkConfigManagerImpl( @@ -39,6 +42,7 @@ object PlatformManagerModule { authTokenInterceptor = authTokenInterceptor, environmentRepository = environmentRepository, baseUrlInterceptors = baseUrlInterceptors, + refreshAuthenticator = refreshAuthenticator, dispatcherManager = dispatcherManager, ) } diff --git a/app/src/test/java/com/x8bit/bitwarden/data/auth/repository/AuthRepositoryTest.kt b/app/src/test/java/com/x8bit/bitwarden/data/auth/repository/AuthRepositoryTest.kt index e827bddf17..507abc6389 100644 --- a/app/src/test/java/com/x8bit/bitwarden/data/auth/repository/AuthRepositoryTest.kt +++ b/app/src/test/java/com/x8bit/bitwarden/data/auth/repository/AuthRepositoryTest.kt @@ -12,6 +12,7 @@ import com.x8bit.bitwarden.data.auth.datasource.network.model.GetTokenResponseJs import com.x8bit.bitwarden.data.auth.datasource.network.model.KdfTypeJson import com.x8bit.bitwarden.data.auth.datasource.network.model.KdfTypeJson.PBKDF2_SHA256 import com.x8bit.bitwarden.data.auth.datasource.network.model.PreLoginResponseJson +import com.x8bit.bitwarden.data.auth.datasource.network.model.RefreshTokenResponseJson import com.x8bit.bitwarden.data.auth.datasource.network.model.RegisterRequestJson import com.x8bit.bitwarden.data.auth.datasource.network.model.RegisterResponseJson import com.x8bit.bitwarden.data.auth.datasource.network.service.AccountsService @@ -29,6 +30,7 @@ import com.x8bit.bitwarden.data.auth.repository.model.RegisterResult import com.x8bit.bitwarden.data.auth.repository.util.CaptchaCallbackTokenResult import com.x8bit.bitwarden.data.auth.repository.util.toSdkParams import com.x8bit.bitwarden.data.auth.repository.util.toUserState +import com.x8bit.bitwarden.data.auth.repository.util.toUserStateJson import com.x8bit.bitwarden.data.auth.util.toSdkParams import com.x8bit.bitwarden.data.platform.base.FakeDispatcherManager import com.x8bit.bitwarden.data.platform.manager.dispatcher.DispatcherManager @@ -114,12 +116,18 @@ class AuthRepositoryTest { @BeforeEach fun beforeEach() { clearMocks(identityService, accountsService, haveIBeenPwnedService) - mockkStatic(GET_TOKEN_RESPONSE_EXTENSIONS_PATH) + mockkStatic( + GET_TOKEN_RESPONSE_EXTENSIONS_PATH, + REFRESH_TOKEN_RESPONSE_EXTENSIONS_PATH, + ) } @AfterEach fun tearDown() { - unmockkStatic(GET_TOKEN_RESPONSE_EXTENSIONS_PATH) + unmockkStatic( + GET_TOKEN_RESPONSE_EXTENSIONS_PATH, + REFRESH_TOKEN_RESPONSE_EXTENSIONS_PATH, + ) } @Test @@ -240,6 +248,54 @@ class AuthRepositoryTest { } } + @Test + fun `refreshTokenSynchronously returns failure if not logged in`() = runTest { + fakeAuthDiskSource.userState = null + + val result = repository.refreshAccessTokenSynchronously(USER_ID_1) + + assertTrue(result.isFailure) + } + + @Test + fun `refreshTokenSynchronously returns failure and logs out on failure`() = runTest { + fakeAuthDiskSource.userState = SINGLE_USER_STATE_1 + coEvery { + identityService.refreshTokenSynchronously(REFRESH_TOKEN) + } returns Throwable("Fail").asFailure() + + assertTrue(repository.refreshAccessTokenSynchronously(USER_ID_1).isFailure) + + coVerify(exactly = 1) { + identityService.refreshTokenSynchronously(REFRESH_TOKEN) + } + } + + @Test + fun `refreshTokenSynchronously returns success and update user state on success`() = runTest { + fakeAuthDiskSource.userState = SINGLE_USER_STATE_1 + coEvery { + identityService.refreshTokenSynchronously(REFRESH_TOKEN) + } returns REFRESH_TOKEN_RESPONSE_JSON.asSuccess() + every { + REFRESH_TOKEN_RESPONSE_JSON.toUserStateJson( + userId = USER_ID_1, + previousUserState = SINGLE_USER_STATE_1, + ) + } returns SINGLE_USER_STATE_1 + + val result = repository.refreshAccessTokenSynchronously(USER_ID_1) + + assertEquals(REFRESH_TOKEN_RESPONSE_JSON.asSuccess(), result) + coVerify(exactly = 1) { + identityService.refreshTokenSynchronously(REFRESH_TOKEN) + REFRESH_TOKEN_RESPONSE_JSON.toUserStateJson( + userId = USER_ID_1, + previousUserState = SINGLE_USER_STATE_1, + ) + } + } + @Test fun `login when pre login fails should return Error with no message`() = runTest { coEvery { @@ -854,6 +910,36 @@ class AuthRepositoryTest { } } + @Test + fun `logout for non-active accounts should leave the active user unchanged`() = runTest { + // First populate multiple user accounts and active user is #3 + val initialUserState = MULTI_USER_STATE_2 + val finalUserState = initialUserState.copy( + accounts = initialUserState.accounts.filter { it.key != USER_ID_2 }, + ) + fakeAuthDiskSource.userState = initialUserState + + assertEquals(initialUserState, fakeAuthDiskSource.userState) + + repository.authStateFlow.test { + assertEquals(AuthState.Authenticated(ACCESS_TOKEN_3), awaitItem()) + + repository.logout(USER_ID_2) + + // The auth state does not actually change + expectNoEvents() + assertEquals(finalUserState, fakeAuthDiskSource.userState) + fakeAuthDiskSource.assertPrivateKey( + userId = USER_ID_2, + privateKey = null, + ) + fakeAuthDiskSource.assertUserKey( + userId = USER_ID_2, + userKey = null, + ) + } + } + @Test fun `getPasswordStrength should be based on password length`() = runTest { // TODO: Replace with SDK call (BIT-964) @@ -878,11 +964,17 @@ class AuthRepositoryTest { companion object { private const val GET_TOKEN_RESPONSE_EXTENSIONS_PATH = "com.x8bit.bitwarden.data.auth.repository.util.GetTokenResponseExtensionsKt" + private const val REFRESH_TOKEN_RESPONSE_EXTENSIONS_PATH = + "com.x8bit.bitwarden.data.auth.repository.util.RefreshTokenResponseExtensionsKt" private const val EMAIL = "test@bitwarden.com" + private const val EMAIL_2 = "test2@bitwarden.com" private const val PASSWORD = "password" private const val PASSWORD_HASH = "passwordHash" private const val ACCESS_TOKEN = "accessToken" private const val ACCESS_TOKEN_2 = "accessToken2" + private const val ACCESS_TOKEN_3 = "accessToken3" + private const val REFRESH_TOKEN = "refreshToken" + private const val REFRESH_TOKEN_2 = "refreshToken2" private const val CAPTCHA_KEY = "captcha" private const val DEFAULT_KDF_ITERATIONS = 600000 private const val ENCRYPTED_USER_KEY = "encryptedUserKey" @@ -890,9 +982,16 @@ class AuthRepositoryTest { private const val PRIVATE_KEY = "privateKey" private const val USER_ID_1 = "2a135b23-e1fb-42c9-bec3-573857bc8181" private const val USER_ID_2 = "b9d32ec0-6497-4582-9798-b350f53bfa02" + private const val USER_ID_3 = "3816ef34-0747-4133-9b7a-ba35d3768a68" private val PRE_LOGIN_SUCCESS = PreLoginResponseJson( kdfParams = PreLoginResponseJson.KdfParams.Pbkdf2(iterations = 1u), ) + private val REFRESH_TOKEN_RESPONSE_JSON = RefreshTokenResponseJson( + accessToken = ACCESS_TOKEN_2, + expiresIn = 3600, + refreshToken = REFRESH_TOKEN_2, + tokenType = "Bearer", + ) private val GET_TOKEN_RESPONSE_SUCCESS = GetTokenResponseJson.Success( accessToken = ACCESS_TOKEN, refreshToken = "refreshToken", @@ -928,7 +1027,7 @@ class AuthRepositoryTest { ), tokens = AccountJson.Tokens( accessToken = ACCESS_TOKEN, - refreshToken = "refreshToken", + refreshToken = REFRESH_TOKEN, ), settings = AccountJson.Settings( environmentUrlData = null, @@ -937,7 +1036,7 @@ class AuthRepositoryTest { private val ACCOUNT_2 = AccountJson( profile = AccountJson.Profile( userId = USER_ID_2, - email = "test2@bitwarden.com", + email = EMAIL_2, isEmailVerified = true, name = "Bitwarden Tester 2", hasPremium = false, @@ -959,6 +1058,31 @@ class AuthRepositoryTest { environmentUrlData = null, ), ) + private val ACCOUNT_3 = AccountJson( + profile = AccountJson.Profile( + userId = USER_ID_3, + email = "test3@bitwarden.com", + isEmailVerified = true, + name = "Bitwarden Tester 3", + hasPremium = false, + stamp = null, + organizationId = null, + avatarColorHex = null, + forcePasswordResetReason = null, + kdfType = KdfTypeJson.PBKDF2_SHA256, + kdfIterations = 400000, + kdfMemory = null, + kdfParallelism = null, + userDecryptionOptions = null, + ), + tokens = AccountJson.Tokens( + accessToken = ACCESS_TOKEN_3, + refreshToken = "refreshToken", + ), + settings = AccountJson.Settings( + environmentUrlData = null, + ), + ) private val SINGLE_USER_STATE_1 = UserStateJson( activeUserId = USER_ID_1, accounts = mapOf( @@ -978,6 +1102,14 @@ class AuthRepositoryTest { USER_ID_2 to ACCOUNT_2, ), ) + private val MULTI_USER_STATE_2 = UserStateJson( + activeUserId = USER_ID_3, + accounts = mapOf( + USER_ID_1 to ACCOUNT_1, + USER_ID_2 to ACCOUNT_2, + USER_ID_3 to ACCOUNT_3, + ), + ) private val VAULT_STATE = VaultState( unlockedVaultUserIds = setOf(USER_ID_1), ) diff --git a/app/src/test/java/com/x8bit/bitwarden/data/auth/repository/util/RefreshTokenResponseJsonTest.kt b/app/src/test/java/com/x8bit/bitwarden/data/auth/repository/util/RefreshTokenResponseJsonTest.kt new file mode 100644 index 0000000000..25a357d3d9 --- /dev/null +++ b/app/src/test/java/com/x8bit/bitwarden/data/auth/repository/util/RefreshTokenResponseJsonTest.kt @@ -0,0 +1,177 @@ +package com.x8bit.bitwarden.data.auth.repository.util + +import com.x8bit.bitwarden.data.auth.datasource.disk.model.AccountJson +import com.x8bit.bitwarden.data.auth.datasource.disk.model.EnvironmentUrlDataJson +import com.x8bit.bitwarden.data.auth.datasource.disk.model.UserStateJson +import com.x8bit.bitwarden.data.auth.datasource.network.model.KdfTypeJson +import com.x8bit.bitwarden.data.auth.datasource.network.model.RefreshTokenResponseJson +import com.x8bit.bitwarden.data.auth.repository.model.JwtTokenDataJson +import io.mockk.every +import io.mockk.mockkStatic +import io.mockk.unmockkStatic +import org.junit.jupiter.api.AfterEach +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.BeforeEach +import org.junit.jupiter.api.Test + +class RefreshTokenResponseJsonTest { + + @BeforeEach + fun beforeEach() { + mockkStatic(JWT_TOKEN_UTILS_PATH) + } + + @AfterEach + fun tearDown() { + unmockkStatic(JWT_TOKEN_UTILS_PATH) + } + + @Test + fun `toUserState updates the previous state`() { + every { parseJwtTokenDataOrNull(ACCESS_TOKEN_UPDATED) } returns JWT_TOKEN_DATA + + assertEquals( + SINGLE_USER_STATE_UPDATED, + REFRESH_TOKEN_RESPONSE.toUserStateJson( + userId = USER_ID_1, + previousUserState = SINGLE_USER_STATE, + ), + ) + } + + @Test + fun `toUserState updates the previous state for non-active user`() { + every { parseJwtTokenDataOrNull(ACCESS_TOKEN_UPDATED) } returns JWT_TOKEN_DATA + + assertEquals( + MULTI_USER_STATE_UPDATED, + REFRESH_TOKEN_RESPONSE.toUserStateJson( + userId = USER_ID_1, + previousUserState = MULTI_USER_STATE, + ), + ) + } +} + +private const val ACCESS_TOKEN = "accessToken" +private const val ACCESS_TOKEN_UPDATED = "updatedAccessToken" +private const val REFRESH_TOKEN = "refreshToken" +private const val REFRESH_TOKEN_UPDATED = "updatedRefreshToken" +private const val USER_ID_1 = "2a135b23-e1fb-42c9-bec3-573857bc8181" +private const val USER_ID_2 = "b9d32ec0-6497-4582-9798-b350f53bfa02" + +private const val JWT_TOKEN_UTILS_PATH = + "com.x8bit.bitwarden.data.auth.repository.util.JwtTokenUtilsKt" + +private val JWT_TOKEN_DATA = JwtTokenDataJson( + userId = USER_ID_1, + email = "updated@bitwarden.com", + isEmailVerified = false, + name = "Updated Bitwarden Tester", + expirationAsEpochTime = 1697495714, + hasPremium = true, + authenticationMethodsReference = listOf("Application"), +) + +private val REFRESH_TOKEN_RESPONSE = RefreshTokenResponseJson( + accessToken = ACCESS_TOKEN_UPDATED, + expiresIn = 3600, + refreshToken = REFRESH_TOKEN_UPDATED, + tokenType = "Bearer", +) + +private val ACCOUNT_1 = AccountJson( + profile = AccountJson.Profile( + userId = USER_ID_1, + email = "test@bitwarden.com", + isEmailVerified = true, + name = "Bitwarden Tester", + hasPremium = false, + stamp = null, + organizationId = null, + avatarColorHex = null, + forcePasswordResetReason = null, + kdfType = KdfTypeJson.ARGON2_ID, + kdfIterations = 600000, + kdfMemory = 16, + kdfParallelism = 4, + userDecryptionOptions = null, + ), + tokens = AccountJson.Tokens( + accessToken = ACCESS_TOKEN, + refreshToken = REFRESH_TOKEN, + ), + settings = AccountJson.Settings( + environmentUrlData = EnvironmentUrlDataJson.DEFAULT_US, + ), +) + +private val ACCOUNT_1_UPDATED = ACCOUNT_1.copy( + profile = ACCOUNT_1.profile.copy( + userId = JWT_TOKEN_DATA.userId, + email = JWT_TOKEN_DATA.email, + isEmailVerified = JWT_TOKEN_DATA.isEmailVerified, + name = JWT_TOKEN_DATA.name, + hasPremium = JWT_TOKEN_DATA.hasPremium, + ), + tokens = AccountJson.Tokens( + accessToken = ACCESS_TOKEN_UPDATED, + refreshToken = REFRESH_TOKEN_UPDATED, + ), +) + +private val ACCOUNT_2 = AccountJson( + profile = AccountJson.Profile( + userId = USER_ID_2, + email = "test2@bitwarden.com", + isEmailVerified = true, + name = "Bitwarden Tester 2", + hasPremium = false, + stamp = null, + organizationId = null, + avatarColorHex = null, + forcePasswordResetReason = null, + kdfType = KdfTypeJson.PBKDF2_SHA256, + kdfIterations = 400000, + kdfMemory = null, + kdfParallelism = null, + userDecryptionOptions = null, + ), + tokens = AccountJson.Tokens( + accessToken = "accessToken2", + refreshToken = "refreshToken2", + ), + settings = AccountJson.Settings( + environmentUrlData = EnvironmentUrlDataJson.DEFAULT_US, + ), +) + +private val SINGLE_USER_STATE = UserStateJson( + activeUserId = USER_ID_1, + accounts = mapOf( + USER_ID_1 to ACCOUNT_1, + ), +) + +private val SINGLE_USER_STATE_UPDATED = UserStateJson( + activeUserId = USER_ID_1, + accounts = mapOf( + USER_ID_1 to ACCOUNT_1_UPDATED, + ), +) + +private val MULTI_USER_STATE = UserStateJson( + activeUserId = USER_ID_2, + accounts = mapOf( + USER_ID_1 to ACCOUNT_1, + USER_ID_2 to ACCOUNT_2, + ), +) + +private val MULTI_USER_STATE_UPDATED = UserStateJson( + activeUserId = USER_ID_2, + accounts = mapOf( + USER_ID_1 to ACCOUNT_1_UPDATED, + USER_ID_2 to ACCOUNT_2, + ), +) diff --git a/app/src/test/java/com/x8bit/bitwarden/data/platform/datasource/network/authenticator/RefreshAuthenticatorTests.kt b/app/src/test/java/com/x8bit/bitwarden/data/platform/datasource/network/authenticator/RefreshAuthenticatorTests.kt new file mode 100644 index 0000000000..abf55b67f8 --- /dev/null +++ b/app/src/test/java/com/x8bit/bitwarden/data/platform/datasource/network/authenticator/RefreshAuthenticatorTests.kt @@ -0,0 +1,143 @@ +package com.x8bit.bitwarden.data.platform.datasource.network.authenticator + +import com.x8bit.bitwarden.data.auth.datasource.network.model.RefreshTokenResponseJson +import com.x8bit.bitwarden.data.auth.repository.model.JwtTokenDataJson +import com.x8bit.bitwarden.data.auth.repository.util.parseJwtTokenDataOrNull +import com.x8bit.bitwarden.data.platform.util.asFailure +import com.x8bit.bitwarden.data.platform.util.asSuccess +import io.mockk.every +import io.mockk.just +import io.mockk.mockk +import io.mockk.mockkStatic +import io.mockk.runs +import io.mockk.unmockkStatic +import io.mockk.verify +import okhttp3.Protocol +import okhttp3.Request +import okhttp3.Response +import org.junit.jupiter.api.AfterEach +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.Assertions.assertNull +import org.junit.jupiter.api.BeforeEach +import org.junit.jupiter.api.Test + +class RefreshAuthenticatorTests { + private lateinit var authenticator: RefreshAuthenticator + private val authenticatorProvider: AuthenticatorProvider = mockk() + + @BeforeEach + fun setup() { + authenticator = RefreshAuthenticator() + authenticator.authenticatorProvider = authenticatorProvider + + mockkStatic(JWT_TOKEN_UTILS_PATH) + } + + @AfterEach + fun tearDown() { + unmockkStatic(JWT_TOKEN_UTILS_PATH) + } + + @Test + fun `RefreshAuthenticator returns null if the request is for a different user`() { + every { parseJwtTokenDataOrNull(JWT_ACCESS_TOKEN) } returns JTW_TOKEN + every { authenticatorProvider.activeUserId } returns "different_user_id" + + assertNull(authenticator.authenticate(null, RESPONSE_401)) + + verify(exactly = 1) { + authenticatorProvider.activeUserId + } + } + + @Test + fun `RefreshAuthenticator returns null if API has no authorization user ID`() { + every { parseJwtTokenDataOrNull(JWT_ACCESS_TOKEN) } returns null + + assertNull(authenticator.authenticate(null, RESPONSE_401)) + + verify(exactly = 0) { + authenticatorProvider.activeUserId + authenticatorProvider.refreshAccessTokenSynchronously(any()) + authenticatorProvider.logout(any()) + } + } + + @Suppress("MaxLineLength") + @Test + fun `RefreshAuthenticator returns null and logs out when request is for active user and refresh is failure`() { + every { parseJwtTokenDataOrNull(JWT_ACCESS_TOKEN) } returns JTW_TOKEN + every { authenticatorProvider.activeUserId } returns USER_ID + every { + authenticatorProvider.refreshAccessTokenSynchronously(USER_ID) + } returns Throwable("Fail").asFailure() + every { authenticatorProvider.logout(USER_ID) } just runs + + assertNull(authenticator.authenticate(null, RESPONSE_401)) + + verify(exactly = 1) { + authenticatorProvider.activeUserId + authenticatorProvider.refreshAccessTokenSynchronously(USER_ID) + authenticatorProvider.logout(USER_ID) + } + } + + @Suppress("MaxLineLength") + @Test + fun `RefreshAuthenticator returns updated request when request is for active user and refresh is success`() { + val newAccessToken = "newAccessToken" + val refreshResponse = RefreshTokenResponseJson( + accessToken = newAccessToken, + expiresIn = 3600, + refreshToken = "refreshToken", + tokenType = "Bearer", + ) + every { parseJwtTokenDataOrNull(JWT_ACCESS_TOKEN) } returns JTW_TOKEN + every { authenticatorProvider.activeUserId } returns USER_ID + every { + authenticatorProvider.refreshAccessTokenSynchronously(USER_ID) + } returns refreshResponse.asSuccess() + + val authenticatedRequest = authenticator.authenticate(null, RESPONSE_401) + + // The okhttp3 Request is not a data class and does not implement equals + // so we are manually checking that the correct header is added. + assertEquals( + "Bearer $newAccessToken", + authenticatedRequest!!.header("Authorization"), + ) + verify(exactly = 1) { + authenticatorProvider.activeUserId + authenticatorProvider.refreshAccessTokenSynchronously(USER_ID) + } + } +} + +private const val JWT_TOKEN_UTILS_PATH = + "com.x8bit.bitwarden.data.auth.repository.util.JwtTokenUtilsKt" + +private const val USER_ID = "2a135b23-e1fb-42c9-bec3-573857bc8181" + +private val JTW_TOKEN = JwtTokenDataJson( + userId = USER_ID, + email = "test@bitwarden.com", + isEmailVerified = true, + name = "Bitwarden Tester", + expirationAsEpochTime = 1697495714, + hasPremium = false, + authenticationMethodsReference = listOf("Application"), +) + +private const val JWT_ACCESS_TOKEN = "jwt" + +private val RESPONSE_401 = Response.Builder() + .code(401) + .request( + request = Request.Builder() + .header(name = "Authorization", value = "Bearer $JWT_ACCESS_TOKEN") + .url("https://www.bitwarden.com") + .build(), + ) + .protocol(Protocol.HTTP_2) + .message("Unauthenticated") + .build() diff --git a/app/src/test/java/com/x8bit/bitwarden/data/platform/datasource/network/retrofit/RetrofitsTest.kt b/app/src/test/java/com/x8bit/bitwarden/data/platform/datasource/network/retrofit/RetrofitsTest.kt index cfcde12093..520e8ec4af 100644 --- a/app/src/test/java/com/x8bit/bitwarden/data/platform/datasource/network/retrofit/RetrofitsTest.kt +++ b/app/src/test/java/com/x8bit/bitwarden/data/platform/datasource/network/retrofit/RetrofitsTest.kt @@ -1,5 +1,6 @@ package com.x8bit.bitwarden.data.platform.datasource.network.retrofit +import com.x8bit.bitwarden.data.platform.datasource.network.authenticator.RefreshAuthenticator import com.x8bit.bitwarden.data.platform.datasource.network.interceptor.AuthTokenInterceptor import com.x8bit.bitwarden.data.platform.datasource.network.interceptor.BaseUrlInterceptors import io.mockk.every @@ -8,6 +9,7 @@ import io.mockk.slot import kotlinx.coroutines.runBlocking import kotlinx.serialization.json.Json import kotlinx.serialization.json.JsonObject +import okhttp3.Authenticator import okhttp3.Interceptor import okhttp3.mockwebserver.MockResponse import okhttp3.mockwebserver.MockWebServer @@ -35,12 +37,16 @@ class RetrofitsTest { mockIntercept { isEventsInterceptorCalled = true } } } + private val refreshAuthenticator = mockk { + mockAuthenticate { isRefreshAuthenticatorCalled = true } + } private val json = Json private val server = MockWebServer() private val retrofits = RetrofitsImpl( authTokenInterceptor = authTokenInterceptor, baseUrlInterceptors = baseUrlInterceptors, + refreshAuthenticator = refreshAuthenticator, json = json, ) @@ -48,6 +54,7 @@ class RetrofitsTest { private var isApiInterceptorCalled = false private var isIdentityInterceptorCalled = false private var isEventsInterceptorCalled = false + private var isRefreshAuthenticatorCalled = false @Before fun setUp() { @@ -59,6 +66,49 @@ class RetrofitsTest { server.shutdown() } + @Test + fun `authenticatedApiRetrofit should not invoke the RefreshAuthenticator on success`() = + runBlocking { + val testApi = retrofits + .authenticatedApiRetrofit + .createMockRetrofit() + .create() + + server.enqueue(MockResponse().setBody("""{}""")) + + testApi.test() + + assertFalse(isRefreshAuthenticatorCalled) + } + + @Test + fun `authenticatedApiRetrofit should invoke the RefreshAuthenticator on 401`() = runBlocking { + val testApi = retrofits + .authenticatedApiRetrofit + .createMockRetrofit() + .create() + + server.enqueue(MockResponse().setResponseCode(401).setBody("""{}""")) + + testApi.test() + + assertTrue(isRefreshAuthenticatorCalled) + } + + @Test + fun `unauthenticatedApiRetrofit should not invoke the RefreshAuthenticator`() = runBlocking { + val testApi = retrofits + .unauthenticatedApiRetrofit + .createMockRetrofit() + .create() + + server.enqueue(MockResponse().setResponseCode(401).setBody("""{}""")) + + testApi.test() + + assertFalse(isRefreshAuthenticatorCalled) + } + @Test fun `authenticatedApiRetrofit should invoke the correct interceptors`() = runBlocking { val testApi = retrofits @@ -138,7 +188,18 @@ class RetrofitsTest { interface TestApi { @GET("/test") - suspend fun test(): JsonObject + suspend fun test(): Result +} + +/** + * Mocks the given [Authenticator] such that the [Authenticator.authenticate] is a no-op and + * returns `null` but triggers the [isCalledCallback]. + */ +private fun Authenticator.mockAuthenticate(isCalledCallback: () -> Unit) { + every { authenticate(any(), any()) } answers { + isCalledCallback() + null + } } /** diff --git a/app/src/test/java/com/x8bit/bitwarden/data/platform/manager/NetworkConfigManagerTest.kt b/app/src/test/java/com/x8bit/bitwarden/data/platform/manager/NetworkConfigManagerTest.kt index 2328551591..6a6efc018c 100644 --- a/app/src/test/java/com/x8bit/bitwarden/data/platform/manager/NetworkConfigManagerTest.kt +++ b/app/src/test/java/com/x8bit/bitwarden/data/platform/manager/NetworkConfigManagerTest.kt @@ -3,6 +3,7 @@ package com.x8bit.bitwarden.data.platform.manager import com.x8bit.bitwarden.data.auth.repository.AuthRepository import com.x8bit.bitwarden.data.auth.repository.model.AuthState import com.x8bit.bitwarden.data.platform.base.FakeDispatcherManager +import com.x8bit.bitwarden.data.platform.datasource.network.authenticator.RefreshAuthenticator import com.x8bit.bitwarden.data.platform.datasource.network.interceptor.AuthTokenInterceptor import com.x8bit.bitwarden.data.platform.datasource.network.interceptor.BaseUrlInterceptors import com.x8bit.bitwarden.data.platform.manager.dispatcher.DispatcherManager @@ -22,7 +23,7 @@ class NetworkConfigManagerTest { private val mutableAuthStateFlow = MutableStateFlow(AuthState.Uninitialized) private val mutableEnvironmentStateFlow = MutableStateFlow(Environment.Us) - private val authRepository: AuthRepository = mockk() { + private val authRepository: AuthRepository = mockk { every { authStateFlow } returns mutableAuthStateFlow } @@ -30,6 +31,7 @@ class NetworkConfigManagerTest { every { environmentStateFlow } returns mutableEnvironmentStateFlow } + private val refreshAuthenticator = RefreshAuthenticator() private val authTokenInterceptor = AuthTokenInterceptor() private val baseUrlInterceptors = BaseUrlInterceptors() @@ -42,10 +44,19 @@ class NetworkConfigManagerTest { authTokenInterceptor = authTokenInterceptor, environmentRepository = environmentRepository, baseUrlInterceptors = baseUrlInterceptors, + refreshAuthenticator = refreshAuthenticator, dispatcherManager = dispatcherManager, ) } + @Test + fun `authenticatorProvider should be set on initialization`() { + assertEquals( + authRepository, + refreshAuthenticator.authenticatorProvider, + ) + } + @Test fun `changes in the AuthState should update the AuthTokenInterceptor`() { mutableAuthStateFlow.value = AuthState.Uninitialized