switch to tokio mutexes, it compiles now
This commit is contained in:
parent
a2fccd1f1c
commit
9a22170dfa
|
@ -38,6 +38,7 @@ checksum = "08b108ad2665fa3f6e6a517c3d80ec3e77d224c47d605167aefaa5d7ef97fa48"
|
|||
dependencies = [
|
||||
"async-trait",
|
||||
"axum-core",
|
||||
"axum-macros",
|
||||
"bitflags",
|
||||
"bytes",
|
||||
"futures-util",
|
||||
|
@ -80,6 +81,18 @@ dependencies = [
|
|||
"tower-service",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "axum-macros"
|
||||
version = "0.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e4df0fc33ada14a338b799002f7e8657711422b25d4e16afb032708d6b185621"
|
||||
dependencies = [
|
||||
"heck",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "bitflags"
|
||||
version = "1.3.2"
|
||||
|
@ -225,6 +238,12 @@ dependencies = [
|
|||
"ahash",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "heck"
|
||||
version = "0.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2540771e65fc8cb83cd6e8a237f70c319bd5c29f78ed1084ba5d50eeac86f7f9"
|
||||
|
||||
[[package]]
|
||||
name = "hermit-abi"
|
||||
version = "0.1.19"
|
||||
|
|
|
@ -6,7 +6,7 @@ edition = "2021"
|
|||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
axum = "0.6.1"
|
||||
axum = { version = "0.6.1", features = ["macros"] }
|
||||
hyper = { version = "0.14", features = ["full"] }
|
||||
tokio = { version = "1", features = ["full"] }
|
||||
tokio-util = { version = "0.7.4", features = ["full"] }
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
use std::sync::{Arc, Mutex};
|
||||
use std::sync::Arc;
|
||||
|
||||
extern crate axum;
|
||||
|
||||
use axum::{
|
||||
routing::{get, post},
|
||||
|
@ -6,6 +8,7 @@ use axum::{
|
|||
};
|
||||
use bytes::Bytes;
|
||||
use memory_cache::MemoryCache;
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
mod state;
|
||||
mod new;
|
||||
|
|
289
src/new.rs
289
src/new.rs
|
@ -1,144 +1,145 @@
|
|||
use std::{collections::HashMap, ffi::OsStr, io::Read, path::PathBuf, sync::Arc, time::Duration};
|
||||
|
||||
use axum::{
|
||||
extract::{BodyStream, Query, State},
|
||||
http::HeaderValue,
|
||||
};
|
||||
use bytes::{BufMut, Bytes, BytesMut};
|
||||
use hyper::{header, HeaderMap, StatusCode};
|
||||
use rand::Rng;
|
||||
use tokio::{
|
||||
fs::File,
|
||||
io::AsyncWriteExt,
|
||||
sync::mpsc::{self, Receiver, Sender},
|
||||
};
|
||||
use tokio_stream::StreamExt;
|
||||
|
||||
// create an upload name from an original file name
|
||||
fn gen_path(original_name: &String) -> PathBuf {
|
||||
// extract extension from original name
|
||||
let extension = PathBuf::from(original_name)
|
||||
.extension()
|
||||
.and_then(OsStr::to_str)
|
||||
.unwrap_or_default()
|
||||
.to_string();
|
||||
|
||||
// generate a 6-character alphanumeric string
|
||||
let id: String = rand::thread_rng()
|
||||
.sample_iter(&rand::distributions::Alphanumeric)
|
||||
.take(6)
|
||||
.map(char::from)
|
||||
.collect();
|
||||
|
||||
// create the path
|
||||
let mut path = PathBuf::new();
|
||||
path.push("uploads/");
|
||||
path.push(id);
|
||||
path.set_extension(extension);
|
||||
|
||||
// if we're already using it, try again
|
||||
if path.exists() {
|
||||
gen_path(original_name)
|
||||
} else {
|
||||
path
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn new(
|
||||
State(state): State<Arc<crate::state::AppState>>,
|
||||
headers: HeaderMap,
|
||||
Query(params): Query<HashMap<String, String>>,
|
||||
mut stream: BodyStream,
|
||||
) -> Result<String, StatusCode> {
|
||||
// require name parameter, it's used for determining the file extension
|
||||
if !params.contains_key("name") {
|
||||
return Err(StatusCode::BAD_REQUEST);
|
||||
}
|
||||
|
||||
// generate a path, take the name, format a url
|
||||
let path = gen_path(params.get("name").unwrap());
|
||||
|
||||
let name = path
|
||||
.file_name()
|
||||
.and_then(OsStr::to_str)
|
||||
.unwrap_or_default()
|
||||
.to_string(); // i hope this never happens. that would suck
|
||||
|
||||
let url = format!("http://127.0.0.1:8000/p/{}", name);
|
||||
|
||||
// process the upload in the background so i can send the URL back immediately!
|
||||
// this isn't supported by ShareX (it waits for the request to complete before handling the response)
|
||||
tokio::spawn(async move {
|
||||
// get the content length, and if parsing it fails, assume it's really big
|
||||
// it may be better to make it fully content-length not required because this feels kind of redundant
|
||||
let content_length = headers
|
||||
.get(header::CONTENT_LENGTH)
|
||||
.unwrap_or(&HeaderValue::from_static(""))
|
||||
.to_str()
|
||||
.and_then(|s| Ok(usize::from_str_radix(s, 10)))
|
||||
.unwrap()
|
||||
.unwrap_or(usize::MAX);
|
||||
|
||||
// if the upload size exceeds 80 MB, we skip caching!
|
||||
// previously, i was going to use redis with a 500 MB max (redis's is 512MiB)
|
||||
// with or without redis, 500 MB is still a bit much..
|
||||
// it could probably be read from disk before anyone could fully download it
|
||||
let mut use_cache = content_length < 80_000_000;
|
||||
|
||||
println!(
|
||||
"[upl] content length: {} using cache: {}",
|
||||
content_length, use_cache
|
||||
);
|
||||
|
||||
// create file to save upload to
|
||||
let mut file = File::create(path)
|
||||
.await
|
||||
.expect("could not open file! make sure your upload path exists");
|
||||
|
||||
let mut data: BytesMut = if use_cache {
|
||||
BytesMut::with_capacity(content_length)
|
||||
} else {
|
||||
BytesMut::new()
|
||||
};
|
||||
|
||||
let (tx, mut rx): (Sender<Bytes>, Receiver<Bytes>) = mpsc::channel(1);
|
||||
|
||||
tokio::spawn(async move {
|
||||
while let Some(chunk) = rx.recv().await {
|
||||
println!("[io] received new chunk");
|
||||
file.write_all(&chunk)
|
||||
.await
|
||||
.expect("error while writing file to disk");
|
||||
}
|
||||
});
|
||||
|
||||
while let Some(chunk) = stream.next().await {
|
||||
let chunk = chunk.unwrap();
|
||||
|
||||
println!("[upl] sending data to io task");
|
||||
tx.send(chunk.clone()).await.unwrap();
|
||||
|
||||
if use_cache {
|
||||
println!("[upl] receiving data into cache");
|
||||
if data.len() + chunk.len() > data.capacity() {
|
||||
println!("[upl] too much data! the client had an invalid content-length!");
|
||||
|
||||
// if we receive too much data, drop the buffer and stop using cache (it is still okay to use disk, probably)
|
||||
data = BytesMut::new();
|
||||
use_cache = false;
|
||||
} else {
|
||||
data.put(chunk);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut cache = state.cache.lock().unwrap();
|
||||
|
||||
if use_cache {
|
||||
println!("[upl] caching upload!!");
|
||||
cache.insert(name, data.freeze(), Some(Duration::from_secs(30)));
|
||||
}
|
||||
});
|
||||
|
||||
Ok(url)
|
||||
}
|
||||
use std::{collections::HashMap, ffi::OsStr, io::Read, path::PathBuf, sync::Arc, time::Duration};
|
||||
|
||||
use axum::{
|
||||
extract::{BodyStream, Query, State},
|
||||
http::HeaderValue,
|
||||
};
|
||||
use bytes::{BufMut, Bytes, BytesMut};
|
||||
use hyper::{header, HeaderMap, StatusCode};
|
||||
use rand::Rng;
|
||||
use tokio::{
|
||||
fs::File,
|
||||
io::AsyncWriteExt,
|
||||
sync::mpsc::{self, Receiver, Sender},
|
||||
};
|
||||
use tokio_stream::StreamExt;
|
||||
|
||||
// create an upload name from an original file name
|
||||
fn gen_path(original_name: &String) -> PathBuf {
|
||||
// extract extension from original name
|
||||
let extension = PathBuf::from(original_name)
|
||||
.extension()
|
||||
.and_then(OsStr::to_str)
|
||||
.unwrap_or_default()
|
||||
.to_string();
|
||||
|
||||
// generate a 6-character alphanumeric string
|
||||
let id: String = rand::thread_rng()
|
||||
.sample_iter(&rand::distributions::Alphanumeric)
|
||||
.take(6)
|
||||
.map(char::from)
|
||||
.collect();
|
||||
|
||||
// create the path
|
||||
let mut path = PathBuf::new();
|
||||
path.push("uploads/");
|
||||
path.push(id);
|
||||
path.set_extension(extension);
|
||||
|
||||
// if we're already using it, try again
|
||||
if path.exists() {
|
||||
gen_path(original_name)
|
||||
} else {
|
||||
path
|
||||
}
|
||||
}
|
||||
|
||||
#[axum::debug_handler]
|
||||
pub async fn new(
|
||||
State(state): State<Arc<crate::state::AppState>>,
|
||||
headers: HeaderMap,
|
||||
Query(params): Query<HashMap<String, String>>,
|
||||
mut stream: BodyStream,
|
||||
) -> Result<String, StatusCode> {
|
||||
// require name parameter, it's used for determining the file extension
|
||||
if !params.contains_key("name") {
|
||||
return Err(StatusCode::BAD_REQUEST);
|
||||
}
|
||||
|
||||
// generate a path, take the name, format a url
|
||||
let path = gen_path(params.get("name").unwrap());
|
||||
|
||||
let name = path
|
||||
.file_name()
|
||||
.and_then(OsStr::to_str)
|
||||
.unwrap_or_default()
|
||||
.to_string(); // i hope this never happens. that would suck
|
||||
|
||||
let url = format!("http://127.0.0.1:8000/p/{}", name);
|
||||
|
||||
// process the upload in the background so i can send the URL back immediately!
|
||||
// this isn't supported by ShareX (it waits for the request to complete before handling the response)
|
||||
tokio::spawn(async move {
|
||||
// get the content length, and if parsing it fails, assume it's really big
|
||||
// it may be better to make it fully content-length not required because this feels kind of redundant
|
||||
let content_length = headers
|
||||
.get(header::CONTENT_LENGTH)
|
||||
.unwrap_or(&HeaderValue::from_static(""))
|
||||
.to_str()
|
||||
.and_then(|s| Ok(usize::from_str_radix(s, 10)))
|
||||
.unwrap()
|
||||
.unwrap_or(usize::MAX);
|
||||
|
||||
// if the upload size exceeds 80 MB, we skip caching!
|
||||
// previously, i was going to use redis with a 500 MB max (redis's is 512MiB)
|
||||
// with or without redis, 500 MB is still a bit much..
|
||||
// it could probably be read from disk before anyone could fully download it
|
||||
let mut use_cache = content_length < 80_000_000;
|
||||
|
||||
println!(
|
||||
"[upl] content length: {} using cache: {}",
|
||||
content_length, use_cache
|
||||
);
|
||||
|
||||
// create file to save upload to
|
||||
let mut file = File::create(path)
|
||||
.await
|
||||
.expect("could not open file! make sure your upload path exists");
|
||||
|
||||
let mut data: BytesMut = if use_cache {
|
||||
BytesMut::with_capacity(content_length)
|
||||
} else {
|
||||
BytesMut::new()
|
||||
};
|
||||
|
||||
let (tx, mut rx): (Sender<Bytes>, Receiver<Bytes>) = mpsc::channel(1);
|
||||
|
||||
tokio::spawn(async move {
|
||||
while let Some(chunk) = rx.recv().await {
|
||||
println!("[io] received new chunk");
|
||||
file.write_all(&chunk)
|
||||
.await
|
||||
.expect("error while writing file to disk");
|
||||
}
|
||||
});
|
||||
|
||||
while let Some(chunk) = stream.next().await {
|
||||
let chunk = chunk.unwrap();
|
||||
|
||||
println!("[upl] sending data to io task");
|
||||
tx.send(chunk.clone()).await.unwrap();
|
||||
|
||||
if use_cache {
|
||||
println!("[upl] receiving data into buffer");
|
||||
if data.len() + chunk.len() > data.capacity() {
|
||||
println!("[upl] too much data! the client had an invalid content-length!");
|
||||
|
||||
// if we receive too much data, drop the buffer and stop using cache (it is still okay to use disk, probably)
|
||||
data = BytesMut::new();
|
||||
use_cache = false;
|
||||
} else {
|
||||
data.put(chunk);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut cache = state.cache.lock().await;
|
||||
|
||||
if use_cache {
|
||||
println!("[upl] caching upload!!");
|
||||
cache.insert(name, data.freeze(), Some(Duration::from_secs(120)));
|
||||
}
|
||||
});
|
||||
|
||||
Ok(url)
|
||||
}
|
||||
|
|
13
src/state.rs
13
src/state.rs
|
@ -1,8 +1,7 @@
|
|||
use std::sync::{Mutex, Arc};
|
||||
|
||||
use bytes::Bytes;
|
||||
use memory_cache::MemoryCache;
|
||||
|
||||
pub struct AppState {
|
||||
pub cache: Mutex<MemoryCache<String, Bytes>>
|
||||
use bytes::Bytes;
|
||||
use memory_cache::MemoryCache;
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
pub struct AppState {
|
||||
pub cache: Mutex<MemoryCache<String, Bytes>>
|
||||
}
|
|
@ -7,7 +7,7 @@ use std::{
|
|||
use axum::{
|
||||
body::StreamBody,
|
||||
extract::{Path, State},
|
||||
response::{IntoResponse, Response},
|
||||
response::{IntoResponse, Response}, debug_handler,
|
||||
};
|
||||
use bytes::{buf::Reader, Bytes};
|
||||
use hyper::StatusCode;
|
||||
|
@ -28,12 +28,11 @@ impl IntoResponse for ViewResponse {
|
|||
}
|
||||
} */
|
||||
|
||||
#[axum::debug_handler]
|
||||
pub async fn view(
|
||||
State(state): State<Arc<crate::state::AppState>>,
|
||||
Path(original_path): Path<PathBuf>,
|
||||
) -> Response {
|
||||
println!("{:?}", original_path);
|
||||
|
||||
// (hopefully) prevent path traversal, just check for any non-file components
|
||||
if original_path
|
||||
.components()
|
||||
|
@ -50,11 +49,11 @@ pub async fn view(
|
|||
.unwrap_or_default()
|
||||
.to_string();
|
||||
|
||||
let cache = state.cache.lock().unwrap();
|
||||
let cache = state.cache.lock().await;
|
||||
|
||||
let cache_item = cache.get(&name);
|
||||
|
||||
if true /* cache_item.is_none() */ {
|
||||
if cache_item.is_none() {
|
||||
let mut path = PathBuf::new();
|
||||
path.push("uploads/");
|
||||
path.push(name);
|
||||
|
|
Loading…
Reference in New Issue