[PM-17424] Implement KeyManager for handling private keys (#4608)

This commit is contained in:
Patrick Honkonen
2025-01-24 15:55:33 -05:00
committed by GitHub
parent 3a6db38172
commit 464f8de5f5
7 changed files with 979 additions and 0 deletions

View File

@@ -0,0 +1,45 @@
package com.x8bit.bitwarden.data.platform.datasource.disk.model
/**
* Models the result of importing a private key.
*/
sealed class ImportPrivateKeyResult {
/**
* Represents a successful result of importing a private key.
*
* @property alias The alias assigned to the imported private key.
*/
data class Success(val alias: String) : ImportPrivateKeyResult()
/**
* Represents a generic error during the import process.
*/
sealed class Error : ImportPrivateKeyResult() {
/**
* Indicates that the provided key is unrecoverable or the password is incorrect.
*/
data object UnrecoverableKey : Error()
/**
* Indicates that the certificate chain associated with the key is invalid.
*/
data object InvalidCertificateChain : Error()
/**
* Indicates that the specified alias is already in use.
*/
data object DuplicateAlias : Error()
/**
* Indicates that an error occurred during the key store operation.
*/
data object KeyStoreOperationFailed : Error()
/**
* Indicates the provided key is not supported.
*/
data object UnsupportedKey : Error()
}
}

View File

@@ -0,0 +1,36 @@
package com.x8bit.bitwarden.data.platform.datasource.disk.model
import java.security.PrivateKey
import java.security.cert.X509Certificate
/**
* Represents a mutual TLS certificate.
*/
data class MutualTlsCertificate(
val alias: String,
val privateKey: PrivateKey,
val certificateChain: List<X509Certificate>,
) {
/**
* Leaf certificate of the chain.
*/
val leafCertificate: X509Certificate?
get() = certificateChain.lastOrNull()
/**
* Root certificate of the chain.
*/
val rootCertificate: X509Certificate?
get() = certificateChain.firstOrNull()
override fun toString(): String = leafCertificate
?.let {
buildString {
appendLine("Subject: ${it.subjectDN}")
appendLine("Issuer: ${it.issuerDN}")
appendLine("Valid From: ${it.notBefore}")
appendLine("Valid Until: ${it.notAfter}")
}
}
?: ""
}

View File

@@ -0,0 +1,16 @@
package com.x8bit.bitwarden.data.platform.datasource.disk.model
/**
* Location of the key data.
*/
enum class MutualTlsKeyHost {
/**
* Key is stored in the system key chain.
*/
KEY_CHAIN,
/**
* Key is stored in a private instance of the Android Key Store.
*/
ANDROID_KEY_STORE,
}

View File

@@ -0,0 +1,37 @@
package com.x8bit.bitwarden.data.platform.manager
import com.x8bit.bitwarden.data.platform.datasource.disk.model.ImportPrivateKeyResult
import com.x8bit.bitwarden.data.platform.datasource.disk.model.MutualTlsCertificate
import com.x8bit.bitwarden.data.platform.datasource.disk.model.MutualTlsKeyHost
/**
* Primary access point for disk information related to key data.
*/
interface KeyManager {
/**
* Import a private key into the application KeyStore.
*
* @param key The private key to be saved.
* @param alias Alias to be assigned to the private key.
* @param password Password used to protect the certificate.
*/
fun importMutualTlsCertificate(
key: ByteArray,
alias: String,
password: String,
): ImportPrivateKeyResult
/**
* Removes the mTLS key from storage.
*/
fun removeMutualTlsKey(alias: String, host: MutualTlsKeyHost)
/**
* Retrieve the certificate chain for the selected mTLS key.
*/
fun getMutualTlsCertificateChain(
alias: String,
host: MutualTlsKeyHost,
): MutualTlsCertificate?
}

View File

@@ -0,0 +1,188 @@
package com.x8bit.bitwarden.data.platform.manager
import android.content.Context
import android.security.KeyChain
import android.security.KeyChainException
import com.x8bit.bitwarden.data.platform.datasource.disk.model.ImportPrivateKeyResult
import com.x8bit.bitwarden.data.platform.datasource.disk.model.MutualTlsCertificate
import com.x8bit.bitwarden.data.platform.datasource.disk.model.MutualTlsKeyHost
import timber.log.Timber
import java.io.IOException
import java.security.KeyStore
import java.security.KeyStoreException
import java.security.NoSuchAlgorithmException
import java.security.PrivateKey
import java.security.UnrecoverableKeyException
import java.security.cert.Certificate
import java.security.cert.CertificateException
import java.security.cert.X509Certificate
/**
* Default implementation of [KeyManager].
*/
class KeyManagerImpl(
private val context: Context,
) : KeyManager {
@Suppress("CyclomaticComplexMethod")
override fun importMutualTlsCertificate(
key: ByteArray,
alias: String,
password: String,
): ImportPrivateKeyResult {
// Step 1: Load PKCS12 bytes into a KeyStore.
val pkcs12KeyStore: KeyStore = key
.inputStream()
.use { stream ->
try {
KeyStore.getInstance(KEYSTORE_TYPE_PKCS12)
.also { it.load(stream, password.toCharArray()) }
} catch (e: KeyStoreException) {
Timber.Forest.e(e, "Failed to load PKCS12 bytes")
return ImportPrivateKeyResult.Error.UnsupportedKey
} catch (e: IOException) {
Timber.Forest.e(e, "Format or password error while loading PKCS12 bytes")
return when (e.cause) {
is UnrecoverableKeyException -> {
ImportPrivateKeyResult.Error.UnrecoverableKey
}
else -> {
ImportPrivateKeyResult.Error.KeyStoreOperationFailed
}
}
} catch (e: CertificateException) {
Timber.Forest.e(e, "Unable to load certificate chain")
return ImportPrivateKeyResult.Error.InvalidCertificateChain
} catch (e: NoSuchAlgorithmException) {
Timber.Forest.e(e, "Cryptographic algorithm not supported")
return ImportPrivateKeyResult.Error.UnsupportedKey
}
}
// Step 2: Get a list of aliases and choose the first one.
val internalAlias = pkcs12KeyStore.aliases()
?.takeIf { it.hasMoreElements() }
?.nextElement()
?: return ImportPrivateKeyResult.Error.UnsupportedKey
// Step 3: Extract PrivateKey and X.509 certificate from the KeyStore and verify
// certificate alias.
val privateKey = try {
pkcs12KeyStore.getKey(internalAlias, password.toCharArray())
?: return ImportPrivateKeyResult.Error.UnrecoverableKey
} catch (e: UnrecoverableKeyException) {
Timber.Forest.e(e, "Failed to get private key")
return ImportPrivateKeyResult.Error.UnrecoverableKey
}
val certChain: Array<Certificate> = pkcs12KeyStore
.getCertificateChain(internalAlias)
?.takeUnless { it.isEmpty() }
?: return ImportPrivateKeyResult.Error.InvalidCertificateChain
// Step 4: Store the private key and X.509 certificate in the AndroidKeyStore if the alias
// does not exists.
with(androidKeyStore) {
if (containsAlias(alias)) {
return ImportPrivateKeyResult.Error.DuplicateAlias
}
try {
setKeyEntry(alias, privateKey, null, certChain)
} catch (e: KeyStoreException) {
Timber.Forest.e(e, "Failed to import key into Android KeyStore")
return ImportPrivateKeyResult.Error.KeyStoreOperationFailed
}
}
return ImportPrivateKeyResult.Success(alias)
}
override fun removeMutualTlsKey(
alias: String,
host: MutualTlsKeyHost,
) {
when (host) {
MutualTlsKeyHost.ANDROID_KEY_STORE -> removeKeyFromAndroidKeyStore(alias)
else -> Unit
}
}
override fun getMutualTlsCertificateChain(
alias: String,
host: MutualTlsKeyHost,
): MutualTlsCertificate? = when (host) {
MutualTlsKeyHost.ANDROID_KEY_STORE -> getKeyFromAndroidKeyStore(alias)
MutualTlsKeyHost.KEY_CHAIN -> getSystemKeySpecOrNull(alias)
}
private fun removeKeyFromAndroidKeyStore(alias: String) {
try {
androidKeyStore.deleteEntry(alias)
} catch (e: KeyStoreException) {
Timber.Forest.e(e, "Failed to remove key from Android KeyStore")
}
}
private fun getSystemKeySpecOrNull(alias: String): MutualTlsCertificate? {
val systemPrivateKey = try {
KeyChain.getPrivateKey(context, alias)
} catch (e: KeyChainException) {
Timber.Forest.e(e, "Requested alias not found in system KeyChain")
null
}
?: return null
val systemCertificateChain = try {
KeyChain.getCertificateChain(context, alias)
} catch (e: KeyChainException) {
Timber.Forest.e(e, "Unable to access certificate chain for provided alias")
null
}
?: return null
return MutualTlsCertificate(
alias = alias,
certificateChain = systemCertificateChain.toList(),
privateKey = systemPrivateKey,
)
}
private fun getKeyFromAndroidKeyStore(alias: String): MutualTlsCertificate? =
with(androidKeyStore) {
try {
val privateKeyRef = (getKey(alias, null) as? PrivateKey)
?: return null
val certChain = getCertificateChain(alias)
.mapNotNull { it as? X509Certificate }
.takeUnless { it.isEmpty() }
?: return null
MutualTlsCertificate(
alias = alias,
certificateChain = certChain,
privateKey = privateKeyRef,
)
} catch (e: KeyStoreException) {
Timber.Forest.e(e, "Failed to load Android KeyStore")
null
} catch (e: UnrecoverableKeyException) {
Timber.Forest.e(e, "Failed to load client certificate from Android KeyStore")
null
} catch (e: NoSuchAlgorithmException) {
Timber.Forest.e(e, "Key cannot be recovered. Password may be incorrect.")
null
} catch (e: NoSuchAlgorithmException) {
Timber.Forest.e(e, "Algorithm not supported")
null
}
}
private val androidKeyStore
get() = KeyStore
.getInstance(KEYSTORE_TYPE_ANDROID)
.also { it.load(null) }
}
private const val KEYSTORE_TYPE_ANDROID = "AndroidKeyStore"
private const val KEYSTORE_TYPE_PKCS12 = "pkcs12"

View File

@@ -28,6 +28,8 @@ import com.x8bit.bitwarden.data.platform.manager.FeatureFlagManager
import com.x8bit.bitwarden.data.platform.manager.FeatureFlagManagerImpl
import com.x8bit.bitwarden.data.platform.manager.FirstTimeActionManager
import com.x8bit.bitwarden.data.platform.manager.FirstTimeActionManagerImpl
import com.x8bit.bitwarden.data.platform.manager.KeyManager
import com.x8bit.bitwarden.data.platform.manager.KeyManagerImpl
import com.x8bit.bitwarden.data.platform.manager.LogsManager
import com.x8bit.bitwarden.data.platform.manager.LogsManagerImpl
import com.x8bit.bitwarden.data.platform.manager.PolicyManager
@@ -329,4 +331,10 @@ object PlatformManagerModule {
autofillEnabledManager = autofillEnabledManager,
accessibilityEnabledManager = accessibilityEnabledManager,
)
@Provides
@Singleton
fun provideKeyManager(
@ApplicationContext context: Context,
): KeyManager = KeyManagerImpl(context = context)
}

View File

@@ -0,0 +1,649 @@
package com.x8bit.bitwarden.data.platform.manager
import android.content.Context
import android.security.KeyChain
import android.security.KeyChainException
import com.x8bit.bitwarden.data.platform.datasource.disk.model.ImportPrivateKeyResult
import com.x8bit.bitwarden.data.platform.datasource.disk.model.MutualTlsCertificate
import com.x8bit.bitwarden.data.platform.datasource.disk.model.MutualTlsKeyHost
import io.mockk.every
import io.mockk.just
import io.mockk.mockk
import io.mockk.mockkStatic
import io.mockk.runs
import io.mockk.unmockkStatic
import io.mockk.verify
import org.junit.jupiter.api.AfterEach
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.Assertions.assertNull
import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Test
import java.io.IOException
import java.security.KeyStore
import java.security.KeyStoreException
import java.security.NoSuchAlgorithmException
import java.security.PrivateKey
import java.security.UnrecoverableKeyException
import java.security.cert.Certificate
import java.security.cert.CertificateException
import java.security.cert.X509Certificate
class KeyManagerTest {
private val mockContext = mockk<Context>()
private val mockAndroidKeyStore = mockk<KeyStore>(name = "MockAndroidKeyStore")
private val mockPkcs12KeyStore = mockk<KeyStore>(name = "MockPKCS12KeyStore")
private val keyDiskSource = KeyManagerImpl(
context = mockContext,
)
@BeforeEach
fun setUp() {
mockkStatic(KeyStore::class, KeyChain::class)
}
@AfterEach
fun tearDown() {
unmockkStatic(KeyStore::class, KeyChain::class)
}
@Test
fun `getMutualTlsCertificateChain should return null when MutualTlsKeyAlias is not found`() {
// Verify null is returned when alias is not found in KeyChain
setupMockAndroidKeyStore()
every { KeyChain.getPrivateKey(mockContext, "mockAlias") } throws KeyChainException()
assertNull(
keyDiskSource.getMutualTlsCertificateChain(
alias = "mockAlias",
host = MutualTlsKeyHost.KEY_CHAIN,
),
)
// Verify null is returned when alias is not found in AndroidKeyStore
every { mockAndroidKeyStore.getKey("mockAlias", null) } throws UnrecoverableKeyException()
assertNull(
keyDiskSource.getMutualTlsCertificateChain(
alias = "mockAlias",
host = MutualTlsKeyHost.ANDROID_KEY_STORE,
),
)
}
@Suppress("MaxLineLength")
@Test
fun `getMutualTlsCertificateChain should return MutualTlsCertificateChain when using ANDROID KEY STORE and key is found`() {
setupMockAndroidKeyStore()
val mockAlias = "mockAlias"
val mockPrivateKey = mockk<PrivateKey>()
val mockCertificate1 = mockk<X509Certificate>(name = "mockCertificate1")
val mockCertificate2 = mockk<X509Certificate>(name = "mockCertificate2")
every {
mockAndroidKeyStore.getCertificateChain(mockAlias)
} returns arrayOf(mockCertificate1, mockCertificate2)
every {
mockAndroidKeyStore.getKey(mockAlias, null)
} returns mockPrivateKey
val result = keyDiskSource.getMutualTlsCertificateChain(
alias = mockAlias,
host = MutualTlsKeyHost.ANDROID_KEY_STORE,
)
assertEquals(
MutualTlsCertificate(
alias = mockAlias,
certificateChain = listOf(mockCertificate1, mockCertificate2),
privateKey = mockPrivateKey,
),
result,
)
}
@Suppress("MaxLineLength")
@Test
fun `getMutualTlsCertificateChain should return null when using ANDROID KEY STORE and key is not found`() {
setupMockAndroidKeyStore()
val mockAlias = "mockAlias"
val mockCertificate1 = mockk<X509Certificate>(name = "mockCertificate1")
val mockCertificate2 = mockk<X509Certificate>(name = "mockCertificate2")
every {
mockAndroidKeyStore.getCertificateChain(mockAlias)
} returns arrayOf(mockCertificate1, mockCertificate2)
every {
mockAndroidKeyStore.getKey(mockAlias, null)
} returns null
assertNull(
keyDiskSource.getMutualTlsCertificateChain(
alias = mockAlias,
host = MutualTlsKeyHost.ANDROID_KEY_STORE,
),
)
}
@Suppress("MaxLineLength")
@Test
fun `getMutualTlsCertificateChain should return null when using ANDROID KEY STORE and certificate chain is invalid`() {
setupMockAndroidKeyStore()
val mockAlias = "mockAlias"
every {
mockAndroidKeyStore.getKey(mockAlias, null)
} returns mockk<PrivateKey>()
// Verify null is returned when certificate chain is empty
every {
mockAndroidKeyStore.getCertificateChain(mockAlias)
} returns emptyArray()
assertNull(
keyDiskSource.getMutualTlsCertificateChain(
alias = mockAlias,
host = MutualTlsKeyHost.ANDROID_KEY_STORE,
),
)
// Verify null is returned when certificate chain contains non-X509Certificate objects
every {
mockAndroidKeyStore.getCertificateChain(mockAlias)
} returns arrayOf(mockk<Certificate>())
assertNull(
keyDiskSource.getMutualTlsCertificateChain(
alias = mockAlias,
host = MutualTlsKeyHost.ANDROID_KEY_STORE,
),
)
}
@Suppress("MaxLineLength")
@Test
fun `getMutualTlsCertificateChain should return null when using ANDROID KEY STORE and an exception occurs`() {
setupMockAndroidKeyStore()
val mockAlias = "mockAlias"
val mockCertificate1 = mockk<X509Certificate>(name = "mockCertificate1")
val mockCertificate2 = mockk<X509Certificate>(name = "mockCertificate2")
every {
mockAndroidKeyStore.getCertificateChain(mockAlias)
} returns arrayOf(mockCertificate1, mockCertificate2)
// Verify KeyStoreException is handled
every {
mockAndroidKeyStore.getKey(mockAlias, null)
} throws KeyStoreException()
assertNull(
keyDiskSource.getMutualTlsCertificateChain(
alias = mockAlias,
host = MutualTlsKeyHost.ANDROID_KEY_STORE,
),
)
// Verify UnrecoverableKeyException is handled
every {
mockAndroidKeyStore.getKey(mockAlias, null)
} throws UnrecoverableKeyException()
assertNull(
keyDiskSource.getMutualTlsCertificateChain(
alias = mockAlias,
host = MutualTlsKeyHost.ANDROID_KEY_STORE,
),
)
// Verify NoSuchAlgorithmException is handled
every {
mockAndroidKeyStore.getKey(mockAlias, null)
} throws NoSuchAlgorithmException()
assertNull(
keyDiskSource.getMutualTlsCertificateChain(
alias = mockAlias,
host = MutualTlsKeyHost.ANDROID_KEY_STORE,
),
)
}
@Suppress("MaxLineLength")
@Test
fun `getMutualTlsCertificateChain should return MutualTlsCertificateChain when using KEY CHAIN and key is found`() {
val mockAlias = "mockAlias"
val mockPrivateKey = mockk<PrivateKey>()
val mockCertificate1 = mockk<X509Certificate>(name = "mockCertificate1")
val mockCertificate2 = mockk<X509Certificate>(name = "mockCertificate2")
every {
KeyChain.getCertificateChain(mockContext, mockAlias)
} returns arrayOf(mockCertificate1, mockCertificate2)
every {
KeyChain.getPrivateKey(mockContext, mockAlias)
} returns mockPrivateKey
val result = keyDiskSource.getMutualTlsCertificateChain(
alias = mockAlias,
host = MutualTlsKeyHost.KEY_CHAIN,
)
assertEquals(
MutualTlsCertificate(
alias = mockAlias,
certificateChain = listOf(mockCertificate1, mockCertificate2),
privateKey = mockPrivateKey,
),
result,
)
}
@Suppress("MaxLineLength")
@Test
fun `getMutualTlsCertificateChain should return null when using KEY CHAIN and key is not found`() {
val mockAlias = "mockAlias"
val mockCertificate1 = mockk<X509Certificate>(name = "mockCertificate1")
val mockCertificate2 = mockk<X509Certificate>(name = "mockCertificate2")
every {
KeyChain.getCertificateChain(mockContext, mockAlias)
} returns arrayOf(mockCertificate1, mockCertificate2)
every {
KeyChain.getPrivateKey(mockContext, mockAlias)
} returns null
assertNull(
keyDiskSource.getMutualTlsCertificateChain(
alias = mockAlias,
host = MutualTlsKeyHost.KEY_CHAIN,
),
)
}
@Suppress("MaxLineLength")
@Test
fun `getMutualTlsCertificateChain should return null when using KEY CHAIN and an exception occurs`() {
val mockAlias = "mockAlias"
val mockCertificate1 = mockk<X509Certificate>(name = "mockCertificate1")
val mockCertificate2 = mockk<X509Certificate>(name = "mockCertificate2")
every {
KeyChain.getCertificateChain(mockContext, mockAlias)
} returns arrayOf(mockCertificate1, mockCertificate2)
// Verify KeyChainException from getPrivateKey is handled
every {
KeyChain.getPrivateKey(mockContext, mockAlias)
} throws KeyChainException()
assertNull(
keyDiskSource.getMutualTlsCertificateChain(
alias = mockAlias,
host = MutualTlsKeyHost.KEY_CHAIN,
),
)
// Verify KeyChainException from getCertificateChain is handled
every { KeyChain.getPrivateKey(mockContext, mockAlias) } returns mockk()
every { KeyChain.getCertificateChain(mockContext, mockAlias) } throws KeyChainException()
assertNull(
keyDiskSource.getMutualTlsCertificateChain(
alias = mockAlias,
host = MutualTlsKeyHost.KEY_CHAIN,
),
)
}
@Suppress("MaxLineLength")
@Test
fun `removeMutualTlsKey should remove key from AndroidKeyStore when host is ANDROID_KEY_STORE`() {
setupMockAndroidKeyStore()
val mockAlias = "mockAlias"
every { mockAndroidKeyStore.deleteEntry(mockAlias) } just runs
keyDiskSource.removeMutualTlsKey(
alias = mockAlias,
host = MutualTlsKeyHost.ANDROID_KEY_STORE,
)
verify {
mockAndroidKeyStore.deleteEntry(mockAlias)
}
}
@Test
fun `removeMutualTlsKey should do nothing when host is KEY_CHAIN`() {
keyDiskSource.removeMutualTlsKey(
alias = "mockAlias",
host = MutualTlsKeyHost.KEY_CHAIN,
)
verify(exactly = 0) {
mockAndroidKeyStore.deleteEntry(any())
}
}
@Test
fun `importMutualTlsCertificate should return Success when key is imported successfully`() {
setupMockAndroidKeyStore()
setupMockPkcs12KeyStore()
val expectedAlias = "mockAlias"
val internalAlias = "mockInternalAlias"
val privateKey = mockk<PrivateKey>()
val certChain = arrayOf(mockk<X509Certificate>())
val pkcs12Bytes = "key.p12".toByteArray()
val password = "password"
every { mockPkcs12KeyStore.aliases() } returns mockk {
every { hasMoreElements() } returns true
every { nextElement() } returns internalAlias
}
every {
mockPkcs12KeyStore.setKeyEntry(
internalAlias,
privateKey,
null,
certChain,
)
} just runs
every {
mockPkcs12KeyStore.getKey(
internalAlias,
password.toCharArray(),
)
} returns privateKey
every {
mockPkcs12KeyStore.getCertificateChain(internalAlias)
} returns certChain
every {
mockAndroidKeyStore.containsAlias(expectedAlias)
} returns false
every {
mockAndroidKeyStore.setKeyEntry(expectedAlias, privateKey, null, certChain)
} just runs
assertEquals(
ImportPrivateKeyResult.Success(alias = expectedAlias),
keyDiskSource.importMutualTlsCertificate(
key = pkcs12Bytes,
alias = expectedAlias,
password = password,
),
)
}
@Test
fun `importMutualTlsCertificate should return Error when loading PKCS12 throws an exception`() {
setupMockPkcs12KeyStore()
val expectedAlias = "mockAlias"
val pkcs12Bytes = "key.p12".toByteArray()
val password = "password"
// Verify KeyStoreException is handled
every {
mockPkcs12KeyStore.load(any(), any())
} throws KeyStoreException()
assertEquals(
ImportPrivateKeyResult.Error.UnsupportedKey,
keyDiskSource.importMutualTlsCertificate(
key = pkcs12Bytes,
alias = expectedAlias,
password = password,
),
) { "KeyStoreException was not handled correctly" }
// Verify IOException is handled
every {
mockPkcs12KeyStore.load(any(), any())
} throws IOException()
assertEquals(
ImportPrivateKeyResult.Error.KeyStoreOperationFailed,
keyDiskSource.importMutualTlsCertificate(
key = pkcs12Bytes,
alias = expectedAlias,
password = password,
),
) { "IOException was not handled correctly" }
// Verify IOException with UnrecoverableKeyException cause is handled
every {
mockPkcs12KeyStore.load(any(), any())
} throws IOException(UnrecoverableKeyException())
assertEquals(
ImportPrivateKeyResult.Error.UnrecoverableKey,
keyDiskSource.importMutualTlsCertificate(
key = pkcs12Bytes,
alias = expectedAlias,
password = password,
),
)
// Verify IOException with unexpected cause is handled
every {
mockPkcs12KeyStore.load(any(), any())
} throws IOException(Exception())
assertEquals(
ImportPrivateKeyResult.Error.KeyStoreOperationFailed,
keyDiskSource.importMutualTlsCertificate(
key = pkcs12Bytes,
alias = expectedAlias,
password = password,
),
) { "IOException with Unexpected exception cause was not handled correctly" }
// Verify CertificateException is handled
every {
mockPkcs12KeyStore.load(any(), any())
} throws CertificateException()
assertEquals(
ImportPrivateKeyResult.Error.InvalidCertificateChain,
keyDiskSource.importMutualTlsCertificate(
key = pkcs12Bytes,
alias = expectedAlias,
password = password,
),
) { "CertificateException was not handled correctly" }
// Verify NoSuchAlgorithmException is handled
every {
mockPkcs12KeyStore.load(any(), any())
} throws NoSuchAlgorithmException()
assertEquals(
ImportPrivateKeyResult.Error.UnsupportedKey,
keyDiskSource.importMutualTlsCertificate(
key = pkcs12Bytes,
alias = expectedAlias,
password = password,
),
) { "NoSuchAlgorithmException was not handled correctly" }
}
@Test
fun `importMutualTlsCertificate should return UnsupportedKey when key store is empty`() {
setupMockPkcs12KeyStore()
val expectedAlias = "mockAlias"
val pkcs12Bytes = "key.p12".toByteArray()
val password = "password"
every { mockPkcs12KeyStore.aliases() } returns mockk {
every { hasMoreElements() } returns false
}
assertEquals(
ImportPrivateKeyResult.Error.UnsupportedKey,
keyDiskSource.importMutualTlsCertificate(
key = pkcs12Bytes,
alias = expectedAlias,
password = password,
),
)
}
@Suppress("MaxLineLength")
@Test
fun `importMutualTlsCertificate should return UnrecoverableKey when unable to retrieve private key`() {
setupMockPkcs12KeyStore()
val expectedAlias = "mockAlias"
val pkcs12Bytes = "key.p12".toByteArray()
val password = "password"
every {
mockPkcs12KeyStore.aliases()
} returns mockk {
every { hasMoreElements() } returns true
every { nextElement() } returns "mockInternalAlias"
}
every {
mockPkcs12KeyStore.getKey(
"mockInternalAlias",
password.toCharArray(),
)
} throws UnrecoverableKeyException()
assertEquals(
ImportPrivateKeyResult.Error.UnrecoverableKey,
keyDiskSource.importMutualTlsCertificate(
key = pkcs12Bytes,
alias = expectedAlias,
password = password,
),
)
every {
mockPkcs12KeyStore.getKey(
"mockInternalAlias",
password.toCharArray(),
)
} returns null
assertEquals(
ImportPrivateKeyResult.Error.UnrecoverableKey,
keyDiskSource.importMutualTlsCertificate(
key = pkcs12Bytes,
alias = expectedAlias,
password = password,
),
)
}
@Suppress("MaxLineLength")
@Test
fun `importMutualTlsCertificate should return InvalidCertificateChain when certificate chain is empty`() {
setupMockPkcs12KeyStore()
val expectedAlias = "mockAlias"
val pkcs12Bytes = "key.p12".toByteArray()
val password = "password"
every { mockPkcs12KeyStore.aliases() } returns mockk {
every { hasMoreElements() } returns true
every { nextElement() } returns "mockInternalAlias"
}
every {
mockPkcs12KeyStore.getKey(
"mockInternalAlias",
password.toCharArray(),
)
} returns mockk()
// Verify empty certificate chain is handled
every {
mockPkcs12KeyStore.getCertificateChain("mockInternalAlias")
} returns emptyArray()
assertEquals(
ImportPrivateKeyResult.Error.InvalidCertificateChain,
keyDiskSource.importMutualTlsCertificate(
key = pkcs12Bytes,
alias = expectedAlias,
password = password,
),
)
// Verify null certificate chain is handled
every {
mockPkcs12KeyStore.getCertificateChain("mockInternalAlias")
} returns null
assertEquals(
ImportPrivateKeyResult.Error.InvalidCertificateChain,
keyDiskSource.importMutualTlsCertificate(
key = pkcs12Bytes,
alias = expectedAlias,
password = password,
),
)
}
@Suppress("MaxLineLength")
@Test
fun `importMutualTlsCertificate should return KeyStoreOperationFailed when saving to Android KeyStore throws KeyStoreException`() {
setupMockAndroidKeyStore()
val expectedAlias = "mockAlias"
val pkcs12Bytes = "key.p12".toByteArray()
val password = "password"
every { mockPkcs12KeyStore.aliases() } returns mockk {
every { hasMoreElements() } returns true
every { nextElement() } returns "mockInternalAlias"
}
every {
mockPkcs12KeyStore.getKey(
"mockInternalAlias",
password.toCharArray(),
)
} returns mockk()
every {
mockPkcs12KeyStore.getCertificateChain("mockInternalAlias")
} returns arrayOf(mockk())
every {
mockAndroidKeyStore.setKeyEntry(
expectedAlias,
any(),
any(),
any(),
)
} throws KeyStoreException()
assertEquals(
ImportPrivateKeyResult.Error.KeyStoreOperationFailed,
keyDiskSource.importMutualTlsCertificate(
key = pkcs12Bytes,
alias = expectedAlias,
password = password,
),
)
}
@Suppress("MaxLineLength")
@Test
fun `importMutualTlsCertificate should return DuplicateAlias when alias already exists in AndroidKeyStore`() {
setupMockAndroidKeyStore()
setupMockPkcs12KeyStore()
val expectedAlias = "mockAlias"
val pkcs12Bytes = "key.p12".toByteArray()
val password = "password"
every { mockPkcs12KeyStore.aliases() } returns mockk {
every { hasMoreElements() } returns true
every { nextElement() } returns "mockInternalAlias"
}
every {
mockPkcs12KeyStore.getKey(
"mockInternalAlias",
password.toCharArray(),
)
} returns mockk()
every {
mockPkcs12KeyStore.getCertificateChain("mockInternalAlias")
} returns arrayOf(mockk())
every { mockAndroidKeyStore.containsAlias(expectedAlias) } returns true
assertEquals(
ImportPrivateKeyResult.Error.DuplicateAlias,
keyDiskSource.importMutualTlsCertificate(
key = pkcs12Bytes,
alias = expectedAlias,
password = password,
),
)
}
private fun setupMockAndroidKeyStore() {
every { KeyStore.getInstance("AndroidKeyStore") } returns mockAndroidKeyStore
every { mockAndroidKeyStore.load(null) } just runs
}
private fun setupMockPkcs12KeyStore() {
every { KeyStore.getInstance("pkcs12") } returns mockPkcs12KeyStore
every { mockPkcs12KeyStore.load(any(), any()) } just runs
}
}