More refactoring

This commit is contained in:
Daniel M 2022-03-31 20:25:24 +02:00
parent 16d0edbbb6
commit e6360153d6
6 changed files with 147 additions and 170 deletions

View File

@ -1,4 +1,5 @@
use std::{num::NonZeroU32, path::PathBuf};
use clap::Parser;
#[derive(Parser, Clone, Debug)]

View File

@ -2,14 +2,12 @@ use std::collections::{HashMap, VecDeque};
use std::io::stdout;
use std::time::SystemTime;
use tokio::sync::mpsc;
use anyhow::Result;
use crossterm::cursor::MoveToPreviousLine;
use crossterm::execute;
use crossterm::style::Print;
use crossterm::terminal::{Clear, ClearType};
use anyhow::Result;
use tokio::sync::mpsc;
#[derive(Clone, Debug)]
pub enum DlStatus {
@ -82,7 +80,7 @@ impl DlReporter {
#[macro_export]
macro_rules! report_msg {
($rep:ident, $fmt:expr) => {
DlReporter::msg(&$rep, $fmt.to_string());
DlReporter::msg(&$rep, format!($fmt));
};
($rep:ident, $fmt:expr, $($fmt2:expr),+) => {
DlReporter::msg(&$rep, format!($fmt, $($fmt2,)+));

View File

@ -1,66 +1,17 @@
use std::io::SeekFrom;
use std::path::Path;
use std::time::SystemTime;
use anyhow::Result;
use futures::stream::FuturesUnordered;
use futures::StreamExt;
use percent_encoding::percent_decode_str;
use std::io::SeekFrom;
use std::path::Path;
use std::time::SystemTime;
use tokio::io::{AsyncSeekExt, AsyncWriteExt};
use tokio::sync::mpsc;
use crate::dlreport::*;
use crate::errors::*;
struct RollingAverage {
index: usize,
data: Vec<f64>,
}
impl RollingAverage {
fn new(size: usize) -> Self {
RollingAverage {
index: 0,
data: Vec::with_capacity(size),
}
}
fn value(&self) -> f64 {
if self.data.is_empty() {
0.0
} else {
let mut max = self.data[0];
for v in self.data.iter() {
if *v > max {
max = *v;
}
}
let mut sum: f64 = self.data.iter().sum();
let mut count = self.data.len();
if self.data.len() >= 3 {
sum -= max;
count -= 1;
}
sum / count as f64
}
}
fn add(&mut self, val: f64) {
if self.data.capacity() == self.data.len() {
self.data[self.index] = val;
self.index += 1;
if self.index >= self.data.capacity() {
self.index = 0;
}
} else {
self.data.push(val);
}
}
}
use crate::dlreport::{DlReport, DlReporter, DlStatus};
use crate::errors::DlError;
use crate::misc::RollingAverage;
/// Get the filename at the end of the given URL. This will decode the URL Encoding.
pub fn url_to_filename(url: &str) -> String {
@ -81,7 +32,7 @@ pub async fn download_feedback(
url: &str,
into_file: &Path,
rep: DlReporter,
content_length: Option<u64>,
content_length: u64,
) -> Result<()> {
download_feedback_chunks(url, into_file, rep, None, content_length).await
}
@ -91,14 +42,9 @@ pub async fn download_feedback_chunks(
into_file: &Path,
rep: DlReporter,
from_to: Option<(u64, u64)>,
content_length: Option<u64>,
mut content_length: u64,
) -> Result<()> {
let mut content_length = match content_length {
Some(it) => it,
None => http_get_filesize_and_range_support(url).await?.filesize,
};
// Send the HTTP request to download the given link
// Build the HTTP request to download the given link
let mut req = reqwest::Client::new().get(url);
// Add range header if needed
@ -213,13 +159,8 @@ pub async fn download_feedback_multi(
into_file: &Path,
rep: DlReporter,
conn_count: u32,
content_length: Option<u64>,
content_length: u64,
) -> Result<()> {
let content_length = match content_length {
Some(it) => it,
None => http_get_filesize_and_range_support(url).await?.filesize,
};
// Create zeroed file with 1 byte too much. This will be truncated on download
// completion and can indicate that the file is not suitable for continuation
create_zeroed_file(into_file, content_length as usize + 1).await?;
@ -258,7 +199,7 @@ pub async fn download_feedback_multi(
&into_file,
rep,
Some(from_to),
Some(specific_content_length),
specific_content_length,
)
.await
}))
@ -377,7 +318,7 @@ pub struct HttpFileInfo {
pub filename: String,
}
pub async fn http_get_filesize_and_range_support(url: &str) -> Result<HttpFileInfo> {
pub async fn http_file_info(url: &str) -> Result<HttpFileInfo> {
let resp = reqwest::Client::new().head(url).send().await?;
let filesize = resp
@ -402,59 +343,3 @@ pub async fn http_get_filesize_and_range_support(url: &str) -> Result<HttpFileIn
Ok(info)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn rolling_average() {
let mut ra = RollingAverage::new(3);
assert_eq!(0, ra.data.len());
assert_eq!(3, ra.data.capacity());
assert_eq!(0.0, ra.value());
// 10 / 1 = 10
ra.add(10.0);
assert_eq!(1, ra.data.len());
assert_eq!(10.0, ra.value());
// (10 + 20) / 2 = 15
ra.add(20.0);
assert_eq!(2, ra.data.len());
assert_eq!(15.0, ra.value());
// (10 + 20 + 30) / 3 = 20
ra.add(30.0);
assert_eq!(3, ra.data.len());
assert_eq!(20.0, ra.value());
assert_eq!(10.0, ra.data[0]);
assert_eq!(20.0, ra.data[1]);
assert_eq!(30.0, ra.data[2]);
// This should replace the oldest value (index 1)
ra.add(40.0);
assert_eq!(3, ra.data.len());
assert_eq!(3, ra.data.capacity());
// (40 + 20 + 30) / 3 = 30
assert_eq!(30.0, ra.value());
assert_eq!(40.0, ra.data[0]);
assert_eq!(20.0, ra.data[1]);
assert_eq!(30.0, ra.data[2]);
ra.add(50.0);
ra.add(60.0);
ra.add(70.0);
assert_eq!(70.0, ra.data[0]);
assert_eq!(50.0, ra.data[1]);
assert_eq!(60.0, ra.data[2]);
}
}

View File

@ -1,34 +1,26 @@
use std::{
collections::VecDeque,
path::{Path, PathBuf},
process::exit,
sync::Arc,
time::SystemTime,
};
use clap::Parser;
use futures::future::join_all;
use tokio::{
fs::create_dir_all,
sync::{
mpsc::{unbounded_channel, UnboundedSender},
Mutex,
},
};
use crate::{
args::CLIArgs,
dlreport::{watch_and_print_reports, DlReport, DlReporter},
download::{download_feedback, download_feedback_multi, http_get_filesize_and_range_support},
zippy::is_zippyshare_url,
};
use std::collections::VecDeque;
use std::path::Path;
use std::process::exit;
use std::sync::Arc;
use std::time::SystemTime;
use anyhow::Result;
use clap::Parser;
use futures::future::join_all;
use tokio::fs::create_dir_all;
use tokio::sync::mpsc::{unbounded_channel, UnboundedSender};
use tokio::sync::Mutex;
use crate::args::CLIArgs;
use crate::dlreport::{watch_and_print_reports, DlReport, DlReporter};
use crate::download::{download_feedback, download_feedback_multi, http_file_info};
use crate::zippy::is_zippyshare_url;
mod args;
mod dlreport;
mod download;
mod errors;
mod misc;
mod zippy;
struct DlRequest {
@ -100,7 +92,7 @@ async fn download_multiple(args: CLIArgs, raw_urls: Vec<String>) -> Result<()> {
let t_start = SystemTime::now();
let jobs = (0..args.file_count.get())
.map(|_| tokio::task::spawn(download_job(urls.clone(), tx.clone(), args.clone())))
.map(|_| tokio::task::spawn(download_job(Arc::clone(&urls), tx.clone(), args.clone())))
.collect::<Vec<_>>();
drop(tx);
@ -135,7 +127,7 @@ async fn download_job(urls: SyncQueue, reporter: UnboundedSender<DlReport>, cli_
dlreq.url.to_string()
};
let info = match http_get_filesize_and_range_support(&url).await {
let info = match http_file_info(&url).await {
Ok(it) => it,
Err(_e) => {
report_msg!(reporter, "Error while querying metadata: {url}");
@ -143,13 +135,7 @@ async fn download_job(urls: SyncQueue, reporter: UnboundedSender<DlReport>, cli_
}
};
let into_file: PathBuf = cli_args
.outdir
.join(Path::new(&info.filename))
.to_str()
.unwrap()
.to_string()
.into();
let into_file = cli_args.outdir.join(Path::new(&info.filename));
// If file with same name is present locally, check filesize
if into_file.exists() {
@ -173,20 +159,20 @@ async fn download_job(urls: SyncQueue, reporter: UnboundedSender<DlReport>, cli_
}
let dl_status = if cli_args.conn_count.get() == 1 {
download_feedback(&url, &into_file, reporter.clone(), Some(info.filesize)).await
download_feedback(&url, &into_file, reporter.clone(), info.filesize).await
} else if !info.range_support {
report_msg!(
reporter,
"Server does not support range headers. Downloading with single connection: {url}"
);
download_feedback(&url, &into_file, reporter.clone(), Some(info.filesize)).await
download_feedback(&url, &into_file, reporter.clone(), info.filesize).await
} else {
download_feedback_multi(
&url,
&into_file,
reporter.clone(),
cli_args.conn_count.get(),
Some(info.filesize),
info.filesize,
)
.await
};

106
src/misc.rs Normal file
View File

@ -0,0 +1,106 @@
pub struct RollingAverage {
index: usize,
data: Vec<f64>,
}
impl RollingAverage {
pub fn new(size: usize) -> Self {
RollingAverage {
index: 0,
data: Vec::with_capacity(size),
}
}
pub fn value(&self) -> f64 {
if self.data.is_empty() {
0.0
} else {
let mut max = self.data[0];
for v in self.data.iter() {
if *v > max {
max = *v;
}
}
let mut sum: f64 = self.data.iter().sum();
let mut count = self.data.len();
if self.data.len() >= 3 {
sum -= max;
count -= 1;
}
sum / count as f64
}
}
pub fn add(&mut self, val: f64) {
if self.data.capacity() == self.data.len() {
self.data[self.index] = val;
self.index += 1;
if self.index >= self.data.capacity() {
self.index = 0;
}
} else {
self.data.push(val);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn rolling_average() {
let mut ra = RollingAverage::new(3);
assert_eq!(0, ra.data.len());
assert_eq!(3, ra.data.capacity());
assert_eq!(0.0, ra.value());
// 10 / 1 = 10
ra.add(10.0);
assert_eq!(1, ra.data.len());
assert_eq!(10.0, ra.value());
// (10 + 20) / 2 = 15
ra.add(20.0);
assert_eq!(2, ra.data.len());
assert_eq!(15.0, ra.value());
// (10 + 20 + 30) / 3 = 20
ra.add(30.0);
assert_eq!(3, ra.data.len());
assert_eq!(20.0, ra.value());
assert_eq!(10.0, ra.data[0]);
assert_eq!(20.0, ra.data[1]);
assert_eq!(30.0, ra.data[2]);
// This should replace the oldest value (index 1)
ra.add(40.0);
assert_eq!(3, ra.data.len());
assert_eq!(3, ra.data.capacity());
// (40 + 20 + 30) / 3 = 30
assert_eq!(30.0, ra.value());
assert_eq!(40.0, ra.data[0]);
assert_eq!(20.0, ra.data[1]);
assert_eq!(30.0, ra.data[2]);
ra.add(50.0);
ra.add(60.0);
ra.add(70.0);
assert_eq!(70.0, ra.data[0]);
assert_eq!(50.0, ra.data[1]);
assert_eq!(60.0, ra.data[2]);
}
}

View File

@ -1,6 +1,7 @@
use std::io::{Error, ErrorKind};
use anyhow::Result;
use regex::Regex;
use std::io::{Error, ErrorKind};
pub fn is_zippyshare_url(url: &str) -> bool {
Regex::new(r"^https?://(?:www\d*\.)?zippyshare\.com/v/[0-9a-zA-Z]+/file\.html$")