From 6495688cbc207354c42e4f1b1ed9f1d9e7012255 Mon Sep 17 00:00:00 2001 From: Lethe Date: Wed, 6 Aug 2025 16:08:59 +0000 Subject: [PATCH 1/6] feat(cli): CLI for client, check chunk integrity --- Cargo.lock | 76 +++++++++++++++++++++++++++++ Cargo.toml | 2 + mine.txt | 39 --------------- src/bin/client.rs | 95 ++++++++++++++++++++++++++++++++++++ src/file.rs | 52 +++++++++++++++++--- src/lib.rs | 7 +++ src/main.rs | 38 ++------------- src/plan.rs | 42 +++++++++------- src/protocol/wire/packets.rs | 6 +++ test.txt | 0 10 files changed, 261 insertions(+), 96 deletions(-) delete mode 100644 mine.txt create mode 100644 src/bin/client.rs create mode 100644 src/lib.rs delete mode 100644 test.txt diff --git a/Cargo.lock b/Cargo.lock index f8b2365..edcf6f8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -52,6 +52,12 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "anyhow" +version = "1.0.98" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e16d2d3311acee920a9eb8d33b8cbc1787ce4a264e85f964c2404b969bdcd487" + [[package]] name = "arrayref" version = "0.3.9" @@ -276,6 +282,27 @@ dependencies = [ "crypto-common", ] +[[package]] +name = "directories" +version = "6.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "16f5094c54661b38d03bd7e50df373292118db60b585c08a411c6d840017fe7d" +dependencies = [ + "dirs-sys", +] + +[[package]] +name = "dirs-sys" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e01a3366d27ee9890022452ee61b2b63a67e6f13f58900b651ff5665f0bb1fab" +dependencies = [ + "libc", + "option-ext", + "redox_users", + "windows-sys 0.60.2", +] + [[package]] name = "ed25519" version = "2.2.3" @@ -437,6 +464,16 @@ version = "0.2.174" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1171693293099992e19cddea4e8b849964e9846f4acee11b3948bcc337be8776" +[[package]] +name = "libredox" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "391290121bad3d37fbddad76d8f5d1c1c314cfc646d143d7e07a3086ddff0ce3" +dependencies = [ + "bitflags", + "libc", +] + [[package]] name = "linux-raw-sys" version = "0.9.4" @@ -517,6 +554,12 @@ version = "1.70.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a4895175b425cb1f87721b59f0f286c2092bd4af812243672510e1ac53e2e0ad" +[[package]] +name = "option-ext" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d" + [[package]] name = "page_size" version = "0.6.0" @@ -622,6 +665,17 @@ name = "raptorq" version = "2.0.0" source = "git+https://github.com/Lethe10137/raptorq.git?branch=master#44b18bbca9972578555ba82b942aa5163c0ba712" +[[package]] +name = "redox_users" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a4e608c6638b9c18977b00b475ac1f28d14e84b27d8d42f70e0bf1e3dec127ac" +dependencies = [ + "getrandom 0.2.16", + "libredox", + "thiserror", +] + [[package]] name = "rustc_version" version = "0.4.1" @@ -778,6 +832,26 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "thiserror" +version = "2.0.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "567b8a2dae586314f7be2a752ec7474332959c6460e02bde30d702a66d488708" +dependencies = [ + "thiserror-impl", +] + +[[package]] +name = "thiserror-impl" +version = "2.0.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f7cf42b4507d8ea322120659672cf1b9dbb93f8f2d4ecfd6e51350ff5b17a1d" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "toml" version = "0.9.4" @@ -850,11 +924,13 @@ checksum = "5a5f39404a5da50712a4c1eecf25e90dd62b613502b7e925fd4e4d19b5c96512" name = "usync" version = "0.1.0" dependencies = [ + "anyhow", "base64", "blake3", "bytes", "clap", "crc", + "directories", "ed25519-dalek", "flume", "hex", diff --git a/Cargo.toml b/Cargo.toml index b50b7d9..9af8841 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,6 +25,8 @@ memmap2 = "0.9.7" page_size = "0.6.0" clap = { version = "4.5.42", features = ["derive"] } serde = { version = "1.0.219", features = ["derive"] } +directories = "6.0.0" +anyhow = "1.0.98" [dev-dependencies] diff --git a/mine.txt b/mine.txt deleted file mode 100644 index 788ca9c..0000000 --- a/mine.txt +++ /dev/null @@ -1,39 +0,0 @@ - Compiling usync v0.1.0 (/home/lethe/thu/working/udp_sender) -warning: unused imports: `Decoder`, `EncodingPacket`, and `calculate_block_offsets` - --> src/bin/slice_raptorq.rs:3:5 - | -3 | Decoder, Encoder, EncodingPacket, ObjectTransmissionInformation, calculate_block_offsets, - | ^^^^^^^ ^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^^ - | - = note: `#[warn(unused_imports)]` on by default - -warning: unused import: `BLOCK_LEN` - --> src/bin/slice_raptorq.rs:6:20 - | -6 | use blake3::{self, BLOCK_LEN}; - | ^^^^^^^^^ - -warning: unused import: `zerocopy::IntoBytes` - --> src/bin/slice_raptorq.rs:7:5 - | -7 | use zerocopy::IntoBytes; - | ^^^^^^^^^^^^^^^^^^^ - -warning: variable `k` is assigned to, but never used - --> src/bin/slice_raptorq.rs:31:13 - | -31 | let mut k = 0; - | ^ - | - = note: consider using `_k` instead - = note: `#[warn(unused_variables)]` on by default - -warning: `usync` (bin "slice_raptorq") generated 4 warnings (run `cargo fix --bin "slice_raptorq"` to apply 3 suggestions) - Finished `release` profile [optimized] target(s) in 0.25s - Running `target/release/slice_raptorq` -[src/bin/slice_raptorq.rs:62:5] hash1 = Hash( - "a5a7d35b3354bf6d8f3eb5b2064106f6dfa9565883d8961f6e6daa6b996060c9", -) -[src/bin/slice_raptorq.rs:62:5] hash2 = Hash( - "a5a7d35b3354bf6d8f3eb5b2064106f6dfa9565883d8961f6e6daa6b996060c9", -) diff --git a/src/bin/client.rs b/src/bin/client.rs new file mode 100644 index 0000000..859101c --- /dev/null +++ b/src/bin/client.rs @@ -0,0 +1,95 @@ +use anyhow::anyhow; +use clap::Parser; +use directories::UserDirs; +use std::{fs, path::PathBuf}; +use zerocopy::IntoBytes; + +use usync::{ + file::{check_file_exist, mmap_segment}, + plan::{FileChunk, FileConfig}, +}; + +#[derive(Parser, Debug)] +#[command(author, version, about = "Client for receiving file", long_about = None)] +struct Args { + /// The path to the plan file (TOML format). + #[arg(short, long, value_name = "PLAN_FILE")] + plan_file: PathBuf, + + /// The path to the downloading file (optional, in your download folder as default). + #[arg(short, long, value_name = "DOWNLOADING_FILE")] + downloading_file: Option, +} + +fn check_chunks<'b>(path: &PathBuf, config: &'b FileConfig) -> Vec<&'b FileChunk> { + let mut result = vec![]; + for chunk in config.chunks.iter() { + result.push(chunk); + + print!(">>> Checking chunk {:04} ... ", chunk.chunk_id); + + let hash = match mmap_segment(path, chunk.offset, chunk.length) { + Ok(chunk_data) => hex::encode(blake3::hash(chunk_data.as_bytes()).as_bytes()), + Err(err) => { + println!(" Failed to read: {err:#}"); + continue; + } + }; + + if hash.as_str() != chunk.hash { + println!( + " Hash check failed. Expected {}, actual {}", + chunk.hash, hash + ); + continue; + } + println!(" OK"); + result.pop(); + } + result +} + +fn main() -> anyhow::Result<()> { + let args = Args::parse(); + + let toml_str = fs::read_to_string(&args.plan_file)?; + let config: FileConfig = toml::from_str(&toml_str)?; + + let downloading_file = match args.downloading_file { + Some(path) => path, + None => { + let user_dir = UserDirs::new(); + let downloads_dir = user_dir.as_ref().and_then(UserDirs::document_dir) + .ok_or(anyhow!( + "Failed to determine downloading path. Please explictly designate one with --downloading-file." + ))?; + + downloads_dir.join(&config.file_name) + } + }; + + println!("Downloading file: {}", downloading_file.display()); + + if check_file_exist(&downloading_file)? { + println!( + "{} already exists, start checking.", + downloading_file.display(), + ); + } else { + println!( + "Created {} successfully as an empty file.", + downloading_file.display() + ) + } + + println!( + "{} chunks in total for file {}.", + config.chunks.len(), + downloading_file.display() + ); + + let need_to_download = check_chunks(&downloading_file, &config); + println!("{} chunks needed to be downloaded", need_to_download.len()); + + Ok(()) +} diff --git a/src/file.rs b/src/file.rs index 612dc5a..1918653 100644 --- a/src/file.rs +++ b/src/file.rs @@ -1,18 +1,58 @@ use memmap2::{Mmap, MmapOptions}; use std::fs::{File, OpenOptions}; -use std::io::Result; +use std::io::{Error, ErrorKind, Result}; use std::os::unix::fs::FileExt; use std::path::Path; -pub fn file_len>(path: P) -> Result { - Ok(std::fs::metadata(path)?.len()) +pub fn sanity_check>(path: P) -> Result<(u64, String)> { + let length = std::fs::metadata(&path)?.len(); + let is_file = std::fs::metadata(&path)?.is_file(); + let file_name = is_file + .then_some(path.as_ref().file_name()) + .flatten() + .ok_or(Error::new( + ErrorKind::IsADirectory, + "A normal file is expected.", + ))? + .to_os_string() + .into_string() + .expect("File name is not valid UTF-8."); + + Ok((length, file_name)) +} + +pub fn check_file_exist>(path: P) -> Result { + let path = path.as_ref(); + if path.exists() { + if path.is_file() { + return Ok(true); + } else { + return Err(Error::other("The path to downloading file is not a file!")); + } + } + File::create(path)?; + Ok(false) } pub fn mmap_segment>(path: P, offset: u64, length: usize) -> Result { let file = File::open(path)?; - + let metadata = file.metadata()?; + let file_size = metadata.len(); let page_size = page_size::get() as u64; - assert_eq!(offset % page_size, 0, "Unaligned offset!"); + if offset % page_size != 0 { + return Err(Error::new(ErrorKind::InvalidInput, "Unaligned offset!")); + } + + let end = offset + .checked_add(length as u64) + .ok_or_else(|| Error::new(ErrorKind::InvalidInput, "Offset + length overflow"))?; + + if end > file_size { + return Err(Error::new( + ErrorKind::UnexpectedEof, + format!("Requested mapping [{offset}..{end}) exceeds file size ({file_size})"), + )); + } let mmap = unsafe { MmapOptions::new().offset(offset).len(length).map(&file)? }; @@ -66,7 +106,7 @@ mod tests { write_at(&file_path, offset2, &block2)?; // Logical length of file = 1 GiB - let file_length = file_len(&file_path)?; + let file_length = std::fs::metadata(&file_path)?.len(); assert_eq!(file_length, file_size); println!("Logical file length: {} bytes", file_length); diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..84f333e --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,7 @@ +#![allow(dead_code)] +#![warn(unused_imports)] + +pub mod constants; +pub mod file; +pub mod plan; +pub mod protocol; diff --git a/src/main.rs b/src/main.rs index c1ccf6c..ef760de 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,19 +1,9 @@ -#![allow(dead_code)] -#![warn(unused_imports)] - -mod constants; -mod file; -mod plan; -mod protocol; - use clap::Parser; -use file::file_len; -use std::io; use std::path::PathBuf; use zerocopy::IntoBytes; -use crate::file::mmap_segment; -use crate::plan::{FileChunk, FileConfig, make_plan}; +use usync::file::{mmap_segment, sanity_check}; +use usync::plan::{FileChunk, FileConfig, make_plan}; #[derive(Parser, Debug)] #[command(author, version, about = "A simple CLI program to build transmission plan.", long_about = None)] @@ -23,33 +13,15 @@ struct Args { file: PathBuf, } -fn main() -> io::Result<()> { +fn main() -> anyhow::Result<()> { let args = Args::parse(); - // Check if the path exists and is a file. - if !args.file.exists() { - eprintln!("Error: The specified file does not exist."); - std::process::exit(1); - } - - if !args.file.is_file() { - eprintln!("Error: The specified path is not a file."); - std::process::exit(1); - } - - let file_name = args - .file - .file_name() - .expect("Failed to get file name") - .to_str() - .expect("Non UTF-8 File name provided") - .to_string(); + let (total_length, file_name) = sanity_check(&args.file)?; - let total_length = file_len(&args.file)?; let mut total_hasher = blake3::Hasher::new(); let mut chunks = vec![]; - for (chunk_id, (offset, length)) in make_plan(total_length as usize).enumerate() { + for (chunk_id, (offset, length)) in make_plan(total_length).enumerate() { let chunk = mmap_segment(&args.file, offset, length)?; let chunk_bytes = chunk.as_bytes(); assert_eq!(chunk_bytes.len(), length); diff --git a/src/plan.rs b/src/plan.rs index ab91534..4dfef14 100644 --- a/src/plan.rs +++ b/src/plan.rs @@ -19,17 +19,17 @@ pub struct FileConfig { } //output an iterator over (start_offset, length) -pub fn make_plan(file_length: usize) -> impl Iterator { - let full_chunks = file_length / CHUNK_SIZE; +pub fn make_plan(file_length: u64) -> impl Iterator { + let full_chunks = file_length / CHUNK_SIZE as u64; let full_chunks_used = full_chunks.checked_sub(1).unwrap_or_default(); - let tail_1_offset = full_chunks_used * CHUNK_SIZE; + let tail_1_offset = full_chunks_used * CHUNK_SIZE as u64; let remain_bytes = file_length - tail_1_offset; - let remain_pages = remain_bytes / DEFAULT_PAGE_SIZE; + let remain_pages = remain_bytes / DEFAULT_PAGE_SIZE as u64; - let tail_1_len = if remain_bytes > CHUNK_SIZE { - remain_pages.div_ceil(2) * DEFAULT_PAGE_SIZE + let tail_1_len = if remain_bytes > CHUNK_SIZE as u64 { + remain_pages.div_ceil(2) * DEFAULT_PAGE_SIZE as u64 } else { 0 }; @@ -38,28 +38,34 @@ pub fn make_plan(file_length: usize) -> impl Iterator { let tail_2_len = file_length - tail_2_offset; (0..full_chunks_used) - .map(|x| (x * CHUNK_SIZE, CHUNK_SIZE)) - .chain(std::iter::once((tail_1_offset, tail_1_len)).filter(|(_, len)| *len > 0)) - .chain(std::iter::once((tail_2_offset, tail_2_len))) - .map(|(offset, len)| (offset as u64, len)) + .map(|x| (x * CHUNK_SIZE as u64, CHUNK_SIZE)) + .chain(std::iter::once((tail_1_offset, tail_1_len as usize)).filter(|(_, len)| *len > 0)) + .chain(std::iter::once((tail_2_offset, tail_2_len as usize))) } // .map(|(offset, len)| (offset as usize, len)) #[cfg(test)] mod test { - use crate::plan::make_plan; + use crate::plan::make_plan as make_plan_u64; const M: usize = 1024 * 1024; const K: usize = 1024; + fn make_plan_usize(file_length: usize) -> impl Iterator { + make_plan_u64(file_length as u64) + } + #[test] fn test_make_plan() { // Case 1, file_length <= 32MiB assert_eq!( vec![(0, 17_245_233)], - make_plan(17_245_233).collect::>() + make_plan_usize(17_245_233).collect::>() ); - assert_eq!(vec![(0, 32 * M)], make_plan(32 * M).collect::>()); + assert_eq!( + vec![(0, 32 * M)], + make_plan_usize(32 * M).collect::>() + ); //Case 2, 32MiB < file_length <= 64MiB @@ -68,14 +74,14 @@ mod test { (0, 24 * M + 612 * K), // aligned to 4K (24 * M + 612 * K, 24 * M + 609 * K + 343) ], - make_plan(49 * M + 197 * K + 343) + make_plan_usize(49 * M + 197 * K + 343) .map(|(offset, len)| (offset as usize, len)) .collect::>() ); assert_eq!( vec![(0, 32 * M), (32 * M, 32 * M)], - make_plan(64 * M) + make_plan_usize(64 * M) .map(|(offset, len)| (offset as usize, len)) .collect::>() ); @@ -88,7 +94,7 @@ mod test { (32 * M, 16 * M + 52 * K), // aligned to 4K (48 * M + 52 * K, 16 * M + 48 * K) ], - make_plan(64 * M + 100 * K) + make_plan_usize(64 * M + 100 * K) .map(|(offset, len)| (offset as usize, len)) .collect::>() ); @@ -100,7 +106,7 @@ mod test { (64 * M, 32 * M), // aligned to 4K (96 * M, 32 * M - 1), ], - make_plan(128 * M - 1) + make_plan_usize(128 * M - 1) .map(|(offset, len)| (offset as usize, len)) .collect::>() ); @@ -113,7 +119,7 @@ mod test { (96 * M, 16 * M), (112 * M, 16 * M + 1) ], - make_plan(128 * M + 1) + make_plan_usize(128 * M + 1) .map(|(offset, len)| (offset as usize, len)) .collect::>() ); diff --git a/src/protocol/wire/packets.rs b/src/protocol/wire/packets.rs index 93ceee1..108c69e 100644 --- a/src/protocol/wire/packets.rs +++ b/src/protocol/wire/packets.rs @@ -137,6 +137,12 @@ pub struct TicketPacket { get_chunk: HashMap, } +impl Default for TicketPacket { + fn default() -> Self { + Self::new() + } +} + impl TicketPacket { pub fn new() -> Self { let pubkey = KEY_RING diff --git a/test.txt b/test.txt deleted file mode 100644 index e69de29..0000000 From 5f5b1f866de18d003336380ecebdb4e4de364214 Mon Sep 17 00:00:00 2001 From: Lethe Date: Thu, 7 Aug 2025 10:45:01 +0000 Subject: [PATCH 2/6] feat: wrapper for transmission layer --- Cargo.lock | 221 +++++++++++++++++++++++++++++++++++++++ Cargo.toml | 14 ++- src/bin/client.rs | 55 +++++++--- src/lib.rs | 1 + src/transmission/mock.rs | 99 ++++++++++++++++++ src/transmission/mod.rs | 11 ++ src/transmission/real.rs | 94 +++++++++++++++++ 7 files changed, 474 insertions(+), 21 deletions(-) create mode 100644 src/transmission/mock.rs create mode 100644 src/transmission/mod.rs create mode 100644 src/transmission/real.rs diff --git a/Cargo.lock b/Cargo.lock index edcf6f8..a811e29 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,6 +2,21 @@ # It is not intended for manual editing. version = 4 +[[package]] +name = "addr2line" +version = "0.24.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dfbe277e56a376000877090da837660b4427aad530e3028d44e0bffe4f89a1c1" +dependencies = [ + "gimli", +] + +[[package]] +name = "adler2" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "320119579fcad9c21884f5c4861d16174d0e06250625266f50fe6898340abefa" + [[package]] name = "anstream" version = "0.6.19" @@ -70,12 +85,38 @@ version = "0.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" +[[package]] +name = "async-trait" +version = "0.1.88" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e539d3fca749fcee5236ab05e93a52867dd549cc157c8cb7f99595f3cedffdb5" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "autocfg" version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" +[[package]] +name = "backtrace" +version = "0.3.75" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6806a6321ec58106fea15becdad98371e28d92ccbc7c8f1b3b6dd724fe8f1002" +dependencies = [ + "addr2line", + "cfg-if", + "libc", + "miniz_oxide", + "object", + "rustc-demangle", + "windows-targets 0.52.6", +] + [[package]] name = "base64" version = "0.21.7" @@ -414,6 +455,12 @@ dependencies = [ "wasi 0.14.2+wasi-0.2.4", ] +[[package]] +name = "gimli" +version = "0.31.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f" + [[package]] name = "hashbrown" version = "0.15.4" @@ -432,6 +479,15 @@ version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" +[[package]] +name = "humansize" +version = "2.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6cb51c9a029ddc91b07a787f1d86b53ccfa49b0e86688c946ebe8d3555685dd7" +dependencies = [ + "libm", +] + [[package]] name = "indexmap" version = "2.10.0" @@ -442,6 +498,17 @@ dependencies = [ "hashbrown", ] +[[package]] +name = "io-uring" +version = "0.7.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d93587f37623a1a17d94ef2bc9ada592f5465fe7732084ab7beefabe5c77c0c4" +dependencies = [ + "bitflags", + "cfg-if", + "libc", +] + [[package]] name = "is_terminal_polyfill" version = "1.70.1" @@ -464,6 +531,12 @@ version = "0.2.174" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1171693293099992e19cddea4e8b849964e9846f4acee11b3948bcc337be8776" +[[package]] +name = "libm" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f9fbbcab51052fe104eb5e5d351cf728d30a5be1fe14d9be8a3b097481fb97de" + [[package]] name = "libredox" version = "0.1.9" @@ -511,6 +584,27 @@ dependencies = [ "libc", ] +[[package]] +name = "miniz_oxide" +version = "0.8.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fa76a2c86f704bdb222d66965fb3d63269ce38518b83cb0575fca855ebb6316" +dependencies = [ + "adler2", +] + +[[package]] +name = "mio" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "78bed444cc8a2160f01cbcf811ef18cac863ad68ae8ca62092e8db51d51c761c" +dependencies = [ + "libc", + "log", + "wasi 0.11.1+wasi-snapshot-preview1", + "windows-sys 0.59.0", +] + [[package]] name = "nanorand" version = "0.7.0" @@ -542,6 +636,15 @@ dependencies = [ "syn", ] +[[package]] +name = "object" +version = "0.36.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "62948e14d923ea95ea2c7c86c71013138b66525b86bdc08d2dcc262bdb497b87" +dependencies = [ + "memchr", +] + [[package]] name = "once_cell" version = "1.21.3" @@ -560,6 +663,12 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d" +[[package]] +name = "owo-colors" +version = "4.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "48dd4f4a2c8405440fd0462561f0e5806bd0f77e86f51c761481bdd4018b545e" + [[package]] name = "page_size" version = "0.6.0" @@ -570,6 +679,35 @@ dependencies = [ "winapi", ] +[[package]] +name = "parking_lot" +version = "0.12.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "70d58bf43669b5795d1576d0641cfb6fbb2057bf629506267a92807158584a13" +dependencies = [ + "lock_api", + "parking_lot_core", +] + +[[package]] +name = "parking_lot_core" +version = "0.9.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bc838d2a56b5b1a6c25f55575dfc605fabb63bb2365f6c2353ef9159aa69e4a5" +dependencies = [ + "cfg-if", + "libc", + "redox_syscall", + "smallvec", + "windows-targets 0.52.6", +] + +[[package]] +name = "pin-project-lite" +version = "0.2.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b3cff922bd51709b605d9ead9aa71031d81447142d828eb4a6eba76fe619f9b" + [[package]] name = "pkcs8" version = "0.10.2" @@ -665,6 +803,15 @@ name = "raptorq" version = "2.0.0" source = "git+https://github.com/Lethe10137/raptorq.git?branch=master#44b18bbca9972578555ba82b942aa5163c0ba712" +[[package]] +name = "redox_syscall" +version = "0.5.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5407465600fb0548f1442edf71dd20683c6ed326200ace4b1ef0763521bb3b77" +dependencies = [ + "bitflags", +] + [[package]] name = "redox_users" version = "0.5.2" @@ -676,6 +823,12 @@ dependencies = [ "thiserror", ] +[[package]] +name = "rustc-demangle" +version = "0.1.26" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "56f7d92ca342cea22a06f2121d944b4fd82af56988c270852495420f961d4ace" + [[package]] name = "rustc_version" version = "0.4.1" @@ -762,6 +915,15 @@ version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" +[[package]] +name = "signal-hook-registry" +version = "1.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b2a4719bff48cee6b39d12c020eeb490953ad2443b7055bd0b21fca26bd8c28b" +dependencies = [ + "libc", +] + [[package]] name = "signature" version = "2.2.0" @@ -771,6 +933,28 @@ dependencies = [ "rand_core 0.6.4", ] +[[package]] +name = "slab" +version = "0.4.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "04dc19736151f35336d325007ac991178d504a119863a2fcb3758cdb5e52c50d" + +[[package]] +name = "smallvec" +version = "1.15.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67b1b7a3b5fe4f1376887184045fcf45c69e92af734b7aaddc05fb777b6fbd03" + +[[package]] +name = "socket2" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "233504af464074f9d066d7b5416c5f9b894a5862a6506e306f7b816cdd6f1807" +dependencies = [ + "libc", + "windows-sys 0.59.0", +] + [[package]] name = "spin" version = "0.9.8" @@ -852,6 +1036,37 @@ dependencies = [ "syn", ] +[[package]] +name = "tokio" +version = "1.47.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "89e49afdadebb872d3145a5638b59eb0691ea23e46ca484037cfab3b76b95038" +dependencies = [ + "backtrace", + "bytes", + "io-uring", + "libc", + "mio", + "parking_lot", + "pin-project-lite", + "signal-hook-registry", + "slab", + "socket2", + "tokio-macros", + "windows-sys 0.59.0", +] + +[[package]] +name = "tokio-macros" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e06d43f1345a3bcd39f6a56dbb7dcab2ba47e68e8ac134855e7e2bdbaf8cab8" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "toml" version = "0.9.4" @@ -925,6 +1140,7 @@ name = "usync" version = "0.1.0" dependencies = [ "anyhow", + "async-trait", "base64", "blake3", "bytes", @@ -934,15 +1150,20 @@ dependencies = [ "ed25519-dalek", "flume", "hex", + "humansize", "log", "memmap2", + "mio", "num_enum", + "owo-colors", "page_size", "rand", "raptorq", "serde", + "socket2", "tap", "tempfile", + "tokio", "toml", "zerocopy", ] diff --git a/Cargo.toml b/Cargo.toml index 9af8841..a30e75b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,6 +6,10 @@ description = "Using Raptorq code to transmit large file over UDP through unreli license = "MIT" authors = ["Lethe Lee "] + +[dev-dependencies] +tempfile = "3.20.0" + [dependencies] blake3 = "1.8.2" flume = "0.11.1" @@ -27,7 +31,9 @@ clap = { version = "4.5.42", features = ["derive"] } serde = { version = "1.0.219", features = ["derive"] } directories = "6.0.0" anyhow = "1.0.98" - - -[dev-dependencies] -tempfile = "3.20.0" +owo-colors = "4.2.2" +humansize = "2.1.3" +async-trait = "0.1.88" +tokio = { version = "1.47.1", features = ["full"] } +mio = { version = "1.0.4", features = ["net"] } +socket2 = { version = "0.6.0", features = ["all"] } diff --git a/src/bin/client.rs b/src/bin/client.rs index 859101c..a94708b 100644 --- a/src/bin/client.rs +++ b/src/bin/client.rs @@ -1,6 +1,8 @@ use anyhow::anyhow; use clap::Parser; use directories::UserDirs; +use humansize::{BINARY, format_size}; +use owo_colors::OwoColorize; use std::{fs, path::PathBuf}; use zerocopy::IntoBytes; @@ -26,29 +28,58 @@ fn check_chunks<'b>(path: &PathBuf, config: &'b FileConfig) -> Vec<&'b FileChunk for chunk in config.chunks.iter() { result.push(chunk); - print!(">>> Checking chunk {:04} ... ", chunk.chunk_id); + print!( + ">>> Checking chunk {:04}: ...", + chunk.chunk_id.bright_blue() + ); let hash = match mmap_segment(path, chunk.offset, chunk.length) { Ok(chunk_data) => hex::encode(blake3::hash(chunk_data.as_bytes()).as_bytes()), Err(err) => { - println!(" Failed to read: {err:#}"); + println!("\x1b[3D {}: {err:#}", "Failed to read".yellow()); continue; } }; if hash.as_str() != chunk.hash { println!( - " Hash check failed. Expected {}, actual {}", - chunk.hash, hash + "\x1b[3D {}. Expected {}, actual {}", + "Hash check failed".red(), + chunk.hash.yellow(), + hash.yellow() ); continue; } - println!(" OK"); + println!("\x1b[3D {}", "OK".green()); result.pop(); } result } +fn check_file<'a>( + downloading_file: &PathBuf, + config: &'a FileConfig, +) -> anyhow::Result> { + println!( + "{} chunks in total for file {}.", + config.chunks.len(), + downloading_file.display() + ); + + let need_to_download = check_chunks(downloading_file, config); + let download_size: usize = need_to_download.iter().map(|chunk| chunk.length).sum(); + + let print_config = BINARY.decimal_places(3).decimal_zeroes(3); + println!( + "Need to download {} / {} chunks which sized {} / {}.", + need_to_download.len().yellow(), + config.chunks.len().blue(), + format_size(download_size, print_config).yellow(), + format_size(config.total_length, print_config).blue(), + ); + Ok(need_to_download) +} + fn main() -> anyhow::Result<()> { let args = Args::parse(); @@ -71,10 +102,7 @@ fn main() -> anyhow::Result<()> { println!("Downloading file: {}", downloading_file.display()); if check_file_exist(&downloading_file)? { - println!( - "{} already exists, start checking.", - downloading_file.display(), - ); + println!("{} already exists.", downloading_file.display(),); } else { println!( "Created {} successfully as an empty file.", @@ -82,14 +110,7 @@ fn main() -> anyhow::Result<()> { ) } - println!( - "{} chunks in total for file {}.", - config.chunks.len(), - downloading_file.display() - ); - - let need_to_download = check_chunks(&downloading_file, &config); - println!("{} chunks needed to be downloaded", need_to_download.len()); + let _need_to_download = check_file(&downloading_file, &config)?; Ok(()) } diff --git a/src/lib.rs b/src/lib.rs index 84f333e..0dbf486 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,3 +5,4 @@ pub mod constants; pub mod file; pub mod plan; pub mod protocol; +pub mod transmission; diff --git a/src/transmission/mock.rs b/src/transmission/mock.rs new file mode 100644 index 0000000..f570388 --- /dev/null +++ b/src/transmission/mock.rs @@ -0,0 +1,99 @@ +use super::UdpSocketLike; +use async_trait::async_trait; +use bytes::Bytes; +use flume::{Receiver, Sender}; +use std::net::SocketAddr; + +#[derive(Clone)] +pub struct MockSocket { + sender: Sender<(Bytes, SocketAddr)>, + receiver: Receiver<(Bytes, SocketAddr)>, + local_addr: SocketAddr, +} + +impl MockSocket { + pub fn pair(addr1: SocketAddr, addr2: SocketAddr) -> (Self, Self) { + // Duplex:A -> B, B -> A + let (tx1, rx1) = flume::unbounded::<(Bytes, SocketAddr)>(); + let (tx2, rx2) = flume::unbounded::<(Bytes, SocketAddr)>(); + + let socket1 = MockSocket { + sender: tx1, + receiver: rx2, + local_addr: addr1, + }; + let socket2 = MockSocket { + sender: tx2, + receiver: rx1, + local_addr: addr2, + }; + + (socket1, socket2) + } +} + +#[async_trait] +impl UdpSocketLike for MockSocket { + async fn send_to(&self, bufs: &[Bytes], target: SocketAddr) -> std::io::Result { + // Splice the parts together to simulate the payload of a UDP packet. + let total_len: usize = bufs.iter().map(|b| b.len()).sum(); + let combined = if bufs.len() == 1 { + bufs[0].clone() + } else { + let mut v = Vec::with_capacity(total_len); + for b in bufs { + v.extend_from_slice(b); + } + Bytes::from(v) + }; + self.sender + .send_async((combined, target)) + .await + .map_err(|e| { + std::io::Error::new( + std::io::ErrorKind::BrokenPipe, + format!("channel closed: {e}"), + ) + })?; + Ok(total_len) + } + + async fn recv_from(&self, buf: &mut [u8]) -> std::io::Result<(usize, SocketAddr)> { + let (data, from) = self.receiver.recv_async().await.map_err(|e| { + std::io::Error::new( + std::io::ErrorKind::UnexpectedEof, + format!("channel closed: {e}"), + ) + })?; + + let len = data.len().min(buf.len()); + buf[..len].copy_from_slice(&data[..len]); + + Ok((len, from)) + } +} + +#[tokio::test] +async fn test_mock_socket_pair() -> std::io::Result<()> { + use std::net::SocketAddr; + + let addr1: SocketAddr = "127.0.0.1:10000".parse().unwrap(); + let addr2: SocketAddr = "127.0.0.1:10001".parse().unwrap(); + + let (sock1, sock2) = MockSocket::pair(addr1, addr2); + + // sock1 Send + let send_data = vec![Bytes::from("hello"), Bytes::from(" world")]; + let sent = sock1.send_to(&send_data, addr2).await?; + assert_eq!(sent, 11); + + // sock2 Receive + let mut buf = vec![0u8; 100]; + let (len, from) = sock2.recv_from(&mut buf).await?; + let received_str = std::str::from_utf8(&buf[..len]).unwrap(); + + assert_eq!(received_str, "hello world"); + assert_eq!(from, addr2); + + Ok(()) +} diff --git a/src/transmission/mod.rs b/src/transmission/mod.rs new file mode 100644 index 0000000..3bd5ca7 --- /dev/null +++ b/src/transmission/mod.rs @@ -0,0 +1,11 @@ +pub mod mock; +pub mod real; + +use bytes::Bytes; +use std::net::SocketAddr; + +#[async_trait::async_trait] +pub trait UdpSocketLike: Send + Sync { + async fn send_to(&self, bufs: &[Bytes], target: SocketAddr) -> std::io::Result; + async fn recv_from(&self, buf: &mut [u8]) -> std::io::Result<(usize, SocketAddr)>; +} diff --git a/src/transmission/real.rs b/src/transmission/real.rs new file mode 100644 index 0000000..309d33c --- /dev/null +++ b/src/transmission/real.rs @@ -0,0 +1,94 @@ +use bytes::Bytes; +use socket2::{Domain, Protocol, SockAddr, Socket, Type}; +use std::io::IoSlice; +use std::net::SocketAddr; +use tokio::net::UdpSocket as TokioUdpSocket; + +use super::UdpSocketLike; + +pub struct RealUdpSocket { + innner_raw: Socket, + inner_tokio: TokioUdpSocket, +} + +impl RealUdpSocket { + pub async fn bind(addr: SocketAddr) -> std::io::Result { + let domain = match addr { + SocketAddr::V4(_) => Domain::IPV4, + SocketAddr::V6(_) => Domain::IPV6, + }; + let socket = Socket::new(domain, Type::DGRAM, Some(Protocol::UDP))?; + + socket.set_reuse_address(true)?; + socket.set_nonblocking(true)?; + + socket.bind(&addr.into())?; + let std_socket = socket.try_clone()?.into(); + let tokio_socket = TokioUdpSocket::from_std(std_socket)?; + + Ok(Self { + inner_tokio: tokio_socket, + innner_raw: socket, + }) + } +} + +#[async_trait::async_trait] +impl UdpSocketLike for RealUdpSocket { + async fn send_to(&self, bufs: &[Bytes], target: SocketAddr) -> std::io::Result { + let io_slice = bufs + .iter() + .map(|slice| IoSlice::new(slice)) + .collect::>(); + + self.innner_raw + .send_to_vectored(io_slice.as_slice(), &SockAddr::from(target)) + } + + async fn recv_from(&self, buf: &mut [u8]) -> std::io::Result<(usize, SocketAddr)> { + self.inner_tokio.recv_from(buf).await + } +} + +#[cfg(test)] +mod tests { + use super::*; + use bytes::Bytes; + use tokio::time::{Duration, sleep}; + + #[tokio::test] + async fn test_real_udp_socket_send_recv() -> std::io::Result<()> { + // 创建两个 socket,分别绑定到不同的端口 + let recv_addr: SocketAddr = "127.0.0.1:40001".parse().unwrap(); + let send_addr: SocketAddr = "127.0.0.1:40002".parse().unwrap(); + + let receiver = RealUdpSocket::bind(recv_addr).await?; + let sender = RealUdpSocket::bind(send_addr).await?; + + // 要发送的数据 + let data = vec![Bytes::from_static(b"Hello, "), Bytes::from_static(b"UDP!")]; + + // 启动接收任务 + let recv_task = tokio::spawn(async move { + let mut buf = vec![0u8; 1024]; + let (len, from) = receiver.recv_from(&mut buf).await.unwrap(); + let received = &buf[..len]; + (received.to_vec(), from) + }); + + // 稍微等一下让接收端准备好 + sleep(Duration::from_millis(100)).await; + + // 发送数据 + let bytes_sent = sender.send_to(&data, recv_addr).await?; + assert_eq!(bytes_sent, b"Hello, UDP!".len()); + + // 获取接收结果 + let (received, from) = recv_task.await.unwrap(); + + assert_eq!(received, b"Hello, UDP!"); + assert_eq!(from.ip(), send_addr.ip()); // 可以比较 IP,端口可能是动态的 + + Ok(()) + } +} From 1c5d34ee2757bede109c909151de8bf70109e10a Mon Sep 17 00:00:00 2001 From: Lethe Date: Sat, 9 Aug 2025 10:34:15 +0000 Subject: [PATCH 3/6] feat: customized timer for sender --- Cargo.lock | 118 ++++++++++++++++++++++- Cargo.toml | 8 +- examples/timer_example.rs | 66 +++++++++++++ src/bin/client.rs | 2 +- src/{main.rs => bin/planner.rs} | 4 +- src/engine/mod.rs | 2 + src/engine/receiving.rs | 1 + src/engine/sending.rs | 94 ++++++++++++++++++ src/lib.rs | 4 +- src/protocol/coding/mod.rs | 2 +- src/protocol/wire/frames.rs | 20 ++-- src/protocol/wire/packets.rs | 2 +- src/{ => util}/file.rs | 0 src/util/mod.rs | 4 + src/{ => util}/plan.rs | 2 +- src/util/timer.rs | 165 ++++++++++++++++++++++++++++++++ src/util/timer_logger.rs | 12 +++ 17 files changed, 485 insertions(+), 21 deletions(-) create mode 100644 examples/timer_example.rs rename src/{main.rs => bin/planner.rs} (91%) create mode 100644 src/engine/mod.rs create mode 100644 src/engine/receiving.rs create mode 100644 src/engine/sending.rs rename src/{ => util}/file.rs (100%) create mode 100644 src/util/mod.rs rename src/{ => util}/plan.rs (98%) create mode 100644 src/util/timer.rs create mode 100644 src/util/timer_logger.rs diff --git a/Cargo.lock b/Cargo.lock index a811e29..bdce584 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -85,6 +85,17 @@ version = "0.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" +[[package]] +name = "async-scoped" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4042078ea593edffc452eef14e99fdb2b120caa4ad9618bcdeabc4a023b98740" +dependencies = [ + "futures", + "pin-project", + "tokio", +] + [[package]] name = "async-trait" version = "0.1.88" @@ -408,18 +419,95 @@ dependencies = [ "spin", ] +[[package]] +name = "futures" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "65bc07b1a8bc7c85c5f2e110c476c7389b4554ba72af57d8445ea63a576b0876" +dependencies = [ + "futures-channel", + "futures-core", + "futures-executor", + "futures-io", + "futures-sink", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-channel" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2dff15bf788c671c1934e366d07e30c1814a8ef514e1af724a602e8a2fbe1b10" +dependencies = [ + "futures-core", + "futures-sink", +] + [[package]] name = "futures-core" version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e" +[[package]] +name = "futures-executor" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e28d1d997f585e54aebc3f97d39e72338912123a67330d723fdbb564d646c9f" +dependencies = [ + "futures-core", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-io" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e5c1b78ca4aae1ac06c48a526a655760685149f0d465d21f37abfe57ce075c6" + +[[package]] +name = "futures-macro" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "futures-sink" version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e575fab7d1e0dcb8d0c7bcf9a63ee213816ab51902e6d244a95819acacf1d4f7" +[[package]] +name = "futures-task" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988" + +[[package]] +name = "futures-util" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81" +dependencies = [ + "futures-channel", + "futures-core", + "futures-io", + "futures-macro", + "futures-sink", + "futures-task", + "memchr", + "pin-project-lite", + "pin-utils", + "slab", +] + [[package]] name = "generic-array" version = "0.14.7" @@ -600,7 +688,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "78bed444cc8a2160f01cbcf811ef18cac863ad68ae8ca62092e8db51d51c761c" dependencies = [ "libc", - "log", "wasi 0.11.1+wasi-snapshot-preview1", "windows-sys 0.59.0", ] @@ -702,12 +789,38 @@ dependencies = [ "windows-targets 0.52.6", ] +[[package]] +name = "pin-project" +version = "1.1.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "677f1add503faace112b9f1373e43e9e054bfdd22ff1a63c1bc485eaec6a6a8a" +dependencies = [ + "pin-project-internal", +] + +[[package]] +name = "pin-project-internal" +version = "1.1.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e918e4ff8c4549eb882f14b3a4bc8c8bc93de829416eacf579f1207a8fbf861" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "pin-project-lite" version = "0.2.16" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3b3cff922bd51709b605d9ead9aa71031d81447142d828eb4a6eba76fe619f9b" +[[package]] +name = "pin-utils" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" + [[package]] name = "pkcs8" version = "0.10.2" @@ -1140,6 +1253,7 @@ name = "usync" version = "0.1.0" dependencies = [ "anyhow", + "async-scoped", "async-trait", "base64", "blake3", @@ -1153,8 +1267,8 @@ dependencies = [ "humansize", "log", "memmap2", - "mio", "num_enum", + "once_cell", "owo-colors", "page_size", "rand", diff --git a/Cargo.toml b/Cargo.toml index a30e75b..52d8c71 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -35,5 +35,11 @@ owo-colors = "4.2.2" humansize = "2.1.3" async-trait = "0.1.88" tokio = { version = "1.47.1", features = ["full"] } -mio = { version = "1.0.4", features = ["net"] } + socket2 = { version = "0.6.0", features = ["all"] } +async-scoped = { version = "0.9.0", features = ["use-tokio"] } +once_cell = "1.21.3" + + +[features] +slow-tests = [] diff --git a/examples/timer_example.rs b/examples/timer_example.rs new file mode 100644 index 0000000..44a516d --- /dev/null +++ b/examples/timer_example.rs @@ -0,0 +1,66 @@ +use tokio::time::{Duration, Instant}; +use usync::engine::sending::*; +use usync::protocol::coding::raptorq_code::RaptorqSender; +use usync::protocol::wire::Frame; +use usync::protocol::wire::frames::DataFrame; + +async fn stub_receiver( + send_order: flume::Sender, + receive_packet: flume::Receiver>, +) { + for i in 0..64 { + let data = receive_packet.recv_async().await.unwrap(); + let frame_id: u32 = data.header().frame_offset.into(); + let chunk_id: u32 = data.header().chunk_id.into(); + + println!("received: {} - {}", chunk_id, frame_id); + + if i % 16 == 8 { + send_order + .send(SendingOrder { + chunk_id, + sending_interval: Duration::from_millis(1).into(), + time_stamp: Instant::now(), + offset_next: frame_id + 1, + offset_no_more_than: frame_id + 100, + close_now: false, + }) + .unwrap(); + } + } + tokio::time::sleep(Duration::from_secs(10)).await; + println!("receiver exit"); +} + +#[tokio::main] +async fn main() { + let (send_order, receive_order) = flume::bounded::(16); + let (send_packet, receive_packet) = flume::unbounded::>(); + + let start_order = SendingOrder { + chunk_id: 0x19260817, + sending_interval: Duration::from_millis(500).into(), + time_stamp: Instant::now(), + offset_next: 0, + offset_no_more_than: 150, + close_now: false, + }; + let mut sender = SendingChunk::::new( + &[0; 65536], + start_order, + receive_order, + send_packet, + ); + + let sender_future = sender.run(); + let receiver_future = stub_receiver(send_order, receive_packet); + + use async_scoped::{Scope, spawner::use_tokio::Tokio}; + + unsafe { + let mut scope = Scope::create(Tokio); + scope.spawn(sender_future); + scope.spawn(receiver_future); + } + println!("Finish"); +} diff --git a/src/bin/client.rs b/src/bin/client.rs index a94708b..7112741 100644 --- a/src/bin/client.rs +++ b/src/bin/client.rs @@ -6,7 +6,7 @@ use owo_colors::OwoColorize; use std::{fs, path::PathBuf}; use zerocopy::IntoBytes; -use usync::{ +use usync::util::{ file::{check_file_exist, mmap_segment}, plan::{FileChunk, FileConfig}, }; diff --git a/src/main.rs b/src/bin/planner.rs similarity index 91% rename from src/main.rs rename to src/bin/planner.rs index ef760de..e8e7590 100644 --- a/src/main.rs +++ b/src/bin/planner.rs @@ -2,8 +2,8 @@ use clap::Parser; use std::path::PathBuf; use zerocopy::IntoBytes; -use usync::file::{mmap_segment, sanity_check}; -use usync::plan::{FileChunk, FileConfig, make_plan}; +use usync::util::file::{mmap_segment, sanity_check}; +use usync::util::plan::{FileChunk, FileConfig, make_plan}; #[derive(Parser, Debug)] #[command(author, version, about = "A simple CLI program to build transmission plan.", long_about = None)] diff --git a/src/engine/mod.rs b/src/engine/mod.rs new file mode 100644 index 0000000..120ffde --- /dev/null +++ b/src/engine/mod.rs @@ -0,0 +1,2 @@ +pub mod receiving; +pub mod sending; diff --git a/src/engine/receiving.rs b/src/engine/receiving.rs new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/src/engine/receiving.rs @@ -0,0 +1 @@ + diff --git a/src/engine/sending.rs b/src/engine/sending.rs new file mode 100644 index 0000000..6d70102 --- /dev/null +++ b/src/engine/sending.rs @@ -0,0 +1,94 @@ +use crate::protocol::{coding::FrameSender, wire::frames::DataFrame}; +use crate::util::timer::{SenderTimer, SenderTimerOutput}; +use bytes::Bytes; +use flume::{Receiver, Sender}; +use tokio::time::{Duration, Instant}; + +use crate::util::timer_logger::print_relative_time; + +pub struct SendingOrder { + pub chunk_id: u32, + pub sending_interval: Option, + pub time_stamp: Instant, + pub offset_next: u32, + pub offset_no_more_than: u32, + pub close_now: bool, +} + +pub struct SendingChunk, const INFO_LENGTH: usize> { + chunk_id: u32, + encoder: FS, + transmission_info: [u8; INFO_LENGTH], + order_receiver: Receiver, + data_sender: Sender>, + max_frame_offset: u32, + max_sent_offset: u32, + timer: SenderTimer, +} + +impl, const INFO_LENGTH: usize> SendingChunk { + pub fn new( + chunk_data: &[u8], + start_order: SendingOrder, + order_receiver: Receiver, + data_sender: Sender>, + ) -> Self { + print_relative_time("Start init sender", Instant::now()); + let encoder = FS::init(chunk_data, start_order.offset_next); + let transmission_info = encoder.get_trasmission_info(); + let sender = Self { + chunk_id: start_order.chunk_id, + encoder, + transmission_info, + order_receiver, + data_sender, + timer: SenderTimer::new( + start_order + .sending_interval + .unwrap_or(Duration::from_millis(20)), + ), + max_sent_offset: 0, + max_frame_offset: start_order.offset_next + start_order.offset_no_more_than, + }; + print_relative_time("Finish init sender", Instant::now()); + sender + } + + pub async fn run(&mut self) { + loop { + tokio::select! { + Ok(order) = self.order_receiver.recv_async() => { + let now = Instant::now(); + print_relative_time("ORDER", now); + // self.timer.set_rate(, new_interval); + self.timer.set_rate(now, order.sending_interval); + if order.close_now { + print_relative_time("FINISH", now); + break; + } + }, + + output = &mut self.timer => { + match output { + SenderTimerOutput::Send(x) => { + for _ in 0..x{ + if self.max_sent_offset >= self.max_frame_offset {break;} + let (frame_offset, frame) = self.encoder.next_frame(); + if self.data_sender.send_async(DataFrame::new(self.chunk_id, frame_offset, self.transmission_info, Bytes::from(frame))).await.is_err(){ + print_relative_time("Can not send", Instant::now()); + break; + } + print_relative_time(format!("Send {frame_offset}").as_str(), Instant::now()); + self.max_sent_offset = frame_offset; + } + }, + SenderTimerOutput::Close => { + print_relative_time("CLOSE", Instant::now()); + break; + } + }; + } + } + } + } +} diff --git a/src/lib.rs b/src/lib.rs index 0dbf486..4dbffd7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,7 +2,7 @@ #![warn(unused_imports)] pub mod constants; -pub mod file; -pub mod plan; +pub mod engine; pub mod protocol; pub mod transmission; +pub mod util; diff --git a/src/protocol/coding/mod.rs b/src/protocol/coding/mod.rs index b6c9d5d..ca7e244 100644 --- a/src/protocol/coding/mod.rs +++ b/src/protocol/coding/mod.rs @@ -10,4 +10,4 @@ pub trait FrameReceiver: Sized { fn expected_frame_id(&self) -> u32; } -mod raptorq_code; +pub mod raptorq_code; diff --git a/src/protocol/wire/frames.rs b/src/protocol/wire/frames.rs index 98d01c7..3ce9cbc 100644 --- a/src/protocol/wire/frames.rs +++ b/src/protocol/wire/frames.rs @@ -18,7 +18,7 @@ pub enum FrameType { impl FrameType { pub(super) fn try_parse<'a>(&self, data: &'a [u8]) -> Option> { match &self { - FrameType::Data => DataFrame::try_parse(data), + FrameType::Data => DataFrame::::try_parse(data), FrameType::GetChunk => GetChunkFrame::try_parse(data), FrameType::RateLimit => RateLimitFrame::try_parse(data), } @@ -34,20 +34,20 @@ pub enum ParsedFrameVariant<'a> { #[repr(C)] #[derive(IntoBytes, FromBytes, Unaligned, Immutable, KnownLayout)] -pub struct DataFrameHeader { +pub struct DataFrameHeader { pub chunk_id: U32, pub frame_offset: U32, - pub transmission_info: [u8; TRANSMISSION_INFO_LENGTH], + pub transmission_info: [u8; INFO_LENGTH], } -impl SpecificFrameHeader for DataFrameHeader { +impl SpecificFrameHeader for DataFrameHeader { fn get_frame_type(&self) -> FrameType { FrameType::Data } } -pub struct DataFrame { - header: DataFrameHeader, +pub struct DataFrame { + header: DataFrameHeader, data: Bytes, } #[derive(Debug)] @@ -58,11 +58,11 @@ pub struct ParsedDataFrame<'a> { pub data: &'a [u8], } -impl DataFrame { +impl DataFrame { pub fn new( chunk_id: u32, frame_offset: u32, - transmission_info: [u8; TRANSMISSION_INFO_LENGTH], + transmission_info: [u8; INFO_LENGTH], data: Bytes, ) -> Self { Self { @@ -76,8 +76,8 @@ impl DataFrame { } } -impl Frame for DataFrame { - type Header = DataFrameHeader; +impl Frame for DataFrame { + type Header = DataFrameHeader; fn header(&self) -> &Self::Header { &self.header diff --git a/src/protocol/wire/packets.rs b/src/protocol/wire/packets.rs index 108c69e..a35e954 100644 --- a/src/protocol/wire/packets.rs +++ b/src/protocol/wire/packets.rs @@ -84,7 +84,7 @@ impl SpecificPacketHeader for DataPacketHeader { pub struct DataPacket { header: DataPacketHeader, - data: DataFrame, + data: DataFrame, // DataFrame<12> for raptorq } impl DataPacket { diff --git a/src/file.rs b/src/util/file.rs similarity index 100% rename from src/file.rs rename to src/util/file.rs diff --git a/src/util/mod.rs b/src/util/mod.rs new file mode 100644 index 0000000..dee3277 --- /dev/null +++ b/src/util/mod.rs @@ -0,0 +1,4 @@ +pub mod file; +pub mod plan; +pub mod timer; +pub mod timer_logger; diff --git a/src/plan.rs b/src/util/plan.rs similarity index 98% rename from src/plan.rs rename to src/util/plan.rs index 4dfef14..21c4aec 100644 --- a/src/plan.rs +++ b/src/util/plan.rs @@ -46,7 +46,7 @@ pub fn make_plan(file_length: u64) -> impl Iterator { // .map(|(offset, len)| (offset as usize, len)) #[cfg(test)] mod test { - use crate::plan::make_plan as make_plan_u64; + use crate::util::plan::make_plan as make_plan_u64; const M: usize = 1024 * 1024; const K: usize = 1024; diff --git a/src/util/timer.rs b/src/util/timer.rs new file mode 100644 index 0000000..4291b22 --- /dev/null +++ b/src/util/timer.rs @@ -0,0 +1,165 @@ +use std::{ + future::Future, + pin::Pin, + task::{Context, Poll, Waker}, + time::Duration, +}; +use tokio::time::Instant; + +pub enum SenderTimerOutput { + Send(usize), + Close, +} + +pub struct SenderTimer { + interval: Duration, + sleep_after: Instant, + exit_after: Instant, + last_send: Instant, + waker: Option, +} + +const STOP_AFTER: Duration = Duration::from_secs(10); +const EXIT_AFTER: Duration = Duration::from_secs(20); +const MAX_BURST: usize = 8; + +impl SenderTimer { + pub fn new(interval: Duration) -> Self { + let now = Instant::now(); + Self { + interval, + sleep_after: now + STOP_AFTER, + exit_after: now + EXIT_AFTER, + last_send: now, + waker: None, + } + } + + pub fn set_rate(&mut self, timestamp: Instant, new_interval: Option) { + if let Some(new_interval) = new_interval { + self.interval = new_interval; + self.last_send = self.last_send.max(timestamp - new_interval); + } + + self.sleep_after = self.sleep_after.max(timestamp + STOP_AFTER); + self.exit_after = self.exit_after.max(timestamp + EXIT_AFTER); + + if let Some(waker) = self.waker.take() { + waker.wake(); + } + } +} + +impl Future for SenderTimer { + type Output = SenderTimerOutput; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + self.waker = Some(cx.waker().clone()); + let now = Instant::now(); + + if now >= self.exit_after { + return Poll::Ready(SenderTimerOutput::Close); + } + + if now >= self.sleep_after { + let waker_clone = self.waker.as_ref().unwrap().clone(); + let wake_time_clone = self.exit_after; + tokio::spawn(async move { + tokio::time::sleep_until(wake_time_clone).await; + waker_clone.wake(); + }); + return Poll::Pending; + } + + let min_sendable_time = self.last_send + self.interval; + + if now >= min_sendable_time { + let can_send_num = (now.duration_since(self.last_send)).div_duration_f64(self.interval); + if can_send_num > 1.0 { + let can_send_num = can_send_num.floor(); + let advance = self.interval.mul_f64(can_send_num); + self.last_send += advance; + return Poll::Ready(SenderTimerOutput::Send( + (can_send_num as usize).min(MAX_BURST), + )); + } + } + + let waker_clone = self.waker.as_ref().unwrap().clone(); + tokio::spawn(async move { + tokio::time::sleep_until(min_sendable_time).await; + waker_clone.wake(); + }); + Poll::Pending + } +} + +#[cfg(feature = "slow-tests")] +#[cfg(test)] +mod test { + + use super::super::timer_logger::{PROGRAM_START_TIME, print_relative_time}; + use super::*; + use tokio::select; + + #[tokio::test] + async fn clock() { + println!("start"); + let (tx, rx) = flume::bounded::(16); + + let controller = tokio::spawn(async move { + tokio::time::sleep_until(*PROGRAM_START_TIME + Duration::from_secs(3)).await; + tx.send(Duration::from_millis(500)).unwrap(); + + tokio::time::sleep_until(*PROGRAM_START_TIME + Duration::from_secs(20)).await; + tx.send(Duration::from_millis(1500)).unwrap(); + }); + + let sender = tokio::spawn(async move { + let mut timer = SenderTimer::new(Duration::from_millis(900)); + let mut cnt = 0; + let mut sent_times = vec![]; + + loop { + select! { + Ok(new_interval) = rx.recv_async() => { + timer.set_rate(Instant::now(), new_interval.into()); + } + output = &mut timer => { + match output { + SenderTimerOutput::Send(x) => { + for _ in 0..x { + sent_times.push(print_relative_time(format!("send {}", cnt).as_str(), Instant::now()) / 100.0); + cnt += 1; + } + }, + SenderTimerOutput::Close => { + sent_times.push(print_relative_time("CLOSE", Instant::now()) / 100.0); + break; + } + }; + } + } + } + sent_times + }); + + controller.await.unwrap(); + let sent_time = sender + .await + .unwrap() + .into_iter() + .map(|sent_time| ((sent_time).round() as usize) * 100); + + let expected = (0..=2) + .map(|x| x * 900 + 900) + .chain((3..=22).map(|x| x * 500 + 1700)) + .chain((23..=29).map(|x| x * 1500 - 14500)) + .chain(std::iter::once(40000)); + + for (i, (expected, actual)) in expected.zip(sent_time).enumerate() { + println!("{} {}ms {}ms", i, expected, actual,); + assert_eq!(expected, actual); + } + } +} diff --git a/src/util/timer_logger.rs b/src/util/timer_logger.rs new file mode 100644 index 0000000..a643559 --- /dev/null +++ b/src/util/timer_logger.rs @@ -0,0 +1,12 @@ +use once_cell::sync::Lazy; +use owo_colors::*; +use tokio::time::Instant; + +pub static PROGRAM_START_TIME: Lazy = Lazy::new(Instant::now); + +pub fn print_relative_time(label: &str, instant: Instant) -> f64 { + let elapsed = instant.duration_since(*PROGRAM_START_TIME); + let time_ms = elapsed.as_secs_f64() * 1000.0; + println!("[{:.6}ms] {}", time_ms.red(), label.blue()); + time_ms +} From a117b87808e60bf76785c7a811dc1d536b99845b Mon Sep 17 00:00:00 2001 From: Lethe Date: Tue, 12 Aug 2025 19:33:48 +0000 Subject: [PATCH 4/6] feat: runnable local example --- Cargo.lock | 73 ++++++++++++- Cargo.toml | 2 + examples/local_transfer.rs | 100 +++++++++++++++++ examples/timer_example.rs | 66 ----------- src/bin/client.rs | 5 +- src/engine/bus_flume.rs | 119 ++++++++++++++++++++ src/engine/bus_tokio.rs | 104 ++++++++++++++++++ src/engine/decoding.rs | 82 ++++++++++++++ src/engine/encoding.rs | 129 ++++++++++++++++++++++ src/engine/mod.rs | 121 +++++++++++++++++++++ src/engine/receiving.rs | 117 ++++++++++++++++++++ src/engine/sending.rs | 163 +++++++++++++++------------- src/protocol/coding/mod.rs | 2 +- src/protocol/coding/raptorq_code.rs | 17 +-- src/protocol/key_ring.rs | 9 ++ src/protocol/mod.rs | 2 + src/protocol/wire/encoding.rs | 70 ++++++------ src/protocol/wire/frames.rs | 77 ++++++++++--- src/protocol/wire/mod.rs | 4 +- src/protocol/wire/packets.rs | 54 +++++---- src/protocol/wire/verify.rs | 4 +- src/transmission/real.rs | 8 +- src/util/file.rs | 18 +++ src/util/mod.rs | 47 ++++++++ src/util/timer.rs | 4 +- src/util/timer_logger.rs | 9 +- 26 files changed, 1163 insertions(+), 243 deletions(-) create mode 100644 examples/local_transfer.rs delete mode 100644 examples/timer_example.rs create mode 100644 src/engine/bus_flume.rs create mode 100644 src/engine/bus_tokio.rs create mode 100644 src/engine/decoding.rs create mode 100644 src/engine/encoding.rs diff --git a/Cargo.lock b/Cargo.lock index bdce584..0bf67fc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -253,6 +253,15 @@ version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7c74b8349d32d297c9134b8c88677813a227df8f779daa29bfc29c183fe3dca6" +[[package]] +name = "convert_case" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bb402b8d4c85569410425650ce3eddc7d698ed96d39a73f941b08fb63082f1e7" +dependencies = [ + "unicode-segmentation", +] + [[package]] name = "cpufeatures" version = "0.2.17" @@ -277,6 +286,12 @@ version = "2.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "19d374276b40fb8bbdee95aef7c7fa6b5316ec764510eb64b8dd0e2ed0d7e7f5" +[[package]] +name = "crossbeam-utils" +version = "0.8.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" + [[package]] name = "crypto-common" version = "0.1.6" @@ -314,6 +329,20 @@ dependencies = [ "syn", ] +[[package]] +name = "dashmap" +version = "6.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5041cc499144891f3790297212f32a74fb938e5136a14943f338ef9e0ae276cf" +dependencies = [ + "cfg-if", + "crossbeam-utils", + "hashbrown 0.14.5", + "lock_api", + "once_cell", + "parking_lot_core", +] + [[package]] name = "der" version = "0.7.10" @@ -324,6 +353,28 @@ dependencies = [ "zeroize", ] +[[package]] +name = "derive_more" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "093242cf7570c207c83073cf82f79706fe7b8317e98620a47d5be7c3d8497678" +dependencies = [ + "derive_more-impl", +] + +[[package]] +name = "derive_more-impl" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bda628edc44c4bb645fbe0f758797143e4e07926f7ebf4e9bdfbd3d2ce621df3" +dependencies = [ + "convert_case", + "proc-macro2", + "quote", + "syn", + "unicode-xid", +] + [[package]] name = "digest" version = "0.10.7" @@ -549,6 +600,12 @@ version = "0.31.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f" +[[package]] +name = "hashbrown" +version = "0.14.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" + [[package]] name = "hashbrown" version = "0.15.4" @@ -583,7 +640,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fe4cd85333e22411419a0bcae1297d25e58c9443848b11dc6a86fefe8c78a661" dependencies = [ "equivalent", - "hashbrown", + "hashbrown 0.15.4", ] [[package]] @@ -1248,6 +1305,18 @@ version = "1.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5a5f39404a5da50712a4c1eecf25e90dd62b613502b7e925fd4e4d19b5c96512" +[[package]] +name = "unicode-segmentation" +version = "1.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f6ccf251212114b54433ec949fd6a7841275f9ada20dddd2f29e9ceea4501493" + +[[package]] +name = "unicode-xid" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853" + [[package]] name = "usync" version = "0.1.0" @@ -1260,6 +1329,8 @@ dependencies = [ "bytes", "clap", "crc", + "dashmap", + "derive_more", "directories", "ed25519-dalek", "flume", diff --git a/Cargo.toml b/Cargo.toml index 52d8c71..12d22c7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -39,6 +39,8 @@ tokio = { version = "1.47.1", features = ["full"] } socket2 = { version = "0.6.0", features = ["all"] } async-scoped = { version = "0.9.0", features = ["use-tokio"] } once_cell = "1.21.3" +dashmap = "6.1.0" +derive_more = { version = "2.0.1", features = ["full"] } [features] diff --git a/examples/local_transfer.rs b/examples/local_transfer.rs new file mode 100644 index 0000000..5b4df77 --- /dev/null +++ b/examples/local_transfer.rs @@ -0,0 +1,100 @@ +use blake3; +use std::collections::HashMap; +use std::ffi::OsString; +use std::io::Read; +use std::net::SocketAddr; +use std::sync::{Arc, atomic::AtomicU32}; +use std::time::Duration; +use usync::util::file::ChunkIndex; + +use tokio::sync::Semaphore; + +use usync::constants::TRANSMISSION_INFO_LENGTH; +use usync::engine::{Bus, BusAddress, BusMessage, decoding, receiving, sending}; +use usync::protocol::coding::raptorq_code::{RaptorqReceiver, RaptorqSender}; +use usync::protocol::mock_init; +use usync::transmission::mock::MockSocket; +use usync::util::{ + file::{CHUNK_INDEX, write_at}, + generate_random, +}; + +const CONCURRENCY: usize = 10; +const CHUNKS: u32 = 20; +const CHUNK_SIZE: usize = 1048576; + +#[tokio::main] +async fn main() { + debug_assert!( + false, + "Run in release mode instead for raptorq is too slow in debug mode." + ); + use tempfile::NamedTempFile; + let mut file = NamedTempFile::new().unwrap(); + + let data = generate_random(CHUNK_SIZE); + let expected_hash = blake3::hash(&data); + dbg!(&expected_hash); + + let path = OsString::from(file.path().as_os_str()); + + write_at(&path, 0, &data).unwrap(); + + let mut check_read = vec![]; + let length = file.read_to_end(&mut check_read).unwrap(); + assert_eq!(length, CHUNK_SIZE); + assert_eq!(&expected_hash, &blake3::hash(&check_read)); + + CHUNK_INDEX + .set(ChunkIndex { + files: HashMap::from([(0, path.clone())]), + chunks: HashMap::from_iter( + (0..CHUNKS).map(|chunk_id| (chunk_id, (0usize, 0u64, CHUNK_SIZE))), + ), + }) + .map_err(|_| "Failed to init OnceLock") + .unwrap(); + + let addr1: SocketAddr = "127.0.0.1:10000".parse().unwrap(); + let addr2: SocketAddr = "127.0.0.1:10001".parse().unwrap(); + let (sock1, sock2) = MockSocket::pair(addr1, addr2); + + mock_init(); + + let bus: Arc>> = Arc::new(Bus::default()); + let sender = sending::SendingSocket::new(sock1, bus.clone().register(BusAddress::SenderSocket)); + tokio::spawn(sender.run::()); + let receiver = + receiving::ReceivingSocket::new(sock2, bus.clone().register(BusAddress::ReceiverSocket)); + tokio::spawn(receiver.run(addr1)); + + let sem = Arc::new(Semaphore::new(CONCURRENCY)); + let finish = Arc::new(AtomicU32::new(CHUNKS)); + + for chunk_id in 0..CHUNKS { + let sem = sem.clone(); + let bus = bus.clone(); + let finish = finish.clone(); + + let waiting = |finish: Arc| async move { + let permit = sem.acquire().await.unwrap(); + let handler = + decoding::spawn::(chunk_id, bus.clone()); + let result = handler.await.unwrap().unwrap(); + drop(permit); + println!( + " {} Finished, length {}, hash {:?}", + chunk_id, + result.len(), + blake3::hash(&result) + ); + finish.fetch_sub(1, std::sync::atomic::Ordering::Relaxed); + }; + tokio::spawn(waiting(finish)); + } + + while finish.load(std::sync::atomic::Ordering::Relaxed) > 0 { + tokio::time::sleep(Duration::from_secs(5)).await; + bus.debug(); + } +} diff --git a/examples/timer_example.rs b/examples/timer_example.rs deleted file mode 100644 index 44a516d..0000000 --- a/examples/timer_example.rs +++ /dev/null @@ -1,66 +0,0 @@ -use tokio::time::{Duration, Instant}; -use usync::engine::sending::*; -use usync::protocol::coding::raptorq_code::RaptorqSender; -use usync::protocol::wire::Frame; -use usync::protocol::wire::frames::DataFrame; - -async fn stub_receiver( - send_order: flume::Sender, - receive_packet: flume::Receiver>, -) { - for i in 0..64 { - let data = receive_packet.recv_async().await.unwrap(); - let frame_id: u32 = data.header().frame_offset.into(); - let chunk_id: u32 = data.header().chunk_id.into(); - - println!("received: {} - {}", chunk_id, frame_id); - - if i % 16 == 8 { - send_order - .send(SendingOrder { - chunk_id, - sending_interval: Duration::from_millis(1).into(), - time_stamp: Instant::now(), - offset_next: frame_id + 1, - offset_no_more_than: frame_id + 100, - close_now: false, - }) - .unwrap(); - } - } - tokio::time::sleep(Duration::from_secs(10)).await; - println!("receiver exit"); -} - -#[tokio::main] -async fn main() { - let (send_order, receive_order) = flume::bounded::(16); - let (send_packet, receive_packet) = flume::unbounded::>(); - - let start_order = SendingOrder { - chunk_id: 0x19260817, - sending_interval: Duration::from_millis(500).into(), - time_stamp: Instant::now(), - offset_next: 0, - offset_no_more_than: 150, - close_now: false, - }; - let mut sender = SendingChunk::::new( - &[0; 65536], - start_order, - receive_order, - send_packet, - ); - - let sender_future = sender.run(); - let receiver_future = stub_receiver(send_order, receive_packet); - - use async_scoped::{Scope, spawner::use_tokio::Tokio}; - - unsafe { - let mut scope = Scope::create(Tokio); - scope.spawn(sender_future); - scope.spawn(receiver_future); - } - println!("Finish"); -} diff --git a/src/bin/client.rs b/src/bin/client.rs index 7112741..9dfa358 100644 --- a/src/bin/client.rs +++ b/src/bin/client.rs @@ -4,6 +4,7 @@ use directories::UserDirs; use humansize::{BINARY, format_size}; use owo_colors::OwoColorize; use std::{fs, path::PathBuf}; +// use tokio::sync::Semaphore; use zerocopy::IntoBytes; use usync::util::{ @@ -79,8 +80,8 @@ fn check_file<'a>( ); Ok(need_to_download) } - -fn main() -> anyhow::Result<()> { +#[tokio::main] +async fn main() -> anyhow::Result<()> { let args = Args::parse(); let toml_str = fs::read_to_string(&args.plan_file)?; diff --git a/src/engine/bus_flume.rs b/src/engine/bus_flume.rs new file mode 100644 index 0000000..a52d3f5 --- /dev/null +++ b/src/engine/bus_flume.rs @@ -0,0 +1,119 @@ +use dashmap::DashMap; +use owo_colors::OwoColorize; +use std::sync::Arc; +use std::{fmt::Debug, hash::Hash}; +// use tokio::sync::mpsc::{self, Receiver, Sender}; +use flume::{Receiver, Sender}; + +pub struct Bus +where + ADDRESS: Eq + Hash + Clone + Debug, + MESSAGE: Debug, +{ + peers: DashMap>, +} + +impl Default for Bus +where + ADDRESS: Eq + Hash + Clone + Debug, + MESSAGE: Debug, +{ + fn default() -> Self { + Self { + peers: DashMap::new(), + } + } +} +impl Bus +where + ADDRESS: Eq + Hash + Clone + Debug, + MESSAGE: Debug, +{ + pub fn debug(&self) { + eprintln!("BUS devices: {}", self.peers.len()); + + for entry in self.peers.iter() { + let address = entry.key(); + let sender = entry.value(); + let len = sender.len(); + eprintln!("Address: {:?}, unread count: {}", address, len); + } + } + + pub fn register(self: Arc, id: ADDRESS) -> BusInterface { + eprintln!("BUS: Register {:?}", &id.green()); + // let (tx, rx) = flume::bounded(100); + let (tx, rx) = flume::unbounded(); + self.peers.insert(id.clone(), tx); + BusInterface { + address: id, + bus: Arc::clone(&self), + receiver: rx, + } + } + + // Returns Err iff trying to send to an address that never existed or has been dropped. + async fn send(&self, to: ADDRESS, msg: MESSAGE) -> Result<(), MESSAGE> { + if let Some(sender) = self.peers.get(&to) { + sender.send_async(msg).await.map_err(|e| e.0)?; + Ok(()) + } else { + Err(msg) + } + } + + fn unregister(&self, id: ADDRESS) { + eprintln!("BUS: Unregister {:?}", &id.red()); + self.peers.remove(&id); + } +} + +pub struct BusInterface +where + ADDRESS: Eq + Hash + Clone + Debug, + MESSAGE: Debug, +{ + address: ADDRESS, + bus: Arc>, + receiver: Receiver, +} + +impl BusInterface +where + ADDRESS: Eq + Hash + Clone + Debug, + MESSAGE: Debug, +{ + pub async fn send(&self, to: ADDRESS, message: M) -> Result<(), Option> + where + M: Into + TryFrom, + { + let message: MESSAGE = message.into(); + // eprintln!("To {:?}: {:?}", &to.magenta(), &message.blue()); + self.bus + .send(to, message) + .await + .map_err(|err| M::try_from(err).ok()) + } + + pub async fn recv>(&mut self) -> Option { + self.receiver + .recv_async() + .await + .ok() + .and_then(|message| R::try_from(message).ok()) + } + + pub fn get_bus(&self) -> Arc> { + self.bus.clone() + } +} + +impl Drop for BusInterface +where + ADDRESS: Eq + Hash + Clone + Debug, + MESSAGE: Debug, +{ + fn drop(&mut self) { + self.bus.unregister(self.address.clone()); + } +} diff --git a/src/engine/bus_tokio.rs b/src/engine/bus_tokio.rs new file mode 100644 index 0000000..032342a --- /dev/null +++ b/src/engine/bus_tokio.rs @@ -0,0 +1,104 @@ +use dashmap::DashMap; +use owo_colors::OwoColorize; +use std::sync::Arc; +use std::{fmt::Debug, hash::Hash}; +use tokio::sync::mpsc::{self, Receiver, Sender}; + +pub struct Bus +where + ADDRESS: Eq + Hash + Clone + Debug, + MESSAGE: Debug, +{ + peers: DashMap>, +} + +impl Bus +where + ADDRESS: Eq + Hash + Clone + Debug, + MESSAGE: Debug, +{ + pub fn debug(&self) { + eprintln!("BUS devices: {}", self.peers.len()); + todo!(); + } + + pub fn new() -> Self { + Self { + peers: DashMap::new(), + } + } + pub fn register(self: Arc, id: ADDRESS) -> BusInterface { + eprintln!("BUS: Register {:?}", &id.green()); + let (tx, rx) = mpsc::channel(100); + self.peers.insert(id.clone(), tx); + BusInterface { + address: id, + bus: Arc::clone(&self), + receiver: rx, + } + } + + // Returns Err iff trying to send to an address that never existed or has been dropped. + async fn send(&self, to: ADDRESS, msg: MESSAGE) -> Result<(), MESSAGE> { + if let Some(sender) = self.peers.get(&to) { + sender.send(msg).await.map_err(|e| e.0)?; + Ok(()) + } else { + Err(msg) + } + } + + fn unregister(&self, id: ADDRESS) { + eprintln!("BUS: Unregister {:?}", &id.red()); + self.peers.remove(&id); + } +} + +pub struct BusInterface +where + ADDRESS: Eq + Hash + Clone + Debug, + MESSAGE: Debug, +{ + address: ADDRESS, + bus: Arc>, + receiver: Receiver, +} + +impl BusInterface +where + ADDRESS: Eq + Hash + Clone + Debug, + MESSAGE: Debug, +{ + pub async fn send(&self, to: ADDRESS, message: M) -> Result<(), Option> + where + M: Into + TryFrom, + { + let message: MESSAGE = message.into(); + // eprintln!("To {:?}: {:?}", &to.magenta(), &message.blue()); + self.bus + .send(to, message) + .await + .map_err(|err| M::try_from(err).ok()) + } + + pub async fn recv>(&mut self) -> Option { + self.receiver + .recv() + .await + .and_then(|message| R::try_from(message).ok()) + } + + pub fn get_bus(&self) -> Arc> { + self.bus.clone() + } +} + +impl Drop for BusInterface +where + ADDRESS: Eq + Hash + Clone + Debug, + MESSAGE: Debug, +{ + fn drop(&mut self) { + self.bus.unregister(self.address.clone()); + } +} diff --git a/src/engine/decoding.rs b/src/engine/decoding.rs new file mode 100644 index 0000000..463f338 --- /dev/null +++ b/src/engine/decoding.rs @@ -0,0 +1,82 @@ +use super::{Bus, BusAddress, BusInterface, BusMessage, ReceivingChunkReport}; +use crate::protocol::{coding::FrameReceiver, wire::frames::ParsedDataFrame}; +use std::sync::Arc; +use tokio::task::JoinHandle; + +pub fn spawn( + chunk_id: u32, + bus: Arc>>, +) -> JoinHandle>> +where + FR: FrameReceiver + std::marker::Send + 'static, +{ + let bus_interface = bus.register(BusAddress::FrameDecoder(chunk_id)); + let decoder: ChunkDecoder = ChunkDecoder::new(chunk_id, bus_interface); + + tokio::spawn(decoder.run::()) +} + +pub struct ChunkDecoder { + chunk_id: u32, + bus_interface: BusInterface>, +} + +impl ChunkDecoder { + pub fn new( + chunk_id: u32, + bus_interface: BusInterface>, + ) -> Self { + Self { + chunk_id, + bus_interface, + } + } + + pub async fn run>(mut self) -> Option> { + self.bus_interface + .send( + BusAddress::ReceiverSocket, + (self.chunk_id, ReceivingChunkReport::WantNext(0)), + ) + .await + .ok(); + + let first_chunk: ParsedDataFrame = self.bus_interface.recv().await?; + + let mut decoder = FR::try_init(&first_chunk.transmission_info)?; + + if let Some(data) = decoder.update(first_chunk.frame_offset, &first_chunk.data) { + return Some(data); + } + + drop(first_chunk); + + loop { + let frame: ParsedDataFrame = self.bus_interface.recv().await?; + + if let Some(data) = decoder.update(frame.frame_offset, &frame.data) { + self.bus_interface + .send( + BusAddress::ReceiverSocket, + ( + self.chunk_id, + ReceivingChunkReport::Finished(decoder.expected_frame_id()), + ), + ) + .await + .ok(); + return Some(data); + } + self.bus_interface + .send( + BusAddress::ReceiverSocket, + ( + self.chunk_id, + ReceivingChunkReport::WantNext(decoder.expected_frame_id()), + ), + ) + .await + .ok(); + } + } +} diff --git a/src/engine/encoding.rs b/src/engine/encoding.rs new file mode 100644 index 0000000..0d94717 --- /dev/null +++ b/src/engine/encoding.rs @@ -0,0 +1,129 @@ +use crate::protocol::{coding::FrameSender, wire::frames::DataFrame}; +use crate::util::Compare; +use crate::util::file::{CHUNK_INDEX, mmap_segment}; +use crate::util::timer::{SenderTimer, SenderTimerOutput}; +use bytes::Bytes; +use memmap2::Mmap; +use std::net::SocketAddr; +use std::sync::Arc; +use tokio::time::{Duration, Instant}; + +use super::{Bus, BusAddress, BusInterface, BusMessage, SendingOrder}; + +use crate::util::timer_logger::print_relative_time; + +pub async fn spawn( + start_order: SendingOrder, + bus: Arc>>, + sock_addr: SocketAddr, + bus_addr: BusAddress, +) where + FS: FrameSender + std::marker::Send + 'static, +{ + let chunk_info = CHUNK_INDEX + .get() + .and_then(|index| index.get(start_order.chunk_id)); + if chunk_info.is_none() { + return; + } + let chunk_info = chunk_info.unwrap(); + + let chunk_data = mmap_segment(chunk_info.0, chunk_info.1, chunk_info.2).unwrap(); + + let bus_interface = bus.register(bus_addr); + let encoder: ChunkEncoder = + ChunkEncoder::new(chunk_data, start_order, bus_interface, sock_addr).await; + + tokio::spawn(encoder.run()); +} + +pub struct ChunkEncoder, const INFO_LENGTH: usize> { + chunk_id: u32, + encoder: FS, + transmission_info: [u8; INFO_LENGTH], + bus_interface: BusInterface>, + max_frame_offset: u32, + max_sent_offset: u32, + timer: SenderTimer, + sock_addr: SocketAddr, +} + +impl, const INFO_LENGTH: usize> ChunkEncoder +where + FS: Send + 'static, +{ + pub async fn new( + chunk_data: Mmap, + start_order: SendingOrder, + bus_interface: BusInterface>, + sock_addr: SocketAddr, + ) -> Self { + print_relative_time(start_order.chunk_id, "Start init sender", Instant::now()); + let encoder = + tokio::task::spawn_blocking(move || FS::init(chunk_data, start_order.offset_next)) + .await + .unwrap(); + + let transmission_info = encoder.get_trasmission_info(); + let sender = Self { + chunk_id: start_order.chunk_id, + encoder, + transmission_info, + bus_interface, + timer: SenderTimer::new( + start_order + .sending_interval + .unwrap_or(Duration::from_millis(20)), + ), + max_sent_offset: 0, + max_frame_offset: start_order.offset_next + start_order.offset_no_more_than, + sock_addr, + }; + print_relative_time(start_order.chunk_id, "Finish init sender", Instant::now()); + sender + } + + pub async fn run(mut self) { + loop { + tokio::select! { + Some(order) = self.bus_interface.recv::() => { + let now = Instant::now(); + print_relative_time(self.chunk_id, "Got Order", now); + self.timer.set_rate(now, order.sending_interval); + self.max_frame_offset.cmax(order.offset_no_more_than); + if order.close_now { + print_relative_time(self.chunk_id, "FINISH", now); + break; + } + }, + + output = &mut self.timer => { + match output { + SenderTimerOutput::Send(x) => { + for _ in 0..x{ + if self.max_sent_offset >= self.max_frame_offset {break;} + let (frame_offset, frame) = self.encoder.next_frame(); + let data_frame = DataFrame::new(self.chunk_id, frame_offset, self.transmission_info, Bytes::from(frame)); + + if self.bus_interface.send(BusAddress::SenderSocket,(self.sock_addr, data_frame )).await.is_err(){ + print_relative_time(self.chunk_id, "Can not send", Instant::now()); + break; + } + + if frame_offset % 4096 == 0{ + print_relative_time(self.chunk_id, format!("Send {frame_offset}").as_str(), Instant::now()); + } + + self.max_sent_offset = frame_offset; + } + }, + SenderTimerOutput::Close => { + print_relative_time(self.chunk_id, "CLOSE", Instant::now()); + break; + } + }; + } + } + } + } +} diff --git a/src/engine/mod.rs b/src/engine/mod.rs index 120ffde..bb621ab 100644 --- a/src/engine/mod.rs +++ b/src/engine/mod.rs @@ -1,2 +1,123 @@ +pub mod decoding; +pub mod encoding; pub mod receiving; pub mod sending; + +// TODO +// Potential Dead load with tokio::mpsc or flume:: +mod bus_flume; +// mod bus_tokio; + +pub use bus_flume::{Bus, BusInterface}; +// pub use bus_tokio::{Bus, BusInterface}; + +use std::net::SocketAddr; +use tokio::time::{Duration, Instant}; + +use crate::protocol::wire::frames::{DataFrame, ParsedDataFrame}; +use derive_more::{self, Debug}; + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum BusAddress { + SenderSocket, + ReceiverSocket, + FrameEncoder(u32, SocketAddr), + FrameDecoder(u32), +} + +#[derive(derive_more::From, derive_more::TryInto, Debug)] +pub enum BusMessage { + SendingOrder(SendingOrder), + ReceivingChunkReport((u32, ReceivingChunkReport)), + SendingData((SocketAddr, DataFrame)), + ReceivingData(ParsedDataFrame), +} + +#[derive(PartialEq, Eq, Clone, Debug)] +pub enum ReceivingChunkReport { + WantNext(u32), + Finished(u32), +} + +impl Ord for ReceivingChunkReport { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + match (self, other) { + (ReceivingChunkReport::Finished(a), ReceivingChunkReport::Finished(b)) => a.cmp(b), + (ReceivingChunkReport::Finished(_), ReceivingChunkReport::WantNext(_)) => { + std::cmp::Ordering::Greater + } + (ReceivingChunkReport::WantNext(_), ReceivingChunkReport::Finished(_)) => { + std::cmp::Ordering::Less + } + (ReceivingChunkReport::WantNext(a), ReceivingChunkReport::WantNext(b)) => a.cmp(b), + } + } +} +impl PartialOrd for ReceivingChunkReport { + fn partial_cmp(&self, other: &Self) -> Option { + self.cmp(other).into() + } +} + +#[derive(Debug)] +pub struct SendingOrder { + pub chunk_id: u32, + pub sending_interval: Option, + pub time_stamp: Instant, + pub offset_next: u32, + pub offset_no_more_than: u32, + pub close_now: bool, +} + +// use dashmap::{DashMap, DashSet}; + +// struct DownloaderControlBlock { +// pub latest_want: ReceivingChunkReport, +// pub data_sender: flume::Sender>, +// } + +// pub struct Downloader { +// socket: S, +// report_sender: flume::Sender<(u32, ReceivingChunkReport)>, +// report_receiver: flume::Receiver<(u32, ReceivingChunkReport)>, +// contol_block: DashMap>, +// } + +// impl Downloader { +// pub fn new(socket: S) -> Self { +// let (report_sender, report_receiver) = flume::unbounded(); +// Self { +// socket, +// report_sender, +// report_receiver, +// contol_block: DashMap::new(), +// } +// } + +// pub fn register( +// &mut self, +// chunk_id: u32, +// ) -> ( +// flume::Receiver>, +// flume::Sender<(u32, ReceivingChunkReport)>, +// ) { +// let report_sender = self.report_sender.clone(); +// let (data_sender, data_receiver) = flume::bounded(32); +// let control_block = DownloaderControlBlock { +// latest_want: ReceivingChunkReport::WantNext(0), +// data_sender, +// }; +// self.contol_block.insert(chunk_id, control_block); +// (data_receiver, report_sender) +// } + +// pub async fn try_download(&mut self, chunk_id: u32) -> Option> +// where +// FR: FrameReceiver, +// { +// let (data_receiver, reporter) = self.register(chunk_id); +// let mut receiving_chunk = +// receiving::ReceivingChunk::::new(chunk_id, data_receiver, reporter); +// receiving_chunk.run::().await +// } +// } diff --git a/src/engine/receiving.rs b/src/engine/receiving.rs index 8b13789..0c7de2d 100644 --- a/src/engine/receiving.rs +++ b/src/engine/receiving.rs @@ -1 +1,118 @@ +use super::{BusAddress, BusInterface, BusMessage, ReceivingChunkReport}; +use crate::protocol::wire::encoding::{PacketExt, parse_packet}; +use crate::protocol::wire::frames::ParsedFrameVariant; +use crate::protocol::wire::packets::TicketPacket; +use crate::transmission::UdpSocketLike; +use crate::util::Compare; +use bytes::Bytes; +use owo_colors::*; +use std::collections::{HashMap, VecDeque}; +use std::net::SocketAddr; +use tokio::time::{Duration, interval}; +#[derive(Default)] +struct Reporter { + activate_data: HashMap, + exiting_data: VecDeque>, +} + +impl Reporter { + fn is_empty(&self) -> bool { + let exited = self.exiting_data.iter().map(|s| s.len()).sum(); + dbg!(exited); + self.activate_data.is_empty() && 0usize == exited + } + + fn update(&mut self, chunk_id: u32, report: ReceivingChunkReport) { + self.activate_data + .entry(chunk_id) + .and_modify(|x| x.cmax(report.clone())) + .or_insert_with_key(|_| report); + } + + fn generate(&mut self, rate_kbps: u32) -> TicketPacket { + if self.exiting_data.len() >= 3 { + self.exiting_data.pop_back(); + } + + self.exiting_data.push_front( + self.activate_data + .extract_if(|_k, v| *v >= ReceivingChunkReport::Finished(0)) + .collect(), + ); + + self.activate_data + .iter() + .chain(self.exiting_data.iter().flat_map(|s| s.iter())) + .fold( + TicketPacket::new().set_rate_limit(rate_kbps), + |packet: TicketPacket, (chunk_id, result)| match result { + ReceivingChunkReport::WantNext(n) => { + packet.set_get_chunk(*chunk_id, *n, 8192.max(*n / 5)) + } + ReceivingChunkReport::Finished(n) => packet.set_get_chunk(*chunk_id, *n, 0), + }, + ) + } +} + +pub struct ReceivingSocket { + socket: S, + bus_interface: BusInterface>, +} +impl ReceivingSocket { + pub fn new( + socket: S, + bus_interface: BusInterface>, + ) -> Self { + Self { + socket, + bus_interface, + } + } + + pub async fn run(mut self, server_addr: SocketAddr) { + let mut buffer = [0u8; 65537]; + let mut reporter = Reporter::default(); + let mut ticker = interval(Duration::from_secs(2)); + + loop { + tokio::select! { + biased; + + _ = ticker.tick() => { + eprintln!("{}", "Tick".yellow()); + if !reporter.is_empty() { + let packet = reporter.generate(40960).build(); // 40Mbps + if self.socket.send_to(packet.as_slice(), server_addr).await.is_err(){ + eprintln!("{}", "Failed to send report to server!".red()); + break; + } + } + }, + + Ok((length, _)) = self.socket.recv_from(&mut buffer) => { + let packet = Bytes::from(Vec::from(&buffer[0..length])); + if let Ok(packet) = parse_packet::(packet){ + for frame in packet.frames{ + if let ParsedFrameVariant::Data(data_frame) = frame{ + let _ = self.bus_interface.send(BusAddress::FrameDecoder(data_frame.chunk_id), data_frame).await; + } + } + } + }, + + Some((chunk_id, report)) = self.bus_interface.recv::<(u32, ReceivingChunkReport)>() => { + reporter.update(chunk_id, report); + }, + + + + else => { + eprintln!("{}", "SenderSocketexit".red()); + break; + } + } + } + } +} diff --git a/src/engine/sending.rs b/src/engine/sending.rs index 6d70102..30454d9 100644 --- a/src/engine/sending.rs +++ b/src/engine/sending.rs @@ -1,92 +1,109 @@ -use crate::protocol::{coding::FrameSender, wire::frames::DataFrame}; -use crate::util::timer::{SenderTimer, SenderTimerOutput}; +use std::collections::HashMap; +use std::net::SocketAddr; +use std::time::Duration; + +use super::{BusAddress, BusInterface, BusMessage, SendingOrder}; +use crate::constants::MTU; +use crate::protocol::coding::FrameSender; +use crate::protocol::wire::encoding::{PacketExt, ParsedPacket, parse_packet}; +use crate::protocol::wire::frames::ParsedFrameVariant; +use crate::protocol::wire::packets::ParsedPacketVariant; +use crate::protocol::wire::{frames::DataFrame, packets::DataPacket}; +use crate::transmission::UdpSocketLike; + use bytes::Bytes; -use flume::{Receiver, Sender}; -use tokio::time::{Duration, Instant}; -use crate::util::timer_logger::print_relative_time; +use tokio::time::Instant; -pub struct SendingOrder { - pub chunk_id: u32, - pub sending_interval: Option, - pub time_stamp: Instant, - pub offset_next: u32, - pub offset_no_more_than: u32, - pub close_now: bool, +pub struct SendingSocket { + socket: S, + bus_interface: BusInterface>, } -pub struct SendingChunk, const INFO_LENGTH: usize> { - chunk_id: u32, - encoder: FS, - transmission_info: [u8; INFO_LENGTH], - order_receiver: Receiver, - data_sender: Sender>, - max_frame_offset: u32, - max_sent_offset: u32, - timer: SenderTimer, +fn build_sending_order( + packet: ParsedPacket, + socket_addr: SocketAddr, +) -> Option> { + let ParsedPacketVariant::TicketPacket { .. } = packet.specific_packet_header else { + return None; + }; + let mut orders = HashMap::new(); + let mut sending_interval = None; + for frame in packet.frames { + match frame { + ParsedFrameVariant::GetChunk(header) => { + let chunk_id: u32 = header.chunk_id.into(); + let next_recieve: u32 = header.next_receive_offset.into(); + let receive_window: u32 = header.receive_window_frames.into(); + + let order = SendingOrder { + chunk_id, + sending_interval, + time_stamp: Instant::now(), + offset_next: next_recieve, + offset_no_more_than: next_recieve + receive_window, + close_now: receive_window == 0, + }; + orders.insert(BusAddress::FrameEncoder(chunk_id, socket_addr), order); + } + ParsedFrameVariant::RateLimit(header) => { + let rate_limit = u32::from(header.desired_max_kbps); + sending_interval = Duration::from_millis(8) + .mul_f32((MTU + 20) as f32) + .div_f64(rate_limit as f64) + .into(); + } + _ => {} + } + } + + orders.into() } -impl, const INFO_LENGTH: usize> SendingChunk { +impl SendingSocket { pub fn new( - chunk_data: &[u8], - start_order: SendingOrder, - order_receiver: Receiver, - data_sender: Sender>, + socket: S, + bus_interface: BusInterface>, ) -> Self { - print_relative_time("Start init sender", Instant::now()); - let encoder = FS::init(chunk_data, start_order.offset_next); - let transmission_info = encoder.get_trasmission_info(); - let sender = Self { - chunk_id: start_order.chunk_id, - encoder, - transmission_info, - order_receiver, - data_sender, - timer: SenderTimer::new( - start_order - .sending_interval - .unwrap_or(Duration::from_millis(20)), - ), - max_sent_offset: 0, - max_frame_offset: start_order.offset_next + start_order.offset_no_more_than, - }; - print_relative_time("Finish init sender", Instant::now()); - sender + Self { + socket, + bus_interface, + } } - pub async fn run(&mut self) { + pub async fn run(mut self) + where + FS: FrameSender + Send + 'static, + { + let mut buffer = [0u8; 65537]; loop { tokio::select! { - Ok(order) = self.order_receiver.recv_async() => { - let now = Instant::now(); - print_relative_time("ORDER", now); - // self.timer.set_rate(, new_interval); - self.timer.set_rate(now, order.sending_interval); - if order.close_now { - print_relative_time("FINISH", now); - break; + Ok((length, sock_addr)) = self.socket.recv_from(&mut buffer) => { + let packet = Bytes::from(Vec::from(&buffer[0..length])); + if let Some(parsed_packet) = parse_packet::(packet) + .inspect_err(|err| {dbg!(err);}) + .ok().map( + |parsed_packet| build_sending_order(parsed_packet, sock_addr).into_iter().flatten() + ){ + for (addr, order) in parsed_packet.into_iter(){ + if let Err(order) = self.bus_interface.send(addr.clone(), order).await{ + let start_order = order.unwrap(); + if start_order.close_now {continue;} + eprintln!("Init encoder for chunk {:?}, addr {:?}", start_order.chunk_id, &addr); + let bus = self.bus_interface.get_bus(); + super::encoding::spawn::(start_order, bus, sock_addr, addr).await; + } + } } }, - output = &mut self.timer => { - match output { - SenderTimerOutput::Send(x) => { - for _ in 0..x{ - if self.max_sent_offset >= self.max_frame_offset {break;} - let (frame_offset, frame) = self.encoder.next_frame(); - if self.data_sender.send_async(DataFrame::new(self.chunk_id, frame_offset, self.transmission_info, Bytes::from(frame))).await.is_err(){ - print_relative_time("Can not send", Instant::now()); - break; - } - print_relative_time(format!("Send {frame_offset}").as_str(), Instant::now()); - self.max_sent_offset = frame_offset; - } - }, - SenderTimerOutput::Close => { - print_relative_time("CLOSE", Instant::now()); - break; - } - }; + Some((addr, frame)) = self.bus_interface.recv::<(SocketAddr, DataFrame)>() => { + let packet = DataPacket::from(frame).build(); + self.socket.send_to(packet.as_slice(), addr).await.ok(); + }, + + else => { + break; } } } diff --git a/src/protocol/coding/mod.rs b/src/protocol/coding/mod.rs index ca7e244..d238eb0 100644 --- a/src/protocol/coding/mod.rs +++ b/src/protocol/coding/mod.rs @@ -1,5 +1,5 @@ pub trait FrameSender { - fn init(chunk_data: &[u8], next_id: u32) -> Self; + fn init(chunk_data: impl AsRef<[u8]>, next_id: u32) -> Self; fn next_frame(&mut self) -> (u32, Vec); fn get_trasmission_info(&self) -> [u8; TRANSMISSION_INFO_LENGTH]; } diff --git a/src/protocol/coding/raptorq_code.rs b/src/protocol/coding/raptorq_code.rs index 5cc9d83..6da4fa1 100644 --- a/src/protocol/coding/raptorq_code.rs +++ b/src/protocol/coding/raptorq_code.rs @@ -14,7 +14,8 @@ pub struct RaptorqSender { } impl FrameSender for RaptorqSender { - fn init(chunk_data: &[u8], next_id: u32) -> Self { + fn init(chunk_data: impl AsRef<[u8]>, next_id: u32) -> Self { + let chunk_data: &[u8] = chunk_data.as_ref(); let config = ObjectTransmissionInformation::with_defaults( chunk_data.len() as u64, DEFAULT_FRAME_LEN as u16, @@ -30,7 +31,7 @@ impl FrameSender for RaptorqSender { } fn next_frame(&mut self) -> (u32, Vec) { - const BURST: usize = 32; + const BURST: usize = 16; if self.cache.is_empty() { let encoder_cnt = self.encoder.get_block_encoders().len(); @@ -85,23 +86,13 @@ impl FrameReceiver for RaptorqReceiver { #[cfg(test)] mod test { - const CHUNK_SIZE: usize = 1048576; - use rand::Rng; - use crate::constants::MTU; use crate::protocol::coding::{ FrameReceiver, FrameSender, raptorq_code::{RaptorqReceiver, RaptorqSender}, }; - - fn generate_random(size: usize) -> Vec { - let mut data: Vec = vec![0; size]; - for byte in data.iter_mut() { - *byte = rand::rng().random(); - } - data - } + use crate::util::generate_random; #[test] fn get_gen_frames() { diff --git a/src/protocol/key_ring.rs b/src/protocol/key_ring.rs index 5cc95d6..d498b9b 100644 --- a/src/protocol/key_ring.rs +++ b/src/protocol/key_ring.rs @@ -8,6 +8,15 @@ use std::sync::OnceLock; pub static KEY_RING: OnceLock = OnceLock::new(); +pub fn mock_init() { + if KEY_RING.get().is_some() { + return; + } + const PRIKEY: &str = "fd9d88daa555f6bad0bbece8e0e4fffef190723e16aa9dfe0d18c8e4ff7a6eda"; + const PUBKEY: &str = "4ae6629e09372dd96196f35c032fd1c5da3dfe01ca40ecf8b268d78d741e9d1c"; + init(vec![String::from(PUBKEY)], Some(String::from(PRIKEY))); +} + #[derive(Debug, Default)] pub struct KeyRing { pub public_key_rings: HashSet, diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs index 4e4401d..9d36bce 100644 --- a/src/protocol/mod.rs +++ b/src/protocol/mod.rs @@ -2,3 +2,5 @@ pub mod coding; mod key_ring; pub mod wire; + +pub use key_ring::mock_init; diff --git a/src/protocol/wire/encoding.rs b/src/protocol/wire/encoding.rs index d956143..60da8bc 100644 --- a/src/protocol/wire/encoding.rs +++ b/src/protocol/wire/encoding.rs @@ -17,7 +17,7 @@ pub trait RawParts: IntoBytes + FromBytes + Unaligned + Sized + Immutable { } impl RawParts for T where T: IntoBytes + FromBytes + Unaligned + Immutable {} -pub(super) trait PacketExt: Packet { +pub(crate) trait PacketExt: Packet { fn build(self) -> Vec { let header_length = ( CommonPacketHeader::raw_len(), @@ -92,11 +92,18 @@ pub(super) trait FrameExt: Frame { impl FrameExt for T {} #[derive(Debug)] -pub struct ParsedPacket<'a> { - pub common_packet_header: &'a CommonPacketHeader, - pub specific_packet_header: ParsedPacketVariant<'a>, - pub frames: Vec>, - pub verification_field: &'a [u8], +pub struct ParsedPacket { + pub pkt: Bytes, + pub specific_packet_header: ParsedPacketVariant, + pub frames: Vec>, +} + +impl ParsedPacket { + pub fn get_common_packet_header(&self) -> &CommonPacketHeader { + let (header, _remain) = + CommonPacketHeader::try_ref_from_prefix(self.pkt.as_bytes()).unwrap(); + header + } } #[derive(Debug)] @@ -112,12 +119,15 @@ pub enum ParseError { FailedToParseFrame, } -fn parse_frame<'a>(mut remained_body: &'a [u8]) -> Result>, ParseError> { +fn parse_frame( + mut remained_body: Bytes, +) -> Result>, ParseError> { let mut frames = vec![]; while !remained_body.is_empty() { - let (common_frame_header, _) = CommonFrameHeader::try_ref_from_prefix(remained_body) - .map_err(|_| ParseError::BodyTooshort)?; + let (common_frame_header, _) = + CommonFrameHeader::try_ref_from_prefix(remained_body.as_bytes()) + .map_err(|_| ParseError::BodyTooshort)?; let frame_type = common_frame_header.frame_type; let frame_length = u16::from(common_frame_header.frame_length) as usize; @@ -130,7 +140,7 @@ fn parse_frame<'a>(mut remained_body: &'a [u8]) -> Result(mut remained_body: &'a [u8]) -> Result(packet: &'a [u8]) -> Result, ParseError> { - let (common_packet_header, _) = - CommonPacketHeader::try_ref_from_prefix(packet).map_err(|_| ParseError::PacketTooShort)?; +pub fn parse_packet( + packet: Bytes, +) -> Result, ParseError> { + let (common_packet_header, _) = CommonPacketHeader::try_ref_from_prefix(packet.as_bytes()) + .map_err(|_| ParseError::PacketTooShort)?; let header_length = u16::from(common_packet_header.header_length) as usize; let body_length = u16::from(common_packet_header.body_length) as usize; if common_packet_header.version != VERSION { @@ -169,7 +181,7 @@ pub fn parse_packet<'a>(packet: &'a [u8]) -> Result, ParseError let packet_variant = PacketType::try_from(common_packet_header.packet_type) .map_err(|_| ParseError::UnsupportedPacketType(common_packet_header.packet_type))? - .try_parse(specific_packet_header) + .try_parse::(packet.slice_ref(specific_packet_header)) .ok_or(ParseError::FailedToParsePacketHeader)?; KEY_RING @@ -183,11 +195,13 @@ pub fn parse_packet<'a>(packet: &'a [u8]) -> Result, ParseError ) .map_err(ParseError::Verification)?; + let remained_body = packet.slice_ref(&packet[header_length..header_length + body_length]); + + let frames = parse_frame(remained_body)?; Ok(ParsedPacket { - common_packet_header, + pkt: packet, specific_packet_header: packet_variant, - frames: parse_frame(&packet[header_length..header_length + body_length])?, - verification_field, + frames, }) } @@ -197,7 +211,7 @@ mod tests { use super::*; use crate::constants::*; - use crate::protocol::key_ring::init; + use crate::protocol::key_ring::mock_init; use crate::protocol::wire::frames::{GetChunkFrameHeader, ParsedFrameVariant}; use crate::protocol::wire::packets::current_timestamp_ms; use bytes::BytesMut; @@ -210,15 +224,6 @@ mod tests { total_packet.freeze() } - fn mock_init() { - if KEY_RING.get().is_some() { - return; - } - const PRIKEY: &str = "fd9d88daa555f6bad0bbece8e0e4fffef190723e16aa9dfe0d18c8e4ff7a6eda"; - const PUBKEY: &str = "4ae6629e09372dd96196f35c032fd1c5da3dfe01ca40ecf8b268d78d741e9d1c"; - init(vec![String::from(PUBKEY)], Some(String::from(PRIKEY))); - } - #[test] fn build_parse_data_packet() { mock_init(); @@ -243,7 +248,8 @@ mod tests { assert!(total_packet.len() <= MTU); - let parsed_packet = parse_packet(&total_packet).unwrap(); + let parsed_packet = + parse_packet::(Bytes::from(total_packet)).unwrap(); if let ParsedFrameVariant::Data(data_frame) = &parsed_packet.frames[0] { assert_eq!(19260817, data_frame.chunk_id); @@ -253,7 +259,6 @@ mod tests { } else { unreachable!() } - assert_eq!(parsed_packet.verification_field.len(), 8, "Should be CRC64"); } #[test] @@ -273,7 +278,8 @@ mod tests { let total_packet = build_into_bytes(packet); assert!(total_packet.len() <= MTU); - let parsed_packet = parse_packet(&total_packet).unwrap(); + let parsed_packet = + parse_packet::(Bytes::from(total_packet)).unwrap(); let current_time = current_timestamp_ms(); @@ -307,11 +313,11 @@ mod tests { } ParsedFrameVariant::GetChunk(GetChunkFrameHeader { chunk_id, - max_received_offset, + next_receive_offset, receive_window_frames, }) => { let expected_entry = expected.remove(&u32::from(chunk_id)).unwrap(); - assert_eq!(expected_entry.0, u32::from(max_received_offset)); + assert_eq!(expected_entry.0, u32::from(next_receive_offset)); assert_eq!(expected_entry.1, u32::from(receive_window_frames)); } _ => unreachable!(), diff --git a/src/protocol/wire/frames.rs b/src/protocol/wire/frames.rs index 3ce9cbc..b4e7d79 100644 --- a/src/protocol/wire/frames.rs +++ b/src/protocol/wire/frames.rs @@ -1,10 +1,10 @@ +use crate::constants::TRANSMISSION_INFO_LENGTH; use bytes::Bytes; use num_enum::{IntoPrimitive, TryFromPrimitive}; +use std::fmt; use zerocopy::byteorder::{BigEndian, U32}; use zerocopy::{FromBytes, Immutable, IntoBytes, KnownLayout, Unaligned}; -use crate::constants::TRANSMISSION_INFO_LENGTH; - use super::{Frame, SpecificFrameHeader}; #[repr(u8)] @@ -16,7 +16,10 @@ pub enum FrameType { } impl FrameType { - pub(super) fn try_parse<'a>(&self, data: &'a [u8]) -> Option> { + pub(super) fn try_parse( + &self, + data: Bytes, + ) -> Option> { match &self { FrameType::Data => DataFrame::::try_parse(data), FrameType::GetChunk => GetChunkFrame::try_parse(data), @@ -26,14 +29,14 @@ impl FrameType { } #[derive(Debug)] -pub enum ParsedFrameVariant<'a> { - Data(ParsedDataFrame<'a>), +pub enum ParsedFrameVariant { + Data(ParsedDataFrame), GetChunk(GetChunkFrameHeader), RateLimit(RateLimitFrameHeader), } #[repr(C)] -#[derive(IntoBytes, FromBytes, Unaligned, Immutable, KnownLayout)] +#[derive(IntoBytes, FromBytes, Unaligned, Immutable, KnownLayout, Debug)] pub struct DataFrameHeader { pub chunk_id: U32, pub frame_offset: U32, @@ -50,12 +53,48 @@ pub struct DataFrame { header: DataFrameHeader, data: Bytes, } -#[derive(Debug)] -pub struct ParsedDataFrame<'a> { + +pub struct ParsedDataFrame { pub chunk_id: u32, pub frame_offset: u32, - pub transmission_info: [u8; TRANSMISSION_INFO_LENGTH], - pub data: &'a [u8], + pub transmission_info: [u8; INFO_LENGTH], + pub data: Bytes, +} + +fn preview_bytes(bytes: &Bytes) -> String { + let len = bytes.len(); + let preview_len = 16.min(len); + let preview: Vec = bytes + .iter() + .take(preview_len) + .map(|b| format!("{:02x}", b)) + .collect(); + format!( + "[{} bytes: {}{}]", + len, + preview.join(" "), + if len > preview_len { " ..." } else { "" } + ) +} + +impl fmt::Debug for DataFrame { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("DataFrame") + .field("header", &self.header) + .field("data", &preview_bytes(&self.data)) + .finish() + } +} + +impl fmt::Debug for ParsedDataFrame { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ParsedDataFrame") + .field("chunk_id", &self.chunk_id) + .field("frame_offset", &self.frame_offset) + .field("transmission_info", &self.transmission_info) + .field("data", &preview_bytes(&self.data)) + .finish() + } } impl DataFrame { @@ -88,13 +127,15 @@ impl Frame for DataFrame { fn take_body(self) -> Option { Some(self.data) } - fn try_parse<'a>(data: &'a [u8]) -> Option> { - let (header, data) = DataFrameHeader::read_from_prefix(data).ok()?; + fn try_parse( + frame: Bytes, + ) -> Option> { + let (header, data) = DataFrameHeader::read_from_prefix(frame.as_bytes()).ok()?; ParsedFrameVariant::Data(ParsedDataFrame { chunk_id: header.chunk_id.into(), frame_offset: header.frame_offset.into(), transmission_info: header.transmission_info, - data, + data: frame.slice_ref(data), }) .into() } @@ -104,7 +145,7 @@ impl Frame for DataFrame { #[derive(IntoBytes, FromBytes, Unaligned, Immutable, KnownLayout, Debug)] pub struct GetChunkFrameHeader { pub chunk_id: U32, - pub max_received_offset: U32, + pub next_receive_offset: U32, pub receive_window_frames: U32, // 0 means send no more! } @@ -121,8 +162,8 @@ impl Frame for GetChunkFrame { fn header(&self) -> &Self::Header { self } - fn try_parse<'a>(data: &'a [u8]) -> Option> { - let (header, remain) = GetChunkFrameHeader::read_from_prefix(data).ok()?; + fn try_parse(data: Bytes) -> Option> { + let (header, remain) = GetChunkFrameHeader::read_from_prefix(data.as_bytes()).ok()?; remain .is_empty() .then_some(ParsedFrameVariant::GetChunk(header)) @@ -148,8 +189,8 @@ impl Frame for RateLimitFrame { fn header(&self) -> &Self::Header { self } - fn try_parse<'a>(data: &'a [u8]) -> Option> { - let (header, remain) = RateLimitFrameHeader::read_from_prefix(data).ok()?; + fn try_parse(data: Bytes) -> Option> { + let (header, remain) = RateLimitFrameHeader::read_from_prefix(data.as_bytes()).ok()?; remain .is_empty() diff --git a/src/protocol/wire/mod.rs b/src/protocol/wire/mod.rs index 62e6bb9..9197cdf 100644 --- a/src/protocol/wire/mod.rs +++ b/src/protocol/wire/mod.rs @@ -54,7 +54,7 @@ pub trait Frame: Sized { fn take_body(self) -> Option { None } - fn try_parse<'a>(data: &'a [u8]) -> Option>; + fn try_parse(data: Bytes) -> Option>; } pub struct BuiltFrame { @@ -69,5 +69,5 @@ pub trait Packet: Sized { fn get_header(&self) -> &Self::Header; fn get_body(self) -> impl Iterator; - fn try_parse<'a>(data: &'a [u8]) -> Option>; + fn try_parse(data: Bytes) -> Option; } diff --git a/src/protocol/wire/packets.rs b/src/protocol/wire/packets.rs index a35e954..3c2efd0 100644 --- a/src/protocol/wire/packets.rs +++ b/src/protocol/wire/packets.rs @@ -4,7 +4,7 @@ use super::encoding::FrameExt; use super::frames::DataFrame; use super::verify::PacketVerificationData; use super::{Packet, SpecificPacketHeader}; -use crate::constants::{PUB_KEY_LENGTH, TRANSMISSION_INFO_LENGTH}; +use crate::constants::PUB_KEY_LENGTH; use crate::protocol::key_ring::KEY_RING; use crate::protocol::wire::frames::{GetChunkFrame, RateLimitFrame}; use crate::protocol::wire::verify::PacketVerifyType; @@ -35,25 +35,25 @@ pub enum PacketType { } impl PacketType { - pub(super) fn try_parse<'a>(&self, data: &'a [u8]) -> Option> { + pub(super) fn try_parse( + &self, + data: Bytes, + ) -> Option { match &self { - PacketType::Data => DataPacket::try_parse(data), + PacketType::Data => DataPacket::::try_parse(data), PacketType::Ticket => TicketPacket::try_parse(data), } } } #[derive(Debug)] -pub enum ParsedPacketVariant<'a> { +pub enum ParsedPacketVariant { DataPacket(), - TicketPacket { - pub_key: &'a [u8; PUBLIC_KEY_LENGTH], - timestamp_ms: u64, - }, + TicketPacket { pub_key: Bytes, timestamp_ms: u64 }, } -impl<'a> ParsedPacketVariant<'a> { - pub fn build_verification_data( +impl ParsedPacketVariant { + pub fn build_verification_data<'a>( &'a self, pkt: &'a [u8], verification_field: &'a [u8], @@ -82,16 +82,25 @@ impl SpecificPacketHeader for DataPacketHeader { } } -pub struct DataPacket { +pub struct DataPacket { header: DataPacketHeader, - data: DataFrame, // DataFrame<12> for raptorq + data: DataFrame, // DataFrame<12> for raptorq } -impl DataPacket { +impl From> for DataPacket { + fn from(data: DataFrame) -> Self { + Self { + header: DataPacketHeader {}, + data, + } + } +} + +impl DataPacket { pub fn new( chunk_id: u32, offset: u32, - transmission_info: [u8; TRANSMISSION_INFO_LENGTH], + transmission_info: [u8; INFO_LENGTH], data: Vec, ) -> Self { Self { @@ -101,7 +110,7 @@ impl DataPacket { } } -impl Packet for DataPacket { +impl Packet for DataPacket { type Header = DataPacketHeader; const PACKET_TYPE: PacketType = PacketType::Data; const PACKET_VERIFICATION_TYPE: PacketVerifyType = PacketVerifyType::CRC64; @@ -113,7 +122,7 @@ impl Packet for DataPacket { let built = self.data.build(); std::iter::once(built) } - fn try_parse<'a>(data: &'a [u8]) -> Option> { + fn try_parse(data: Bytes) -> Option { (data.is_empty()).then_some(ParsedPacketVariant::DataPacket()) } } @@ -168,14 +177,14 @@ impl TicketPacket { pub fn set_get_chunk( mut self, chunk_id: u32, - max_received_offset: u32, + next_received_offset: u32, receive_window: u32, ) -> Self { self.get_chunk.insert( chunk_id, GetChunkFrame { chunk_id: chunk_id.into(), - max_received_offset: max_received_offset.into(), + next_receive_offset: next_received_offset.into(), receive_window_frames: receive_window.into(), }, ); @@ -201,15 +210,16 @@ impl Packet for TicketPacket { rate_limit.chain(get_packets) } - fn try_parse<'a>(data: &'a [u8]) -> Option> { - let (pub_key, mut remain): (&'a [u8], &'a [u8]) = data.split_at_checked(PUB_KEY_LENGTH)?; - let pub_key: &'a [u8; PUB_KEY_LENGTH] = pub_key.try_into().ok()?; + fn try_parse(data: Bytes) -> Option { + let (pub_key, mut remain): (&[u8], &[u8]) = + data.as_bytes().split_at_checked(PUB_KEY_LENGTH)?; + let pub_key: &[u8; PUB_KEY_LENGTH] = pub_key.try_into().ok()?; let timestamp_ms = remain.try_get_u64().ok()?; remain .is_empty() .then_some(ParsedPacketVariant::TicketPacket { - pub_key, + pub_key: data.slice_ref(pub_key), timestamp_ms, }) } diff --git a/src/protocol/wire/verify.rs b/src/protocol/wire/verify.rs index c8854e3..dd3c899 100644 --- a/src/protocol/wire/verify.rs +++ b/src/protocol/wire/verify.rs @@ -6,7 +6,7 @@ use ed25519_dalek::{Signature, VerifyingKey}; use crate::protocol::key_ring::KeyRing; -use crate::constants::{MTU, PUB_KEY_LENGTH}; +use crate::constants::MTU; pub fn check_crc64(content: &[u8]) -> u64 { Crc::::new(&CRC_64_ECMA_182).checksum(content) } @@ -36,7 +36,7 @@ pub enum PacketVerificationData<'a> { }, Ed25519 { pkt: &'a [u8], - pub_key: &'a [u8; PUB_KEY_LENGTH], + pub_key: &'a [u8], signature: &'a [u8], }, } diff --git a/src/transmission/real.rs b/src/transmission/real.rs index 309d33c..23fd857 100644 --- a/src/transmission/real.rs +++ b/src/transmission/real.rs @@ -58,17 +58,14 @@ mod tests { #[tokio::test] async fn test_real_udp_socket_send_recv() -> std::io::Result<()> { - // 创建两个 socket,分别绑定到不同的端口 let recv_addr: SocketAddr = "127.0.0.1:40001".parse().unwrap(); let send_addr: SocketAddr = "127.0.0.1:40002".parse().unwrap(); let receiver = RealUdpSocket::bind(recv_addr).await?; let sender = RealUdpSocket::bind(send_addr).await?; - // 要发送的数据 let data = vec![Bytes::from_static(b"Hello, "), Bytes::from_static(b"UDP!")]; - // 启动接收任务 let recv_task = tokio::spawn(async move { let mut buf = vec![0u8; 1024]; let (len, from) = receiver.recv_from(&mut buf).await.unwrap(); @@ -76,18 +73,15 @@ mod tests { (received.to_vec(), from) }); - // 稍微等一下让接收端准备好 sleep(Duration::from_millis(100)).await; - // 发送数据 let bytes_sent = sender.send_to(&data, recv_addr).await?; assert_eq!(bytes_sent, b"Hello, UDP!".len()); - // 获取接收结果 let (received, from) = recv_task.await.unwrap(); assert_eq!(received, b"Hello, UDP!"); - assert_eq!(from.ip(), send_addr.ip()); // 可以比较 IP,端口可能是动态的 + assert_eq!(from.ip(), send_addr.ip()); Ok(()) } diff --git a/src/util/file.rs b/src/util/file.rs index 1918653..dc2fdfa 100644 --- a/src/util/file.rs +++ b/src/util/file.rs @@ -1,8 +1,26 @@ use memmap2::{Mmap, MmapOptions}; +use std::collections::HashMap; +use std::ffi::OsString; use std::fs::{File, OpenOptions}; use std::io::{Error, ErrorKind, Result}; use std::os::unix::fs::FileExt; use std::path::Path; +use std::sync::OnceLock; + +pub struct ChunkIndex { + pub files: HashMap, + pub chunks: HashMap, // (file, offset, length) +} + +impl ChunkIndex { + pub fn get(&self, index: u32) -> Option<(&OsString, u64, usize)> { + self.chunks.get(&index).and_then(|(file, offset, length)| { + self.files.get(file).map(|file| (file, *offset, *length)) + }) + } +} + +pub static CHUNK_INDEX: OnceLock = OnceLock::new(); pub fn sanity_check>(path: P) -> Result<(u64, String)> { let length = std::fs::metadata(&path)?.len(); diff --git a/src/util/mod.rs b/src/util/mod.rs index dee3277..4d249b5 100644 --- a/src/util/mod.rs +++ b/src/util/mod.rs @@ -2,3 +2,50 @@ pub mod file; pub mod plan; pub mod timer; pub mod timer_logger; + +use std::time::{Duration, SystemTime, UNIX_EPOCH}; +use tokio::time::Instant; + +pub fn unix_ms_to_tokio_instant(unix_ms: u64) -> Instant { + // Current wall-clock time + let now_unix_ms = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("Time went backwards") + .as_millis() as u64; + + let now_instant = Instant::now(); + + if unix_ms >= now_unix_ms { + // Future timestamp: add the difference + let diff = unix_ms - now_unix_ms; + now_instant + Duration::from_millis(diff) + } else { + // Past timestamp: subtract the difference + let diff = now_unix_ms - unix_ms; + now_instant - Duration::from_millis(diff) + } +} + +pub trait Compare: Ord + Clone { + fn cmax(&mut self, other: Self) { + if *self < other { + *self = other; + } + } + fn cmin(&mut self, other: Self) { + if *self > other { + *self = other; + } + } +} + +impl Compare for T {} + +pub fn generate_random(size: usize) -> Vec { + use rand::Rng; + let mut data: Vec = vec![0; size]; + for byte in data.iter_mut() { + *byte = rand::rng().random(); + } + data +} diff --git a/src/util/timer.rs b/src/util/timer.rs index 4291b22..e61a445 100644 --- a/src/util/timer.rs +++ b/src/util/timer.rs @@ -129,12 +129,12 @@ mod test { match output { SenderTimerOutput::Send(x) => { for _ in 0..x { - sent_times.push(print_relative_time(format!("send {}", cnt).as_str(), Instant::now()) / 100.0); + sent_times.push(print_relative_time(0, format!("send {}", cnt).as_str(), Instant::now()) / 100.0); cnt += 1; } }, SenderTimerOutput::Close => { - sent_times.push(print_relative_time("CLOSE", Instant::now()) / 100.0); + sent_times.push( print_relative_time(0, "CLOSE", Instant::now()) / 100.0); break; } }; diff --git a/src/util/timer_logger.rs b/src/util/timer_logger.rs index a643559..d928273 100644 --- a/src/util/timer_logger.rs +++ b/src/util/timer_logger.rs @@ -4,9 +4,14 @@ use tokio::time::Instant; pub static PROGRAM_START_TIME: Lazy = Lazy::new(Instant::now); -pub fn print_relative_time(label: &str, instant: Instant) -> f64 { +pub fn print_relative_time(chunk_id: u32, label: &str, instant: Instant) -> f64 { let elapsed = instant.duration_since(*PROGRAM_START_TIME); let time_ms = elapsed.as_secs_f64() * 1000.0; - println!("[{:.6}ms] {}", time_ms.red(), label.blue()); + eprintln!( + "{} [{:.6}ms] {}", + chunk_id.magenta(), + time_ms.red(), + label.blue() + ); time_ms } From f6087361b751f91a6fe5061b6a9e1fcc99f583a4 Mon Sep 17 00:00:00 2001 From: Lethe Date: Tue, 12 Aug 2025 21:05:37 +0000 Subject: [PATCH 5/6] milestone: mininal runnable --- readme.md | 24 +++++++++++ src/bin/client.rs | 98 +++++++++++++++++++++++++++++++++++++++++---- src/bin/server.rs | 86 +++++++++++++++++++++++++++++++++++++++ src/protocol/mod.rs | 2 +- src/util/file.rs | 14 ++++++- src/util/plan.rs | 2 +- 6 files changed, 216 insertions(+), 10 deletions(-) create mode 100644 src/bin/server.rs diff --git a/readme.md b/readme.md index 9e798bd..a8e39b9 100644 --- a/readme.md +++ b/readme.md @@ -10,3 +10,27 @@ code is a pratical alternative method under bad network condition, thus no guara The project was developed, tested on and supports UNIX-like system only for now, especially with regard of file system operations. Yet future support for Windows is planned. + +## Example usage + +1. Generate plan file: +```bash +cargo run --bin planner -- --file ~/test.zip > test.plan +``` + +2. Generate key pairs +```bash +cargo test protocol::wire::verify::tests::test_exchange_public_key -- --no-capture +``` +Note: Signing key == Private key, Verifying key == Public key + + +3. Run Server +```bash +cargo run --release --bin server -- --plan-file test.plan --listening 0.0.0.0:7234 --public-key pub.key --folder ~ +``` + +4. Run Client +```bash +cargo run --release --bin client -- --plan-file plan.plan --server 127.0.0.1:7234 --private-key +``` diff --git a/src/bin/client.rs b/src/bin/client.rs index 9dfa358..90afec9 100644 --- a/src/bin/client.rs +++ b/src/bin/client.rs @@ -3,14 +3,20 @@ use clap::Parser; use directories::UserDirs; use humansize::{BINARY, format_size}; use owo_colors::OwoColorize; -use std::{fs, path::PathBuf}; -// use tokio::sync::Semaphore; -use zerocopy::IntoBytes; - +use std::str::FromStr; +use std::sync::{Arc, atomic::AtomicUsize}; +use std::{fs, net::SocketAddr, path::PathBuf}; +use tokio::sync::Semaphore; +use tokio::time::Duration; +use usync::constants::TRANSMISSION_INFO_LENGTH; +use usync::engine::{Bus, BusAddress, BusMessage, decoding, receiving}; +use usync::protocol::{coding::raptorq_code::RaptorqReceiver, init}; +use usync::transmission::real::RealUdpSocket; use usync::util::{ - file::{check_file_exist, mmap_segment}, + file::{check_file_exist_create, mmap_segment, write_at}, plan::{FileChunk, FileConfig}, }; +use zerocopy::IntoBytes; #[derive(Parser, Debug)] #[command(author, version, about = "Client for receiving file", long_about = None)] @@ -19,6 +25,14 @@ struct Args { #[arg(short, long, value_name = "PLAN_FILE")] plan_file: PathBuf, + /// Socket Addr of Server + #[arg(short, long, value_name = "SERVER")] + server: SocketAddr, + + /// Private Key + #[arg(short, long, value_name = "PRI_KEY")] + private_key: String, + /// The path to the downloading file (optional, in your download folder as default). #[arg(short, long, value_name = "DOWNLOADING_FILE")] downloading_file: Option, @@ -82,8 +96,16 @@ fn check_file<'a>( } #[tokio::main] async fn main() -> anyhow::Result<()> { + debug_assert!( + false, + "Run in release mode instead for raptorq is too slow in debug mode." + ); + let args = Args::parse(); + // Init key ring. + init(vec![], Some(args.private_key)); + let toml_str = fs::read_to_string(&args.plan_file)?; let config: FileConfig = toml::from_str(&toml_str)?; @@ -102,7 +124,7 @@ async fn main() -> anyhow::Result<()> { println!("Downloading file: {}", downloading_file.display()); - if check_file_exist(&downloading_file)? { + if check_file_exist_create(&downloading_file)? { println!("{} already exists.", downloading_file.display(),); } else { println!( @@ -111,7 +133,69 @@ async fn main() -> anyhow::Result<()> { ) } - let _need_to_download = check_file(&downloading_file, &config)?; + let bus: Arc>> = Arc::new(Bus::default()); + let socket = RealUdpSocket::bind(SocketAddr::from_str("127.0.0.1:0").unwrap()) + .await + .unwrap(); + let receiver = + receiving::ReceivingSocket::new(socket, bus.clone().register(BusAddress::ReceiverSocket)); + tokio::spawn(receiver.run(args.server)); + + let need_to_download = check_file(&downloading_file, &config)?; + + let semaphore = Arc::new(Semaphore::new(8)); + let finish = Arc::new(AtomicUsize::new(need_to_download.len())); + + for to_download in need_to_download { + let to_download = to_download.clone(); + let semaphore = semaphore.clone(); + let bus = bus.clone(); + let finish = finish.clone(); + let downloading_file = downloading_file.clone(); + + let chunk_id = to_download.chunk_id as u32; + + let waiting = |finish: Arc| async move { + let permit = semaphore.acquire().await.unwrap(); + let result = + decoding::spawn::(chunk_id, bus.clone()) + .await; + + drop(permit); + let Ok(Some(result)) = result else { + eprintln!( + "Downloaded chunk {} currupted.", + to_download.chunk_id.on_red(), + ); + finish.fetch_sub(1, std::sync::atomic::Ordering::Relaxed); + return; + }; + + let hash = hex::encode(blake3::hash(&result).as_bytes()); + if hash == to_download.hash && result.len() == to_download.length { + write_at(downloading_file, to_download.offset, &result).ok(); + eprintln!( + "Succeed in download chunk {}, at [{},{})", + to_download.chunk_id.green(), + to_download.offset.magenta(), + (to_download.offset + to_download.length as u64).magenta() + ) + } else { + eprintln!( + "Downloaded chunk {} currupted.", + to_download.chunk_id.on_red(), + ) + } + + finish.fetch_sub(1, std::sync::atomic::Ordering::Relaxed); + }; + tokio::spawn(waiting(finish)); + } + + while finish.load(std::sync::atomic::Ordering::Relaxed) > 0 { + tokio::time::sleep(Duration::from_secs(5)).await; + bus.debug(); + } Ok(()) } diff --git a/src/bin/server.rs b/src/bin/server.rs new file mode 100644 index 0000000..aab3f9c --- /dev/null +++ b/src/bin/server.rs @@ -0,0 +1,86 @@ +use clap::Parser; +use std::collections::HashMap; +use std::ffi::OsString; +use std::fs::File; +use std::io::BufRead; + +use std::sync::Arc; +use std::{fs, net::SocketAddr, path::PathBuf}; +use tokio::time::Duration; +use usync::constants::TRANSMISSION_INFO_LENGTH; +use usync::engine::{Bus, BusAddress, BusMessage, sending}; +use usync::protocol::{coding::raptorq_code::RaptorqSender, init}; +use usync::transmission::real::RealUdpSocket; +use usync::util::{ + file::{CHUNK_INDEX, ChunkIndex, check_file_exist}, + plan::FileConfig, +}; + +#[derive(Parser, Debug)] +#[command(author, version, about = "Server for sending file", long_about = None)] +struct Args { + /// The path to the plan file (TOML format). + #[arg(short, long, value_name = "PLAN_FILE")] + plan_file: PathBuf, + + /// Listening addr + #[arg(short, long, value_name = "LISTEN")] + listening: SocketAddr, + + /// The path to authorized public key, one per line. + #[arg(short, long, value_name = "PUB_KEY")] + public_key: PathBuf, + + /// The path to the folder that contains the file to be downloaded. + #[arg(short, long, value_name = "DOWNLOAD_FOLDER")] + folder: PathBuf, +} + +#[tokio::main] +async fn main() -> anyhow::Result<()> { + debug_assert!( + false, + "Run in release mode instead for raptorq is too slow in debug mode." + ); + + let args = Args::parse(); + + let public_key_file = File::open(args.public_key).unwrap(); + let lines = std::io::BufReader::new(public_key_file) + .lines() + .collect::, _>>() + .unwrap(); + init(lines, None); + + let toml_str = fs::read_to_string(&args.plan_file)?; + let config: FileConfig = toml::from_str(&toml_str)?; + + let downloading_file = args.folder.join(config.file_name); + println!("Downloading file: {}", downloading_file.display()); + + check_file_exist(&downloading_file)?; + println!("{} already exists.", downloading_file.display()); + + CHUNK_INDEX + .set(ChunkIndex { + files: HashMap::from([(0usize, OsString::from(downloading_file))]), + chunks: HashMap::from_iter( + config + .chunks + .iter() + .map(|chunk| (chunk.chunk_id as u32, (0usize, chunk.offset, chunk.length))), + ), + }) + .map_err(|_| "Failed to init OnceLock") + .unwrap(); + + let bus: Arc>> = Arc::new(Bus::default()); + let socket = RealUdpSocket::bind(args.listening).await.unwrap(); + let sender = + sending::SendingSocket::new(socket, bus.clone().register(BusAddress::SenderSocket)); + tokio::spawn(sender.run::()); + loop { + tokio::time::sleep(Duration::from_secs(5)).await; + bus.debug(); + } +} diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs index 9d36bce..a172044 100644 --- a/src/protocol/mod.rs +++ b/src/protocol/mod.rs @@ -3,4 +3,4 @@ pub mod coding; mod key_ring; pub mod wire; -pub use key_ring::mock_init; +pub use key_ring::{init, mock_init}; diff --git a/src/util/file.rs b/src/util/file.rs index dc2fdfa..b40791e 100644 --- a/src/util/file.rs +++ b/src/util/file.rs @@ -39,7 +39,7 @@ pub fn sanity_check>(path: P) -> Result<(u64, String)> { Ok((length, file_name)) } -pub fn check_file_exist>(path: P) -> Result { +pub fn check_file_exist_create>(path: P) -> Result { let path = path.as_ref(); if path.exists() { if path.is_file() { @@ -52,6 +52,18 @@ pub fn check_file_exist>(path: P) -> Result { Ok(false) } +pub fn check_file_exist>(path: P) -> Result<()> { + let path = path.as_ref(); + if path.exists() { + if path.is_file() { + return Ok(()); + } else { + return Err(Error::other("The path to downloading file is not a file!")); + } + } + Err(Error::other("No such file or directory")) +} + pub fn mmap_segment>(path: P, offset: u64, length: usize) -> Result { let file = File::open(path)?; let metadata = file.metadata()?; diff --git a/src/util/plan.rs b/src/util/plan.rs index 21c4aec..4577094 100644 --- a/src/util/plan.rs +++ b/src/util/plan.rs @@ -2,7 +2,7 @@ use serde::{Deserialize, Serialize}; use crate::constants::{CHUNK_SIZE, DEFAULT_PAGE_SIZE}; -#[derive(Serialize, Deserialize, Debug)] +#[derive(Serialize, Deserialize, Debug, Clone)] pub struct FileChunk { pub chunk_id: usize, pub hash: String, From 667b7e5eee81100a664e2087ab4b1a70e9eae20e Mon Sep 17 00:00:00 2001 From: Lethe Date: Tue, 12 Aug 2025 21:09:00 +0000 Subject: [PATCH 6/6] chore: fix clippy --- src/engine/bus_flume.rs | 2 +- src/protocol/wire/frames.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/engine/bus_flume.rs b/src/engine/bus_flume.rs index a52d3f5..f3d0c5f 100644 --- a/src/engine/bus_flume.rs +++ b/src/engine/bus_flume.rs @@ -36,7 +36,7 @@ where let address = entry.key(); let sender = entry.value(); let len = sender.len(); - eprintln!("Address: {:?}, unread count: {}", address, len); + eprintln!("Address: {address:?}, unread count: {len}"); } } diff --git a/src/protocol/wire/frames.rs b/src/protocol/wire/frames.rs index b4e7d79..84388c6 100644 --- a/src/protocol/wire/frames.rs +++ b/src/protocol/wire/frames.rs @@ -67,7 +67,7 @@ fn preview_bytes(bytes: &Bytes) -> String { let preview: Vec = bytes .iter() .take(preview_len) - .map(|b| format!("{:02x}", b)) + .map(|b| format!("{b:02x}")) .collect(); format!( "[{} bytes: {}{}]",