From e6360153d62682dad7787c1717e220014f909453 Mon Sep 17 00:00:00 2001 From: Daniel M Date: Thu, 31 Mar 2022 20:25:24 +0200 Subject: [PATCH] More refactoring --- src/args.rs | 1 + src/dlreport.rs | 8 ++- src/download.rs | 141 +++++------------------------------------------- src/main.rs | 58 ++++++++------------ src/misc.rs | 106 ++++++++++++++++++++++++++++++++++++ src/zippy.rs | 3 +- 6 files changed, 147 insertions(+), 170 deletions(-) create mode 100644 src/misc.rs diff --git a/src/args.rs b/src/args.rs index cd3b9a1..4ef1463 100644 --- a/src/args.rs +++ b/src/args.rs @@ -1,4 +1,5 @@ use std::{num::NonZeroU32, path::PathBuf}; + use clap::Parser; #[derive(Parser, Clone, Debug)] diff --git a/src/dlreport.rs b/src/dlreport.rs index ace25fc..d9684c0 100644 --- a/src/dlreport.rs +++ b/src/dlreport.rs @@ -2,14 +2,12 @@ use std::collections::{HashMap, VecDeque}; use std::io::stdout; use std::time::SystemTime; -use tokio::sync::mpsc; - +use anyhow::Result; use crossterm::cursor::MoveToPreviousLine; use crossterm::execute; use crossterm::style::Print; use crossterm::terminal::{Clear, ClearType}; - -use anyhow::Result; +use tokio::sync::mpsc; #[derive(Clone, Debug)] pub enum DlStatus { @@ -82,7 +80,7 @@ impl DlReporter { #[macro_export] macro_rules! report_msg { ($rep:ident, $fmt:expr) => { - DlReporter::msg(&$rep, $fmt.to_string()); + DlReporter::msg(&$rep, format!($fmt)); }; ($rep:ident, $fmt:expr, $($fmt2:expr),+) => { DlReporter::msg(&$rep, format!($fmt, $($fmt2,)+)); diff --git a/src/download.rs b/src/download.rs index 9e3bbc4..042c327 100644 --- a/src/download.rs +++ b/src/download.rs @@ -1,66 +1,17 @@ +use std::io::SeekFrom; +use std::path::Path; +use std::time::SystemTime; + use anyhow::Result; use futures::stream::FuturesUnordered; use futures::StreamExt; use percent_encoding::percent_decode_str; -use std::io::SeekFrom; -use std::path::Path; -use std::time::SystemTime; use tokio::io::{AsyncSeekExt, AsyncWriteExt}; use tokio::sync::mpsc; -use crate::dlreport::*; -use crate::errors::*; - -struct RollingAverage { - index: usize, - data: Vec, -} - -impl RollingAverage { - fn new(size: usize) -> Self { - RollingAverage { - index: 0, - data: Vec::with_capacity(size), - } - } - - fn value(&self) -> f64 { - if self.data.is_empty() { - 0.0 - } else { - let mut max = self.data[0]; - - for v in self.data.iter() { - if *v > max { - max = *v; - } - } - - let mut sum: f64 = self.data.iter().sum(); - let mut count = self.data.len(); - - if self.data.len() >= 3 { - sum -= max; - count -= 1; - } - - sum / count as f64 - } - } - - fn add(&mut self, val: f64) { - if self.data.capacity() == self.data.len() { - self.data[self.index] = val; - - self.index += 1; - if self.index >= self.data.capacity() { - self.index = 0; - } - } else { - self.data.push(val); - } - } -} +use crate::dlreport::{DlReport, DlReporter, DlStatus}; +use crate::errors::DlError; +use crate::misc::RollingAverage; /// Get the filename at the end of the given URL. This will decode the URL Encoding. pub fn url_to_filename(url: &str) -> String { @@ -81,7 +32,7 @@ pub async fn download_feedback( url: &str, into_file: &Path, rep: DlReporter, - content_length: Option, + content_length: u64, ) -> Result<()> { download_feedback_chunks(url, into_file, rep, None, content_length).await } @@ -91,14 +42,9 @@ pub async fn download_feedback_chunks( into_file: &Path, rep: DlReporter, from_to: Option<(u64, u64)>, - content_length: Option, + mut content_length: u64, ) -> Result<()> { - let mut content_length = match content_length { - Some(it) => it, - None => http_get_filesize_and_range_support(url).await?.filesize, - }; - - // Send the HTTP request to download the given link + // Build the HTTP request to download the given link let mut req = reqwest::Client::new().get(url); // Add range header if needed @@ -213,13 +159,8 @@ pub async fn download_feedback_multi( into_file: &Path, rep: DlReporter, conn_count: u32, - content_length: Option, + content_length: u64, ) -> Result<()> { - let content_length = match content_length { - Some(it) => it, - None => http_get_filesize_and_range_support(url).await?.filesize, - }; - // Create zeroed file with 1 byte too much. This will be truncated on download // completion and can indicate that the file is not suitable for continuation create_zeroed_file(into_file, content_length as usize + 1).await?; @@ -258,7 +199,7 @@ pub async fn download_feedback_multi( &into_file, rep, Some(from_to), - Some(specific_content_length), + specific_content_length, ) .await })) @@ -377,7 +318,7 @@ pub struct HttpFileInfo { pub filename: String, } -pub async fn http_get_filesize_and_range_support(url: &str) -> Result { +pub async fn http_file_info(url: &str) -> Result { let resp = reqwest::Client::new().head(url).send().await?; let filesize = resp @@ -402,59 +343,3 @@ pub async fn http_get_filesize_and_range_support(url: &str) -> Result) -> Result<()> { let t_start = SystemTime::now(); let jobs = (0..args.file_count.get()) - .map(|_| tokio::task::spawn(download_job(urls.clone(), tx.clone(), args.clone()))) + .map(|_| tokio::task::spawn(download_job(Arc::clone(&urls), tx.clone(), args.clone()))) .collect::>(); drop(tx); @@ -135,7 +127,7 @@ async fn download_job(urls: SyncQueue, reporter: UnboundedSender, cli_ dlreq.url.to_string() }; - let info = match http_get_filesize_and_range_support(&url).await { + let info = match http_file_info(&url).await { Ok(it) => it, Err(_e) => { report_msg!(reporter, "Error while querying metadata: {url}"); @@ -143,13 +135,7 @@ async fn download_job(urls: SyncQueue, reporter: UnboundedSender, cli_ } }; - let into_file: PathBuf = cli_args - .outdir - .join(Path::new(&info.filename)) - .to_str() - .unwrap() - .to_string() - .into(); + let into_file = cli_args.outdir.join(Path::new(&info.filename)); // If file with same name is present locally, check filesize if into_file.exists() { @@ -173,20 +159,20 @@ async fn download_job(urls: SyncQueue, reporter: UnboundedSender, cli_ } let dl_status = if cli_args.conn_count.get() == 1 { - download_feedback(&url, &into_file, reporter.clone(), Some(info.filesize)).await + download_feedback(&url, &into_file, reporter.clone(), info.filesize).await } else if !info.range_support { report_msg!( reporter, "Server does not support range headers. Downloading with single connection: {url}" ); - download_feedback(&url, &into_file, reporter.clone(), Some(info.filesize)).await + download_feedback(&url, &into_file, reporter.clone(), info.filesize).await } else { download_feedback_multi( &url, &into_file, reporter.clone(), cli_args.conn_count.get(), - Some(info.filesize), + info.filesize, ) .await }; diff --git a/src/misc.rs b/src/misc.rs new file mode 100644 index 0000000..d267e93 --- /dev/null +++ b/src/misc.rs @@ -0,0 +1,106 @@ +pub struct RollingAverage { + index: usize, + data: Vec, +} + +impl RollingAverage { + pub fn new(size: usize) -> Self { + RollingAverage { + index: 0, + data: Vec::with_capacity(size), + } + } + + pub fn value(&self) -> f64 { + if self.data.is_empty() { + 0.0 + } else { + let mut max = self.data[0]; + + for v in self.data.iter() { + if *v > max { + max = *v; + } + } + + let mut sum: f64 = self.data.iter().sum(); + let mut count = self.data.len(); + + if self.data.len() >= 3 { + sum -= max; + count -= 1; + } + + sum / count as f64 + } + } + + pub fn add(&mut self, val: f64) { + if self.data.capacity() == self.data.len() { + self.data[self.index] = val; + + self.index += 1; + if self.index >= self.data.capacity() { + self.index = 0; + } + } else { + self.data.push(val); + } + } +} + +#[cfg(test)] +mod tests { + + use super::*; + + #[test] + fn rolling_average() { + let mut ra = RollingAverage::new(3); + + assert_eq!(0, ra.data.len()); + assert_eq!(3, ra.data.capacity()); + + assert_eq!(0.0, ra.value()); + + // 10 / 1 = 10 + ra.add(10.0); + assert_eq!(1, ra.data.len()); + assert_eq!(10.0, ra.value()); + + // (10 + 20) / 2 = 15 + ra.add(20.0); + assert_eq!(2, ra.data.len()); + assert_eq!(15.0, ra.value()); + + // (10 + 20 + 30) / 3 = 20 + ra.add(30.0); + assert_eq!(3, ra.data.len()); + assert_eq!(20.0, ra.value()); + + assert_eq!(10.0, ra.data[0]); + assert_eq!(20.0, ra.data[1]); + assert_eq!(30.0, ra.data[2]); + + // This should replace the oldest value (index 1) + ra.add(40.0); + + assert_eq!(3, ra.data.len()); + assert_eq!(3, ra.data.capacity()); + + // (40 + 20 + 30) / 3 = 30 + assert_eq!(30.0, ra.value()); + + assert_eq!(40.0, ra.data[0]); + assert_eq!(20.0, ra.data[1]); + assert_eq!(30.0, ra.data[2]); + + ra.add(50.0); + ra.add(60.0); + ra.add(70.0); + + assert_eq!(70.0, ra.data[0]); + assert_eq!(50.0, ra.data[1]); + assert_eq!(60.0, ra.data[2]); + } +} diff --git a/src/zippy.rs b/src/zippy.rs index 7f726bc..377f41d 100644 --- a/src/zippy.rs +++ b/src/zippy.rs @@ -1,6 +1,7 @@ +use std::io::{Error, ErrorKind}; + use anyhow::Result; use regex::Regex; -use std::io::{Error, ErrorKind}; pub fn is_zippyshare_url(url: &str) -> bool { Regex::new(r"^https?://(?:www\d*\.)?zippyshare\.com/v/[0-9a-zA-Z]+/file\.html$")