Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weโ€™ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PM-16157] Support self-host servers using TLS with Client Authentication (mTLS) #4486

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import kotlinx.serialization.Serializable
* Represents URLs for various Bitwarden domains.
*
* @property base The overall base URL.
* @property keyAlias A key alias to use for connections with the server.
* @property api Separate base URL for the "/api" domain (if applicable).
* @property identity Separate base URL for the "/identity" domain (if applicable).
* @property icon Separate base URL for the icon domain (if applicable).
Expand All @@ -19,6 +20,9 @@ data class EnvironmentUrlDataJson(
@SerialName("base")
val base: String,

@SerialName("keyAlias")
val keyAlias: String? = null,

@SerialName("api")
val api: String? = null,

Expand Down Expand Up @@ -51,6 +55,7 @@ data class EnvironmentUrlDataJson(
*/
val DEFAULT_LEGACY_US: EnvironmentUrlDataJson = EnvironmentUrlDataJson(
base = "https://vault.bitwarden.com",
keyAlias = null,
api = "https://api.bitwarden.com",
identity = "https://identity.bitwarden.com",
icon = "https://icons.bitwarden.net",
Expand All @@ -71,6 +76,7 @@ data class EnvironmentUrlDataJson(
*/
val DEFAULT_LEGACY_EU: EnvironmentUrlDataJson = EnvironmentUrlDataJson(
base = "https://vault.bitwarden.eu",
keyAlias = null,
api = "https://api.bitwarden.eu",
identity = "https://identity.bitwarden.eu",
icon = "https://icons.bitwarden.eu",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package com.x8bit.bitwarden.data.platform.datasource.network.di

import android.content.Context
import com.x8bit.bitwarden.data.auth.datasource.disk.AuthDiskSource
import com.x8bit.bitwarden.data.platform.datasource.disk.EnvironmentDiskSource
import com.x8bit.bitwarden.data.platform.datasource.network.authenticator.RefreshAuthenticator
import com.x8bit.bitwarden.data.platform.datasource.network.interceptor.AuthTokenInterceptor
import com.x8bit.bitwarden.data.platform.datasource.network.interceptor.BaseUrlInterceptors
Expand All @@ -14,9 +16,13 @@ import com.x8bit.bitwarden.data.platform.datasource.network.service.EventService
import com.x8bit.bitwarden.data.platform.datasource.network.service.EventServiceImpl
import com.x8bit.bitwarden.data.platform.datasource.network.service.PushService
import com.x8bit.bitwarden.data.platform.datasource.network.service.PushServiceImpl
import com.x8bit.bitwarden.data.platform.datasource.network.util.TLSHelper
import com.x8bit.bitwarden.data.platform.repository.KeyChainRepository
import com.x8bit.bitwarden.data.platform.repository.KeyChainRepositoryImpl
import dagger.Module
import dagger.Provides
import dagger.hilt.InstallIn
import dagger.hilt.android.qualifiers.ApplicationContext
import dagger.hilt.components.SingletonComponent
import kotlinx.serialization.json.Json
import kotlinx.serialization.modules.SerializersModule
Expand Down Expand Up @@ -70,6 +76,19 @@ object PlatformNetworkModule {
@Singleton
fun providesRefreshAuthenticator(): RefreshAuthenticator = RefreshAuthenticator()

@Provides
@Singleton
fun providesKeyChainRepository(
environmentDiskSource: EnvironmentDiskSource,
@ApplicationContext context: Context,
): KeyChainRepository =
KeyChainRepositoryImpl(environmentDiskSource = environmentDiskSource, context = context)

@Provides
@Singleton
fun providesTlsHelper(keyChainRepository: KeyChainRepository): TLSHelper =
TLSHelper(keyChainRepository = keyChainRepository)

@Provides
@Singleton
fun provideRetrofits(
Expand All @@ -78,13 +97,15 @@ object PlatformNetworkModule {
headersInterceptor: HeadersInterceptor,
refreshAuthenticator: RefreshAuthenticator,
json: Json,
tlsHelper: TLSHelper,
): Retrofits =
RetrofitsImpl(
authTokenInterceptor = authTokenInterceptor,
baseUrlInterceptors = baseUrlInterceptors,
headersInterceptor = headersInterceptor,
refreshAuthenticator = refreshAuthenticator,
json = json,
tlsHelper = tlsHelper,
)

@Provides
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import com.x8bit.bitwarden.data.platform.datasource.network.interceptor.BaseUrlI
import com.x8bit.bitwarden.data.platform.datasource.network.interceptor.BaseUrlInterceptors
import com.x8bit.bitwarden.data.platform.datasource.network.interceptor.HeadersInterceptor
import com.x8bit.bitwarden.data.platform.datasource.network.util.HEADER_KEY_AUTHORIZATION
import com.x8bit.bitwarden.data.platform.datasource.network.util.TLSHelper
import kotlinx.serialization.json.Json
import okhttp3.MediaType.Companion.toMediaType
import okhttp3.OkHttpClient
Expand All @@ -24,6 +25,7 @@ class RetrofitsImpl(
headersInterceptor: HeadersInterceptor,
refreshAuthenticator: RefreshAuthenticator,
json: Json,
tlsHelper: TLSHelper,
) : Retrofits {
//region Authenticated Retrofits

Expand Down Expand Up @@ -84,7 +86,7 @@ class RetrofitsImpl(
}

private val baseOkHttpClient: OkHttpClient =
OkHttpClient.Builder()
tlsHelper.setupOkHttpClientSSLSocketFactory(OkHttpClient.Builder())
.addInterceptor(headersInterceptor)
.build()

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
package com.x8bit.bitwarden.data.platform.datasource.network.util

import com.x8bit.bitwarden.data.platform.repository.KeyChainRepository
import okhttp3.OkHttpClient
import java.net.Socket
import java.security.KeyStore
import java.security.Principal
import java.security.PrivateKey
import java.security.cert.X509Certificate
import javax.inject.Inject
import javax.inject.Named
import javax.net.ssl.SSLContext
import javax.net.ssl.TrustManagerFactory
import javax.net.ssl.X509ExtendedKeyManager
import javax.net.ssl.X509TrustManager

/**
* Helper class for setting up TLS (Transport Layer Security) for OkHttp client.
* It provides functionality to setup a SSL Socket factory using KeyChainRepository certificate and
* key.
*
* @param keyChainRepository repository for access to certificate and private key.
*/
class TLSHelper @Inject constructor(
@Named("keyChainRepository") private val keyChainRepository: KeyChainRepository,
) {

/**
* Sets up a SSL Socket factory using KeyChainRepository certificate and key.
*/
fun setupOkHttpClientSSLSocketFactory(builder: OkHttpClient.Builder): OkHttpClient.Builder {
val trustManagerFactory =
TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm())
trustManagerFactory.init(null as KeyStore?)
val trustManagers = trustManagerFactory.trustManagers

val sslContext = SSLContext.getInstance("TLS")
sslContext.init(arrayOf(getMTLSKeyManagerForOKHTTP()), trustManagers, null)

builder.sslSocketFactory(sslContext.socketFactory, trustManagers[0] as X509TrustManager)

return builder
}

private fun getMTLSKeyManagerForOKHTTP(): X509ExtendedKeyManager {
return object : X509ExtendedKeyManager() {
override fun getClientAliases(
p0: String?,
p1: Array<out Principal>?,
): Array<String> {
return emptyArray()
}

override fun chooseClientAlias(
p0: Array<out String>?,
p1: Array<out Principal>?,
p2: Socket?,
): String {
return ""
}

override fun getServerAliases(
p0: String?,
p1: Array<out Principal>?,
): Array<String> {
return arrayOf()
}

override fun chooseServerAlias(
p0: String?,
p1: Array<out Principal>?,
p2: Socket?,
): String {
return ""
}

override fun getCertificateChain(p0: String?): Array<X509Certificate>? {
return keyChainRepository.getCertificateChain()
}

override fun getPrivateKey(p0: String?): PrivateKey? {
return keyChainRepository.getPrivateKey()
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package com.x8bit.bitwarden.data.platform.repository

import java.security.PrivateKey
import java.security.cert.X509Certificate

/**
* Repository for accessing the KeyChain.
*/
interface KeyChainRepository {

/**
* Returns the private key.
*/
fun getPrivateKey(): PrivateKey?

/**
* Returns the certificate chain.
*/
fun getCertificateChain(): Array<X509Certificate>?
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package com.x8bit.bitwarden.data.platform.repository

import android.content.Context
import android.security.KeyChain
import android.security.KeyChainException
import com.x8bit.bitwarden.data.platform.datasource.disk.EnvironmentDiskSource
import com.x8bit.bitwarden.data.platform.repository.util.toEnvironmentUrlsOrDefault
import java.security.PrivateKey
import java.security.cert.X509Certificate
import javax.inject.Inject

/**
* Default implementation of [KeyChainRepository].
*/
class KeyChainRepositoryImpl @Inject constructor(
environmentDiskSource: EnvironmentDiskSource,
val context: Context,
) : KeyChainRepository {
private var alias: String? = null
private var key: PrivateKey? = null
private var chain: Array<X509Certificate>? = null

init {
alias = environmentDiskSource
.preAuthEnvironmentUrlData
.toEnvironmentUrlsOrDefault()
.environmentUrlData
.keyAlias
}

override fun getPrivateKey(): PrivateKey? {
if (key == null && !alias.isNullOrEmpty()) {
key = try {
KeyChain.getPrivateKey(context, alias!!)
} catch (_: KeyChainException) {
null
}
}

return key
}

override fun getCertificateChain(): Array<X509Certificate>? {
if (chain == null && !alias.isNullOrEmpty()) {
chain = try {
KeyChain.getCertificateChain(context, alias!!)
} catch (_: KeyChainException) {
null
}
}

return chain
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ import com.x8bit.bitwarden.ui.platform.components.header.BitwardenListHeaderText
import com.x8bit.bitwarden.ui.platform.components.model.CardStyle
import com.x8bit.bitwarden.ui.platform.components.scaffold.BitwardenScaffold
import com.x8bit.bitwarden.ui.platform.components.util.rememberVectorPainter
import com.x8bit.bitwarden.ui.platform.composition.LocalKeyChainManager
import com.x8bit.bitwarden.ui.platform.manager.keychain.KeyChainManager
import kotlinx.collections.immutable.persistentListOf

/**
Expand All @@ -48,6 +50,7 @@ import kotlinx.collections.immutable.persistentListOf
@Composable
fun EnvironmentScreen(
onNavigateBack: () -> Unit,
keyChainManager: KeyChainManager = LocalKeyChainManager.current,
viewModel: EnvironmentViewModel = hiltViewModel(),
) {
val state by viewModel.stateFlow.collectAsStateWithLifecycle()
Expand All @@ -58,6 +61,14 @@ fun EnvironmentScreen(
is EnvironmentEvent.ShowToast -> {
Toast.makeText(context, event.message(context.resources), Toast.LENGTH_SHORT).show()
}

is EnvironmentEvent.ShowSystemCertificateSelectionDialog -> {
viewModel.trySendAction(
EnvironmentAction.SystemCertificateSelectionResultReceive(
result = keyChainManager.choosePrivateKeyAlias(state.serverUrl),
),
)
}
}
}

Expand Down Expand Up @@ -138,6 +149,38 @@ fun EnvironmentScreen(
.standardHorizontalMargin(),
)

Spacer(modifier = Modifier.height(16.dp))

BitwardenListHeaderText(
label = stringResource(id = R.string.client_certificate_mtls),
modifier = Modifier
.fillMaxWidth()
.padding(horizontal = 16.dp),
)

BitwardenTextField(
label = stringResource(id = R.string.certificate_alias),
value = state.keyAlias,
supportingText = stringResource(
id = R.string.certificate_used_for_client_authentication,
),
onValueChange = {},
readOnly = true,
cardStyle = CardStyle.Full,
modifier = Modifier
.fillMaxWidth()
.testTag("KeyAliasEntry")
.padding(horizontal = 16.dp),
)

BitwardenTextButton(
label = stringResource(id = R.string.use_system_certificate),
onClick = remember(viewModel) {
{ viewModel.trySendAction(EnvironmentAction.UseSystemCertificateClick) }
},
modifier = Modifier.standardHorizontalMargin(),
)

Spacer(modifier = Modifier.height(height = 16.dp))

BitwardenListHeaderText(
Expand Down
Loading