switch to tokio mutexes, it compiles now

This commit is contained in:
minish 2022-12-29 00:03:14 -05:00 committed by minish
parent a2fccd1f1c
commit 9a22170dfa
6 changed files with 179 additions and 158 deletions

19
Cargo.lock generated
View File

@ -38,6 +38,7 @@ checksum = "08b108ad2665fa3f6e6a517c3d80ec3e77d224c47d605167aefaa5d7ef97fa48"
dependencies = [ dependencies = [
"async-trait", "async-trait",
"axum-core", "axum-core",
"axum-macros",
"bitflags", "bitflags",
"bytes", "bytes",
"futures-util", "futures-util",
@ -80,6 +81,18 @@ dependencies = [
"tower-service", "tower-service",
] ]
[[package]]
name = "axum-macros"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e4df0fc33ada14a338b799002f7e8657711422b25d4e16afb032708d6b185621"
dependencies = [
"heck",
"proc-macro2",
"quote",
"syn",
]
[[package]] [[package]]
name = "bitflags" name = "bitflags"
version = "1.3.2" version = "1.3.2"
@ -225,6 +238,12 @@ dependencies = [
"ahash", "ahash",
] ]
[[package]]
name = "heck"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2540771e65fc8cb83cd6e8a237f70c319bd5c29f78ed1084ba5d50eeac86f7f9"
[[package]] [[package]]
name = "hermit-abi" name = "hermit-abi"
version = "0.1.19" version = "0.1.19"

View File

@ -6,7 +6,7 @@ edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies] [dependencies]
axum = "0.6.1" axum = { version = "0.6.1", features = ["macros"] }
hyper = { version = "0.14", features = ["full"] } hyper = { version = "0.14", features = ["full"] }
tokio = { version = "1", features = ["full"] } tokio = { version = "1", features = ["full"] }
tokio-util = { version = "0.7.4", features = ["full"] } tokio-util = { version = "0.7.4", features = ["full"] }

View File

@ -1,4 +1,6 @@
use std::sync::{Arc, Mutex}; use std::sync::Arc;
extern crate axum;
use axum::{ use axum::{
routing::{get, post}, routing::{get, post},
@ -6,6 +8,7 @@ use axum::{
}; };
use bytes::Bytes; use bytes::Bytes;
use memory_cache::MemoryCache; use memory_cache::MemoryCache;
use tokio::sync::Mutex;
mod state; mod state;
mod new; mod new;

View File

@ -1,144 +1,145 @@
use std::{collections::HashMap, ffi::OsStr, io::Read, path::PathBuf, sync::Arc, time::Duration}; use std::{collections::HashMap, ffi::OsStr, io::Read, path::PathBuf, sync::Arc, time::Duration};
use axum::{ use axum::{
extract::{BodyStream, Query, State}, extract::{BodyStream, Query, State},
http::HeaderValue, http::HeaderValue,
}; };
use bytes::{BufMut, Bytes, BytesMut}; use bytes::{BufMut, Bytes, BytesMut};
use hyper::{header, HeaderMap, StatusCode}; use hyper::{header, HeaderMap, StatusCode};
use rand::Rng; use rand::Rng;
use tokio::{ use tokio::{
fs::File, fs::File,
io::AsyncWriteExt, io::AsyncWriteExt,
sync::mpsc::{self, Receiver, Sender}, sync::mpsc::{self, Receiver, Sender},
}; };
use tokio_stream::StreamExt; use tokio_stream::StreamExt;
// create an upload name from an original file name // create an upload name from an original file name
fn gen_path(original_name: &String) -> PathBuf { fn gen_path(original_name: &String) -> PathBuf {
// extract extension from original name // extract extension from original name
let extension = PathBuf::from(original_name) let extension = PathBuf::from(original_name)
.extension() .extension()
.and_then(OsStr::to_str) .and_then(OsStr::to_str)
.unwrap_or_default() .unwrap_or_default()
.to_string(); .to_string();
// generate a 6-character alphanumeric string // generate a 6-character alphanumeric string
let id: String = rand::thread_rng() let id: String = rand::thread_rng()
.sample_iter(&rand::distributions::Alphanumeric) .sample_iter(&rand::distributions::Alphanumeric)
.take(6) .take(6)
.map(char::from) .map(char::from)
.collect(); .collect();
// create the path // create the path
let mut path = PathBuf::new(); let mut path = PathBuf::new();
path.push("uploads/"); path.push("uploads/");
path.push(id); path.push(id);
path.set_extension(extension); path.set_extension(extension);
// if we're already using it, try again // if we're already using it, try again
if path.exists() { if path.exists() {
gen_path(original_name) gen_path(original_name)
} else { } else {
path path
} }
} }
pub async fn new( #[axum::debug_handler]
State(state): State<Arc<crate::state::AppState>>, pub async fn new(
headers: HeaderMap, State(state): State<Arc<crate::state::AppState>>,
Query(params): Query<HashMap<String, String>>, headers: HeaderMap,
mut stream: BodyStream, Query(params): Query<HashMap<String, String>>,
) -> Result<String, StatusCode> { mut stream: BodyStream,
// require name parameter, it's used for determining the file extension ) -> Result<String, StatusCode> {
if !params.contains_key("name") { // require name parameter, it's used for determining the file extension
return Err(StatusCode::BAD_REQUEST); if !params.contains_key("name") {
} return Err(StatusCode::BAD_REQUEST);
}
// generate a path, take the name, format a url
let path = gen_path(params.get("name").unwrap()); // generate a path, take the name, format a url
let path = gen_path(params.get("name").unwrap());
let name = path
.file_name() let name = path
.and_then(OsStr::to_str) .file_name()
.unwrap_or_default() .and_then(OsStr::to_str)
.to_string(); // i hope this never happens. that would suck .unwrap_or_default()
.to_string(); // i hope this never happens. that would suck
let url = format!("http://127.0.0.1:8000/p/{}", name);
let url = format!("http://127.0.0.1:8000/p/{}", name);
// process the upload in the background so i can send the URL back immediately!
// this isn't supported by ShareX (it waits for the request to complete before handling the response) // process the upload in the background so i can send the URL back immediately!
tokio::spawn(async move { // this isn't supported by ShareX (it waits for the request to complete before handling the response)
// get the content length, and if parsing it fails, assume it's really big tokio::spawn(async move {
// it may be better to make it fully content-length not required because this feels kind of redundant // get the content length, and if parsing it fails, assume it's really big
let content_length = headers // it may be better to make it fully content-length not required because this feels kind of redundant
.get(header::CONTENT_LENGTH) let content_length = headers
.unwrap_or(&HeaderValue::from_static("")) .get(header::CONTENT_LENGTH)
.to_str() .unwrap_or(&HeaderValue::from_static(""))
.and_then(|s| Ok(usize::from_str_radix(s, 10))) .to_str()
.unwrap() .and_then(|s| Ok(usize::from_str_radix(s, 10)))
.unwrap_or(usize::MAX); .unwrap()
.unwrap_or(usize::MAX);
// if the upload size exceeds 80 MB, we skip caching!
// previously, i was going to use redis with a 500 MB max (redis's is 512MiB) // if the upload size exceeds 80 MB, we skip caching!
// with or without redis, 500 MB is still a bit much.. // previously, i was going to use redis with a 500 MB max (redis's is 512MiB)
// it could probably be read from disk before anyone could fully download it // with or without redis, 500 MB is still a bit much..
let mut use_cache = content_length < 80_000_000; // it could probably be read from disk before anyone could fully download it
let mut use_cache = content_length < 80_000_000;
println!(
"[upl] content length: {} using cache: {}", println!(
content_length, use_cache "[upl] content length: {} using cache: {}",
); content_length, use_cache
);
// create file to save upload to
let mut file = File::create(path) // create file to save upload to
.await let mut file = File::create(path)
.expect("could not open file! make sure your upload path exists"); .await
.expect("could not open file! make sure your upload path exists");
let mut data: BytesMut = if use_cache {
BytesMut::with_capacity(content_length) let mut data: BytesMut = if use_cache {
} else { BytesMut::with_capacity(content_length)
BytesMut::new() } else {
}; BytesMut::new()
};
let (tx, mut rx): (Sender<Bytes>, Receiver<Bytes>) = mpsc::channel(1);
let (tx, mut rx): (Sender<Bytes>, Receiver<Bytes>) = mpsc::channel(1);
tokio::spawn(async move {
while let Some(chunk) = rx.recv().await { tokio::spawn(async move {
println!("[io] received new chunk"); while let Some(chunk) = rx.recv().await {
file.write_all(&chunk) println!("[io] received new chunk");
.await file.write_all(&chunk)
.expect("error while writing file to disk"); .await
} .expect("error while writing file to disk");
}); }
});
while let Some(chunk) = stream.next().await {
let chunk = chunk.unwrap(); while let Some(chunk) = stream.next().await {
let chunk = chunk.unwrap();
println!("[upl] sending data to io task");
tx.send(chunk.clone()).await.unwrap(); println!("[upl] sending data to io task");
tx.send(chunk.clone()).await.unwrap();
if use_cache {
println!("[upl] receiving data into cache"); if use_cache {
if data.len() + chunk.len() > data.capacity() { println!("[upl] receiving data into buffer");
println!("[upl] too much data! the client had an invalid content-length!"); if data.len() + chunk.len() > data.capacity() {
println!("[upl] too much data! the client had an invalid content-length!");
// if we receive too much data, drop the buffer and stop using cache (it is still okay to use disk, probably)
data = BytesMut::new(); // if we receive too much data, drop the buffer and stop using cache (it is still okay to use disk, probably)
use_cache = false; data = BytesMut::new();
} else { use_cache = false;
data.put(chunk); } else {
} data.put(chunk);
} }
} }
}
let mut cache = state.cache.lock().unwrap();
let mut cache = state.cache.lock().await;
if use_cache {
println!("[upl] caching upload!!"); if use_cache {
cache.insert(name, data.freeze(), Some(Duration::from_secs(30))); println!("[upl] caching upload!!");
} cache.insert(name, data.freeze(), Some(Duration::from_secs(120)));
}); }
});
Ok(url)
} Ok(url)
}

View File

@ -1,8 +1,7 @@
use std::sync::{Mutex, Arc}; use bytes::Bytes;
use memory_cache::MemoryCache;
use bytes::Bytes; use tokio::sync::Mutex;
use memory_cache::MemoryCache;
pub struct AppState {
pub struct AppState { pub cache: Mutex<MemoryCache<String, Bytes>>
pub cache: Mutex<MemoryCache<String, Bytes>>
} }

View File

@ -7,7 +7,7 @@ use std::{
use axum::{ use axum::{
body::StreamBody, body::StreamBody,
extract::{Path, State}, extract::{Path, State},
response::{IntoResponse, Response}, response::{IntoResponse, Response}, debug_handler,
}; };
use bytes::{buf::Reader, Bytes}; use bytes::{buf::Reader, Bytes};
use hyper::StatusCode; use hyper::StatusCode;
@ -28,12 +28,11 @@ impl IntoResponse for ViewResponse {
} }
} */ } */
#[axum::debug_handler]
pub async fn view( pub async fn view(
State(state): State<Arc<crate::state::AppState>>, State(state): State<Arc<crate::state::AppState>>,
Path(original_path): Path<PathBuf>, Path(original_path): Path<PathBuf>,
) -> Response { ) -> Response {
println!("{:?}", original_path);
// (hopefully) prevent path traversal, just check for any non-file components // (hopefully) prevent path traversal, just check for any non-file components
if original_path if original_path
.components() .components()
@ -50,11 +49,11 @@ pub async fn view(
.unwrap_or_default() .unwrap_or_default()
.to_string(); .to_string();
let cache = state.cache.lock().unwrap(); let cache = state.cache.lock().await;
let cache_item = cache.get(&name); let cache_item = cache.get(&name);
if true /* cache_item.is_none() */ { if cache_item.is_none() {
let mut path = PathBuf::new(); let mut path = PathBuf::new();
path.push("uploads/"); path.push("uploads/");
path.push(name); path.push(name);