ffdl/src/download.rs
2022-03-31 01:16:39 +02:00

473 lines
14 KiB
Rust

use futures::stream::FuturesUnordered;
use futures::StreamExt;
use percent_encoding::percent_decode_str;
use std::io::SeekFrom;
use std::path::Path;
use std::time::SystemTime;
use tokio::io::{AsyncSeekExt, AsyncWriteExt};
use tokio::sync::mpsc;
use crate::dlreport::*;
use crate::errors::*;
struct RollingAverage {
index: usize,
data: Vec<f64>,
}
impl RollingAverage {
fn new(size: usize) -> Self {
RollingAverage {
index: 0,
data: Vec::with_capacity(size),
}
}
fn value(&self) -> f64 {
if self.data.len() == 0 {
0.0
} else {
let mut max = self.data[0];
for v in self.data.iter() {
if *v > max {
max = *v;
}
}
let mut sum: f64 = self.data.iter().sum();
let mut count = self.data.len();
if self.data.len() >= 3 {
sum -= max;
count -= 1;
}
sum / count as f64
}
}
fn add(&mut self, val: f64) {
if self.data.capacity() == self.data.len() {
self.data[self.index] = val;
self.index += 1;
if self.index >= self.data.capacity() {
self.index = 0;
}
} else {
self.data.push(val);
}
}
}
/// Get the filename at the end of the given URL. This will decode the URL Encoding.
pub fn url_to_filename(url: &str) -> String {
let url_dec = percent_decode_str(&url)
.decode_utf8_lossy()
.to_owned()
.to_string();
let file_name = std::path::Path::new(&url_dec)
.file_name()
.unwrap()
.to_str()
.unwrap();
// Split at ? and return the first part. If no ? is present, this just returns the full string
file_name.split("?").next().unwrap().to_string()
}
pub async fn download_feedback(
url: &str,
into_file: &Path,
rep: DlReporter,
content_length: Option<u64>,
) -> ResBE<()> {
download_feedback_chunks(url, into_file, rep, None, content_length).await
}
pub async fn download_feedback_chunks(
url: &str,
into_file: &Path,
rep: DlReporter,
from_to: Option<(u64, u64)>,
content_length: Option<u64>,
) -> ResBE<()> {
let mut content_length = match content_length {
Some(it) => it,
None => {
let (content_length, _) = http_get_filesize_and_range_support(url).await?;
content_length
}
};
// Send the HTTP request to download the given link
let mut req = reqwest::Client::new().get(url);
// Add range header if needed
if let Some((from, to)) = from_to {
req = req.header(reqwest::header::RANGE, format!("bytes={}-{}", from, to));
content_length = to - from + 1;
}
// Actually send the request and get the response
let mut resp = req.send().await?;
// Error out if the server response is not success (something went wrong)
if !resp.status().is_success() {
return Err(DlError::BadHttpStatus.into());
}
// Open the local output file
let mut opts = tokio::fs::OpenOptions::new();
let mut ofile = opts
.create(true)
.write(true)
.truncate(!from_to.is_some())
.open(into_file)
.await?;
if from_to.is_some() {
ofile.seek(SeekFrom::Start(from_to.unwrap().0)).await?;
}
let filename = into_file.file_name().unwrap().to_str().unwrap();
// Report the download start
rep.send(DlStatus::Init {
bytes_total: content_length,
filename: filename.to_string(),
});
let mut curr_progress = 0;
let mut speed_mbps = 0.0;
let t_start = SystemTime::now();
let mut t_last_speed = SystemTime::now();
let mut last_bytecount = 0;
let mut average_speed = RollingAverage::new(10);
let mut buff: Vec<u8> = Vec::new();
// Read data from server as long as new data is available
while let Some(chunk) = resp.chunk().await? {
let datalen = chunk.len() as u64;
buff.extend(chunk);
// Buffer in memory first and only write to disk if the threshold is reached.
// This reduces the number of small disk writes and thereby reduces the
// io bottleneck that occurs on HDDs with many small writes in different
// files and offsets at the same time
if buff.len() >= 1_000_000 {
// Write the received data into the file
ofile.write_all(&buff).await?;
buff.clear();
}
// Update progress
curr_progress += datalen;
// Update the number of bytes downloaded since the last report
last_bytecount += datalen;
let t_elapsed = t_last_speed.elapsed()?.as_secs_f64();
// Update the reported download speed after every 3MB or every second
// depending on what happens first
if last_bytecount >= 3_000_000 || t_elapsed >= 0.8 {
// Update rolling average
average_speed.add(((last_bytecount as f64) / t_elapsed) / 1_000_000.0);
speed_mbps = average_speed.value() as f32;
// Reset the time and bytecount
last_bytecount = 0;
t_last_speed = SystemTime::now();
}
// Send status update report
rep.send(DlStatus::Update {
speed_mbps,
bytes_curr: curr_progress,
});
}
if buff.len() > 0 {
ofile.write_all(&buff).await?;
}
if curr_progress != content_length {
return Err(DlError::HttpNoData.into());
}
// Ensure that IO is completed
//ofile.flush().await?;
let duration_ms = t_start.elapsed()?.as_millis() as u64;
// Send report that the download is finished
rep.send(DlStatus::Done { duration_ms });
Ok(())
}
// This will spin up multiple tasks that and manage the status updates for them.
// The combined status will be reported back to the caller
pub async fn download_feedback_multi(
url: &str,
into_file: &Path,
rep: DlReporter,
conn_count: u32,
content_length: Option<u64>,
) -> ResBE<()> {
let content_length = match content_length {
Some(it) => it,
None => http_get_filesize_and_range_support(url).await?.0,
};
// Create zeroed file with 1 byte too much. This will be truncated on download
// completion and can indicate that the file is not suitable for continuation
create_zeroed_file(into_file, content_length as usize + 1).await?;
let chunksize = content_length / conn_count as u64;
let rest = content_length % conn_count as u64;
let mut joiners = Vec::new();
let (tx, mut rx) = mpsc::unbounded_channel::<DlReport>();
let t_start = SystemTime::now();
for index in 0..conn_count {
let url = url.clone().to_owned();
let into_file = into_file.clone().to_owned();
let tx = tx.clone();
joiners.push(tokio::spawn(async move {
let rep = DlReporter::new(index, tx.clone());
let mut from_to = (index as u64 * chunksize, (index + 1) as u64 * chunksize - 1);
if index == conn_count - 1 {
from_to.1 += rest;
}
let specific_content_length = from_to.1 - from_to.0 + 1;
// Delay each chunk-download to reduce the number of simultanious connection attempts
tokio::time::sleep(tokio::time::Duration::from_millis(50 * index as u64)).await;
download_feedback_chunks(
&url,
&into_file,
rep,
Some(from_to),
Some(specific_content_length),
)
.await
.map_err(|e| e.to_string())
}))
}
drop(tx);
let filename = Path::new(into_file).file_name().unwrap().to_str().unwrap();
rep.send(DlStatus::Init {
bytes_total: content_length,
filename: filename.to_string(),
});
let rep_task = rep.clone();
let mut t_last = t_start.clone();
let manager_handle = tokio::task::spawn(async move {
let rep = rep_task;
//let mut dl_speeds = vec![0.0_f32; conn_count as usize];
let mut progresses = vec![0; conn_count as usize];
let mut progress_last: u64 = 0;
let mut average_speed = RollingAverage::new(10);
while let Some(update) = rx.recv().await {
match update.status {
DlStatus::Init {
bytes_total: _,
filename: _,
} => {}
DlStatus::Update {
speed_mbps: _,
bytes_curr,
} => {
//dl_speeds[update.id as usize] = speed_mbps;
progresses[update.id as usize] = bytes_curr;
let progress_curr = progresses.iter().sum();
let progress_delta = progress_curr - progress_last;
let t_elapsed = t_last.elapsed().unwrap().as_secs_f64();
let speed_mbps = average_speed.value() as f32;
// currently executes always, but might change
if progress_delta >= 5_000_000 {
average_speed.add(((progress_delta as f64) / 1_000_000.0) / t_elapsed);
progress_last = progress_curr;
t_last = SystemTime::now();
}
rep.send(DlStatus::Update {
speed_mbps: speed_mbps,
bytes_curr: progress_curr,
});
}
DlStatus::Done { duration_ms: _ } => {
//dl_speeds[update.id as usize] = 0.0;
}
// Just forwared everything else to the calling receiver
_ => rep.send(update.status),
}
}
});
let mut joiners: FuturesUnordered<_> = joiners.into_iter().collect();
// Validate if the tasks were successful. This will always grab the next completed
// task, independent from the original order in the joiners list
while let Some(output) = joiners.next().await {
// If any of the download tasks fail, abort the rest and delete the file
// since it is non-recoverable anyways
if let Err(e) = output? {
for handle in joiners.iter() {
handle.abort();
}
manager_handle.abort();
tokio::fs::remove_file(&into_file).await?;
return Err(e.into());
}
}
manager_handle.await?;
// Remove the additional byte at the file end
let ofile = tokio::fs::OpenOptions::new()
.create(false)
.write(true)
.truncate(false)
.open(&into_file)
.await?;
ofile.set_len(content_length).await?;
rep.send(DlStatus::Done {
duration_ms: t_start.elapsed()?.as_millis() as u64,
});
Ok(())
}
async fn create_zeroed_file(file: &Path, filesize: usize) -> ResBE<()> {
let ofile = tokio::fs::OpenOptions::new()
.create(true)
// Open in write mode
.write(true)
// Delete and overwrite the file
.truncate(true)
.open(file)
.await?;
ofile.set_len(filesize as u64).await?;
Ok(())
}
pub async fn http_get_filesize_and_range_support(url: &str) -> ResBE<(u64, bool)> {
let resp = reqwest::Client::new().head(url).send().await?;
if let Some(filesize) = resp.headers().get(reqwest::header::CONTENT_LENGTH) {
if let Ok(val_str) = filesize.to_str() {
if let Ok(val) = val_str.parse::<u64>() {
let mut range_supported = false;
if let Some(range) = resp.headers().get(reqwest::header::ACCEPT_RANGES) {
if let Ok(range) = range.to_str() {
if range == "bytes" {
range_supported = true;
}
}
}
return Ok((val, range_supported));
}
}
}
Err(DlError::ContentLengthUnknown.into())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn rolling_average() {
let mut ra = RollingAverage::new(3);
assert_eq!(0, ra.data.len());
assert_eq!(3, ra.data.capacity());
assert_eq!(0.0, ra.value());
// 10 / 1 = 10
ra.add(10.0);
assert_eq!(1, ra.data.len());
assert_eq!(10.0, ra.value());
// (10 + 20) / 2 = 15
ra.add(20.0);
assert_eq!(2, ra.data.len());
assert_eq!(15.0, ra.value());
// (10 + 20 + 30) / 3 = 20
ra.add(30.0);
assert_eq!(3, ra.data.len());
assert_eq!(20.0, ra.value());
assert_eq!(10.0, ra.data[0]);
assert_eq!(20.0, ra.data[1]);
assert_eq!(30.0, ra.data[2]);
// This should replace the oldest value (index 1)
ra.add(40.0);
assert_eq!(3, ra.data.len());
assert_eq!(3, ra.data.capacity());
// (40 + 20 + 30) / 3 = 30
assert_eq!(30.0, ra.value());
assert_eq!(40.0, ra.data[0]);
assert_eq!(20.0, ra.data[1]);
assert_eq!(30.0, ra.data[2]);
ra.add(50.0);
ra.add(60.0);
ra.add(70.0);
assert_eq!(70.0, ra.data[0]);
assert_eq!(50.0, ra.data[1]);
assert_eq!(60.0, ra.data[2]);
}
}