Implement concurrent download for single files

- Added `download_feedback_multi` that downloads one file using multiple
  connections.
- The file is preallocated and zero-filled and then written to in
  parallel at different offsets.
This commit is contained in:
Daniel M 2021-03-25 21:35:58 +01:00
parent a8474aab1e
commit 9ca93cbeb2
3 changed files with 217 additions and 27 deletions

View File

@ -1,7 +1,10 @@
use std::path::Path;
use tokio::io::AsyncWriteExt;
use tokio::io::{ AsyncWriteExt, AsyncSeekExt };
use std::time::SystemTime;
use percent_encoding::percent_decode_str;
use std::io::SeekFrom;
use tokio::sync::mpsc;
use futures::future::join_all;
use crate::errors::*;
use crate::dlreport::*;
@ -83,35 +86,59 @@ pub async fn download(url: &str, into_file: &str) -> ResBE<()> {
Ok(())
}
pub async fn download_feedback(url: &str, into_file: &str, rep: DlReporter) -> ResBE<()> {
download_feedback_chunks(url, into_file, rep, None, false).await
}
pub async fn download_feedback_chunks(url: &str, into_file: &str, rep: DlReporter, from_to: Option<(u64, u64)>, seek_from: bool) -> ResBE<()> {
let into_file = Path::new(into_file);
let (mut content_length, range_supported) = http_get_filesize_and_range_support(url).await?;
if from_to != None && !range_supported{
return Err(DlError::Other("Server doesn't support range header".to_string()).into());
}
// Send the HTTP request to download the given link
let mut resp = reqwest::Client::new()
.get(url)
.send().await?;
let mut req = reqwest::Client::new()
.get(url);
// Add range header if needed
if let Some(from_to) = from_to {
req = req.header(reqwest::header::RANGE, format!("bytes={}-{}", from_to.0, from_to.1));
content_length = from_to.1 - from_to.0 + 1;
}
// Actually send the request and get the response
let mut resp = req.send().await?;
// Error out if the server response is not success (something went wrong)
if !resp.status().is_success() {
return Err(DlError::BadHttpStatus.into());
}
// Get the content length for status update. If not present, error out cause
// without progress everything sucks anyways
let content_length = match resp.headers().get(reqwest::header::CONTENT_LENGTH) {
Some(cl) => cl.to_str()?.parse::<u64>()?,
None => return Err(DlError::ContentLengthUnknown.into())
};
// Open the local output file
let mut ofile = tokio::fs::OpenOptions::new()
let mut ofile = tokio::fs::OpenOptions::new();
// Create the file if not existant
ofile.create(true)
// Open in write mode
.write(true)
.write(true);
// If seek_from is specified, the file cant be overwritten
if !seek_from {
// Delete and overwrite the file
.truncate(true)
// Create the file if not existant
.create(true)
.open(into_file).await?;
ofile.truncate(true);
}
let mut ofile = ofile.open(into_file).await?;
if seek_from {
ofile.seek(SeekFrom::Start(from_to.unwrap().0)).await?;
}
let filename = into_file.file_name().unwrap().to_str().unwrap();
@ -148,8 +175,8 @@ pub async fn download_feedback(url: &str, into_file: &str, rep: DlReporter) -> R
// Update the number of bytes downloaded since the last report
last_bytecount += datalen;
// Update the reported download speed after every 10MB
if last_bytecount > 10_000_000 {
// Update the reported download speed after every 5MB
if last_bytecount > 5_000_000 {
let t_elapsed = t_last_speed.elapsed()?.as_millis();
// Update rolling average
@ -188,6 +215,143 @@ pub async fn download_feedback(url: &str, into_file: &str, rep: DlReporter) -> R
Ok(())
}
// This will spin up multiple tasks that and manage the status updates for them.
// The combined status will be reported back to the caller
pub async fn download_feedback_multi(url: &str, into_file: &str, rep: DlReporter, numparal: i32) -> ResBE<()> {
let (content_length, _) = http_get_filesize_and_range_support(url).await?;
// Create zeroed file with 1 byte too much. This will be truncated on download
// completion and can indicate that the file is not suitable for continuation
create_zeroed_file(into_file, content_length as usize + 1).await?;
let chunksize = content_length / numparal as u64;
let rest = content_length % numparal as u64;
let mut joiners = Vec::new();
let (tx, mut rx) = mpsc::unbounded_channel::<DlReport>();
let t_start = SystemTime::now();
for index in 0 .. numparal {
let url = url.clone().to_owned();
let into_file = into_file.clone().to_owned();
let tx = tx.clone();
joiners.push(tokio::spawn(async move {
let rep = DlReporter::new(index, tx.clone());
let mut from_to = (
index as u64 * chunksize,
(index+1) as u64 * chunksize - 1
);
if index == numparal - 1 {
from_to.1 += rest;
}
download_feedback_chunks(&url, &into_file, rep, Some(from_to), true).await.unwrap();
}))
}
drop(tx);
rep.send(DlStatus::Init {
bytes_total: content_length,
filename: into_file.to_string()
})?;
let mut update_counter = 0;
let mut dl_speeds = vec![0.0f64; numparal as usize];
let mut progresses = vec![0; numparal as usize];
while let Some(update) = rx.recv().await {
match update.status {
DlStatus::Init {
bytes_total: _,
filename: _
} => {
},
DlStatus::Update {
speed_mbps,
bytes_curr
} => {
dl_speeds[update.id as usize] = speed_mbps;
progresses[update.id as usize] = bytes_curr;
if update_counter == 10 {
update_counter = 0;
let speed = dl_speeds.iter().sum();
let curr = progresses.iter().sum();
rep.send(DlStatus::Update {
speed_mbps: speed,
bytes_curr: curr
})?;
} else {
update_counter += 1;
}
},
DlStatus::Done {
duration_ms: _
} => {
dl_speeds[update.id as usize] = 0.0;
}
}
}
join_all(joiners).await;
// Remove the additional byte at the file end
let ofile = tokio::fs::OpenOptions::new()
.create(false)
.write(true)
.truncate(false)
.open(&into_file)
.await?;
ofile.set_len(content_length).await?;
rep.send(DlStatus::Done {
duration_ms: t_start.elapsed()?.as_millis() as u64
})?;
Ok(())
}
async fn create_zeroed_file(file: &str, filesize: usize) -> ResBE<()> {
let ofile = tokio::fs::OpenOptions::new()
.create(true)
// Open in write mode
.write(true)
// Delete and overwrite the file
.truncate(true)
.open(file)
.await?;
ofile.set_len(filesize as u64).await?;
Ok(())
}
pub async fn http_get_filesize_and_range_support(url: &str) -> ResBE<(u64, bool)> {
let resp = reqwest::Client::new()
.head(url)

View File

@ -8,7 +8,7 @@ pub type ResBE<T> = Result<T, Box<dyn Error>>;
pub enum DlError {
BadHttpStatus,
ContentLengthUnknown,
Other
Other(String)
}
impl Error for DlError {}
@ -19,7 +19,7 @@ impl Display for DlError {
match self {
DlError::BadHttpStatus => write!(f, "Bad http response status"),
DlError::ContentLengthUnknown => write!(f, "Content-Length is unknown"),
DlError::Other => write!(f, "Unknown download error")
DlError::Other(s) => write!(f, "Unknown download error: '{}'", s)
}
}

View File

@ -101,6 +101,7 @@ async fn main() -> ResBE<()> {
let is_zippy = arguments.is_present("zippyshare");
if arguments.is_present("listfile") {
let listfile = arguments.value_of("listfile").unwrap();
@ -114,6 +115,7 @@ async fn main() -> ResBE<()> {
.collect();
if is_zippy {
println!("Pre-resolving zippyshare URLs");
let mut zippy_urls = Vec::new();
for url in urls {
zippy_urls.push(
@ -149,7 +151,7 @@ async fn main() -> ResBE<()> {
url.to_string()
};
download_one(&url, outdir).await?;
download_one(&url, outdir, numparal).await?;
} else if arguments.is_present("resolve") {
@ -175,7 +177,7 @@ async fn main() -> ResBE<()> {
}
async fn download_one(url: &str, outdir: &str) -> ResBE<()> {
async fn download_one(url: &str, outdir: &str, numparal: i32) -> ResBE<()> {
let outdir = Path::new(outdir);
if !outdir.exists() {
@ -210,9 +212,16 @@ async fn download_one(url: &str, outdir: &str) -> ResBE<()> {
// Create reporter with id 0 since there is only one anyways
let rep = DlReporter::new(0, tx);
if let Err(e) = download::download_feedback(&url, &into_file, rep).await {
eprintln!("Error while downloading");
eprintln!("{}", e);
if numparal == 1 {
if let Err(e) = download::download_feedback(&url, &into_file, rep).await {
eprintln!("Error while downloading");
eprintln!("{}", e);
}
} else {
if let Err(e) = download::download_feedback_multi(&url, &into_file, rep, numparal).await {
eprintln!("Error while downloading");
eprintln!("{}", e);
}
}
});
@ -357,7 +366,7 @@ async fn download_multiple(urls: Vec<String>, outdir: &str, numparal: i32) -> Re
s.3 = speed_mbps;
}
if t_last.elapsed().unwrap().as_millis() > 2000 {
if t_last.elapsed().unwrap().as_millis() > 500 {
let mut dl_speed_sum = 0.0;
@ -368,13 +377,29 @@ async fn download_multiple(urls: Vec<String>, outdir: &str, numparal: i32) -> Re
let speed_mbps = v.3;
let percent_complete = bytes_curr as f64 / filesize as f64 * 100.0;
crossterm::execute!(
std::io::stdout(),
crossterm::terminal::Clear(crossterm::terminal::ClearType::CurrentLine)
);
println!("Status: {:6.2} mb/s {:5.2}% completed '{}'", speed_mbps, percent_complete, filename);
dl_speed_sum += speed_mbps;
}
crossterm::execute!(
std::io::stdout(),
crossterm::terminal::Clear(crossterm::terminal::ClearType::CurrentLine)
);
println!("Accumulated download speed: {:6.2} mb/s\n", dl_speed_sum);
crossterm::execute!(
std::io::stdout(),
crossterm::cursor::MoveUp(statuses.len() as u16 + 2)
);
t_last = SystemTime::now();
}
@ -397,6 +422,7 @@ async fn download_multiple(urls: Vec<String>, outdir: &str, numparal: i32) -> Re
}
}
join_all(joiners).await;