Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ async-trait = "0.1"
connection-string = "0.2"
num-traits = "0.2"
uuid = "1.0"
zeroize = "1.8.2"

[target.'cfg(windows)'.dependencies]
winauth = { version = "0.0.4", optional = true }
Expand Down
35 changes: 27 additions & 8 deletions src/client/auth.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,15 @@
use std::fmt::Debug;
use zeroize::Zeroizing;

#[derive(Clone, PartialEq, Eq)]
pub struct SqlServerAuth {
user: String,
password: String,
password: Zeroizing<String>,
}

impl SqlServerAuth {
pub(crate) fn user(&self) -> &str {
&self.user
}

pub(crate) fn password(&self) -> &str {
&self.password
pub(crate) fn into_credentials(self) -> (String, Zeroizing<String>) {
(self.user, self.password)
}
}

Expand Down Expand Up @@ -79,7 +76,7 @@ impl AuthMethod {
pub fn sql_server(user: impl ToString, password: impl ToString) -> Self {
Self::SqlServer(SqlServerAuth {
user: user.to_string(),
password: password.to_string(),
password: Zeroizing::new(password.to_string()),
})
}

Expand All @@ -104,3 +101,25 @@ impl AuthMethod {
Self::AADToken(token.to_string())
}
}

#[cfg(test)]
mod tests {
use super::AuthMethod;
use zeroize::Zeroize;

#[test]
fn sql_server_password_can_be_consumed_and_zeroized() {
let AuthMethod::SqlServer(auth) = AuthMethod::sql_server("sa", "secret") else {
unreachable!();
};

let (user, mut password) = auth.into_credentials();

assert_eq!("sa", user);
assert_eq!("secret", password.as_str());

password.zeroize();

assert!(password.is_empty());
}
}
53 changes: 49 additions & 4 deletions src/client/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use asynchronous_codec::Framed;
use bytes::BytesMut;
#[cfg(any(windows, feature = "integrated-auth-gssapi"))]
use codec::TokenSspi;
use futures_util::io::{AsyncRead, AsyncWrite};
use futures_util::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use futures_util::ready;
use futures_util::sink::SinkExt;
use futures_util::stream::{Stream, TryStream, TryStreamExt};
Expand All @@ -39,6 +39,7 @@ use task::Poll;
use tracing::{event, Level};
#[cfg(all(windows, feature = "winauth"))]
use winauth::{windows::NtlmSspiBuilder, NextBytes};
use zeroize::{Zeroize, Zeroizing};

/// A `Connection` is an abstraction between the [`Client`] and the server. It
/// can be used as a `Stream` to fetch [`Packet`]s from and to `send` packets
Expand Down Expand Up @@ -196,6 +197,46 @@ impl<S: AsyncRead + AsyncWrite + Unpin + Send> Connection<S> {
Ok(())
}

async fn send_sensitive_login<'a>(
&mut self,
mut header: PacketHeader,
item: LoginMessage<'a>,
) -> crate::Result<()> {
self.flushed = false;
let packet_size = (self.context.packet_size() as usize) - HEADER_BYTES;
let mut payload = item.encode_to_vec()?;
let mut offset = 0;

while offset < payload.len() {
let end = cmp::min(payload.len(), offset + packet_size);

if end == payload.len() {
header.set_status(PacketStatus::EndOfMessage);
} else {
header.set_status(PacketStatus::NormalMessage);
}

let mut frame = Zeroizing::new(Vec::with_capacity(HEADER_BYTES + end - offset));
header.encode(&mut *frame)?;
frame.extend_from_slice(&payload[offset..end]);

let size = (frame.len() as u16).to_be_bytes();
frame[2] = size[0];
frame[3] = size[1];

event!(Level::TRACE, "Sending a packet ({} bytes)", frame.len(),);

(&mut *self.transport).write_all(frame.as_slice()).await?;
frame.zeroize();
payload[offset..end].zeroize();
offset = end;
}

(&mut *self.transport).flush().await?;

Ok(())
}

/// Sends a packet of data to the database.
///
/// # Warning
Expand Down Expand Up @@ -415,11 +456,15 @@ impl<S: AsyncRead + AsyncWrite + Unpin + Send> Connection<S> {
self = self.post_login_encryption(encryption);
}
AuthMethod::SqlServer(auth) => {
login_message.user_name(auth.user());
login_message.password(auth.password());
let (user, mut password) = auth.into_credentials();

login_message.user_name(user);
login_message.password(password.as_str());

let id = self.context.next_packet_id();
self.send(PacketHeader::login(id), login_message).await?;
self.send_sensitive_login(PacketHeader::login(id), login_message)
.await?;
password.zeroize();
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated
self = self.post_login_encryption(encryption);
}
AuthMethod::AADToken(token) => {
Expand Down
18 changes: 13 additions & 5 deletions src/tds/codec/login.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use enumflags2::{bitflags, BitFlags};
use io::{Cursor, Write};
use std::fmt::Debug;
use std::{borrow::Cow, io};
use zeroize::{Zeroize, Zeroizing};

uint_enum! {
#[repr(u32)]
Expand Down Expand Up @@ -240,10 +241,8 @@ impl<'a> LoginMessage<'a> {
self.type_flags.remove(LoginTypeFlag::ReadOnlyIntent);
}
}
}

impl<'a> Encode<BytesMut> for LoginMessage<'a> {
fn encode(self, dst: &mut BytesMut) -> crate::Result<()> {
pub(crate) fn encode_to_vec(self) -> crate::Result<Zeroizing<Vec<u8>>> {
let mut cursor = Cursor::new(Vec::with_capacity(512));

// Space for the length
Expand Down Expand Up @@ -366,7 +365,7 @@ impl<'a> Encode<BytesMut> for LoginMessage<'a> {
for codepoint in fed_auth_ext.fed_auth_token.encode_utf16() {
token.write_u16::<LittleEndian>(codepoint)?;
}
let token = token.into_inner();
let mut token = token.into_inner();

// options (1) + TokenLength(4) + Token.length + nonce.length
let feature_ext_length =
Expand All @@ -383,6 +382,7 @@ impl<'a> Encode<BytesMut> for LoginMessage<'a> {

cursor.write_u32::<LittleEndian>(token.len() as u32)?;
cursor.write_all(token.as_slice())?;
token.zeroize();

if let Some(nonce) = fed_auth_ext.nonce {
cursor.write_all(nonce.as_ref())?;
Expand All @@ -394,7 +394,15 @@ impl<'a> Encode<BytesMut> for LoginMessage<'a> {
cursor.set_position(0);
cursor.write_u32::<LittleEndian>(cursor.get_ref().len() as u32)?;

dst.extend(cursor.into_inner());
Ok(Zeroizing::new(cursor.into_inner()))
}
}

impl<'a> Encode<BytesMut> for LoginMessage<'a> {
fn encode(self, dst: &mut BytesMut) -> crate::Result<()> {
let mut encoded = self.encode_to_vec()?;
dst.extend_from_slice(encoded.as_slice());
encoded.zeroize();

Ok(())
}
Expand Down