From d95e5df2a7b0852f440dee3fffff9d6c4ac33263 Mon Sep 17 00:00:00 2001 From: Brian Yencho Date: Mon, 8 Jan 2024 22:56:35 -0600 Subject: [PATCH] Add VaultState.unlockingVaultUserIds and clean up the vault unlock logic (#546) --- .../vault/repository/VaultRepositoryImpl.kt | 185 +++++++++++------- .../data/vault/repository/model/VaultState.kt | 3 + .../auth/repository/AuthRepositoryTest.kt | 6 +- .../util/UserStateJsonExtensionsTest.kt | 2 + .../vault/repository/VaultRepositoryTest.kt | 40 ++++ 5 files changed, 169 insertions(+), 67 deletions(-) diff --git a/app/src/main/java/com/x8bit/bitwarden/data/vault/repository/VaultRepositoryImpl.kt b/app/src/main/java/com/x8bit/bitwarden/data/vault/repository/VaultRepositoryImpl.kt index cd60633a3d..824e484c03 100644 --- a/app/src/main/java/com/x8bit/bitwarden/data/vault/repository/VaultRepositoryImpl.kt +++ b/app/src/main/java/com/x8bit/bitwarden/data/vault/repository/VaultRepositoryImpl.kt @@ -91,14 +91,17 @@ class VaultRepositoryImpl( private var syncJob: Job = Job().apply { complete() } - private var willSyncAfterUnlock = false - private val activeUserId: String? get() = authDiskSource.userState?.activeUserId private val mutableTotpCodeFlow = bufferedMutableSharedFlow() private val mutableVaultStateStateFlow = - MutableStateFlow(VaultState(unlockedVaultUserIds = emptySet())) + MutableStateFlow( + VaultState( + unlockedVaultUserIds = emptySet(), + unlockingVaultUserIds = emptySet(), + ), + ) private val mutableSendDataStateFlow = MutableStateFlow>(DataState.Loading) @@ -194,8 +197,8 @@ class VaultRepositoryImpl( } override fun sync() { - if (!syncJob.isCompleted || willSyncAfterUnlock) return val userId = activeUserId ?: return + if (!syncJob.isCompleted || isVaultUnlocking(userId)) return mutableCiphersStateFlow.updateToPendingOrLoading() mutableFoldersStateFlow.updateToPendingOrLoading() mutableCollectionsStateFlow.updateToPendingOrLoading() @@ -290,22 +293,15 @@ class VaultRepositoryImpl( override suspend fun unlockVaultAndSyncForCurrentUser( masterPassword: String, ): VaultUnlockResult { - val userState = authDiskSource.userState + val userId = activeUserId ?: return VaultUnlockResult.InvalidStateError + val userKey = authDiskSource.getUserKey(userId = userId) ?: return VaultUnlockResult.InvalidStateError - val userKey = authDiskSource.getUserKey(userId = userState.activeUserId) - ?: return VaultUnlockResult.InvalidStateError - val privateKey = authDiskSource.getPrivateKey(userId = userState.activeUserId) - ?: return VaultUnlockResult.InvalidStateError - val organizationKeys = authDiskSource - .getOrganizationKeys(userId = userState.activeUserId) - return unlockVault( - userId = userState.activeUserId, - masterPassword = masterPassword, - email = userState.activeAccount.profile.email, - kdf = userState.activeAccount.profile.toSdkParams(), - userKey = userKey, - privateKey = privateKey, - organizationKeys = organizationKeys, + return unlockVaultForUser( + userId = userId, + initUserCryptoMethod = InitUserCryptoMethod.Password( + password = masterPassword, + userKey = userKey, + ), ) .also { if (it is VaultUnlockResult.Success) { @@ -323,53 +319,17 @@ class VaultRepositoryImpl( privateKey: String, organizationKeys: Map?, ): VaultUnlockResult = - flow { - willSyncAfterUnlock = true - emit( - vaultSdkSource - .initializeCrypto( - userId = userId, - request = InitUserCryptoRequest( - kdfParams = kdf, - email = email, - privateKey = privateKey, - method = InitUserCryptoMethod.Password( - password = masterPassword, - userKey = userKey, - ), - ), - ) - .flatMap { result -> - // Initialize the SDK for organizations if necessary - if (organizationKeys != null && - result is InitializeCryptoResult.Success - ) { - vaultSdkSource.initializeOrganizationCrypto( - userId = userId, - request = InitOrgCryptoRequest( - organizationKeys = organizationKeys, - ), - ) - } else { - result.asSuccess() - } - } - .fold( - onFailure = { VaultUnlockResult.GenericError }, - onSuccess = { initializeCryptoResult -> - initializeCryptoResult - .toVaultUnlockResult() - .also { - if (it is VaultUnlockResult.Success) { - setVaultToUnlocked(userId = userId) - } - } - }, - ), - ) - } - .onCompletion { willSyncAfterUnlock = false } - .first() + unlockVaultInternal( + userId = userId, + email = email, + kdf = kdf, + privateKey = privateKey, + initUserCryptoMethod = InitUserCryptoMethod.Password( + password = masterPassword, + userKey = userKey, + ), + organizationKeys = organizationKeys, + ) override suspend fun createCipher(cipherView: CipherView): CreateCipherResult { val userId = requireNotNull(activeUserId) @@ -506,6 +466,25 @@ class VaultRepositoryImpl( } } + private fun setVaultToUnlocking(userId: String) { + mutableVaultStateStateFlow.update { + it.copy( + unlockingVaultUserIds = it.unlockingVaultUserIds + userId, + ) + } + } + + private fun setVaultToNotUnlocking(userId: String) { + mutableVaultStateStateFlow.update { + it.copy( + unlockingVaultUserIds = it.unlockingVaultUserIds - userId, + ) + } + } + + private fun isVaultUnlocking(userId: String) = + userId in mutableVaultStateStateFlow.value.unlockingVaultUserIds + private fun storeProfileData( syncResponse: SyncResponseJson, ) { @@ -536,6 +515,80 @@ class VaultRepositoryImpl( } } + @Suppress("ReturnCount") + private suspend fun unlockVaultForUser( + userId: String, + initUserCryptoMethod: InitUserCryptoMethod, + ): VaultUnlockResult { + val account = authDiskSource.userState?.accounts?.get(userId) + ?: return VaultUnlockResult.InvalidStateError + val privateKey = authDiskSource.getPrivateKey(userId = userId) + ?: return VaultUnlockResult.InvalidStateError + val organizationKeys = authDiskSource + .getOrganizationKeys(userId = userId) + return unlockVaultInternal( + userId = userId, + email = account.profile.email, + kdf = account.profile.toSdkParams(), + privateKey = privateKey, + initUserCryptoMethod = initUserCryptoMethod, + organizationKeys = organizationKeys, + ) + } + + private suspend fun unlockVaultInternal( + userId: String, + email: String, + kdf: Kdf, + privateKey: String, + initUserCryptoMethod: InitUserCryptoMethod, + organizationKeys: Map?, + ): VaultUnlockResult = + flow { + setVaultToUnlocking(userId = userId) + emit( + vaultSdkSource + .initializeCrypto( + userId = userId, + request = InitUserCryptoRequest( + kdfParams = kdf, + email = email, + privateKey = privateKey, + method = initUserCryptoMethod, + ), + ) + .flatMap { result -> + // Initialize the SDK for organizations if necessary + if (organizationKeys != null && + result is InitializeCryptoResult.Success + ) { + vaultSdkSource.initializeOrganizationCrypto( + userId = userId, + request = InitOrgCryptoRequest( + organizationKeys = organizationKeys, + ), + ) + } else { + result.asSuccess() + } + } + .fold( + onFailure = { VaultUnlockResult.GenericError }, + onSuccess = { initializeCryptoResult -> + initializeCryptoResult + .toVaultUnlockResult() + .also { + if (it is VaultUnlockResult.Success) { + setVaultToUnlocked(userId = userId) + } + } + }, + ), + ) + } + .onCompletion { setVaultToNotUnlocking(userId = userId) } + .first() + private suspend fun unlockVaultForOrganizationsIfNecessary( syncResponse: SyncResponseJson, ) { diff --git a/app/src/main/java/com/x8bit/bitwarden/data/vault/repository/model/VaultState.kt b/app/src/main/java/com/x8bit/bitwarden/data/vault/repository/model/VaultState.kt index ec59e41af9..0b9debbaea 100644 --- a/app/src/main/java/com/x8bit/bitwarden/data/vault/repository/model/VaultState.kt +++ b/app/src/main/java/com/x8bit/bitwarden/data/vault/repository/model/VaultState.kt @@ -4,7 +4,10 @@ package com.x8bit.bitwarden.data.vault.repository.model * General description of the vault across multiple users. * * @property unlockedVaultUserIds The user IDs for all users that currently have unlocked vaults. + * @property unlockedVaultUserIds The user IDs for all users that are actively unlocking their + * vaults. */ data class VaultState( val unlockedVaultUserIds: Set, + val unlockingVaultUserIds: Set, ) diff --git a/app/src/test/java/com/x8bit/bitwarden/data/auth/repository/AuthRepositoryTest.kt b/app/src/test/java/com/x8bit/bitwarden/data/auth/repository/AuthRepositoryTest.kt index 88215f1dd6..dc1a22d7ed 100644 --- a/app/src/test/java/com/x8bit/bitwarden/data/auth/repository/AuthRepositoryTest.kt +++ b/app/src/test/java/com/x8bit/bitwarden/data/auth/repository/AuthRepositoryTest.kt @@ -179,7 +179,10 @@ class AuthRepositoryTest { repository.userStateFlow.value, ) - val emptyVaultState = VaultState(unlockedVaultUserIds = emptySet()) + val emptyVaultState = VaultState( + unlockedVaultUserIds = emptySet(), + unlockingVaultUserIds = emptySet(), + ) mutableVaultStateFlow.value = emptyVaultState assertEquals( MULTI_USER_STATE.toUserState( @@ -1535,6 +1538,7 @@ class AuthRepositoryTest { ) private val VAULT_STATE = VaultState( unlockedVaultUserIds = setOf(USER_ID_1), + unlockingVaultUserIds = emptySet(), ) } } diff --git a/app/src/test/java/com/x8bit/bitwarden/data/auth/repository/util/UserStateJsonExtensionsTest.kt b/app/src/test/java/com/x8bit/bitwarden/data/auth/repository/util/UserStateJsonExtensionsTest.kt index d3a7c9e963..5b904b660e 100644 --- a/app/src/test/java/com/x8bit/bitwarden/data/auth/repository/util/UserStateJsonExtensionsTest.kt +++ b/app/src/test/java/com/x8bit/bitwarden/data/auth/repository/util/UserStateJsonExtensionsTest.kt @@ -135,6 +135,7 @@ class UserStateJsonExtensionsTest { .toUserState( vaultState = VaultState( unlockedVaultUserIds = setOf("activeUserId"), + unlockingVaultUserIds = emptySet(), ), userOrganizationsList = listOf( UserOrganizations( @@ -198,6 +199,7 @@ class UserStateJsonExtensionsTest { .toUserState( vaultState = VaultState( unlockedVaultUserIds = emptySet(), + unlockingVaultUserIds = emptySet(), ), userOrganizationsList = listOf( UserOrganizations( diff --git a/app/src/test/java/com/x8bit/bitwarden/data/vault/repository/VaultRepositoryTest.kt b/app/src/test/java/com/x8bit/bitwarden/data/vault/repository/VaultRepositoryTest.kt index 926b2cc3ba..9930abfce7 100644 --- a/app/src/test/java/com/x8bit/bitwarden/data/vault/repository/VaultRepositoryTest.kt +++ b/app/src/test/java/com/x8bit/bitwarden/data/vault/repository/VaultRepositoryTest.kt @@ -538,6 +538,7 @@ class VaultRepositoryTest { assertEquals( VaultState( unlockedVaultUserIds = setOf(userId), + unlockingVaultUserIds = emptySet(), ), vaultRepository.vaultStateFlow.value, ) @@ -547,6 +548,7 @@ class VaultRepositoryTest { assertEquals( VaultState( unlockedVaultUserIds = emptySet(), + unlockingVaultUserIds = emptySet(), ), vaultRepository.vaultStateFlow.value, ) @@ -562,6 +564,7 @@ class VaultRepositoryTest { assertEquals( VaultState( unlockedVaultUserIds = setOf(userId), + unlockingVaultUserIds = emptySet(), ), vaultRepository.vaultStateFlow.value, ) @@ -571,6 +574,7 @@ class VaultRepositoryTest { assertEquals( VaultState( unlockedVaultUserIds = emptySet(), + unlockingVaultUserIds = emptySet(), ), vaultRepository.vaultStateFlow.value, ) @@ -634,6 +638,7 @@ class VaultRepositoryTest { assertEquals( VaultState( unlockedVaultUserIds = emptySet(), + unlockingVaultUserIds = emptySet(), ), vaultRepository.vaultStateFlow.value, ) @@ -649,6 +654,7 @@ class VaultRepositoryTest { assertEquals( VaultState( unlockedVaultUserIds = setOf("mockId-1"), + unlockingVaultUserIds = emptySet(), ), vaultRepository.vaultStateFlow.value, ) @@ -794,6 +800,7 @@ class VaultRepositoryTest { assertEquals( VaultState( unlockedVaultUserIds = emptySet(), + unlockingVaultUserIds = emptySet(), ), vaultRepository.vaultStateFlow.value, ) @@ -809,6 +816,7 @@ class VaultRepositoryTest { assertEquals( VaultState( unlockedVaultUserIds = emptySet(), + unlockingVaultUserIds = emptySet(), ), vaultRepository.vaultStateFlow.value, ) @@ -861,6 +869,7 @@ class VaultRepositoryTest { assertEquals( VaultState( unlockedVaultUserIds = emptySet(), + unlockingVaultUserIds = emptySet(), ), vaultRepository.vaultStateFlow.value, ) @@ -876,6 +885,7 @@ class VaultRepositoryTest { assertEquals( VaultState( unlockedVaultUserIds = emptySet(), + unlockingVaultUserIds = emptySet(), ), vaultRepository.vaultStateFlow.value, ) @@ -913,6 +923,7 @@ class VaultRepositoryTest { assertEquals( VaultState( unlockedVaultUserIds = emptySet(), + unlockingVaultUserIds = emptySet(), ), vaultRepository.vaultStateFlow.value, ) @@ -925,6 +936,7 @@ class VaultRepositoryTest { assertEquals( VaultState( unlockedVaultUserIds = emptySet(), + unlockingVaultUserIds = emptySet(), ), vaultRepository.vaultStateFlow.value, ) @@ -974,6 +986,7 @@ class VaultRepositoryTest { assertEquals( VaultState( unlockedVaultUserIds = emptySet(), + unlockingVaultUserIds = emptySet(), ), vaultRepository.vaultStateFlow.value, ) @@ -986,6 +999,7 @@ class VaultRepositoryTest { assertEquals( VaultState( unlockedVaultUserIds = emptySet(), + unlockingVaultUserIds = emptySet(), ), vaultRepository.vaultStateFlow.value, ) @@ -999,6 +1013,7 @@ class VaultRepositoryTest { assertEquals( VaultState( unlockedVaultUserIds = emptySet(), + unlockingVaultUserIds = emptySet(), ), vaultRepository.vaultStateFlow.value, ) @@ -1012,6 +1027,7 @@ class VaultRepositoryTest { assertEquals( VaultState( unlockedVaultUserIds = emptySet(), + unlockingVaultUserIds = emptySet(), ), vaultRepository.vaultStateFlow.value, ) @@ -1024,6 +1040,7 @@ class VaultRepositoryTest { assertEquals( VaultState( unlockedVaultUserIds = emptySet(), + unlockingVaultUserIds = emptySet(), ), vaultRepository.vaultStateFlow.value, ) @@ -1045,6 +1062,7 @@ class VaultRepositoryTest { assertEquals( VaultState( unlockedVaultUserIds = emptySet(), + unlockingVaultUserIds = emptySet(), ), vaultRepository.vaultStateFlow.value, ) @@ -1057,6 +1075,7 @@ class VaultRepositoryTest { assertEquals( VaultState( unlockedVaultUserIds = emptySet(), + unlockingVaultUserIds = emptySet(), ), vaultRepository.vaultStateFlow.value, ) @@ -1077,6 +1096,7 @@ class VaultRepositoryTest { assertEquals( VaultState( unlockedVaultUserIds = emptySet(), + unlockingVaultUserIds = emptySet(), ), vaultRepository.vaultStateFlow.value, ) @@ -1114,6 +1134,7 @@ class VaultRepositoryTest { assertEquals( VaultState( unlockedVaultUserIds = emptySet(), + unlockingVaultUserIds = emptySet(), ), vaultRepository.vaultStateFlow.value, ) @@ -1132,6 +1153,7 @@ class VaultRepositoryTest { assertEquals( VaultState( unlockedVaultUserIds = setOf(userId), + unlockingVaultUserIds = emptySet(), ), vaultRepository.vaultStateFlow.value, ) @@ -1186,6 +1208,7 @@ class VaultRepositoryTest { assertEquals( VaultState( unlockedVaultUserIds = emptySet(), + unlockingVaultUserIds = emptySet(), ), vaultRepository.vaultStateFlow.value, ) @@ -1204,6 +1227,7 @@ class VaultRepositoryTest { assertEquals( VaultState( unlockedVaultUserIds = emptySet(), + unlockingVaultUserIds = emptySet(), ), vaultRepository.vaultStateFlow.value, ) @@ -1258,6 +1282,7 @@ class VaultRepositoryTest { assertEquals( VaultState( unlockedVaultUserIds = emptySet(), + unlockingVaultUserIds = emptySet(), ), vaultRepository.vaultStateFlow.value, ) @@ -1276,6 +1301,7 @@ class VaultRepositoryTest { assertEquals( VaultState( unlockedVaultUserIds = emptySet(), + unlockingVaultUserIds = emptySet(), ), vaultRepository.vaultStateFlow.value, ) @@ -1328,6 +1354,7 @@ class VaultRepositoryTest { assertEquals( VaultState( unlockedVaultUserIds = emptySet(), + unlockingVaultUserIds = emptySet(), ), vaultRepository.vaultStateFlow.value, ) @@ -1346,6 +1373,7 @@ class VaultRepositoryTest { assertEquals( VaultState( unlockedVaultUserIds = emptySet(), + unlockingVaultUserIds = emptySet(), ), vaultRepository.vaultStateFlow.value, ) @@ -1398,6 +1426,7 @@ class VaultRepositoryTest { assertEquals( VaultState( unlockedVaultUserIds = emptySet(), + unlockingVaultUserIds = emptySet(), ), vaultRepository.vaultStateFlow.value, ) @@ -1416,6 +1445,7 @@ class VaultRepositoryTest { assertEquals( VaultState( unlockedVaultUserIds = emptySet(), + unlockingVaultUserIds = emptySet(), ), vaultRepository.vaultStateFlow.value, ) @@ -1477,6 +1507,16 @@ class VaultRepositoryTest { organizationKeys = organizationKeys, ) } + + // The given userId is marked as unlocking + assertEquals( + VaultState( + unlockedVaultUserIds = emptySet(), + unlockingVaultUserIds = setOf(userId), + ), + vaultRepository.vaultStateFlow.value, + ) + // Does nothing because we are blocking vaultRepository.sync() scope.cancel()