use futures::stream::FuturesUnordered; use futures::StreamExt; use percent_encoding::percent_decode_str; use std::io::SeekFrom; use std::path::Path; use std::time::SystemTime; use tokio::io::{AsyncSeekExt, AsyncWriteExt}; use tokio::sync::mpsc; use crate::dlreport::*; use crate::errors::*; struct RollingAverage { index: usize, data: Vec, } impl RollingAverage { fn new(size: usize) -> Self { RollingAverage { index: 0, data: Vec::with_capacity(size), } } fn value(&self) -> f64 { if self.data.len() == 0 { 0.0 } else { let mut max = self.data[0]; for v in self.data.iter() { if *v > max { max = *v; } } let mut sum: f64 = self.data.iter().sum(); let mut count = self.data.len(); if self.data.len() >= 3 { sum -= max; count -= 1; } sum / count as f64 } } fn add(&mut self, val: f64) { if self.data.capacity() == self.data.len() { self.data[self.index] = val; self.index += 1; if self.index >= self.data.capacity() { self.index = 0; } } else { self.data.push(val); } } } /// Get the filename at the end of the given URL. This will decode the URL Encoding. pub fn url_to_filename(url: &str) -> String { let url_dec = percent_decode_str(&url) .decode_utf8_lossy() .to_owned() .to_string(); let file_name = std::path::Path::new(&url_dec) .file_name() .unwrap() .to_str() .unwrap(); // Split at ? and return the first part. If no ? is present, this just returns the full string file_name.split("?").next().unwrap().to_string() } pub async fn download_feedback( url: &str, into_file: &Path, rep: DlReporter, content_length: Option, ) -> ResBE<()> { download_feedback_chunks(url, into_file, rep, None, content_length).await } pub async fn download_feedback_chunks( url: &str, into_file: &Path, rep: DlReporter, from_to: Option<(u64, u64)>, content_length: Option, ) -> ResBE<()> { let mut content_length = match content_length { Some(it) => it, None => { let (content_length, _) = http_get_filesize_and_range_support(url).await?; content_length } }; // Send the HTTP request to download the given link let mut req = reqwest::Client::new().get(url); // Add range header if needed if let Some((from, to)) = from_to { req = req.header(reqwest::header::RANGE, format!("bytes={}-{}", from, to)); content_length = to - from + 1; } // Actually send the request and get the response let mut resp = req.send().await?; // Error out if the server response is not success (something went wrong) if !resp.status().is_success() { return Err(DlError::BadHttpStatus.into()); } // Open the local output file let mut opts = tokio::fs::OpenOptions::new(); let mut ofile = opts .create(true) .write(true) .truncate(!from_to.is_some()) .open(into_file) .await?; if from_to.is_some() { ofile.seek(SeekFrom::Start(from_to.unwrap().0)).await?; } let filename = into_file.file_name().unwrap().to_str().unwrap(); // Report the download start rep.send(DlStatus::Init { bytes_total: content_length, filename: filename.to_string(), }); let mut curr_progress = 0; let mut speed_mbps = 0.0; let t_start = SystemTime::now(); let mut t_last_speed = SystemTime::now(); let mut last_bytecount = 0; let mut average_speed = RollingAverage::new(10); let mut buff: Vec = Vec::new(); // Read data from server as long as new data is available while let Some(chunk) = resp.chunk().await? { let datalen = chunk.len() as u64; buff.extend(chunk); // Buffer in memory first and only write to disk if the threshold is reached. // This reduces the number of small disk writes and thereby reduces the // io bottleneck that occurs on HDDs with many small writes in different // files and offsets at the same time if buff.len() >= 1_000_000 { // Write the received data into the file ofile.write_all(&buff).await?; buff.clear(); } // Update progress curr_progress += datalen; // Update the number of bytes downloaded since the last report last_bytecount += datalen; let t_elapsed = t_last_speed.elapsed()?.as_secs_f64(); // Update the reported download speed after every 3MB or every second // depending on what happens first if last_bytecount >= 3_000_000 || t_elapsed >= 0.8 { // Update rolling average average_speed.add(((last_bytecount as f64) / t_elapsed) / 1_000_000.0); speed_mbps = average_speed.value() as f32; // Reset the time and bytecount last_bytecount = 0; t_last_speed = SystemTime::now(); } // Send status update report rep.send(DlStatus::Update { speed_mbps, bytes_curr: curr_progress, }); } if buff.len() > 0 { ofile.write_all(&buff).await?; } if curr_progress != content_length { return Err(DlError::HttpNoData.into()); } // Ensure that IO is completed //ofile.flush().await?; let duration_ms = t_start.elapsed()?.as_millis() as u64; // Send report that the download is finished rep.send(DlStatus::Done { duration_ms }); Ok(()) } // This will spin up multiple tasks that and manage the status updates for them. // The combined status will be reported back to the caller pub async fn download_feedback_multi( url: &str, into_file: &Path, rep: DlReporter, conn_count: u32, content_length: Option, ) -> ResBE<()> { let content_length = match content_length { Some(it) => it, None => http_get_filesize_and_range_support(url).await?.0, }; // Create zeroed file with 1 byte too much. This will be truncated on download // completion and can indicate that the file is not suitable for continuation create_zeroed_file(into_file, content_length as usize + 1).await?; let chunksize = content_length / conn_count as u64; let rest = content_length % conn_count as u64; let mut joiners = Vec::new(); let (tx, mut rx) = mpsc::unbounded_channel::(); let t_start = SystemTime::now(); for index in 0..conn_count { let url = url.clone().to_owned(); let into_file = into_file.clone().to_owned(); let tx = tx.clone(); joiners.push(tokio::spawn(async move { let rep = DlReporter::new(index, tx.clone()); let mut from_to = (index as u64 * chunksize, (index + 1) as u64 * chunksize - 1); if index == conn_count - 1 { from_to.1 += rest; } let specific_content_length = from_to.1 - from_to.0 + 1; // Delay each chunk-download to reduce the number of simultanious connection attempts tokio::time::sleep(tokio::time::Duration::from_millis(50 * index as u64)).await; download_feedback_chunks( &url, &into_file, rep, Some(from_to), Some(specific_content_length), ) .await .map_err(|e| e.to_string()) })) } drop(tx); let filename = Path::new(into_file).file_name().unwrap().to_str().unwrap(); rep.send(DlStatus::Init { bytes_total: content_length, filename: filename.to_string(), }); let rep_task = rep.clone(); let mut t_last = t_start.clone(); let manager_handle = tokio::task::spawn(async move { let rep = rep_task; //let mut dl_speeds = vec![0.0_f32; conn_count as usize]; let mut progresses = vec![0; conn_count as usize]; let mut progress_last: u64 = 0; let mut average_speed = RollingAverage::new(10); while let Some(update) = rx.recv().await { match update.status { DlStatus::Init { bytes_total: _, filename: _, } => {} DlStatus::Update { speed_mbps: _, bytes_curr, } => { //dl_speeds[update.id as usize] = speed_mbps; progresses[update.id as usize] = bytes_curr; let progress_curr = progresses.iter().sum(); let progress_delta = progress_curr - progress_last; let t_elapsed = t_last.elapsed().unwrap().as_secs_f64(); let speed_mbps = average_speed.value() as f32; // currently executes always, but might change if progress_delta >= 5_000_000 { average_speed.add(((progress_delta as f64) / 1_000_000.0) / t_elapsed); progress_last = progress_curr; t_last = SystemTime::now(); } rep.send(DlStatus::Update { speed_mbps: speed_mbps, bytes_curr: progress_curr, }); } DlStatus::Done { duration_ms: _ } => { //dl_speeds[update.id as usize] = 0.0; } // Just forwared everything else to the calling receiver _ => rep.send(update.status), } } }); let mut joiners: FuturesUnordered<_> = joiners.into_iter().collect(); // Validate if the tasks were successful. This will always grab the next completed // task, independent from the original order in the joiners list while let Some(output) = joiners.next().await { // If any of the download tasks fail, abort the rest and delete the file // since it is non-recoverable anyways if let Err(e) = output? { for handle in joiners.iter() { handle.abort(); } manager_handle.abort(); tokio::fs::remove_file(&into_file).await?; return Err(e.into()); } } manager_handle.await?; // Remove the additional byte at the file end let ofile = tokio::fs::OpenOptions::new() .create(false) .write(true) .truncate(false) .open(&into_file) .await?; ofile.set_len(content_length).await?; rep.send(DlStatus::Done { duration_ms: t_start.elapsed()?.as_millis() as u64, }); Ok(()) } async fn create_zeroed_file(file: &Path, filesize: usize) -> ResBE<()> { let ofile = tokio::fs::OpenOptions::new() .create(true) // Open in write mode .write(true) // Delete and overwrite the file .truncate(true) .open(file) .await?; ofile.set_len(filesize as u64).await?; Ok(()) } pub async fn http_get_filesize_and_range_support(url: &str) -> ResBE<(u64, bool)> { let resp = reqwest::Client::new().head(url).send().await?; if let Some(filesize) = resp.headers().get(reqwest::header::CONTENT_LENGTH) { if let Ok(val_str) = filesize.to_str() { if let Ok(val) = val_str.parse::() { let mut range_supported = false; if let Some(range) = resp.headers().get(reqwest::header::ACCEPT_RANGES) { if let Ok(range) = range.to_str() { if range == "bytes" { range_supported = true; } } } return Ok((val, range_supported)); } } } Err(DlError::ContentLengthUnknown.into()) } #[cfg(test)] mod tests { use super::*; #[test] fn rolling_average() { let mut ra = RollingAverage::new(3); assert_eq!(0, ra.data.len()); assert_eq!(3, ra.data.capacity()); assert_eq!(0.0, ra.value()); // 10 / 1 = 10 ra.add(10.0); assert_eq!(1, ra.data.len()); assert_eq!(10.0, ra.value()); // (10 + 20) / 2 = 15 ra.add(20.0); assert_eq!(2, ra.data.len()); assert_eq!(15.0, ra.value()); // (10 + 20 + 30) / 3 = 20 ra.add(30.0); assert_eq!(3, ra.data.len()); assert_eq!(20.0, ra.value()); assert_eq!(10.0, ra.data[0]); assert_eq!(20.0, ra.data[1]); assert_eq!(30.0, ra.data[2]); // This should replace the oldest value (index 1) ra.add(40.0); assert_eq!(3, ra.data.len()); assert_eq!(3, ra.data.capacity()); // (40 + 20 + 30) / 3 = 30 assert_eq!(30.0, ra.value()); assert_eq!(40.0, ra.data[0]); assert_eq!(20.0, ra.data[1]); assert_eq!(30.0, ra.data[2]); ra.add(50.0); ra.add(60.0); ra.add(70.0); assert_eq!(70.0, ra.data[0]); assert_eq!(50.0, ra.data[1]); assert_eq!(60.0, ra.data[2]); } }