mirror of
https://github.com/bitwarden/android.git
synced 2026-03-12 05:04:17 -05:00
[PM-17424] Implement KeyManager for handling private keys (#4608)
This commit is contained in:
@@ -0,0 +1,45 @@
|
||||
package com.x8bit.bitwarden.data.platform.datasource.disk.model
|
||||
|
||||
/**
|
||||
* Models the result of importing a private key.
|
||||
*/
|
||||
sealed class ImportPrivateKeyResult {
|
||||
|
||||
/**
|
||||
* Represents a successful result of importing a private key.
|
||||
*
|
||||
* @property alias The alias assigned to the imported private key.
|
||||
*/
|
||||
data class Success(val alias: String) : ImportPrivateKeyResult()
|
||||
|
||||
/**
|
||||
* Represents a generic error during the import process.
|
||||
*/
|
||||
sealed class Error : ImportPrivateKeyResult() {
|
||||
|
||||
/**
|
||||
* Indicates that the provided key is unrecoverable or the password is incorrect.
|
||||
*/
|
||||
data object UnrecoverableKey : Error()
|
||||
|
||||
/**
|
||||
* Indicates that the certificate chain associated with the key is invalid.
|
||||
*/
|
||||
data object InvalidCertificateChain : Error()
|
||||
|
||||
/**
|
||||
* Indicates that the specified alias is already in use.
|
||||
*/
|
||||
data object DuplicateAlias : Error()
|
||||
|
||||
/**
|
||||
* Indicates that an error occurred during the key store operation.
|
||||
*/
|
||||
data object KeyStoreOperationFailed : Error()
|
||||
|
||||
/**
|
||||
* Indicates the provided key is not supported.
|
||||
*/
|
||||
data object UnsupportedKey : Error()
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,36 @@
|
||||
package com.x8bit.bitwarden.data.platform.datasource.disk.model
|
||||
|
||||
import java.security.PrivateKey
|
||||
import java.security.cert.X509Certificate
|
||||
|
||||
/**
|
||||
* Represents a mutual TLS certificate.
|
||||
*/
|
||||
data class MutualTlsCertificate(
|
||||
val alias: String,
|
||||
val privateKey: PrivateKey,
|
||||
val certificateChain: List<X509Certificate>,
|
||||
) {
|
||||
/**
|
||||
* Leaf certificate of the chain.
|
||||
*/
|
||||
val leafCertificate: X509Certificate?
|
||||
get() = certificateChain.lastOrNull()
|
||||
|
||||
/**
|
||||
* Root certificate of the chain.
|
||||
*/
|
||||
val rootCertificate: X509Certificate?
|
||||
get() = certificateChain.firstOrNull()
|
||||
|
||||
override fun toString(): String = leafCertificate
|
||||
?.let {
|
||||
buildString {
|
||||
appendLine("Subject: ${it.subjectDN}")
|
||||
appendLine("Issuer: ${it.issuerDN}")
|
||||
appendLine("Valid From: ${it.notBefore}")
|
||||
appendLine("Valid Until: ${it.notAfter}")
|
||||
}
|
||||
}
|
||||
?: ""
|
||||
}
|
||||
@@ -0,0 +1,16 @@
|
||||
package com.x8bit.bitwarden.data.platform.datasource.disk.model
|
||||
|
||||
/**
|
||||
* Location of the key data.
|
||||
*/
|
||||
enum class MutualTlsKeyHost {
|
||||
/**
|
||||
* Key is stored in the system key chain.
|
||||
*/
|
||||
KEY_CHAIN,
|
||||
|
||||
/**
|
||||
* Key is stored in a private instance of the Android Key Store.
|
||||
*/
|
||||
ANDROID_KEY_STORE,
|
||||
}
|
||||
@@ -0,0 +1,37 @@
|
||||
package com.x8bit.bitwarden.data.platform.manager
|
||||
|
||||
import com.x8bit.bitwarden.data.platform.datasource.disk.model.ImportPrivateKeyResult
|
||||
import com.x8bit.bitwarden.data.platform.datasource.disk.model.MutualTlsCertificate
|
||||
import com.x8bit.bitwarden.data.platform.datasource.disk.model.MutualTlsKeyHost
|
||||
|
||||
/**
|
||||
* Primary access point for disk information related to key data.
|
||||
*/
|
||||
interface KeyManager {
|
||||
|
||||
/**
|
||||
* Import a private key into the application KeyStore.
|
||||
*
|
||||
* @param key The private key to be saved.
|
||||
* @param alias Alias to be assigned to the private key.
|
||||
* @param password Password used to protect the certificate.
|
||||
*/
|
||||
fun importMutualTlsCertificate(
|
||||
key: ByteArray,
|
||||
alias: String,
|
||||
password: String,
|
||||
): ImportPrivateKeyResult
|
||||
|
||||
/**
|
||||
* Removes the mTLS key from storage.
|
||||
*/
|
||||
fun removeMutualTlsKey(alias: String, host: MutualTlsKeyHost)
|
||||
|
||||
/**
|
||||
* Retrieve the certificate chain for the selected mTLS key.
|
||||
*/
|
||||
fun getMutualTlsCertificateChain(
|
||||
alias: String,
|
||||
host: MutualTlsKeyHost,
|
||||
): MutualTlsCertificate?
|
||||
}
|
||||
@@ -0,0 +1,188 @@
|
||||
package com.x8bit.bitwarden.data.platform.manager
|
||||
|
||||
import android.content.Context
|
||||
import android.security.KeyChain
|
||||
import android.security.KeyChainException
|
||||
import com.x8bit.bitwarden.data.platform.datasource.disk.model.ImportPrivateKeyResult
|
||||
import com.x8bit.bitwarden.data.platform.datasource.disk.model.MutualTlsCertificate
|
||||
import com.x8bit.bitwarden.data.platform.datasource.disk.model.MutualTlsKeyHost
|
||||
import timber.log.Timber
|
||||
import java.io.IOException
|
||||
import java.security.KeyStore
|
||||
import java.security.KeyStoreException
|
||||
import java.security.NoSuchAlgorithmException
|
||||
import java.security.PrivateKey
|
||||
import java.security.UnrecoverableKeyException
|
||||
import java.security.cert.Certificate
|
||||
import java.security.cert.CertificateException
|
||||
import java.security.cert.X509Certificate
|
||||
|
||||
/**
|
||||
* Default implementation of [KeyManager].
|
||||
*/
|
||||
class KeyManagerImpl(
|
||||
private val context: Context,
|
||||
) : KeyManager {
|
||||
|
||||
@Suppress("CyclomaticComplexMethod")
|
||||
override fun importMutualTlsCertificate(
|
||||
key: ByteArray,
|
||||
alias: String,
|
||||
password: String,
|
||||
): ImportPrivateKeyResult {
|
||||
// Step 1: Load PKCS12 bytes into a KeyStore.
|
||||
val pkcs12KeyStore: KeyStore = key
|
||||
.inputStream()
|
||||
.use { stream ->
|
||||
try {
|
||||
KeyStore.getInstance(KEYSTORE_TYPE_PKCS12)
|
||||
.also { it.load(stream, password.toCharArray()) }
|
||||
} catch (e: KeyStoreException) {
|
||||
Timber.Forest.e(e, "Failed to load PKCS12 bytes")
|
||||
return ImportPrivateKeyResult.Error.UnsupportedKey
|
||||
} catch (e: IOException) {
|
||||
Timber.Forest.e(e, "Format or password error while loading PKCS12 bytes")
|
||||
return when (e.cause) {
|
||||
is UnrecoverableKeyException -> {
|
||||
ImportPrivateKeyResult.Error.UnrecoverableKey
|
||||
}
|
||||
|
||||
else -> {
|
||||
ImportPrivateKeyResult.Error.KeyStoreOperationFailed
|
||||
}
|
||||
}
|
||||
} catch (e: CertificateException) {
|
||||
Timber.Forest.e(e, "Unable to load certificate chain")
|
||||
return ImportPrivateKeyResult.Error.InvalidCertificateChain
|
||||
} catch (e: NoSuchAlgorithmException) {
|
||||
Timber.Forest.e(e, "Cryptographic algorithm not supported")
|
||||
return ImportPrivateKeyResult.Error.UnsupportedKey
|
||||
}
|
||||
}
|
||||
|
||||
// Step 2: Get a list of aliases and choose the first one.
|
||||
val internalAlias = pkcs12KeyStore.aliases()
|
||||
?.takeIf { it.hasMoreElements() }
|
||||
?.nextElement()
|
||||
?: return ImportPrivateKeyResult.Error.UnsupportedKey
|
||||
|
||||
// Step 3: Extract PrivateKey and X.509 certificate from the KeyStore and verify
|
||||
// certificate alias.
|
||||
val privateKey = try {
|
||||
pkcs12KeyStore.getKey(internalAlias, password.toCharArray())
|
||||
?: return ImportPrivateKeyResult.Error.UnrecoverableKey
|
||||
} catch (e: UnrecoverableKeyException) {
|
||||
Timber.Forest.e(e, "Failed to get private key")
|
||||
return ImportPrivateKeyResult.Error.UnrecoverableKey
|
||||
}
|
||||
|
||||
val certChain: Array<Certificate> = pkcs12KeyStore
|
||||
.getCertificateChain(internalAlias)
|
||||
?.takeUnless { it.isEmpty() }
|
||||
?: return ImportPrivateKeyResult.Error.InvalidCertificateChain
|
||||
|
||||
// Step 4: Store the private key and X.509 certificate in the AndroidKeyStore if the alias
|
||||
// does not exists.
|
||||
with(androidKeyStore) {
|
||||
if (containsAlias(alias)) {
|
||||
return ImportPrivateKeyResult.Error.DuplicateAlias
|
||||
}
|
||||
|
||||
try {
|
||||
setKeyEntry(alias, privateKey, null, certChain)
|
||||
} catch (e: KeyStoreException) {
|
||||
Timber.Forest.e(e, "Failed to import key into Android KeyStore")
|
||||
return ImportPrivateKeyResult.Error.KeyStoreOperationFailed
|
||||
}
|
||||
}
|
||||
return ImportPrivateKeyResult.Success(alias)
|
||||
}
|
||||
|
||||
override fun removeMutualTlsKey(
|
||||
alias: String,
|
||||
host: MutualTlsKeyHost,
|
||||
) {
|
||||
when (host) {
|
||||
MutualTlsKeyHost.ANDROID_KEY_STORE -> removeKeyFromAndroidKeyStore(alias)
|
||||
else -> Unit
|
||||
}
|
||||
}
|
||||
|
||||
override fun getMutualTlsCertificateChain(
|
||||
alias: String,
|
||||
host: MutualTlsKeyHost,
|
||||
): MutualTlsCertificate? = when (host) {
|
||||
MutualTlsKeyHost.ANDROID_KEY_STORE -> getKeyFromAndroidKeyStore(alias)
|
||||
|
||||
MutualTlsKeyHost.KEY_CHAIN -> getSystemKeySpecOrNull(alias)
|
||||
}
|
||||
|
||||
private fun removeKeyFromAndroidKeyStore(alias: String) {
|
||||
try {
|
||||
androidKeyStore.deleteEntry(alias)
|
||||
} catch (e: KeyStoreException) {
|
||||
Timber.Forest.e(e, "Failed to remove key from Android KeyStore")
|
||||
}
|
||||
}
|
||||
|
||||
private fun getSystemKeySpecOrNull(alias: String): MutualTlsCertificate? {
|
||||
val systemPrivateKey = try {
|
||||
KeyChain.getPrivateKey(context, alias)
|
||||
} catch (e: KeyChainException) {
|
||||
Timber.Forest.e(e, "Requested alias not found in system KeyChain")
|
||||
null
|
||||
}
|
||||
?: return null
|
||||
|
||||
val systemCertificateChain = try {
|
||||
KeyChain.getCertificateChain(context, alias)
|
||||
} catch (e: KeyChainException) {
|
||||
Timber.Forest.e(e, "Unable to access certificate chain for provided alias")
|
||||
null
|
||||
}
|
||||
?: return null
|
||||
|
||||
return MutualTlsCertificate(
|
||||
alias = alias,
|
||||
certificateChain = systemCertificateChain.toList(),
|
||||
privateKey = systemPrivateKey,
|
||||
)
|
||||
}
|
||||
|
||||
private fun getKeyFromAndroidKeyStore(alias: String): MutualTlsCertificate? =
|
||||
with(androidKeyStore) {
|
||||
try {
|
||||
val privateKeyRef = (getKey(alias, null) as? PrivateKey)
|
||||
?: return null
|
||||
val certChain = getCertificateChain(alias)
|
||||
.mapNotNull { it as? X509Certificate }
|
||||
.takeUnless { it.isEmpty() }
|
||||
?: return null
|
||||
MutualTlsCertificate(
|
||||
alias = alias,
|
||||
certificateChain = certChain,
|
||||
privateKey = privateKeyRef,
|
||||
)
|
||||
} catch (e: KeyStoreException) {
|
||||
Timber.Forest.e(e, "Failed to load Android KeyStore")
|
||||
null
|
||||
} catch (e: UnrecoverableKeyException) {
|
||||
Timber.Forest.e(e, "Failed to load client certificate from Android KeyStore")
|
||||
null
|
||||
} catch (e: NoSuchAlgorithmException) {
|
||||
Timber.Forest.e(e, "Key cannot be recovered. Password may be incorrect.")
|
||||
null
|
||||
} catch (e: NoSuchAlgorithmException) {
|
||||
Timber.Forest.e(e, "Algorithm not supported")
|
||||
null
|
||||
}
|
||||
}
|
||||
|
||||
private val androidKeyStore
|
||||
get() = KeyStore
|
||||
.getInstance(KEYSTORE_TYPE_ANDROID)
|
||||
.also { it.load(null) }
|
||||
}
|
||||
|
||||
private const val KEYSTORE_TYPE_ANDROID = "AndroidKeyStore"
|
||||
private const val KEYSTORE_TYPE_PKCS12 = "pkcs12"
|
||||
@@ -28,6 +28,8 @@ import com.x8bit.bitwarden.data.platform.manager.FeatureFlagManager
|
||||
import com.x8bit.bitwarden.data.platform.manager.FeatureFlagManagerImpl
|
||||
import com.x8bit.bitwarden.data.platform.manager.FirstTimeActionManager
|
||||
import com.x8bit.bitwarden.data.platform.manager.FirstTimeActionManagerImpl
|
||||
import com.x8bit.bitwarden.data.platform.manager.KeyManager
|
||||
import com.x8bit.bitwarden.data.platform.manager.KeyManagerImpl
|
||||
import com.x8bit.bitwarden.data.platform.manager.LogsManager
|
||||
import com.x8bit.bitwarden.data.platform.manager.LogsManagerImpl
|
||||
import com.x8bit.bitwarden.data.platform.manager.PolicyManager
|
||||
@@ -329,4 +331,10 @@ object PlatformManagerModule {
|
||||
autofillEnabledManager = autofillEnabledManager,
|
||||
accessibilityEnabledManager = accessibilityEnabledManager,
|
||||
)
|
||||
|
||||
@Provides
|
||||
@Singleton
|
||||
fun provideKeyManager(
|
||||
@ApplicationContext context: Context,
|
||||
): KeyManager = KeyManagerImpl(context = context)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,649 @@
|
||||
package com.x8bit.bitwarden.data.platform.manager
|
||||
|
||||
import android.content.Context
|
||||
import android.security.KeyChain
|
||||
import android.security.KeyChainException
|
||||
import com.x8bit.bitwarden.data.platform.datasource.disk.model.ImportPrivateKeyResult
|
||||
import com.x8bit.bitwarden.data.platform.datasource.disk.model.MutualTlsCertificate
|
||||
import com.x8bit.bitwarden.data.platform.datasource.disk.model.MutualTlsKeyHost
|
||||
import io.mockk.every
|
||||
import io.mockk.just
|
||||
import io.mockk.mockk
|
||||
import io.mockk.mockkStatic
|
||||
import io.mockk.runs
|
||||
import io.mockk.unmockkStatic
|
||||
import io.mockk.verify
|
||||
import org.junit.jupiter.api.AfterEach
|
||||
import org.junit.jupiter.api.Assertions.assertEquals
|
||||
import org.junit.jupiter.api.Assertions.assertNull
|
||||
import org.junit.jupiter.api.BeforeEach
|
||||
import org.junit.jupiter.api.Test
|
||||
import java.io.IOException
|
||||
import java.security.KeyStore
|
||||
import java.security.KeyStoreException
|
||||
import java.security.NoSuchAlgorithmException
|
||||
import java.security.PrivateKey
|
||||
import java.security.UnrecoverableKeyException
|
||||
import java.security.cert.Certificate
|
||||
import java.security.cert.CertificateException
|
||||
import java.security.cert.X509Certificate
|
||||
|
||||
class KeyManagerTest {
|
||||
private val mockContext = mockk<Context>()
|
||||
private val mockAndroidKeyStore = mockk<KeyStore>(name = "MockAndroidKeyStore")
|
||||
private val mockPkcs12KeyStore = mockk<KeyStore>(name = "MockPKCS12KeyStore")
|
||||
private val keyDiskSource = KeyManagerImpl(
|
||||
context = mockContext,
|
||||
)
|
||||
|
||||
@BeforeEach
|
||||
fun setUp() {
|
||||
mockkStatic(KeyStore::class, KeyChain::class)
|
||||
}
|
||||
|
||||
@AfterEach
|
||||
fun tearDown() {
|
||||
unmockkStatic(KeyStore::class, KeyChain::class)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `getMutualTlsCertificateChain should return null when MutualTlsKeyAlias is not found`() {
|
||||
// Verify null is returned when alias is not found in KeyChain
|
||||
setupMockAndroidKeyStore()
|
||||
every { KeyChain.getPrivateKey(mockContext, "mockAlias") } throws KeyChainException()
|
||||
assertNull(
|
||||
keyDiskSource.getMutualTlsCertificateChain(
|
||||
alias = "mockAlias",
|
||||
host = MutualTlsKeyHost.KEY_CHAIN,
|
||||
),
|
||||
)
|
||||
|
||||
// Verify null is returned when alias is not found in AndroidKeyStore
|
||||
every { mockAndroidKeyStore.getKey("mockAlias", null) } throws UnrecoverableKeyException()
|
||||
assertNull(
|
||||
keyDiskSource.getMutualTlsCertificateChain(
|
||||
alias = "mockAlias",
|
||||
host = MutualTlsKeyHost.ANDROID_KEY_STORE,
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
@Suppress("MaxLineLength")
|
||||
@Test
|
||||
fun `getMutualTlsCertificateChain should return MutualTlsCertificateChain when using ANDROID KEY STORE and key is found`() {
|
||||
setupMockAndroidKeyStore()
|
||||
val mockAlias = "mockAlias"
|
||||
val mockPrivateKey = mockk<PrivateKey>()
|
||||
val mockCertificate1 = mockk<X509Certificate>(name = "mockCertificate1")
|
||||
val mockCertificate2 = mockk<X509Certificate>(name = "mockCertificate2")
|
||||
every {
|
||||
mockAndroidKeyStore.getCertificateChain(mockAlias)
|
||||
} returns arrayOf(mockCertificate1, mockCertificate2)
|
||||
every {
|
||||
mockAndroidKeyStore.getKey(mockAlias, null)
|
||||
} returns mockPrivateKey
|
||||
|
||||
val result = keyDiskSource.getMutualTlsCertificateChain(
|
||||
alias = mockAlias,
|
||||
host = MutualTlsKeyHost.ANDROID_KEY_STORE,
|
||||
)
|
||||
|
||||
assertEquals(
|
||||
MutualTlsCertificate(
|
||||
alias = mockAlias,
|
||||
certificateChain = listOf(mockCertificate1, mockCertificate2),
|
||||
privateKey = mockPrivateKey,
|
||||
),
|
||||
result,
|
||||
)
|
||||
}
|
||||
|
||||
@Suppress("MaxLineLength")
|
||||
@Test
|
||||
fun `getMutualTlsCertificateChain should return null when using ANDROID KEY STORE and key is not found`() {
|
||||
setupMockAndroidKeyStore()
|
||||
val mockAlias = "mockAlias"
|
||||
val mockCertificate1 = mockk<X509Certificate>(name = "mockCertificate1")
|
||||
val mockCertificate2 = mockk<X509Certificate>(name = "mockCertificate2")
|
||||
every {
|
||||
mockAndroidKeyStore.getCertificateChain(mockAlias)
|
||||
} returns arrayOf(mockCertificate1, mockCertificate2)
|
||||
every {
|
||||
mockAndroidKeyStore.getKey(mockAlias, null)
|
||||
} returns null
|
||||
|
||||
assertNull(
|
||||
keyDiskSource.getMutualTlsCertificateChain(
|
||||
alias = mockAlias,
|
||||
host = MutualTlsKeyHost.ANDROID_KEY_STORE,
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
@Suppress("MaxLineLength")
|
||||
@Test
|
||||
fun `getMutualTlsCertificateChain should return null when using ANDROID KEY STORE and certificate chain is invalid`() {
|
||||
setupMockAndroidKeyStore()
|
||||
val mockAlias = "mockAlias"
|
||||
every {
|
||||
mockAndroidKeyStore.getKey(mockAlias, null)
|
||||
} returns mockk<PrivateKey>()
|
||||
|
||||
// Verify null is returned when certificate chain is empty
|
||||
every {
|
||||
mockAndroidKeyStore.getCertificateChain(mockAlias)
|
||||
} returns emptyArray()
|
||||
assertNull(
|
||||
keyDiskSource.getMutualTlsCertificateChain(
|
||||
alias = mockAlias,
|
||||
host = MutualTlsKeyHost.ANDROID_KEY_STORE,
|
||||
),
|
||||
)
|
||||
|
||||
// Verify null is returned when certificate chain contains non-X509Certificate objects
|
||||
every {
|
||||
mockAndroidKeyStore.getCertificateChain(mockAlias)
|
||||
} returns arrayOf(mockk<Certificate>())
|
||||
assertNull(
|
||||
keyDiskSource.getMutualTlsCertificateChain(
|
||||
alias = mockAlias,
|
||||
host = MutualTlsKeyHost.ANDROID_KEY_STORE,
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
@Suppress("MaxLineLength")
|
||||
@Test
|
||||
fun `getMutualTlsCertificateChain should return null when using ANDROID KEY STORE and an exception occurs`() {
|
||||
setupMockAndroidKeyStore()
|
||||
val mockAlias = "mockAlias"
|
||||
val mockCertificate1 = mockk<X509Certificate>(name = "mockCertificate1")
|
||||
val mockCertificate2 = mockk<X509Certificate>(name = "mockCertificate2")
|
||||
every {
|
||||
mockAndroidKeyStore.getCertificateChain(mockAlias)
|
||||
} returns arrayOf(mockCertificate1, mockCertificate2)
|
||||
|
||||
// Verify KeyStoreException is handled
|
||||
every {
|
||||
mockAndroidKeyStore.getKey(mockAlias, null)
|
||||
} throws KeyStoreException()
|
||||
|
||||
assertNull(
|
||||
keyDiskSource.getMutualTlsCertificateChain(
|
||||
alias = mockAlias,
|
||||
host = MutualTlsKeyHost.ANDROID_KEY_STORE,
|
||||
),
|
||||
)
|
||||
|
||||
// Verify UnrecoverableKeyException is handled
|
||||
every {
|
||||
mockAndroidKeyStore.getKey(mockAlias, null)
|
||||
} throws UnrecoverableKeyException()
|
||||
assertNull(
|
||||
keyDiskSource.getMutualTlsCertificateChain(
|
||||
alias = mockAlias,
|
||||
host = MutualTlsKeyHost.ANDROID_KEY_STORE,
|
||||
),
|
||||
)
|
||||
|
||||
// Verify NoSuchAlgorithmException is handled
|
||||
every {
|
||||
mockAndroidKeyStore.getKey(mockAlias, null)
|
||||
} throws NoSuchAlgorithmException()
|
||||
assertNull(
|
||||
keyDiskSource.getMutualTlsCertificateChain(
|
||||
alias = mockAlias,
|
||||
host = MutualTlsKeyHost.ANDROID_KEY_STORE,
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
@Suppress("MaxLineLength")
|
||||
@Test
|
||||
fun `getMutualTlsCertificateChain should return MutualTlsCertificateChain when using KEY CHAIN and key is found`() {
|
||||
val mockAlias = "mockAlias"
|
||||
val mockPrivateKey = mockk<PrivateKey>()
|
||||
val mockCertificate1 = mockk<X509Certificate>(name = "mockCertificate1")
|
||||
val mockCertificate2 = mockk<X509Certificate>(name = "mockCertificate2")
|
||||
every {
|
||||
KeyChain.getCertificateChain(mockContext, mockAlias)
|
||||
} returns arrayOf(mockCertificate1, mockCertificate2)
|
||||
every {
|
||||
KeyChain.getPrivateKey(mockContext, mockAlias)
|
||||
} returns mockPrivateKey
|
||||
|
||||
val result = keyDiskSource.getMutualTlsCertificateChain(
|
||||
alias = mockAlias,
|
||||
host = MutualTlsKeyHost.KEY_CHAIN,
|
||||
)
|
||||
|
||||
assertEquals(
|
||||
MutualTlsCertificate(
|
||||
alias = mockAlias,
|
||||
certificateChain = listOf(mockCertificate1, mockCertificate2),
|
||||
privateKey = mockPrivateKey,
|
||||
),
|
||||
result,
|
||||
)
|
||||
}
|
||||
|
||||
@Suppress("MaxLineLength")
|
||||
@Test
|
||||
fun `getMutualTlsCertificateChain should return null when using KEY CHAIN and key is not found`() {
|
||||
val mockAlias = "mockAlias"
|
||||
val mockCertificate1 = mockk<X509Certificate>(name = "mockCertificate1")
|
||||
val mockCertificate2 = mockk<X509Certificate>(name = "mockCertificate2")
|
||||
every {
|
||||
KeyChain.getCertificateChain(mockContext, mockAlias)
|
||||
} returns arrayOf(mockCertificate1, mockCertificate2)
|
||||
every {
|
||||
KeyChain.getPrivateKey(mockContext, mockAlias)
|
||||
} returns null
|
||||
|
||||
assertNull(
|
||||
keyDiskSource.getMutualTlsCertificateChain(
|
||||
alias = mockAlias,
|
||||
host = MutualTlsKeyHost.KEY_CHAIN,
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
@Suppress("MaxLineLength")
|
||||
@Test
|
||||
fun `getMutualTlsCertificateChain should return null when using KEY CHAIN and an exception occurs`() {
|
||||
val mockAlias = "mockAlias"
|
||||
val mockCertificate1 = mockk<X509Certificate>(name = "mockCertificate1")
|
||||
val mockCertificate2 = mockk<X509Certificate>(name = "mockCertificate2")
|
||||
|
||||
every {
|
||||
KeyChain.getCertificateChain(mockContext, mockAlias)
|
||||
} returns arrayOf(mockCertificate1, mockCertificate2)
|
||||
|
||||
// Verify KeyChainException from getPrivateKey is handled
|
||||
every {
|
||||
KeyChain.getPrivateKey(mockContext, mockAlias)
|
||||
} throws KeyChainException()
|
||||
assertNull(
|
||||
keyDiskSource.getMutualTlsCertificateChain(
|
||||
alias = mockAlias,
|
||||
host = MutualTlsKeyHost.KEY_CHAIN,
|
||||
),
|
||||
)
|
||||
|
||||
// Verify KeyChainException from getCertificateChain is handled
|
||||
every { KeyChain.getPrivateKey(mockContext, mockAlias) } returns mockk()
|
||||
every { KeyChain.getCertificateChain(mockContext, mockAlias) } throws KeyChainException()
|
||||
assertNull(
|
||||
keyDiskSource.getMutualTlsCertificateChain(
|
||||
alias = mockAlias,
|
||||
host = MutualTlsKeyHost.KEY_CHAIN,
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
@Suppress("MaxLineLength")
|
||||
@Test
|
||||
fun `removeMutualTlsKey should remove key from AndroidKeyStore when host is ANDROID_KEY_STORE`() {
|
||||
setupMockAndroidKeyStore()
|
||||
val mockAlias = "mockAlias"
|
||||
|
||||
every { mockAndroidKeyStore.deleteEntry(mockAlias) } just runs
|
||||
|
||||
keyDiskSource.removeMutualTlsKey(
|
||||
alias = mockAlias,
|
||||
host = MutualTlsKeyHost.ANDROID_KEY_STORE,
|
||||
)
|
||||
|
||||
verify {
|
||||
mockAndroidKeyStore.deleteEntry(mockAlias)
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `removeMutualTlsKey should do nothing when host is KEY_CHAIN`() {
|
||||
keyDiskSource.removeMutualTlsKey(
|
||||
alias = "mockAlias",
|
||||
host = MutualTlsKeyHost.KEY_CHAIN,
|
||||
)
|
||||
|
||||
verify(exactly = 0) {
|
||||
mockAndroidKeyStore.deleteEntry(any())
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `importMutualTlsCertificate should return Success when key is imported successfully`() {
|
||||
setupMockAndroidKeyStore()
|
||||
setupMockPkcs12KeyStore()
|
||||
val expectedAlias = "mockAlias"
|
||||
val internalAlias = "mockInternalAlias"
|
||||
val privateKey = mockk<PrivateKey>()
|
||||
val certChain = arrayOf(mockk<X509Certificate>())
|
||||
val pkcs12Bytes = "key.p12".toByteArray()
|
||||
val password = "password"
|
||||
every { mockPkcs12KeyStore.aliases() } returns mockk {
|
||||
every { hasMoreElements() } returns true
|
||||
every { nextElement() } returns internalAlias
|
||||
}
|
||||
every {
|
||||
mockPkcs12KeyStore.setKeyEntry(
|
||||
internalAlias,
|
||||
privateKey,
|
||||
null,
|
||||
certChain,
|
||||
)
|
||||
} just runs
|
||||
every {
|
||||
mockPkcs12KeyStore.getKey(
|
||||
internalAlias,
|
||||
password.toCharArray(),
|
||||
)
|
||||
} returns privateKey
|
||||
every {
|
||||
mockPkcs12KeyStore.getCertificateChain(internalAlias)
|
||||
} returns certChain
|
||||
every {
|
||||
mockAndroidKeyStore.containsAlias(expectedAlias)
|
||||
} returns false
|
||||
every {
|
||||
mockAndroidKeyStore.setKeyEntry(expectedAlias, privateKey, null, certChain)
|
||||
} just runs
|
||||
|
||||
assertEquals(
|
||||
ImportPrivateKeyResult.Success(alias = expectedAlias),
|
||||
keyDiskSource.importMutualTlsCertificate(
|
||||
key = pkcs12Bytes,
|
||||
alias = expectedAlias,
|
||||
password = password,
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `importMutualTlsCertificate should return Error when loading PKCS12 throws an exception`() {
|
||||
setupMockPkcs12KeyStore()
|
||||
val expectedAlias = "mockAlias"
|
||||
val pkcs12Bytes = "key.p12".toByteArray()
|
||||
val password = "password"
|
||||
|
||||
// Verify KeyStoreException is handled
|
||||
every {
|
||||
mockPkcs12KeyStore.load(any(), any())
|
||||
} throws KeyStoreException()
|
||||
assertEquals(
|
||||
ImportPrivateKeyResult.Error.UnsupportedKey,
|
||||
keyDiskSource.importMutualTlsCertificate(
|
||||
key = pkcs12Bytes,
|
||||
alias = expectedAlias,
|
||||
password = password,
|
||||
),
|
||||
) { "KeyStoreException was not handled correctly" }
|
||||
|
||||
// Verify IOException is handled
|
||||
every {
|
||||
mockPkcs12KeyStore.load(any(), any())
|
||||
} throws IOException()
|
||||
assertEquals(
|
||||
ImportPrivateKeyResult.Error.KeyStoreOperationFailed,
|
||||
keyDiskSource.importMutualTlsCertificate(
|
||||
key = pkcs12Bytes,
|
||||
alias = expectedAlias,
|
||||
password = password,
|
||||
),
|
||||
) { "IOException was not handled correctly" }
|
||||
|
||||
// Verify IOException with UnrecoverableKeyException cause is handled
|
||||
every {
|
||||
mockPkcs12KeyStore.load(any(), any())
|
||||
} throws IOException(UnrecoverableKeyException())
|
||||
|
||||
assertEquals(
|
||||
ImportPrivateKeyResult.Error.UnrecoverableKey,
|
||||
keyDiskSource.importMutualTlsCertificate(
|
||||
key = pkcs12Bytes,
|
||||
alias = expectedAlias,
|
||||
password = password,
|
||||
),
|
||||
)
|
||||
|
||||
// Verify IOException with unexpected cause is handled
|
||||
every {
|
||||
mockPkcs12KeyStore.load(any(), any())
|
||||
} throws IOException(Exception())
|
||||
assertEquals(
|
||||
ImportPrivateKeyResult.Error.KeyStoreOperationFailed,
|
||||
keyDiskSource.importMutualTlsCertificate(
|
||||
key = pkcs12Bytes,
|
||||
alias = expectedAlias,
|
||||
password = password,
|
||||
),
|
||||
) { "IOException with Unexpected exception cause was not handled correctly" }
|
||||
|
||||
// Verify CertificateException is handled
|
||||
every {
|
||||
mockPkcs12KeyStore.load(any(), any())
|
||||
} throws CertificateException()
|
||||
assertEquals(
|
||||
ImportPrivateKeyResult.Error.InvalidCertificateChain,
|
||||
keyDiskSource.importMutualTlsCertificate(
|
||||
key = pkcs12Bytes,
|
||||
alias = expectedAlias,
|
||||
password = password,
|
||||
),
|
||||
) { "CertificateException was not handled correctly" }
|
||||
|
||||
// Verify NoSuchAlgorithmException is handled
|
||||
every {
|
||||
mockPkcs12KeyStore.load(any(), any())
|
||||
} throws NoSuchAlgorithmException()
|
||||
assertEquals(
|
||||
ImportPrivateKeyResult.Error.UnsupportedKey,
|
||||
keyDiskSource.importMutualTlsCertificate(
|
||||
key = pkcs12Bytes,
|
||||
alias = expectedAlias,
|
||||
password = password,
|
||||
),
|
||||
) { "NoSuchAlgorithmException was not handled correctly" }
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `importMutualTlsCertificate should return UnsupportedKey when key store is empty`() {
|
||||
setupMockPkcs12KeyStore()
|
||||
val expectedAlias = "mockAlias"
|
||||
val pkcs12Bytes = "key.p12".toByteArray()
|
||||
val password = "password"
|
||||
|
||||
every { mockPkcs12KeyStore.aliases() } returns mockk {
|
||||
every { hasMoreElements() } returns false
|
||||
}
|
||||
|
||||
assertEquals(
|
||||
ImportPrivateKeyResult.Error.UnsupportedKey,
|
||||
keyDiskSource.importMutualTlsCertificate(
|
||||
key = pkcs12Bytes,
|
||||
alias = expectedAlias,
|
||||
password = password,
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
@Suppress("MaxLineLength")
|
||||
@Test
|
||||
fun `importMutualTlsCertificate should return UnrecoverableKey when unable to retrieve private key`() {
|
||||
setupMockPkcs12KeyStore()
|
||||
val expectedAlias = "mockAlias"
|
||||
val pkcs12Bytes = "key.p12".toByteArray()
|
||||
val password = "password"
|
||||
|
||||
every {
|
||||
mockPkcs12KeyStore.aliases()
|
||||
} returns mockk {
|
||||
every { hasMoreElements() } returns true
|
||||
every { nextElement() } returns "mockInternalAlias"
|
||||
}
|
||||
every {
|
||||
mockPkcs12KeyStore.getKey(
|
||||
"mockInternalAlias",
|
||||
password.toCharArray(),
|
||||
)
|
||||
} throws UnrecoverableKeyException()
|
||||
|
||||
assertEquals(
|
||||
ImportPrivateKeyResult.Error.UnrecoverableKey,
|
||||
keyDiskSource.importMutualTlsCertificate(
|
||||
key = pkcs12Bytes,
|
||||
alias = expectedAlias,
|
||||
password = password,
|
||||
),
|
||||
)
|
||||
|
||||
every {
|
||||
mockPkcs12KeyStore.getKey(
|
||||
"mockInternalAlias",
|
||||
password.toCharArray(),
|
||||
)
|
||||
} returns null
|
||||
assertEquals(
|
||||
ImportPrivateKeyResult.Error.UnrecoverableKey,
|
||||
keyDiskSource.importMutualTlsCertificate(
|
||||
key = pkcs12Bytes,
|
||||
alias = expectedAlias,
|
||||
password = password,
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
@Suppress("MaxLineLength")
|
||||
@Test
|
||||
fun `importMutualTlsCertificate should return InvalidCertificateChain when certificate chain is empty`() {
|
||||
setupMockPkcs12KeyStore()
|
||||
val expectedAlias = "mockAlias"
|
||||
val pkcs12Bytes = "key.p12".toByteArray()
|
||||
val password = "password"
|
||||
|
||||
every { mockPkcs12KeyStore.aliases() } returns mockk {
|
||||
every { hasMoreElements() } returns true
|
||||
every { nextElement() } returns "mockInternalAlias"
|
||||
}
|
||||
every {
|
||||
mockPkcs12KeyStore.getKey(
|
||||
"mockInternalAlias",
|
||||
password.toCharArray(),
|
||||
)
|
||||
} returns mockk()
|
||||
|
||||
// Verify empty certificate chain is handled
|
||||
every {
|
||||
mockPkcs12KeyStore.getCertificateChain("mockInternalAlias")
|
||||
} returns emptyArray()
|
||||
assertEquals(
|
||||
ImportPrivateKeyResult.Error.InvalidCertificateChain,
|
||||
keyDiskSource.importMutualTlsCertificate(
|
||||
key = pkcs12Bytes,
|
||||
alias = expectedAlias,
|
||||
password = password,
|
||||
),
|
||||
)
|
||||
|
||||
// Verify null certificate chain is handled
|
||||
every {
|
||||
mockPkcs12KeyStore.getCertificateChain("mockInternalAlias")
|
||||
} returns null
|
||||
assertEquals(
|
||||
ImportPrivateKeyResult.Error.InvalidCertificateChain,
|
||||
keyDiskSource.importMutualTlsCertificate(
|
||||
key = pkcs12Bytes,
|
||||
alias = expectedAlias,
|
||||
password = password,
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
@Suppress("MaxLineLength")
|
||||
@Test
|
||||
fun `importMutualTlsCertificate should return KeyStoreOperationFailed when saving to Android KeyStore throws KeyStoreException`() {
|
||||
setupMockAndroidKeyStore()
|
||||
val expectedAlias = "mockAlias"
|
||||
val pkcs12Bytes = "key.p12".toByteArray()
|
||||
val password = "password"
|
||||
|
||||
every { mockPkcs12KeyStore.aliases() } returns mockk {
|
||||
every { hasMoreElements() } returns true
|
||||
every { nextElement() } returns "mockInternalAlias"
|
||||
}
|
||||
|
||||
every {
|
||||
mockPkcs12KeyStore.getKey(
|
||||
"mockInternalAlias",
|
||||
password.toCharArray(),
|
||||
)
|
||||
} returns mockk()
|
||||
every {
|
||||
mockPkcs12KeyStore.getCertificateChain("mockInternalAlias")
|
||||
} returns arrayOf(mockk())
|
||||
|
||||
every {
|
||||
mockAndroidKeyStore.setKeyEntry(
|
||||
expectedAlias,
|
||||
any(),
|
||||
any(),
|
||||
any(),
|
||||
)
|
||||
} throws KeyStoreException()
|
||||
|
||||
assertEquals(
|
||||
ImportPrivateKeyResult.Error.KeyStoreOperationFailed,
|
||||
keyDiskSource.importMutualTlsCertificate(
|
||||
key = pkcs12Bytes,
|
||||
alias = expectedAlias,
|
||||
password = password,
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
@Suppress("MaxLineLength")
|
||||
@Test
|
||||
fun `importMutualTlsCertificate should return DuplicateAlias when alias already exists in AndroidKeyStore`() {
|
||||
setupMockAndroidKeyStore()
|
||||
setupMockPkcs12KeyStore()
|
||||
val expectedAlias = "mockAlias"
|
||||
val pkcs12Bytes = "key.p12".toByteArray()
|
||||
val password = "password"
|
||||
|
||||
every { mockPkcs12KeyStore.aliases() } returns mockk {
|
||||
every { hasMoreElements() } returns true
|
||||
every { nextElement() } returns "mockInternalAlias"
|
||||
}
|
||||
|
||||
every {
|
||||
mockPkcs12KeyStore.getKey(
|
||||
"mockInternalAlias",
|
||||
password.toCharArray(),
|
||||
)
|
||||
} returns mockk()
|
||||
every {
|
||||
mockPkcs12KeyStore.getCertificateChain("mockInternalAlias")
|
||||
} returns arrayOf(mockk())
|
||||
|
||||
every { mockAndroidKeyStore.containsAlias(expectedAlias) } returns true
|
||||
|
||||
assertEquals(
|
||||
ImportPrivateKeyResult.Error.DuplicateAlias,
|
||||
keyDiskSource.importMutualTlsCertificate(
|
||||
key = pkcs12Bytes,
|
||||
alias = expectedAlias,
|
||||
password = password,
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
private fun setupMockAndroidKeyStore() {
|
||||
every { KeyStore.getInstance("AndroidKeyStore") } returns mockAndroidKeyStore
|
||||
every { mockAndroidKeyStore.load(null) } just runs
|
||||
}
|
||||
|
||||
private fun setupMockPkcs12KeyStore() {
|
||||
every { KeyStore.getInstance("pkcs12") } returns mockPkcs12KeyStore
|
||||
every { mockPkcs12KeyStore.load(any(), any()) } just runs
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user