diff --git a/src/args.rs b/src/args.rs index 8084f5d..c9123cb 100644 --- a/src/args.rs +++ b/src/args.rs @@ -1,4 +1,4 @@ -use std::num::NonZeroU32; +use std::{num::NonZeroU32, path::PathBuf}; use clap::Parser; #[derive(Parser, Clone, Debug)] @@ -16,7 +16,7 @@ pub struct CLIArgs { default_value = "./", help = "Set the output directory. The directory will be created if it doesn't exit yet", )] - pub outdir: String, + pub outdir: PathBuf, #[clap( short = 'i', @@ -24,7 +24,7 @@ pub struct CLIArgs { value_name = "FILENAME", help = "Force filename. This only works for single file downloads", )] - pub into_file: Option, + pub into_file: Option, #[clap( short = 'n', @@ -61,7 +61,7 @@ pub struct CLIArgs { value_name = "URL LISTFILE", help = "Download all files from the specified url list file", )] - pub listfile: Vec, + pub listfile: Vec, #[clap( short = 'd', diff --git a/src/main.rs b/src/main.rs index b77a858..758f2f4 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,18 +1,19 @@ -use std::{ - collections::VecDeque, - path::{Path, PathBuf}, - process::exit, - sync::Arc, - time::SystemTime, -}; +use std::{collections::VecDeque, path::Path, process::exit, sync::Arc, time::SystemTime}; use clap::Parser; +use download::{download_feedback, download_feedback_multi, http_get_filesize_and_range_support}; use futures::future::join_all; -use tokio::sync::{mpsc, Mutex}; +use tokio::{ + fs::create_dir_all, + sync::{ + mpsc::{unbounded_channel, UnboundedSender}, + Mutex, + }, +}; use crate::{ args::CLIArgs, - dlreport::{DlReport, DlReporter, DlStatus}, + dlreport::{watch_and_print_reports, DlReport, DlReporter, DlStatus}, errors::ResBE, }; @@ -22,25 +23,39 @@ mod download; mod errors; mod zippy; +struct DlRequest { + id: usize, + url: String, +} + +type SyncQueue = Arc>>; + #[tokio::main] async fn main() -> ResBE<()> { let args = CLIArgs::parse(); + // Combine all urls taken from files and the ones provided on the command line let mut urls = args.download.clone(); for file in args.listfile.iter() { - match read_urls_from_listfile(file).await { + match urls_from_listfile(file).await { Ok(listfile_urls) => urls.extend(listfile_urls), Err(_) => { - eprintln!("Failed to read urls from file: {}", file); + eprintln!("Failed to read urls from file: {}", file.display()); exit(1); } } } + if urls.is_empty() { + eprintln!("No URLs provided"); + return Ok(()); + } + download_multiple(args, urls).await } -async fn read_urls_from_listfile(listfile: &str) -> ResBE> { +/// Parse a listfile and return all urls found in it +async fn urls_from_listfile(listfile: &Path) -> ResBE> { let text = tokio::fs::read_to_string(listfile).await?; let urls = text .lines() @@ -51,158 +66,138 @@ async fn read_urls_from_listfile(listfile: &str) -> ResBE> { Ok(urls) } -async fn download_multiple(cli_args: CLIArgs, raw_urls: Vec) -> ResBE<()> { - let outdir = Path::new(&cli_args.outdir); - +// Download all files in parallel according to the provided CLI arguments +async fn download_multiple(args: CLIArgs, raw_urls: Vec) -> ResBE<()> { let num_urls = raw_urls.len(); - let parallel_file_count = cli_args.file_count.get(); - let conn_count = cli_args.conn_count.get(); - let zippy = cli_args.zippy; + let urls: SyncQueue = Default::default(); - let urls = Arc::new(Mutex::new(VecDeque::<(usize, String)>::new())); + let enumerated_urls = raw_urls + .into_iter() + .enumerate() + .map(|(id, url)| DlRequest { id, url }); + urls.lock().await.extend(enumerated_urls); - urls.lock().await.extend(raw_urls.into_iter().enumerate()); - - if !outdir.exists() { - if let Err(_e) = tokio::fs::create_dir_all(outdir).await { - eprintln!("Error creating output directory '{}'", outdir.display()); + if !args.outdir.exists() { + if let Err(_e) = create_dir_all(&args.outdir).await { + eprintln!( + "Error creating output directory '{}'", + args.outdir.display() + ); exit(1); } } - let (tx, rx) = mpsc::unbounded_channel::(); + let (tx, rx) = unbounded_channel::(); let t_start = SystemTime::now(); - let joiners = (0..parallel_file_count) - .map(|_| { - tokio::task::spawn(download_job( - urls.clone(), - tx.clone(), - conn_count, - zippy, - outdir.to_owned(), - cli_args.into_file.clone(), - )) - }) + let jobs = (0..args.file_count.get()) + .map(|_| tokio::task::spawn(download_job(urls.clone(), tx.clone(), args.clone()))) .collect::>(); drop(tx); - dlreport::watch_and_print_reports(rx, num_urls as i32).await?; + watch_and_print_reports(rx, num_urls as i32).await?; - join_all(joiners).await; + join_all(jobs).await; println!("Total time: {}s", t_start.elapsed()?.as_secs()); Ok(()) } -async fn download_job( - urls: Arc>>, - tx: mpsc::UnboundedSender, - conn_count: u32, - zippy: bool, - outdir: PathBuf, - arg_filename: Option, -) { +async fn download_job(urls: SyncQueue, reporter: UnboundedSender, cli_args: CLIArgs) { loop { - let (global_url_index, url) = match urls.lock().await.pop_front() { + // Get the next url to download or break if there are no more urls + let dlreq = match urls.lock().await.pop_front() { Some(it) => it, None => break, }; - let tx = tx.clone(); + let reporter = DlReporter::new(dlreq.id as u32, reporter.clone()); - let rep = DlReporter::new(global_url_index as u32, tx); - - let url = if zippy { - match zippy::resolve_link(&url).await { + // Resolve the zippy url to the direct download url if necessary + let url = if cli_args.zippy { + match zippy::resolve_link(&dlreq.url).await { Ok(url) => url, Err(_e) => { - rep.send(DlStatus::Message(format!( + reporter.send(DlStatus::Message(format!( "Zippyshare link could not be resolved: {}", - url + dlreq.url ))); continue; } } } else { - url.to_string() + dlreq.url.to_string() }; - let file_name = arg_filename + let file_name = cli_args + .into_file .clone() - .unwrap_or_else(|| download::url_to_filename(&url)); + .unwrap_or_else(|| download::url_to_filename(&url).into()); - let into_file = outdir + let into_file = cli_args + .outdir .join(Path::new(&file_name)) .to_str() .unwrap() .to_string(); let path_into_file = Path::new(&into_file); - let (filesize, range_supported) = - match download::http_get_filesize_and_range_support(&url).await { - Ok((filesize, range_supported)) => (filesize, range_supported), - Err(_e) => { - rep.send(DlStatus::Message(format!( - "Error while querying metadata: {}", - url - ))); - continue; - } - }; + let (filesize, range_supported) = match http_get_filesize_and_range_support(&url).await { + Ok((filesize, range_supported)) => (filesize, range_supported), + Err(_e) => { + reporter.send(DlStatus::Message(format!( + "Error while querying metadata: {}", + url + ))); + continue; + } + }; // If file with same name is present locally, check filesize if path_into_file.exists() { let local_filesize = std::fs::metadata(path_into_file).unwrap().len(); if filesize == local_filesize { - rep.send(DlStatus::Message(format!( + reporter.send(DlStatus::Message(format!( "Skipping file '{}': already present", - &file_name + file_name.display() ))); - rep.send(DlStatus::Skipped); + reporter.send(DlStatus::Skipped); continue; } else { - rep.send(DlStatus::Message(format!( + reporter.send(DlStatus::Message(format!( "Replacing file '{}': present but not completed", - &file_name + &file_name.display() ))); } } - if conn_count == 1 { - if let Err(_e) = - download::download_feedback(&url, &into_file, rep.clone(), Some(filesize)).await - { - rep.send(DlStatus::DoneErr { - filename: file_name.to_string(), - }); - } + let dl_status = if cli_args.conn_count.get() == 1 { + download_feedback(&url, &into_file, reporter.clone(), Some(filesize)).await + } else if !range_supported { + reporter.send(DlStatus::Message(format!( + "Server does not support range headers. Downloading with single connection: {}", + url + ))); + download_feedback(&url, &into_file, reporter.clone(), Some(filesize)).await } else { - if !range_supported { - rep.send(DlStatus::Message(format!( - "Error Server does not support range header: {}", - url - ))); - continue; - } - - if let Err(_e) = download::download_feedback_multi( + download_feedback_multi( &url, &into_file, - rep.clone(), - conn_count, + reporter.clone(), + cli_args.conn_count.get(), Some(filesize), ) .await - { - rep.send(DlStatus::DoneErr { - filename: file_name.to_string(), - }); - } }; + + if dl_status.is_err() { + reporter.send(DlStatus::DoneErr { + filename: file_name.to_str().unwrap().to_string(), + }); + } } }