[PM-16157] Support self-host servers using TLS with Client Authentication (mTLS) (#4486)

This commit is contained in:
rohm1
2025-02-10 19:33:28 +01:00
committed by GitHub
parent fd26472f71
commit 571b8368e1
21 changed files with 1713 additions and 26 deletions

View File

@@ -7,6 +7,7 @@ import kotlinx.serialization.Serializable
* Represents URLs for various Bitwarden domains.
*
* @property base The overall base URL.
* @property keyUri A Uri containing the alias and host of the key used for mutual TLS.
* @property api Separate base URL for the "/api" domain (if applicable).
* @property identity Separate base URL for the "/identity" domain (if applicable).
* @property icon Separate base URL for the icon domain (if applicable).
@@ -19,6 +20,9 @@ data class EnvironmentUrlDataJson(
@SerialName("base")
val base: String,
@SerialName("keyUri")
val keyUri: String? = null,
@SerialName("api")
val api: String? = null,
@@ -51,6 +55,7 @@ data class EnvironmentUrlDataJson(
*/
val DEFAULT_LEGACY_US: EnvironmentUrlDataJson = EnvironmentUrlDataJson(
base = "https://vault.bitwarden.com",
keyUri = null,
api = "https://api.bitwarden.com",
identity = "https://identity.bitwarden.com",
icon = "https://icons.bitwarden.net",
@@ -71,6 +76,7 @@ data class EnvironmentUrlDataJson(
*/
val DEFAULT_LEGACY_EU: EnvironmentUrlDataJson = EnvironmentUrlDataJson(
base = "https://vault.bitwarden.eu",
keyUri = null,
api = "https://api.bitwarden.eu",
identity = "https://identity.bitwarden.eu",
icon = "https://icons.bitwarden.eu",

View File

@@ -14,6 +14,10 @@ import com.x8bit.bitwarden.data.platform.datasource.network.service.EventService
import com.x8bit.bitwarden.data.platform.datasource.network.service.EventServiceImpl
import com.x8bit.bitwarden.data.platform.datasource.network.service.PushService
import com.x8bit.bitwarden.data.platform.datasource.network.service.PushServiceImpl
import com.x8bit.bitwarden.data.platform.datasource.network.ssl.SslManager
import com.x8bit.bitwarden.data.platform.datasource.network.ssl.SslManagerImpl
import com.x8bit.bitwarden.data.platform.manager.KeyManager
import com.x8bit.bitwarden.data.platform.repository.EnvironmentRepository
import dagger.Module
import dagger.Provides
import dagger.hilt.InstallIn
@@ -70,6 +74,17 @@ object PlatformNetworkModule {
@Singleton
fun providesRefreshAuthenticator(): RefreshAuthenticator = RefreshAuthenticator()
@Provides
@Singleton
fun provideSslManager(
keyManager: KeyManager,
environmentRepository: EnvironmentRepository,
): SslManager =
SslManagerImpl(
keyManager = keyManager,
environmentRepository = environmentRepository,
)
@Provides
@Singleton
fun provideRetrofits(
@@ -77,6 +92,7 @@ object PlatformNetworkModule {
baseUrlInterceptors: BaseUrlInterceptors,
headersInterceptor: HeadersInterceptor,
refreshAuthenticator: RefreshAuthenticator,
sslManager: SslManager,
json: Json,
): Retrofits =
RetrofitsImpl(
@@ -84,6 +100,7 @@ object PlatformNetworkModule {
baseUrlInterceptors = baseUrlInterceptors,
headersInterceptor = headersInterceptor,
refreshAuthenticator = refreshAuthenticator,
sslManager = sslManager,
json = json,
)

View File

@@ -6,6 +6,7 @@ import com.x8bit.bitwarden.data.platform.datasource.network.interceptor.AuthToke
import com.x8bit.bitwarden.data.platform.datasource.network.interceptor.BaseUrlInterceptor
import com.x8bit.bitwarden.data.platform.datasource.network.interceptor.BaseUrlInterceptors
import com.x8bit.bitwarden.data.platform.datasource.network.interceptor.HeadersInterceptor
import com.x8bit.bitwarden.data.platform.datasource.network.ssl.SslManager
import com.x8bit.bitwarden.data.platform.datasource.network.util.HEADER_KEY_AUTHORIZATION
import kotlinx.serialization.json.Json
import okhttp3.MediaType.Companion.toMediaType
@@ -14,6 +15,9 @@ import okhttp3.logging.HttpLoggingInterceptor
import retrofit2.Retrofit
import retrofit2.converter.kotlinx.serialization.asConverterFactory
import timber.log.Timber
import javax.net.ssl.SSLContext
import javax.net.ssl.TrustManager
import javax.net.ssl.X509TrustManager
/**
* Primary implementation of [Retrofits].
@@ -24,6 +28,7 @@ class RetrofitsImpl(
headersInterceptor: HeadersInterceptor,
refreshAuthenticator: RefreshAuthenticator,
json: Json,
private val sslManager: SslManager,
) : Retrofits {
//region Authenticated Retrofits
@@ -67,6 +72,10 @@ class RetrofitsImpl(
baseClient
.newBuilder()
.addInterceptor(loggingInterceptor)
.setSslSocketFactory(
sslContext = sslManager.sslContext,
trustManagers = sslManager.trustManagers,
)
.build(),
)
.build()
@@ -93,6 +102,10 @@ class RetrofitsImpl(
.newBuilder()
.authenticator(refreshAuthenticator)
.addInterceptor(authTokenInterceptor)
.setSslSocketFactory(
sslContext = sslManager.sslContext,
trustManagers = sslManager.trustManagers,
)
.build()
}
@@ -133,9 +146,22 @@ class RetrofitsImpl(
.newBuilder()
.addInterceptor(baseUrlInterceptor)
.addInterceptor(loggingInterceptor)
.setSslSocketFactory(
sslContext = sslManager.sslContext,
trustManagers = sslManager.trustManagers,
)
.build(),
)
.build()
private fun OkHttpClient.Builder.setSslSocketFactory(
sslContext: SSLContext,
trustManagers: Array<TrustManager>,
): OkHttpClient.Builder =
sslSocketFactory(
sslContext.socketFactory,
trustManagers.first() as X509TrustManager,
)
//endregion Helper properties and functions
}

View File

@@ -0,0 +1,20 @@
package com.x8bit.bitwarden.data.platform.datasource.network.ssl
import javax.net.ssl.SSLContext
import javax.net.ssl.TrustManager
/**
* Interface for managing SSL connections.
*/
interface SslManager {
/**
* The SSL context to use for SSL connections.
*/
val sslContext: SSLContext
/**
* The trust managers to use for SSL connections.
*/
val trustManagers: Array<TrustManager>
}

View File

@@ -0,0 +1,116 @@
package com.x8bit.bitwarden.data.platform.datasource.network.ssl
import android.net.Uri
import androidx.annotation.VisibleForTesting
import androidx.annotation.WorkerThread
import androidx.core.net.toUri
import com.x8bit.bitwarden.data.platform.datasource.disk.model.MutualTlsCertificate
import com.x8bit.bitwarden.data.platform.datasource.disk.model.MutualTlsKeyHost
import com.x8bit.bitwarden.data.platform.manager.KeyManager
import com.x8bit.bitwarden.data.platform.repository.EnvironmentRepository
import java.net.Socket
import java.security.KeyStore
import java.security.Principal
import java.security.PrivateKey
import java.security.cert.X509Certificate
import javax.net.ssl.SSLContext
import javax.net.ssl.TrustManager
import javax.net.ssl.TrustManagerFactory
import javax.net.ssl.X509ExtendedKeyManager
/**
* Primary implementation of [SslManager].
*/
class SslManagerImpl(
private val keyManager: KeyManager,
private val environmentRepository: EnvironmentRepository,
) : SslManager {
/*
This property must only be accessed from a background thread. Accessing this property from
the main thread will result in an exception being thrown when retrieving the mutual TLS
certificate from [KeyManager].
*/
@VisibleForTesting(otherwise = VisibleForTesting.PRIVATE)
@get:WorkerThread
internal val mutualTlsCertificate: MutualTlsCertificate?
get() {
val keyUri = getKeyUri()
?: return null
val host = MutualTlsKeyHost
.entries
.find { it.name == keyUri.authority }
?: return null
val alias = keyUri.path
?.trim('/')
?.takeUnless { it.isEmpty() }
?: return null
return keyManager.getMutualTlsCertificateChain(
alias = alias,
host = host,
)
}
override val trustManagers: Array<TrustManager>
get() = TrustManagerFactory
.getInstance(TrustManagerFactory.getDefaultAlgorithm())
.apply { init(null as KeyStore?) }
.trustManagers
override val sslContext: SSLContext
get() = SSLContext
.getInstance("TLS")
.apply {
init(
arrayOf(X509ExtendedKeyManagerImpl()),
trustManagers,
null,
)
}
private fun getKeyUri(): Uri? = environmentRepository
.environment
.environmentUrlData
.keyUri
?.toUri()
private inner class X509ExtendedKeyManagerImpl : X509ExtendedKeyManager() {
override fun chooseClientAlias(
keyType: Array<out String>?,
issuers: Array<out Principal>?,
socket: Socket?,
): String = mutualTlsCertificate?.alias ?: ""
override fun getCertificateChain(
alias: String?,
): Array<X509Certificate>? =
mutualTlsCertificate
?.certificateChain
?.toTypedArray()
override fun getPrivateKey(alias: String?): PrivateKey? =
mutualTlsCertificate
?.privateKey
//region Unused server side methods
override fun getServerAliases(
alias: String?,
issuers: Array<out Principal>?,
): Array<String> = arrayOf()
override fun getClientAliases(
keyType: String?,
issuers: Array<out Principal>?,
): Array<String> = emptyArray()
override fun chooseServerAlias(
alias: String?,
issuers: Array<out Principal>?,
socket: Socket?,
): String = ""
//endregion Unused server side methods
}
}

View File

@@ -1,6 +1,7 @@
package com.x8bit.bitwarden.data.platform.manager
import com.x8bit.bitwarden.data.platform.datasource.disk.model.ImportPrivateKeyResult
import androidx.annotation.WorkerThread
import com.x8bit.bitwarden.data.platform.manager.model.ImportPrivateKeyResult
import com.x8bit.bitwarden.data.platform.datasource.disk.model.MutualTlsCertificate
import com.x8bit.bitwarden.data.platform.datasource.disk.model.MutualTlsKeyHost
@@ -29,7 +30,10 @@ interface KeyManager {
/**
* Retrieve the certificate chain for the selected mTLS key.
*
* Must be called from a background thread to prevent possible deadlocks on the main thread.
*/
@WorkerThread
fun getMutualTlsCertificateChain(
alias: String,
host: MutualTlsKeyHost,

View File

@@ -3,9 +3,9 @@ package com.x8bit.bitwarden.data.platform.manager
import android.content.Context
import android.security.KeyChain
import android.security.KeyChainException
import com.x8bit.bitwarden.data.platform.datasource.disk.model.ImportPrivateKeyResult
import com.x8bit.bitwarden.data.platform.datasource.disk.model.MutualTlsCertificate
import com.x8bit.bitwarden.data.platform.datasource.disk.model.MutualTlsKeyHost
import com.x8bit.bitwarden.data.platform.manager.model.ImportPrivateKeyResult
import timber.log.Timber
import java.io.IOException
import java.security.KeyStore

View File

@@ -1,4 +1,4 @@
package com.x8bit.bitwarden.data.platform.datasource.disk.model
package com.x8bit.bitwarden.data.platform.manager.model
/**
* Models the result of importing a private key.

View File

@@ -18,6 +18,7 @@ import androidx.compose.runtime.Composable
import androidx.compose.runtime.getValue
import androidx.compose.runtime.remember
import androidx.compose.ui.Modifier
import androidx.compose.ui.focus.focusProperties
import androidx.compose.ui.input.nestedscroll.nestedScroll
import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.platform.testTag
@@ -31,13 +32,21 @@ import com.x8bit.bitwarden.R
import com.x8bit.bitwarden.ui.platform.base.util.EventsEffect
import com.x8bit.bitwarden.ui.platform.base.util.standardHorizontalMargin
import com.x8bit.bitwarden.ui.platform.components.appbar.BitwardenTopAppBar
import com.x8bit.bitwarden.ui.platform.components.button.BitwardenFilledButton
import com.x8bit.bitwarden.ui.platform.components.button.BitwardenOutlinedButton
import com.x8bit.bitwarden.ui.platform.components.button.BitwardenTextButton
import com.x8bit.bitwarden.ui.platform.components.dialog.BitwardenBasicDialog
import com.x8bit.bitwarden.ui.platform.components.dialog.BitwardenClientCertificateDialog
import com.x8bit.bitwarden.ui.platform.components.dialog.BitwardenTwoButtonDialog
import com.x8bit.bitwarden.ui.platform.components.field.BitwardenTextField
import com.x8bit.bitwarden.ui.platform.components.header.BitwardenListHeaderText
import com.x8bit.bitwarden.ui.platform.components.model.CardStyle
import com.x8bit.bitwarden.ui.platform.components.scaffold.BitwardenScaffold
import com.x8bit.bitwarden.ui.platform.components.util.rememberVectorPainter
import com.x8bit.bitwarden.ui.platform.composition.LocalIntentManager
import com.x8bit.bitwarden.ui.platform.composition.LocalKeyChainManager
import com.x8bit.bitwarden.ui.platform.manager.intent.IntentManager
import com.x8bit.bitwarden.ui.platform.manager.keychain.KeyChainManager
import kotlinx.collections.immutable.persistentListOf
/**
@@ -48,27 +57,100 @@ import kotlinx.collections.immutable.persistentListOf
@Composable
fun EnvironmentScreen(
onNavigateBack: () -> Unit,
intentManager: IntentManager = LocalIntentManager.current,
keyChainManager: KeyChainManager = LocalKeyChainManager.current,
viewModel: EnvironmentViewModel = hiltViewModel(),
) {
val state by viewModel.stateFlow.collectAsStateWithLifecycle()
val context = LocalContext.current
val certificateImportFilePickerLauncher = intentManager.getActivityResultLauncher { result ->
intentManager.getFileDataFromActivityResult(result)?.let {
viewModel.trySendAction(
EnvironmentAction.ImportCertificateFilePickerResultReceive(it),
)
}
}
EventsEffect(viewModel = viewModel) { event ->
when (event) {
is EnvironmentEvent.NavigateBack -> onNavigateBack.invoke()
is EnvironmentEvent.ShowToast -> {
Toast.makeText(context, event.message(context.resources), Toast.LENGTH_SHORT).show()
}
is EnvironmentEvent.ShowCertificateImportFileChooser -> {
certificateImportFilePickerLauncher.launch(
intentManager.createFileChooserIntent(withCameraIntents = false),
)
}
is EnvironmentEvent.ShowSystemCertificateSelectionDialog -> {
viewModel.trySendAction(
EnvironmentAction.SystemCertificateSelectionResultReceive(
keyChainManager.choosePrivateKeyAlias(
currentServerUrl = event.serverUrl?.takeUnless { it.isEmpty() },
),
),
)
}
}
}
if (state.shouldShowErrorDialog) {
BitwardenBasicDialog(
title = stringResource(id = R.string.an_error_has_occurred),
message = stringResource(id = R.string.environment_page_urls_error),
onDismissRequest = remember(viewModel) {
{ viewModel.trySendAction(EnvironmentAction.ErrorDialogDismiss) }
},
)
when (val dialog = state.dialog) {
is EnvironmentState.DialogState.Error -> {
BitwardenBasicDialog(
title = stringResource(id = R.string.an_error_has_occurred),
message = dialog.message(),
onDismissRequest = remember(viewModel) {
{ viewModel.trySendAction(EnvironmentAction.ErrorDialogDismiss) }
},
)
}
is EnvironmentState.DialogState.SetCertificateData -> {
BitwardenClientCertificateDialog(
onConfirmClick = remember(viewModel) {
{ alias, password ->
viewModel.trySendAction(
EnvironmentAction.SetCertificateInfoResultReceive(
certificateFileData = dialog.certificateBytes,
password = password,
alias = alias,
),
)
}
},
onDismissRequest = remember(viewModel) {
{
viewModel.trySendAction(
action = EnvironmentAction.SetCertificatePasswordDialogDismiss,
)
}
},
)
}
is EnvironmentState.DialogState.SystemCertificateWarningDialog -> {
@Suppress("MaxLineLength")
BitwardenTwoButtonDialog(
title = stringResource(R.string.warning),
message = stringResource(
R.string.system_certificates_are_not_as_secure_as_importing_certificates_to_bitwarden,
),
confirmButtonText = stringResource(R.string.continue_text),
onConfirmClick = remember(viewModel) {
{ viewModel.trySendAction(EnvironmentAction.ConfirmChooseSystemCertificateClick) }
},
dismissButtonText = stringResource(R.string.cancel),
onDismissClick = remember(viewModel) {
{ viewModel.trySendAction(EnvironmentAction.ErrorDialogDismiss) }
},
onDismissRequest = remember(viewModel) {
{ viewModel.trySendAction(EnvironmentAction.ErrorDialogDismiss) }
},
)
}
null -> Unit
}
val scrollBehavior = TopAppBarDefaults.pinnedScrollBehavior(rememberTopAppBarState())
BitwardenScaffold(
@@ -138,7 +220,7 @@ fun EnvironmentScreen(
.standardHorizontalMargin(),
)
Spacer(modifier = Modifier.height(height = 16.dp))
Spacer(modifier = Modifier.height(16.dp))
BitwardenListHeaderText(
label = stringResource(id = R.string.custom_environment),
@@ -213,6 +295,59 @@ fun EnvironmentScreen(
.standardHorizontalMargin(),
)
if (state.showMutualTlsOptions) {
Spacer(modifier = Modifier.height(height = 16.dp))
BitwardenListHeaderText(
label = stringResource(id = R.string.client_certificate_mtls),
modifier = Modifier
.fillMaxWidth()
.standardHorizontalMargin()
.padding(horizontal = 16.dp),
)
Spacer(modifier = Modifier.height(height = 8.dp))
BitwardenTextField(
label = stringResource(id = R.string.certificate_alias),
value = state.keyAlias,
supportingText = stringResource(
id = R.string.certificate_used_for_client_authentication,
),
onValueChange = {},
readOnly = true,
cardStyle = CardStyle.Full,
textFieldTestTag = "KeyAliasEntry",
modifier = Modifier
.fillMaxWidth()
.focusProperties { canFocus = false }
.standardHorizontalMargin(),
)
Spacer(modifier = Modifier.height(height = 16.dp))
BitwardenFilledButton(
label = stringResource(id = R.string.import_certificate),
onClick = remember(viewModel) {
{ viewModel.trySendAction(EnvironmentAction.ImportCertificateClick) }
},
modifier = Modifier
.fillMaxWidth()
.standardHorizontalMargin()
.testTag("ImportCertificateButton"),
)
Spacer(modifier = Modifier.height(height = 12.dp))
BitwardenOutlinedButton(
label = stringResource(id = R.string.choose_system_certificate),
onClick = remember(viewModel) {
{ viewModel.trySendAction(EnvironmentAction.ChooseSystemCertificateClick) }
},
modifier = Modifier
.fillMaxWidth()
.standardHorizontalMargin()
.testTag("ChooseSystemCertificateButton"),
)
}
Spacer(modifier = Modifier.height(height = 16.dp))
Spacer(modifier = Modifier.navigationBarsPadding())
}

View File

@@ -1,21 +1,32 @@
package com.x8bit.bitwarden.ui.auth.feature.environment
import android.os.Parcelable
import androidx.core.net.toUri
import androidx.lifecycle.SavedStateHandle
import androidx.lifecycle.viewModelScope
import com.x8bit.bitwarden.R
import com.x8bit.bitwarden.data.auth.datasource.disk.model.EnvironmentUrlDataJson
import com.x8bit.bitwarden.data.platform.datasource.disk.model.MutualTlsKeyHost
import com.x8bit.bitwarden.data.platform.manager.FeatureFlagManager
import com.x8bit.bitwarden.data.platform.manager.KeyManager
import com.x8bit.bitwarden.data.platform.manager.model.FlagKey
import com.x8bit.bitwarden.data.platform.manager.model.ImportPrivateKeyResult
import com.x8bit.bitwarden.data.platform.repository.EnvironmentRepository
import com.x8bit.bitwarden.data.platform.repository.model.Environment
import com.x8bit.bitwarden.data.vault.manager.FileManager
import com.x8bit.bitwarden.ui.platform.base.BaseViewModel
import com.x8bit.bitwarden.ui.platform.base.util.Text
import com.x8bit.bitwarden.ui.platform.base.util.asText
import com.x8bit.bitwarden.ui.platform.base.util.isValidUri
import com.x8bit.bitwarden.ui.platform.base.util.orNullIfBlank
import com.x8bit.bitwarden.ui.platform.manager.intent.IntentManager
import com.x8bit.bitwarden.ui.platform.manager.keychain.model.PrivateKeyAliasSelectionResult
import dagger.hilt.android.lifecycle.HiltViewModel
import kotlinx.coroutines.flow.launchIn
import kotlinx.coroutines.flow.map
import kotlinx.coroutines.flow.onEach
import kotlinx.coroutines.flow.update
import kotlinx.coroutines.launch
import kotlinx.parcelize.Parcelize
import javax.inject.Inject
@@ -24,9 +35,13 @@ private const val KEY_STATE = "state"
/**
* View model for the self-hosted/custom environment screen.
*/
@Suppress("TooManyFunctions")
@HiltViewModel
class EnvironmentViewModel @Inject constructor(
private val environmentRepository: EnvironmentRepository,
private val fileManager: FileManager,
private val keyManager: KeyManager,
private val featureFlagManager: FeatureFlagManager,
private val savedStateHandle: SavedStateHandle,
) : BaseViewModel<EnvironmentState, EnvironmentEvent, EnvironmentAction>(
initialState = savedStateHandle[KEY_STATE] ?: run {
@@ -37,13 +52,19 @@ class EnvironmentViewModel @Inject constructor(
is Environment.SelfHosted -> environment.environmentUrlData
}
val keyUri = environmentUrlData.keyUri?.toUri()
val keyAlias = keyUri?.path?.trim('/').orEmpty()
val keyHost = MutualTlsKeyHost.entries.find { it.name == keyUri?.authority }
EnvironmentState(
serverUrl = environmentUrlData.base,
webVaultServerUrl = environmentUrlData.webVault.orEmpty(),
apiServerUrl = environmentUrlData.api.orEmpty(),
identityServerUrl = environmentUrlData.identity.orEmpty(),
iconsServerUrl = environmentUrlData.icon.orEmpty(),
shouldShowErrorDialog = false,
keyAlias = keyAlias,
keyHost = keyHost,
dialog = null,
showMutualTlsOptions = featureFlagManager.getFeatureFlag(FlagKey.MutualTls),
)
},
) {
@@ -54,6 +75,11 @@ class EnvironmentViewModel @Inject constructor(
savedStateHandle[KEY_STATE] = it
}
.launchIn(viewModelScope)
featureFlagManager.getFeatureFlagFlow(FlagKey.MutualTls)
.map { EnvironmentAction.Internal.MutualTlsFeatureFlagUpdate(it) }
.onEach(::handleAction)
.launchIn(viewModelScope)
}
override fun handleAction(action: EnvironmentAction): Unit = when (action) {
@@ -65,6 +91,36 @@ class EnvironmentViewModel @Inject constructor(
is EnvironmentAction.ApiServerUrlChange -> handleApiServerUrlChangeAction(action)
is EnvironmentAction.IdentityServerUrlChange -> handleIdentityServerUrlChangeAction(action)
is EnvironmentAction.IconsServerUrlChange -> handleIconsServerUrlChangeAction(action)
is EnvironmentAction.ImportCertificateClick -> handleImportCertificateClick()
is EnvironmentAction.ImportCertificateFilePickerResultReceive -> {
handleCertificateFilePickerResultReceive(action)
}
is EnvironmentAction.SetCertificatePasswordDialogDismiss -> {
handleSetCertificatePasswordDialogDismiss()
}
is EnvironmentAction.CertificateInstallationResultReceive -> {
handleCertificateInstallationResultReceive(action)
}
is EnvironmentAction.SetCertificateInfoResultReceive -> {
handleSetCertificateInfoResultReceive(action)
}
is EnvironmentAction.ChooseSystemCertificateClick -> {
handleChooseSystemCertificateClickAction()
}
is EnvironmentAction.ConfirmChooseSystemCertificateClick -> {
handleConfirmChooseSystemCertificateClick()
}
is EnvironmentAction.SystemCertificateSelectionResultReceive -> {
handleSystemCertificateSelectionResultReceive(action)
}
is EnvironmentAction.Internal -> handleInternalAction(action)
}
private fun handleCloseClickAction() {
@@ -85,7 +141,7 @@ class EnvironmentViewModel @Inject constructor(
}
if (!urlsAreAllNullOrValid) {
mutableStateFlow.update { it.copy(shouldShowErrorDialog = true) }
showErrorDialog(message = R.string.environment_page_urls_error.asText())
return
}
@@ -95,7 +151,6 @@ class EnvironmentViewModel @Inject constructor(
val updatedApiServerUrl = state.apiServerUrl.prefixHttpsIfNecessaryOrNull()
val updatedIdentityServerUrl = state.identityServerUrl.prefixHttpsIfNecessaryOrNull()
val updatedIconsServerUrl = state.iconsServerUrl.prefixHttpsIfNecessaryOrNull()
environmentRepository.environment = Environment.SelfHosted(
environmentUrlData = EnvironmentUrlDataJson(
base = updatedServerUrl,
@@ -103,15 +158,16 @@ class EnvironmentViewModel @Inject constructor(
identity = updatedIdentityServerUrl,
icon = updatedIconsServerUrl,
webVault = updatedWebVaultServerUrl,
keyUri = state.keyUri,
),
)
sendEvent(EnvironmentEvent.ShowToast(message = R.string.environment_saved.asText()))
showToast(message = R.string.environment_saved.asText())
sendEvent(EnvironmentEvent.NavigateBack)
}
private fun handleErrorDialogDismiss() {
mutableStateFlow.update { it.copy(shouldShowErrorDialog = false) }
mutableStateFlow.update { it.copy(dialog = null) }
}
private fun handleServerUrlChangeAction(
@@ -122,6 +178,45 @@ class EnvironmentViewModel @Inject constructor(
}
}
private fun handleCertificateInstallationResultReceive(
action: EnvironmentAction.CertificateInstallationResultReceive,
) {
showToast(
message = if (action.success) {
R.string.certificate_installed.asText()
} else {
R.string.certificate_installation_failed.asText()
},
)
}
private fun handleSetCertificatePasswordDialogDismiss() {
mutableStateFlow.update { it.copy(dialog = null) }
}
private fun handleCertificateFilePickerResultReceive(
action: EnvironmentAction.ImportCertificateFilePickerResultReceive,
) {
mutableStateFlow.update {
it.copy(
dialog = EnvironmentState.DialogState.SetCertificateData(
certificateBytes = action.certificateFileData,
),
)
}
}
private fun handleConfirmChooseSystemCertificateClick() {
mutableStateFlow.update {
it.copy(dialog = null)
}
sendEvent(
EnvironmentEvent.ShowSystemCertificateSelectionDialog(
serverUrl = state.serverUrl.prefixHttpsIfNecessaryOrNull(),
),
)
}
private fun handleWebVaultServerUrlChangeAction(
action: EnvironmentAction.WebVaultServerUrlChange,
) {
@@ -153,6 +248,150 @@ class EnvironmentViewModel @Inject constructor(
it.copy(iconsServerUrl = action.iconsServerUrl)
}
}
private fun handleImportCertificateClick() {
sendEvent(EnvironmentEvent.ShowCertificateImportFileChooser)
}
private fun handleInternalAction(action: EnvironmentAction.Internal) {
when (action) {
is EnvironmentAction.Internal.ImportKeyResultReceive -> {
handleSaveKeyResultReceive(action)
}
is EnvironmentAction.Internal.MutualTlsFeatureFlagUpdate -> {
handleMutualTlsFeatureFlagUpdate(action)
}
}
}
private fun handleSetCertificateInfoResultReceive(
action: EnvironmentAction.SetCertificateInfoResultReceive,
) {
if (action.password.isBlank()) {
showErrorDialog(
message = R.string.validation_field_required.asText(
R.string.password.asText(),
),
)
return
}
if (action.alias.isBlank()) {
showErrorDialog(
message = R.string.validation_field_required.asText(
R.string.alias.asText(),
),
)
return
}
mutableStateFlow.update { it.copy(dialog = null) }
viewModelScope.launch {
fileManager
.uriToByteArray(action.certificateFileData.uri)
.map { bytes ->
keyManager.importMutualTlsCertificate(
key = bytes,
alias = action.alias,
password = action.password,
)
}
.map { result ->
sendAction(
EnvironmentAction.Internal.ImportKeyResultReceive(result),
)
}
}
}
private fun handleSaveKeyResultReceive(
action: EnvironmentAction.Internal.ImportKeyResultReceive,
) {
when (val result = action.result) {
is ImportPrivateKeyResult.Success -> {
mutableStateFlow.update { state ->
state.copy(
keyAlias = result.alias,
keyHost = MutualTlsKeyHost.ANDROID_KEY_STORE,
)
}
}
is ImportPrivateKeyResult.Error.UnsupportedKey -> {
showToast(message = R.string.unsupported_certificate_type.asText())
}
is ImportPrivateKeyResult.Error.KeyStoreOperationFailed -> {
showToast(message = R.string.certificate_installation_failed.asText())
}
is ImportPrivateKeyResult.Error.UnrecoverableKey -> {
showToast(message = R.string.certificate_password_incorrect.asText())
}
is ImportPrivateKeyResult.Error.InvalidCertificateChain -> {
showToast(R.string.invalid_certificate_chain.asText())
}
ImportPrivateKeyResult.Error.DuplicateAlias -> {
// TODO [PM-17686] Improve duplicate alias handling.
showToast(R.string.certificate_alias_already_exists.asText())
}
}
}
private fun handleMutualTlsFeatureFlagUpdate(
action: EnvironmentAction.Internal.MutualTlsFeatureFlagUpdate,
) {
mutableStateFlow.update {
it.copy(
showMutualTlsOptions = action.enabled,
)
}
}
private fun handleChooseSystemCertificateClickAction() {
mutableStateFlow.update {
it.copy(
dialog = EnvironmentState.DialogState.SystemCertificateWarningDialog,
)
}
}
private fun handleSystemCertificateSelectionResultReceive(
action: EnvironmentAction.SystemCertificateSelectionResultReceive,
) {
when (val result = action.privateKeyAliasSelectionResult) {
is PrivateKeyAliasSelectionResult.Success -> {
mutableStateFlow.update {
it.copy(
keyAlias = result.alias.orEmpty(),
keyHost = result.alias?.let { MutualTlsKeyHost.KEY_CHAIN },
)
}
}
is PrivateKeyAliasSelectionResult.Error -> {
sendEvent(
EnvironmentEvent.ShowToast(
message = R.string.error_loading_certificate.asText(),
),
)
}
}
}
private fun showToast(message: Text) {
sendEvent(EnvironmentEvent.ShowToast(message))
}
private fun showErrorDialog(message: Text) {
mutableStateFlow.update {
it.copy(dialog = EnvironmentState.DialogState.Error(message = message))
}
}
}
/**
@@ -165,8 +404,43 @@ data class EnvironmentState(
val apiServerUrl: String,
val identityServerUrl: String,
val iconsServerUrl: String,
val shouldShowErrorDialog: Boolean,
) : Parcelable
val keyAlias: String,
val dialog: DialogState?,
val showMutualTlsOptions: Boolean,
// internal
private val keyHost: MutualTlsKeyHost?,
) : Parcelable {
val keyUri: String?
get() = "cert://$keyHost/$keyAlias"
.takeUnless { keyHost == null || keyAlias.isEmpty() }
/**
* Models the dialog states of the environment screen.
*/
@Parcelize
sealed class DialogState : Parcelable {
/**
* Show an error dialog.
*/
data class Error(
val message: Text,
) : DialogState()
/**
* Show a dialog to capture the certificate alias and password.
*/
data class SetCertificateData(
val certificateBytes: IntentManager.FileData,
) : DialogState()
/**
* Show a dialog warning the user that system certificates are not as secure.
*/
data object SystemCertificateWarningDialog : DialogState()
}
}
/**
* Models events for the environment screen.
@@ -177,6 +451,18 @@ sealed class EnvironmentEvent {
*/
data object NavigateBack : EnvironmentEvent()
/**
* Show the File chooser dialog for certificate import.
*/
data object ShowCertificateImportFileChooser : EnvironmentEvent()
/**
* Show the system certificate selection dialog.
*/
data class ShowSystemCertificateSelectionDialog(
val serverUrl: String?,
) : EnvironmentEvent()
/**
* Show a toast with the given message.
*/
@@ -204,6 +490,26 @@ sealed class EnvironmentAction {
*/
data object ErrorDialogDismiss : EnvironmentAction()
/**
* User clicked the import certificate button.
*/
data object ImportCertificateClick : EnvironmentAction()
/**
* User dismissed the set certificate password dialog without providing a password.
*/
data object SetCertificatePasswordDialogDismiss : EnvironmentAction()
/**
* User clicked the choose system certificate button.
*/
data object ChooseSystemCertificateClick : EnvironmentAction()
/**
* User confirmed choosing the system certificate.
*/
data object ConfirmChooseSystemCertificateClick : EnvironmentAction()
/**
* Indicates that the overall server URL has changed.
*/
@@ -211,6 +517,13 @@ sealed class EnvironmentAction {
val serverUrl: String,
) : EnvironmentAction()
/**
* Indicates that the certificate installation result was received.
*/
data class CertificateInstallationResultReceive(
val success: Boolean,
) : EnvironmentAction()
/**
* Indicates that the web vault server URL has changed.
*/
@@ -238,6 +551,48 @@ sealed class EnvironmentAction {
data class IconsServerUrlChange(
val iconsServerUrl: String,
) : EnvironmentAction()
/**
* Indicates that the certificate file selection result was received.
*/
data class ImportCertificateFilePickerResultReceive(
val certificateFileData: IntentManager.FileData,
) : EnvironmentAction()
/**
* Indicates the certificate info data was received.
*/
data class SetCertificateInfoResultReceive(
val certificateFileData: IntentManager.FileData,
val password: String,
val alias: String,
) : EnvironmentAction()
/**
* User has selected a system certificate alias.
*/
data class SystemCertificateSelectionResultReceive(
val privateKeyAliasSelectionResult: PrivateKeyAliasSelectionResult,
) : EnvironmentAction()
/**
* Models actions the EnvironmentViewModel itself may trigger.
*/
sealed class Internal : EnvironmentAction() {
/**
* Indicates the result of importing a key was received.
*/
data class ImportKeyResultReceive(
val result: ImportPrivateKeyResult,
) : Internal()
/**
* Indicates the mutual TLS feature flag was updated.
*/
data class MutualTlsFeatureFlagUpdate(
val enabled: Boolean,
) : Internal()
}
}
/**

View File

@@ -22,6 +22,8 @@ import com.x8bit.bitwarden.ui.platform.manager.exit.ExitManager
import com.x8bit.bitwarden.ui.platform.manager.exit.ExitManagerImpl
import com.x8bit.bitwarden.ui.platform.manager.intent.IntentManager
import com.x8bit.bitwarden.ui.platform.manager.intent.IntentManagerImpl
import com.x8bit.bitwarden.ui.platform.manager.keychain.KeyChainManager
import com.x8bit.bitwarden.ui.platform.manager.keychain.KeyChainManagerImpl
import com.x8bit.bitwarden.ui.platform.manager.nfc.NfcManager
import com.x8bit.bitwarden.ui.platform.manager.nfc.NfcManagerImpl
import com.x8bit.bitwarden.ui.platform.manager.permissions.PermissionsManager
@@ -52,6 +54,7 @@ fun LocalManagerProvider(
LocalFido2CompletionManager provides fido2CompletionManager,
LocalAppReviewManager provides AppReviewManagerImpl(activity),
LocalAppResumeStateManager provides AppResumeStateManagerImpl(),
LocalKeyChainManager provides KeyChainManagerImpl(activity),
) {
content()
}
@@ -110,3 +113,7 @@ val LocalAppReviewManager: ProvidableCompositionLocal<AppReviewManager> = compos
val LocalAppResumeStateManager = compositionLocalOf<AppResumeStateManager> {
error("CompositionLocal AppResumeStateManager not present")
}
val LocalKeyChainManager: ProvidableCompositionLocal<KeyChainManager> = compositionLocalOf {
error("CompositionLocal KeyChainManager not present")
}

View File

@@ -0,0 +1,18 @@
package com.x8bit.bitwarden.ui.platform.manager.keychain
import com.x8bit.bitwarden.ui.platform.manager.keychain.model.PrivateKeyAliasSelectionResult
/**
* Responsible for managing keys stored in the system KeyChain.
*/
interface KeyChainManager {
/**
* Display the system private key alias selection dialog.
*
* @param currentServerUrl The currently selected server URL.
*/
suspend fun choosePrivateKeyAlias(
currentServerUrl: String?,
): PrivateKeyAliasSelectionResult
}

View File

@@ -0,0 +1,42 @@
package com.x8bit.bitwarden.ui.platform.manager.keychain
import android.app.Activity
import android.security.KeyChain
import androidx.core.net.toUri
import com.x8bit.bitwarden.ui.platform.manager.keychain.model.PrivateKeyAliasSelectionResult
import kotlinx.coroutines.channels.awaitClose
import kotlinx.coroutines.flow.callbackFlow
import kotlinx.coroutines.flow.first
/**
* Default implementation of [KeyChainManager].
*/
class KeyChainManagerImpl(
private val activity: Activity,
) : KeyChainManager {
override suspend fun choosePrivateKeyAlias(
currentServerUrl: String?,
): PrivateKeyAliasSelectionResult =
callbackFlow<PrivateKeyAliasSelectionResult> {
try {
KeyChain.choosePrivateKeyAlias(
activity,
{ alias ->
trySend(PrivateKeyAliasSelectionResult.Success(alias))
close()
},
null,
null,
currentServerUrl?.toUri(),
null,
)
} catch (_: IllegalArgumentException) {
trySend(PrivateKeyAliasSelectionResult.Error)
close()
}
awaitClose()
}
.first()
}

View File

@@ -0,0 +1,17 @@
package com.x8bit.bitwarden.ui.platform.manager.keychain.model
/**
* Represents the result of an operation to select a private key alias from the system KeyChain.
*/
sealed class PrivateKeyAliasSelectionResult {
/**
* Indicates that the operation was successful and an alias was selected (or null if none was
* selected).
*/
data class Success(val alias: String?) : PrivateKeyAliasSelectionResult()
/**
* Indicates that an error occurred during the operation.
*/
object Error : PrivateKeyAliasSelectionResult()
}

View File

@@ -281,6 +281,9 @@ Scanning will happen automatically.</string>
<string name="encryption_key_migration_required_description_long">Encryption key migration required. Please login through the web vault to update your encryption key.</string>
<string name="learn_more">Learn more</string>
<string name="api_url">API server URL</string>
<string name="client_certificate_mtls">Client certificate (mTLS)</string>
<string name="certificate_used_for_client_authentication">Certificate used for client authentication.</string>
<string name="use_system_certificate">Use system certificate</string>
<string name="custom_environment">Custom environment</string>
<string name="custom_environment_footer">For advanced users. You can specify the base URL of each service independently.</string>
<string name="environment_saved">The environment URLs have been saved.</string>
@@ -1165,5 +1168,17 @@ Do you want to switch to this account?</string>
<string name="coachmark_6_of_6">6 OF 6</string>
<string name="use_these_options_to_adjust_your_password_to_your_account_requirements">Use these options to adjust your password to meet your account website\'s requirements.</string>
<string name="after_you_save_your_new_password_to_bitwarden_don_t_forget_to_update_it_on_your_account_website">"After you save your new password to Bitwarden, dont forget to update it on your account website. "</string>
<string name="link">Link</string>
<string name="link">Link</string>
<string name="error_loading_certificate">Error loading certificate</string>
<string name="certificate_alias">Certificate alias</string>
<string name="import_certificate">Import certificate</string>
<string name="choose_system_certificate">Choose system certificate</string>
<string name="certificate_installed">Certificate installed</string>
<string name="certificate_installation_failed">Certificate installation failed</string>
<string name="certificate_selection_failed">Certificate selection failed</string>
<string name="unsupported_certificate_type">Unsupported certificate type</string>
<string name="certificate_password_incorrect">Certificate password incorrect</string>
<string name="invalid_certificate_chain">Invalid certificate chain</string>
<string name="system_certificates_are_not_as_secure_as_importing_certificates_to_bitwarden">Using a system certificate is less secure than storing the certificate with Bitwarden. Continuing will display a list of available system certificates if one is already installed.</string>
<string name="certificate_alias_already_exists">Certificate alias already exists</string>
</resources>

View File

@@ -5,14 +5,21 @@ import com.x8bit.bitwarden.data.platform.datasource.network.interceptor.AuthToke
import com.x8bit.bitwarden.data.platform.datasource.network.interceptor.BaseUrlInterceptors
import com.x8bit.bitwarden.data.platform.datasource.network.interceptor.HeadersInterceptor
import com.x8bit.bitwarden.data.platform.datasource.network.model.NetworkResult
import com.x8bit.bitwarden.data.platform.datasource.network.ssl.SslManager
import com.x8bit.bitwarden.data.util.mockBuilder
import io.mockk.every
import io.mockk.mockk
import io.mockk.mockkConstructor
import io.mockk.slot
import io.mockk.unmockkConstructor
import io.mockk.verify
import kotlinx.coroutines.runBlocking
import kotlinx.coroutines.test.runTest
import kotlinx.serialization.json.Json
import kotlinx.serialization.json.JsonObject
import okhttp3.Authenticator
import okhttp3.Interceptor
import okhttp3.OkHttpClient
import okhttp3.mockwebserver.MockResponse
import okhttp3.mockwebserver.MockWebServer
import org.junit.After
@@ -23,6 +30,10 @@ import org.junit.jupiter.api.Assertions.assertTrue
import retrofit2.Retrofit
import retrofit2.create
import retrofit2.http.GET
import javax.net.ssl.SSLContext
import javax.net.ssl.SSLSocketFactory
import javax.net.ssl.TrustManager
import javax.net.ssl.X509TrustManager
class RetrofitsTest {
private val authTokenInterceptor = mockk<AuthTokenInterceptor> {
@@ -47,12 +58,17 @@ class RetrofitsTest {
}
private val json = Json
private val server = MockWebServer()
private val mockSslManager = mockk<SslManager> {
every { sslContext } returns mockk(relaxed = true)
every { trustManagers } returns arrayOf(mockk<X509TrustManager>(relaxed = true))
}
private val retrofits = RetrofitsImpl(
authTokenInterceptor = authTokenInterceptor,
baseUrlInterceptors = baseUrlInterceptors,
headersInterceptor = headersInterceptors,
refreshAuthenticator = refreshAuthenticator,
sslManager = mockSslManager,
json = json,
)
@@ -71,6 +87,7 @@ class RetrofitsTest {
@After
fun tearDown() {
server.shutdown()
unmockkConstructor(OkHttpClient.Builder::class)
}
@Test
@@ -256,6 +273,94 @@ class RetrofitsTest {
assertFalse(isEventsInterceptorCalled)
}
@Suppress("MaxLineLength")
@Test
fun `createStaticRetrofit should set sslSocketFactory`() =
runTest {
val mockTrustManager = mockk<X509TrustManager>(relaxed = true)
val mockSocketFactory = mockk<SSLSocketFactory>()
val mockSslContext = mockk<SSLContext> {
every { socketFactory } returns mockSocketFactory
}
setupMockOkHttpClientBuilder(
sslContext = mockSslContext,
trustManagers = arrayOf(mockTrustManager),
)
retrofits.createStaticRetrofit(isAuthenticated = false)
verify {
anyConstructed<OkHttpClient.Builder>()
.sslSocketFactory(
sslSocketFactory = mockSocketFactory,
trustManager = mockTrustManager,
)
}
}
@Suppress("MaxLineLength")
@Test
fun `authenticatedOkHttpClient should set sslSocketFactory`() =
runTest {
val mockTrustManager = mockk<X509TrustManager>(relaxed = true)
val mockSocketFactory = mockk<SSLSocketFactory>()
val mockSslContext = mockk<SSLContext> {
every { socketFactory } returns mockSocketFactory
}
setupMockOkHttpClientBuilder(
sslContext = mockSslContext,
trustManagers = arrayOf(mockTrustManager),
)
retrofits.authenticatedApiRetrofit
verify {
anyConstructed<OkHttpClient.Builder>()
.sslSocketFactory(
sslSocketFactory = mockSocketFactory,
trustManager = mockTrustManager,
)
}
}
@Suppress("MaxLineLength")
@Test
fun `unauthenticatedOkHttpClient should set sslSocketFactory`() =
runTest {
val mockTrustManager = mockk<X509TrustManager>(relaxed = true)
val mockSocketFactory = mockk<SSLSocketFactory>()
val mockSslContext = mockk<SSLContext> {
every { socketFactory } returns mockSocketFactory
}
setupMockOkHttpClientBuilder(
sslContext = mockSslContext,
trustManagers = arrayOf(mockTrustManager),
)
retrofits.unauthenticatedApiRetrofit
verify {
anyConstructed<OkHttpClient.Builder>()
.sslSocketFactory(
sslSocketFactory = mockSocketFactory,
trustManager = mockTrustManager,
)
}
}
private fun setupMockOkHttpClientBuilder(
sslContext: SSLContext = mockk<SSLContext>(),
trustManagers: Array<TrustManager> = emptyArray(),
) {
mockkConstructor(OkHttpClient.Builder::class)
every { mockSslManager.sslContext } returns sslContext
every { mockSslManager.trustManagers } returns trustManagers
mockBuilder<OkHttpClient.Builder> {
it.sslSocketFactory(any(), any())
}
every { anyConstructed<OkHttpClient.Builder>().build() } returns mockk(relaxed = true)
}
private fun Retrofit.createMockRetrofit(): Retrofit =
this
.newBuilder()

View File

@@ -0,0 +1,218 @@
package com.x8bit.bitwarden.data.platform.datasource.network.ssl
import android.net.Uri
import com.x8bit.bitwarden.data.auth.datasource.disk.model.EnvironmentUrlDataJson
import com.x8bit.bitwarden.data.platform.datasource.disk.model.MutualTlsCertificate
import com.x8bit.bitwarden.data.platform.datasource.disk.model.MutualTlsKeyHost
import com.x8bit.bitwarden.data.platform.manager.KeyManager
import com.x8bit.bitwarden.data.platform.repository.EnvironmentRepository
import com.x8bit.bitwarden.data.platform.repository.model.Environment
import io.mockk.every
import io.mockk.just
import io.mockk.mockk
import io.mockk.mockkStatic
import io.mockk.runs
import io.mockk.slot
import io.mockk.unmockkStatic
import io.mockk.verify
import org.junit.jupiter.api.AfterEach
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.Assertions.assertNotNull
import org.junit.jupiter.api.Assertions.assertNull
import org.junit.jupiter.api.Assertions.assertTrue
import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Test
import java.security.KeyStore
import java.security.PrivateKey
import java.security.cert.X509Certificate
import javax.net.ssl.SSLContext
import javax.net.ssl.TrustManager
import javax.net.ssl.TrustManagerFactory
import javax.net.ssl.X509ExtendedKeyManager
class SslManagerTest {
private val mockEnvironment = mockk<Environment> {
every { environmentUrlData } returns DEFAULT_ENV_URL_DATA
}
private val mockEnvironmentRepository = mockk<EnvironmentRepository> {
every { environment } returns mockEnvironment
}
private val mockMutualTlsCertificate = mockk<MutualTlsCertificate> {
every { alias } returns "mockAlias"
}
private val mockKeyManager = mockk<KeyManager> {
every { getMutualTlsCertificateChain(any(), any()) } returns mockMutualTlsCertificate
}
private val mockTrustManagerFactory = mockk<TrustManagerFactory> {
every { init(null as? KeyStore?) } just runs
every { trustManagers } returns DEFAULT_TRUST_MANAGERS
}
private val sslManager: SslManagerImpl = SslManagerImpl(
keyManager = mockKeyManager,
environmentRepository = mockEnvironmentRepository,
)
@BeforeEach
fun setUp() {
mockkStatic(TrustManagerFactory::class)
every {
TrustManagerFactory.getDefaultAlgorithm()
} returns "defaultAlgorithm"
every {
TrustManagerFactory.getInstance("defaultAlgorithm")
} returns mockTrustManagerFactory
}
@AfterEach
fun tearDown() {
unmockkStatic(TrustManagerFactory::class, Uri::class, SSLContext::class)
}
@Test
fun `sslContext should be initialized with default TLS protocol`() {
setupMockUri()
assertTrue(sslManager.sslContext.protocol == "TLS")
}
@Test
fun `X509ExtendedKeyManagerImpl should initialize with mutualTlsCertificate`() {
setupMockUri()
mockkStatic(SSLContext::class)
val keyManagersCaptor = slot<Array<X509ExtendedKeyManager>>()
val trustManagersCaptor = slot<Array<TrustManager>>()
every { SSLContext.getInstance("TLS") } returns mockk<SSLContext> {
every {
init(
capture(keyManagersCaptor),
capture(trustManagersCaptor),
any(),
)
} just runs
}
every { mockEnvironment.environmentUrlData } returns DEFAULT_ENV_URL_DATA
every { mockMutualTlsCertificate.alias } returns "mockAlias"
every { mockMutualTlsCertificate.certificateChain } returns listOf(
mockk<X509Certificate>(name = "MockCertificate1"),
mockk<X509Certificate>(name = "MockCertificate2"),
)
every {
mockMutualTlsCertificate.privateKey
} returns mockk<PrivateKey>()
every {
mockKeyManager.getMutualTlsCertificateChain(
alias = "mockAlias",
host = MutualTlsKeyHost.ANDROID_KEY_STORE,
)
} returns mockMutualTlsCertificate
assertNotNull(sslManager.sslContext)
val keyManager = keyManagersCaptor.captured.first()
assertEquals(
mockMutualTlsCertificate.alias,
keyManager.chooseClientAlias(null, null, null),
)
assertTrue(
keyManager
.getCertificateChain("mockAlias")
.contentEquals(
mockMutualTlsCertificate
.certificateChain
.toTypedArray(),
),
)
assertEquals(
mockMutualTlsCertificate.privateKey,
keyManager.getPrivateKey("mockAlias"),
)
}
@Test
fun `mutualTlsCertificate should return null when keyUri is null`() {
every {
mockEnvironment.environmentUrlData
} returns DEFAULT_ENV_URL_DATA.copy(keyUri = null)
assertNull(sslManager.mutualTlsCertificate)
}
@Test
fun `mutualTlsCertificate should be null when host is invalid`() {
setupMockUri(authority = "UNKNOWN_HOST")
assertNull(sslManager.mutualTlsCertificate)
}
@Test
fun `mutualTlsCertificate should be null when alias is null`() {
setupMockUri(path = null)
assertNull(sslManager.mutualTlsCertificate)
}
@Test
fun `mutualTlsCertificate should trim path when it is not null`() {
setupMockUri(path = "/mockAlias/")
assertEquals("mockAlias", sslManager.mutualTlsCertificate?.alias)
}
@Test
fun `mutualTlsCertificate should be null when alias is empty after trim`() {
setupMockUri(path = "/")
assertNull(sslManager.mutualTlsCertificate)
}
@Test
fun `mutualTlsCertificate should call keyManager with correct alias and host`() {
// Set host to ANDROID_KEY_STORE
setupMockUri()
assertNotNull(sslManager.mutualTlsCertificate)
verify {
mockKeyManager.getMutualTlsCertificateChain(
alias = "mockAlias",
host = MutualTlsKeyHost.ANDROID_KEY_STORE,
)
}
// Set host to KEY_CHAIN
setupMockUri(authority = "KEY_CHAIN")
assertNotNull(sslManager.mutualTlsCertificate)
verify {
mockKeyManager.getMutualTlsCertificateChain(
alias = "mockAlias",
host = MutualTlsKeyHost.KEY_CHAIN,
)
}
}
@Suppress("MaxLineLength")
@Test
fun `trustManagers should return TrustManager array initialized with default algorithm and null keystore`() {
assertTrue(sslManager.trustManagers.contentEquals(DEFAULT_TRUST_MANAGERS))
verify {
TrustManagerFactory.getInstance("defaultAlgorithm")
TrustManagerFactory.getDefaultAlgorithm()
mockTrustManagerFactory.init(null as? KeyStore?)
}
}
private fun setupMockUri(
authority: String = "ANDROID_KEY_STORE",
path: String? = "/mockAlias",
) {
mockkStatic(Uri::class)
val uriMock = mockk<Uri>()
every { Uri.parse(any()) } returns uriMock
every { uriMock.authority } returns authority
every { uriMock.path } returns path
}
}
private val DEFAULT_TRUST_MANAGERS = arrayOf<TrustManager>(
mockk(name = "MockTrustManager1"),
mockk(name = "MockTrustManager2"),
)
val DEFAULT_ENV_URL_DATA = EnvironmentUrlDataJson(
base = "https://example.com",
keyUri = "cert://ANDROID_KEY_STORE/mockAlias",
)

View File

@@ -3,7 +3,7 @@ package com.x8bit.bitwarden.data.platform.manager
import android.content.Context
import android.security.KeyChain
import android.security.KeyChainException
import com.x8bit.bitwarden.data.platform.datasource.disk.model.ImportPrivateKeyResult
import com.x8bit.bitwarden.data.platform.manager.model.ImportPrivateKeyResult
import com.x8bit.bitwarden.data.platform.datasource.disk.model.MutualTlsCertificate
import com.x8bit.bitwarden.data.platform.datasource.disk.model.MutualTlsKeyHost
import io.mockk.every

View File

@@ -10,9 +10,17 @@ import androidx.compose.ui.test.onAllNodesWithText
import androidx.compose.ui.test.onNodeWithContentDescription
import androidx.compose.ui.test.onNodeWithText
import androidx.compose.ui.test.performClick
import androidx.compose.ui.test.performScrollTo
import androidx.compose.ui.test.performTextInput
import com.x8bit.bitwarden.data.platform.repository.util.bufferedMutableSharedFlow
import com.x8bit.bitwarden.ui.auth.feature.environment.EnvironmentState.DialogState
import com.x8bit.bitwarden.ui.platform.base.BaseComposeTest
import com.x8bit.bitwarden.ui.platform.base.util.asText
import com.x8bit.bitwarden.ui.platform.manager.intent.IntentManager
import com.x8bit.bitwarden.ui.platform.manager.keychain.KeyChainManager
import com.x8bit.bitwarden.ui.platform.manager.keychain.model.PrivateKeyAliasSelectionResult
import io.mockk.coEvery
import io.mockk.coVerify
import io.mockk.every
import io.mockk.mockk
import io.mockk.verify
@@ -26,6 +34,12 @@ class EnvironmentScreenTest : BaseComposeTest() {
private var onNavigateBackCalled = false
private val mutableEventFlow = bufferedMutableSharedFlow<EnvironmentEvent>()
private val mutableStateFlow = MutableStateFlow(DEFAULT_STATE)
private val mockIntentManager = mockk<IntentManager>(relaxed = true)
private val mockKeyChainManager = mockk<KeyChainManager> {
coEvery {
choosePrivateKeyAlias(any())
} returns PrivateKeyAliasSelectionResult.Success("mockAlias")
}
private val viewModel = mockk<EnvironmentViewModel>(relaxed = true) {
every { eventFlow } returns mutableEventFlow
every { stateFlow } returns mutableStateFlow
@@ -36,6 +50,8 @@ class EnvironmentScreenTest : BaseComposeTest() {
composeTestRule.setContent {
EnvironmentScreen(
onNavigateBack = { onNavigateBackCalled = true },
intentManager = mockIntentManager,
keyChainManager = mockKeyChainManager,
viewModel = viewModel,
)
}
@@ -69,7 +85,11 @@ class EnvironmentScreenTest : BaseComposeTest() {
mutableStateFlow.update {
it.copy(
shouldShowErrorDialog = true,
dialog = DialogState.Error(
("One or more of the URLs entered are invalid. " +
"Please revise it and try to save again.")
.asText(),
),
)
}
@@ -96,7 +116,7 @@ class EnvironmentScreenTest : BaseComposeTest() {
fun `error dialog OK click should send ErrorDialogDismiss action`() {
mutableStateFlow.update {
it.copy(
shouldShowErrorDialog = true,
dialog = DialogState.Error("Error".asText()),
)
}
composeTestRule
@@ -131,6 +151,62 @@ class EnvironmentScreenTest : BaseComposeTest() {
}
}
@Test
fun `use system certificate click should send UseSystemKeyCertificateClick`() {
composeTestRule
.onNodeWithText("Choose system certificate")
.performScrollTo()
.assertIsDisplayed()
.performClick()
verify {
viewModel.trySendAction(EnvironmentAction.ChooseSystemCertificateClick)
}
}
@Test
fun `ShowSystemCertificateSelection event should show system certificate selection dialog`() {
mutableEventFlow.tryEmit(
EnvironmentEvent.ShowSystemCertificateSelectionDialog(serverUrl = ""),
)
coVerify { mockKeyChainManager.choosePrivateKeyAlias(null) }
}
@Suppress("MaxLineLength")
@Test
fun `system certificate selection should send SystemCertificateSelectionResultReceive action`() {
coEvery {
mockKeyChainManager.choosePrivateKeyAlias(null)
} returns PrivateKeyAliasSelectionResult.Success("alias")
mutableEventFlow.tryEmit(
EnvironmentEvent.ShowSystemCertificateSelectionDialog(serverUrl = ""),
)
verify {
viewModel.trySendAction(
EnvironmentAction.SystemCertificateSelectionResultReceive(
privateKeyAliasSelectionResult = PrivateKeyAliasSelectionResult.Success(
alias = "alias",
),
),
)
}
}
@Test
fun `key alias should change according to the state`() {
composeTestRule
.onNodeWithText("Certificate alias")
.assertTextEquals("Certificate alias", "")
mutableStateFlow.update { it.copy(keyAlias = "mock-alias") }
composeTestRule
.onNodeWithText("Certificate alias")
.assertTextEquals("Certificate alias", "mock-alias")
}
@Test
fun `web vault URL should change according to the state`() {
composeTestRule
@@ -238,11 +314,14 @@ class EnvironmentScreenTest : BaseComposeTest() {
companion object {
val DEFAULT_STATE = EnvironmentState(
serverUrl = "",
keyAlias = "",
webVaultServerUrl = "",
apiServerUrl = "",
identityServerUrl = "",
iconsServerUrl = "",
shouldShowErrorDialog = false,
keyHost = null,
dialog = null,
showMutualTlsOptions = true,
)
}
}

View File

@@ -1,13 +1,28 @@
package com.x8bit.bitwarden.ui.auth.feature.environment
import android.net.Uri
import androidx.lifecycle.SavedStateHandle
import app.cash.turbine.test
import com.x8bit.bitwarden.R
import com.x8bit.bitwarden.data.auth.datasource.disk.model.EnvironmentUrlDataJson
import com.x8bit.bitwarden.data.platform.datasource.disk.model.MutualTlsKeyHost
import com.x8bit.bitwarden.data.platform.manager.FeatureFlagManager
import com.x8bit.bitwarden.data.platform.manager.KeyManager
import com.x8bit.bitwarden.data.platform.manager.model.FlagKey
import com.x8bit.bitwarden.data.platform.manager.model.ImportPrivateKeyResult
import com.x8bit.bitwarden.data.platform.repository.model.Environment
import com.x8bit.bitwarden.data.platform.repository.util.FakeEnvironmentRepository
import com.x8bit.bitwarden.data.platform.util.asSuccess
import com.x8bit.bitwarden.data.vault.manager.FileManager
import com.x8bit.bitwarden.ui.platform.base.BaseViewModelTest
import com.x8bit.bitwarden.ui.platform.base.util.asText
import com.x8bit.bitwarden.ui.platform.manager.intent.IntentManager
import com.x8bit.bitwarden.ui.platform.manager.keychain.model.PrivateKeyAliasSelectionResult
import io.mockk.coEvery
import io.mockk.coVerify
import io.mockk.every
import io.mockk.mockk
import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.test.runTest
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.Test
@@ -15,6 +30,13 @@ import org.junit.jupiter.api.Test
class EnvironmentViewModelTest : BaseViewModelTest() {
private val fakeEnvironmentRepository = FakeEnvironmentRepository()
private val mutableMutualTlsFeatureFlagFlow = MutableStateFlow(true)
private val mockFeatureFlagManager = mockk<FeatureFlagManager> {
every { getFeatureFlag(FlagKey.MutualTls) } returns true
every { getFeatureFlagFlow(FlagKey.MutualTls) } returns mutableMutualTlsFeatureFlagFlow
}
private val mockKeyManager = mockk<KeyManager>()
private val mockFileManager = mockk<FileManager>()
@Suppress("MaxLineLength")
@Test
@@ -60,6 +82,8 @@ class EnvironmentViewModelTest : BaseViewModelTest() {
apiServerUrl = "saved-api",
identityServerUrl = "saved-identity",
iconsServerUrl = "saved-icons",
keyHost = MutualTlsKeyHost.ANDROID_KEY_STORE,
keyAlias = "saved-key-alias",
)
val viewModel = createViewModel(
savedStateHandle = SavedStateHandle(
@@ -75,6 +99,8 @@ class EnvironmentViewModelTest : BaseViewModelTest() {
apiServerUrl = "saved-api",
identityServerUrl = "saved-identity",
iconsServerUrl = "saved-icons",
keyHost = MutualTlsKeyHost.ANDROID_KEY_STORE,
keyAlias = "saved-key-alias",
),
viewModel.stateFlow.value,
)
@@ -115,7 +141,9 @@ class EnvironmentViewModelTest : BaseViewModelTest() {
assertEquals(
initialState.copy(
shouldShowErrorDialog = true,
dialog = EnvironmentState.DialogState.Error(
message = R.string.environment_page_urls_error.asText(),
),
),
viewModel.stateFlow.value,
)
@@ -154,6 +182,11 @@ class EnvironmentViewModelTest : BaseViewModelTest() {
EnvironmentAction.IconsServerUrlChange(
iconsServerUrl = "icons-url",
),
EnvironmentAction.SystemCertificateSelectionResultReceive(
privateKeyAliasSelectionResult = PrivateKeyAliasSelectionResult.Success(
alias = "mockAlias",
),
),
)
.forEach { viewModel.trySendAction(it) }
@@ -179,6 +212,7 @@ class EnvironmentViewModelTest : BaseViewModelTest() {
notifications = null,
webVault = "http://web-vault-url",
events = null,
keyUri = "cert://KEY_CHAIN/mockAlias",
),
),
fakeEnvironmentRepository.environment,
@@ -226,6 +260,7 @@ class EnvironmentViewModelTest : BaseViewModelTest() {
notifications = null,
webVault = "http://web-vault-url",
events = null,
keyUri = null,
),
),
fakeEnvironmentRepository.environment,
@@ -293,6 +328,337 @@ class EnvironmentViewModelTest : BaseViewModelTest() {
)
}
@Suppress("MaxLineLength")
@Test
fun `SystemCertificateSelectionResultReceive should update key alias and key host when successful`() {
val viewModel = createViewModel()
viewModel.trySendAction(
EnvironmentAction.SystemCertificateSelectionResultReceive(
privateKeyAliasSelectionResult = PrivateKeyAliasSelectionResult.Success(
alias = "mockAlias",
),
),
)
assertEquals(
DEFAULT_STATE.copy(
keyAlias = "mockAlias",
keyHost = MutualTlsKeyHost.KEY_CHAIN,
),
viewModel.stateFlow.value,
)
viewModel.trySendAction(
EnvironmentAction.SystemCertificateSelectionResultReceive(
privateKeyAliasSelectionResult = PrivateKeyAliasSelectionResult.Success(
alias = null,
),
),
)
assertEquals(
DEFAULT_STATE.copy(
keyAlias = "",
keyHost = null,
),
viewModel.stateFlow.value,
)
}
@Test
fun `SystemCertificateSelectionResultReceive should show toast when error`() = runTest {
val viewModel = createViewModel()
viewModel.trySendAction(
EnvironmentAction.SystemCertificateSelectionResultReceive(
privateKeyAliasSelectionResult = PrivateKeyAliasSelectionResult.Error,
),
)
viewModel.eventFlow.test {
assertEquals(
EnvironmentEvent.ShowToast(R.string.error_loading_certificate.asText()),
awaitItem(),
)
}
}
@Test
fun `ChooseSystemCertificate should show system certificate warning dialog`() =
runTest {
val viewModel = createViewModel()
viewModel.trySendAction(EnvironmentAction.ChooseSystemCertificateClick)
assertEquals(
DEFAULT_STATE.copy(
dialog = EnvironmentState.DialogState.SystemCertificateWarningDialog,
),
viewModel.stateFlow.value,
)
}
@Test
fun `ErrorDialogDismiss should clear the dialog`() = runTest {
val viewModel = createViewModel()
viewModel.trySendAction(EnvironmentAction.ErrorDialogDismiss)
assertEquals(
DEFAULT_STATE.copy(dialog = null),
viewModel.stateFlow.value,
)
}
@Test
fun `ImportCertificateClick should emit ShowCertificateImportFileChooser`() = runTest {
val viewModel = createViewModel()
viewModel.eventFlow.test {
viewModel.trySendAction(EnvironmentAction.ImportCertificateClick)
assertEquals(EnvironmentEvent.ShowCertificateImportFileChooser, awaitItem())
}
}
@Test
fun `ImportCertificateFilePickerResultReceive should show SetCertificateData dialog`() =
runTest {
val viewModel = createViewModel()
val mockFileData = mockk<IntentManager.FileData>()
viewModel.trySendAction(
EnvironmentAction.ImportCertificateFilePickerResultReceive(
certificateFileData = mockFileData,
),
)
assertEquals(
DEFAULT_STATE.copy(
dialog = EnvironmentState.DialogState.SetCertificateData(
certificateBytes = mockFileData,
),
),
viewModel.stateFlow.value,
)
}
@Test
fun `SetCertificatePasswordDialogDismiss should clear the dialog`() = runTest {
val viewModel = createViewModel()
viewModel.trySendAction(EnvironmentAction.SetCertificatePasswordDialogDismiss)
assertEquals(
DEFAULT_STATE.copy(dialog = null),
viewModel.stateFlow.value,
)
}
@Test
fun `CertificateInstallationResultReceive should show toast based on result`() = runTest {
val viewModel = createViewModel()
viewModel.eventFlow.test {
viewModel.trySendAction(
EnvironmentAction.CertificateInstallationResultReceive(
success = true,
),
)
assertEquals(
EnvironmentEvent.ShowToast(R.string.certificate_installed.asText()),
awaitItem(),
)
viewModel.trySendAction(
EnvironmentAction.CertificateInstallationResultReceive(
success = false,
),
)
assertEquals(
EnvironmentEvent.ShowToast(R.string.certificate_installation_failed.asText()),
awaitItem(),
)
}
}
@Suppress("MaxLineLength")
@Test
fun `ConfirmChooseSystemCertificateClick should clear the dialog and emit ShowSystemCertificateSelectionDialog`() =
runTest {
val viewModel = createViewModel(
savedStateHandle = SavedStateHandle(
initialState = mapOf(
"state" to DEFAULT_STATE.copy(serverUrl = "https://mockServerUrl"),
),
),
)
viewModel.trySendAction(EnvironmentAction.ConfirmChooseSystemCertificateClick)
assertEquals(
DEFAULT_STATE.copy(dialog = null, serverUrl = "https://mockServerUrl"),
viewModel.stateFlow.value,
)
viewModel.eventFlow.test {
assertEquals(
EnvironmentEvent.ShowSystemCertificateSelectionDialog(
serverUrl = "https://mockServerUrl",
),
awaitItem(),
)
}
}
@Test
fun `ImportKeyResultReceive should update key alias and key host on success`() = runTest {
val viewModel = createViewModel()
viewModel.trySendAction(
EnvironmentAction.Internal.ImportKeyResultReceive(
result = ImportPrivateKeyResult.Success(alias = "mockAlias"),
),
)
assertEquals(
DEFAULT_STATE.copy(
keyAlias = "mockAlias",
keyHost = MutualTlsKeyHost.ANDROID_KEY_STORE,
),
viewModel.stateFlow.value,
)
}
@Test
fun `ImportKeyResultReceive should show toast with correct message on error`() = runTest {
val viewModel = createViewModel()
viewModel.eventFlow.test {
viewModel.trySendAction(
EnvironmentAction.Internal.ImportKeyResultReceive(
result = ImportPrivateKeyResult.Error.UnsupportedKey,
),
)
assertEquals(
EnvironmentEvent.ShowToast(R.string.unsupported_certificate_type.asText()),
awaitItem(),
)
viewModel.trySendAction(
EnvironmentAction.Internal.ImportKeyResultReceive(
result = ImportPrivateKeyResult.Error.KeyStoreOperationFailed,
),
)
assertEquals(
EnvironmentEvent.ShowToast(R.string.certificate_installation_failed.asText()),
awaitItem(),
)
viewModel.trySendAction(
EnvironmentAction.Internal.ImportKeyResultReceive(
result = ImportPrivateKeyResult.Error.UnrecoverableKey,
),
)
assertEquals(
EnvironmentEvent.ShowToast(R.string.certificate_password_incorrect.asText()),
awaitItem(),
)
viewModel.trySendAction(
EnvironmentAction.Internal.ImportKeyResultReceive(
result = ImportPrivateKeyResult.Error.InvalidCertificateChain,
),
)
assertEquals(
EnvironmentEvent.ShowToast(R.string.invalid_certificate_chain.asText()),
awaitItem(),
)
viewModel.trySendAction(
EnvironmentAction.Internal.ImportKeyResultReceive(
result = ImportPrivateKeyResult.Error.DuplicateAlias,
),
)
assertEquals(
EnvironmentEvent.ShowToast(R.string.certificate_alias_already_exists.asText()),
awaitItem(),
)
}
}
@Suppress("MaxLineLength")
@Test
fun `SetCertificateInfoResultReceive should clear the dialog update key alias and key host after successful import`() =
runTest {
val viewModel = createViewModel()
val mockUri = mockk<Uri>()
val mockFileData = IntentManager.FileData(
fileName = "mockFileName",
uri = mockUri,
sizeBytes = 0,
)
val keyBytes = byteArrayOf()
coEvery {
mockFileManager.uriToByteArray(mockFileData.uri)
} returns keyBytes.asSuccess()
coEvery {
mockKeyManager.importMutualTlsCertificate(
key = keyBytes,
alias = "mockAlias",
password = "mockPassword",
)
} returns ImportPrivateKeyResult.Success(alias = "mockAlias")
viewModel.trySendAction(
EnvironmentAction.SetCertificateInfoResultReceive(
certificateFileData = mockFileData,
alias = "mockAlias",
password = "mockPassword",
),
)
assertEquals(
DEFAULT_STATE.copy(
dialog = null,
keyAlias = "mockAlias",
keyHost = MutualTlsKeyHost.ANDROID_KEY_STORE,
),
viewModel.stateFlow.value,
)
coVerify {
mockFileManager.uriToByteArray(mockFileData.uri)
mockKeyManager.importMutualTlsCertificate(
key = byteArrayOf(),
alias = "mockAlias",
password = "mockPassword",
)
}
}
@Test
fun `SetCertificateInfoResultReceive should show error dialog if input is invalid`() = runTest {
val viewModel = createViewModel()
viewModel.trySendAction(
EnvironmentAction.SetCertificateInfoResultReceive(
certificateFileData = mockk(),
alias = "mockAlias",
password = "",
),
)
assertEquals(
DEFAULT_STATE.copy(
dialog = EnvironmentState.DialogState.Error(
R.string.validation_field_required.asText(
R.string.password.asText(),
),
),
),
viewModel.stateFlow.value,
)
viewModel.trySendAction(
EnvironmentAction.SetCertificateInfoResultReceive(
certificateFileData = mockk(),
alias = "",
password = "mockPassword",
),
)
assertEquals(
DEFAULT_STATE.copy(
dialog = EnvironmentState.DialogState.Error(
R.string.validation_field_required.asText(
R.string.alias.asText(),
),
),
),
viewModel.stateFlow.value,
)
}
//region Helper methods
private fun createViewModel(
@@ -300,6 +666,9 @@ class EnvironmentViewModelTest : BaseViewModelTest() {
): EnvironmentViewModel =
EnvironmentViewModel(
environmentRepository = fakeEnvironmentRepository,
featureFlagManager = mockFeatureFlagManager,
keyManager = mockKeyManager,
fileManager = mockFileManager,
savedStateHandle = savedStateHandle,
)
@@ -308,11 +677,14 @@ class EnvironmentViewModelTest : BaseViewModelTest() {
companion object {
private val DEFAULT_STATE = EnvironmentState(
serverUrl = "",
keyAlias = "",
webVaultServerUrl = "",
apiServerUrl = "",
identityServerUrl = "",
iconsServerUrl = "",
shouldShowErrorDialog = false,
keyHost = null,
dialog = null,
showMutualTlsOptions = true,
)
}
}

View File

@@ -0,0 +1,135 @@
package com.x8bit.bitwarden.ui.platform.manager.keychain
import android.app.Activity
import android.net.Uri
import android.security.KeyChain
import android.security.KeyChainAliasCallback
import androidx.core.net.toUri
import com.x8bit.bitwarden.ui.platform.manager.keychain.model.PrivateKeyAliasSelectionResult
import io.mockk.every
import io.mockk.mockk
import io.mockk.mockkStatic
import io.mockk.slot
import io.mockk.unmockkStatic
import kotlinx.coroutines.test.runTest
import org.junit.jupiter.api.AfterEach
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.Assertions.assertNull
import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Test
import org.junit.jupiter.api.assertInstanceOf
class KeyChainManagerTest {
private val mockActivity = mockk<Activity>()
private val keyChainManager = KeyChainManagerImpl(activity = mockActivity)
@BeforeEach
fun setUp() {
mockkStatic(KeyChain::class)
}
@AfterEach
fun tearDown() {
unmockkStatic(KeyChain::class)
}
@Test
fun `choosePrivateKeyAlias should return Success with alias when key is selected`() = runTest {
setupMockUri()
val systemCallbackCaptor = slot<KeyChainAliasCallback>()
every {
KeyChain.choosePrivateKeyAlias(
/* activity = */ mockActivity,
/* response = */ capture(systemCallbackCaptor),
/* keyTypes = */ null,
/* issuers = */ null,
/* uri = */ null,
/* alias = */ null,
)
} answers {
systemCallbackCaptor.captured.alias("mockAlias")
}
val result = keyChainManager.choosePrivateKeyAlias(currentServerUrl = null)
assertEquals(
PrivateKeyAliasSelectionResult.Success("mockAlias"),
result,
)
}
@Test
fun `choosePrivateKeyAlias should return Error when IllegalArgumentException is thrown`() =
runTest {
setupMockUri()
every {
KeyChain.choosePrivateKeyAlias(
/* activity = */ mockActivity,
/* response = */ any(),
/* keyTypes = */ null,
/* issuers = */ null,
/* uri = */ null,
/* alias = */ null,
)
} throws IllegalArgumentException()
val result = keyChainManager.choosePrivateKeyAlias(currentServerUrl = null)
assertInstanceOf<PrivateKeyAliasSelectionResult.Error>(result)
}
@Test
fun `choosePrivateKeyAlias should pass currentServerUrl to system KeyChain`() = runTest {
setupMockUri()
val systemCallbackCaptor = slot<KeyChainAliasCallback>()
every {
KeyChain.choosePrivateKeyAlias(
/* activity = */ mockActivity,
/* response = */ capture(systemCallbackCaptor),
/* keyTypes = */ null,
/* issuers = */ null,
/* uri = */ "www.mockuri.com".toUri(),
/* alias = */ null,
)
} answers {
systemCallbackCaptor.captured.alias("mockAlias")
}
val result = keyChainManager.choosePrivateKeyAlias(currentServerUrl = "www.mockuri.com")
assertInstanceOf<PrivateKeyAliasSelectionResult.Success>(result)
.also { assertEquals("mockAlias", it.alias) }
}
@Test
fun `choosePrivateKeyAlias should return Success with null alias when no key is selected`() =
runTest {
setupMockUri()
val systemCallbackCaptor = slot<KeyChainAliasCallback>()
every {
KeyChain.choosePrivateKeyAlias(
/* activity = */ mockActivity,
/* response = */ capture(systemCallbackCaptor),
/* keyTypes = */ null,
/* issuers = */ null,
/* uri = */ null,
/* alias = */ null,
)
} answers {
systemCallbackCaptor.captured.alias(null)
}
val result = keyChainManager.choosePrivateKeyAlias(currentServerUrl = null)
assertInstanceOf<PrivateKeyAliasSelectionResult.Success>(result)
.also { assertNull(it.alias) }
}
private fun setupMockUri() {
mockkStatic(Uri::class)
val uriMock = mockk<Uri>()
every { Uri.parse(any()) } returns uriMock
every { uriMock.host } returns "www.mockuri.com"
}
}