Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ async-trait = "0.1"
connection-string = "0.2"
num-traits = "0.2"
uuid = "1.0"
zeroize = "1.8.2"

[target.'cfg(windows)'.dependencies]
winauth = { version = "0.0.4", optional = true }
Expand Down
35 changes: 27 additions & 8 deletions src/client/auth.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,15 @@
use std::fmt::Debug;
use zeroize::Zeroizing;

#[derive(Clone, PartialEq, Eq)]
pub struct SqlServerAuth {
user: String,
password: String,
password: Zeroizing<String>,
}

impl SqlServerAuth {
pub(crate) fn user(&self) -> &str {
&self.user
}

pub(crate) fn password(&self) -> &str {
&self.password
pub(crate) fn into_credentials(self) -> (String, Zeroizing<String>) {
(self.user, self.password)
}
}

Expand Down Expand Up @@ -79,7 +76,7 @@ impl AuthMethod {
pub fn sql_server(user: impl ToString, password: impl ToString) -> Self {
Self::SqlServer(SqlServerAuth {
user: user.to_string(),
password: password.to_string(),
password: Zeroizing::new(password.to_string()),
})
}

Expand All @@ -104,3 +101,25 @@ impl AuthMethod {
Self::AADToken(token.to_string())
}
}

#[cfg(test)]
mod tests {
use super::AuthMethod;
use zeroize::Zeroize;

#[test]
fn sql_server_password_can_be_consumed_and_zeroized() {
let AuthMethod::SqlServer(auth) = AuthMethod::sql_server("sa", "secret") else {
unreachable!();
};

let (user, mut password) = auth.into_credentials();

assert_eq!("sa", user);
assert_eq!("secret", password.as_str());

password.zeroize();

assert!(password.is_empty());
}
}
53 changes: 49 additions & 4 deletions src/client/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use asynchronous_codec::Framed;
use bytes::BytesMut;
#[cfg(any(windows, feature = "integrated-auth-gssapi"))]
use codec::TokenSspi;
use futures_util::io::{AsyncRead, AsyncWrite};
use futures_util::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use futures_util::ready;
use futures_util::sink::SinkExt;
use futures_util::stream::{Stream, TryStream, TryStreamExt};
Expand All @@ -39,6 +39,7 @@ use task::Poll;
use tracing::{event, Level};
#[cfg(all(windows, feature = "winauth"))]
use winauth::{windows::NtlmSspiBuilder, NextBytes};
use zeroize::{Zeroize, Zeroizing};

/// A `Connection` is an abstraction between the [`Client`] and the server. It
/// can be used as a `Stream` to fetch [`Packet`]s from and to `send` packets
Expand Down Expand Up @@ -196,6 +197,45 @@ impl<S: AsyncRead + AsyncWrite + Unpin + Send> Connection<S> {
Ok(())
}

async fn send_sensitive_login(
&mut self,
mut header: PacketHeader,
mut payload: Zeroizing<Vec<u8>>,
) -> crate::Result<()> {
self.flushed = false;
let packet_size = (self.context.packet_size() as usize) - HEADER_BYTES;
let mut offset = 0;

while offset < payload.len() {
let end = cmp::min(payload.len(), offset + packet_size);

if end == payload.len() {
header.set_status(PacketStatus::EndOfMessage);
} else {
header.set_status(PacketStatus::NormalMessage);
}

let mut frame = Zeroizing::new(Vec::with_capacity(HEADER_BYTES + end - offset));
header.encode(&mut *frame)?;
frame.extend_from_slice(&payload[offset..end]);

let size = (frame.len() as u16).to_be_bytes();
frame[2] = size[0];
frame[3] = size[1];

event!(Level::TRACE, "Sending a packet ({} bytes)", frame.len(),);

(&mut *self.transport).write_all(frame.as_slice()).await?;
frame.zeroize();
payload[offset..end].zeroize();
offset = end;
}

(&mut *self.transport).flush().await?;

Ok(())
}

/// Sends a packet of data to the database.
///
/// # Warning
Expand Down Expand Up @@ -415,11 +455,16 @@ impl<S: AsyncRead + AsyncWrite + Unpin + Send> Connection<S> {
self = self.post_login_encryption(encryption);
}
AuthMethod::SqlServer(auth) => {
login_message.user_name(auth.user());
login_message.password(auth.password());
let (user, mut password) = auth.into_credentials();

login_message.user_name(user);
login_message.password(password.as_str());
let payload = login_message.encode_to_vec()?;
password.zeroize();

let id = self.context.next_packet_id();
self.send(PacketHeader::login(id), login_message).await?;
self.send_sensitive_login(PacketHeader::login(id), payload)
.await?;
self = self.post_login_encryption(encryption);
}
AuthMethod::AADToken(token) => {
Expand Down
18 changes: 13 additions & 5 deletions src/tds/codec/login.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use enumflags2::{bitflags, BitFlags};
use io::{Cursor, Write};
use std::fmt::Debug;
use std::{borrow::Cow, io};
use zeroize::{Zeroize, Zeroizing};

uint_enum! {
#[repr(u32)]
Expand Down Expand Up @@ -240,10 +241,8 @@ impl<'a> LoginMessage<'a> {
self.type_flags.remove(LoginTypeFlag::ReadOnlyIntent);
}
}
}

impl<'a> Encode<BytesMut> for LoginMessage<'a> {
fn encode(self, dst: &mut BytesMut) -> crate::Result<()> {
pub(crate) fn encode_to_vec(self) -> crate::Result<Zeroizing<Vec<u8>>> {
let mut cursor = Cursor::new(Vec::with_capacity(512));

// Space for the length
Expand Down Expand Up @@ -366,7 +365,7 @@ impl<'a> Encode<BytesMut> for LoginMessage<'a> {
for codepoint in fed_auth_ext.fed_auth_token.encode_utf16() {
token.write_u16::<LittleEndian>(codepoint)?;
}
let token = token.into_inner();
let mut token = token.into_inner();

// options (1) + TokenLength(4) + Token.length + nonce.length
let feature_ext_length =
Expand All @@ -383,6 +382,7 @@ impl<'a> Encode<BytesMut> for LoginMessage<'a> {

cursor.write_u32::<LittleEndian>(token.len() as u32)?;
cursor.write_all(token.as_slice())?;
token.zeroize();

if let Some(nonce) = fed_auth_ext.nonce {
cursor.write_all(nonce.as_ref())?;
Expand All @@ -394,7 +394,15 @@ impl<'a> Encode<BytesMut> for LoginMessage<'a> {
cursor.set_position(0);
cursor.write_u32::<LittleEndian>(cursor.get_ref().len() as u32)?;

dst.extend(cursor.into_inner());
Ok(Zeroizing::new(cursor.into_inner()))
}
}

impl<'a> Encode<BytesMut> for LoginMessage<'a> {
fn encode(self, dst: &mut BytesMut) -> crate::Result<()> {
let mut encoded = self.encode_to_vec()?;
dst.extend_from_slice(encoded.as_slice());
encoded.zeroize();

Ok(())
}
Expand Down