From 41c35e23ddded322d4dbd0b2a36ff10b16d39241 Mon Sep 17 00:00:00 2001 From: Brian Yencho Date: Fri, 5 Jan 2024 08:57:49 -0600 Subject: [PATCH] Add the SdkClientManager and use a single Client per user for vault (#499) --- .../data/platform/manager/SdkClientManager.kt | 21 ++ .../platform/manager/SdkClientManagerImpl.kt | 25 ++ .../manager/di/PlatformManagerModule.kt | 6 + .../repository/GeneratorRepositoryImpl.kt | 11 +- .../vault/datasource/sdk/VaultSdkSource.kt | 139 +++++++++-- .../datasource/sdk/VaultSdkSourceImpl.kt | 153 +++++++++--- .../vault/datasource/sdk/di/VaultSdkModule.kt | 8 +- .../vault/repository/VaultRepositoryImpl.kt | 38 ++- .../platform/manager/SdkClientManagerTest.kt | 44 ++++ .../repository/GeneratorRepositoryTest.kt | 26 +- .../datasource/sdk/VaultSdkSourceTest.kt | 85 ++++++- .../vault/repository/VaultRepositoryTest.kt | 222 ++++++++++++++++-- 12 files changed, 671 insertions(+), 107 deletions(-) create mode 100644 app/src/main/java/com/x8bit/bitwarden/data/platform/manager/SdkClientManager.kt create mode 100644 app/src/main/java/com/x8bit/bitwarden/data/platform/manager/SdkClientManagerImpl.kt create mode 100644 app/src/test/java/com/x8bit/bitwarden/data/platform/manager/SdkClientManagerTest.kt diff --git a/app/src/main/java/com/x8bit/bitwarden/data/platform/manager/SdkClientManager.kt b/app/src/main/java/com/x8bit/bitwarden/data/platform/manager/SdkClientManager.kt new file mode 100644 index 0000000000..445f12ae31 --- /dev/null +++ b/app/src/main/java/com/x8bit/bitwarden/data/platform/manager/SdkClientManager.kt @@ -0,0 +1,21 @@ +package com.x8bit.bitwarden.data.platform.manager + +import com.bitwarden.sdk.Client + +/** + * Manages the creation, caching, and destruction of SDK [Client] instances on a per-user basis. + */ +interface SdkClientManager { + + /** + * Returns the cached [Client] instance for the given [userId], otherwise creates and caches + * a new one and returns it. + */ + fun getOrCreateClient(userId: String): Client + + /** + * Clears any resources from the [Client] associated with the given [userId] and removes it + * from the internal cache. + */ + fun destroyClient(userId: String) +} diff --git a/app/src/main/java/com/x8bit/bitwarden/data/platform/manager/SdkClientManagerImpl.kt b/app/src/main/java/com/x8bit/bitwarden/data/platform/manager/SdkClientManagerImpl.kt new file mode 100644 index 0000000000..26f4f7e21b --- /dev/null +++ b/app/src/main/java/com/x8bit/bitwarden/data/platform/manager/SdkClientManagerImpl.kt @@ -0,0 +1,25 @@ +package com.x8bit.bitwarden.data.platform.manager + +import com.bitwarden.sdk.Client + +/** + * Primary implementation of [SdkClientManager]. + */ +class SdkClientManagerImpl( + private val clientProvider: () -> Client = { Client(null) }, +) : SdkClientManager { + private val userIdToClientMap = mutableMapOf() + + override fun getOrCreateClient( + userId: String, + ): Client = + userIdToClientMap.getOrPut(key = userId) { clientProvider() } + + override fun destroyClient( + userId: String, + ) { + userIdToClientMap + .remove(key = userId) + ?.close() + } +} diff --git a/app/src/main/java/com/x8bit/bitwarden/data/platform/manager/di/PlatformManagerModule.kt b/app/src/main/java/com/x8bit/bitwarden/data/platform/manager/di/PlatformManagerModule.kt index ef4d6dd6cd..f6cb52d0a6 100644 --- a/app/src/main/java/com/x8bit/bitwarden/data/platform/manager/di/PlatformManagerModule.kt +++ b/app/src/main/java/com/x8bit/bitwarden/data/platform/manager/di/PlatformManagerModule.kt @@ -6,6 +6,8 @@ import com.x8bit.bitwarden.data.platform.datasource.network.interceptor.AuthToke import com.x8bit.bitwarden.data.platform.datasource.network.interceptor.BaseUrlInterceptors import com.x8bit.bitwarden.data.platform.manager.NetworkConfigManager import com.x8bit.bitwarden.data.platform.manager.NetworkConfigManagerImpl +import com.x8bit.bitwarden.data.platform.manager.SdkClientManager +import com.x8bit.bitwarden.data.platform.manager.SdkClientManagerImpl import com.x8bit.bitwarden.data.platform.manager.dispatcher.DispatcherManager import com.x8bit.bitwarden.data.platform.manager.dispatcher.DispatcherManagerImpl import com.x8bit.bitwarden.data.platform.repository.EnvironmentRepository @@ -26,6 +28,10 @@ object PlatformManagerModule { @Singleton fun provideBitwardenDispatchers(): DispatcherManager = DispatcherManagerImpl() + @Provides + @Singleton + fun provideSdkClientManager(): SdkClientManager = SdkClientManagerImpl() + @Provides @Singleton fun provideNetworkConfigManager( diff --git a/app/src/main/java/com/x8bit/bitwarden/data/tools/generator/repository/GeneratorRepositoryImpl.kt b/app/src/main/java/com/x8bit/bitwarden/data/tools/generator/repository/GeneratorRepositoryImpl.kt index a128eb08ea..ec73a78118 100644 --- a/app/src/main/java/com/x8bit/bitwarden/data/tools/generator/repository/GeneratorRepositoryImpl.kt +++ b/app/src/main/java/com/x8bit/bitwarden/data/tools/generator/repository/GeneratorRepositoryImpl.kt @@ -67,7 +67,11 @@ class GeneratorRepositoryImpl( .onStart { mutablePasswordHistoryStateFlow.value = LocalDataState.Loading } .map { encryptedPasswordHistoryList -> val passwordHistories = encryptedPasswordHistoryList.map { it.toPasswordHistory() } - vaultSdkSource.decryptPasswordHistoryList(passwordHistories) + vaultSdkSource + .decryptPasswordHistoryList( + userId = userId, + passwordHistoryList = passwordHistories, + ) } .onEach { encryptedPasswordHistoryListResult -> mutablePasswordHistoryStateFlow.value = encryptedPasswordHistoryListResult.fold( @@ -148,7 +152,10 @@ class GeneratorRepositoryImpl( override suspend fun storePasswordHistory(passwordHistoryView: PasswordHistoryView) { val userId = authDiskSource.userState?.activeUserId ?: return val encryptedPasswordHistory = vaultSdkSource - .encryptPasswordHistory(passwordHistoryView) + .encryptPasswordHistory( + userId = userId, + passwordHistory = passwordHistoryView, + ) .getOrNull() ?: return passwordHistoryDiskSource.insertPasswordHistory( encryptedPasswordHistory.toPasswordHistoryEntity(userId), diff --git a/app/src/main/java/com/x8bit/bitwarden/data/vault/datasource/sdk/VaultSdkSource.kt b/app/src/main/java/com/x8bit/bitwarden/data/vault/datasource/sdk/VaultSdkSource.kt index ffbda5a59c..be3965df25 100644 --- a/app/src/main/java/com/x8bit/bitwarden/data/vault/datasource/sdk/VaultSdkSource.kt +++ b/app/src/main/java/com/x8bit/bitwarden/data/vault/datasource/sdk/VaultSdkSource.kt @@ -22,84 +22,171 @@ import com.x8bit.bitwarden.data.vault.datasource.sdk.model.InitializeCryptoResul interface VaultSdkSource { /** - * Attempts to initialize cryptography functionality for an individual user for the - * Bitwarden SDK with a given [InitUserCryptoRequest]. + * Clears any cryptography-related functionality for the given [userId], effectively locking + * the associated vault. */ - suspend fun initializeCrypto(request: InitUserCryptoRequest): Result + fun clearCrypto(userId: String) + + /** + * Attempts to initialize cryptography functionality for an individual user with the given + * [userId] for the Bitwarden SDK with a given [InitUserCryptoRequest]. + */ + suspend fun initializeCrypto( + userId: String, + request: InitUserCryptoRequest, + ): Result /** * Attempts to initialize cryptography functionality for organization data associated with - * the current user for the Bitwarden SDK with a given [InitOrgCryptoRequest]. + * the user with the given [userId] for the Bitwarden SDK with a given [InitOrgCryptoRequest]. * - * This should only be called after a successful call to [initializeCrypto]. + * This should only be called after a successful call to [initializeCrypto] for the associated + * user. */ suspend fun initializeOrganizationCrypto( + userId: String, request: InitOrgCryptoRequest, ): Result /** - * Encrypts a [CipherView] returning a [Cipher] wrapped in a [Result]. + * Encrypts a [CipherView] for the user with the given [userId], returning a [Cipher] wrapped + * in a [Result]. + * + * This should only be called after a successful call to [initializeCrypto] for the associated + * user. */ - suspend fun encryptCipher(cipherView: CipherView): Result + suspend fun encryptCipher( + userId: String, + cipherView: CipherView, + ): Result /** - * Decrypts a [Cipher] returning a [CipherView] wrapped in a [Result]. + * Decrypts a [Cipher] for the user with the given [userId], returning a [CipherView] wrapped + * in a [Result]. + * + * This should only be called after a successful call to [initializeCrypto] for the associated + * user. */ - suspend fun decryptCipher(cipher: Cipher): Result + suspend fun decryptCipher( + userId: String, + cipher: Cipher, + ): Result /** - * Decrypts a list of [Cipher]s returning a list of [CipherListView] wrapped in a [Result]. + * Decrypts a list of [Cipher]s for the user with the given [userId], returning a list of + * [CipherListView] wrapped in a [Result]. + * + * This should only be called after a successful call to [initializeCrypto] for the associated + * user. */ - suspend fun decryptCipherListCollection(cipherList: List): Result> + suspend fun decryptCipherListCollection( + userId: String, + cipherList: List, + ): Result> /** - * Decrypts a list of [Cipher]s returning a list of [CipherView] wrapped in a [Result]. + * Decrypts a list of [Cipher]s for the user with the given [userId], returning a list of + * [CipherView] wrapped in a [Result]. + * + * This should only be called after a successful call to [initializeCrypto] for the associated + * user. */ - suspend fun decryptCipherList(cipherList: List): Result> + suspend fun decryptCipherList( + userId: String, + cipherList: List, + ): Result> /** - * Decrypts a [Collection] returning a [CollectionView] wrapped in a [Result]. + * Decrypts a [Collection] for the user with the given [userId], returning a [CollectionView] + * wrapped in a [Result]. + * + * This should only be called after a successful call to [initializeCrypto] for the associated + * user. */ - suspend fun decryptCollection(collection: Collection): Result + suspend fun decryptCollection( + userId: String, + collection: Collection, + ): Result /** - * Decrypts a list of [Collection]s returning a list of [CollectionView] wrapped in a [Result]. + * Decrypts a list of [Collection]s for the user with the given [userId], returning a list of + * [CollectionView] wrapped in a [Result]. + * + * This should only be called after a successful call to [initializeCrypto] for the associated + * user. */ suspend fun decryptCollectionList( + userId: String, collectionList: List, ): Result> /** - * Decrypts a [Send] returning a [SendView] wrapped in a [Result]. + * Decrypts a [Send] for the user with the given [userId], returning a [SendView] wrapped in a + * [Result]. + * + * This should only be called after a successful call to [initializeCrypto] for the associated + * user. */ - suspend fun decryptSend(send: Send): Result + suspend fun decryptSend( + userId: String, + send: Send, + ): Result /** - * Decrypts a list of [Send]s returning a list of [SendView] wrapped in a [Result]. + * Decrypts a list of [Send]s for the user with the given [userId], returning a list of + * [SendView] wrapped in a [Result]. + * + * This should only be called after a successful call to [initializeCrypto] for the associated + * user. */ - suspend fun decryptSendList(sendList: List): Result> + suspend fun decryptSendList( + userId: String, + sendList: List, + ): Result> /** - * Decrypts a [Folder] returning a [FolderView] wrapped in a [Result]. + * Decrypts a [Folder] for the user with the given [userId], returning a [FolderView] wrapped + * in a [Result]. + * + * This should only be called after a successful call to [initializeCrypto] for the associated + * user. */ - suspend fun decryptFolder(folder: Folder): Result + suspend fun decryptFolder( + userId: String, + folder: Folder, + ): Result /** - * Decrypts a list of [Folder]s returning a list of [FolderView] wrapped in a [Result]. + * Decrypts a list of [Folder]s for the user with the given [userId], returning a list of + * [FolderView] wrapped in a [Result]. + * + * This should only be called after a successful call to [initializeCrypto] for the associated + * user. */ - suspend fun decryptFolderList(folderList: List): Result> + suspend fun decryptFolderList( + userId: String, + folderList: List, + ): Result> /** - * Encrypts a given password history item. + * Encrypts a given password history item for the user with the given [userId]. + * + * This should only be called after a successful call to [initializeCrypto] for the associated + * user. */ suspend fun encryptPasswordHistory( + userId: String, passwordHistory: PasswordHistoryView, ): Result /** - * Decrypts a list of password history items. + * Decrypts a list of password history items for the user with the given [userId]. + * + * This should only be called after a successful call to [initializeCrypto] for the associated + * user. */ suspend fun decryptPasswordHistoryList( + userId: String, passwordHistoryList: List, ): Result> } diff --git a/app/src/main/java/com/x8bit/bitwarden/data/vault/datasource/sdk/VaultSdkSourceImpl.kt b/app/src/main/java/com/x8bit/bitwarden/data/vault/datasource/sdk/VaultSdkSourceImpl.kt index bec761541b..a5034cf3e3 100644 --- a/app/src/main/java/com/x8bit/bitwarden/data/vault/datasource/sdk/VaultSdkSourceImpl.kt +++ b/app/src/main/java/com/x8bit/bitwarden/data/vault/datasource/sdk/VaultSdkSourceImpl.kt @@ -14,9 +14,9 @@ import com.bitwarden.core.PasswordHistoryView import com.bitwarden.core.Send import com.bitwarden.core.SendView import com.bitwarden.sdk.BitwardenException -import com.bitwarden.sdk.ClientCrypto -import com.bitwarden.sdk.ClientPasswordHistory +import com.bitwarden.sdk.Client import com.bitwarden.sdk.ClientVault +import com.x8bit.bitwarden.data.platform.manager.SdkClientManager import com.x8bit.bitwarden.data.vault.datasource.sdk.model.InitializeCryptoResult /** @@ -25,16 +25,21 @@ import com.x8bit.bitwarden.data.vault.datasource.sdk.model.InitializeCryptoResul */ @Suppress("TooManyFunctions") class VaultSdkSourceImpl( - private val clientVault: ClientVault, - private val clientCrypto: ClientCrypto, - private val clientPasswordHistory: ClientPasswordHistory, + private val sdkClientManager: SdkClientManager, ) : VaultSdkSource { + override fun clearCrypto(userId: String) { + sdkClientManager.destroyClient(userId = userId) + } + override suspend fun initializeCrypto( + userId: String, request: InitUserCryptoRequest, ): Result = runCatching { try { - clientCrypto.initializeUserCrypto(req = request) + getClient(userId = userId) + .crypto() + .initializeUserCrypto(req = request) InitializeCryptoResult.Success } catch (exception: BitwardenException) { // The only truly expected error from the SDK is an incorrect key/password. @@ -43,11 +48,14 @@ class VaultSdkSourceImpl( } override suspend fun initializeOrganizationCrypto( + userId: String, request: InitOrgCryptoRequest, ): Result = runCatching { try { - clientCrypto.initializeOrgCrypto(req = request) + getClient(userId = userId) + .crypto() + .initializeOrgCrypto(req = request) InitializeCryptoResult.Success } catch (exception: BitwardenException) { // The only truly expected error from the SDK is for incorrect keys. @@ -55,53 +63,140 @@ class VaultSdkSourceImpl( } } - override suspend fun encryptCipher(cipherView: CipherView): Result = - runCatching { clientVault.ciphers().encrypt(cipherView) } + override suspend fun encryptCipher( + userId: String, + cipherView: CipherView, + ): Result = + runCatching { + getClient(userId = userId) + .vault() + .ciphers() + .encrypt(cipherView) + } - override suspend fun decryptCipher(cipher: Cipher): Result = - runCatching { clientVault.ciphers().decrypt(cipher) } + override suspend fun decryptCipher( + userId: String, + cipher: Cipher, + ): Result = + runCatching { + getClient(userId = userId) + .vault() + .ciphers() + .decrypt(cipher) + } override suspend fun decryptCipherListCollection( + userId: String, cipherList: List, ): Result> = - runCatching { clientVault.ciphers().decryptList(cipherList) } - - override suspend fun decryptCipherList(cipherList: List): Result> = - runCatching { cipherList.map { clientVault.ciphers().decrypt(it) } } - - override suspend fun decryptCollection(collection: Collection): Result = runCatching { - clientVault.collections().decrypt(collection) + getClient(userId = userId) + .vault().ciphers() + .decryptList(cipherList) + } + + override suspend fun decryptCipherList( + userId: String, + cipherList: List, + ): Result> = + runCatching { + cipherList.map { + getClient(userId = userId) + .vault() + .ciphers() + .decrypt(it) + } + } + + override suspend fun decryptCollection( + userId: String, + collection: Collection, + ): Result = + runCatching { + getClient(userId = userId) + .vault() + .collections() + .decrypt(collection) } override suspend fun decryptCollectionList( + userId: String, collectionList: List, ): Result> = runCatching { - clientVault.collections().decryptList(collectionList) + getClient(userId = userId) + .vault() + .collections() + .decryptList(collectionList) } - override suspend fun decryptSend(send: Send): Result = - runCatching { clientVault.sends().decrypt(send) } + override suspend fun decryptSend( + userId: String, + send: Send, + ): Result = + runCatching { + getClient(userId = userId) + .vault() + .sends() + .decrypt(send) + } - override suspend fun decryptSendList(sendList: List): Result> = - runCatching { sendList.map { clientVault.sends().decrypt(it) } } + override suspend fun decryptSendList( + userId: String, + sendList: List, + ): Result> = + runCatching { + sendList.map { + getClient(userId = userId) + .vault() + .sends() + .decrypt(it) + } + } - override suspend fun decryptFolder(folder: Folder): Result = - runCatching { clientVault.folders().decrypt(folder) } + override suspend fun decryptFolder( + userId: String, + folder: Folder, + ): Result = + runCatching { + getClient(userId = userId) + .vault() + .folders() + .decrypt(folder) + } - override suspend fun decryptFolderList(folderList: List): Result> = - runCatching { clientVault.folders().decryptList(folderList) } + override suspend fun decryptFolderList( + userId: String, + folderList: List, + ): Result> = + runCatching { + getClient(userId = userId) + .vault() + .folders() + .decryptList(folderList) + } override suspend fun encryptPasswordHistory( + userId: String, passwordHistory: PasswordHistoryView, ): Result = runCatching { - clientPasswordHistory.encrypt(passwordHistory) + getClient(userId = userId) + .vault() + .passwordHistory() + .encrypt(passwordHistory) } override suspend fun decryptPasswordHistoryList( + userId: String, passwordHistoryList: List, ): Result> = runCatching { - clientPasswordHistory.decryptList(passwordHistoryList) + getClient(userId = userId) + .vault() + .passwordHistory() + .decryptList(passwordHistoryList) } + + private fun getClient( + userId: String, + ): Client = sdkClientManager.getOrCreateClient(userId = userId) } diff --git a/app/src/main/java/com/x8bit/bitwarden/data/vault/datasource/sdk/di/VaultSdkModule.kt b/app/src/main/java/com/x8bit/bitwarden/data/vault/datasource/sdk/di/VaultSdkModule.kt index 1a1e211fab..46dc7eeac6 100644 --- a/app/src/main/java/com/x8bit/bitwarden/data/vault/datasource/sdk/di/VaultSdkModule.kt +++ b/app/src/main/java/com/x8bit/bitwarden/data/vault/datasource/sdk/di/VaultSdkModule.kt @@ -1,6 +1,6 @@ package com.x8bit.bitwarden.data.vault.datasource.sdk.di -import com.bitwarden.sdk.Client +import com.x8bit.bitwarden.data.platform.manager.SdkClientManager import com.x8bit.bitwarden.data.vault.datasource.sdk.VaultSdkSource import com.x8bit.bitwarden.data.vault.datasource.sdk.VaultSdkSourceImpl import dagger.Module @@ -19,11 +19,9 @@ object VaultSdkModule { @Provides @Singleton fun providesVaultSdkSource( - client: Client, + sdkClientManager: SdkClientManager, ): VaultSdkSource = VaultSdkSourceImpl( - clientVault = client.vault(), - clientCrypto = client.crypto(), - clientPasswordHistory = client.vault().passwordHistory(), + sdkClientManager = sdkClientManager, ) } diff --git a/app/src/main/java/com/x8bit/bitwarden/data/vault/repository/VaultRepositoryImpl.kt b/app/src/main/java/com/x8bit/bitwarden/data/vault/repository/VaultRepositoryImpl.kt index 19289546c0..a3d5ce23e0 100644 --- a/app/src/main/java/com/x8bit/bitwarden/data/vault/repository/VaultRepositoryImpl.kt +++ b/app/src/main/java/com/x8bit/bitwarden/data/vault/repository/VaultRepositoryImpl.kt @@ -320,6 +320,7 @@ class VaultRepositoryImpl( emit( vaultSdkSource .initializeCrypto( + userId = userId, request = InitUserCryptoRequest( kdfParams = kdf, email = email, @@ -336,6 +337,7 @@ class VaultRepositoryImpl( result is InitializeCryptoResult.Success ) { vaultSdkSource.initializeOrganizationCrypto( + userId = userId, request = InitOrgCryptoRequest( organizationKeys = organizationKeys, ), @@ -363,7 +365,10 @@ class VaultRepositoryImpl( override suspend fun createCipher(cipherView: CipherView): CreateCipherResult = vaultSdkSource - .encryptCipher(cipherView = cipherView) + .encryptCipher( + userId = requireNotNull(activeUserId), + cipherView = cipherView, + ) .flatMap { cipher -> ciphersService .createCipher( @@ -385,7 +390,10 @@ class VaultRepositoryImpl( cipherView: CipherView, ): UpdateCipherResult = vaultSdkSource - .encryptCipher(cipherView = cipherView) + .encryptCipher( + userId = requireNotNull(activeUserId), + cipherView = cipherView, + ) .flatMap { cipher -> ciphersService.updateCipher( cipherId = cipherId, @@ -421,6 +429,7 @@ class VaultRepositoryImpl( // TODO: This is temporary. Eventually this needs to be based on the presence of various // user keys but this will likely require SDK updates to support this (BIT-1190). private fun setVaultToLocked(userId: String) { + vaultSdkSource.clearCrypto(userId = userId) mutableVaultStateStateFlow.update { it.copy( unlockedVaultUserIds = it.unlockedVaultUserIds - userId, @@ -473,6 +482,7 @@ class VaultRepositoryImpl( // the return type here. vaultSdkSource .initializeOrganizationCrypto( + userId = syncResponse.profile.id, request = InitOrgCryptoRequest( organizationKeys = organizationKeys, ), @@ -484,10 +494,15 @@ class VaultRepositoryImpl( ): Flow>> = vaultDiskSource .getCiphers(userId = userId) - .onStart { mutableCiphersStateFlow.value = DataState.Loading } + .onStart { + mutableCiphersStateFlow.value = DataState.Loading + } .map { vaultSdkSource - .decryptCipherList(cipherList = it.toEncryptedSdkCipherList()) + .decryptCipherList( + userId = userId, + cipherList = it.toEncryptedSdkCipherList(), + ) .fold( onSuccess = { ciphers -> DataState.Loaded(ciphers) }, onFailure = { throwable -> DataState.Error(throwable) }, @@ -503,7 +518,10 @@ class VaultRepositoryImpl( .onStart { mutableFoldersStateFlow.value = DataState.Loading } .map { vaultSdkSource - .decryptFolderList(folderList = it.toEncryptedSdkFolderList()) + .decryptFolderList( + userId = userId, + folderList = it.toEncryptedSdkFolderList(), + ) .fold( onSuccess = { folders -> DataState.Loaded(folders) }, onFailure = { throwable -> DataState.Error(throwable) }, @@ -519,7 +537,10 @@ class VaultRepositoryImpl( .onStart { mutableCollectionsStateFlow.value = DataState.Loading } .map { vaultSdkSource - .decryptCollectionList(collectionList = it.toEncryptedSdkCollectionList()) + .decryptCollectionList( + userId = userId, + collectionList = it.toEncryptedSdkCollectionList(), + ) .fold( onSuccess = { collections -> DataState.Loaded(collections) }, onFailure = { throwable -> DataState.Error(throwable) }, @@ -535,7 +556,10 @@ class VaultRepositoryImpl( .onStart { mutableSendDataStateFlow.value = DataState.Loading } .map { vaultSdkSource - .decryptSendList(sendList = it.toEncryptedSdkSendList()) + .decryptSendList( + userId = userId, + sendList = it.toEncryptedSdkSendList(), + ) .fold( onSuccess = { sends -> DataState.Loaded(SendData(sends)) }, onFailure = { throwable -> DataState.Error(throwable) }, diff --git a/app/src/test/java/com/x8bit/bitwarden/data/platform/manager/SdkClientManagerTest.kt b/app/src/test/java/com/x8bit/bitwarden/data/platform/manager/SdkClientManagerTest.kt new file mode 100644 index 0000000000..ccd4c3c9e1 --- /dev/null +++ b/app/src/test/java/com/x8bit/bitwarden/data/platform/manager/SdkClientManagerTest.kt @@ -0,0 +1,44 @@ +package com.x8bit.bitwarden.data.platform.manager + +import io.mockk.mockk +import io.mockk.verify +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.Assertions.assertNotEquals +import org.junit.jupiter.api.Test + +class SdkClientManagerTest { + + private val sdkClientManager = SdkClientManagerImpl( + clientProvider = { mockk(relaxed = true) }, + ) + + @Suppress("MaxLineLength") + @Test + fun `getOrCreateClient should create a new client for each userId and return a cached client for subsequent calls`() { + val userId = "userId" + val firstClient = sdkClientManager.getOrCreateClient(userId = userId) + + // Additional calls for the same userId return the same value + val secondClient = sdkClientManager.getOrCreateClient(userId = userId) + assertEquals(firstClient, secondClient) + + // Additional calls for different userIds should return different values + val otherUserId = "otherUserId" + val thirdClient = sdkClientManager.getOrCreateClient(userId = otherUserId) + assertNotEquals(firstClient, thirdClient) + } + + @Test + fun `destroyClient should call close on the Client and remove it from the cache`() { + val userId = "userId" + val firstClient = sdkClientManager.getOrCreateClient(userId = userId) + + sdkClientManager.destroyClient(userId = userId) + + verify { firstClient.close() } + + // New calls for the same userId should return different values + val secondClient = sdkClientManager.getOrCreateClient(userId = userId) + assertNotEquals(firstClient, secondClient) + } +} diff --git a/app/src/test/java/com/x8bit/bitwarden/data/tools/generator/repository/GeneratorRepositoryTest.kt b/app/src/test/java/com/x8bit/bitwarden/data/tools/generator/repository/GeneratorRepositoryTest.kt index 0438ab0889..164d61129c 100644 --- a/app/src/test/java/com/x8bit/bitwarden/data/tools/generator/repository/GeneratorRepositoryTest.kt +++ b/app/src/test/java/com/x8bit/bitwarden/data/tools/generator/repository/GeneratorRepositoryTest.kt @@ -105,7 +105,7 @@ class GeneratorRepositoryTest { coEvery { generatorSdkSource.generatePassword(request) } returns Result.success(generatedPassword) - coEvery { vaultSdkSource.encryptPasswordHistory(any()) } returns + coEvery { vaultSdkSource.encryptPasswordHistory(any(), any()) } returns Result.success(encryptedPasswordHistory) coEvery { passwordHistoryDiskSource.insertPasswordHistory(any()) } just runs @@ -156,7 +156,7 @@ class GeneratorRepositoryTest { coEvery { generatorSdkSource.generatePassword(request) } returns Result.success(generatedPassword) - coEvery { vaultSdkSource.encryptPasswordHistory(any()) } returns + coEvery { vaultSdkSource.encryptPasswordHistory(any(), any()) } returns Result.success(encryptedPasswordHistory) coEvery { passwordHistoryDiskSource.insertPasswordHistory(any()) } just runs @@ -220,7 +220,7 @@ class GeneratorRepositoryTest { coEvery { generatorSdkSource.generatePassphrase(request) } returns Result.success(generatedPassphrase) - coEvery { vaultSdkSource.encryptPasswordHistory(any()) } returns + coEvery { vaultSdkSource.encryptPasswordHistory(any(), any()) } returns Result.success(encryptedPasswordHistory) coEvery { passwordHistoryDiskSource.insertPasswordHistory(any()) } just runs @@ -232,7 +232,7 @@ class GeneratorRepositoryTest { (result as GeneratedPassphraseResult.Success).generatedString, ) coVerify { generatorSdkSource.generatePassphrase(request) } - coVerify { vaultSdkSource.encryptPasswordHistory(any()) } + coVerify { vaultSdkSource.encryptPasswordHistory(any(), any()) } coVerify { passwordHistoryDiskSource.insertPasswordHistory( encryptedPasswordHistory.toPasswordHistoryEntity(userId), @@ -406,7 +406,12 @@ class GeneratorRepositoryTest { coEvery { authDiskSource.userState?.activeUserId } returns testUserId - coEvery { vaultSdkSource.encryptPasswordHistory(passwordHistoryView) } returns + coEvery { + vaultSdkSource.encryptPasswordHistory( + userId = testUserId, + passwordHistory = passwordHistoryView, + ) + } returns Result.success(encryptedPasswordHistory) coEvery { @@ -415,7 +420,12 @@ class GeneratorRepositoryTest { repository.storePasswordHistory(passwordHistoryView) - coVerify { vaultSdkSource.encryptPasswordHistory(passwordHistoryView) } + coVerify { + vaultSdkSource.encryptPasswordHistory( + userId = testUserId, + passwordHistory = passwordHistoryView, + ) + } coVerify { passwordHistoryDiskSource.insertPasswordHistory(expectedPasswordHistoryEntity) } } @@ -451,7 +461,7 @@ class GeneratorRepositoryTest { } returns flowOf(encryptedPasswordHistoryEntities) coEvery { - vaultSdkSource.decryptPasswordHistoryList(any()) + vaultSdkSource.decryptPasswordHistoryList(any(), any()) } returns Result.success(decryptedPasswordHistoryList) val historyFlow = repository.passwordHistoryStateFlow @@ -467,7 +477,7 @@ class GeneratorRepositoryTest { passwordHistoryDiskSource.getPasswordHistoriesForUser(USER_STATE.activeUserId) } - coVerify { vaultSdkSource.decryptPasswordHistoryList(any()) } + coVerify { vaultSdkSource.decryptPasswordHistoryList(any(), any()) } } @Test diff --git a/app/src/test/java/com/x8bit/bitwarden/data/vault/datasource/sdk/VaultSdkSourceTest.kt b/app/src/test/java/com/x8bit/bitwarden/data/vault/datasource/sdk/VaultSdkSourceTest.kt index 977503463b..845a16237e 100644 --- a/app/src/test/java/com/x8bit/bitwarden/data/vault/datasource/sdk/VaultSdkSourceTest.kt +++ b/app/src/test/java/com/x8bit/bitwarden/data/vault/datasource/sdk/VaultSdkSourceTest.kt @@ -14,32 +14,56 @@ import com.bitwarden.core.PasswordHistoryView import com.bitwarden.core.Send import com.bitwarden.core.SendView import com.bitwarden.sdk.BitwardenException +import com.bitwarden.sdk.Client import com.bitwarden.sdk.ClientCrypto import com.bitwarden.sdk.ClientPasswordHistory import com.bitwarden.sdk.ClientVault +import com.x8bit.bitwarden.data.platform.manager.SdkClientManager import com.x8bit.bitwarden.data.platform.util.asFailure import com.x8bit.bitwarden.data.platform.util.asSuccess import com.x8bit.bitwarden.data.vault.datasource.sdk.model.InitializeCryptoResult import io.mockk.coEvery import io.mockk.coVerify +import io.mockk.every +import io.mockk.just import io.mockk.mockk +import io.mockk.runs +import io.mockk.verify import kotlinx.coroutines.runBlocking import org.junit.jupiter.api.Assertions.assertEquals import org.junit.jupiter.api.Test class VaultSdkSourceTest { - private val clientVault = mockk() private val clientCrypto = mockk() private val clientPasswordHistory = mockk() + private val clientVault = mockk() { + every { passwordHistory() } returns clientPasswordHistory + } + private val client = mockk() { + every { vault() } returns clientVault + every { crypto() } returns clientCrypto + } + private val sdkClientManager = mockk { + every { getOrCreateClient(any()) } returns client + every { destroyClient(any()) } just runs + } private val vaultSdkSource: VaultSdkSource = VaultSdkSourceImpl( - clientVault = clientVault, - clientCrypto = clientCrypto, - clientPasswordHistory = clientPasswordHistory, + sdkClientManager = sdkClientManager, ) + @Test + fun `clearCrypto should destroy the associated client via the SDK Manager`() { + val userId = "userId" + + vaultSdkSource.clearCrypto(userId = userId) + + verify { sdkClientManager.destroyClient(userId = userId) } + } + @Test fun `initializeUserCrypto with sdk success should return InitializeCryptoResult Success`() = runBlocking { + val userId = "userId" val mockInitCryptoRequest = mockk() coEvery { clientCrypto.initializeUserCrypto( @@ -47,6 +71,7 @@ class VaultSdkSourceTest { ) } returns Unit val result = vaultSdkSource.initializeCrypto( + userId = userId, request = mockInitCryptoRequest, ) assertEquals( @@ -58,10 +83,12 @@ class VaultSdkSourceTest { req = mockInitCryptoRequest, ) } + verify { sdkClientManager.getOrCreateClient(userId = userId) } } @Test fun `initializeUserCrypto with sdk failure should return failure`() = runBlocking { + val userId = "userId" val mockInitCryptoRequest = mockk() val expectedException = IllegalStateException("mock") coEvery { @@ -70,6 +97,7 @@ class VaultSdkSourceTest { ) } throws expectedException val result = vaultSdkSource.initializeCrypto( + userId = userId, request = mockInitCryptoRequest, ) assertEquals( @@ -81,11 +109,13 @@ class VaultSdkSourceTest { req = mockInitCryptoRequest, ) } + verify { sdkClientManager.getOrCreateClient(userId = userId) } } @Test fun `initializeUserCrypto with BitwardenException failure should return AuthenticationError`() = runBlocking { + val userId = "userId" val mockInitCryptoRequest = mockk() val expectedException = BitwardenException.E(message = "") coEvery { @@ -94,6 +124,7 @@ class VaultSdkSourceTest { ) } throws expectedException val result = vaultSdkSource.initializeCrypto( + userId = userId, request = mockInitCryptoRequest, ) assertEquals( @@ -105,11 +136,13 @@ class VaultSdkSourceTest { req = mockInitCryptoRequest, ) } + verify { sdkClientManager.getOrCreateClient(userId = userId) } } @Test fun `initializeOrgCrypto with sdk success should return InitializeCryptoResult Success`() = runBlocking { + val userId = "userId" val mockInitCryptoRequest = mockk() coEvery { clientCrypto.initializeOrgCrypto( @@ -117,6 +150,7 @@ class VaultSdkSourceTest { ) } returns Unit val result = vaultSdkSource.initializeOrganizationCrypto( + userId = userId, request = mockInitCryptoRequest, ) assertEquals( @@ -128,10 +162,12 @@ class VaultSdkSourceTest { req = mockInitCryptoRequest, ) } + verify { sdkClientManager.getOrCreateClient(userId = userId) } } @Test fun `initializeOrgCrypto with sdk failure should return failure`() = runBlocking { + val userId = "userId" val mockInitCryptoRequest = mockk() val expectedException = IllegalStateException("mock") coEvery { @@ -140,6 +176,7 @@ class VaultSdkSourceTest { ) } throws expectedException val result = vaultSdkSource.initializeOrganizationCrypto( + userId = userId, request = mockInitCryptoRequest, ) assertEquals( @@ -151,11 +188,13 @@ class VaultSdkSourceTest { req = mockInitCryptoRequest, ) } + verify { sdkClientManager.getOrCreateClient(userId = userId) } } @Test fun `initializeOrgCrypto with BitwardenException failure should return AuthenticationError`() = runBlocking { + val userId = "userId" val mockInitCryptoRequest = mockk() val expectedException = BitwardenException.E(message = "") coEvery { @@ -164,6 +203,7 @@ class VaultSdkSourceTest { ) } throws expectedException val result = vaultSdkSource.initializeOrganizationCrypto( + userId = userId, request = mockInitCryptoRequest, ) assertEquals( @@ -175,10 +215,12 @@ class VaultSdkSourceTest { req = mockInitCryptoRequest, ) } + verify { sdkClientManager.getOrCreateClient(userId = userId) } } @Test fun `decryptCipher should call SDK and return a Result with correct data`() = runBlocking { + val userId = "userId" val mockCipher = mockk() val expectedResult = mockk() coEvery { @@ -187,6 +229,7 @@ class VaultSdkSourceTest { ) } returns expectedResult val result = vaultSdkSource.encryptCipher( + userId = userId, cipherView = mockCipher, ) assertEquals( @@ -198,10 +241,12 @@ class VaultSdkSourceTest { cipherView = mockCipher, ) } + verify { sdkClientManager.getOrCreateClient(userId = userId) } } @Test fun `Cipher decrypt should call SDK and return a Result with correct data`() = runBlocking { + val userId = "userId" val mockCipher = mockk() val expectedResult = mockk() coEvery { @@ -210,6 +255,7 @@ class VaultSdkSourceTest { ) } returns expectedResult val result = vaultSdkSource.decryptCipher( + userId = userId, cipher = mockCipher, ) assertEquals( @@ -221,11 +267,13 @@ class VaultSdkSourceTest { cipher = mockCipher, ) } + verify { sdkClientManager.getOrCreateClient(userId = userId) } } @Test fun `Cipher decryptListCollection should call SDK and return a Result with correct data`() = runBlocking { + val userId = "userId" val mockCiphers = mockk>() val expectedResult = mockk>() coEvery { @@ -234,6 +282,7 @@ class VaultSdkSourceTest { ) } returns expectedResult val result = vaultSdkSource.decryptCipherListCollection( + userId = userId, cipherList = mockCiphers, ) assertEquals( @@ -245,10 +294,12 @@ class VaultSdkSourceTest { ciphers = mockCiphers, ) } + verify { sdkClientManager.getOrCreateClient(userId = userId) } } @Test fun `Cipher decryptList should call SDK and return a Result with correct data`() = runBlocking { + val userId = "userId" val mockCiphers = mockk() val expectedResult = mockk() coEvery { @@ -257,6 +308,7 @@ class VaultSdkSourceTest { ) } returns expectedResult val result = vaultSdkSource.decryptCipherList( + userId = userId, cipherList = listOf(mockCiphers), ) assertEquals( @@ -268,11 +320,13 @@ class VaultSdkSourceTest { cipher = mockCiphers, ) } + verify { sdkClientManager.getOrCreateClient(userId = userId) } } @Test fun `decryptCollection should call SDK and return correct data wrapped in a Result`() = runBlocking { + val userId = "userId" val mockCollection = mockk() val expectedResult = mockk() coEvery { @@ -281,6 +335,7 @@ class VaultSdkSourceTest { ) } returns expectedResult val result = vaultSdkSource.decryptCollection( + userId = userId, collection = mockCollection, ) assertEquals( @@ -291,11 +346,13 @@ class VaultSdkSourceTest { collection = mockCollection, ) } + verify { sdkClientManager.getOrCreateClient(userId = userId) } } @Test fun `decryptCollectionList should call SDK and return correct data wrapped in a Result`() = runBlocking { + val userId = "userId" val mockCollectionsList = mockk>() val expectedResult = mockk>() coEvery { @@ -304,6 +361,7 @@ class VaultSdkSourceTest { ) } returns expectedResult val result = vaultSdkSource.decryptCollectionList( + userId = userId, collectionList = mockCollectionsList, ) assertEquals( @@ -315,11 +373,13 @@ class VaultSdkSourceTest { collections = mockCollectionsList, ) } + verify { sdkClientManager.getOrCreateClient(userId = userId) } } @Test fun `decryptSendList should call SDK and return correct data wrapped in a Result`() = runBlocking { + val userId = "userId" val mockSend = mockk() val expectedResult = mockk() coEvery { @@ -328,6 +388,7 @@ class VaultSdkSourceTest { ) } returns expectedResult val result = vaultSdkSource.decryptSendList( + userId = userId, sendList = listOf(mockSend), ) assertEquals( @@ -339,11 +400,13 @@ class VaultSdkSourceTest { send = mockSend, ) } + verify { sdkClientManager.getOrCreateClient(userId = userId) } } @Test fun `decryptSend should call SDK and return correct data wrapped in a Result`() = runBlocking { + val userId = "userId" val mockSend = mockk() val expectedResult = mockk() coEvery { @@ -352,6 +415,7 @@ class VaultSdkSourceTest { ) } returns expectedResult val result = vaultSdkSource.decryptSend( + userId = userId, send = mockSend, ) assertEquals( @@ -362,10 +426,12 @@ class VaultSdkSourceTest { send = mockSend, ) } + verify { sdkClientManager.getOrCreateClient(userId = userId) } } @Test fun `Folder decrypt should call SDK and return a Result with correct data`() = runBlocking { + val userId = "userId" val mockFolder = mockk() val expectedResult = mockk() coEvery { @@ -374,6 +440,7 @@ class VaultSdkSourceTest { ) } returns expectedResult val result = vaultSdkSource.decryptFolder( + userId = userId, folder = mockFolder, ) assertEquals( @@ -385,10 +452,12 @@ class VaultSdkSourceTest { folder = mockFolder, ) } + verify { sdkClientManager.getOrCreateClient(userId = userId) } } @Test fun `Folder decryptList should call SDK and return a Result with correct data`() = runBlocking { + val userId = "userId" val mockFolders = mockk>() val expectedResult = mockk>() coEvery { @@ -397,6 +466,7 @@ class VaultSdkSourceTest { ) } returns expectedResult val result = vaultSdkSource.decryptFolderList( + userId = userId, folderList = mockFolders, ) assertEquals( @@ -408,11 +478,13 @@ class VaultSdkSourceTest { folders = mockFolders, ) } + verify { sdkClientManager.getOrCreateClient(userId = userId) } } @Test fun `encryptPasswordHistory should call SDK and return a Result with correct data`() = runBlocking { + val userId = "userId" val mockPasswordHistoryView = mockk() val expectedResult = mockk() coEvery { @@ -422,6 +494,7 @@ class VaultSdkSourceTest { } returns expectedResult val result = vaultSdkSource.encryptPasswordHistory( + userId = userId, passwordHistory = mockPasswordHistoryView, ) @@ -431,11 +504,13 @@ class VaultSdkSourceTest { passwordHistory = mockPasswordHistoryView, ) } + verify { sdkClientManager.getOrCreateClient(userId = userId) } } @Test fun `decryptPasswordHistoryList should call SDK and return a Result with correct data`() = runBlocking { + val userId = "userId" val mockPasswordHistoryList = mockk>() val expectedResult = mockk>() coEvery { @@ -445,6 +520,7 @@ class VaultSdkSourceTest { } returns expectedResult val result = vaultSdkSource.decryptPasswordHistoryList( + userId = userId, passwordHistoryList = mockPasswordHistoryList, ) @@ -454,5 +530,6 @@ class VaultSdkSourceTest { list = mockPasswordHistoryList, ) } + verify { sdkClientManager.getOrCreateClient(userId = userId) } } } diff --git a/app/src/test/java/com/x8bit/bitwarden/data/vault/repository/VaultRepositoryTest.kt b/app/src/test/java/com/x8bit/bitwarden/data/vault/repository/VaultRepositoryTest.kt index 76a1d37469..a63a50d43a 100644 --- a/app/src/test/java/com/x8bit/bitwarden/data/vault/repository/VaultRepositoryTest.kt +++ b/app/src/test/java/com/x8bit/bitwarden/data/vault/repository/VaultRepositoryTest.kt @@ -59,6 +59,7 @@ import io.mockk.every import io.mockk.just import io.mockk.mockk import io.mockk.runs +import io.mockk.verify import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.cancel @@ -77,7 +78,9 @@ class VaultRepositoryTest { private val syncService: SyncService = mockk() private val ciphersService: CiphersService = mockk() private val vaultDiskSource: VaultDiskSource = mockk() - private val vaultSdkSource: VaultSdkSource = mockk() + private val vaultSdkSource: VaultSdkSource = mockk { + every { clearCrypto(userId = any()) } just runs + } private val vaultRepository = VaultRepositoryImpl( syncService = syncService, ciphersService = ciphersService, @@ -91,6 +94,7 @@ class VaultRepositoryTest { fun `ciphersStateFlow should emit decrypted list of ciphers when decryptCipherList succeeds`() = runTest { fakeAuthDiskSource.userState = MOCK_USER_STATE + val userId = "mockId-1" val mockCipherList = listOf(createMockCipher(number = 1)) val mockEncryptedCipherList = mockCipherList.toEncryptedSdkCipherList() val mockCipherViewList = listOf(createMockCipherView(number = 1)) @@ -100,7 +104,10 @@ class VaultRepositoryTest { vaultDiskSource.getCiphers(userId = MOCK_USER_STATE.activeUserId) } returns mutableCiphersStateFlow coEvery { - vaultSdkSource.decryptCipherList(mockEncryptedCipherList) + vaultSdkSource.decryptCipherList( + userId = userId, + cipherList = mockEncryptedCipherList, + ) } returns mockCipherViewList.asSuccess() vaultRepository @@ -115,6 +122,7 @@ class VaultRepositoryTest { @Test fun `ciphersStateFlow should emit an error when decryptCipherList fails`() = runTest { fakeAuthDiskSource.userState = MOCK_USER_STATE + val userId = "mockId-1" val throwable = Throwable("Fail") val mockCipherList = listOf(createMockCipher(number = 1)) val mockEncryptedCipherList = mockCipherList.toEncryptedSdkCipherList() @@ -124,7 +132,10 @@ class VaultRepositoryTest { vaultDiskSource.getCiphers(userId = MOCK_USER_STATE.activeUserId) } returns mutableCiphersStateFlow coEvery { - vaultSdkSource.decryptCipherList(mockEncryptedCipherList) + vaultSdkSource.decryptCipherList( + userId = userId, + cipherList = mockEncryptedCipherList, + ) } returns throwable.asFailure() vaultRepository @@ -141,6 +152,7 @@ class VaultRepositoryTest { fun `collectionsStateFlow should emit decrypted list of collections when decryptCollectionList succeeds`() = runTest { fakeAuthDiskSource.userState = MOCK_USER_STATE + val userId = "mockId-1" val mockCollectionList = listOf(createMockCollection(number = 1)) val mockEncryptedCollectionList = mockCollectionList.toEncryptedSdkCollectionList() val mockCollectionViewList = listOf(createMockCollectionView(number = 1)) @@ -150,7 +162,10 @@ class VaultRepositoryTest { vaultDiskSource.getCollections(userId = MOCK_USER_STATE.activeUserId) } returns mutableCollectionsStateFlow coEvery { - vaultSdkSource.decryptCollectionList(mockEncryptedCollectionList) + vaultSdkSource.decryptCollectionList( + userId = userId, + collectionList = mockEncryptedCollectionList, + ) } returns mockCollectionViewList.asSuccess() vaultRepository @@ -165,6 +180,7 @@ class VaultRepositoryTest { @Test fun `collectionsStateFlow should emit an error when decryptCollectionList fails`() = runTest { fakeAuthDiskSource.userState = MOCK_USER_STATE + val userId = "mockId-1" val throwable = Throwable("Fail") val mockCollectionList = listOf(createMockCollection(number = 1)) val mockEncryptedCollectionList = mockCollectionList.toEncryptedSdkCollectionList() @@ -174,7 +190,10 @@ class VaultRepositoryTest { vaultDiskSource.getCollections(userId = MOCK_USER_STATE.activeUserId) } returns mutableCollectionStateFlow coEvery { - vaultSdkSource.decryptCollectionList(mockEncryptedCollectionList) + vaultSdkSource.decryptCollectionList( + userId = userId, + collectionList = mockEncryptedCollectionList, + ) } returns throwable.asFailure() vaultRepository @@ -191,6 +210,7 @@ class VaultRepositoryTest { fun `foldersStateFlow should emit decrypted list of folders when decryptFolderList succeeds`() = runTest { fakeAuthDiskSource.userState = MOCK_USER_STATE + val userId = "mockId-1" val mockFolderList = listOf(createMockFolder(number = 1)) val mockEncryptedFolderList = mockFolderList.toEncryptedSdkFolderList() val mockFolderViewList = listOf(createMockFolderView(number = 1)) @@ -200,7 +220,10 @@ class VaultRepositoryTest { vaultDiskSource.getFolders(userId = MOCK_USER_STATE.activeUserId) } returns mutableFoldersStateFlow coEvery { - vaultSdkSource.decryptFolderList(mockEncryptedFolderList) + vaultSdkSource.decryptFolderList( + userId = userId, + folderList = mockEncryptedFolderList, + ) } returns mockFolderViewList.asSuccess() vaultRepository @@ -215,6 +238,7 @@ class VaultRepositoryTest { @Test fun `foldersStateFlow should emit an error when decryptFolderList fails`() = runTest { fakeAuthDiskSource.userState = MOCK_USER_STATE + val userId = "mockId-1" val throwable = Throwable("Fail") val mockFolderList = listOf(createMockFolder(number = 1)) val mockEncryptedFolderList = mockFolderList.toEncryptedSdkFolderList() @@ -224,7 +248,10 @@ class VaultRepositoryTest { vaultDiskSource.getFolders(userId = MOCK_USER_STATE.activeUserId) } returns mutableFoldersStateFlow coEvery { - vaultSdkSource.decryptFolderList(mockEncryptedFolderList) + vaultSdkSource.decryptFolderList( + userId = userId, + folderList = mockEncryptedFolderList, + ) } returns throwable.asFailure() vaultRepository @@ -240,6 +267,7 @@ class VaultRepositoryTest { fun `sendDataStateFlow should emit decrypted list of sends when decryptSendsList succeeds`() = runTest { fakeAuthDiskSource.userState = MOCK_USER_STATE + val userId = "mockId-1" val mockSendList = listOf(createMockSend(number = 1)) val mockEncryptedSendList = mockSendList.toEncryptedSdkSendList() val mockSendViewList = listOf(createMockSendView(number = 1)) @@ -249,7 +277,10 @@ class VaultRepositoryTest { vaultDiskSource.getSends(userId = MOCK_USER_STATE.activeUserId) } returns mutableSendsStateFlow coEvery { - vaultSdkSource.decryptSendList(mockEncryptedSendList) + vaultSdkSource.decryptSendList( + userId = userId, + sendList = mockEncryptedSendList, + ) } returns mockSendViewList.asSuccess() vaultRepository @@ -264,6 +295,7 @@ class VaultRepositoryTest { @Test fun `sendDataStateFlow should emit an error when decryptSendsList fails`() = runTest { fakeAuthDiskSource.userState = MOCK_USER_STATE + val userId = "mockId-1" val throwable = Throwable("Fail") val mockSendList = listOf(createMockSend(number = 1)) val mockEncryptedSendList = mockSendList.toEncryptedSdkSendList() @@ -273,7 +305,10 @@ class VaultRepositoryTest { vaultDiskSource.getSends(userId = MOCK_USER_STATE.activeUserId) } returns mutableSendsStateFlow coEvery { - vaultSdkSource.decryptSendList(mockEncryptedSendList) + vaultSdkSource.decryptSendList( + userId = userId, + sendList = mockEncryptedSendList, + ) } returns throwable.asFailure() vaultRepository @@ -302,10 +337,12 @@ class VaultRepositoryTest { fun `sync with syncService Success should unlock the vault for orgs if necessary and update AuthDiskSource and VaultDiskSource`() = runTest { fakeAuthDiskSource.userState = MOCK_USER_STATE + val userId = "mockId-1" val mockSyncResponse = createMockSyncResponse(number = 1) coEvery { syncService.sync() } returns mockSyncResponse.asSuccess() coEvery { vaultSdkSource.initializeOrganizationCrypto( + userId = userId, request = InitOrgCryptoRequest( organizationKeys = createMockOrganizationKeys(1), ), @@ -356,6 +393,7 @@ class VaultRepositoryTest { vault = mockSyncResponse, ) vaultSdkSource.initializeOrganizationCrypto( + userId = userId, request = InitOrgCryptoRequest( organizationKeys = createMockOrganizationKeys(1), ), @@ -366,6 +404,7 @@ class VaultRepositoryTest { @Test fun `sync with syncService Failure should update DataStateFlow with an Error`() = runTest { fakeAuthDiskSource.userState = MOCK_USER_STATE + val userId = "mockId-1" val mockException = IllegalStateException("sad") coEvery { syncService.sync() } returns mockException.asFailure() @@ -392,6 +431,7 @@ class VaultRepositoryTest { @Test fun `sync with syncService Failure should update vaultDataStateFlow with an Error`() = runTest { fakeAuthDiskSource.userState = MOCK_USER_STATE + val userId = "mockId-1" val mockException = IllegalStateException("sad") coEvery { syncService.sync() } returns mockException.asFailure() setupVaultDiskSourceFlows() @@ -408,6 +448,7 @@ class VaultRepositoryTest { @Test fun `sync with NoNetwork should update DataStateFlows to NoNetwork`() = runTest { fakeAuthDiskSource.userState = MOCK_USER_STATE + val userId = "mockId-1" coEvery { syncService.sync() } returns UnknownHostException().asFailure() vaultRepository.sync() @@ -433,6 +474,7 @@ class VaultRepositoryTest { @Test fun `sync with NoNetwork should update vaultDataStateFlow to NoNetwork`() = runTest { fakeAuthDiskSource.userState = MOCK_USER_STATE + val userId = "mockId-1" coEvery { syncService.sync() } returns UnknownHostException().asFailure() setupVaultDiskSourceFlows() @@ -450,11 +492,15 @@ class VaultRepositoryTest { fun `sync with NoNetwork data should update sendDataStateFlow to Pending and NoNetwork with data`() = runTest { fakeAuthDiskSource.userState = MOCK_USER_STATE + val userId = "mockId-1" coEvery { syncService.sync() } returns UnknownHostException().asFailure() val sendsFlow = bufferedMutableSharedFlow>() setupVaultDiskSourceFlows(sendsFlow = sendsFlow) coEvery { - vaultSdkSource.decryptSendList(listOf(createMockSdkSend(1))) + vaultSdkSource.decryptSendList( + userId = userId, + sendList = listOf(createMockSdkSend(1)), + ) } returns listOf(createMockSendView(1)).asSuccess() vaultRepository @@ -501,6 +547,7 @@ class VaultRepositoryTest { ), vaultRepository.vaultStateFlow.value, ) + verify { vaultSdkSource.clearCrypto(userId = userId) } } @Test @@ -524,16 +571,19 @@ class VaultRepositoryTest { ), vaultRepository.vaultStateFlow.value, ) + verify { vaultSdkSource.clearCrypto(userId = userId) } } @Suppress("MaxLineLength") @Test fun `unlockVaultAndSyncForCurrentUser with unlockVault Success should sync and return Success`() = runTest { + val userId = "mockId-1" val mockSyncResponse = createMockSyncResponse(number = 1) coEvery { syncService.sync() } returns mockSyncResponse.asSuccess() coEvery { vaultSdkSource.initializeOrganizationCrypto( + userId = userId, request = InitOrgCryptoRequest( organizationKeys = createMockOrganizationKeys(1), ), @@ -546,7 +596,10 @@ class VaultRepositoryTest { ) } just runs coEvery { - vaultSdkSource.decryptSendList(listOf(createMockSdkSend(number = 1))) + vaultSdkSource.decryptSendList( + userId = userId, + sendList = listOf(createMockSdkSend(number = 1)), + ) } returns listOf(createMockSendView(number = 1)).asSuccess() fakeAuthDiskSource.storePrivateKey( userId = "mockId-1", @@ -563,6 +616,7 @@ class VaultRepositoryTest { fakeAuthDiskSource.userState = MOCK_USER_STATE coEvery { vaultSdkSource.initializeCrypto( + userId = userId, request = InitUserCryptoRequest( kdfParams = Kdf.Pbkdf2(iterations = DEFAULT_PBKDF2_ITERATIONS.toUInt()), email = "email", @@ -601,10 +655,12 @@ class VaultRepositoryTest { @Test fun `sync should be able to be called after unlockVaultAndSyncForCurrentUser is canceled`() = runTest { + val userId = "mockId-1" val mockSyncResponse = createMockSyncResponse(number = 1) coEvery { syncService.sync() } returns mockSyncResponse.asSuccess() coEvery { vaultSdkSource.initializeOrganizationCrypto( + userId = userId, request = InitOrgCryptoRequest( organizationKeys = createMockOrganizationKeys(1), ), @@ -617,7 +673,10 @@ class VaultRepositoryTest { ) } just runs coEvery { - vaultSdkSource.decryptSendList(listOf(createMockSdkSend(number = 1))) + vaultSdkSource.decryptSendList( + userId = userId, + sendList = listOf(createMockSdkSend(number = 1)), + ) } returns listOf(createMockSendView(number = 1)).asSuccess() fakeAuthDiskSource.storePrivateKey( userId = "mockId-1", @@ -630,6 +689,7 @@ class VaultRepositoryTest { fakeAuthDiskSource.userState = MOCK_USER_STATE coEvery { vaultSdkSource.initializeCrypto( + userId = userId, request = InitUserCryptoRequest( kdfParams = Kdf.Pbkdf2(iterations = DEFAULT_PBKDF2_ITERATIONS.toUInt()), email = "email", @@ -668,8 +728,10 @@ class VaultRepositoryTest { userKey = "mockKey-1", ) fakeAuthDiskSource.userState = MOCK_USER_STATE + val userId = "mockId-1" coEvery { vaultSdkSource.initializeCrypto( + userId = userId, request = InitUserCryptoRequest( kdfParams = Kdf.Pbkdf2(iterations = DEFAULT_PBKDF2_ITERATIONS.toUInt()), email = "email", @@ -711,8 +773,10 @@ class VaultRepositoryTest { userKey = "mockKey-1", ) fakeAuthDiskSource.userState = MOCK_USER_STATE + val userId = "mockId-1" coEvery { vaultSdkSource.initializeCrypto( + userId = userId, request = InitUserCryptoRequest( kdfParams = Kdf.Pbkdf2(iterations = DEFAULT_PBKDF2_ITERATIONS.toUInt()), email = "email", @@ -767,8 +831,10 @@ class VaultRepositoryTest { organizationKeys = createMockOrganizationKeys(number = 1), ) fakeAuthDiskSource.userState = MOCK_USER_STATE + val userId = "mockId-1" coEvery { vaultSdkSource.initializeCrypto( + userId = userId, request = InitUserCryptoRequest( kdfParams = Kdf.Pbkdf2(iterations = DEFAULT_PBKDF2_ITERATIONS.toUInt()), email = "email", @@ -782,6 +848,7 @@ class VaultRepositoryTest { } returns InitializeCryptoResult.Success.asSuccess() coEvery { vaultSdkSource.initializeOrganizationCrypto( + userId = userId, request = InitOrgCryptoRequest( organizationKeys = createMockOrganizationKeys(1), ), @@ -825,8 +892,10 @@ class VaultRepositoryTest { userKey = "mockKey-1", ) fakeAuthDiskSource.userState = MOCK_USER_STATE + val userId = "mockId-1" coEvery { vaultSdkSource.initializeCrypto( + userId = userId, request = InitUserCryptoRequest( kdfParams = Kdf.Pbkdf2(iterations = DEFAULT_PBKDF2_ITERATIONS.toUInt()), email = "email", @@ -876,8 +945,10 @@ class VaultRepositoryTest { organizationKeys = createMockOrganizationKeys(number = 1), ) fakeAuthDiskSource.userState = MOCK_USER_STATE + val userId = "mockId-1" coEvery { vaultSdkSource.initializeCrypto( + userId = userId, request = InitUserCryptoRequest( kdfParams = Kdf.Pbkdf2(iterations = DEFAULT_PBKDF2_ITERATIONS.toUInt()), email = "email", @@ -891,6 +962,7 @@ class VaultRepositoryTest { } returns InitializeCryptoResult.Success.asSuccess() coEvery { vaultSdkSource.initializeOrganizationCrypto( + userId = userId, request = InitOrgCryptoRequest( organizationKeys = createMockOrganizationKeys(1), ), @@ -963,6 +1035,7 @@ class VaultRepositoryTest { privateKey = "mockPrivateKey-1", ) fakeAuthDiskSource.userState = MOCK_USER_STATE + val userId = "mockId-1" assertEquals( VaultUnlockResult.InvalidStateError, result, @@ -995,6 +1068,7 @@ class VaultRepositoryTest { privateKey = null, ) fakeAuthDiskSource.userState = MOCK_USER_STATE + val userId = "mockId-1" assertEquals( VaultUnlockResult.InvalidStateError, result, @@ -1018,6 +1092,7 @@ class VaultRepositoryTest { val organizationKeys = mapOf("orgId1" to "orgKey1") coEvery { vaultSdkSource.initializeCrypto( + userId = userId, request = InitUserCryptoRequest( kdfParams = kdf, email = email, @@ -1031,6 +1106,7 @@ class VaultRepositoryTest { } returns InitializeCryptoResult.Success.asSuccess() coEvery { vaultSdkSource.initializeOrganizationCrypto( + userId = userId, request = InitOrgCryptoRequest(organizationKeys = organizationKeys), ) } returns InitializeCryptoResult.Success.asSuccess() @@ -1060,6 +1136,7 @@ class VaultRepositoryTest { ) coVerify(exactly = 1) { vaultSdkSource.initializeCrypto( + userId = userId, request = InitUserCryptoRequest( kdfParams = kdf, email = email, @@ -1073,6 +1150,7 @@ class VaultRepositoryTest { } coVerify(exactly = 1) { vaultSdkSource.initializeOrganizationCrypto( + userId = userId, request = InitOrgCryptoRequest(organizationKeys = organizationKeys), ) } @@ -1091,6 +1169,7 @@ class VaultRepositoryTest { val organizationKeys = mapOf("orgId1" to "orgKey1") coEvery { vaultSdkSource.initializeCrypto( + userId = userId, request = InitUserCryptoRequest( kdfParams = kdf, email = email, @@ -1129,6 +1208,7 @@ class VaultRepositoryTest { ) coVerify(exactly = 1) { vaultSdkSource.initializeCrypto( + userId = userId, request = InitUserCryptoRequest( kdfParams = kdf, email = email, @@ -1155,6 +1235,7 @@ class VaultRepositoryTest { val organizationKeys = mapOf("orgId1" to "orgKey1") coEvery { vaultSdkSource.initializeCrypto( + userId = userId, request = InitUserCryptoRequest( kdfParams = kdf, email = email, @@ -1168,6 +1249,7 @@ class VaultRepositoryTest { } returns InitializeCryptoResult.Success.asSuccess() coEvery { vaultSdkSource.initializeOrganizationCrypto( + userId = userId, request = InitOrgCryptoRequest(organizationKeys = organizationKeys), ) } returns InitializeCryptoResult.AuthenticationError.asSuccess() @@ -1198,6 +1280,7 @@ class VaultRepositoryTest { ) coVerify(exactly = 1) { vaultSdkSource.initializeCrypto( + userId = userId, request = InitUserCryptoRequest( kdfParams = kdf, email = email, @@ -1211,6 +1294,7 @@ class VaultRepositoryTest { } coVerify(exactly = 1) { vaultSdkSource.initializeOrganizationCrypto( + userId = userId, request = InitOrgCryptoRequest(organizationKeys = organizationKeys), ) } @@ -1228,6 +1312,7 @@ class VaultRepositoryTest { val organizationKeys = mapOf("orgId1" to "orgKey1") coEvery { vaultSdkSource.initializeCrypto( + userId = userId, request = InitUserCryptoRequest( kdfParams = kdf, email = email, @@ -1265,6 +1350,7 @@ class VaultRepositoryTest { ) coVerify(exactly = 1) { vaultSdkSource.initializeCrypto( + userId = userId, request = InitUserCryptoRequest( kdfParams = kdf, email = email, @@ -1290,6 +1376,7 @@ class VaultRepositoryTest { val organizationKeys = mapOf("orgId1" to "orgKey1") coEvery { vaultSdkSource.initializeCrypto( + userId = userId, request = InitUserCryptoRequest( kdfParams = kdf, email = email, @@ -1303,6 +1390,7 @@ class VaultRepositoryTest { } returns InitializeCryptoResult.Success.asSuccess() coEvery { vaultSdkSource.initializeOrganizationCrypto( + userId = userId, request = InitOrgCryptoRequest(organizationKeys = organizationKeys), ) } returns Throwable("Fail").asFailure() @@ -1332,6 +1420,7 @@ class VaultRepositoryTest { ) coVerify(exactly = 1) { vaultSdkSource.initializeCrypto( + userId = userId, request = InitUserCryptoRequest( kdfParams = kdf, email = email, @@ -1345,6 +1434,7 @@ class VaultRepositoryTest { } coVerify(exactly = 1) { vaultSdkSource.initializeOrganizationCrypto( + userId = userId, request = InitOrgCryptoRequest(organizationKeys = organizationKeys), ) } @@ -1361,6 +1451,7 @@ class VaultRepositoryTest { val organizationKeys = null coEvery { vaultSdkSource.initializeCrypto( + userId = userId, request = InitUserCryptoRequest( kdfParams = kdf, email = email, @@ -1392,6 +1483,7 @@ class VaultRepositoryTest { coVerify(exactly = 0) { syncService.sync() } coVerify(exactly = 1) { vaultSdkSource.initializeCrypto( + userId = userId, request = InitUserCryptoRequest( kdfParams = kdf, email = email, @@ -1408,17 +1500,30 @@ class VaultRepositoryTest { @Test fun `clearUnlockedData should update the vaultDataStateFlow to Loading`() = runTest { fakeAuthDiskSource.userState = MOCK_USER_STATE + val userId = "mockId-1" coEvery { - vaultSdkSource.decryptCipherList(listOf(createMockSdkCipher(1))) + vaultSdkSource.decryptCipherList( + userId = userId, + cipherList = listOf(createMockSdkCipher(1)), + ) } returns listOf(createMockCipherView(number = 1)).asSuccess() coEvery { - vaultSdkSource.decryptFolderList(listOf(createMockSdkFolder(1))) + vaultSdkSource.decryptFolderList( + userId = userId, + folderList = listOf(createMockSdkFolder(1)), + ) } returns listOf(createMockFolderView(number = 1)).asSuccess() coEvery { - vaultSdkSource.decryptCollectionList(listOf(createMockSdkCollection(1))) + vaultSdkSource.decryptCollectionList( + userId = userId, + collectionList = listOf(createMockSdkCollection(1)), + ) } returns listOf(createMockCollectionView(number = 1)).asSuccess() coEvery { - vaultSdkSource.decryptSendList(listOf(createMockSdkSend(number = 1))) + vaultSdkSource.decryptSendList( + userId = userId, + sendList = listOf(createMockSdkSend(number = 1)), + ) } returns listOf(createMockSendView(number = 1)).asSuccess() val ciphersFlow = bufferedMutableSharedFlow>() val collectionsFlow = bufferedMutableSharedFlow>() @@ -1456,8 +1561,12 @@ class VaultRepositoryTest { @Test fun `clearUnlockedData should update the sendDataStateFlow to Loading`() = runTest { fakeAuthDiskSource.userState = MOCK_USER_STATE + val userId = "mockId-1" coEvery { - vaultSdkSource.decryptSendList(listOf(createMockSdkSend(number = 1))) + vaultSdkSource.decryptSendList( + userId = userId, + sendList = listOf(createMockSdkSend(number = 1)), + ) } returns listOf(createMockSendView(number = 1)).asSuccess() val sendsFlow = bufferedMutableSharedFlow>() setupVaultDiskSourceFlows(sendsFlow = sendsFlow) @@ -1486,6 +1595,7 @@ class VaultRepositoryTest { val folderIdString = "mockId-$folderId" val throwable = Throwable("Fail") fakeAuthDiskSource.userState = MOCK_USER_STATE + val userId = "mockId-1" coEvery { syncService.sync() } returns throwable.asFailure() setupVaultDiskSourceFlows() @@ -1506,6 +1616,7 @@ class VaultRepositoryTest { val itemId = 1234 val itemIdString = "mockId-$itemId" fakeAuthDiskSource.userState = MOCK_USER_STATE + val userId = "mockId-1" coEvery { syncService.sync() } returns UnknownHostException().asFailure() setupVaultDiskSourceFlows() @@ -1526,6 +1637,7 @@ class VaultRepositoryTest { val folderId = 1234 val folderIdString = "mockId-$folderId" fakeAuthDiskSource.userState = MOCK_USER_STATE + val userId = "mockId-1" coEvery { syncService.sync() } returns UnknownHostException().asFailure() setupVaultDiskSourceFlows() @@ -1546,6 +1658,7 @@ class VaultRepositoryTest { val folderIdString = "mockId-$folderId" val throwable = Throwable("Fail") fakeAuthDiskSource.userState = MOCK_USER_STATE + val userId = "mockId-1" coEvery { syncService.sync() } returns throwable.asFailure() setupVaultDiskSourceFlows() @@ -1563,9 +1676,14 @@ class VaultRepositoryTest { @Test fun `createCipher with encryptCipher failure should return CreateCipherResult failure`() = runTest { + fakeAuthDiskSource.userState = MOCK_USER_STATE + val userId = "mockId-1" val mockCipherView = createMockCipherView(number = 1) coEvery { - vaultSdkSource.encryptCipher(cipherView = mockCipherView) + vaultSdkSource.encryptCipher( + userId = userId, + cipherView = mockCipherView, + ) } returns IllegalStateException().asFailure() val result = vaultRepository.createCipher(cipherView = mockCipherView) @@ -1580,9 +1698,14 @@ class VaultRepositoryTest { @Suppress("MaxLineLength") fun `createCipher with ciphersService createCipher failure should return CreateCipherResult failure`() = runTest { + fakeAuthDiskSource.userState = MOCK_USER_STATE + val userId = "mockId-1" val mockCipherView = createMockCipherView(number = 1) coEvery { - vaultSdkSource.encryptCipher(cipherView = mockCipherView) + vaultSdkSource.encryptCipher( + userId = userId, + cipherView = mockCipherView, + ) } returns createMockSdkCipher(number = 1).asSuccess() coEvery { ciphersService.createCipher( @@ -1602,9 +1725,14 @@ class VaultRepositoryTest { @Suppress("MaxLineLength") fun `createCipher with ciphersService createCipher success should return CreateCipherResult success`() = runTest { + fakeAuthDiskSource.userState = MOCK_USER_STATE + val userId = "mockId-1" val mockCipherView = createMockCipherView(number = 1) coEvery { - vaultSdkSource.encryptCipher(cipherView = mockCipherView) + vaultSdkSource.encryptCipher( + userId = userId, + cipherView = mockCipherView, + ) } returns createMockSdkCipher(number = 1).asSuccess() coEvery { ciphersService.createCipher( @@ -1614,15 +1742,25 @@ class VaultRepositoryTest { coEvery { syncService.sync() } returns Result.success(createMockSyncResponse(1)) + coEvery { + vaultDiskSource.replaceVaultData( + userId = userId, + vault = createMockSyncResponse(1), + ) + } just runs coEvery { vaultSdkSource.initializeOrganizationCrypto( + userId = userId, request = InitOrgCryptoRequest( organizationKeys = createMockOrganizationKeys(1), ), ) } returns InitializeCryptoResult.Success.asSuccess() coEvery { - vaultSdkSource.decryptSendList(listOf(createMockSdkSend(1))) + vaultSdkSource.decryptSendList( + userId = userId, + sendList = listOf(createMockSdkSend(1)), + ) } returns listOf(createMockSendView(1)).asSuccess() val result = vaultRepository.createCipher(cipherView = mockCipherView) @@ -1636,10 +1774,15 @@ class VaultRepositoryTest { @Test fun `updateCipher with encryptCipher failure should return UpdateCipherResult failure`() = runTest { + fakeAuthDiskSource.userState = MOCK_USER_STATE + val userId = "mockId-1" val cipherId = "cipherId1234" val mockCipherView = createMockCipherView(number = 1) coEvery { - vaultSdkSource.encryptCipher(cipherView = mockCipherView) + vaultSdkSource.encryptCipher( + userId = userId, + cipherView = mockCipherView, + ) } returns IllegalStateException().asFailure() val result = vaultRepository.updateCipher( @@ -1654,10 +1797,15 @@ class VaultRepositoryTest { @Suppress("MaxLineLength") fun `updateCipher with ciphersService updateCipher failure should return UpdateCipherResult Error with a null message`() = runTest { + fakeAuthDiskSource.userState = MOCK_USER_STATE + val userId = "mockId-1" val cipherId = "cipherId1234" val mockCipherView = createMockCipherView(number = 1) coEvery { - vaultSdkSource.encryptCipher(cipherView = mockCipherView) + vaultSdkSource.encryptCipher( + userId = userId, + cipherView = mockCipherView, + ) } returns createMockSdkCipher(number = 1).asSuccess() coEvery { ciphersService.updateCipher( @@ -1678,10 +1826,15 @@ class VaultRepositoryTest { @Suppress("MaxLineLength") fun `updateCipher with ciphersService updateCipher Invalid response should return UpdateCipherResult Error with a non-null message`() = runTest { + fakeAuthDiskSource.userState = MOCK_USER_STATE + val userId = "mockId-1" val cipherId = "cipherId1234" val mockCipherView = createMockCipherView(number = 1) coEvery { - vaultSdkSource.encryptCipher(cipherView = mockCipherView) + vaultSdkSource.encryptCipher( + userId = userId, + cipherView = mockCipherView, + ) } returns createMockSdkCipher(number = 1).asSuccess() coEvery { ciphersService.updateCipher( @@ -1712,10 +1865,15 @@ class VaultRepositoryTest { @Suppress("MaxLineLength") fun `updateCipher with ciphersService updateCipher Success response should return UpdateCipherResult success`() = runTest { + fakeAuthDiskSource.userState = MOCK_USER_STATE + val userId = "mockId-1" val cipherId = "cipherId1234" val mockCipherView = createMockCipherView(number = 1) coEvery { - vaultSdkSource.encryptCipher(cipherView = mockCipherView) + vaultSdkSource.encryptCipher( + userId = userId, + cipherView = mockCipherView, + ) } returns createMockSdkCipher(number = 1).asSuccess() coEvery { ciphersService.updateCipher( @@ -1728,15 +1886,25 @@ class VaultRepositoryTest { coEvery { syncService.sync() } returns Result.success(createMockSyncResponse(1)) + coEvery { + vaultDiskSource.replaceVaultData( + userId = userId, + vault = createMockSyncResponse(1), + ) + } just runs coEvery { vaultSdkSource.initializeOrganizationCrypto( + userId = userId, request = InitOrgCryptoRequest( organizationKeys = createMockOrganizationKeys(1), ), ) } returns InitializeCryptoResult.Success.asSuccess() coEvery { - vaultSdkSource.decryptSendList(listOf(createMockSdkSend(1))) + vaultSdkSource.decryptSendList( + userId = userId, + sendList = listOf(createMockSdkSend(1)), + ) } returns listOf(createMockSendView(1)).asSuccess() val result = vaultRepository.updateCipher( @@ -1779,6 +1947,7 @@ class VaultRepositoryTest { val organizationKeys = null coEvery { vaultSdkSource.initializeCrypto( + userId = userId, request = InitUserCryptoRequest( kdfParams = kdf, email = email, @@ -1804,6 +1973,7 @@ class VaultRepositoryTest { assertEquals(VaultUnlockResult.Success, result) coVerify(exactly = 1) { vaultSdkSource.initializeCrypto( + userId = userId, request = InitUserCryptoRequest( kdfParams = kdf, email = email,