From 464f8de5f5f3e59ea39e5a42d80a2cc36ecea6f1 Mon Sep 17 00:00:00 2001 From: Patrick Honkonen <1883101+SaintPatrck@users.noreply.github.com> Date: Fri, 24 Jan 2025 15:55:33 -0500 Subject: [PATCH] [PM-17424] Implement KeyManager for handling private keys (#4608) --- .../disk/model/ImportPrivateKeyResult.kt | 45 ++ .../disk/model/MutualTlsCertificate.kt | 36 + .../datasource/disk/model/MutualTlsKeyHost.kt | 16 + .../data/platform/manager/KeyManager.kt | 37 + .../data/platform/manager/KeyManagerImpl.kt | 188 +++++ .../manager/di/PlatformManagerModule.kt | 8 + .../data/platform/manager/KeyManagerTest.kt | 649 ++++++++++++++++++ 7 files changed, 979 insertions(+) create mode 100644 app/src/main/java/com/x8bit/bitwarden/data/platform/datasource/disk/model/ImportPrivateKeyResult.kt create mode 100644 app/src/main/java/com/x8bit/bitwarden/data/platform/datasource/disk/model/MutualTlsCertificate.kt create mode 100644 app/src/main/java/com/x8bit/bitwarden/data/platform/datasource/disk/model/MutualTlsKeyHost.kt create mode 100644 app/src/main/java/com/x8bit/bitwarden/data/platform/manager/KeyManager.kt create mode 100644 app/src/main/java/com/x8bit/bitwarden/data/platform/manager/KeyManagerImpl.kt create mode 100644 app/src/test/java/com/x8bit/bitwarden/data/platform/manager/KeyManagerTest.kt diff --git a/app/src/main/java/com/x8bit/bitwarden/data/platform/datasource/disk/model/ImportPrivateKeyResult.kt b/app/src/main/java/com/x8bit/bitwarden/data/platform/datasource/disk/model/ImportPrivateKeyResult.kt new file mode 100644 index 0000000000..b0bcf9fffc --- /dev/null +++ b/app/src/main/java/com/x8bit/bitwarden/data/platform/datasource/disk/model/ImportPrivateKeyResult.kt @@ -0,0 +1,45 @@ +package com.x8bit.bitwarden.data.platform.datasource.disk.model + +/** + * Models the result of importing a private key. + */ +sealed class ImportPrivateKeyResult { + + /** + * Represents a successful result of importing a private key. + * + * @property alias The alias assigned to the imported private key. + */ + data class Success(val alias: String) : ImportPrivateKeyResult() + + /** + * Represents a generic error during the import process. + */ + sealed class Error : ImportPrivateKeyResult() { + + /** + * Indicates that the provided key is unrecoverable or the password is incorrect. + */ + data object UnrecoverableKey : Error() + + /** + * Indicates that the certificate chain associated with the key is invalid. + */ + data object InvalidCertificateChain : Error() + + /** + * Indicates that the specified alias is already in use. + */ + data object DuplicateAlias : Error() + + /** + * Indicates that an error occurred during the key store operation. + */ + data object KeyStoreOperationFailed : Error() + + /** + * Indicates the provided key is not supported. + */ + data object UnsupportedKey : Error() + } +} diff --git a/app/src/main/java/com/x8bit/bitwarden/data/platform/datasource/disk/model/MutualTlsCertificate.kt b/app/src/main/java/com/x8bit/bitwarden/data/platform/datasource/disk/model/MutualTlsCertificate.kt new file mode 100644 index 0000000000..b7d3d8031c --- /dev/null +++ b/app/src/main/java/com/x8bit/bitwarden/data/platform/datasource/disk/model/MutualTlsCertificate.kt @@ -0,0 +1,36 @@ +package com.x8bit.bitwarden.data.platform.datasource.disk.model + +import java.security.PrivateKey +import java.security.cert.X509Certificate + +/** + * Represents a mutual TLS certificate. + */ +data class MutualTlsCertificate( + val alias: String, + val privateKey: PrivateKey, + val certificateChain: List, +) { + /** + * Leaf certificate of the chain. + */ + val leafCertificate: X509Certificate? + get() = certificateChain.lastOrNull() + + /** + * Root certificate of the chain. + */ + val rootCertificate: X509Certificate? + get() = certificateChain.firstOrNull() + + override fun toString(): String = leafCertificate + ?.let { + buildString { + appendLine("Subject: ${it.subjectDN}") + appendLine("Issuer: ${it.issuerDN}") + appendLine("Valid From: ${it.notBefore}") + appendLine("Valid Until: ${it.notAfter}") + } + } + ?: "" +} diff --git a/app/src/main/java/com/x8bit/bitwarden/data/platform/datasource/disk/model/MutualTlsKeyHost.kt b/app/src/main/java/com/x8bit/bitwarden/data/platform/datasource/disk/model/MutualTlsKeyHost.kt new file mode 100644 index 0000000000..dd254b235d --- /dev/null +++ b/app/src/main/java/com/x8bit/bitwarden/data/platform/datasource/disk/model/MutualTlsKeyHost.kt @@ -0,0 +1,16 @@ +package com.x8bit.bitwarden.data.platform.datasource.disk.model + +/** + * Location of the key data. + */ +enum class MutualTlsKeyHost { + /** + * Key is stored in the system key chain. + */ + KEY_CHAIN, + + /** + * Key is stored in a private instance of the Android Key Store. + */ + ANDROID_KEY_STORE, +} diff --git a/app/src/main/java/com/x8bit/bitwarden/data/platform/manager/KeyManager.kt b/app/src/main/java/com/x8bit/bitwarden/data/platform/manager/KeyManager.kt new file mode 100644 index 0000000000..4ce0a9d42a --- /dev/null +++ b/app/src/main/java/com/x8bit/bitwarden/data/platform/manager/KeyManager.kt @@ -0,0 +1,37 @@ +package com.x8bit.bitwarden.data.platform.manager + +import com.x8bit.bitwarden.data.platform.datasource.disk.model.ImportPrivateKeyResult +import com.x8bit.bitwarden.data.platform.datasource.disk.model.MutualTlsCertificate +import com.x8bit.bitwarden.data.platform.datasource.disk.model.MutualTlsKeyHost + +/** + * Primary access point for disk information related to key data. + */ +interface KeyManager { + + /** + * Import a private key into the application KeyStore. + * + * @param key The private key to be saved. + * @param alias Alias to be assigned to the private key. + * @param password Password used to protect the certificate. + */ + fun importMutualTlsCertificate( + key: ByteArray, + alias: String, + password: String, + ): ImportPrivateKeyResult + + /** + * Removes the mTLS key from storage. + */ + fun removeMutualTlsKey(alias: String, host: MutualTlsKeyHost) + + /** + * Retrieve the certificate chain for the selected mTLS key. + */ + fun getMutualTlsCertificateChain( + alias: String, + host: MutualTlsKeyHost, + ): MutualTlsCertificate? +} diff --git a/app/src/main/java/com/x8bit/bitwarden/data/platform/manager/KeyManagerImpl.kt b/app/src/main/java/com/x8bit/bitwarden/data/platform/manager/KeyManagerImpl.kt new file mode 100644 index 0000000000..19e9cb3b39 --- /dev/null +++ b/app/src/main/java/com/x8bit/bitwarden/data/platform/manager/KeyManagerImpl.kt @@ -0,0 +1,188 @@ +package com.x8bit.bitwarden.data.platform.manager + +import android.content.Context +import android.security.KeyChain +import android.security.KeyChainException +import com.x8bit.bitwarden.data.platform.datasource.disk.model.ImportPrivateKeyResult +import com.x8bit.bitwarden.data.platform.datasource.disk.model.MutualTlsCertificate +import com.x8bit.bitwarden.data.platform.datasource.disk.model.MutualTlsKeyHost +import timber.log.Timber +import java.io.IOException +import java.security.KeyStore +import java.security.KeyStoreException +import java.security.NoSuchAlgorithmException +import java.security.PrivateKey +import java.security.UnrecoverableKeyException +import java.security.cert.Certificate +import java.security.cert.CertificateException +import java.security.cert.X509Certificate + +/** + * Default implementation of [KeyManager]. + */ +class KeyManagerImpl( + private val context: Context, +) : KeyManager { + + @Suppress("CyclomaticComplexMethod") + override fun importMutualTlsCertificate( + key: ByteArray, + alias: String, + password: String, + ): ImportPrivateKeyResult { + // Step 1: Load PKCS12 bytes into a KeyStore. + val pkcs12KeyStore: KeyStore = key + .inputStream() + .use { stream -> + try { + KeyStore.getInstance(KEYSTORE_TYPE_PKCS12) + .also { it.load(stream, password.toCharArray()) } + } catch (e: KeyStoreException) { + Timber.Forest.e(e, "Failed to load PKCS12 bytes") + return ImportPrivateKeyResult.Error.UnsupportedKey + } catch (e: IOException) { + Timber.Forest.e(e, "Format or password error while loading PKCS12 bytes") + return when (e.cause) { + is UnrecoverableKeyException -> { + ImportPrivateKeyResult.Error.UnrecoverableKey + } + + else -> { + ImportPrivateKeyResult.Error.KeyStoreOperationFailed + } + } + } catch (e: CertificateException) { + Timber.Forest.e(e, "Unable to load certificate chain") + return ImportPrivateKeyResult.Error.InvalidCertificateChain + } catch (e: NoSuchAlgorithmException) { + Timber.Forest.e(e, "Cryptographic algorithm not supported") + return ImportPrivateKeyResult.Error.UnsupportedKey + } + } + + // Step 2: Get a list of aliases and choose the first one. + val internalAlias = pkcs12KeyStore.aliases() + ?.takeIf { it.hasMoreElements() } + ?.nextElement() + ?: return ImportPrivateKeyResult.Error.UnsupportedKey + + // Step 3: Extract PrivateKey and X.509 certificate from the KeyStore and verify + // certificate alias. + val privateKey = try { + pkcs12KeyStore.getKey(internalAlias, password.toCharArray()) + ?: return ImportPrivateKeyResult.Error.UnrecoverableKey + } catch (e: UnrecoverableKeyException) { + Timber.Forest.e(e, "Failed to get private key") + return ImportPrivateKeyResult.Error.UnrecoverableKey + } + + val certChain: Array = pkcs12KeyStore + .getCertificateChain(internalAlias) + ?.takeUnless { it.isEmpty() } + ?: return ImportPrivateKeyResult.Error.InvalidCertificateChain + + // Step 4: Store the private key and X.509 certificate in the AndroidKeyStore if the alias + // does not exists. + with(androidKeyStore) { + if (containsAlias(alias)) { + return ImportPrivateKeyResult.Error.DuplicateAlias + } + + try { + setKeyEntry(alias, privateKey, null, certChain) + } catch (e: KeyStoreException) { + Timber.Forest.e(e, "Failed to import key into Android KeyStore") + return ImportPrivateKeyResult.Error.KeyStoreOperationFailed + } + } + return ImportPrivateKeyResult.Success(alias) + } + + override fun removeMutualTlsKey( + alias: String, + host: MutualTlsKeyHost, + ) { + when (host) { + MutualTlsKeyHost.ANDROID_KEY_STORE -> removeKeyFromAndroidKeyStore(alias) + else -> Unit + } + } + + override fun getMutualTlsCertificateChain( + alias: String, + host: MutualTlsKeyHost, + ): MutualTlsCertificate? = when (host) { + MutualTlsKeyHost.ANDROID_KEY_STORE -> getKeyFromAndroidKeyStore(alias) + + MutualTlsKeyHost.KEY_CHAIN -> getSystemKeySpecOrNull(alias) + } + + private fun removeKeyFromAndroidKeyStore(alias: String) { + try { + androidKeyStore.deleteEntry(alias) + } catch (e: KeyStoreException) { + Timber.Forest.e(e, "Failed to remove key from Android KeyStore") + } + } + + private fun getSystemKeySpecOrNull(alias: String): MutualTlsCertificate? { + val systemPrivateKey = try { + KeyChain.getPrivateKey(context, alias) + } catch (e: KeyChainException) { + Timber.Forest.e(e, "Requested alias not found in system KeyChain") + null + } + ?: return null + + val systemCertificateChain = try { + KeyChain.getCertificateChain(context, alias) + } catch (e: KeyChainException) { + Timber.Forest.e(e, "Unable to access certificate chain for provided alias") + null + } + ?: return null + + return MutualTlsCertificate( + alias = alias, + certificateChain = systemCertificateChain.toList(), + privateKey = systemPrivateKey, + ) + } + + private fun getKeyFromAndroidKeyStore(alias: String): MutualTlsCertificate? = + with(androidKeyStore) { + try { + val privateKeyRef = (getKey(alias, null) as? PrivateKey) + ?: return null + val certChain = getCertificateChain(alias) + .mapNotNull { it as? X509Certificate } + .takeUnless { it.isEmpty() } + ?: return null + MutualTlsCertificate( + alias = alias, + certificateChain = certChain, + privateKey = privateKeyRef, + ) + } catch (e: KeyStoreException) { + Timber.Forest.e(e, "Failed to load Android KeyStore") + null + } catch (e: UnrecoverableKeyException) { + Timber.Forest.e(e, "Failed to load client certificate from Android KeyStore") + null + } catch (e: NoSuchAlgorithmException) { + Timber.Forest.e(e, "Key cannot be recovered. Password may be incorrect.") + null + } catch (e: NoSuchAlgorithmException) { + Timber.Forest.e(e, "Algorithm not supported") + null + } + } + + private val androidKeyStore + get() = KeyStore + .getInstance(KEYSTORE_TYPE_ANDROID) + .also { it.load(null) } +} + +private const val KEYSTORE_TYPE_ANDROID = "AndroidKeyStore" +private const val KEYSTORE_TYPE_PKCS12 = "pkcs12" diff --git a/app/src/main/java/com/x8bit/bitwarden/data/platform/manager/di/PlatformManagerModule.kt b/app/src/main/java/com/x8bit/bitwarden/data/platform/manager/di/PlatformManagerModule.kt index 311458e6e1..f7832343bf 100644 --- a/app/src/main/java/com/x8bit/bitwarden/data/platform/manager/di/PlatformManagerModule.kt +++ b/app/src/main/java/com/x8bit/bitwarden/data/platform/manager/di/PlatformManagerModule.kt @@ -28,6 +28,8 @@ import com.x8bit.bitwarden.data.platform.manager.FeatureFlagManager import com.x8bit.bitwarden.data.platform.manager.FeatureFlagManagerImpl import com.x8bit.bitwarden.data.platform.manager.FirstTimeActionManager import com.x8bit.bitwarden.data.platform.manager.FirstTimeActionManagerImpl +import com.x8bit.bitwarden.data.platform.manager.KeyManager +import com.x8bit.bitwarden.data.platform.manager.KeyManagerImpl import com.x8bit.bitwarden.data.platform.manager.LogsManager import com.x8bit.bitwarden.data.platform.manager.LogsManagerImpl import com.x8bit.bitwarden.data.platform.manager.PolicyManager @@ -329,4 +331,10 @@ object PlatformManagerModule { autofillEnabledManager = autofillEnabledManager, accessibilityEnabledManager = accessibilityEnabledManager, ) + + @Provides + @Singleton + fun provideKeyManager( + @ApplicationContext context: Context, + ): KeyManager = KeyManagerImpl(context = context) } diff --git a/app/src/test/java/com/x8bit/bitwarden/data/platform/manager/KeyManagerTest.kt b/app/src/test/java/com/x8bit/bitwarden/data/platform/manager/KeyManagerTest.kt new file mode 100644 index 0000000000..1d2b57fab6 --- /dev/null +++ b/app/src/test/java/com/x8bit/bitwarden/data/platform/manager/KeyManagerTest.kt @@ -0,0 +1,649 @@ +package com.x8bit.bitwarden.data.platform.manager + +import android.content.Context +import android.security.KeyChain +import android.security.KeyChainException +import com.x8bit.bitwarden.data.platform.datasource.disk.model.ImportPrivateKeyResult +import com.x8bit.bitwarden.data.platform.datasource.disk.model.MutualTlsCertificate +import com.x8bit.bitwarden.data.platform.datasource.disk.model.MutualTlsKeyHost +import io.mockk.every +import io.mockk.just +import io.mockk.mockk +import io.mockk.mockkStatic +import io.mockk.runs +import io.mockk.unmockkStatic +import io.mockk.verify +import org.junit.jupiter.api.AfterEach +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.Assertions.assertNull +import org.junit.jupiter.api.BeforeEach +import org.junit.jupiter.api.Test +import java.io.IOException +import java.security.KeyStore +import java.security.KeyStoreException +import java.security.NoSuchAlgorithmException +import java.security.PrivateKey +import java.security.UnrecoverableKeyException +import java.security.cert.Certificate +import java.security.cert.CertificateException +import java.security.cert.X509Certificate + +class KeyManagerTest { + private val mockContext = mockk() + private val mockAndroidKeyStore = mockk(name = "MockAndroidKeyStore") + private val mockPkcs12KeyStore = mockk(name = "MockPKCS12KeyStore") + private val keyDiskSource = KeyManagerImpl( + context = mockContext, + ) + + @BeforeEach + fun setUp() { + mockkStatic(KeyStore::class, KeyChain::class) + } + + @AfterEach + fun tearDown() { + unmockkStatic(KeyStore::class, KeyChain::class) + } + + @Test + fun `getMutualTlsCertificateChain should return null when MutualTlsKeyAlias is not found`() { + // Verify null is returned when alias is not found in KeyChain + setupMockAndroidKeyStore() + every { KeyChain.getPrivateKey(mockContext, "mockAlias") } throws KeyChainException() + assertNull( + keyDiskSource.getMutualTlsCertificateChain( + alias = "mockAlias", + host = MutualTlsKeyHost.KEY_CHAIN, + ), + ) + + // Verify null is returned when alias is not found in AndroidKeyStore + every { mockAndroidKeyStore.getKey("mockAlias", null) } throws UnrecoverableKeyException() + assertNull( + keyDiskSource.getMutualTlsCertificateChain( + alias = "mockAlias", + host = MutualTlsKeyHost.ANDROID_KEY_STORE, + ), + ) + } + + @Suppress("MaxLineLength") + @Test + fun `getMutualTlsCertificateChain should return MutualTlsCertificateChain when using ANDROID KEY STORE and key is found`() { + setupMockAndroidKeyStore() + val mockAlias = "mockAlias" + val mockPrivateKey = mockk() + val mockCertificate1 = mockk(name = "mockCertificate1") + val mockCertificate2 = mockk(name = "mockCertificate2") + every { + mockAndroidKeyStore.getCertificateChain(mockAlias) + } returns arrayOf(mockCertificate1, mockCertificate2) + every { + mockAndroidKeyStore.getKey(mockAlias, null) + } returns mockPrivateKey + + val result = keyDiskSource.getMutualTlsCertificateChain( + alias = mockAlias, + host = MutualTlsKeyHost.ANDROID_KEY_STORE, + ) + + assertEquals( + MutualTlsCertificate( + alias = mockAlias, + certificateChain = listOf(mockCertificate1, mockCertificate2), + privateKey = mockPrivateKey, + ), + result, + ) + } + + @Suppress("MaxLineLength") + @Test + fun `getMutualTlsCertificateChain should return null when using ANDROID KEY STORE and key is not found`() { + setupMockAndroidKeyStore() + val mockAlias = "mockAlias" + val mockCertificate1 = mockk(name = "mockCertificate1") + val mockCertificate2 = mockk(name = "mockCertificate2") + every { + mockAndroidKeyStore.getCertificateChain(mockAlias) + } returns arrayOf(mockCertificate1, mockCertificate2) + every { + mockAndroidKeyStore.getKey(mockAlias, null) + } returns null + + assertNull( + keyDiskSource.getMutualTlsCertificateChain( + alias = mockAlias, + host = MutualTlsKeyHost.ANDROID_KEY_STORE, + ), + ) + } + + @Suppress("MaxLineLength") + @Test + fun `getMutualTlsCertificateChain should return null when using ANDROID KEY STORE and certificate chain is invalid`() { + setupMockAndroidKeyStore() + val mockAlias = "mockAlias" + every { + mockAndroidKeyStore.getKey(mockAlias, null) + } returns mockk() + + // Verify null is returned when certificate chain is empty + every { + mockAndroidKeyStore.getCertificateChain(mockAlias) + } returns emptyArray() + assertNull( + keyDiskSource.getMutualTlsCertificateChain( + alias = mockAlias, + host = MutualTlsKeyHost.ANDROID_KEY_STORE, + ), + ) + + // Verify null is returned when certificate chain contains non-X509Certificate objects + every { + mockAndroidKeyStore.getCertificateChain(mockAlias) + } returns arrayOf(mockk()) + assertNull( + keyDiskSource.getMutualTlsCertificateChain( + alias = mockAlias, + host = MutualTlsKeyHost.ANDROID_KEY_STORE, + ), + ) + } + + @Suppress("MaxLineLength") + @Test + fun `getMutualTlsCertificateChain should return null when using ANDROID KEY STORE and an exception occurs`() { + setupMockAndroidKeyStore() + val mockAlias = "mockAlias" + val mockCertificate1 = mockk(name = "mockCertificate1") + val mockCertificate2 = mockk(name = "mockCertificate2") + every { + mockAndroidKeyStore.getCertificateChain(mockAlias) + } returns arrayOf(mockCertificate1, mockCertificate2) + + // Verify KeyStoreException is handled + every { + mockAndroidKeyStore.getKey(mockAlias, null) + } throws KeyStoreException() + + assertNull( + keyDiskSource.getMutualTlsCertificateChain( + alias = mockAlias, + host = MutualTlsKeyHost.ANDROID_KEY_STORE, + ), + ) + + // Verify UnrecoverableKeyException is handled + every { + mockAndroidKeyStore.getKey(mockAlias, null) + } throws UnrecoverableKeyException() + assertNull( + keyDiskSource.getMutualTlsCertificateChain( + alias = mockAlias, + host = MutualTlsKeyHost.ANDROID_KEY_STORE, + ), + ) + + // Verify NoSuchAlgorithmException is handled + every { + mockAndroidKeyStore.getKey(mockAlias, null) + } throws NoSuchAlgorithmException() + assertNull( + keyDiskSource.getMutualTlsCertificateChain( + alias = mockAlias, + host = MutualTlsKeyHost.ANDROID_KEY_STORE, + ), + ) + } + + @Suppress("MaxLineLength") + @Test + fun `getMutualTlsCertificateChain should return MutualTlsCertificateChain when using KEY CHAIN and key is found`() { + val mockAlias = "mockAlias" + val mockPrivateKey = mockk() + val mockCertificate1 = mockk(name = "mockCertificate1") + val mockCertificate2 = mockk(name = "mockCertificate2") + every { + KeyChain.getCertificateChain(mockContext, mockAlias) + } returns arrayOf(mockCertificate1, mockCertificate2) + every { + KeyChain.getPrivateKey(mockContext, mockAlias) + } returns mockPrivateKey + + val result = keyDiskSource.getMutualTlsCertificateChain( + alias = mockAlias, + host = MutualTlsKeyHost.KEY_CHAIN, + ) + + assertEquals( + MutualTlsCertificate( + alias = mockAlias, + certificateChain = listOf(mockCertificate1, mockCertificate2), + privateKey = mockPrivateKey, + ), + result, + ) + } + + @Suppress("MaxLineLength") + @Test + fun `getMutualTlsCertificateChain should return null when using KEY CHAIN and key is not found`() { + val mockAlias = "mockAlias" + val mockCertificate1 = mockk(name = "mockCertificate1") + val mockCertificate2 = mockk(name = "mockCertificate2") + every { + KeyChain.getCertificateChain(mockContext, mockAlias) + } returns arrayOf(mockCertificate1, mockCertificate2) + every { + KeyChain.getPrivateKey(mockContext, mockAlias) + } returns null + + assertNull( + keyDiskSource.getMutualTlsCertificateChain( + alias = mockAlias, + host = MutualTlsKeyHost.KEY_CHAIN, + ), + ) + } + + @Suppress("MaxLineLength") + @Test + fun `getMutualTlsCertificateChain should return null when using KEY CHAIN and an exception occurs`() { + val mockAlias = "mockAlias" + val mockCertificate1 = mockk(name = "mockCertificate1") + val mockCertificate2 = mockk(name = "mockCertificate2") + + every { + KeyChain.getCertificateChain(mockContext, mockAlias) + } returns arrayOf(mockCertificate1, mockCertificate2) + + // Verify KeyChainException from getPrivateKey is handled + every { + KeyChain.getPrivateKey(mockContext, mockAlias) + } throws KeyChainException() + assertNull( + keyDiskSource.getMutualTlsCertificateChain( + alias = mockAlias, + host = MutualTlsKeyHost.KEY_CHAIN, + ), + ) + + // Verify KeyChainException from getCertificateChain is handled + every { KeyChain.getPrivateKey(mockContext, mockAlias) } returns mockk() + every { KeyChain.getCertificateChain(mockContext, mockAlias) } throws KeyChainException() + assertNull( + keyDiskSource.getMutualTlsCertificateChain( + alias = mockAlias, + host = MutualTlsKeyHost.KEY_CHAIN, + ), + ) + } + + @Suppress("MaxLineLength") + @Test + fun `removeMutualTlsKey should remove key from AndroidKeyStore when host is ANDROID_KEY_STORE`() { + setupMockAndroidKeyStore() + val mockAlias = "mockAlias" + + every { mockAndroidKeyStore.deleteEntry(mockAlias) } just runs + + keyDiskSource.removeMutualTlsKey( + alias = mockAlias, + host = MutualTlsKeyHost.ANDROID_KEY_STORE, + ) + + verify { + mockAndroidKeyStore.deleteEntry(mockAlias) + } + } + + @Test + fun `removeMutualTlsKey should do nothing when host is KEY_CHAIN`() { + keyDiskSource.removeMutualTlsKey( + alias = "mockAlias", + host = MutualTlsKeyHost.KEY_CHAIN, + ) + + verify(exactly = 0) { + mockAndroidKeyStore.deleteEntry(any()) + } + } + + @Test + fun `importMutualTlsCertificate should return Success when key is imported successfully`() { + setupMockAndroidKeyStore() + setupMockPkcs12KeyStore() + val expectedAlias = "mockAlias" + val internalAlias = "mockInternalAlias" + val privateKey = mockk() + val certChain = arrayOf(mockk()) + val pkcs12Bytes = "key.p12".toByteArray() + val password = "password" + every { mockPkcs12KeyStore.aliases() } returns mockk { + every { hasMoreElements() } returns true + every { nextElement() } returns internalAlias + } + every { + mockPkcs12KeyStore.setKeyEntry( + internalAlias, + privateKey, + null, + certChain, + ) + } just runs + every { + mockPkcs12KeyStore.getKey( + internalAlias, + password.toCharArray(), + ) + } returns privateKey + every { + mockPkcs12KeyStore.getCertificateChain(internalAlias) + } returns certChain + every { + mockAndroidKeyStore.containsAlias(expectedAlias) + } returns false + every { + mockAndroidKeyStore.setKeyEntry(expectedAlias, privateKey, null, certChain) + } just runs + + assertEquals( + ImportPrivateKeyResult.Success(alias = expectedAlias), + keyDiskSource.importMutualTlsCertificate( + key = pkcs12Bytes, + alias = expectedAlias, + password = password, + ), + ) + } + + @Test + fun `importMutualTlsCertificate should return Error when loading PKCS12 throws an exception`() { + setupMockPkcs12KeyStore() + val expectedAlias = "mockAlias" + val pkcs12Bytes = "key.p12".toByteArray() + val password = "password" + + // Verify KeyStoreException is handled + every { + mockPkcs12KeyStore.load(any(), any()) + } throws KeyStoreException() + assertEquals( + ImportPrivateKeyResult.Error.UnsupportedKey, + keyDiskSource.importMutualTlsCertificate( + key = pkcs12Bytes, + alias = expectedAlias, + password = password, + ), + ) { "KeyStoreException was not handled correctly" } + + // Verify IOException is handled + every { + mockPkcs12KeyStore.load(any(), any()) + } throws IOException() + assertEquals( + ImportPrivateKeyResult.Error.KeyStoreOperationFailed, + keyDiskSource.importMutualTlsCertificate( + key = pkcs12Bytes, + alias = expectedAlias, + password = password, + ), + ) { "IOException was not handled correctly" } + + // Verify IOException with UnrecoverableKeyException cause is handled + every { + mockPkcs12KeyStore.load(any(), any()) + } throws IOException(UnrecoverableKeyException()) + + assertEquals( + ImportPrivateKeyResult.Error.UnrecoverableKey, + keyDiskSource.importMutualTlsCertificate( + key = pkcs12Bytes, + alias = expectedAlias, + password = password, + ), + ) + + // Verify IOException with unexpected cause is handled + every { + mockPkcs12KeyStore.load(any(), any()) + } throws IOException(Exception()) + assertEquals( + ImportPrivateKeyResult.Error.KeyStoreOperationFailed, + keyDiskSource.importMutualTlsCertificate( + key = pkcs12Bytes, + alias = expectedAlias, + password = password, + ), + ) { "IOException with Unexpected exception cause was not handled correctly" } + + // Verify CertificateException is handled + every { + mockPkcs12KeyStore.load(any(), any()) + } throws CertificateException() + assertEquals( + ImportPrivateKeyResult.Error.InvalidCertificateChain, + keyDiskSource.importMutualTlsCertificate( + key = pkcs12Bytes, + alias = expectedAlias, + password = password, + ), + ) { "CertificateException was not handled correctly" } + + // Verify NoSuchAlgorithmException is handled + every { + mockPkcs12KeyStore.load(any(), any()) + } throws NoSuchAlgorithmException() + assertEquals( + ImportPrivateKeyResult.Error.UnsupportedKey, + keyDiskSource.importMutualTlsCertificate( + key = pkcs12Bytes, + alias = expectedAlias, + password = password, + ), + ) { "NoSuchAlgorithmException was not handled correctly" } + } + + @Test + fun `importMutualTlsCertificate should return UnsupportedKey when key store is empty`() { + setupMockPkcs12KeyStore() + val expectedAlias = "mockAlias" + val pkcs12Bytes = "key.p12".toByteArray() + val password = "password" + + every { mockPkcs12KeyStore.aliases() } returns mockk { + every { hasMoreElements() } returns false + } + + assertEquals( + ImportPrivateKeyResult.Error.UnsupportedKey, + keyDiskSource.importMutualTlsCertificate( + key = pkcs12Bytes, + alias = expectedAlias, + password = password, + ), + ) + } + + @Suppress("MaxLineLength") + @Test + fun `importMutualTlsCertificate should return UnrecoverableKey when unable to retrieve private key`() { + setupMockPkcs12KeyStore() + val expectedAlias = "mockAlias" + val pkcs12Bytes = "key.p12".toByteArray() + val password = "password" + + every { + mockPkcs12KeyStore.aliases() + } returns mockk { + every { hasMoreElements() } returns true + every { nextElement() } returns "mockInternalAlias" + } + every { + mockPkcs12KeyStore.getKey( + "mockInternalAlias", + password.toCharArray(), + ) + } throws UnrecoverableKeyException() + + assertEquals( + ImportPrivateKeyResult.Error.UnrecoverableKey, + keyDiskSource.importMutualTlsCertificate( + key = pkcs12Bytes, + alias = expectedAlias, + password = password, + ), + ) + + every { + mockPkcs12KeyStore.getKey( + "mockInternalAlias", + password.toCharArray(), + ) + } returns null + assertEquals( + ImportPrivateKeyResult.Error.UnrecoverableKey, + keyDiskSource.importMutualTlsCertificate( + key = pkcs12Bytes, + alias = expectedAlias, + password = password, + ), + ) + } + + @Suppress("MaxLineLength") + @Test + fun `importMutualTlsCertificate should return InvalidCertificateChain when certificate chain is empty`() { + setupMockPkcs12KeyStore() + val expectedAlias = "mockAlias" + val pkcs12Bytes = "key.p12".toByteArray() + val password = "password" + + every { mockPkcs12KeyStore.aliases() } returns mockk { + every { hasMoreElements() } returns true + every { nextElement() } returns "mockInternalAlias" + } + every { + mockPkcs12KeyStore.getKey( + "mockInternalAlias", + password.toCharArray(), + ) + } returns mockk() + + // Verify empty certificate chain is handled + every { + mockPkcs12KeyStore.getCertificateChain("mockInternalAlias") + } returns emptyArray() + assertEquals( + ImportPrivateKeyResult.Error.InvalidCertificateChain, + keyDiskSource.importMutualTlsCertificate( + key = pkcs12Bytes, + alias = expectedAlias, + password = password, + ), + ) + + // Verify null certificate chain is handled + every { + mockPkcs12KeyStore.getCertificateChain("mockInternalAlias") + } returns null + assertEquals( + ImportPrivateKeyResult.Error.InvalidCertificateChain, + keyDiskSource.importMutualTlsCertificate( + key = pkcs12Bytes, + alias = expectedAlias, + password = password, + ), + ) + } + + @Suppress("MaxLineLength") + @Test + fun `importMutualTlsCertificate should return KeyStoreOperationFailed when saving to Android KeyStore throws KeyStoreException`() { + setupMockAndroidKeyStore() + val expectedAlias = "mockAlias" + val pkcs12Bytes = "key.p12".toByteArray() + val password = "password" + + every { mockPkcs12KeyStore.aliases() } returns mockk { + every { hasMoreElements() } returns true + every { nextElement() } returns "mockInternalAlias" + } + + every { + mockPkcs12KeyStore.getKey( + "mockInternalAlias", + password.toCharArray(), + ) + } returns mockk() + every { + mockPkcs12KeyStore.getCertificateChain("mockInternalAlias") + } returns arrayOf(mockk()) + + every { + mockAndroidKeyStore.setKeyEntry( + expectedAlias, + any(), + any(), + any(), + ) + } throws KeyStoreException() + + assertEquals( + ImportPrivateKeyResult.Error.KeyStoreOperationFailed, + keyDiskSource.importMutualTlsCertificate( + key = pkcs12Bytes, + alias = expectedAlias, + password = password, + ), + ) + } + + @Suppress("MaxLineLength") + @Test + fun `importMutualTlsCertificate should return DuplicateAlias when alias already exists in AndroidKeyStore`() { + setupMockAndroidKeyStore() + setupMockPkcs12KeyStore() + val expectedAlias = "mockAlias" + val pkcs12Bytes = "key.p12".toByteArray() + val password = "password" + + every { mockPkcs12KeyStore.aliases() } returns mockk { + every { hasMoreElements() } returns true + every { nextElement() } returns "mockInternalAlias" + } + + every { + mockPkcs12KeyStore.getKey( + "mockInternalAlias", + password.toCharArray(), + ) + } returns mockk() + every { + mockPkcs12KeyStore.getCertificateChain("mockInternalAlias") + } returns arrayOf(mockk()) + + every { mockAndroidKeyStore.containsAlias(expectedAlias) } returns true + + assertEquals( + ImportPrivateKeyResult.Error.DuplicateAlias, + keyDiskSource.importMutualTlsCertificate( + key = pkcs12Bytes, + alias = expectedAlias, + password = password, + ), + ) + } + + private fun setupMockAndroidKeyStore() { + every { KeyStore.getInstance("AndroidKeyStore") } returns mockAndroidKeyStore + every { mockAndroidKeyStore.load(null) } just runs + } + + private fun setupMockPkcs12KeyStore() { + every { KeyStore.getInstance("pkcs12") } returns mockPkcs12KeyStore + every { mockPkcs12KeyStore.load(any(), any()) } just runs + } +}