diff --git a/src/download.rs b/src/download.rs index d9766e9..32a82a1 100644 --- a/src/download.rs +++ b/src/download.rs @@ -1,7 +1,10 @@ use std::path::Path; -use tokio::io::AsyncWriteExt; +use tokio::io::{ AsyncWriteExt, AsyncSeekExt }; use std::time::SystemTime; use percent_encoding::percent_decode_str; +use std::io::SeekFrom; +use tokio::sync::mpsc; +use futures::future::join_all; use crate::errors::*; use crate::dlreport::*; @@ -83,35 +86,59 @@ pub async fn download(url: &str, into_file: &str) -> ResBE<()> { Ok(()) } + pub async fn download_feedback(url: &str, into_file: &str, rep: DlReporter) -> ResBE<()> { + + download_feedback_chunks(url, into_file, rep, None, false).await + +} + +pub async fn download_feedback_chunks(url: &str, into_file: &str, rep: DlReporter, from_to: Option<(u64, u64)>, seek_from: bool) -> ResBE<()> { let into_file = Path::new(into_file); + let (mut content_length, range_supported) = http_get_filesize_and_range_support(url).await?; + + if from_to != None && !range_supported{ + return Err(DlError::Other("Server doesn't support range header".to_string()).into()); + } + // Send the HTTP request to download the given link - let mut resp = reqwest::Client::new() - .get(url) - .send().await?; + let mut req = reqwest::Client::new() + .get(url); + + // Add range header if needed + if let Some(from_to) = from_to { + req = req.header(reqwest::header::RANGE, format!("bytes={}-{}", from_to.0, from_to.1)); + content_length = from_to.1 - from_to.0 + 1; + } + + // Actually send the request and get the response + let mut resp = req.send().await?; // Error out if the server response is not success (something went wrong) if !resp.status().is_success() { return Err(DlError::BadHttpStatus.into()); } - - // Get the content length for status update. If not present, error out cause - // without progress everything sucks anyways - let content_length = match resp.headers().get(reqwest::header::CONTENT_LENGTH) { - Some(cl) => cl.to_str()?.parse::()?, - None => return Err(DlError::ContentLengthUnknown.into()) - }; // Open the local output file - let mut ofile = tokio::fs::OpenOptions::new() + let mut ofile = tokio::fs::OpenOptions::new(); + + // Create the file if not existant + ofile.create(true) // Open in write mode - .write(true) + .write(true); + + // If seek_from is specified, the file cant be overwritten + if !seek_from { // Delete and overwrite the file - .truncate(true) - // Create the file if not existant - .create(true) - .open(into_file).await?; + ofile.truncate(true); + } + + let mut ofile = ofile.open(into_file).await?; + + if seek_from { + ofile.seek(SeekFrom::Start(from_to.unwrap().0)).await?; + } let filename = into_file.file_name().unwrap().to_str().unwrap(); @@ -148,8 +175,8 @@ pub async fn download_feedback(url: &str, into_file: &str, rep: DlReporter) -> R // Update the number of bytes downloaded since the last report last_bytecount += datalen; - // Update the reported download speed after every 10MB - if last_bytecount > 10_000_000 { + // Update the reported download speed after every 5MB + if last_bytecount > 5_000_000 { let t_elapsed = t_last_speed.elapsed()?.as_millis(); // Update rolling average @@ -188,6 +215,143 @@ pub async fn download_feedback(url: &str, into_file: &str, rep: DlReporter) -> R Ok(()) } +// This will spin up multiple tasks that and manage the status updates for them. +// The combined status will be reported back to the caller +pub async fn download_feedback_multi(url: &str, into_file: &str, rep: DlReporter, numparal: i32) -> ResBE<()> { + + let (content_length, _) = http_get_filesize_and_range_support(url).await?; + + // Create zeroed file with 1 byte too much. This will be truncated on download + // completion and can indicate that the file is not suitable for continuation + create_zeroed_file(into_file, content_length as usize + 1).await?; + + let chunksize = content_length / numparal as u64; + let rest = content_length % numparal as u64; + + let mut joiners = Vec::new(); + + let (tx, mut rx) = mpsc::unbounded_channel::(); + + let t_start = SystemTime::now(); + + for index in 0 .. numparal { + + let url = url.clone().to_owned(); + let into_file = into_file.clone().to_owned(); + + let tx = tx.clone(); + + joiners.push(tokio::spawn(async move { + + let rep = DlReporter::new(index, tx.clone()); + + let mut from_to = ( + index as u64 * chunksize, + (index+1) as u64 * chunksize - 1 + ); + + if index == numparal - 1 { + from_to.1 += rest; + } + + download_feedback_chunks(&url, &into_file, rep, Some(from_to), true).await.unwrap(); + + })) + } + + drop(tx); + + rep.send(DlStatus::Init { + bytes_total: content_length, + filename: into_file.to_string() + })?; + + let mut update_counter = 0; + let mut dl_speeds = vec![0.0f64; numparal as usize]; + let mut progresses = vec![0; numparal as usize]; + + while let Some(update) = rx.recv().await { + match update.status { + + DlStatus::Init { + bytes_total: _, + filename: _ + } => { + + }, + DlStatus::Update { + speed_mbps, + bytes_curr + } => { + + dl_speeds[update.id as usize] = speed_mbps; + progresses[update.id as usize] = bytes_curr; + + if update_counter == 10 { + update_counter = 0; + + let speed = dl_speeds.iter().sum(); + let curr = progresses.iter().sum(); + + rep.send(DlStatus::Update { + speed_mbps: speed, + bytes_curr: curr + })?; + + } else { + update_counter += 1; + } + + }, + DlStatus::Done { + duration_ms: _ + } => { + + dl_speeds[update.id as usize] = 0.0; + + } + + } + } + + join_all(joiners).await; + + + // Remove the additional byte at the file end + let ofile = tokio::fs::OpenOptions::new() + .create(false) + .write(true) + .truncate(false) + .open(&into_file) + .await?; + + ofile.set_len(content_length).await?; + + + + rep.send(DlStatus::Done { + duration_ms: t_start.elapsed()?.as_millis() as u64 + })?; + + Ok(()) +} + +async fn create_zeroed_file(file: &str, filesize: usize) -> ResBE<()> { + + let ofile = tokio::fs::OpenOptions::new() + .create(true) + // Open in write mode + .write(true) + // Delete and overwrite the file + .truncate(true) + .open(file) + .await?; + + ofile.set_len(filesize as u64).await?; + + Ok(()) +} + pub async fn http_get_filesize_and_range_support(url: &str) -> ResBE<(u64, bool)> { let resp = reqwest::Client::new() .head(url) diff --git a/src/errors.rs b/src/errors.rs index ae06f1b..781e70a 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -8,7 +8,7 @@ pub type ResBE = Result>; pub enum DlError { BadHttpStatus, ContentLengthUnknown, - Other + Other(String) } impl Error for DlError {} @@ -19,7 +19,7 @@ impl Display for DlError { match self { DlError::BadHttpStatus => write!(f, "Bad http response status"), DlError::ContentLengthUnknown => write!(f, "Content-Length is unknown"), - DlError::Other => write!(f, "Unknown download error") + DlError::Other(s) => write!(f, "Unknown download error: '{}'", s) } } diff --git a/src/main.rs b/src/main.rs index 33c2567..60ada6b 100644 --- a/src/main.rs +++ b/src/main.rs @@ -101,6 +101,7 @@ async fn main() -> ResBE<()> { let is_zippy = arguments.is_present("zippyshare"); + if arguments.is_present("listfile") { let listfile = arguments.value_of("listfile").unwrap(); @@ -114,6 +115,7 @@ async fn main() -> ResBE<()> { .collect(); if is_zippy { + println!("Pre-resolving zippyshare URLs"); let mut zippy_urls = Vec::new(); for url in urls { zippy_urls.push( @@ -149,7 +151,7 @@ async fn main() -> ResBE<()> { url.to_string() }; - download_one(&url, outdir).await?; + download_one(&url, outdir, numparal).await?; } else if arguments.is_present("resolve") { @@ -175,7 +177,7 @@ async fn main() -> ResBE<()> { } -async fn download_one(url: &str, outdir: &str) -> ResBE<()> { +async fn download_one(url: &str, outdir: &str, numparal: i32) -> ResBE<()> { let outdir = Path::new(outdir); if !outdir.exists() { @@ -210,9 +212,16 @@ async fn download_one(url: &str, outdir: &str) -> ResBE<()> { // Create reporter with id 0 since there is only one anyways let rep = DlReporter::new(0, tx); - if let Err(e) = download::download_feedback(&url, &into_file, rep).await { - eprintln!("Error while downloading"); - eprintln!("{}", e); + if numparal == 1 { + if let Err(e) = download::download_feedback(&url, &into_file, rep).await { + eprintln!("Error while downloading"); + eprintln!("{}", e); + } + } else { + if let Err(e) = download::download_feedback_multi(&url, &into_file, rep, numparal).await { + eprintln!("Error while downloading"); + eprintln!("{}", e); + } } }); @@ -357,7 +366,7 @@ async fn download_multiple(urls: Vec, outdir: &str, numparal: i32) -> Re s.3 = speed_mbps; } - if t_last.elapsed().unwrap().as_millis() > 2000 { + if t_last.elapsed().unwrap().as_millis() > 500 { let mut dl_speed_sum = 0.0; @@ -368,13 +377,29 @@ async fn download_multiple(urls: Vec, outdir: &str, numparal: i32) -> Re let speed_mbps = v.3; let percent_complete = bytes_curr as f64 / filesize as f64 * 100.0; + + + crossterm::execute!( + std::io::stdout(), + crossterm::terminal::Clear(crossterm::terminal::ClearType::CurrentLine) + ); + println!("Status: {:6.2} mb/s {:5.2}% completed '{}'", speed_mbps, percent_complete, filename); dl_speed_sum += speed_mbps; } + crossterm::execute!( + std::io::stdout(), + crossterm::terminal::Clear(crossterm::terminal::ClearType::CurrentLine) + ); println!("Accumulated download speed: {:6.2} mb/s\n", dl_speed_sum); + crossterm::execute!( + std::io::stdout(), + crossterm::cursor::MoveUp(statuses.len() as u16 + 2) + ); + t_last = SystemTime::now(); } @@ -396,6 +421,7 @@ async fn download_multiple(urls: Vec, outdir: &str, numparal: i32) -> Re } } + join_all(joiners).await;