diff --git a/Cargo.lock b/Cargo.lock index 54c6443..f61015f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -38,6 +38,7 @@ checksum = "08b108ad2665fa3f6e6a517c3d80ec3e77d224c47d605167aefaa5d7ef97fa48" dependencies = [ "async-trait", "axum-core", + "axum-macros", "bitflags", "bytes", "futures-util", @@ -80,6 +81,18 @@ dependencies = [ "tower-service", ] +[[package]] +name = "axum-macros" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e4df0fc33ada14a338b799002f7e8657711422b25d4e16afb032708d6b185621" +dependencies = [ + "heck", + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "bitflags" version = "1.3.2" @@ -225,6 +238,12 @@ dependencies = [ "ahash", ] +[[package]] +name = "heck" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2540771e65fc8cb83cd6e8a237f70c319bd5c29f78ed1084ba5d50eeac86f7f9" + [[package]] name = "hermit-abi" version = "0.1.19" diff --git a/Cargo.toml b/Cargo.toml index a39cd8d..350ba99 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,7 +6,7 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -axum = "0.6.1" +axum = { version = "0.6.1", features = ["macros"] } hyper = { version = "0.14", features = ["full"] } tokio = { version = "1", features = ["full"] } tokio-util = { version = "0.7.4", features = ["full"] } diff --git a/src/main.rs b/src/main.rs index a685d40..fea0105 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,4 +1,6 @@ -use std::sync::{Arc, Mutex}; +use std::sync::Arc; + +extern crate axum; use axum::{ routing::{get, post}, @@ -6,6 +8,7 @@ use axum::{ }; use bytes::Bytes; use memory_cache::MemoryCache; +use tokio::sync::Mutex; mod state; mod new; diff --git a/src/new.rs b/src/new.rs index 4ae8caa..44f4162 100644 --- a/src/new.rs +++ b/src/new.rs @@ -1,144 +1,145 @@ -use std::{collections::HashMap, ffi::OsStr, io::Read, path::PathBuf, sync::Arc, time::Duration}; - -use axum::{ - extract::{BodyStream, Query, State}, - http::HeaderValue, -}; -use bytes::{BufMut, Bytes, BytesMut}; -use hyper::{header, HeaderMap, StatusCode}; -use rand::Rng; -use tokio::{ - fs::File, - io::AsyncWriteExt, - sync::mpsc::{self, Receiver, Sender}, -}; -use tokio_stream::StreamExt; - -// create an upload name from an original file name -fn gen_path(original_name: &String) -> PathBuf { - // extract extension from original name - let extension = PathBuf::from(original_name) - .extension() - .and_then(OsStr::to_str) - .unwrap_or_default() - .to_string(); - - // generate a 6-character alphanumeric string - let id: String = rand::thread_rng() - .sample_iter(&rand::distributions::Alphanumeric) - .take(6) - .map(char::from) - .collect(); - - // create the path - let mut path = PathBuf::new(); - path.push("uploads/"); - path.push(id); - path.set_extension(extension); - - // if we're already using it, try again - if path.exists() { - gen_path(original_name) - } else { - path - } -} - -pub async fn new( - State(state): State>, - headers: HeaderMap, - Query(params): Query>, - mut stream: BodyStream, -) -> Result { - // require name parameter, it's used for determining the file extension - if !params.contains_key("name") { - return Err(StatusCode::BAD_REQUEST); - } - - // generate a path, take the name, format a url - let path = gen_path(params.get("name").unwrap()); - - let name = path - .file_name() - .and_then(OsStr::to_str) - .unwrap_or_default() - .to_string(); // i hope this never happens. that would suck - - let url = format!("http://127.0.0.1:8000/p/{}", name); - - // process the upload in the background so i can send the URL back immediately! - // this isn't supported by ShareX (it waits for the request to complete before handling the response) - tokio::spawn(async move { - // get the content length, and if parsing it fails, assume it's really big - // it may be better to make it fully content-length not required because this feels kind of redundant - let content_length = headers - .get(header::CONTENT_LENGTH) - .unwrap_or(&HeaderValue::from_static("")) - .to_str() - .and_then(|s| Ok(usize::from_str_radix(s, 10))) - .unwrap() - .unwrap_or(usize::MAX); - - // if the upload size exceeds 80 MB, we skip caching! - // previously, i was going to use redis with a 500 MB max (redis's is 512MiB) - // with or without redis, 500 MB is still a bit much.. - // it could probably be read from disk before anyone could fully download it - let mut use_cache = content_length < 80_000_000; - - println!( - "[upl] content length: {} using cache: {}", - content_length, use_cache - ); - - // create file to save upload to - let mut file = File::create(path) - .await - .expect("could not open file! make sure your upload path exists"); - - let mut data: BytesMut = if use_cache { - BytesMut::with_capacity(content_length) - } else { - BytesMut::new() - }; - - let (tx, mut rx): (Sender, Receiver) = mpsc::channel(1); - - tokio::spawn(async move { - while let Some(chunk) = rx.recv().await { - println!("[io] received new chunk"); - file.write_all(&chunk) - .await - .expect("error while writing file to disk"); - } - }); - - while let Some(chunk) = stream.next().await { - let chunk = chunk.unwrap(); - - println!("[upl] sending data to io task"); - tx.send(chunk.clone()).await.unwrap(); - - if use_cache { - println!("[upl] receiving data into cache"); - if data.len() + chunk.len() > data.capacity() { - println!("[upl] too much data! the client had an invalid content-length!"); - - // if we receive too much data, drop the buffer and stop using cache (it is still okay to use disk, probably) - data = BytesMut::new(); - use_cache = false; - } else { - data.put(chunk); - } - } - } - - let mut cache = state.cache.lock().unwrap(); - - if use_cache { - println!("[upl] caching upload!!"); - cache.insert(name, data.freeze(), Some(Duration::from_secs(30))); - } - }); - - Ok(url) -} +use std::{collections::HashMap, ffi::OsStr, io::Read, path::PathBuf, sync::Arc, time::Duration}; + +use axum::{ + extract::{BodyStream, Query, State}, + http::HeaderValue, +}; +use bytes::{BufMut, Bytes, BytesMut}; +use hyper::{header, HeaderMap, StatusCode}; +use rand::Rng; +use tokio::{ + fs::File, + io::AsyncWriteExt, + sync::mpsc::{self, Receiver, Sender}, +}; +use tokio_stream::StreamExt; + +// create an upload name from an original file name +fn gen_path(original_name: &String) -> PathBuf { + // extract extension from original name + let extension = PathBuf::from(original_name) + .extension() + .and_then(OsStr::to_str) + .unwrap_or_default() + .to_string(); + + // generate a 6-character alphanumeric string + let id: String = rand::thread_rng() + .sample_iter(&rand::distributions::Alphanumeric) + .take(6) + .map(char::from) + .collect(); + + // create the path + let mut path = PathBuf::new(); + path.push("uploads/"); + path.push(id); + path.set_extension(extension); + + // if we're already using it, try again + if path.exists() { + gen_path(original_name) + } else { + path + } +} + +#[axum::debug_handler] +pub async fn new( + State(state): State>, + headers: HeaderMap, + Query(params): Query>, + mut stream: BodyStream, +) -> Result { + // require name parameter, it's used for determining the file extension + if !params.contains_key("name") { + return Err(StatusCode::BAD_REQUEST); + } + + // generate a path, take the name, format a url + let path = gen_path(params.get("name").unwrap()); + + let name = path + .file_name() + .and_then(OsStr::to_str) + .unwrap_or_default() + .to_string(); // i hope this never happens. that would suck + + let url = format!("http://127.0.0.1:8000/p/{}", name); + + // process the upload in the background so i can send the URL back immediately! + // this isn't supported by ShareX (it waits for the request to complete before handling the response) + tokio::spawn(async move { + // get the content length, and if parsing it fails, assume it's really big + // it may be better to make it fully content-length not required because this feels kind of redundant + let content_length = headers + .get(header::CONTENT_LENGTH) + .unwrap_or(&HeaderValue::from_static("")) + .to_str() + .and_then(|s| Ok(usize::from_str_radix(s, 10))) + .unwrap() + .unwrap_or(usize::MAX); + + // if the upload size exceeds 80 MB, we skip caching! + // previously, i was going to use redis with a 500 MB max (redis's is 512MiB) + // with or without redis, 500 MB is still a bit much.. + // it could probably be read from disk before anyone could fully download it + let mut use_cache = content_length < 80_000_000; + + println!( + "[upl] content length: {} using cache: {}", + content_length, use_cache + ); + + // create file to save upload to + let mut file = File::create(path) + .await + .expect("could not open file! make sure your upload path exists"); + + let mut data: BytesMut = if use_cache { + BytesMut::with_capacity(content_length) + } else { + BytesMut::new() + }; + + let (tx, mut rx): (Sender, Receiver) = mpsc::channel(1); + + tokio::spawn(async move { + while let Some(chunk) = rx.recv().await { + println!("[io] received new chunk"); + file.write_all(&chunk) + .await + .expect("error while writing file to disk"); + } + }); + + while let Some(chunk) = stream.next().await { + let chunk = chunk.unwrap(); + + println!("[upl] sending data to io task"); + tx.send(chunk.clone()).await.unwrap(); + + if use_cache { + println!("[upl] receiving data into buffer"); + if data.len() + chunk.len() > data.capacity() { + println!("[upl] too much data! the client had an invalid content-length!"); + + // if we receive too much data, drop the buffer and stop using cache (it is still okay to use disk, probably) + data = BytesMut::new(); + use_cache = false; + } else { + data.put(chunk); + } + } + } + + let mut cache = state.cache.lock().await; + + if use_cache { + println!("[upl] caching upload!!"); + cache.insert(name, data.freeze(), Some(Duration::from_secs(120))); + } + }); + + Ok(url) +} diff --git a/src/state.rs b/src/state.rs index 3cbce6b..eb9f5c8 100644 --- a/src/state.rs +++ b/src/state.rs @@ -1,8 +1,7 @@ -use std::sync::{Mutex, Arc}; - -use bytes::Bytes; -use memory_cache::MemoryCache; - -pub struct AppState { - pub cache: Mutex> +use bytes::Bytes; +use memory_cache::MemoryCache; +use tokio::sync::Mutex; + +pub struct AppState { + pub cache: Mutex> } \ No newline at end of file diff --git a/src/view.rs b/src/view.rs index 0b3a97a..c6efb00 100644 --- a/src/view.rs +++ b/src/view.rs @@ -7,7 +7,7 @@ use std::{ use axum::{ body::StreamBody, extract::{Path, State}, - response::{IntoResponse, Response}, + response::{IntoResponse, Response}, debug_handler, }; use bytes::{buf::Reader, Bytes}; use hyper::StatusCode; @@ -28,12 +28,11 @@ impl IntoResponse for ViewResponse { } } */ +#[axum::debug_handler] pub async fn view( State(state): State>, Path(original_path): Path, ) -> Response { - println!("{:?}", original_path); - // (hopefully) prevent path traversal, just check for any non-file components if original_path .components() @@ -50,11 +49,11 @@ pub async fn view( .unwrap_or_default() .to_string(); - let cache = state.cache.lock().unwrap(); + let cache = state.cache.lock().await; let cache_item = cache.get(&name); - if true /* cache_item.is_none() */ { + if cache_item.is_none() { let mut path = PathBuf::new(); path.push("uploads/"); path.push(name);