diff --git a/app/src/main/java/com/x8bit/bitwarden/data/auth/datasource/disk/model/EnvironmentUrlDataJson.kt b/app/src/main/java/com/x8bit/bitwarden/data/auth/datasource/disk/model/EnvironmentUrlDataJson.kt index 738091f49b..b8f5426d5d 100644 --- a/app/src/main/java/com/x8bit/bitwarden/data/auth/datasource/disk/model/EnvironmentUrlDataJson.kt +++ b/app/src/main/java/com/x8bit/bitwarden/data/auth/datasource/disk/model/EnvironmentUrlDataJson.kt @@ -7,6 +7,7 @@ import kotlinx.serialization.Serializable * Represents URLs for various Bitwarden domains. * * @property base The overall base URL. + * @property keyUri A Uri containing the alias and host of the key used for mutual TLS. * @property api Separate base URL for the "/api" domain (if applicable). * @property identity Separate base URL for the "/identity" domain (if applicable). * @property icon Separate base URL for the icon domain (if applicable). @@ -19,6 +20,9 @@ data class EnvironmentUrlDataJson( @SerialName("base") val base: String, + @SerialName("keyUri") + val keyUri: String? = null, + @SerialName("api") val api: String? = null, @@ -51,6 +55,7 @@ data class EnvironmentUrlDataJson( */ val DEFAULT_LEGACY_US: EnvironmentUrlDataJson = EnvironmentUrlDataJson( base = "https://vault.bitwarden.com", + keyUri = null, api = "https://api.bitwarden.com", identity = "https://identity.bitwarden.com", icon = "https://icons.bitwarden.net", @@ -71,6 +76,7 @@ data class EnvironmentUrlDataJson( */ val DEFAULT_LEGACY_EU: EnvironmentUrlDataJson = EnvironmentUrlDataJson( base = "https://vault.bitwarden.eu", + keyUri = null, api = "https://api.bitwarden.eu", identity = "https://identity.bitwarden.eu", icon = "https://icons.bitwarden.eu", diff --git a/app/src/main/java/com/x8bit/bitwarden/data/platform/datasource/network/di/PlatformNetworkModule.kt b/app/src/main/java/com/x8bit/bitwarden/data/platform/datasource/network/di/PlatformNetworkModule.kt index ff47617734..8ad8fa3d56 100644 --- a/app/src/main/java/com/x8bit/bitwarden/data/platform/datasource/network/di/PlatformNetworkModule.kt +++ b/app/src/main/java/com/x8bit/bitwarden/data/platform/datasource/network/di/PlatformNetworkModule.kt @@ -14,6 +14,10 @@ import com.x8bit.bitwarden.data.platform.datasource.network.service.EventService import com.x8bit.bitwarden.data.platform.datasource.network.service.EventServiceImpl import com.x8bit.bitwarden.data.platform.datasource.network.service.PushService import com.x8bit.bitwarden.data.platform.datasource.network.service.PushServiceImpl +import com.x8bit.bitwarden.data.platform.datasource.network.ssl.SslManager +import com.x8bit.bitwarden.data.platform.datasource.network.ssl.SslManagerImpl +import com.x8bit.bitwarden.data.platform.manager.KeyManager +import com.x8bit.bitwarden.data.platform.repository.EnvironmentRepository import dagger.Module import dagger.Provides import dagger.hilt.InstallIn @@ -70,6 +74,17 @@ object PlatformNetworkModule { @Singleton fun providesRefreshAuthenticator(): RefreshAuthenticator = RefreshAuthenticator() + @Provides + @Singleton + fun provideSslManager( + keyManager: KeyManager, + environmentRepository: EnvironmentRepository, + ): SslManager = + SslManagerImpl( + keyManager = keyManager, + environmentRepository = environmentRepository, + ) + @Provides @Singleton fun provideRetrofits( @@ -77,6 +92,7 @@ object PlatformNetworkModule { baseUrlInterceptors: BaseUrlInterceptors, headersInterceptor: HeadersInterceptor, refreshAuthenticator: RefreshAuthenticator, + sslManager: SslManager, json: Json, ): Retrofits = RetrofitsImpl( @@ -84,6 +100,7 @@ object PlatformNetworkModule { baseUrlInterceptors = baseUrlInterceptors, headersInterceptor = headersInterceptor, refreshAuthenticator = refreshAuthenticator, + sslManager = sslManager, json = json, ) diff --git a/app/src/main/java/com/x8bit/bitwarden/data/platform/datasource/network/retrofit/RetrofitsImpl.kt b/app/src/main/java/com/x8bit/bitwarden/data/platform/datasource/network/retrofit/RetrofitsImpl.kt index 7ee544457c..a473dda8dd 100644 --- a/app/src/main/java/com/x8bit/bitwarden/data/platform/datasource/network/retrofit/RetrofitsImpl.kt +++ b/app/src/main/java/com/x8bit/bitwarden/data/platform/datasource/network/retrofit/RetrofitsImpl.kt @@ -6,6 +6,7 @@ import com.x8bit.bitwarden.data.platform.datasource.network.interceptor.AuthToke import com.x8bit.bitwarden.data.platform.datasource.network.interceptor.BaseUrlInterceptor import com.x8bit.bitwarden.data.platform.datasource.network.interceptor.BaseUrlInterceptors import com.x8bit.bitwarden.data.platform.datasource.network.interceptor.HeadersInterceptor +import com.x8bit.bitwarden.data.platform.datasource.network.ssl.SslManager import com.x8bit.bitwarden.data.platform.datasource.network.util.HEADER_KEY_AUTHORIZATION import kotlinx.serialization.json.Json import okhttp3.MediaType.Companion.toMediaType @@ -14,6 +15,9 @@ import okhttp3.logging.HttpLoggingInterceptor import retrofit2.Retrofit import retrofit2.converter.kotlinx.serialization.asConverterFactory import timber.log.Timber +import javax.net.ssl.SSLContext +import javax.net.ssl.TrustManager +import javax.net.ssl.X509TrustManager /** * Primary implementation of [Retrofits]. @@ -24,6 +28,7 @@ class RetrofitsImpl( headersInterceptor: HeadersInterceptor, refreshAuthenticator: RefreshAuthenticator, json: Json, + private val sslManager: SslManager, ) : Retrofits { //region Authenticated Retrofits @@ -67,6 +72,10 @@ class RetrofitsImpl( baseClient .newBuilder() .addInterceptor(loggingInterceptor) + .setSslSocketFactory( + sslContext = sslManager.sslContext, + trustManagers = sslManager.trustManagers, + ) .build(), ) .build() @@ -93,6 +102,10 @@ class RetrofitsImpl( .newBuilder() .authenticator(refreshAuthenticator) .addInterceptor(authTokenInterceptor) + .setSslSocketFactory( + sslContext = sslManager.sslContext, + trustManagers = sslManager.trustManagers, + ) .build() } @@ -133,9 +146,22 @@ class RetrofitsImpl( .newBuilder() .addInterceptor(baseUrlInterceptor) .addInterceptor(loggingInterceptor) + .setSslSocketFactory( + sslContext = sslManager.sslContext, + trustManagers = sslManager.trustManagers, + ) .build(), ) .build() + private fun OkHttpClient.Builder.setSslSocketFactory( + sslContext: SSLContext, + trustManagers: Array, + ): OkHttpClient.Builder = + sslSocketFactory( + sslContext.socketFactory, + trustManagers.first() as X509TrustManager, + ) + //endregion Helper properties and functions } diff --git a/app/src/main/java/com/x8bit/bitwarden/data/platform/datasource/network/ssl/SslManager.kt b/app/src/main/java/com/x8bit/bitwarden/data/platform/datasource/network/ssl/SslManager.kt new file mode 100644 index 0000000000..bee538c080 --- /dev/null +++ b/app/src/main/java/com/x8bit/bitwarden/data/platform/datasource/network/ssl/SslManager.kt @@ -0,0 +1,20 @@ +package com.x8bit.bitwarden.data.platform.datasource.network.ssl + +import javax.net.ssl.SSLContext +import javax.net.ssl.TrustManager + +/** + * Interface for managing SSL connections. + */ +interface SslManager { + + /** + * The SSL context to use for SSL connections. + */ + val sslContext: SSLContext + + /** + * The trust managers to use for SSL connections. + */ + val trustManagers: Array +} diff --git a/app/src/main/java/com/x8bit/bitwarden/data/platform/datasource/network/ssl/SslManagerImpl.kt b/app/src/main/java/com/x8bit/bitwarden/data/platform/datasource/network/ssl/SslManagerImpl.kt new file mode 100644 index 0000000000..2785be891e --- /dev/null +++ b/app/src/main/java/com/x8bit/bitwarden/data/platform/datasource/network/ssl/SslManagerImpl.kt @@ -0,0 +1,116 @@ +package com.x8bit.bitwarden.data.platform.datasource.network.ssl + +import android.net.Uri +import androidx.annotation.VisibleForTesting +import androidx.annotation.WorkerThread +import androidx.core.net.toUri +import com.x8bit.bitwarden.data.platform.datasource.disk.model.MutualTlsCertificate +import com.x8bit.bitwarden.data.platform.datasource.disk.model.MutualTlsKeyHost +import com.x8bit.bitwarden.data.platform.manager.KeyManager +import com.x8bit.bitwarden.data.platform.repository.EnvironmentRepository +import java.net.Socket +import java.security.KeyStore +import java.security.Principal +import java.security.PrivateKey +import java.security.cert.X509Certificate +import javax.net.ssl.SSLContext +import javax.net.ssl.TrustManager +import javax.net.ssl.TrustManagerFactory +import javax.net.ssl.X509ExtendedKeyManager + +/** + * Primary implementation of [SslManager]. + */ +class SslManagerImpl( + private val keyManager: KeyManager, + private val environmentRepository: EnvironmentRepository, +) : SslManager { + + /* + This property must only be accessed from a background thread. Accessing this property from + the main thread will result in an exception being thrown when retrieving the mutual TLS + certificate from [KeyManager]. + */ + @VisibleForTesting(otherwise = VisibleForTesting.PRIVATE) + @get:WorkerThread + internal val mutualTlsCertificate: MutualTlsCertificate? + get() { + val keyUri = getKeyUri() + ?: return null + + val host = MutualTlsKeyHost + .entries + .find { it.name == keyUri.authority } + ?: return null + + val alias = keyUri.path + ?.trim('/') + ?.takeUnless { it.isEmpty() } + ?: return null + + return keyManager.getMutualTlsCertificateChain( + alias = alias, + host = host, + ) + } + + override val trustManagers: Array + get() = TrustManagerFactory + .getInstance(TrustManagerFactory.getDefaultAlgorithm()) + .apply { init(null as KeyStore?) } + .trustManagers + + override val sslContext: SSLContext + get() = SSLContext + .getInstance("TLS") + .apply { + init( + arrayOf(X509ExtendedKeyManagerImpl()), + trustManagers, + null, + ) + } + + private fun getKeyUri(): Uri? = environmentRepository + .environment + .environmentUrlData + .keyUri + ?.toUri() + + private inner class X509ExtendedKeyManagerImpl : X509ExtendedKeyManager() { + override fun chooseClientAlias( + keyType: Array?, + issuers: Array?, + socket: Socket?, + ): String = mutualTlsCertificate?.alias ?: "" + + override fun getCertificateChain( + alias: String?, + ): Array? = + mutualTlsCertificate + ?.certificateChain + ?.toTypedArray() + + override fun getPrivateKey(alias: String?): PrivateKey? = + mutualTlsCertificate + ?.privateKey + + //region Unused server side methods + override fun getServerAliases( + alias: String?, + issuers: Array?, + ): Array = arrayOf() + + override fun getClientAliases( + keyType: String?, + issuers: Array?, + ): Array = emptyArray() + + override fun chooseServerAlias( + alias: String?, + issuers: Array?, + socket: Socket?, + ): String = "" + //endregion Unused server side methods + } +} diff --git a/app/src/main/java/com/x8bit/bitwarden/data/platform/manager/KeyManager.kt b/app/src/main/java/com/x8bit/bitwarden/data/platform/manager/KeyManager.kt index 4ce0a9d42a..5a24ec75a9 100644 --- a/app/src/main/java/com/x8bit/bitwarden/data/platform/manager/KeyManager.kt +++ b/app/src/main/java/com/x8bit/bitwarden/data/platform/manager/KeyManager.kt @@ -1,6 +1,7 @@ package com.x8bit.bitwarden.data.platform.manager -import com.x8bit.bitwarden.data.platform.datasource.disk.model.ImportPrivateKeyResult +import androidx.annotation.WorkerThread +import com.x8bit.bitwarden.data.platform.manager.model.ImportPrivateKeyResult import com.x8bit.bitwarden.data.platform.datasource.disk.model.MutualTlsCertificate import com.x8bit.bitwarden.data.platform.datasource.disk.model.MutualTlsKeyHost @@ -29,7 +30,10 @@ interface KeyManager { /** * Retrieve the certificate chain for the selected mTLS key. + * + * Must be called from a background thread to prevent possible deadlocks on the main thread. */ + @WorkerThread fun getMutualTlsCertificateChain( alias: String, host: MutualTlsKeyHost, diff --git a/app/src/main/java/com/x8bit/bitwarden/data/platform/manager/KeyManagerImpl.kt b/app/src/main/java/com/x8bit/bitwarden/data/platform/manager/KeyManagerImpl.kt index 19e9cb3b39..ef23908bdd 100644 --- a/app/src/main/java/com/x8bit/bitwarden/data/platform/manager/KeyManagerImpl.kt +++ b/app/src/main/java/com/x8bit/bitwarden/data/platform/manager/KeyManagerImpl.kt @@ -3,9 +3,9 @@ package com.x8bit.bitwarden.data.platform.manager import android.content.Context import android.security.KeyChain import android.security.KeyChainException -import com.x8bit.bitwarden.data.platform.datasource.disk.model.ImportPrivateKeyResult import com.x8bit.bitwarden.data.platform.datasource.disk.model.MutualTlsCertificate import com.x8bit.bitwarden.data.platform.datasource.disk.model.MutualTlsKeyHost +import com.x8bit.bitwarden.data.platform.manager.model.ImportPrivateKeyResult import timber.log.Timber import java.io.IOException import java.security.KeyStore diff --git a/app/src/main/java/com/x8bit/bitwarden/data/platform/datasource/disk/model/ImportPrivateKeyResult.kt b/app/src/main/java/com/x8bit/bitwarden/data/platform/manager/model/ImportPrivateKeyResult.kt similarity index 94% rename from app/src/main/java/com/x8bit/bitwarden/data/platform/datasource/disk/model/ImportPrivateKeyResult.kt rename to app/src/main/java/com/x8bit/bitwarden/data/platform/manager/model/ImportPrivateKeyResult.kt index b0bcf9fffc..2cf360ddb6 100644 --- a/app/src/main/java/com/x8bit/bitwarden/data/platform/datasource/disk/model/ImportPrivateKeyResult.kt +++ b/app/src/main/java/com/x8bit/bitwarden/data/platform/manager/model/ImportPrivateKeyResult.kt @@ -1,4 +1,4 @@ -package com.x8bit.bitwarden.data.platform.datasource.disk.model +package com.x8bit.bitwarden.data.platform.manager.model /** * Models the result of importing a private key. diff --git a/app/src/main/java/com/x8bit/bitwarden/ui/auth/feature/environment/EnvironmentScreen.kt b/app/src/main/java/com/x8bit/bitwarden/ui/auth/feature/environment/EnvironmentScreen.kt index 3508ebd44e..696704008e 100644 --- a/app/src/main/java/com/x8bit/bitwarden/ui/auth/feature/environment/EnvironmentScreen.kt +++ b/app/src/main/java/com/x8bit/bitwarden/ui/auth/feature/environment/EnvironmentScreen.kt @@ -18,6 +18,7 @@ import androidx.compose.runtime.Composable import androidx.compose.runtime.getValue import androidx.compose.runtime.remember import androidx.compose.ui.Modifier +import androidx.compose.ui.focus.focusProperties import androidx.compose.ui.input.nestedscroll.nestedScroll import androidx.compose.ui.platform.LocalContext import androidx.compose.ui.platform.testTag @@ -31,13 +32,21 @@ import com.x8bit.bitwarden.R import com.x8bit.bitwarden.ui.platform.base.util.EventsEffect import com.x8bit.bitwarden.ui.platform.base.util.standardHorizontalMargin import com.x8bit.bitwarden.ui.platform.components.appbar.BitwardenTopAppBar +import com.x8bit.bitwarden.ui.platform.components.button.BitwardenFilledButton +import com.x8bit.bitwarden.ui.platform.components.button.BitwardenOutlinedButton import com.x8bit.bitwarden.ui.platform.components.button.BitwardenTextButton import com.x8bit.bitwarden.ui.platform.components.dialog.BitwardenBasicDialog +import com.x8bit.bitwarden.ui.platform.components.dialog.BitwardenClientCertificateDialog +import com.x8bit.bitwarden.ui.platform.components.dialog.BitwardenTwoButtonDialog import com.x8bit.bitwarden.ui.platform.components.field.BitwardenTextField import com.x8bit.bitwarden.ui.platform.components.header.BitwardenListHeaderText import com.x8bit.bitwarden.ui.platform.components.model.CardStyle import com.x8bit.bitwarden.ui.platform.components.scaffold.BitwardenScaffold import com.x8bit.bitwarden.ui.platform.components.util.rememberVectorPainter +import com.x8bit.bitwarden.ui.platform.composition.LocalIntentManager +import com.x8bit.bitwarden.ui.platform.composition.LocalKeyChainManager +import com.x8bit.bitwarden.ui.platform.manager.intent.IntentManager +import com.x8bit.bitwarden.ui.platform.manager.keychain.KeyChainManager import kotlinx.collections.immutable.persistentListOf /** @@ -48,27 +57,100 @@ import kotlinx.collections.immutable.persistentListOf @Composable fun EnvironmentScreen( onNavigateBack: () -> Unit, + intentManager: IntentManager = LocalIntentManager.current, + keyChainManager: KeyChainManager = LocalKeyChainManager.current, viewModel: EnvironmentViewModel = hiltViewModel(), ) { val state by viewModel.stateFlow.collectAsStateWithLifecycle() val context = LocalContext.current + val certificateImportFilePickerLauncher = intentManager.getActivityResultLauncher { result -> + intentManager.getFileDataFromActivityResult(result)?.let { + viewModel.trySendAction( + EnvironmentAction.ImportCertificateFilePickerResultReceive(it), + ) + } + } EventsEffect(viewModel = viewModel) { event -> when (event) { is EnvironmentEvent.NavigateBack -> onNavigateBack.invoke() is EnvironmentEvent.ShowToast -> { Toast.makeText(context, event.message(context.resources), Toast.LENGTH_SHORT).show() } + + is EnvironmentEvent.ShowCertificateImportFileChooser -> { + certificateImportFilePickerLauncher.launch( + intentManager.createFileChooserIntent(withCameraIntents = false), + ) + } + + is EnvironmentEvent.ShowSystemCertificateSelectionDialog -> { + viewModel.trySendAction( + EnvironmentAction.SystemCertificateSelectionResultReceive( + keyChainManager.choosePrivateKeyAlias( + currentServerUrl = event.serverUrl?.takeUnless { it.isEmpty() }, + ), + ), + ) + } } } - if (state.shouldShowErrorDialog) { - BitwardenBasicDialog( - title = stringResource(id = R.string.an_error_has_occurred), - message = stringResource(id = R.string.environment_page_urls_error), - onDismissRequest = remember(viewModel) { - { viewModel.trySendAction(EnvironmentAction.ErrorDialogDismiss) } - }, - ) + when (val dialog = state.dialog) { + is EnvironmentState.DialogState.Error -> { + BitwardenBasicDialog( + title = stringResource(id = R.string.an_error_has_occurred), + message = dialog.message(), + onDismissRequest = remember(viewModel) { + { viewModel.trySendAction(EnvironmentAction.ErrorDialogDismiss) } + }, + ) + } + + is EnvironmentState.DialogState.SetCertificateData -> { + BitwardenClientCertificateDialog( + onConfirmClick = remember(viewModel) { + { alias, password -> + viewModel.trySendAction( + EnvironmentAction.SetCertificateInfoResultReceive( + certificateFileData = dialog.certificateBytes, + password = password, + alias = alias, + ), + ) + } + }, + onDismissRequest = remember(viewModel) { + { + viewModel.trySendAction( + action = EnvironmentAction.SetCertificatePasswordDialogDismiss, + ) + } + }, + ) + } + + is EnvironmentState.DialogState.SystemCertificateWarningDialog -> { + @Suppress("MaxLineLength") + BitwardenTwoButtonDialog( + title = stringResource(R.string.warning), + message = stringResource( + R.string.system_certificates_are_not_as_secure_as_importing_certificates_to_bitwarden, + ), + confirmButtonText = stringResource(R.string.continue_text), + onConfirmClick = remember(viewModel) { + { viewModel.trySendAction(EnvironmentAction.ConfirmChooseSystemCertificateClick) } + }, + dismissButtonText = stringResource(R.string.cancel), + onDismissClick = remember(viewModel) { + { viewModel.trySendAction(EnvironmentAction.ErrorDialogDismiss) } + }, + onDismissRequest = remember(viewModel) { + { viewModel.trySendAction(EnvironmentAction.ErrorDialogDismiss) } + }, + ) + } + + null -> Unit } val scrollBehavior = TopAppBarDefaults.pinnedScrollBehavior(rememberTopAppBarState()) BitwardenScaffold( @@ -138,7 +220,7 @@ fun EnvironmentScreen( .standardHorizontalMargin(), ) - Spacer(modifier = Modifier.height(height = 16.dp)) + Spacer(modifier = Modifier.height(16.dp)) BitwardenListHeaderText( label = stringResource(id = R.string.custom_environment), @@ -213,6 +295,59 @@ fun EnvironmentScreen( .standardHorizontalMargin(), ) + if (state.showMutualTlsOptions) { + Spacer(modifier = Modifier.height(height = 16.dp)) + + BitwardenListHeaderText( + label = stringResource(id = R.string.client_certificate_mtls), + modifier = Modifier + .fillMaxWidth() + .standardHorizontalMargin() + .padding(horizontal = 16.dp), + ) + Spacer(modifier = Modifier.height(height = 8.dp)) + + BitwardenTextField( + label = stringResource(id = R.string.certificate_alias), + value = state.keyAlias, + supportingText = stringResource( + id = R.string.certificate_used_for_client_authentication, + ), + onValueChange = {}, + readOnly = true, + cardStyle = CardStyle.Full, + textFieldTestTag = "KeyAliasEntry", + modifier = Modifier + .fillMaxWidth() + .focusProperties { canFocus = false } + .standardHorizontalMargin(), + ) + Spacer(modifier = Modifier.height(height = 16.dp)) + + BitwardenFilledButton( + label = stringResource(id = R.string.import_certificate), + onClick = remember(viewModel) { + { viewModel.trySendAction(EnvironmentAction.ImportCertificateClick) } + }, + modifier = Modifier + .fillMaxWidth() + .standardHorizontalMargin() + .testTag("ImportCertificateButton"), + ) + + Spacer(modifier = Modifier.height(height = 12.dp)) + + BitwardenOutlinedButton( + label = stringResource(id = R.string.choose_system_certificate), + onClick = remember(viewModel) { + { viewModel.trySendAction(EnvironmentAction.ChooseSystemCertificateClick) } + }, + modifier = Modifier + .fillMaxWidth() + .standardHorizontalMargin() + .testTag("ChooseSystemCertificateButton"), + ) + } Spacer(modifier = Modifier.height(height = 16.dp)) Spacer(modifier = Modifier.navigationBarsPadding()) } diff --git a/app/src/main/java/com/x8bit/bitwarden/ui/auth/feature/environment/EnvironmentViewModel.kt b/app/src/main/java/com/x8bit/bitwarden/ui/auth/feature/environment/EnvironmentViewModel.kt index 4c93113f89..6e7e9593db 100644 --- a/app/src/main/java/com/x8bit/bitwarden/ui/auth/feature/environment/EnvironmentViewModel.kt +++ b/app/src/main/java/com/x8bit/bitwarden/ui/auth/feature/environment/EnvironmentViewModel.kt @@ -1,21 +1,32 @@ package com.x8bit.bitwarden.ui.auth.feature.environment import android.os.Parcelable +import androidx.core.net.toUri import androidx.lifecycle.SavedStateHandle import androidx.lifecycle.viewModelScope import com.x8bit.bitwarden.R import com.x8bit.bitwarden.data.auth.datasource.disk.model.EnvironmentUrlDataJson +import com.x8bit.bitwarden.data.platform.datasource.disk.model.MutualTlsKeyHost +import com.x8bit.bitwarden.data.platform.manager.FeatureFlagManager +import com.x8bit.bitwarden.data.platform.manager.KeyManager +import com.x8bit.bitwarden.data.platform.manager.model.FlagKey +import com.x8bit.bitwarden.data.platform.manager.model.ImportPrivateKeyResult import com.x8bit.bitwarden.data.platform.repository.EnvironmentRepository import com.x8bit.bitwarden.data.platform.repository.model.Environment +import com.x8bit.bitwarden.data.vault.manager.FileManager import com.x8bit.bitwarden.ui.platform.base.BaseViewModel import com.x8bit.bitwarden.ui.platform.base.util.Text import com.x8bit.bitwarden.ui.platform.base.util.asText import com.x8bit.bitwarden.ui.platform.base.util.isValidUri import com.x8bit.bitwarden.ui.platform.base.util.orNullIfBlank +import com.x8bit.bitwarden.ui.platform.manager.intent.IntentManager +import com.x8bit.bitwarden.ui.platform.manager.keychain.model.PrivateKeyAliasSelectionResult import dagger.hilt.android.lifecycle.HiltViewModel import kotlinx.coroutines.flow.launchIn +import kotlinx.coroutines.flow.map import kotlinx.coroutines.flow.onEach import kotlinx.coroutines.flow.update +import kotlinx.coroutines.launch import kotlinx.parcelize.Parcelize import javax.inject.Inject @@ -24,9 +35,13 @@ private const val KEY_STATE = "state" /** * View model for the self-hosted/custom environment screen. */ +@Suppress("TooManyFunctions") @HiltViewModel class EnvironmentViewModel @Inject constructor( private val environmentRepository: EnvironmentRepository, + private val fileManager: FileManager, + private val keyManager: KeyManager, + private val featureFlagManager: FeatureFlagManager, private val savedStateHandle: SavedStateHandle, ) : BaseViewModel( initialState = savedStateHandle[KEY_STATE] ?: run { @@ -37,13 +52,19 @@ class EnvironmentViewModel @Inject constructor( is Environment.SelfHosted -> environment.environmentUrlData } + val keyUri = environmentUrlData.keyUri?.toUri() + val keyAlias = keyUri?.path?.trim('/').orEmpty() + val keyHost = MutualTlsKeyHost.entries.find { it.name == keyUri?.authority } EnvironmentState( serverUrl = environmentUrlData.base, webVaultServerUrl = environmentUrlData.webVault.orEmpty(), apiServerUrl = environmentUrlData.api.orEmpty(), identityServerUrl = environmentUrlData.identity.orEmpty(), iconsServerUrl = environmentUrlData.icon.orEmpty(), - shouldShowErrorDialog = false, + keyAlias = keyAlias, + keyHost = keyHost, + dialog = null, + showMutualTlsOptions = featureFlagManager.getFeatureFlag(FlagKey.MutualTls), ) }, ) { @@ -54,6 +75,11 @@ class EnvironmentViewModel @Inject constructor( savedStateHandle[KEY_STATE] = it } .launchIn(viewModelScope) + + featureFlagManager.getFeatureFlagFlow(FlagKey.MutualTls) + .map { EnvironmentAction.Internal.MutualTlsFeatureFlagUpdate(it) } + .onEach(::handleAction) + .launchIn(viewModelScope) } override fun handleAction(action: EnvironmentAction): Unit = when (action) { @@ -65,6 +91,36 @@ class EnvironmentViewModel @Inject constructor( is EnvironmentAction.ApiServerUrlChange -> handleApiServerUrlChangeAction(action) is EnvironmentAction.IdentityServerUrlChange -> handleIdentityServerUrlChangeAction(action) is EnvironmentAction.IconsServerUrlChange -> handleIconsServerUrlChangeAction(action) + is EnvironmentAction.ImportCertificateClick -> handleImportCertificateClick() + is EnvironmentAction.ImportCertificateFilePickerResultReceive -> { + handleCertificateFilePickerResultReceive(action) + } + + is EnvironmentAction.SetCertificatePasswordDialogDismiss -> { + handleSetCertificatePasswordDialogDismiss() + } + + is EnvironmentAction.CertificateInstallationResultReceive -> { + handleCertificateInstallationResultReceive(action) + } + + is EnvironmentAction.SetCertificateInfoResultReceive -> { + handleSetCertificateInfoResultReceive(action) + } + + is EnvironmentAction.ChooseSystemCertificateClick -> { + handleChooseSystemCertificateClickAction() + } + + is EnvironmentAction.ConfirmChooseSystemCertificateClick -> { + handleConfirmChooseSystemCertificateClick() + } + + is EnvironmentAction.SystemCertificateSelectionResultReceive -> { + handleSystemCertificateSelectionResultReceive(action) + } + + is EnvironmentAction.Internal -> handleInternalAction(action) } private fun handleCloseClickAction() { @@ -85,7 +141,7 @@ class EnvironmentViewModel @Inject constructor( } if (!urlsAreAllNullOrValid) { - mutableStateFlow.update { it.copy(shouldShowErrorDialog = true) } + showErrorDialog(message = R.string.environment_page_urls_error.asText()) return } @@ -95,7 +151,6 @@ class EnvironmentViewModel @Inject constructor( val updatedApiServerUrl = state.apiServerUrl.prefixHttpsIfNecessaryOrNull() val updatedIdentityServerUrl = state.identityServerUrl.prefixHttpsIfNecessaryOrNull() val updatedIconsServerUrl = state.iconsServerUrl.prefixHttpsIfNecessaryOrNull() - environmentRepository.environment = Environment.SelfHosted( environmentUrlData = EnvironmentUrlDataJson( base = updatedServerUrl, @@ -103,15 +158,16 @@ class EnvironmentViewModel @Inject constructor( identity = updatedIdentityServerUrl, icon = updatedIconsServerUrl, webVault = updatedWebVaultServerUrl, + keyUri = state.keyUri, ), ) - sendEvent(EnvironmentEvent.ShowToast(message = R.string.environment_saved.asText())) + showToast(message = R.string.environment_saved.asText()) sendEvent(EnvironmentEvent.NavigateBack) } private fun handleErrorDialogDismiss() { - mutableStateFlow.update { it.copy(shouldShowErrorDialog = false) } + mutableStateFlow.update { it.copy(dialog = null) } } private fun handleServerUrlChangeAction( @@ -122,6 +178,45 @@ class EnvironmentViewModel @Inject constructor( } } + private fun handleCertificateInstallationResultReceive( + action: EnvironmentAction.CertificateInstallationResultReceive, + ) { + showToast( + message = if (action.success) { + R.string.certificate_installed.asText() + } else { + R.string.certificate_installation_failed.asText() + }, + ) + } + + private fun handleSetCertificatePasswordDialogDismiss() { + mutableStateFlow.update { it.copy(dialog = null) } + } + + private fun handleCertificateFilePickerResultReceive( + action: EnvironmentAction.ImportCertificateFilePickerResultReceive, + ) { + mutableStateFlow.update { + it.copy( + dialog = EnvironmentState.DialogState.SetCertificateData( + certificateBytes = action.certificateFileData, + ), + ) + } + } + + private fun handleConfirmChooseSystemCertificateClick() { + mutableStateFlow.update { + it.copy(dialog = null) + } + sendEvent( + EnvironmentEvent.ShowSystemCertificateSelectionDialog( + serverUrl = state.serverUrl.prefixHttpsIfNecessaryOrNull(), + ), + ) + } + private fun handleWebVaultServerUrlChangeAction( action: EnvironmentAction.WebVaultServerUrlChange, ) { @@ -153,6 +248,150 @@ class EnvironmentViewModel @Inject constructor( it.copy(iconsServerUrl = action.iconsServerUrl) } } + + private fun handleImportCertificateClick() { + sendEvent(EnvironmentEvent.ShowCertificateImportFileChooser) + } + + private fun handleInternalAction(action: EnvironmentAction.Internal) { + when (action) { + is EnvironmentAction.Internal.ImportKeyResultReceive -> { + handleSaveKeyResultReceive(action) + } + + is EnvironmentAction.Internal.MutualTlsFeatureFlagUpdate -> { + handleMutualTlsFeatureFlagUpdate(action) + } + } + } + + private fun handleSetCertificateInfoResultReceive( + action: EnvironmentAction.SetCertificateInfoResultReceive, + ) { + if (action.password.isBlank()) { + showErrorDialog( + message = R.string.validation_field_required.asText( + R.string.password.asText(), + ), + ) + return + } + + if (action.alias.isBlank()) { + showErrorDialog( + message = R.string.validation_field_required.asText( + R.string.alias.asText(), + ), + ) + return + } + + mutableStateFlow.update { it.copy(dialog = null) } + + viewModelScope.launch { + fileManager + .uriToByteArray(action.certificateFileData.uri) + .map { bytes -> + keyManager.importMutualTlsCertificate( + key = bytes, + alias = action.alias, + password = action.password, + ) + } + .map { result -> + sendAction( + EnvironmentAction.Internal.ImportKeyResultReceive(result), + ) + } + } + } + + private fun handleSaveKeyResultReceive( + action: EnvironmentAction.Internal.ImportKeyResultReceive, + ) { + when (val result = action.result) { + is ImportPrivateKeyResult.Success -> { + mutableStateFlow.update { state -> + state.copy( + keyAlias = result.alias, + keyHost = MutualTlsKeyHost.ANDROID_KEY_STORE, + ) + } + } + + is ImportPrivateKeyResult.Error.UnsupportedKey -> { + showToast(message = R.string.unsupported_certificate_type.asText()) + } + + is ImportPrivateKeyResult.Error.KeyStoreOperationFailed -> { + showToast(message = R.string.certificate_installation_failed.asText()) + } + + is ImportPrivateKeyResult.Error.UnrecoverableKey -> { + showToast(message = R.string.certificate_password_incorrect.asText()) + } + + is ImportPrivateKeyResult.Error.InvalidCertificateChain -> { + showToast(R.string.invalid_certificate_chain.asText()) + } + + ImportPrivateKeyResult.Error.DuplicateAlias -> { + // TODO [PM-17686] Improve duplicate alias handling. + showToast(R.string.certificate_alias_already_exists.asText()) + } + } + } + + private fun handleMutualTlsFeatureFlagUpdate( + action: EnvironmentAction.Internal.MutualTlsFeatureFlagUpdate, + ) { + mutableStateFlow.update { + it.copy( + showMutualTlsOptions = action.enabled, + ) + } + } + + private fun handleChooseSystemCertificateClickAction() { + mutableStateFlow.update { + it.copy( + dialog = EnvironmentState.DialogState.SystemCertificateWarningDialog, + ) + } + } + + private fun handleSystemCertificateSelectionResultReceive( + action: EnvironmentAction.SystemCertificateSelectionResultReceive, + ) { + when (val result = action.privateKeyAliasSelectionResult) { + is PrivateKeyAliasSelectionResult.Success -> { + mutableStateFlow.update { + it.copy( + keyAlias = result.alias.orEmpty(), + keyHost = result.alias?.let { MutualTlsKeyHost.KEY_CHAIN }, + ) + } + } + + is PrivateKeyAliasSelectionResult.Error -> { + sendEvent( + EnvironmentEvent.ShowToast( + message = R.string.error_loading_certificate.asText(), + ), + ) + } + } + } + + private fun showToast(message: Text) { + sendEvent(EnvironmentEvent.ShowToast(message)) + } + + private fun showErrorDialog(message: Text) { + mutableStateFlow.update { + it.copy(dialog = EnvironmentState.DialogState.Error(message = message)) + } + } } /** @@ -165,8 +404,43 @@ data class EnvironmentState( val apiServerUrl: String, val identityServerUrl: String, val iconsServerUrl: String, - val shouldShowErrorDialog: Boolean, -) : Parcelable + val keyAlias: String, + val dialog: DialogState?, + val showMutualTlsOptions: Boolean, + // internal + private val keyHost: MutualTlsKeyHost?, +) : Parcelable { + + val keyUri: String? + get() = "cert://$keyHost/$keyAlias" + .takeUnless { keyHost == null || keyAlias.isEmpty() } + + /** + * Models the dialog states of the environment screen. + */ + @Parcelize + sealed class DialogState : Parcelable { + + /** + * Show an error dialog. + */ + data class Error( + val message: Text, + ) : DialogState() + + /** + * Show a dialog to capture the certificate alias and password. + */ + data class SetCertificateData( + val certificateBytes: IntentManager.FileData, + ) : DialogState() + + /** + * Show a dialog warning the user that system certificates are not as secure. + */ + data object SystemCertificateWarningDialog : DialogState() + } +} /** * Models events for the environment screen. @@ -177,6 +451,18 @@ sealed class EnvironmentEvent { */ data object NavigateBack : EnvironmentEvent() + /** + * Show the File chooser dialog for certificate import. + */ + data object ShowCertificateImportFileChooser : EnvironmentEvent() + + /** + * Show the system certificate selection dialog. + */ + data class ShowSystemCertificateSelectionDialog( + val serverUrl: String?, + ) : EnvironmentEvent() + /** * Show a toast with the given message. */ @@ -204,6 +490,26 @@ sealed class EnvironmentAction { */ data object ErrorDialogDismiss : EnvironmentAction() + /** + * User clicked the import certificate button. + */ + data object ImportCertificateClick : EnvironmentAction() + + /** + * User dismissed the set certificate password dialog without providing a password. + */ + data object SetCertificatePasswordDialogDismiss : EnvironmentAction() + + /** + * User clicked the choose system certificate button. + */ + data object ChooseSystemCertificateClick : EnvironmentAction() + + /** + * User confirmed choosing the system certificate. + */ + data object ConfirmChooseSystemCertificateClick : EnvironmentAction() + /** * Indicates that the overall server URL has changed. */ @@ -211,6 +517,13 @@ sealed class EnvironmentAction { val serverUrl: String, ) : EnvironmentAction() + /** + * Indicates that the certificate installation result was received. + */ + data class CertificateInstallationResultReceive( + val success: Boolean, + ) : EnvironmentAction() + /** * Indicates that the web vault server URL has changed. */ @@ -238,6 +551,48 @@ sealed class EnvironmentAction { data class IconsServerUrlChange( val iconsServerUrl: String, ) : EnvironmentAction() + + /** + * Indicates that the certificate file selection result was received. + */ + data class ImportCertificateFilePickerResultReceive( + val certificateFileData: IntentManager.FileData, + ) : EnvironmentAction() + + /** + * Indicates the certificate info data was received. + */ + data class SetCertificateInfoResultReceive( + val certificateFileData: IntentManager.FileData, + val password: String, + val alias: String, + ) : EnvironmentAction() + + /** + * User has selected a system certificate alias. + */ + data class SystemCertificateSelectionResultReceive( + val privateKeyAliasSelectionResult: PrivateKeyAliasSelectionResult, + ) : EnvironmentAction() + + /** + * Models actions the EnvironmentViewModel itself may trigger. + */ + sealed class Internal : EnvironmentAction() { + /** + * Indicates the result of importing a key was received. + */ + data class ImportKeyResultReceive( + val result: ImportPrivateKeyResult, + ) : Internal() + + /** + * Indicates the mutual TLS feature flag was updated. + */ + data class MutualTlsFeatureFlagUpdate( + val enabled: Boolean, + ) : Internal() + } } /** diff --git a/app/src/main/java/com/x8bit/bitwarden/ui/platform/composition/LocalManagerProvider.kt b/app/src/main/java/com/x8bit/bitwarden/ui/platform/composition/LocalManagerProvider.kt index 335b41f17a..2b82b3ef36 100644 --- a/app/src/main/java/com/x8bit/bitwarden/ui/platform/composition/LocalManagerProvider.kt +++ b/app/src/main/java/com/x8bit/bitwarden/ui/platform/composition/LocalManagerProvider.kt @@ -22,6 +22,8 @@ import com.x8bit.bitwarden.ui.platform.manager.exit.ExitManager import com.x8bit.bitwarden.ui.platform.manager.exit.ExitManagerImpl import com.x8bit.bitwarden.ui.platform.manager.intent.IntentManager import com.x8bit.bitwarden.ui.platform.manager.intent.IntentManagerImpl +import com.x8bit.bitwarden.ui.platform.manager.keychain.KeyChainManager +import com.x8bit.bitwarden.ui.platform.manager.keychain.KeyChainManagerImpl import com.x8bit.bitwarden.ui.platform.manager.nfc.NfcManager import com.x8bit.bitwarden.ui.platform.manager.nfc.NfcManagerImpl import com.x8bit.bitwarden.ui.platform.manager.permissions.PermissionsManager @@ -52,6 +54,7 @@ fun LocalManagerProvider( LocalFido2CompletionManager provides fido2CompletionManager, LocalAppReviewManager provides AppReviewManagerImpl(activity), LocalAppResumeStateManager provides AppResumeStateManagerImpl(), + LocalKeyChainManager provides KeyChainManagerImpl(activity), ) { content() } @@ -110,3 +113,7 @@ val LocalAppReviewManager: ProvidableCompositionLocal = compos val LocalAppResumeStateManager = compositionLocalOf { error("CompositionLocal AppResumeStateManager not present") } + +val LocalKeyChainManager: ProvidableCompositionLocal = compositionLocalOf { + error("CompositionLocal KeyChainManager not present") +} diff --git a/app/src/main/java/com/x8bit/bitwarden/ui/platform/manager/keychain/KeyChainManager.kt b/app/src/main/java/com/x8bit/bitwarden/ui/platform/manager/keychain/KeyChainManager.kt new file mode 100644 index 0000000000..099e9a89a5 --- /dev/null +++ b/app/src/main/java/com/x8bit/bitwarden/ui/platform/manager/keychain/KeyChainManager.kt @@ -0,0 +1,18 @@ +package com.x8bit.bitwarden.ui.platform.manager.keychain + +import com.x8bit.bitwarden.ui.platform.manager.keychain.model.PrivateKeyAliasSelectionResult + +/** + * Responsible for managing keys stored in the system KeyChain. + */ +interface KeyChainManager { + + /** + * Display the system private key alias selection dialog. + * + * @param currentServerUrl The currently selected server URL. + */ + suspend fun choosePrivateKeyAlias( + currentServerUrl: String?, + ): PrivateKeyAliasSelectionResult +} diff --git a/app/src/main/java/com/x8bit/bitwarden/ui/platform/manager/keychain/KeyChainManagerImpl.kt b/app/src/main/java/com/x8bit/bitwarden/ui/platform/manager/keychain/KeyChainManagerImpl.kt new file mode 100644 index 0000000000..5da57f42e3 --- /dev/null +++ b/app/src/main/java/com/x8bit/bitwarden/ui/platform/manager/keychain/KeyChainManagerImpl.kt @@ -0,0 +1,42 @@ +package com.x8bit.bitwarden.ui.platform.manager.keychain + +import android.app.Activity +import android.security.KeyChain +import androidx.core.net.toUri +import com.x8bit.bitwarden.ui.platform.manager.keychain.model.PrivateKeyAliasSelectionResult +import kotlinx.coroutines.channels.awaitClose +import kotlinx.coroutines.flow.callbackFlow +import kotlinx.coroutines.flow.first + +/** + * Default implementation of [KeyChainManager]. + */ +class KeyChainManagerImpl( + private val activity: Activity, +) : KeyChainManager { + + override suspend fun choosePrivateKeyAlias( + currentServerUrl: String?, + ): PrivateKeyAliasSelectionResult = + callbackFlow { + try { + KeyChain.choosePrivateKeyAlias( + activity, + { alias -> + trySend(PrivateKeyAliasSelectionResult.Success(alias)) + close() + }, + null, + null, + currentServerUrl?.toUri(), + null, + ) + } catch (_: IllegalArgumentException) { + trySend(PrivateKeyAliasSelectionResult.Error) + close() + } + + awaitClose() + } + .first() +} diff --git a/app/src/main/java/com/x8bit/bitwarden/ui/platform/manager/keychain/model/PrivateKeyAliasSelectionResult.kt b/app/src/main/java/com/x8bit/bitwarden/ui/platform/manager/keychain/model/PrivateKeyAliasSelectionResult.kt new file mode 100644 index 0000000000..7cf93b5034 --- /dev/null +++ b/app/src/main/java/com/x8bit/bitwarden/ui/platform/manager/keychain/model/PrivateKeyAliasSelectionResult.kt @@ -0,0 +1,17 @@ +package com.x8bit.bitwarden.ui.platform.manager.keychain.model + +/** + * Represents the result of an operation to select a private key alias from the system KeyChain. + */ +sealed class PrivateKeyAliasSelectionResult { + /** + * Indicates that the operation was successful and an alias was selected (or null if none was + * selected). + */ + data class Success(val alias: String?) : PrivateKeyAliasSelectionResult() + + /** + * Indicates that an error occurred during the operation. + */ + object Error : PrivateKeyAliasSelectionResult() +} diff --git a/app/src/main/res/values/strings.xml b/app/src/main/res/values/strings.xml index edeead4c57..14f1cf57fe 100644 --- a/app/src/main/res/values/strings.xml +++ b/app/src/main/res/values/strings.xml @@ -281,6 +281,9 @@ Scanning will happen automatically. Encryption key migration required. Please login through the web vault to update your encryption key. Learn more API server URL + Client certificate (mTLS) + Certificate used for client authentication. + Use system certificate Custom environment For advanced users. You can specify the base URL of each service independently. The environment URLs have been saved. @@ -1165,5 +1168,17 @@ Do you want to switch to this account? 6 OF 6 Use these options to adjust your password to meet your account website\'s requirements. "After you save your new password to Bitwarden, don’t forget to update it on your account website. " - Link + Link + Error loading certificate + Certificate alias + Import certificate + Choose system certificate + Certificate installed + Certificate installation failed + Certificate selection failed + Unsupported certificate type + Certificate password incorrect + Invalid certificate chain + Using a system certificate is less secure than storing the certificate with Bitwarden. Continuing will display a list of available system certificates if one is already installed. + Certificate alias already exists diff --git a/app/src/test/java/com/x8bit/bitwarden/data/platform/datasource/network/retrofit/RetrofitsTest.kt b/app/src/test/java/com/x8bit/bitwarden/data/platform/datasource/network/retrofit/RetrofitsTest.kt index fe9cbb4b07..ff980d744f 100644 --- a/app/src/test/java/com/x8bit/bitwarden/data/platform/datasource/network/retrofit/RetrofitsTest.kt +++ b/app/src/test/java/com/x8bit/bitwarden/data/platform/datasource/network/retrofit/RetrofitsTest.kt @@ -5,14 +5,21 @@ import com.x8bit.bitwarden.data.platform.datasource.network.interceptor.AuthToke import com.x8bit.bitwarden.data.platform.datasource.network.interceptor.BaseUrlInterceptors import com.x8bit.bitwarden.data.platform.datasource.network.interceptor.HeadersInterceptor import com.x8bit.bitwarden.data.platform.datasource.network.model.NetworkResult +import com.x8bit.bitwarden.data.platform.datasource.network.ssl.SslManager +import com.x8bit.bitwarden.data.util.mockBuilder import io.mockk.every import io.mockk.mockk +import io.mockk.mockkConstructor import io.mockk.slot +import io.mockk.unmockkConstructor +import io.mockk.verify import kotlinx.coroutines.runBlocking +import kotlinx.coroutines.test.runTest import kotlinx.serialization.json.Json import kotlinx.serialization.json.JsonObject import okhttp3.Authenticator import okhttp3.Interceptor +import okhttp3.OkHttpClient import okhttp3.mockwebserver.MockResponse import okhttp3.mockwebserver.MockWebServer import org.junit.After @@ -23,6 +30,10 @@ import org.junit.jupiter.api.Assertions.assertTrue import retrofit2.Retrofit import retrofit2.create import retrofit2.http.GET +import javax.net.ssl.SSLContext +import javax.net.ssl.SSLSocketFactory +import javax.net.ssl.TrustManager +import javax.net.ssl.X509TrustManager class RetrofitsTest { private val authTokenInterceptor = mockk { @@ -47,12 +58,17 @@ class RetrofitsTest { } private val json = Json private val server = MockWebServer() + private val mockSslManager = mockk { + every { sslContext } returns mockk(relaxed = true) + every { trustManagers } returns arrayOf(mockk(relaxed = true)) + } private val retrofits = RetrofitsImpl( authTokenInterceptor = authTokenInterceptor, baseUrlInterceptors = baseUrlInterceptors, headersInterceptor = headersInterceptors, refreshAuthenticator = refreshAuthenticator, + sslManager = mockSslManager, json = json, ) @@ -71,6 +87,7 @@ class RetrofitsTest { @After fun tearDown() { server.shutdown() + unmockkConstructor(OkHttpClient.Builder::class) } @Test @@ -256,6 +273,94 @@ class RetrofitsTest { assertFalse(isEventsInterceptorCalled) } + @Suppress("MaxLineLength") + @Test + fun `createStaticRetrofit should set sslSocketFactory`() = + runTest { + val mockTrustManager = mockk(relaxed = true) + val mockSocketFactory = mockk() + val mockSslContext = mockk { + every { socketFactory } returns mockSocketFactory + } + setupMockOkHttpClientBuilder( + sslContext = mockSslContext, + trustManagers = arrayOf(mockTrustManager), + ) + + retrofits.createStaticRetrofit(isAuthenticated = false) + + verify { + anyConstructed() + .sslSocketFactory( + sslSocketFactory = mockSocketFactory, + trustManager = mockTrustManager, + ) + } + } + + @Suppress("MaxLineLength") + @Test + fun `authenticatedOkHttpClient should set sslSocketFactory`() = + runTest { + val mockTrustManager = mockk(relaxed = true) + val mockSocketFactory = mockk() + val mockSslContext = mockk { + every { socketFactory } returns mockSocketFactory + } + setupMockOkHttpClientBuilder( + sslContext = mockSslContext, + trustManagers = arrayOf(mockTrustManager), + ) + + retrofits.authenticatedApiRetrofit + + verify { + anyConstructed() + .sslSocketFactory( + sslSocketFactory = mockSocketFactory, + trustManager = mockTrustManager, + ) + } + } + + @Suppress("MaxLineLength") + @Test + fun `unauthenticatedOkHttpClient should set sslSocketFactory`() = + runTest { + val mockTrustManager = mockk(relaxed = true) + val mockSocketFactory = mockk() + val mockSslContext = mockk { + every { socketFactory } returns mockSocketFactory + } + setupMockOkHttpClientBuilder( + sslContext = mockSslContext, + trustManagers = arrayOf(mockTrustManager), + ) + + retrofits.unauthenticatedApiRetrofit + + verify { + anyConstructed() + .sslSocketFactory( + sslSocketFactory = mockSocketFactory, + trustManager = mockTrustManager, + ) + } + } + + private fun setupMockOkHttpClientBuilder( + sslContext: SSLContext = mockk(), + trustManagers: Array = emptyArray(), + ) { + mockkConstructor(OkHttpClient.Builder::class) + every { mockSslManager.sslContext } returns sslContext + every { mockSslManager.trustManagers } returns trustManagers + mockBuilder { + it.sslSocketFactory(any(), any()) + } + every { anyConstructed().build() } returns mockk(relaxed = true) + } + private fun Retrofit.createMockRetrofit(): Retrofit = this .newBuilder() diff --git a/app/src/test/java/com/x8bit/bitwarden/data/platform/datasource/network/ssl/SslManagerTest.kt b/app/src/test/java/com/x8bit/bitwarden/data/platform/datasource/network/ssl/SslManagerTest.kt new file mode 100644 index 0000000000..7c71943936 --- /dev/null +++ b/app/src/test/java/com/x8bit/bitwarden/data/platform/datasource/network/ssl/SslManagerTest.kt @@ -0,0 +1,218 @@ +package com.x8bit.bitwarden.data.platform.datasource.network.ssl + +import android.net.Uri +import com.x8bit.bitwarden.data.auth.datasource.disk.model.EnvironmentUrlDataJson +import com.x8bit.bitwarden.data.platform.datasource.disk.model.MutualTlsCertificate +import com.x8bit.bitwarden.data.platform.datasource.disk.model.MutualTlsKeyHost +import com.x8bit.bitwarden.data.platform.manager.KeyManager +import com.x8bit.bitwarden.data.platform.repository.EnvironmentRepository +import com.x8bit.bitwarden.data.platform.repository.model.Environment +import io.mockk.every +import io.mockk.just +import io.mockk.mockk +import io.mockk.mockkStatic +import io.mockk.runs +import io.mockk.slot +import io.mockk.unmockkStatic +import io.mockk.verify +import org.junit.jupiter.api.AfterEach +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.Assertions.assertNotNull +import org.junit.jupiter.api.Assertions.assertNull +import org.junit.jupiter.api.Assertions.assertTrue +import org.junit.jupiter.api.BeforeEach +import org.junit.jupiter.api.Test +import java.security.KeyStore +import java.security.PrivateKey +import java.security.cert.X509Certificate +import javax.net.ssl.SSLContext +import javax.net.ssl.TrustManager +import javax.net.ssl.TrustManagerFactory +import javax.net.ssl.X509ExtendedKeyManager + +class SslManagerTest { + + private val mockEnvironment = mockk { + every { environmentUrlData } returns DEFAULT_ENV_URL_DATA + } + private val mockEnvironmentRepository = mockk { + every { environment } returns mockEnvironment + } + private val mockMutualTlsCertificate = mockk { + every { alias } returns "mockAlias" + } + private val mockKeyManager = mockk { + every { getMutualTlsCertificateChain(any(), any()) } returns mockMutualTlsCertificate + } + private val mockTrustManagerFactory = mockk { + every { init(null as? KeyStore?) } just runs + every { trustManagers } returns DEFAULT_TRUST_MANAGERS + } + + private val sslManager: SslManagerImpl = SslManagerImpl( + keyManager = mockKeyManager, + environmentRepository = mockEnvironmentRepository, + ) + + @BeforeEach + fun setUp() { + mockkStatic(TrustManagerFactory::class) + every { + TrustManagerFactory.getDefaultAlgorithm() + } returns "defaultAlgorithm" + every { + TrustManagerFactory.getInstance("defaultAlgorithm") + } returns mockTrustManagerFactory + } + + @AfterEach + fun tearDown() { + unmockkStatic(TrustManagerFactory::class, Uri::class, SSLContext::class) + } + + @Test + fun `sslContext should be initialized with default TLS protocol`() { + setupMockUri() + assertTrue(sslManager.sslContext.protocol == "TLS") + } + + @Test + fun `X509ExtendedKeyManagerImpl should initialize with mutualTlsCertificate`() { + setupMockUri() + mockkStatic(SSLContext::class) + val keyManagersCaptor = slot>() + val trustManagersCaptor = slot>() + every { SSLContext.getInstance("TLS") } returns mockk { + every { + init( + capture(keyManagersCaptor), + capture(trustManagersCaptor), + any(), + ) + } just runs + } + every { mockEnvironment.environmentUrlData } returns DEFAULT_ENV_URL_DATA + every { mockMutualTlsCertificate.alias } returns "mockAlias" + every { mockMutualTlsCertificate.certificateChain } returns listOf( + mockk(name = "MockCertificate1"), + mockk(name = "MockCertificate2"), + ) + every { + mockMutualTlsCertificate.privateKey + } returns mockk() + every { + mockKeyManager.getMutualTlsCertificateChain( + alias = "mockAlias", + host = MutualTlsKeyHost.ANDROID_KEY_STORE, + ) + } returns mockMutualTlsCertificate + + assertNotNull(sslManager.sslContext) + + val keyManager = keyManagersCaptor.captured.first() + assertEquals( + mockMutualTlsCertificate.alias, + keyManager.chooseClientAlias(null, null, null), + ) + assertTrue( + keyManager + .getCertificateChain("mockAlias") + .contentEquals( + mockMutualTlsCertificate + .certificateChain + .toTypedArray(), + ), + ) + assertEquals( + mockMutualTlsCertificate.privateKey, + keyManager.getPrivateKey("mockAlias"), + ) + } + + @Test + fun `mutualTlsCertificate should return null when keyUri is null`() { + every { + mockEnvironment.environmentUrlData + } returns DEFAULT_ENV_URL_DATA.copy(keyUri = null) + assertNull(sslManager.mutualTlsCertificate) + } + + @Test + fun `mutualTlsCertificate should be null when host is invalid`() { + setupMockUri(authority = "UNKNOWN_HOST") + assertNull(sslManager.mutualTlsCertificate) + } + + @Test + fun `mutualTlsCertificate should be null when alias is null`() { + setupMockUri(path = null) + assertNull(sslManager.mutualTlsCertificate) + } + + @Test + fun `mutualTlsCertificate should trim path when it is not null`() { + setupMockUri(path = "/mockAlias/") + assertEquals("mockAlias", sslManager.mutualTlsCertificate?.alias) + } + + @Test + fun `mutualTlsCertificate should be null when alias is empty after trim`() { + setupMockUri(path = "/") + assertNull(sslManager.mutualTlsCertificate) + } + + @Test + fun `mutualTlsCertificate should call keyManager with correct alias and host`() { + // Set host to ANDROID_KEY_STORE + setupMockUri() + assertNotNull(sslManager.mutualTlsCertificate) + verify { + mockKeyManager.getMutualTlsCertificateChain( + alias = "mockAlias", + host = MutualTlsKeyHost.ANDROID_KEY_STORE, + ) + } + + // Set host to KEY_CHAIN + setupMockUri(authority = "KEY_CHAIN") + assertNotNull(sslManager.mutualTlsCertificate) + verify { + mockKeyManager.getMutualTlsCertificateChain( + alias = "mockAlias", + host = MutualTlsKeyHost.KEY_CHAIN, + ) + } + } + + @Suppress("MaxLineLength") + @Test + fun `trustManagers should return TrustManager array initialized with default algorithm and null keystore`() { + assertTrue(sslManager.trustManagers.contentEquals(DEFAULT_TRUST_MANAGERS)) + verify { + TrustManagerFactory.getInstance("defaultAlgorithm") + TrustManagerFactory.getDefaultAlgorithm() + mockTrustManagerFactory.init(null as? KeyStore?) + } + } + + private fun setupMockUri( + authority: String = "ANDROID_KEY_STORE", + path: String? = "/mockAlias", + ) { + mockkStatic(Uri::class) + val uriMock = mockk() + every { Uri.parse(any()) } returns uriMock + every { uriMock.authority } returns authority + every { uriMock.path } returns path + } +} + +private val DEFAULT_TRUST_MANAGERS = arrayOf( + mockk(name = "MockTrustManager1"), + mockk(name = "MockTrustManager2"), +) + +val DEFAULT_ENV_URL_DATA = EnvironmentUrlDataJson( + base = "https://example.com", + keyUri = "cert://ANDROID_KEY_STORE/mockAlias", +) diff --git a/app/src/test/java/com/x8bit/bitwarden/data/platform/manager/KeyManagerTest.kt b/app/src/test/java/com/x8bit/bitwarden/data/platform/manager/KeyManagerTest.kt index 1d2b57fab6..3c70f1a47b 100644 --- a/app/src/test/java/com/x8bit/bitwarden/data/platform/manager/KeyManagerTest.kt +++ b/app/src/test/java/com/x8bit/bitwarden/data/platform/manager/KeyManagerTest.kt @@ -3,7 +3,7 @@ package com.x8bit.bitwarden.data.platform.manager import android.content.Context import android.security.KeyChain import android.security.KeyChainException -import com.x8bit.bitwarden.data.platform.datasource.disk.model.ImportPrivateKeyResult +import com.x8bit.bitwarden.data.platform.manager.model.ImportPrivateKeyResult import com.x8bit.bitwarden.data.platform.datasource.disk.model.MutualTlsCertificate import com.x8bit.bitwarden.data.platform.datasource.disk.model.MutualTlsKeyHost import io.mockk.every diff --git a/app/src/test/java/com/x8bit/bitwarden/ui/auth/feature/environment/EnvironmentScreenTest.kt b/app/src/test/java/com/x8bit/bitwarden/ui/auth/feature/environment/EnvironmentScreenTest.kt index 56fad26a2f..71d12437b6 100644 --- a/app/src/test/java/com/x8bit/bitwarden/ui/auth/feature/environment/EnvironmentScreenTest.kt +++ b/app/src/test/java/com/x8bit/bitwarden/ui/auth/feature/environment/EnvironmentScreenTest.kt @@ -10,9 +10,17 @@ import androidx.compose.ui.test.onAllNodesWithText import androidx.compose.ui.test.onNodeWithContentDescription import androidx.compose.ui.test.onNodeWithText import androidx.compose.ui.test.performClick +import androidx.compose.ui.test.performScrollTo import androidx.compose.ui.test.performTextInput import com.x8bit.bitwarden.data.platform.repository.util.bufferedMutableSharedFlow +import com.x8bit.bitwarden.ui.auth.feature.environment.EnvironmentState.DialogState import com.x8bit.bitwarden.ui.platform.base.BaseComposeTest +import com.x8bit.bitwarden.ui.platform.base.util.asText +import com.x8bit.bitwarden.ui.platform.manager.intent.IntentManager +import com.x8bit.bitwarden.ui.platform.manager.keychain.KeyChainManager +import com.x8bit.bitwarden.ui.platform.manager.keychain.model.PrivateKeyAliasSelectionResult +import io.mockk.coEvery +import io.mockk.coVerify import io.mockk.every import io.mockk.mockk import io.mockk.verify @@ -26,6 +34,12 @@ class EnvironmentScreenTest : BaseComposeTest() { private var onNavigateBackCalled = false private val mutableEventFlow = bufferedMutableSharedFlow() private val mutableStateFlow = MutableStateFlow(DEFAULT_STATE) + private val mockIntentManager = mockk(relaxed = true) + private val mockKeyChainManager = mockk { + coEvery { + choosePrivateKeyAlias(any()) + } returns PrivateKeyAliasSelectionResult.Success("mockAlias") + } private val viewModel = mockk(relaxed = true) { every { eventFlow } returns mutableEventFlow every { stateFlow } returns mutableStateFlow @@ -36,6 +50,8 @@ class EnvironmentScreenTest : BaseComposeTest() { composeTestRule.setContent { EnvironmentScreen( onNavigateBack = { onNavigateBackCalled = true }, + intentManager = mockIntentManager, + keyChainManager = mockKeyChainManager, viewModel = viewModel, ) } @@ -69,7 +85,11 @@ class EnvironmentScreenTest : BaseComposeTest() { mutableStateFlow.update { it.copy( - shouldShowErrorDialog = true, + dialog = DialogState.Error( + ("One or more of the URLs entered are invalid. " + + "Please revise it and try to save again.") + .asText(), + ), ) } @@ -96,7 +116,7 @@ class EnvironmentScreenTest : BaseComposeTest() { fun `error dialog OK click should send ErrorDialogDismiss action`() { mutableStateFlow.update { it.copy( - shouldShowErrorDialog = true, + dialog = DialogState.Error("Error".asText()), ) } composeTestRule @@ -131,6 +151,62 @@ class EnvironmentScreenTest : BaseComposeTest() { } } + @Test + fun `use system certificate click should send UseSystemKeyCertificateClick`() { + composeTestRule + .onNodeWithText("Choose system certificate") + .performScrollTo() + .assertIsDisplayed() + .performClick() + + verify { + viewModel.trySendAction(EnvironmentAction.ChooseSystemCertificateClick) + } + } + + @Test + fun `ShowSystemCertificateSelection event should show system certificate selection dialog`() { + mutableEventFlow.tryEmit( + EnvironmentEvent.ShowSystemCertificateSelectionDialog(serverUrl = ""), + ) + coVerify { mockKeyChainManager.choosePrivateKeyAlias(null) } + } + + @Suppress("MaxLineLength") + @Test + fun `system certificate selection should send SystemCertificateSelectionResultReceive action`() { + coEvery { + mockKeyChainManager.choosePrivateKeyAlias(null) + } returns PrivateKeyAliasSelectionResult.Success("alias") + + mutableEventFlow.tryEmit( + EnvironmentEvent.ShowSystemCertificateSelectionDialog(serverUrl = ""), + ) + + verify { + viewModel.trySendAction( + EnvironmentAction.SystemCertificateSelectionResultReceive( + privateKeyAliasSelectionResult = PrivateKeyAliasSelectionResult.Success( + alias = "alias", + ), + ), + ) + } + } + + @Test + fun `key alias should change according to the state`() { + composeTestRule + .onNodeWithText("Certificate alias") + .assertTextEquals("Certificate alias", "") + + mutableStateFlow.update { it.copy(keyAlias = "mock-alias") } + + composeTestRule + .onNodeWithText("Certificate alias") + .assertTextEquals("Certificate alias", "mock-alias") + } + @Test fun `web vault URL should change according to the state`() { composeTestRule @@ -238,11 +314,14 @@ class EnvironmentScreenTest : BaseComposeTest() { companion object { val DEFAULT_STATE = EnvironmentState( serverUrl = "", + keyAlias = "", webVaultServerUrl = "", apiServerUrl = "", identityServerUrl = "", iconsServerUrl = "", - shouldShowErrorDialog = false, + keyHost = null, + dialog = null, + showMutualTlsOptions = true, ) } } diff --git a/app/src/test/java/com/x8bit/bitwarden/ui/auth/feature/environment/EnvironmentViewModelTest.kt b/app/src/test/java/com/x8bit/bitwarden/ui/auth/feature/environment/EnvironmentViewModelTest.kt index 8d7c3a3ed0..4cde15cfe9 100644 --- a/app/src/test/java/com/x8bit/bitwarden/ui/auth/feature/environment/EnvironmentViewModelTest.kt +++ b/app/src/test/java/com/x8bit/bitwarden/ui/auth/feature/environment/EnvironmentViewModelTest.kt @@ -1,13 +1,28 @@ package com.x8bit.bitwarden.ui.auth.feature.environment +import android.net.Uri import androidx.lifecycle.SavedStateHandle import app.cash.turbine.test import com.x8bit.bitwarden.R import com.x8bit.bitwarden.data.auth.datasource.disk.model.EnvironmentUrlDataJson +import com.x8bit.bitwarden.data.platform.datasource.disk.model.MutualTlsKeyHost +import com.x8bit.bitwarden.data.platform.manager.FeatureFlagManager +import com.x8bit.bitwarden.data.platform.manager.KeyManager +import com.x8bit.bitwarden.data.platform.manager.model.FlagKey +import com.x8bit.bitwarden.data.platform.manager.model.ImportPrivateKeyResult import com.x8bit.bitwarden.data.platform.repository.model.Environment import com.x8bit.bitwarden.data.platform.repository.util.FakeEnvironmentRepository +import com.x8bit.bitwarden.data.platform.util.asSuccess +import com.x8bit.bitwarden.data.vault.manager.FileManager import com.x8bit.bitwarden.ui.platform.base.BaseViewModelTest import com.x8bit.bitwarden.ui.platform.base.util.asText +import com.x8bit.bitwarden.ui.platform.manager.intent.IntentManager +import com.x8bit.bitwarden.ui.platform.manager.keychain.model.PrivateKeyAliasSelectionResult +import io.mockk.coEvery +import io.mockk.coVerify +import io.mockk.every +import io.mockk.mockk +import kotlinx.coroutines.flow.MutableStateFlow import kotlinx.coroutines.test.runTest import org.junit.jupiter.api.Assertions.assertEquals import org.junit.jupiter.api.Test @@ -15,6 +30,13 @@ import org.junit.jupiter.api.Test class EnvironmentViewModelTest : BaseViewModelTest() { private val fakeEnvironmentRepository = FakeEnvironmentRepository() + private val mutableMutualTlsFeatureFlagFlow = MutableStateFlow(true) + private val mockFeatureFlagManager = mockk { + every { getFeatureFlag(FlagKey.MutualTls) } returns true + every { getFeatureFlagFlow(FlagKey.MutualTls) } returns mutableMutualTlsFeatureFlagFlow + } + private val mockKeyManager = mockk() + private val mockFileManager = mockk() @Suppress("MaxLineLength") @Test @@ -60,6 +82,8 @@ class EnvironmentViewModelTest : BaseViewModelTest() { apiServerUrl = "saved-api", identityServerUrl = "saved-identity", iconsServerUrl = "saved-icons", + keyHost = MutualTlsKeyHost.ANDROID_KEY_STORE, + keyAlias = "saved-key-alias", ) val viewModel = createViewModel( savedStateHandle = SavedStateHandle( @@ -75,6 +99,8 @@ class EnvironmentViewModelTest : BaseViewModelTest() { apiServerUrl = "saved-api", identityServerUrl = "saved-identity", iconsServerUrl = "saved-icons", + keyHost = MutualTlsKeyHost.ANDROID_KEY_STORE, + keyAlias = "saved-key-alias", ), viewModel.stateFlow.value, ) @@ -115,7 +141,9 @@ class EnvironmentViewModelTest : BaseViewModelTest() { assertEquals( initialState.copy( - shouldShowErrorDialog = true, + dialog = EnvironmentState.DialogState.Error( + message = R.string.environment_page_urls_error.asText(), + ), ), viewModel.stateFlow.value, ) @@ -154,6 +182,11 @@ class EnvironmentViewModelTest : BaseViewModelTest() { EnvironmentAction.IconsServerUrlChange( iconsServerUrl = "icons-url", ), + EnvironmentAction.SystemCertificateSelectionResultReceive( + privateKeyAliasSelectionResult = PrivateKeyAliasSelectionResult.Success( + alias = "mockAlias", + ), + ), ) .forEach { viewModel.trySendAction(it) } @@ -179,6 +212,7 @@ class EnvironmentViewModelTest : BaseViewModelTest() { notifications = null, webVault = "http://web-vault-url", events = null, + keyUri = "cert://KEY_CHAIN/mockAlias", ), ), fakeEnvironmentRepository.environment, @@ -226,6 +260,7 @@ class EnvironmentViewModelTest : BaseViewModelTest() { notifications = null, webVault = "http://web-vault-url", events = null, + keyUri = null, ), ), fakeEnvironmentRepository.environment, @@ -293,6 +328,337 @@ class EnvironmentViewModelTest : BaseViewModelTest() { ) } + @Suppress("MaxLineLength") + @Test + fun `SystemCertificateSelectionResultReceive should update key alias and key host when successful`() { + val viewModel = createViewModel() + viewModel.trySendAction( + EnvironmentAction.SystemCertificateSelectionResultReceive( + privateKeyAliasSelectionResult = PrivateKeyAliasSelectionResult.Success( + alias = "mockAlias", + ), + ), + ) + assertEquals( + DEFAULT_STATE.copy( + keyAlias = "mockAlias", + keyHost = MutualTlsKeyHost.KEY_CHAIN, + ), + viewModel.stateFlow.value, + ) + + viewModel.trySendAction( + EnvironmentAction.SystemCertificateSelectionResultReceive( + privateKeyAliasSelectionResult = PrivateKeyAliasSelectionResult.Success( + alias = null, + ), + ), + ) + assertEquals( + DEFAULT_STATE.copy( + keyAlias = "", + keyHost = null, + ), + viewModel.stateFlow.value, + ) + } + + @Test + fun `SystemCertificateSelectionResultReceive should show toast when error`() = runTest { + val viewModel = createViewModel() + viewModel.trySendAction( + EnvironmentAction.SystemCertificateSelectionResultReceive( + privateKeyAliasSelectionResult = PrivateKeyAliasSelectionResult.Error, + ), + ) + viewModel.eventFlow.test { + assertEquals( + EnvironmentEvent.ShowToast(R.string.error_loading_certificate.asText()), + awaitItem(), + ) + } + } + + @Test + fun `ChooseSystemCertificate should show system certificate warning dialog`() = + runTest { + val viewModel = createViewModel() + + viewModel.trySendAction(EnvironmentAction.ChooseSystemCertificateClick) + + assertEquals( + DEFAULT_STATE.copy( + dialog = EnvironmentState.DialogState.SystemCertificateWarningDialog, + ), + viewModel.stateFlow.value, + ) + } + + @Test + fun `ErrorDialogDismiss should clear the dialog`() = runTest { + val viewModel = createViewModel() + viewModel.trySendAction(EnvironmentAction.ErrorDialogDismiss) + assertEquals( + DEFAULT_STATE.copy(dialog = null), + viewModel.stateFlow.value, + ) + } + + @Test + fun `ImportCertificateClick should emit ShowCertificateImportFileChooser`() = runTest { + val viewModel = createViewModel() + viewModel.eventFlow.test { + viewModel.trySendAction(EnvironmentAction.ImportCertificateClick) + assertEquals(EnvironmentEvent.ShowCertificateImportFileChooser, awaitItem()) + } + } + + @Test + fun `ImportCertificateFilePickerResultReceive should show SetCertificateData dialog`() = + runTest { + val viewModel = createViewModel() + val mockFileData = mockk() + viewModel.trySendAction( + EnvironmentAction.ImportCertificateFilePickerResultReceive( + certificateFileData = mockFileData, + ), + ) + assertEquals( + DEFAULT_STATE.copy( + dialog = EnvironmentState.DialogState.SetCertificateData( + certificateBytes = mockFileData, + ), + ), + viewModel.stateFlow.value, + ) + } + + @Test + fun `SetCertificatePasswordDialogDismiss should clear the dialog`() = runTest { + val viewModel = createViewModel() + viewModel.trySendAction(EnvironmentAction.SetCertificatePasswordDialogDismiss) + assertEquals( + DEFAULT_STATE.copy(dialog = null), + viewModel.stateFlow.value, + ) + } + + @Test + fun `CertificateInstallationResultReceive should show toast based on result`() = runTest { + val viewModel = createViewModel() + + viewModel.eventFlow.test { + viewModel.trySendAction( + EnvironmentAction.CertificateInstallationResultReceive( + success = true, + ), + ) + assertEquals( + EnvironmentEvent.ShowToast(R.string.certificate_installed.asText()), + awaitItem(), + ) + + viewModel.trySendAction( + EnvironmentAction.CertificateInstallationResultReceive( + success = false, + ), + ) + assertEquals( + EnvironmentEvent.ShowToast(R.string.certificate_installation_failed.asText()), + awaitItem(), + ) + } + } + + @Suppress("MaxLineLength") + @Test + fun `ConfirmChooseSystemCertificateClick should clear the dialog and emit ShowSystemCertificateSelectionDialog`() = + runTest { + val viewModel = createViewModel( + savedStateHandle = SavedStateHandle( + initialState = mapOf( + "state" to DEFAULT_STATE.copy(serverUrl = "https://mockServerUrl"), + ), + ), + ) + viewModel.trySendAction(EnvironmentAction.ConfirmChooseSystemCertificateClick) + assertEquals( + DEFAULT_STATE.copy(dialog = null, serverUrl = "https://mockServerUrl"), + viewModel.stateFlow.value, + ) + viewModel.eventFlow.test { + assertEquals( + EnvironmentEvent.ShowSystemCertificateSelectionDialog( + serverUrl = "https://mockServerUrl", + ), + awaitItem(), + ) + } + } + + @Test + fun `ImportKeyResultReceive should update key alias and key host on success`() = runTest { + val viewModel = createViewModel() + viewModel.trySendAction( + EnvironmentAction.Internal.ImportKeyResultReceive( + result = ImportPrivateKeyResult.Success(alias = "mockAlias"), + ), + ) + assertEquals( + DEFAULT_STATE.copy( + keyAlias = "mockAlias", + keyHost = MutualTlsKeyHost.ANDROID_KEY_STORE, + ), + viewModel.stateFlow.value, + ) + } + + @Test + fun `ImportKeyResultReceive should show toast with correct message on error`() = runTest { + val viewModel = createViewModel() + viewModel.eventFlow.test { + viewModel.trySendAction( + EnvironmentAction.Internal.ImportKeyResultReceive( + result = ImportPrivateKeyResult.Error.UnsupportedKey, + ), + ) + assertEquals( + EnvironmentEvent.ShowToast(R.string.unsupported_certificate_type.asText()), + awaitItem(), + ) + + viewModel.trySendAction( + EnvironmentAction.Internal.ImportKeyResultReceive( + result = ImportPrivateKeyResult.Error.KeyStoreOperationFailed, + ), + ) + assertEquals( + EnvironmentEvent.ShowToast(R.string.certificate_installation_failed.asText()), + awaitItem(), + ) + + viewModel.trySendAction( + EnvironmentAction.Internal.ImportKeyResultReceive( + result = ImportPrivateKeyResult.Error.UnrecoverableKey, + ), + ) + assertEquals( + EnvironmentEvent.ShowToast(R.string.certificate_password_incorrect.asText()), + awaitItem(), + ) + + viewModel.trySendAction( + EnvironmentAction.Internal.ImportKeyResultReceive( + result = ImportPrivateKeyResult.Error.InvalidCertificateChain, + ), + ) + assertEquals( + EnvironmentEvent.ShowToast(R.string.invalid_certificate_chain.asText()), + awaitItem(), + ) + + viewModel.trySendAction( + EnvironmentAction.Internal.ImportKeyResultReceive( + result = ImportPrivateKeyResult.Error.DuplicateAlias, + ), + ) + assertEquals( + EnvironmentEvent.ShowToast(R.string.certificate_alias_already_exists.asText()), + awaitItem(), + ) + } + } + + @Suppress("MaxLineLength") + @Test + fun `SetCertificateInfoResultReceive should clear the dialog update key alias and key host after successful import`() = + runTest { + val viewModel = createViewModel() + val mockUri = mockk() + val mockFileData = IntentManager.FileData( + fileName = "mockFileName", + uri = mockUri, + sizeBytes = 0, + ) + val keyBytes = byteArrayOf() + coEvery { + mockFileManager.uriToByteArray(mockFileData.uri) + } returns keyBytes.asSuccess() + coEvery { + mockKeyManager.importMutualTlsCertificate( + key = keyBytes, + alias = "mockAlias", + password = "mockPassword", + ) + } returns ImportPrivateKeyResult.Success(alias = "mockAlias") + + viewModel.trySendAction( + EnvironmentAction.SetCertificateInfoResultReceive( + certificateFileData = mockFileData, + alias = "mockAlias", + password = "mockPassword", + ), + ) + assertEquals( + DEFAULT_STATE.copy( + dialog = null, + keyAlias = "mockAlias", + keyHost = MutualTlsKeyHost.ANDROID_KEY_STORE, + ), + viewModel.stateFlow.value, + ) + + coVerify { + mockFileManager.uriToByteArray(mockFileData.uri) + mockKeyManager.importMutualTlsCertificate( + key = byteArrayOf(), + alias = "mockAlias", + password = "mockPassword", + ) + } + } + + @Test + fun `SetCertificateInfoResultReceive should show error dialog if input is invalid`() = runTest { + val viewModel = createViewModel() + + viewModel.trySendAction( + EnvironmentAction.SetCertificateInfoResultReceive( + certificateFileData = mockk(), + alias = "mockAlias", + password = "", + ), + ) + assertEquals( + DEFAULT_STATE.copy( + dialog = EnvironmentState.DialogState.Error( + R.string.validation_field_required.asText( + R.string.password.asText(), + ), + ), + ), + viewModel.stateFlow.value, + ) + + viewModel.trySendAction( + EnvironmentAction.SetCertificateInfoResultReceive( + certificateFileData = mockk(), + alias = "", + password = "mockPassword", + ), + ) + assertEquals( + DEFAULT_STATE.copy( + dialog = EnvironmentState.DialogState.Error( + R.string.validation_field_required.asText( + R.string.alias.asText(), + ), + ), + ), + viewModel.stateFlow.value, + ) + } + //region Helper methods private fun createViewModel( @@ -300,6 +666,9 @@ class EnvironmentViewModelTest : BaseViewModelTest() { ): EnvironmentViewModel = EnvironmentViewModel( environmentRepository = fakeEnvironmentRepository, + featureFlagManager = mockFeatureFlagManager, + keyManager = mockKeyManager, + fileManager = mockFileManager, savedStateHandle = savedStateHandle, ) @@ -308,11 +677,14 @@ class EnvironmentViewModelTest : BaseViewModelTest() { companion object { private val DEFAULT_STATE = EnvironmentState( serverUrl = "", + keyAlias = "", webVaultServerUrl = "", apiServerUrl = "", identityServerUrl = "", iconsServerUrl = "", - shouldShowErrorDialog = false, + keyHost = null, + dialog = null, + showMutualTlsOptions = true, ) } } diff --git a/app/src/test/java/com/x8bit/bitwarden/ui/platform/manager/keychain/KeyChainManagerTest.kt b/app/src/test/java/com/x8bit/bitwarden/ui/platform/manager/keychain/KeyChainManagerTest.kt new file mode 100644 index 0000000000..8e560dd2c6 --- /dev/null +++ b/app/src/test/java/com/x8bit/bitwarden/ui/platform/manager/keychain/KeyChainManagerTest.kt @@ -0,0 +1,135 @@ +package com.x8bit.bitwarden.ui.platform.manager.keychain + +import android.app.Activity +import android.net.Uri +import android.security.KeyChain +import android.security.KeyChainAliasCallback +import androidx.core.net.toUri +import com.x8bit.bitwarden.ui.platform.manager.keychain.model.PrivateKeyAliasSelectionResult +import io.mockk.every +import io.mockk.mockk +import io.mockk.mockkStatic +import io.mockk.slot +import io.mockk.unmockkStatic +import kotlinx.coroutines.test.runTest +import org.junit.jupiter.api.AfterEach +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.Assertions.assertNull +import org.junit.jupiter.api.BeforeEach +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.assertInstanceOf + +class KeyChainManagerTest { + + private val mockActivity = mockk() + private val keyChainManager = KeyChainManagerImpl(activity = mockActivity) + + @BeforeEach + fun setUp() { + mockkStatic(KeyChain::class) + } + + @AfterEach + fun tearDown() { + unmockkStatic(KeyChain::class) + } + + @Test + fun `choosePrivateKeyAlias should return Success with alias when key is selected`() = runTest { + setupMockUri() + val systemCallbackCaptor = slot() + every { + KeyChain.choosePrivateKeyAlias( + /* activity = */ mockActivity, + /* response = */ capture(systemCallbackCaptor), + /* keyTypes = */ null, + /* issuers = */ null, + /* uri = */ null, + /* alias = */ null, + ) + } answers { + systemCallbackCaptor.captured.alias("mockAlias") + } + + val result = keyChainManager.choosePrivateKeyAlias(currentServerUrl = null) + + assertEquals( + PrivateKeyAliasSelectionResult.Success("mockAlias"), + result, + ) + } + + @Test + fun `choosePrivateKeyAlias should return Error when IllegalArgumentException is thrown`() = + runTest { + setupMockUri() + every { + KeyChain.choosePrivateKeyAlias( + /* activity = */ mockActivity, + /* response = */ any(), + /* keyTypes = */ null, + /* issuers = */ null, + /* uri = */ null, + /* alias = */ null, + ) + } throws IllegalArgumentException() + + val result = keyChainManager.choosePrivateKeyAlias(currentServerUrl = null) + + assertInstanceOf(result) + } + + @Test + fun `choosePrivateKeyAlias should pass currentServerUrl to system KeyChain`() = runTest { + setupMockUri() + val systemCallbackCaptor = slot() + every { + KeyChain.choosePrivateKeyAlias( + /* activity = */ mockActivity, + /* response = */ capture(systemCallbackCaptor), + /* keyTypes = */ null, + /* issuers = */ null, + /* uri = */ "www.mockuri.com".toUri(), + /* alias = */ null, + ) + } answers { + systemCallbackCaptor.captured.alias("mockAlias") + } + + val result = keyChainManager.choosePrivateKeyAlias(currentServerUrl = "www.mockuri.com") + + assertInstanceOf(result) + .also { assertEquals("mockAlias", it.alias) } + } + + @Test + fun `choosePrivateKeyAlias should return Success with null alias when no key is selected`() = + runTest { + setupMockUri() + val systemCallbackCaptor = slot() + every { + KeyChain.choosePrivateKeyAlias( + /* activity = */ mockActivity, + /* response = */ capture(systemCallbackCaptor), + /* keyTypes = */ null, + /* issuers = */ null, + /* uri = */ null, + /* alias = */ null, + ) + } answers { + systemCallbackCaptor.captured.alias(null) + } + + val result = keyChainManager.choosePrivateKeyAlias(currentServerUrl = null) + + assertInstanceOf(result) + .also { assertNull(it.alias) } + } + + private fun setupMockUri() { + mockkStatic(Uri::class) + val uriMock = mockk() + every { Uri.parse(any()) } returns uriMock + every { uriMock.host } returns "www.mockuri.com" + } +}