diff --git a/Cargo.toml b/Cargo.toml index 0caaac81..4f730c95 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -57,9 +57,10 @@ winauth = { version = "0.0.4", optional = true } [target.'cfg(unix)'.dependencies] libgssapi = { version = "0.8.1", optional = true, default-features = false } +libc = "0.2" [dependencies.async-native-tls] -version = "0.4" +version = "0.5" features = ["runtime-async-std"] optional = true @@ -191,7 +192,8 @@ all = [ "bigdecimal", "native-tls", ] -default = ["tds73", "winauth", "native-tls"] +default = ["tds80", "winauth", "native-tls"] +tds80 = ["tds73"] tds73 = [] docs = [] sql-browser-async-std = ["async-std"] diff --git a/src/client/config.rs b/src/client/config.rs index fff68bc1..bdc24cc6 100644 --- a/src/client/config.rs +++ b/src/client/config.rs @@ -32,6 +32,8 @@ pub struct Config { pub(crate) trust: TrustConfig, pub(crate) auth: AuthMethod, pub(crate) readonly: bool, + pub(crate) hostname_in_certificate: Option, + pub(crate) client_name: Option, } #[derive(Clone, Debug)] @@ -65,6 +67,8 @@ impl Default for Config { trust: TrustConfig::Default, auth: AuthMethod::None, readonly: false, + hostname_in_certificate: None, + client_name: None, } } } @@ -149,14 +153,29 @@ impl Config { /// Will panic in case `trust_cert` was called before. /// /// - Defaults to validating the server certificate is validated against system's certificate storage. - pub fn trust_cert_ca(&mut self, path: impl ToString) { + pub fn trust_cert_ca(&mut self, path: impl Into) { if let TrustConfig::TrustAll = &self.trust { panic!("'trust_cert' and 'trust_cert_ca' are mutual exclusive! Only use one.") } else { - self.trust = TrustConfig::CaCertificateLocation(PathBuf::from(path.to_string())) + self.trust = TrustConfig::CaCertificateLocation(path.into()) } } + /// Sets the hostname to be used for certificate validation. + /// If not set, the hostname from `host` will be used. + /// + /// - Defaults to the value of `host`. + pub fn hostname_in_certificate(&mut self, hostname: impl ToString) { + self.hostname_in_certificate = Some(hostname.to_string()); + } + + /// Sets the client name to be sent to the server. + /// + /// - Defaults to the current workstation id (hostname). + pub fn client_name(&mut self, name: impl ToString) { + self.client_name = Some(name.to_string()); + } + /// Sets the authentication method. /// /// - Defaults to `None`. @@ -190,6 +209,12 @@ impl Config { } } + pub(crate) fn get_hostname_in_certificate(&self) -> &str { + self.hostname_in_certificate + .as_deref() + .unwrap_or_else(|| self.get_host()) + } + /// Get the host address including port pub fn get_addr(&self) -> String { format!("{}:{}", self.get_host(), self.get_port()) @@ -210,7 +235,7 @@ impl Config { /// |`database`|``|The name of the database.| /// |`TrustServerCertificate`|`true`,`false`,`yes`,`no`|Specifies whether the driver trusts the server certificate when connecting using TLS. Cannot be used toghether with `TrustServerCertificateCA`| /// |`TrustServerCertificateCA`|``|Path to a `pem`, `crt` or `der` certificate file. Cannot be used together with `TrustServerCertificate`| - /// |`encrypt`|`true`,`false`,`yes`,`no`,`DANGER_PLAINTEXT`|Specifies whether the driver uses TLS to encrypt communication.| + /// |`encrypt`|`strict`,`true`,`false`,`yes`,`no`,`DANGER_PLAINTEXT`|Specifies whether the driver uses TLS to encrypt communication.| /// |`Application Name`, `ApplicationName`|``|Sets the application name for the connection.| /// /// [ADO.NET connection string]: https://docs.microsoft.com/en-us/dotnet/framework/data/adonet/connection-strings @@ -265,10 +290,18 @@ impl Config { builder.trust_cert_ca(ca); } + if let Some(hostname_in_cert) = s.hostname_in_certificate() { + builder.hostname_in_certificate(hostname_in_cert); + } + builder.encryption(s.encrypt()?); builder.readonly(s.readonly()); + if let Some(client_name) = s.client_name() { + builder.client_name(client_name); + } + Ok(builder) } } @@ -346,6 +379,20 @@ pub(crate) trait ConfigString { .map(|ca| ca.to_string()) } + fn hostname_in_certificate(&self) -> Option { + self.dict() + .get("hostnameincertificate") + .or_else(|| self.dict().get("hostname in certificate")) + .map(|ca| ca.to_string()) + } + + fn client_name(&self) -> Option { + self.dict() + .get("workstationid") + .or_else(|| self.dict().get("workstation id")) + .map(|name| name.to_string()) + } + #[cfg(any( feature = "rustls", feature = "native-tls", @@ -358,6 +405,9 @@ pub(crate) trait ConfigString { Ok(true) => Ok(EncryptionLevel::Required), Ok(false) => Ok(EncryptionLevel::Off), Err(_) if val == "DANGER_PLAINTEXT" => Ok(EncryptionLevel::NotSupported), + Err(_) if val.eq_ignore_ascii_case("strict") && cfg!(feature = "tds80") => { + Ok(EncryptionLevel::Strict) + } Err(e) => Err(e), }) .unwrap_or(Ok(EncryptionLevel::Off)) diff --git a/src/client/config/ado_net.rs b/src/client/config/ado_net.rs index 94df9ca3..e0c1c39d 100644 --- a/src/client/config/ado_net.rs +++ b/src/client/config/ado_net.rs @@ -470,6 +470,21 @@ mod tests { Ok(()) } + #[test] + #[cfg(any( + feature = "rustls", + feature = "native-tls", + feature = "vendored-openssl" + ))] + fn encryption_parsing_strict() -> crate::Result<()> { + let test_str = "encrypt=strict"; + let ado: AdoNetConfig = test_str.parse()?; + + assert_eq!(EncryptionLevel::Strict, ado.encrypt()?); + + Ok(()) + } + #[test] fn application_name_parsing() -> crate::Result<()> { let test_str = "Application Name=meow"; @@ -484,4 +499,40 @@ mod tests { Ok(()) } + + #[test] + fn client_name_parsing() -> crate::Result<()> { + let test_str = "workstationid=meow"; + let ado: AdoNetConfig = test_str.parse()?; + + assert_eq!(Some("meow".into()), ado.client_name()); + + let test_str = "Workstation ID=meow"; + let ado: AdoNetConfig = test_str.parse()?; + + assert_eq!(Some("meow".into()), ado.client_name()); + + Ok(()) + } + + #[test] + fn hostname_in_certificate_parsing() -> crate::Result<()> { + let test_str = "HostNameInCertificate=foo.example.com"; + let ado: AdoNetConfig = test_str.parse()?; + + assert_eq!( + Some("foo.example.com".into()), + ado.hostname_in_certificate() + ); + + let test_str = "HostName In Certificate=foo.example.com"; + let ado: AdoNetConfig = test_str.parse()?; + + assert_eq!( + Some("foo.example.com".into()), + ado.hostname_in_certificate() + ); + + Ok(()) + } } diff --git a/src/client/connection.rs b/src/client/connection.rs index 09d37256..07d58ba8 100644 --- a/src/client/connection.rs +++ b/src/client/connection.rs @@ -79,7 +79,17 @@ impl Connection { context }; - let transport = Framed::new(MaybeTlsStream::Raw(tcp_stream), PacketCodec); + let transport = match config.encryption { + EncryptionLevel::Strict => { + event!(Level::INFO, "Performing a TLS handshake"); + let mut pre_login_stream = TlsPreloginWrapper::new(tcp_stream); + pre_login_stream.handshake_complete(); + let stream = create_tls_stream(&config, pre_login_stream).await?; + event!(Level::INFO, "TLS handshake successful"); + Framed::new(MaybeTlsStream::Tls(stream), PacketCodec) + } + _ => Framed::new(MaybeTlsStream::Raw(tcp_stream), PacketCodec), + }; let mut connection = Self { transport, @@ -98,17 +108,7 @@ impl Connection { let connection = connection.tls_handshake(&config, encryption).await?; - let mut connection = connection - .login( - config.auth, - encryption, - config.database, - config.host, - config.application_name, - config.readonly, - prelogin, - ) - .await?; + let mut connection = connection.login(config, encryption, prelogin).await?; connection.flush_done().await?; @@ -284,34 +284,33 @@ impl Connection { /// Defines the login record rules with SQL Server. Authentication with /// connection options. - #[allow(clippy::too_many_arguments)] - async fn login<'a>( + async fn login( mut self, - auth: AuthMethod, + config: Config, encryption: EncryptionLevel, - db: Option, - server_name: Option, - application_name: Option, - readonly: bool, prelogin: PreloginMessage, ) -> crate::Result { let mut login_message = LoginMessage::new(); - if let Some(db) = db { + if let Some(db) = config.database { login_message.db_name(db); } - if let Some(server_name) = server_name { + if let Some(server_name) = config.host { login_message.server_name(server_name); } - if let Some(app_name) = application_name { + if let Some(app_name) = config.application_name { login_message.app_name(app_name); } - login_message.readonly(readonly); + if let Some(client_name) = config.client_name { + login_message.hostname(client_name); + } + + login_message.readonly(config.readonly); - match auth { + match config.auth { #[cfg(all(windows, feature = "winauth"))] AuthMethod::Integrated => { let mut client = NtlmSspiBuilder::new() @@ -444,37 +443,47 @@ impl Connection { config: &Config, encryption: EncryptionLevel, ) -> crate::Result { - if encryption != EncryptionLevel::NotSupported { - event!(Level::INFO, "Performing a TLS handshake"); - - let Self { - transport, context, .. - } = self; - let mut stream = match transport.into_inner() { - MaybeTlsStream::Raw(tcp) => { - create_tls_stream(config, TlsPreloginWrapper::new(tcp)).await? - } - _ => unreachable!(), - }; - - stream.get_mut().handshake_complete(); - event!(Level::INFO, "TLS handshake successful"); + match encryption { + EncryptionLevel::NotSupported => { + event!( + Level::WARN, + "TLS encryption is not enabled. All traffic including the login credentials are not encrypted." + ); + Ok(self) + } + EncryptionLevel::Strict => { + // In Strict mode, we should already be in TLS stream after prelogin, so just return self. + event!( + Level::TRACE, + "Already in TLS stream due to Strict encryption level, skipping handshake." + ); + Ok(self) + } + EncryptionLevel::Off | EncryptionLevel::On | EncryptionLevel::Required => { + event!(Level::INFO, "Performing a TLS handshake"); + + let Self { + transport, context, .. + } = self; + let mut stream = match transport.into_inner() { + MaybeTlsStream::Raw(tcp) => { + create_tls_stream(config, TlsPreloginWrapper::new(tcp)).await? + } + _ => unreachable!(), + }; - let transport = Framed::new(MaybeTlsStream::Tls(stream), PacketCodec); + stream.get_mut().handshake_complete(); + event!(Level::INFO, "TLS handshake successful"); - Ok(Self { - transport, - context, - flushed: false, - buf: BytesMut::new(), - }) - } else { - event!( - Level::WARN, - "TLS encryption is not enabled. All traffic including the login credentials are not encrypted." - ); + let transport = Framed::new(MaybeTlsStream::Tls(stream), PacketCodec); - Ok(self) + Ok(Self { + transport, + context, + flushed: false, + buf: BytesMut::new(), + }) + } } } diff --git a/src/client/tls_stream.rs b/src/client/tls_stream.rs index 9eba1060..d2dc246f 100644 --- a/src/client/tls_stream.rs +++ b/src/client/tls_stream.rs @@ -1,6 +1,8 @@ use crate::Config; use futures_util::io::{AsyncRead, AsyncWrite}; +pub(crate) const TDS_ALPN_PROTOCOL_NAME: &str = "tds/8.0"; + #[cfg(feature = "native-tls")] mod native_tls_stream; diff --git a/src/client/tls_stream/native_tls_stream.rs b/src/client/tls_stream/native_tls_stream.rs index cf5591d8..a9405890 100644 --- a/src/client/tls_stream/native_tls_stream.rs +++ b/src/client/tls_stream/native_tls_stream.rs @@ -14,17 +14,21 @@ pub(crate) async fn create_tls_stream( ) -> crate::Result> { let mut builder = TlsConnector::new(); + if matches!(config.encryption, crate::EncryptionLevel::Strict) { + builder = builder.request_alpns(&[super::TDS_ALPN_PROTOCOL_NAME]); + } + match &config.trust { TrustConfig::CaCertificateLocation(path) => { if let Ok(buf) = fs::read(path) { let cert = match path.extension() { Some(ext) - if ext.to_ascii_lowercase() == "pem" - || ext.to_ascii_lowercase() == "crt" => + if ext.eq_ignore_ascii_case("pem") + || ext.eq_ignore_ascii_case("crt") => { Some(Certificate::from_pem(&buf)?) } - Some(ext) if ext.to_ascii_lowercase() == "der" => { + Some(ext) if ext.eq_ignore_ascii_case("der") => { Some(Certificate::from_der(&buf)?) } Some(_) | None => return Err(Error::Io { @@ -56,5 +60,7 @@ pub(crate) async fn create_tls_stream( } } - Ok(builder.connect(config.get_host(), stream).await?) + Ok(builder + .connect(config.get_hostname_in_certificate(), stream) + .await?) } diff --git a/src/client/tls_stream/opentls_tls_stream.rs b/src/client/tls_stream/opentls_tls_stream.rs index 1f028669..537b235c 100644 --- a/src/client/tls_stream/opentls_tls_stream.rs +++ b/src/client/tls_stream/opentls_tls_stream.rs @@ -14,6 +14,10 @@ pub(crate) async fn create_tls_stream( ) -> crate::Result> { let mut builder = TlsConnector::new(); + if matches!(config.encryption, crate::EncryptionLevel::Strict) { + event!(Level::WARN, "OpenTLS does not support ALPN, so the TDS 8.0 ALPN protocol will not be requested. SQL Server will assume TDS 8.0"); + } + match &config.trust { TrustConfig::CaCertificateLocation(path) => { if let Ok(buf) = fs::read(path) { @@ -56,5 +60,7 @@ pub(crate) async fn create_tls_stream( } } - Ok(builder.connect(config.get_host(), stream).await?) + Ok(builder + .connect(config.get_hostname_in_certificate(), stream) + .await?) } diff --git a/src/client/tls_stream/rustls_tls_stream.rs b/src/client/tls_stream/rustls_tls_stream.rs index e417583a..77521f77 100644 --- a/src/client/tls_stream/rustls_tls_stream.rs +++ b/src/client/tls_stream/rustls_tls_stream.rs @@ -58,13 +58,25 @@ impl ServerCertVerifier for NoCertVerifier { ) -> Result { Ok(HandshakeSignatureValid::assertion()) } + + fn verify_tls13_signature( + &self, + _message: &[u8], + _cert: &Certificate, + _dss: &DigitallySignedStruct, + ) -> Result { + Ok(HandshakeSignatureValid::assertion()) + } } fn get_server_name(config: &Config) -> crate::Result { - match (ServerName::try_from(config.get_host()), &config.trust) { + match ( + ServerName::try_from(config.get_hostname_in_certificate()), + &config.trust, + ) { (Ok(sn), _) => Ok(sn), (Err(_), TrustConfig::TrustAll) => { - Ok(ServerName::try_from("placeholder.domain.com").unwrap()) + Ok(ServerName::try_from("placeholder.example.com").unwrap()) } (Err(e), _) => Err(crate::Error::Tls(e.to_string())), } @@ -76,13 +88,13 @@ impl TlsStream { let builder = ClientConfig::builder().with_safe_defaults(); - let client_config = match &config.trust { + let mut client_config = match &config.trust { TrustConfig::CaCertificateLocation(path) => { if let Ok(buf) = fs::read(path) { let cert = match path.extension() { Some(ext) - if ext.to_ascii_lowercase() == "pem" - || ext.to_ascii_lowercase() == "crt" => + if ext.eq_ignore_ascii_case("pem") + || ext.eq_ignore_ascii_case("crt") => { let pem_cert = rustls_pemfile::certs(&mut buf.as_slice())?; if pem_cert.len() != 1 { @@ -94,7 +106,7 @@ impl TlsStream { Certificate(pem_cert.into_iter().next().unwrap()) } - Some(ext) if ext.to_ascii_lowercase() == "der" => { + Some(ext) if ext.eq_ignore_ascii_case("der") => { Certificate(buf) } Some(_) | None => return Err(crate::Error::Io { @@ -134,6 +146,12 @@ impl TlsStream { } }; + if matches!(config.encryption, crate::EncryptionLevel::Strict) { + client_config + .alpn_protocols + .push(super::TDS_ALPN_PROTOCOL_NAME.as_bytes().to_vec()); + } + let connector = TlsConnector::from(Arc::new(client_config)); let tls_stream = connector diff --git a/src/lib.rs b/src/lib.rs index 882f5ad3..60df87bc 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -250,6 +250,16 @@ #![doc(test(attr(deny(rust_2018_idioms, warnings))))] #![doc(test(attr(allow(unused_extern_crates, unused_variables))))] +#[cfg(all( + feature = "tds80", + not(any( + feature = "rustls", + feature = "native-tls", + feature = "vendored-openssl" + )) +))] +compile_error!("The `tds80` feature requires one of the TLS features to be enabled."); + #[cfg(feature = "bigdecimal")] pub(crate) extern crate bigdecimal_ as bigdecimal; diff --git a/src/tds.rs b/src/tds.rs index f4b6f925..804b05c7 100644 --- a/src/tds.rs +++ b/src/tds.rs @@ -25,6 +25,31 @@ uint_enum! { NotSupported = 2, /// Encrypt everything and fail if not possible Required = 3, + /// Start encryption before TDS prelogin and encrypt everything, fail if not possible + Strict = 4, } } + +impl EncryptionLevel { + pub(crate) fn as_wire_value(&self) -> u8 { + match self { + EncryptionLevel::Strict => EncryptionLevel::Required as u8, + other => *other as u8, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn encryption_level_as_wire_value() { + assert_eq!(EncryptionLevel::Off.as_wire_value(), 0); + assert_eq!(EncryptionLevel::On.as_wire_value(), 1); + assert_eq!(EncryptionLevel::NotSupported.as_wire_value(), 2); + assert_eq!(EncryptionLevel::Required.as_wire_value(), 3); + assert_eq!(EncryptionLevel::Strict.as_wire_value(), 3); + } +} diff --git a/src/tds/codec/login.rs b/src/tds/codec/login.rs index 265db381..dcd11b60 100644 --- a/src/tds/codec/login.rs +++ b/src/tds/codec/login.rs @@ -183,10 +183,59 @@ impl<'a> LoginMessage<'a> { option_flags_2: OptionFlag2::InitLangFatal | OptionFlag2::OdbcDriver, option_flags_3: BitFlags::from_flag(OptionFlag3::UnknownCollationHandling), app_name: "tiberius".into(), + hostname: Self::get_hostname(), ..Default::default() } } + fn get_hostname() -> Cow<'static, str> { + #[cfg(windows)] + fn get_computer_name() -> io::Result { + unsafe extern "system" { + // https://learn.microsoft.com/en-us/windows/win32/api/winbase/nf-winbase-getcomputernamew + fn GetComputerNameW(lpBuffer: *mut u16, nSize: *mut u32) -> i32; + } + + // MAX_COMPUTERNAME_LENGTH is 15 and we need 1 byte for the null terminator + let mut buffer = [0u16; 15 + 1]; + let mut size = buffer.len() as u32; + let result = unsafe { GetComputerNameW(buffer.as_mut_ptr(), &mut size) }; + if result == 0 { + let lerr = io::Error::last_os_error(); + tracing::error!("GetComputerNameW failed: {lerr}"); + Err(lerr) + } else { + Ok(String::from_utf16_lossy(&buffer[..size as usize])) + } + } + + #[cfg(target_family = "unix")] + fn get_computer_name() -> io::Result { + // Extract from the man page of gethostname(): + // gethostname() returns the null-terminated hostname in the character array name, which has a length of len bytes. If the null-terminated hostname is too large to fit, then the + // name is truncated, and no error is returned (but see NOTES below). POSIX.1 says that if such truncation occurs, then it is unspecified whether the returned buffer includes a + // terminating null byte. + + let mut buffer = [0u8; 255 + 1]; + let result = + unsafe { libc::gethostname(buffer.as_mut_ptr(), buffer.len() as libc::size_t) }; + if result != 0 { + let lerr = io::Error::last_os_error(); + tracing::error!("gethostname failed: {lerr}"); + Err(lerr) + } else { + // Since the buffer *MAY* or *MAY NOT* be null-terminated, we need to either + // find the first null-byte or assume the entire buffer is the host name + match buffer.split(|b| *b == 0).next() { + Some(hostname) => Ok(String::from_utf8_lossy(hostname).into_owned()), + None => Ok(String::from_utf8_lossy(&buffer).into_owned()), + } + } + } + + get_computer_name().map(Cow::Owned).unwrap_or_default() + } + #[cfg(any(all(unix, feature = "integrated-auth-gssapi"), windows))] pub fn integrated_security(&mut self, bytes: Option>) { if bytes.is_some() { @@ -240,6 +289,10 @@ impl<'a> LoginMessage<'a> { self.type_flags.remove(LoginTypeFlag::ReadOnlyIntent); } } + + pub fn hostname(&mut self, hostname: impl Into>) { + self.hostname = hostname.into(); + } } impl<'a> Encode for LoginMessage<'a> { diff --git a/src/tds/codec/pre_login.rs b/src/tds/codec/pre_login.rs index eb4c27e6..da0b585f 100644 --- a/src/tds/codec/pre_login.rs +++ b/src/tds/codec/pre_login.rs @@ -72,6 +72,7 @@ impl PreloginMessage { | (EncryptionLevel::On, EncryptionLevel::NotSupported) => { panic!("Server does not allow the requested encryption level.") } + (EncryptionLevel::Strict, _) => EncryptionLevel::Strict, (_, _) => EncryptionLevel::On, } } @@ -110,7 +111,7 @@ impl Encode for PreloginMessage { // encryption fields.push((PRELOGIN_ENCRYPTION, 0x01)); // encryption - data_cursor.write_u8(self.encryption as u8)?; + data_cursor.write_u8(self.encryption.as_wire_value())?; // threadid fields.push((PRELOGIN_THREADID, 0x04)); // thread id