diff --git a/src/client/mod.rs b/src/client/mod.rs index 8c3f08c6..cf0f1e2e 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -585,19 +585,31 @@ impl ClientState { self.send_nick_password()?; self.send_umodes()?; - let config_chans = self.config().channels(); - for chan in config_chans { - match self.config().channel_key(chan) { - Some(key) => self.send_join_with_keys::<&str, &str>(chan, key)?, - None => self.send_join(chan)?, + // Batch autojoin: keyed channels first, then keyless (RFC 2812 §3.2.1) + let config = self.config(); + let batches = Self::build_batched_joins(config.channels(), &config.channel_keys); + for (chanlist, keylist) in &batches { + match keylist { + Some(keys) => self.send_join_with_keys(chanlist, keys)?, + None => self.send_join(chanlist)?, } } - let joined_chans = self.chanlists.read(); - for chan in joined_chans + + // Re-join previously joined channels not in config + let config_chans = config.channels(); + let rejoin: Vec = self + .chanlists + .read() .keys() .filter(|x| !config_chans.iter().any(|c| c == *x)) - { - self.send_join(chan)? + .cloned() + .collect(); + if !rejoin.is_empty() { + let no_keys = HashMap::new(); + let rejoin_batches = Self::build_batched_joins(&rejoin, &no_keys); + for (chanlist, _) in &rejoin_batches { + self.send_join(chanlist)?; + } } } Command::Response(Response::ERR_NICKNAMEINUSE, _) @@ -806,6 +818,135 @@ impl ClientState { Ok(()) } + /// Builds batched JOIN commands from a list of channels and optional keys. + /// + /// Channels with keys are placed first in each batch so that positional key + /// matching works correctly per RFC 2812 §3.2.1. Commands are split into + /// multiple JOINs when they would exceed the 512-byte IRC line limit. + fn build_batched_joins( + channels: &[String], + channel_keys: &HashMap, + ) -> Vec<(String, Option)> { + if channels.is_empty() { + return Vec::new(); + } + + // Partition into keyed and keyless, preserving config order within groups + let mut keyed: Vec<(&str, &str)> = Vec::new(); + let mut keyless: Vec<&str> = Vec::new(); + for chan in channels { + match channel_keys.get(chan.as_str()) { + Some(key) => keyed.push((chan, key)), + None => keyless.push(chan), + } + } + + // "JOIN " = 5 bytes, "\r\n" = 2 bytes → 505 bytes for payload + const BUDGET: usize = 512 - 7; + + let mut batches: Vec<(String, Option)> = Vec::new(); + let mut batch_chans: Vec<&str> = Vec::new(); + let mut batch_keys: Vec<&str> = Vec::new(); + let mut chan_len: usize = 0; + let mut key_len: usize = 0; + + // Returns total payload size: chanlist [+ " " + keylist] + let payload = |cl: usize, kl: usize, has_keys: bool| -> usize { + if has_keys { + cl + 1 + kl + } else { + cl + } + }; + + // Flush current batch + let flush = |chans: &mut Vec<&str>, + keys: &mut Vec<&str>, + cl: &mut usize, + kl: &mut usize, + out: &mut Vec<(String, Option)>| { + if !chans.is_empty() { + let chanlist = chans.join(","); + let keylist = if keys.is_empty() { + None + } else { + Some(keys.join(",")) + }; + out.push((chanlist, keylist)); + chans.clear(); + keys.clear(); + *cl = 0; + *kl = 0; + } + }; + + // Process keyed channels first (must precede keyless for positional keys) + for (chan, key) in &keyed { + let new_cl = if batch_chans.is_empty() { + chan.len() + } else { + chan_len + 1 + chan.len() + }; + let new_kl = if batch_keys.is_empty() { + key.len() + } else { + key_len + 1 + key.len() + }; + + if !batch_chans.is_empty() && payload(new_cl, new_kl, true) > BUDGET { + flush( + &mut batch_chans, + &mut batch_keys, + &mut chan_len, + &mut key_len, + &mut batches, + ); + chan_len = chan.len(); + key_len = key.len(); + } else { + chan_len = new_cl; + key_len = new_kl; + } + batch_chans.push(chan); + batch_keys.push(key); + } + + // Append keyless channels, filling remaining space in current batch + for chan in &keyless { + let new_cl = if batch_chans.is_empty() { + chan.len() + } else { + chan_len + 1 + chan.len() + }; + let has_keys = !batch_keys.is_empty(); + + if !batch_chans.is_empty() && payload(new_cl, key_len, has_keys) > BUDGET { + flush( + &mut batch_chans, + &mut batch_keys, + &mut chan_len, + &mut key_len, + &mut batches, + ); + chan_len = chan.len(); + } else { + chan_len = new_cl; + } + batch_chans.push(chan); + } + + // Flush remaining + flush( + &mut batch_chans, + &mut batch_keys, + &mut chan_len, + &mut key_len, + &mut batches, + ); + + batches + } + pub_state_base!(); } @@ -1093,7 +1234,7 @@ impl Client { mod test { use std::{collections::HashMap, default::Default, thread, time::Duration}; - use super::Client; + use super::{Client, ClientState}; #[cfg(feature = "channel-lists")] use crate::client::data::User; use crate::{ @@ -1163,10 +1304,7 @@ mod test { }) .await?; client.stream()?.collect().await?; - assert_eq!( - &get_client_value(client)[..], - "JOIN #test\r\nJOIN #test2\r\n" - ); + assert_eq!(&get_client_value(client)[..], "JOIN #test,#test2\r\n"); Ok(()) } @@ -1183,8 +1321,7 @@ mod test { client.stream()?.collect().await?; assert_eq!( &get_client_value(client)[..], - "NICKSERV IDENTIFY password\r\nJOIN #test\r\n\ - JOIN #test2\r\n" + "NICKSERV IDENTIFY password\r\nJOIN #test,#test2\r\n" ); Ok(()) } @@ -1207,7 +1344,7 @@ mod test { client.stream()?.collect().await?; assert_eq!( &get_client_value(client)[..], - "JOIN #test\r\nJOIN #test2 password\r\n" + "JOIN #test2,#test password\r\n" ); Ok(()) } @@ -1230,7 +1367,7 @@ mod test { assert_eq!( &get_client_value(client)[..], "NICK test2\r\nNICKSERV GHOST test password\r\n\ - NICK test\r\nNICKSERV IDENTIFY password\r\nJOIN #test\r\nJOIN #test2\r\n" + NICK test\r\nNICKSERV IDENTIFY password\r\nJOIN #test,#test2\r\n" ); Ok(()) } @@ -1255,7 +1392,7 @@ mod test { &get_client_value(client)[..], "NICK test2\r\nNICKSERV RECOVER test password\ \r\nNICKSERV RELEASE test password\r\nNICK test\r\nNICKSERV IDENTIFY password\ - \r\nJOIN #test\r\nJOIN #test2\r\n" + \r\nJOIN #test,#test2\r\n" ); Ok(()) } @@ -1274,7 +1411,7 @@ mod test { client.stream()?.collect().await?; assert_eq!( &get_client_value(client)[..], - "MODE test +B\r\nJOIN #test\r\nJOIN #test2\r\n" + "MODE test +B\r\nJOIN #test,#test2\r\n" ); Ok(()) } @@ -1974,4 +2111,88 @@ mod test { ); Ok(()) } + + #[test] + fn batch_joins_all_keyless() { + let chans: Vec = vec!["#a".into(), "#b".into(), "#c".into()]; + let keys = HashMap::new(); + let batches = ClientState::build_batched_joins(&chans, &keys); + assert_eq!(batches.len(), 1); + assert_eq!(batches[0].0, "#a,#b,#c"); + assert!(batches[0].1.is_none()); + } + + #[test] + fn batch_joins_keyed_first() { + let chans: Vec = vec!["#plain".into(), "#secret".into(), "#open".into()]; + let mut keys = HashMap::new(); + keys.insert("#secret".to_string(), "pass".to_string()); + let batches = ClientState::build_batched_joins(&chans, &keys); + assert_eq!(batches.len(), 1); + assert_eq!(batches[0].0, "#secret,#plain,#open"); + assert_eq!(batches[0].1.as_deref(), Some("pass")); + } + + #[test] + fn batch_joins_multiple_keys() { + let chans: Vec = vec!["#a".into(), "#b".into(), "#c".into(), "#d".into()]; + let mut keys = HashMap::new(); + keys.insert("#b".to_string(), "kb".to_string()); + keys.insert("#d".to_string(), "kd".to_string()); + let batches = ClientState::build_batched_joins(&chans, &keys); + assert_eq!(batches.len(), 1); + // Keyed channels first (preserving config order: #b before #d), then keyless + assert_eq!(batches[0].0, "#b,#d,#a,#c"); + assert_eq!(batches[0].1.as_deref(), Some("kb,kd")); + } + + #[test] + fn batch_joins_empty() { + let chans: Vec = Vec::new(); + let keys = HashMap::new(); + let batches = ClientState::build_batched_joins(&chans, &keys); + assert!(batches.is_empty()); + } + + #[test] + fn batch_joins_respects_line_limit() { + // Create channels that exceed 512-byte line limit when combined + // "JOIN " (5) + channels + "\r\n" (2) must be <= 512, so payload <= 505 + // Each channel is ~50 chars, so ~10 channels per batch + let chans: Vec = (0..15) + .map(|i| format!("#channel-with-a-long-name-for-testing-{:02}", i)) + .collect(); + let keys = HashMap::new(); + let batches = ClientState::build_batched_joins(&chans, &keys); + assert!(batches.len() > 1, "should split into multiple batches"); + for (chanlist, keylist) in &batches { + // "JOIN " + chanlist + "\r\n" must fit in 512 + let line_len = 5 + chanlist.len() + 2; + assert!(line_len <= 512, "batch line too long: {} bytes", line_len); + assert!(keylist.is_none()); + } + // All channels should be present + let all_chans: Vec<&str> = batches.iter().flat_map(|(cl, _)| cl.split(',')).collect(); + assert_eq!(all_chans.len(), 15); + } + + #[test] + fn batch_joins_keyed_with_line_limit() { + // Keyed channels that exceed the limit with their keys + let chans: Vec = (0..15) + .map(|i| format!("#keyed-channel-long-name-{:02}", i)) + .collect(); + let keys: HashMap = chans + .iter() + .map(|c| (c.clone(), "a-somewhat-long-key-value".to_string())) + .collect(); + let batches = ClientState::build_batched_joins(&chans, &keys); + assert!(batches.len() > 1, "should split into multiple batches"); + for (ref chanlist, ref keylist) in &batches { + let kl = keylist.as_ref().expect("all keyed"); + // "JOIN " + chanlist + " " + keylist + "\r\n" must fit in 512 + let line_len: usize = 5 + chanlist.len() + 1 + kl.len() + 2; + assert!(line_len <= 512, "batch line too long: {} bytes", line_len); + } + } }