diff --git a/Cargo.toml b/Cargo.toml index 0caaac815..d0b0b9107 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -51,6 +51,7 @@ async-trait = "0.1" connection-string = "0.2" num-traits = "0.2" uuid = "1.0" +zeroize = "1.8.2" [target.'cfg(windows)'.dependencies] winauth = { version = "0.0.4", optional = true } diff --git a/src/client/auth.rs b/src/client/auth.rs index 208d8d060..3abf42df8 100644 --- a/src/client/auth.rs +++ b/src/client/auth.rs @@ -1,18 +1,15 @@ use std::fmt::Debug; +use zeroize::Zeroizing; #[derive(Clone, PartialEq, Eq)] pub struct SqlServerAuth { user: String, - password: String, + password: Zeroizing, } impl SqlServerAuth { - pub(crate) fn user(&self) -> &str { - &self.user - } - - pub(crate) fn password(&self) -> &str { - &self.password + pub(crate) fn into_credentials(self) -> (String, Zeroizing) { + (self.user, self.password) } } @@ -79,7 +76,7 @@ impl AuthMethod { pub fn sql_server(user: impl ToString, password: impl ToString) -> Self { Self::SqlServer(SqlServerAuth { user: user.to_string(), - password: password.to_string(), + password: Zeroizing::new(password.to_string()), }) } @@ -104,3 +101,25 @@ impl AuthMethod { Self::AADToken(token.to_string()) } } + +#[cfg(test)] +mod tests { + use super::AuthMethod; + use zeroize::Zeroize; + + #[test] + fn sql_server_password_can_be_consumed_and_zeroized() { + let AuthMethod::SqlServer(auth) = AuthMethod::sql_server("sa", "secret") else { + unreachable!(); + }; + + let (user, mut password) = auth.into_credentials(); + + assert_eq!("sa", user); + assert_eq!("secret", password.as_str()); + + password.zeroize(); + + assert!(password.is_empty()); + } +} diff --git a/src/client/connection.rs b/src/client/connection.rs index 09d372561..9cfa6044c 100644 --- a/src/client/connection.rs +++ b/src/client/connection.rs @@ -20,7 +20,7 @@ use asynchronous_codec::Framed; use bytes::BytesMut; #[cfg(any(windows, feature = "integrated-auth-gssapi"))] use codec::TokenSspi; -use futures_util::io::{AsyncRead, AsyncWrite}; +use futures_util::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; use futures_util::ready; use futures_util::sink::SinkExt; use futures_util::stream::{Stream, TryStream, TryStreamExt}; @@ -39,6 +39,7 @@ use task::Poll; use tracing::{event, Level}; #[cfg(all(windows, feature = "winauth"))] use winauth::{windows::NtlmSspiBuilder, NextBytes}; +use zeroize::{Zeroize, Zeroizing}; /// A `Connection` is an abstraction between the [`Client`] and the server. It /// can be used as a `Stream` to fetch [`Packet`]s from and to `send` packets @@ -196,6 +197,45 @@ impl Connection { Ok(()) } + async fn send_sensitive_login( + &mut self, + mut header: PacketHeader, + mut payload: Zeroizing>, + ) -> crate::Result<()> { + self.flushed = false; + let packet_size = (self.context.packet_size() as usize) - HEADER_BYTES; + let mut offset = 0; + + while offset < payload.len() { + let end = cmp::min(payload.len(), offset + packet_size); + + if end == payload.len() { + header.set_status(PacketStatus::EndOfMessage); + } else { + header.set_status(PacketStatus::NormalMessage); + } + + let mut frame = Zeroizing::new(Vec::with_capacity(HEADER_BYTES + end - offset)); + header.encode(&mut *frame)?; + frame.extend_from_slice(&payload[offset..end]); + + let size = (frame.len() as u16).to_be_bytes(); + frame[2] = size[0]; + frame[3] = size[1]; + + event!(Level::TRACE, "Sending a packet ({} bytes)", frame.len(),); + + (&mut *self.transport).write_all(frame.as_slice()).await?; + frame.zeroize(); + payload[offset..end].zeroize(); + offset = end; + } + + (&mut *self.transport).flush().await?; + + Ok(()) + } + /// Sends a packet of data to the database. /// /// # Warning @@ -415,11 +455,16 @@ impl Connection { self = self.post_login_encryption(encryption); } AuthMethod::SqlServer(auth) => { - login_message.user_name(auth.user()); - login_message.password(auth.password()); + let (user, mut password) = auth.into_credentials(); + + login_message.user_name(user); + login_message.password(password.as_str()); + let payload = login_message.encode_to_vec()?; + password.zeroize(); let id = self.context.next_packet_id(); - self.send(PacketHeader::login(id), login_message).await?; + self.send_sensitive_login(PacketHeader::login(id), payload) + .await?; self = self.post_login_encryption(encryption); } AuthMethod::AADToken(token) => { diff --git a/src/tds/codec/login.rs b/src/tds/codec/login.rs index 265db381e..16e14d1e7 100644 --- a/src/tds/codec/login.rs +++ b/src/tds/codec/login.rs @@ -5,6 +5,7 @@ use enumflags2::{bitflags, BitFlags}; use io::{Cursor, Write}; use std::fmt::Debug; use std::{borrow::Cow, io}; +use zeroize::{Zeroize, Zeroizing}; uint_enum! { #[repr(u32)] @@ -240,10 +241,8 @@ impl<'a> LoginMessage<'a> { self.type_flags.remove(LoginTypeFlag::ReadOnlyIntent); } } -} -impl<'a> Encode for LoginMessage<'a> { - fn encode(self, dst: &mut BytesMut) -> crate::Result<()> { + pub(crate) fn encode_to_vec(self) -> crate::Result>> { let mut cursor = Cursor::new(Vec::with_capacity(512)); // Space for the length @@ -366,7 +365,7 @@ impl<'a> Encode for LoginMessage<'a> { for codepoint in fed_auth_ext.fed_auth_token.encode_utf16() { token.write_u16::(codepoint)?; } - let token = token.into_inner(); + let mut token = token.into_inner(); // options (1) + TokenLength(4) + Token.length + nonce.length let feature_ext_length = @@ -383,6 +382,7 @@ impl<'a> Encode for LoginMessage<'a> { cursor.write_u32::(token.len() as u32)?; cursor.write_all(token.as_slice())?; + token.zeroize(); if let Some(nonce) = fed_auth_ext.nonce { cursor.write_all(nonce.as_ref())?; @@ -394,7 +394,15 @@ impl<'a> Encode for LoginMessage<'a> { cursor.set_position(0); cursor.write_u32::(cursor.get_ref().len() as u32)?; - dst.extend(cursor.into_inner()); + Ok(Zeroizing::new(cursor.into_inner())) + } +} + +impl<'a> Encode for LoginMessage<'a> { + fn encode(self, dst: &mut BytesMut) -> crate::Result<()> { + let mut encoded = self.encode_to_vec()?; + dst.extend_from_slice(encoded.as_slice()); + encoded.zeroize(); Ok(()) }