initial commit

This commit is contained in:
minish 2024-11-27 21:23:18 -05:00
commit cb3a7100e5
Signed by: min
GPG Key ID: FEECFF24EF0CE9E9
8 changed files with 2220 additions and 0 deletions

1
.gitignore vendored Normal file
View File

@ -0,0 +1 @@
/target

1759
Cargo.lock generated Normal file

File diff suppressed because it is too large Load Diff

22
Cargo.toml Normal file
View File

@ -0,0 +1,22 @@
[package]
name = "srt2rvt"
version = "0.1.0"
edition = "2021"
[dependencies]
ac-ffmpeg = "0.18.1"
anyhow = "1.0"
argh = "0.1.12"
axum = { version = "0.7", features = ["http2", "macros"] }
base64 = "0.22.1"
bytes = "1.5"
crossbeam-queue = "0.3"
dashmap = "6.1"
futures = "0.3"
rayon = "1.10.0"
serde = { version = "1.0", features = ["derive"] }
srt-tokio = { version = "0.4", features = ["ac-ffmpeg"] }
tokio = { version = "1", features = ["full"] }
tokio-stream = "0.1"
tower = "0.5.1"
zstd = { version = "0.13.2", features = ["zstdmt"] }

22
src/app_result.rs Normal file
View File

@ -0,0 +1,22 @@
use axum::{
http,
response::{IntoResponse, Response},
};
pub struct AppError(anyhow::Error);
impl<T: Into<anyhow::Error>> From<T> for AppError {
fn from(err: T) -> Self {
AppError(err.into())
}
}
impl IntoResponse for AppError {
fn into_response(self) -> Response {
eprintln!("route fail! {:?}", self.0);
http::StatusCode::INTERNAL_SERVER_ERROR.into_response()
}
}
// pub type AppResult<T> = std::result::Result<T, AppError>;
pub type AppResult2<T, E> = std::result::Result<std::result::Result<T, E>, AppError>;

36
src/byte_receiver.rs Normal file
View File

@ -0,0 +1,36 @@
use std::{io::Read, sync::mpsc::Receiver};
pub struct ByteReceiver {
rx: Receiver<bytes::Bytes>,
prev: Option<(bytes::Bytes, usize)>,
}
impl ByteReceiver {
pub fn new(rx: Receiver<bytes::Bytes>) -> Self {
Self { rx, prev: None }
}
}
impl Read for ByteReceiver {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
if let Some(ref mut prev) = self.prev {
let limit = std::cmp::min(prev.0.len() - prev.1, buf.len());
buf[..limit].copy_from_slice(&prev.0[prev.1..prev.1 + limit]);
if prev.1 + limit < prev.0.len() {
prev.1 += limit;
} else {
self.prev = None
}
Ok(limit)
} else if let Ok(bytes) = self.rx.recv() {
let limit = std::cmp::min(bytes.len(), buf.len());
buf[..limit].copy_from_slice(&bytes[..limit]);
if buf.len() < bytes.len() {
self.prev = Some((bytes, buf.len()));
}
Ok(limit)
} else {
Ok(0)
}
}
}

326
src/main.rs Normal file
View File

@ -0,0 +1,326 @@
#![feature(portable_simd)]
use ac_ffmpeg::{
codec::{
video::{scaler::Algorithm, PixelFormat, VideoDecoder, VideoFrameScaler},
Decoder,
},
format::{
demuxer::{Demuxer, DemuxerWithStreamInfo, InputFormat},
io::IO,
},
};
use anyhow::Context;
use app_result::AppResult2;
use argh::FromArgs;
use axum::{
extract::{Path, State},
routing::get,
Json, Router,
};
use byte_receiver::ByteReceiver;
use bytes::{Bytes, BytesMut};
use crossbeam_queue::ArrayQueue;
use dashmap::DashMap;
use rayon::prelude::*;
use rvt_packet::RvtPacket;
use srt_tokio::{SrtListener, SrtSocket};
use std::{
io::Read,
str::FromStr,
sync::{mpsc, Arc},
};
use tokio::time::Instant;
use tokio_stream::StreamExt;
use zstd_payload::ZstdPayload;
mod app_result;
mod byte_receiver;
mod rvt_packet;
mod zstd_payload;
fn dump_streams(demuxer: &DemuxerWithStreamInfo<impl Read>) {
for (index, stream) in demuxer.streams().iter().enumerate() {
let params = stream.codec_parameters();
println!("Stream #{index}:");
println!(" duration: {:?}", stream.duration().as_f64());
let tb = stream.time_base();
println!(" time base: {} / {}", tb.num(), tb.den());
if let Some(params) = params.as_audio_codec_parameters() {
println!(" type: audio");
println!(" codec: {}", params.decoder_name().unwrap_or("N/A"));
println!(" sample format: {}", params.sample_format().name());
println!(" sample rate: {}", params.sample_rate());
println!(" channels: {}", params.channel_layout().channels());
println!(" bitrate: {}", params.bit_rate());
} else if let Some(params) = params.as_video_codec_parameters() {
println!(" type: video");
println!(" codec: {}", params.decoder_name().unwrap_or("N/A"));
println!(" width: {}", params.width());
println!(" height: {}", params.height());
println!(" pixel format: {}", params.pixel_format().name());
} else {
println!(" type: unknown");
}
}
}
fn handle_input(cx: Arc<RvtContext>, rx: impl Read) -> anyhow::Result<()> {
let io = IO::from_read_stream(rx);
let format = InputFormat::find_by_name("mpegts").unwrap();
let mut demuxer = Demuxer::builder()
.input_format(Some(format))
.build(io)?
.find_stream_info(None)
.map_err(|(_, err)| err)?;
// Dump stream info
dump_streams(&demuxer);
// Create decoder
let (video_stream_index, video_stream) = demuxer
.streams()
.iter()
.enumerate()
.find(|(_, s)| s.codec_parameters().is_video_codec())
.context("no video streams? what?")?;
let mut decoder = VideoDecoder::from_stream(video_stream)?.build()?;
// Create scaler
let cp = video_stream.codec_parameters();
let vcp = cp.as_video_codec_parameters().unwrap();
let pix_rgba = PixelFormat::from_str("rgba").unwrap();
let w = vcp.width();
let h = vcp.height();
let mut scaler = VideoFrameScaler::builder()
.algorithm(Algorithm::FastBilinear)
.source_width(w)
.source_height(h)
.source_pixel_format(vcp.pixel_format())
.target_width(w)
.target_height(h)
.target_pixel_format(pix_rgba)
.build()?;
// Pull from demuxer, push to decoder
while let Some(packet) = demuxer.take()? {
if packet.stream_index() != video_stream_index {
// Don't care skip it
continue;
}
if let Err(e) = decoder.push(packet) {
println!("Err: {e}");
}
while let Some(frame) = decoder.take()? {
let frame = scaler.scale(&frame)?;
let planes = frame.planes();
let data = planes.first().context("missing first frame")?.data();
let size = w * h * 4;
let data = Bytes::copy_from_slice(&data[0..size]);
let oldest = cx.kfq.force_push(data);
drop(oldest);
println!("force push to kf queue len={}", cx.kfq.len());
}
}
// No more packets to pull, flush decoder.
println!("Flushing decoder...");
decoder.flush()?;
Ok(())
}
async fn handle_socket(
mut socket: SrtSocket,
tx: mpsc::Sender<bytes::Bytes>,
) -> anyhow::Result<()> {
let client_desc = format!(
"(ip_port: {}, sockid: {}, stream_id: {:?})",
socket.settings().remote,
socket.settings().remote_sockid.0,
socket.settings().stream_id,
);
println!("New client connected: {client_desc}");
let mut count = 0;
while let Some((_inst, bytes)) = socket.try_next().await? {
tx.send(bytes)?;
count += 1;
}
println!("Client {client_desc} disconnected, received {count:?} packets");
Ok(())
}
async fn run_srt_server(st: Arc<AppState>) -> anyhow::Result<()> {
let (_binding, mut incoming) = SrtListener::builder().bind(3333).await?;
while let Some(request) = incoming.incoming().next().await {
let mut socket = request.accept(None).await?;
let stream_id = if let Some(s) = &socket.settings().stream_id {
s.clone()
} else {
socket.close_and_finish().await?;
break;
};
let new_st = st.clone();
tokio::spawn(async move {
let (tx, rx) = mpsc::channel();
let cx = Arc::new(RvtContext {
kfq: ArrayQueue::new(new_st.args.playback_size),
});
// add context to map
new_st.contexts.insert(stream_id, cx.clone());
let f1 = tokio::task::spawn_blocking(move || handle_input(cx, ByteReceiver::new(rx)));
let f2 = handle_socket(socket, tx);
let (r1, r2) = tokio::join!(f1, f2);
if let Ok(Err(e)) = r1 {
println!("Error in input handler: {e}");
}
if let Err(e) = r2 {
println!("Error in socket handler: {e}");
}
});
}
Ok(())
}
struct RvtContext {
kfq: crossbeam_queue::ArrayQueue<Bytes>,
}
#[inline(always)]
fn rgb_distance(c1: &[u8], c2: &[u8]) -> f32 {
let d0 = c2[0] as f32 - c1[0] as f32;
let d1 = c2[1] as f32 - c1[1] as f32;
let d2 = c2[2] as f32 - c1[2] as f32;
let ss = d0 * d0 + d1 * d1 + d2 * d2;
ss.sqrt()
}
async fn h_key_delta(
Path((stream_id, num_delta)): Path<(String, u16)>,
State(st): State<Arc<AppState>>,
) -> AppResult2<Json<ZstdPayload>, &'static str> {
let start = Instant::now();
let cx = if let Some(cx) = st.contexts.get(&stream_id) {
cx
} else {
return Ok(Err("stream not found"));
};
let cx = cx.value();
let kf = if let Some(kf) = cx.kfq.pop() {
kf
} else {
return Ok(Err("no keyframe"));
};
let mut deltas = Vec::new();
for _ in 1..num_delta {
let f = if let Some(f) = cx.kfq.pop() {
f
} else {
break;
};
let mut f = BytesMut::from(f);
let start2 = Instant::now();
f.par_chunks_exact_mut(4).enumerate().for_each(|(i, rgba)| {
// check if significant enough to care
// println!("par proc ch {i}");
let kfp = &kf[i..i + 4];
if rgb_distance(rgba, kfp) < 20. {
rgba.copy_from_slice(kfp);
}
});
println!("took {}ms to calc delta", start2.elapsed().as_millis());
deltas.push(f.freeze());
}
let mut frames = vec![kf];
frames.append(&mut deltas);
let rp = RvtPacket(frames);
let bin = rp.encode()?;
let start3 = Instant::now();
let zst_bin = ZstdPayload::new(bin)?;
println!("took {}ms to compress packet", start3.elapsed().as_millis());
println!(
"took {}ms to proc request. sending off now.",
start.elapsed().as_millis()
);
Ok(Ok(Json(zst_bin)))
}
async fn run_http_server(st: Arc<AppState>) -> anyhow::Result<()> {
let app = Router::new()
.route("/", get(|| async { "hi world RVT server" }))
.route("/:stream_id/kd/:num_delta", get(h_key_delta))
.with_state(st);
let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
axum::serve(listener, app).await?;
Ok(())
}
/// srt to rvt transcode server
#[derive(FromArgs, Debug)]
struct Args {
/// how many frames to store in playback buffer
#[argh(option, short = 'p', arg_name = "playback-size")]
playback_size: usize,
}
struct AppState {
args: Args,
contexts: DashMap<String, Arc<RvtContext>>,
}
#[tokio::main]
async fn main() -> anyhow::Result<()> {
let args = argh::from_env();
let st = Arc::new(AppState {
args,
contexts: DashMap::new(),
});
let f1 = run_srt_server(st.clone());
let f2 = run_http_server(st.clone());
let (r1, r2) = tokio::join!(f1, f2);
if let Err(e) = r1 {
println!("Failure in HTTP runner: {e}");
}
if let Err(e) = r2 {
println!("Failure in SRT runner: {e}");
}
Ok(())
}

17
src/rvt_packet.rs Normal file
View File

@ -0,0 +1,17 @@
use bytes::{Bytes, BytesMut, BufMut};
pub struct RvtPacket(pub Vec<Bytes>);
impl RvtPacket {
pub fn encode(self) -> anyhow::Result<Bytes> {
let mut out = BytesMut::new();
for f in self.0 {
let len = f.len().try_into()?;
out.put_u32_le(len);
out.put(f);
}
Ok(out.freeze())
}
}

37
src/zstd_payload.rs Normal file
View File

@ -0,0 +1,37 @@
use base64::{prelude::BASE64_STANDARD, Engine};
use bytes::{Buf, Bytes};
use serde::Serialize;
#[derive(Serialize)]
pub struct ZstdPayload {
m: Option<()>,
t: &'static str,
zbase64: String,
}
pub fn encode_all(source: Bytes, level: i32) -> std::io::Result<Vec<u8>> {
let mut result = Vec::<u8>::new();
let mut encoder = zstd::Encoder::new(&mut result, level)?;
encoder.set_pledged_src_size(Some(source.len() as u64))?;
encoder.multithread(128)?;
std::io::copy(&mut source.reader(), &mut encoder)?;
encoder.finish()?;
Ok(result)
}
impl ZstdPayload {
pub fn new(b: Bytes) -> anyhow::Result<Self> {
let zst = encode_all(b, 1)?;
let b64 = BASE64_STANDARD.encode(zst);
let p = Self {
m: None,
t: "buffer",
zbase64: b64,
};
Ok(p)
}
}