diff --git a/app/src/main/java/com/x8bit/bitwarden/data/vault/manager/VaultLockManager.kt b/app/src/main/java/com/x8bit/bitwarden/data/vault/manager/VaultLockManager.kt index 0ff153f6a9..00bdad6ccd 100644 --- a/app/src/main/java/com/x8bit/bitwarden/data/vault/manager/VaultLockManager.kt +++ b/app/src/main/java/com/x8bit/bitwarden/data/vault/manager/VaultLockManager.kt @@ -50,4 +50,9 @@ interface VaultLockManager { initUserCryptoMethod: InitUserCryptoMethod, organizationKeys: Map?, ): VaultUnlockResult + + /** + * Suspends until the vault for the given [userId] is unlocked. + */ + suspend fun waitUntilUnlocked(userId: String) } diff --git a/app/src/main/java/com/x8bit/bitwarden/data/vault/manager/VaultLockManagerImpl.kt b/app/src/main/java/com/x8bit/bitwarden/data/vault/manager/VaultLockManagerImpl.kt index f614e7c75d..da15b522b1 100644 --- a/app/src/main/java/com/x8bit/bitwarden/data/vault/manager/VaultLockManagerImpl.kt +++ b/app/src/main/java/com/x8bit/bitwarden/data/vault/manager/VaultLockManagerImpl.kt @@ -176,6 +176,17 @@ class VaultLockManagerImpl( .onCompletion { setVaultToNotUnlocking(userId = userId) } .first() + override suspend fun waitUntilUnlocked(userId: String) { + vaultUnlockDataStateFlow + .map { vaultUnlockDataList -> + // Get the list of currently-unlocked vaults and map them to user IDs. + vaultUnlockDataList + .filter { it.status == VaultUnlockData.Status.UNLOCKED } + .map { it.userId } + } + .first { unlockedUserIds -> userId in unlockedUserIds } + } + /** * Increments the stored invalid unlock count for the given [userId] and automatically logs out * if this new value is greater than [MAXIMUM_INVALID_UNLOCK_ATTEMPTS]. diff --git a/app/src/main/java/com/x8bit/bitwarden/data/vault/repository/VaultRepositoryImpl.kt b/app/src/main/java/com/x8bit/bitwarden/data/vault/repository/VaultRepositoryImpl.kt index 905f003502..b92784a792 100644 --- a/app/src/main/java/com/x8bit/bitwarden/data/vault/repository/VaultRepositoryImpl.kt +++ b/app/src/main/java/com/x8bit/bitwarden/data/vault/repository/VaultRepositoryImpl.kt @@ -1289,6 +1289,7 @@ class VaultRepositoryImpl( mutableCiphersStateFlow.value = DataState.Loading } .map { + waitUntilUnlocked(userId = userId) vaultSdkSource .decryptCipherList( userId = userId, @@ -1321,6 +1322,7 @@ class VaultRepositoryImpl( .getFolders(userId = userId) .onStart { mutableFoldersStateFlow.value = DataState.Loading } .map { + waitUntilUnlocked(userId = userId) vaultSdkSource .decryptFolderList( userId = userId, @@ -1340,6 +1342,7 @@ class VaultRepositoryImpl( .getCollections(userId = userId) .onStart { mutableCollectionsStateFlow.value = DataState.Loading } .map { + waitUntilUnlocked(userId = userId) vaultSdkSource .decryptCollectionList( userId = userId, @@ -1363,6 +1366,7 @@ class VaultRepositoryImpl( .getSends(userId = userId) .onStart { mutableSendDataStateFlow.value = DataState.Loading } .map { + waitUntilUnlocked(userId = userId) vaultSdkSource .decryptSendList( userId = userId, diff --git a/app/src/test/java/com/x8bit/bitwarden/data/vault/manager/VaultLockManagerTest.kt b/app/src/test/java/com/x8bit/bitwarden/data/vault/manager/VaultLockManagerTest.kt index 2634855dca..231b6036a6 100644 --- a/app/src/test/java/com/x8bit/bitwarden/data/vault/manager/VaultLockManagerTest.kt +++ b/app/src/test/java/com/x8bit/bitwarden/data/vault/manager/VaultLockManagerTest.kt @@ -1263,6 +1263,28 @@ class VaultLockManagerTest { } } + @Test + fun `waitUntilUnlocked should suspend until the user's vault has unlocked`() = runTest { + // Begin in a locked state + assertFalse(vaultLockManager.isVaultUnlocked(userId = USER_ID)) + + val waitUntilUnlockedJob = async { + vaultLockManager.waitUntilUnlocked(userId = USER_ID) + } + this.testScheduler.advanceUntilIdle() + + // Confirm waitUntilUnlocked has not yet completed + assertFalse(waitUntilUnlockedJob.isCompleted) + + // Unlock vault + verifyUnlockedVault(userId = USER_ID) + this.testScheduler.advanceUntilIdle() + + // Confirm unlock call has now completed and that the vault is unlocked + assertTrue(waitUntilUnlockedJob.isCompleted) + assertTrue(vaultLockManager.isVaultUnlocked(userId = USER_ID)) + } + /** * Resets the verification call count for the given [mock] while leaving all other mocked * behavior in place. diff --git a/app/src/test/java/com/x8bit/bitwarden/data/vault/repository/VaultRepositoryTest.kt b/app/src/test/java/com/x8bit/bitwarden/data/vault/repository/VaultRepositoryTest.kt index afdbb68230..1106358042 100644 --- a/app/src/test/java/com/x8bit/bitwarden/data/vault/repository/VaultRepositoryTest.kt +++ b/app/src/test/java/com/x8bit/bitwarden/data/vault/repository/VaultRepositoryTest.kt @@ -125,7 +125,9 @@ import io.mockk.unmockkStatic import io.mockk.verify import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.MutableStateFlow +import kotlinx.coroutines.flow.first import kotlinx.coroutines.flow.flowOf +import kotlinx.coroutines.flow.update import kotlinx.coroutines.test.runTest import org.junit.jupiter.api.AfterEach import org.junit.jupiter.api.Assertions.assertEquals @@ -164,12 +166,24 @@ class VaultRepositoryTest { private val mutableVaultStateFlow = MutableStateFlow>( emptyList(), ) + private val mutableUnlockedUserIdsStateFlow = MutableStateFlow>(emptySet()) private val vaultLockManager: VaultLockManager = mockk { every { vaultUnlockDataStateFlow } returns mutableVaultStateFlow - every { isVaultUnlocked(any()) } returns false + every { + isVaultUnlocked(any()) + } answers { call -> + val userId = call.invocation.args.first() + userId in mutableUnlockedUserIdsStateFlow.value + } every { isVaultUnlocking(any()) } returns false every { lockVault(any()) } just runs every { lockVaultForCurrentUser() } just runs + coEvery { + waitUntilUnlocked(any()) + } coAnswers { call -> + val userId = call.invocation.args.first() + mutableUnlockedUserIdsStateFlow.first { userId in it } + } } private val mutableFullSyncFlow = bufferedMutableSharedFlow() @@ -245,6 +259,11 @@ class VaultRepositoryTest { .test { assertEquals(DataState.Loading, awaitItem()) mutableCiphersStateFlow.tryEmit(mockCipherList) + + // No additional emissions until vault is unlocked + expectNoEvents() + setVaultToUnlocked(userId = userId) + assertEquals(DataState.Loaded(mockCipherViewList), awaitItem()) } } @@ -273,6 +292,11 @@ class VaultRepositoryTest { .test { assertEquals(DataState.Loading, awaitItem()) mutableCiphersStateFlow.tryEmit(mockCipherList) + + // No additional emissions until vault is unlocked + expectNoEvents() + setVaultToUnlocked(userId = userId) + assertEquals(DataState.Error>(throwable), awaitItem()) } } @@ -303,6 +327,11 @@ class VaultRepositoryTest { .test { assertEquals(DataState.Loading, awaitItem()) mutableCollectionsStateFlow.tryEmit(mockCollectionList) + + // No additional emissions until vault is unlocked + expectNoEvents() + setVaultToUnlocked(userId = userId) + assertEquals(DataState.Loaded(mockCollectionViewList), awaitItem()) } } @@ -331,6 +360,11 @@ class VaultRepositoryTest { .test { assertEquals(DataState.Loading, awaitItem()) mutableCollectionStateFlow.tryEmit(mockCollectionList) + + // No additional emissions until vault is unlocked + expectNoEvents() + setVaultToUnlocked(userId = userId) + assertEquals(DataState.Error>(throwable), awaitItem()) } } @@ -361,6 +395,11 @@ class VaultRepositoryTest { .test { assertEquals(DataState.Loading, awaitItem()) mutableFoldersStateFlow.tryEmit(mockFolderList) + + // No additional emissions until vault is unlocked + expectNoEvents() + setVaultToUnlocked(userId = userId) + assertEquals(DataState.Loaded(mockFolderViewList), awaitItem()) } } @@ -389,6 +428,11 @@ class VaultRepositoryTest { .test { assertEquals(DataState.Loading, awaitItem()) mutableFoldersStateFlow.tryEmit(mockFolderList) + + // No additional emissions until vault is unlocked + expectNoEvents() + setVaultToUnlocked(userId = userId) + assertEquals(DataState.Error>(throwable), awaitItem()) } } @@ -418,6 +462,11 @@ class VaultRepositoryTest { .test { assertEquals(DataState.Loading, awaitItem()) mutableSendsStateFlow.tryEmit(mockSendList) + + // No additional emissions until vault is unlocked + expectNoEvents() + setVaultToUnlocked(userId = userId) + assertEquals(DataState.Loaded(SendData(mockSendViewList)), awaitItem()) } } @@ -446,6 +495,11 @@ class VaultRepositoryTest { .test { assertEquals(DataState.Loading, awaitItem()) mutableSendsStateFlow.tryEmit(mockSendList) + + // No additional emissions until vault is unlocked + expectNoEvents() + setVaultToUnlocked(userId = userId) + assertEquals(DataState.Error(throwable), awaitItem()) } } @@ -634,6 +688,7 @@ class VaultRepositoryTest { runTest { fakeAuthDiskSource.userState = MOCK_USER_STATE val userId = "mockId-1" + setVaultToUnlocked(userId = userId) coEvery { syncService.sync() } returns UnknownHostException().asFailure() val sendsFlow = bufferedMutableSharedFlow>() setupVaultDiskSourceFlows(sendsFlow = sendsFlow) @@ -1395,6 +1450,10 @@ class VaultRepositoryTest { foldersFlow.tryEmit(listOf(createMockFolder(number = 1))) sendsFlow.tryEmit(listOf(createMockSend(number = 1))) + // No events received until unlocked + expectNoEvents() + setVaultToUnlocked(userId = userId) + assertEquals( DataState.Loaded( data = VaultData( @@ -1417,6 +1476,7 @@ class VaultRepositoryTest { fun `clearUnlockedData should update the sendDataStateFlow to Loading`() = runTest { fakeAuthDiskSource.userState = MOCK_USER_STATE val userId = "mockId-1" + setVaultToUnlocked(userId = userId) coEvery { vaultSdkSource.decryptSendList( userId = userId, @@ -1580,6 +1640,11 @@ class VaultRepositoryTest { vaultRepository.getSendStateFlow("mockId-$sendId").test { assertEquals(DataState.Loading, awaitItem()) sendsFlow.tryEmit(emptyList()) + + // No additional emissions until vault is unlocked + expectNoEvents() + setVaultToUnlocked(userId = MOCK_USER_STATE.activeUserId) + assertEquals(DataState.Loaded(null), awaitItem()) sendsFlow.tryEmit(listOf(createMockSend(number = sendId))) assertEquals(DataState.Loaded(sendView), awaitItem()) @@ -4053,6 +4118,7 @@ class VaultRepositoryTest { runTest { fakeAuthDiskSource.userState = MOCK_USER_STATE val userId = "mockId-1" + setVaultToUnlocked(userId = userId) val mockSyncResponse = createMockSyncResponse(number = 1) coEvery { syncService.sync() } returns mockSyncResponse.asSuccess() @@ -4121,6 +4187,7 @@ class VaultRepositoryTest { fun `getAuthCodesFlow should update data state when state changes`() = runTest { fakeAuthDiskSource.userState = MOCK_USER_STATE val userId = "mockId-1" + setVaultToUnlocked(userId = userId) val mockSyncResponse = createMockSyncResponse(number = 1) coEvery { syncService.sync() } returns mockSyncResponse.asSuccess() @@ -4214,6 +4281,7 @@ class VaultRepositoryTest { val cipherId = "mockId-$number" fakeAuthDiskSource.userState = MOCK_USER_STATE + setVaultToUnlocked(userId = MOCK_USER_STATE.activeUserId) val cipherView = createMockCipherView(number = number) coEvery { vaultSdkSource.decryptCipherList( @@ -4256,6 +4324,7 @@ class VaultRepositoryTest { val cipherId = "mockId-$number" fakeAuthDiskSource.userState = MOCK_USER_STATE + setVaultToUnlocked(userId = MOCK_USER_STATE.activeUserId) val cipherView = createMockCipherView(number = number) coEvery { vaultSdkSource.decryptCipherList( @@ -4326,6 +4395,7 @@ class VaultRepositoryTest { val cipherId = "mockId-$number" fakeAuthDiskSource.userState = MOCK_USER_STATE + setVaultToUnlocked(userId = MOCK_USER_STATE.activeUserId) coEvery { vaultSdkSource.decryptCipherList( userId = MOCK_USER_STATE.activeUserId, @@ -4393,6 +4463,7 @@ class VaultRepositoryTest { val cipherId = "mockId-$number" fakeAuthDiskSource.userState = MOCK_USER_STATE + setVaultToUnlocked(userId = MOCK_USER_STATE.activeUserId) coEvery { vaultSdkSource.decryptCipherList( userId = MOCK_USER_STATE.activeUserId, @@ -4432,6 +4503,7 @@ class VaultRepositoryTest { val cipherId = "mockId-$number" fakeAuthDiskSource.userState = MOCK_USER_STATE + setVaultToUnlocked(userId = MOCK_USER_STATE.activeUserId) val cipherView = createMockCipherView(number = number) coEvery { vaultSdkSource.decryptCipherList( @@ -4487,6 +4559,7 @@ class VaultRepositoryTest { } just runs fakeAuthDiskSource.userState = MOCK_USER_STATE + setVaultToUnlocked(userId = MOCK_USER_STATE.activeUserId) val cipherView = createMockCipherView(number = number) coEvery { vaultSdkSource.decryptCipherList( @@ -4531,6 +4604,7 @@ class VaultRepositoryTest { val cipherId = "mockId-1" fakeAuthDiskSource.userState = MOCK_USER_STATE + setVaultToUnlocked(userId = MOCK_USER_STATE.activeUserId) val response: HttpException = mockk { every { code() } returns 404 @@ -4585,6 +4659,7 @@ class VaultRepositoryTest { val cipherId = "mockId-$number" fakeAuthDiskSource.userState = MOCK_USER_STATE + setVaultToUnlocked(userId = MOCK_USER_STATE.activeUserId) coEvery { vaultSdkSource.decryptCipherList( userId = MOCK_USER_STATE.activeUserId, @@ -4638,6 +4713,7 @@ class VaultRepositoryTest { val cipherId = "mockId-$number" fakeAuthDiskSource.userState = MOCK_USER_STATE + setVaultToUnlocked(userId = MOCK_USER_STATE.activeUserId) val cipherView = createMockCipherView(number = number) coEvery { vaultSdkSource.decryptCipherList( @@ -4705,6 +4781,7 @@ class VaultRepositoryTest { val sendId = "mockId-$number" fakeAuthDiskSource.userState = MOCK_USER_STATE + setVaultToUnlocked(userId = MOCK_USER_STATE.activeUserId) val sendView = createMockSendView(number = number) coEvery { vaultSdkSource.decryptSendList( @@ -4743,6 +4820,7 @@ class VaultRepositoryTest { val sendId = "mockId-$number" fakeAuthDiskSource.userState = MOCK_USER_STATE + setVaultToUnlocked(userId = MOCK_USER_STATE.activeUserId) coEvery { vaultSdkSource.decryptSendList( userId = MOCK_USER_STATE.activeUserId, @@ -4780,6 +4858,7 @@ class VaultRepositoryTest { val sendId = "mockId-$number" fakeAuthDiskSource.userState = MOCK_USER_STATE + setVaultToUnlocked(userId = MOCK_USER_STATE.activeUserId) val sendView = createMockSendView(number = number) coEvery { vaultSdkSource.decryptSendList( @@ -4833,6 +4912,7 @@ class VaultRepositoryTest { } just runs fakeAuthDiskSource.userState = MOCK_USER_STATE + setVaultToUnlocked(userId = MOCK_USER_STATE.activeUserId) val sendView = createMockSendView(number = number) coEvery { vaultSdkSource.decryptSendList( @@ -4875,6 +4955,7 @@ class VaultRepositoryTest { val sendId = "mockId-1" fakeAuthDiskSource.userState = MOCK_USER_STATE + setVaultToUnlocked(userId = MOCK_USER_STATE.activeUserId) val response: HttpException = mockk { every { code() } returns 404 @@ -4927,6 +5008,7 @@ class VaultRepositoryTest { val sendId = "mockId-$number" fakeAuthDiskSource.userState = MOCK_USER_STATE + setVaultToUnlocked(userId = MOCK_USER_STATE.activeUserId) coEvery { vaultSdkSource.decryptSendList( userId = MOCK_USER_STATE.activeUserId, @@ -4978,6 +5060,7 @@ class VaultRepositoryTest { val sendId = "mockId-$number" fakeAuthDiskSource.userState = MOCK_USER_STATE + setVaultToUnlocked(userId = MOCK_USER_STATE.activeUserId) val sendView = createMockSendView(number = number) coEvery { vaultSdkSource.decryptSendList( @@ -5049,6 +5132,7 @@ class VaultRepositoryTest { val folderId = "mockId-$number" fakeAuthDiskSource.userState = MOCK_USER_STATE + setVaultToUnlocked(userId = MOCK_USER_STATE.activeUserId) val folderView = createMockFolderView(number = number) coEvery { vaultSdkSource.decryptFolderList( @@ -5087,6 +5171,7 @@ class VaultRepositoryTest { val folderId = "mockId-$number" fakeAuthDiskSource.userState = MOCK_USER_STATE + setVaultToUnlocked(userId = MOCK_USER_STATE.activeUserId) coEvery { vaultSdkSource.decryptFolderList( userId = MOCK_USER_STATE.activeUserId, @@ -5124,6 +5209,7 @@ class VaultRepositoryTest { val folderId = "mockId-$number" fakeAuthDiskSource.userState = MOCK_USER_STATE + setVaultToUnlocked(userId = MOCK_USER_STATE.activeUserId) val folderView = createMockFolderView(number = number) coEvery { vaultSdkSource.decryptFolderList( @@ -5166,6 +5252,7 @@ class VaultRepositoryTest { val folderId = "mockId-$number" fakeAuthDiskSource.userState = MOCK_USER_STATE + setVaultToUnlocked(userId = MOCK_USER_STATE.activeUserId) coEvery { vaultSdkSource.decryptFolderList( userId = MOCK_USER_STATE.activeUserId, @@ -5217,6 +5304,7 @@ class VaultRepositoryTest { val folderId = "mockId-$number" fakeAuthDiskSource.userState = MOCK_USER_STATE + setVaultToUnlocked(userId = MOCK_USER_STATE.activeUserId) val folderView = createMockFolderView(number = number) coEvery { vaultSdkSource.decryptFolderList( @@ -5343,6 +5431,14 @@ class VaultRepositoryTest { } returns unlockResult } + /** + * Ensures the vault for the given [userId] is unlocked and can pass any + * [VaultLockManager.waitUntilUnlocked] or [VaultLockManager.isVaultUnlocked] checks. + */ + private fun setVaultToUnlocked(userId: String) { + mutableUnlockedUserIdsStateFlow.update { it + userId } + } + /** * Helper setup all flows required to properly subscribe to the * [VaultRepository.vaultDataStateFlow].