Add VaultState.unlockingVaultUserIds and clean up the vault unlock logic (#546)

This commit is contained in:
Brian Yencho
2024-01-08 22:56:35 -06:00
committed by Álison Fernandes
parent 940979599e
commit d95e5df2a7
5 changed files with 169 additions and 67 deletions

View File

@@ -91,14 +91,17 @@ class VaultRepositoryImpl(
private var syncJob: Job = Job().apply { complete() }
private var willSyncAfterUnlock = false
private val activeUserId: String? get() = authDiskSource.userState?.activeUserId
private val mutableTotpCodeFlow = bufferedMutableSharedFlow<String>()
private val mutableVaultStateStateFlow =
MutableStateFlow(VaultState(unlockedVaultUserIds = emptySet()))
MutableStateFlow(
VaultState(
unlockedVaultUserIds = emptySet(),
unlockingVaultUserIds = emptySet(),
),
)
private val mutableSendDataStateFlow = MutableStateFlow<DataState<SendData>>(DataState.Loading)
@@ -194,8 +197,8 @@ class VaultRepositoryImpl(
}
override fun sync() {
if (!syncJob.isCompleted || willSyncAfterUnlock) return
val userId = activeUserId ?: return
if (!syncJob.isCompleted || isVaultUnlocking(userId)) return
mutableCiphersStateFlow.updateToPendingOrLoading()
mutableFoldersStateFlow.updateToPendingOrLoading()
mutableCollectionsStateFlow.updateToPendingOrLoading()
@@ -290,22 +293,15 @@ class VaultRepositoryImpl(
override suspend fun unlockVaultAndSyncForCurrentUser(
masterPassword: String,
): VaultUnlockResult {
val userState = authDiskSource.userState
val userId = activeUserId ?: return VaultUnlockResult.InvalidStateError
val userKey = authDiskSource.getUserKey(userId = userId)
?: return VaultUnlockResult.InvalidStateError
val userKey = authDiskSource.getUserKey(userId = userState.activeUserId)
?: return VaultUnlockResult.InvalidStateError
val privateKey = authDiskSource.getPrivateKey(userId = userState.activeUserId)
?: return VaultUnlockResult.InvalidStateError
val organizationKeys = authDiskSource
.getOrganizationKeys(userId = userState.activeUserId)
return unlockVault(
userId = userState.activeUserId,
masterPassword = masterPassword,
email = userState.activeAccount.profile.email,
kdf = userState.activeAccount.profile.toSdkParams(),
userKey = userKey,
privateKey = privateKey,
organizationKeys = organizationKeys,
return unlockVaultForUser(
userId = userId,
initUserCryptoMethod = InitUserCryptoMethod.Password(
password = masterPassword,
userKey = userKey,
),
)
.also {
if (it is VaultUnlockResult.Success) {
@@ -323,53 +319,17 @@ class VaultRepositoryImpl(
privateKey: String,
organizationKeys: Map<String, String>?,
): VaultUnlockResult =
flow {
willSyncAfterUnlock = true
emit(
vaultSdkSource
.initializeCrypto(
userId = userId,
request = InitUserCryptoRequest(
kdfParams = kdf,
email = email,
privateKey = privateKey,
method = InitUserCryptoMethod.Password(
password = masterPassword,
userKey = userKey,
),
),
)
.flatMap { result ->
// Initialize the SDK for organizations if necessary
if (organizationKeys != null &&
result is InitializeCryptoResult.Success
) {
vaultSdkSource.initializeOrganizationCrypto(
userId = userId,
request = InitOrgCryptoRequest(
organizationKeys = organizationKeys,
),
)
} else {
result.asSuccess()
}
}
.fold(
onFailure = { VaultUnlockResult.GenericError },
onSuccess = { initializeCryptoResult ->
initializeCryptoResult
.toVaultUnlockResult()
.also {
if (it is VaultUnlockResult.Success) {
setVaultToUnlocked(userId = userId)
}
}
},
),
)
}
.onCompletion { willSyncAfterUnlock = false }
.first()
unlockVaultInternal(
userId = userId,
email = email,
kdf = kdf,
privateKey = privateKey,
initUserCryptoMethod = InitUserCryptoMethod.Password(
password = masterPassword,
userKey = userKey,
),
organizationKeys = organizationKeys,
)
override suspend fun createCipher(cipherView: CipherView): CreateCipherResult {
val userId = requireNotNull(activeUserId)
@@ -506,6 +466,25 @@ class VaultRepositoryImpl(
}
}
private fun setVaultToUnlocking(userId: String) {
mutableVaultStateStateFlow.update {
it.copy(
unlockingVaultUserIds = it.unlockingVaultUserIds + userId,
)
}
}
private fun setVaultToNotUnlocking(userId: String) {
mutableVaultStateStateFlow.update {
it.copy(
unlockingVaultUserIds = it.unlockingVaultUserIds - userId,
)
}
}
private fun isVaultUnlocking(userId: String) =
userId in mutableVaultStateStateFlow.value.unlockingVaultUserIds
private fun storeProfileData(
syncResponse: SyncResponseJson,
) {
@@ -536,6 +515,80 @@ class VaultRepositoryImpl(
}
}
@Suppress("ReturnCount")
private suspend fun unlockVaultForUser(
userId: String,
initUserCryptoMethod: InitUserCryptoMethod,
): VaultUnlockResult {
val account = authDiskSource.userState?.accounts?.get(userId)
?: return VaultUnlockResult.InvalidStateError
val privateKey = authDiskSource.getPrivateKey(userId = userId)
?: return VaultUnlockResult.InvalidStateError
val organizationKeys = authDiskSource
.getOrganizationKeys(userId = userId)
return unlockVaultInternal(
userId = userId,
email = account.profile.email,
kdf = account.profile.toSdkParams(),
privateKey = privateKey,
initUserCryptoMethod = initUserCryptoMethod,
organizationKeys = organizationKeys,
)
}
private suspend fun unlockVaultInternal(
userId: String,
email: String,
kdf: Kdf,
privateKey: String,
initUserCryptoMethod: InitUserCryptoMethod,
organizationKeys: Map<String, String>?,
): VaultUnlockResult =
flow {
setVaultToUnlocking(userId = userId)
emit(
vaultSdkSource
.initializeCrypto(
userId = userId,
request = InitUserCryptoRequest(
kdfParams = kdf,
email = email,
privateKey = privateKey,
method = initUserCryptoMethod,
),
)
.flatMap { result ->
// Initialize the SDK for organizations if necessary
if (organizationKeys != null &&
result is InitializeCryptoResult.Success
) {
vaultSdkSource.initializeOrganizationCrypto(
userId = userId,
request = InitOrgCryptoRequest(
organizationKeys = organizationKeys,
),
)
} else {
result.asSuccess()
}
}
.fold(
onFailure = { VaultUnlockResult.GenericError },
onSuccess = { initializeCryptoResult ->
initializeCryptoResult
.toVaultUnlockResult()
.also {
if (it is VaultUnlockResult.Success) {
setVaultToUnlocked(userId = userId)
}
}
},
),
)
}
.onCompletion { setVaultToNotUnlocking(userId = userId) }
.first()
private suspend fun unlockVaultForOrganizationsIfNecessary(
syncResponse: SyncResponseJson,
) {

View File

@@ -4,7 +4,10 @@ package com.x8bit.bitwarden.data.vault.repository.model
* General description of the vault across multiple users.
*
* @property unlockedVaultUserIds The user IDs for all users that currently have unlocked vaults.
* @property unlockedVaultUserIds The user IDs for all users that are actively unlocking their
* vaults.
*/
data class VaultState(
val unlockedVaultUserIds: Set<String>,
val unlockingVaultUserIds: Set<String>,
)