switch to tokio mutexes, it compiles now

This commit is contained in:
minish 2022-12-29 00:03:14 -05:00 committed by minish
parent a2fccd1f1c
commit 9a22170dfa
6 changed files with 179 additions and 158 deletions

19
Cargo.lock generated
View File

@ -38,6 +38,7 @@ checksum = "08b108ad2665fa3f6e6a517c3d80ec3e77d224c47d605167aefaa5d7ef97fa48"
dependencies = [
"async-trait",
"axum-core",
"axum-macros",
"bitflags",
"bytes",
"futures-util",
@ -80,6 +81,18 @@ dependencies = [
"tower-service",
]
[[package]]
name = "axum-macros"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e4df0fc33ada14a338b799002f7e8657711422b25d4e16afb032708d6b185621"
dependencies = [
"heck",
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "bitflags"
version = "1.3.2"
@ -225,6 +238,12 @@ dependencies = [
"ahash",
]
[[package]]
name = "heck"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2540771e65fc8cb83cd6e8a237f70c319bd5c29f78ed1084ba5d50eeac86f7f9"
[[package]]
name = "hermit-abi"
version = "0.1.19"

View File

@ -6,7 +6,7 @@ edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
axum = "0.6.1"
axum = { version = "0.6.1", features = ["macros"] }
hyper = { version = "0.14", features = ["full"] }
tokio = { version = "1", features = ["full"] }
tokio-util = { version = "0.7.4", features = ["full"] }

View File

@ -1,4 +1,6 @@
use std::sync::{Arc, Mutex};
use std::sync::Arc;
extern crate axum;
use axum::{
routing::{get, post},
@ -6,6 +8,7 @@ use axum::{
};
use bytes::Bytes;
use memory_cache::MemoryCache;
use tokio::sync::Mutex;
mod state;
mod new;

View File

@ -1,144 +1,145 @@
use std::{collections::HashMap, ffi::OsStr, io::Read, path::PathBuf, sync::Arc, time::Duration};
use axum::{
extract::{BodyStream, Query, State},
http::HeaderValue,
};
use bytes::{BufMut, Bytes, BytesMut};
use hyper::{header, HeaderMap, StatusCode};
use rand::Rng;
use tokio::{
fs::File,
io::AsyncWriteExt,
sync::mpsc::{self, Receiver, Sender},
};
use tokio_stream::StreamExt;
// create an upload name from an original file name
fn gen_path(original_name: &String) -> PathBuf {
// extract extension from original name
let extension = PathBuf::from(original_name)
.extension()
.and_then(OsStr::to_str)
.unwrap_or_default()
.to_string();
// generate a 6-character alphanumeric string
let id: String = rand::thread_rng()
.sample_iter(&rand::distributions::Alphanumeric)
.take(6)
.map(char::from)
.collect();
// create the path
let mut path = PathBuf::new();
path.push("uploads/");
path.push(id);
path.set_extension(extension);
// if we're already using it, try again
if path.exists() {
gen_path(original_name)
} else {
path
}
}
pub async fn new(
State(state): State<Arc<crate::state::AppState>>,
headers: HeaderMap,
Query(params): Query<HashMap<String, String>>,
mut stream: BodyStream,
) -> Result<String, StatusCode> {
// require name parameter, it's used for determining the file extension
if !params.contains_key("name") {
return Err(StatusCode::BAD_REQUEST);
}
// generate a path, take the name, format a url
let path = gen_path(params.get("name").unwrap());
let name = path
.file_name()
.and_then(OsStr::to_str)
.unwrap_or_default()
.to_string(); // i hope this never happens. that would suck
let url = format!("http://127.0.0.1:8000/p/{}", name);
// process the upload in the background so i can send the URL back immediately!
// this isn't supported by ShareX (it waits for the request to complete before handling the response)
tokio::spawn(async move {
// get the content length, and if parsing it fails, assume it's really big
// it may be better to make it fully content-length not required because this feels kind of redundant
let content_length = headers
.get(header::CONTENT_LENGTH)
.unwrap_or(&HeaderValue::from_static(""))
.to_str()
.and_then(|s| Ok(usize::from_str_radix(s, 10)))
.unwrap()
.unwrap_or(usize::MAX);
// if the upload size exceeds 80 MB, we skip caching!
// previously, i was going to use redis with a 500 MB max (redis's is 512MiB)
// with or without redis, 500 MB is still a bit much..
// it could probably be read from disk before anyone could fully download it
let mut use_cache = content_length < 80_000_000;
println!(
"[upl] content length: {} using cache: {}",
content_length, use_cache
);
// create file to save upload to
let mut file = File::create(path)
.await
.expect("could not open file! make sure your upload path exists");
let mut data: BytesMut = if use_cache {
BytesMut::with_capacity(content_length)
} else {
BytesMut::new()
};
let (tx, mut rx): (Sender<Bytes>, Receiver<Bytes>) = mpsc::channel(1);
tokio::spawn(async move {
while let Some(chunk) = rx.recv().await {
println!("[io] received new chunk");
file.write_all(&chunk)
.await
.expect("error while writing file to disk");
}
});
while let Some(chunk) = stream.next().await {
let chunk = chunk.unwrap();
println!("[upl] sending data to io task");
tx.send(chunk.clone()).await.unwrap();
if use_cache {
println!("[upl] receiving data into cache");
if data.len() + chunk.len() > data.capacity() {
println!("[upl] too much data! the client had an invalid content-length!");
// if we receive too much data, drop the buffer and stop using cache (it is still okay to use disk, probably)
data = BytesMut::new();
use_cache = false;
} else {
data.put(chunk);
}
}
}
let mut cache = state.cache.lock().unwrap();
if use_cache {
println!("[upl] caching upload!!");
cache.insert(name, data.freeze(), Some(Duration::from_secs(30)));
}
});
Ok(url)
}
use std::{collections::HashMap, ffi::OsStr, io::Read, path::PathBuf, sync::Arc, time::Duration};
use axum::{
extract::{BodyStream, Query, State},
http::HeaderValue,
};
use bytes::{BufMut, Bytes, BytesMut};
use hyper::{header, HeaderMap, StatusCode};
use rand::Rng;
use tokio::{
fs::File,
io::AsyncWriteExt,
sync::mpsc::{self, Receiver, Sender},
};
use tokio_stream::StreamExt;
// create an upload name from an original file name
fn gen_path(original_name: &String) -> PathBuf {
// extract extension from original name
let extension = PathBuf::from(original_name)
.extension()
.and_then(OsStr::to_str)
.unwrap_or_default()
.to_string();
// generate a 6-character alphanumeric string
let id: String = rand::thread_rng()
.sample_iter(&rand::distributions::Alphanumeric)
.take(6)
.map(char::from)
.collect();
// create the path
let mut path = PathBuf::new();
path.push("uploads/");
path.push(id);
path.set_extension(extension);
// if we're already using it, try again
if path.exists() {
gen_path(original_name)
} else {
path
}
}
#[axum::debug_handler]
pub async fn new(
State(state): State<Arc<crate::state::AppState>>,
headers: HeaderMap,
Query(params): Query<HashMap<String, String>>,
mut stream: BodyStream,
) -> Result<String, StatusCode> {
// require name parameter, it's used for determining the file extension
if !params.contains_key("name") {
return Err(StatusCode::BAD_REQUEST);
}
// generate a path, take the name, format a url
let path = gen_path(params.get("name").unwrap());
let name = path
.file_name()
.and_then(OsStr::to_str)
.unwrap_or_default()
.to_string(); // i hope this never happens. that would suck
let url = format!("http://127.0.0.1:8000/p/{}", name);
// process the upload in the background so i can send the URL back immediately!
// this isn't supported by ShareX (it waits for the request to complete before handling the response)
tokio::spawn(async move {
// get the content length, and if parsing it fails, assume it's really big
// it may be better to make it fully content-length not required because this feels kind of redundant
let content_length = headers
.get(header::CONTENT_LENGTH)
.unwrap_or(&HeaderValue::from_static(""))
.to_str()
.and_then(|s| Ok(usize::from_str_radix(s, 10)))
.unwrap()
.unwrap_or(usize::MAX);
// if the upload size exceeds 80 MB, we skip caching!
// previously, i was going to use redis with a 500 MB max (redis's is 512MiB)
// with or without redis, 500 MB is still a bit much..
// it could probably be read from disk before anyone could fully download it
let mut use_cache = content_length < 80_000_000;
println!(
"[upl] content length: {} using cache: {}",
content_length, use_cache
);
// create file to save upload to
let mut file = File::create(path)
.await
.expect("could not open file! make sure your upload path exists");
let mut data: BytesMut = if use_cache {
BytesMut::with_capacity(content_length)
} else {
BytesMut::new()
};
let (tx, mut rx): (Sender<Bytes>, Receiver<Bytes>) = mpsc::channel(1);
tokio::spawn(async move {
while let Some(chunk) = rx.recv().await {
println!("[io] received new chunk");
file.write_all(&chunk)
.await
.expect("error while writing file to disk");
}
});
while let Some(chunk) = stream.next().await {
let chunk = chunk.unwrap();
println!("[upl] sending data to io task");
tx.send(chunk.clone()).await.unwrap();
if use_cache {
println!("[upl] receiving data into buffer");
if data.len() + chunk.len() > data.capacity() {
println!("[upl] too much data! the client had an invalid content-length!");
// if we receive too much data, drop the buffer and stop using cache (it is still okay to use disk, probably)
data = BytesMut::new();
use_cache = false;
} else {
data.put(chunk);
}
}
}
let mut cache = state.cache.lock().await;
if use_cache {
println!("[upl] caching upload!!");
cache.insert(name, data.freeze(), Some(Duration::from_secs(120)));
}
});
Ok(url)
}

View File

@ -1,8 +1,7 @@
use std::sync::{Mutex, Arc};
use bytes::Bytes;
use memory_cache::MemoryCache;
pub struct AppState {
pub cache: Mutex<MemoryCache<String, Bytes>>
use bytes::Bytes;
use memory_cache::MemoryCache;
use tokio::sync::Mutex;
pub struct AppState {
pub cache: Mutex<MemoryCache<String, Bytes>>
}

View File

@ -7,7 +7,7 @@ use std::{
use axum::{
body::StreamBody,
extract::{Path, State},
response::{IntoResponse, Response},
response::{IntoResponse, Response}, debug_handler,
};
use bytes::{buf::Reader, Bytes};
use hyper::StatusCode;
@ -28,12 +28,11 @@ impl IntoResponse for ViewResponse {
}
} */
#[axum::debug_handler]
pub async fn view(
State(state): State<Arc<crate::state::AppState>>,
Path(original_path): Path<PathBuf>,
) -> Response {
println!("{:?}", original_path);
// (hopefully) prevent path traversal, just check for any non-file components
if original_path
.components()
@ -50,11 +49,11 @@ pub async fn view(
.unwrap_or_default()
.to_string();
let cache = state.cache.lock().unwrap();
let cache = state.cache.lock().await;
let cache_item = cache.get(&name);
if true /* cache_item.is_none() */ {
if cache_item.is_none() {
let mut path = PathBuf::new();
path.push("uploads/");
path.push(name);