initial commit
This commit is contained in:
commit
cb3a7100e5
|
@ -0,0 +1 @@
|
|||
/target
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,22 @@
|
|||
[package]
|
||||
name = "srt2rvt"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
ac-ffmpeg = "0.18.1"
|
||||
anyhow = "1.0"
|
||||
argh = "0.1.12"
|
||||
axum = { version = "0.7", features = ["http2", "macros"] }
|
||||
base64 = "0.22.1"
|
||||
bytes = "1.5"
|
||||
crossbeam-queue = "0.3"
|
||||
dashmap = "6.1"
|
||||
futures = "0.3"
|
||||
rayon = "1.10.0"
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
srt-tokio = { version = "0.4", features = ["ac-ffmpeg"] }
|
||||
tokio = { version = "1", features = ["full"] }
|
||||
tokio-stream = "0.1"
|
||||
tower = "0.5.1"
|
||||
zstd = { version = "0.13.2", features = ["zstdmt"] }
|
|
@ -0,0 +1,22 @@
|
|||
use axum::{
|
||||
http,
|
||||
response::{IntoResponse, Response},
|
||||
};
|
||||
|
||||
pub struct AppError(anyhow::Error);
|
||||
|
||||
impl<T: Into<anyhow::Error>> From<T> for AppError {
|
||||
fn from(err: T) -> Self {
|
||||
AppError(err.into())
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoResponse for AppError {
|
||||
fn into_response(self) -> Response {
|
||||
eprintln!("route fail! {:?}", self.0);
|
||||
http::StatusCode::INTERNAL_SERVER_ERROR.into_response()
|
||||
}
|
||||
}
|
||||
|
||||
// pub type AppResult<T> = std::result::Result<T, AppError>;
|
||||
pub type AppResult2<T, E> = std::result::Result<std::result::Result<T, E>, AppError>;
|
|
@ -0,0 +1,36 @@
|
|||
use std::{io::Read, sync::mpsc::Receiver};
|
||||
|
||||
pub struct ByteReceiver {
|
||||
rx: Receiver<bytes::Bytes>,
|
||||
prev: Option<(bytes::Bytes, usize)>,
|
||||
}
|
||||
|
||||
impl ByteReceiver {
|
||||
pub fn new(rx: Receiver<bytes::Bytes>) -> Self {
|
||||
Self { rx, prev: None }
|
||||
}
|
||||
}
|
||||
|
||||
impl Read for ByteReceiver {
|
||||
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
|
||||
if let Some(ref mut prev) = self.prev {
|
||||
let limit = std::cmp::min(prev.0.len() - prev.1, buf.len());
|
||||
buf[..limit].copy_from_slice(&prev.0[prev.1..prev.1 + limit]);
|
||||
if prev.1 + limit < prev.0.len() {
|
||||
prev.1 += limit;
|
||||
} else {
|
||||
self.prev = None
|
||||
}
|
||||
Ok(limit)
|
||||
} else if let Ok(bytes) = self.rx.recv() {
|
||||
let limit = std::cmp::min(bytes.len(), buf.len());
|
||||
buf[..limit].copy_from_slice(&bytes[..limit]);
|
||||
if buf.len() < bytes.len() {
|
||||
self.prev = Some((bytes, buf.len()));
|
||||
}
|
||||
Ok(limit)
|
||||
} else {
|
||||
Ok(0)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,326 @@
|
|||
#![feature(portable_simd)]
|
||||
|
||||
use ac_ffmpeg::{
|
||||
codec::{
|
||||
video::{scaler::Algorithm, PixelFormat, VideoDecoder, VideoFrameScaler},
|
||||
Decoder,
|
||||
},
|
||||
format::{
|
||||
demuxer::{Demuxer, DemuxerWithStreamInfo, InputFormat},
|
||||
io::IO,
|
||||
},
|
||||
};
|
||||
use anyhow::Context;
|
||||
use app_result::AppResult2;
|
||||
use argh::FromArgs;
|
||||
use axum::{
|
||||
extract::{Path, State},
|
||||
routing::get,
|
||||
Json, Router,
|
||||
};
|
||||
use byte_receiver::ByteReceiver;
|
||||
use bytes::{Bytes, BytesMut};
|
||||
use crossbeam_queue::ArrayQueue;
|
||||
use dashmap::DashMap;
|
||||
use rayon::prelude::*;
|
||||
use rvt_packet::RvtPacket;
|
||||
use srt_tokio::{SrtListener, SrtSocket};
|
||||
use std::{
|
||||
io::Read,
|
||||
str::FromStr,
|
||||
sync::{mpsc, Arc},
|
||||
};
|
||||
use tokio::time::Instant;
|
||||
use tokio_stream::StreamExt;
|
||||
use zstd_payload::ZstdPayload;
|
||||
|
||||
mod app_result;
|
||||
mod byte_receiver;
|
||||
mod rvt_packet;
|
||||
mod zstd_payload;
|
||||
|
||||
fn dump_streams(demuxer: &DemuxerWithStreamInfo<impl Read>) {
|
||||
for (index, stream) in demuxer.streams().iter().enumerate() {
|
||||
let params = stream.codec_parameters();
|
||||
|
||||
println!("Stream #{index}:");
|
||||
println!(" duration: {:?}", stream.duration().as_f64());
|
||||
let tb = stream.time_base();
|
||||
println!(" time base: {} / {}", tb.num(), tb.den());
|
||||
|
||||
if let Some(params) = params.as_audio_codec_parameters() {
|
||||
println!(" type: audio");
|
||||
println!(" codec: {}", params.decoder_name().unwrap_or("N/A"));
|
||||
println!(" sample format: {}", params.sample_format().name());
|
||||
println!(" sample rate: {}", params.sample_rate());
|
||||
println!(" channels: {}", params.channel_layout().channels());
|
||||
println!(" bitrate: {}", params.bit_rate());
|
||||
} else if let Some(params) = params.as_video_codec_parameters() {
|
||||
println!(" type: video");
|
||||
println!(" codec: {}", params.decoder_name().unwrap_or("N/A"));
|
||||
println!(" width: {}", params.width());
|
||||
println!(" height: {}", params.height());
|
||||
println!(" pixel format: {}", params.pixel_format().name());
|
||||
} else {
|
||||
println!(" type: unknown");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_input(cx: Arc<RvtContext>, rx: impl Read) -> anyhow::Result<()> {
|
||||
let io = IO::from_read_stream(rx);
|
||||
let format = InputFormat::find_by_name("mpegts").unwrap();
|
||||
let mut demuxer = Demuxer::builder()
|
||||
.input_format(Some(format))
|
||||
.build(io)?
|
||||
.find_stream_info(None)
|
||||
.map_err(|(_, err)| err)?;
|
||||
|
||||
// Dump stream info
|
||||
dump_streams(&demuxer);
|
||||
|
||||
// Create decoder
|
||||
let (video_stream_index, video_stream) = demuxer
|
||||
.streams()
|
||||
.iter()
|
||||
.enumerate()
|
||||
.find(|(_, s)| s.codec_parameters().is_video_codec())
|
||||
.context("no video streams? what?")?;
|
||||
let mut decoder = VideoDecoder::from_stream(video_stream)?.build()?;
|
||||
|
||||
// Create scaler
|
||||
let cp = video_stream.codec_parameters();
|
||||
let vcp = cp.as_video_codec_parameters().unwrap();
|
||||
let pix_rgba = PixelFormat::from_str("rgba").unwrap();
|
||||
let w = vcp.width();
|
||||
let h = vcp.height();
|
||||
let mut scaler = VideoFrameScaler::builder()
|
||||
.algorithm(Algorithm::FastBilinear)
|
||||
.source_width(w)
|
||||
.source_height(h)
|
||||
.source_pixel_format(vcp.pixel_format())
|
||||
.target_width(w)
|
||||
.target_height(h)
|
||||
.target_pixel_format(pix_rgba)
|
||||
.build()?;
|
||||
|
||||
// Pull from demuxer, push to decoder
|
||||
while let Some(packet) = demuxer.take()? {
|
||||
if packet.stream_index() != video_stream_index {
|
||||
// Don't care skip it
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Err(e) = decoder.push(packet) {
|
||||
println!("Err: {e}");
|
||||
}
|
||||
|
||||
while let Some(frame) = decoder.take()? {
|
||||
let frame = scaler.scale(&frame)?;
|
||||
|
||||
let planes = frame.planes();
|
||||
let data = planes.first().context("missing first frame")?.data();
|
||||
|
||||
let size = w * h * 4;
|
||||
let data = Bytes::copy_from_slice(&data[0..size]);
|
||||
|
||||
let oldest = cx.kfq.force_push(data);
|
||||
drop(oldest);
|
||||
println!("force push to kf queue len={}", cx.kfq.len());
|
||||
}
|
||||
}
|
||||
|
||||
// No more packets to pull, flush decoder.
|
||||
println!("Flushing decoder...");
|
||||
decoder.flush()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn handle_socket(
|
||||
mut socket: SrtSocket,
|
||||
tx: mpsc::Sender<bytes::Bytes>,
|
||||
) -> anyhow::Result<()> {
|
||||
let client_desc = format!(
|
||||
"(ip_port: {}, sockid: {}, stream_id: {:?})",
|
||||
socket.settings().remote,
|
||||
socket.settings().remote_sockid.0,
|
||||
socket.settings().stream_id,
|
||||
);
|
||||
|
||||
println!("New client connected: {client_desc}");
|
||||
|
||||
let mut count = 0;
|
||||
while let Some((_inst, bytes)) = socket.try_next().await? {
|
||||
tx.send(bytes)?;
|
||||
count += 1;
|
||||
}
|
||||
|
||||
println!("Client {client_desc} disconnected, received {count:?} packets");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn run_srt_server(st: Arc<AppState>) -> anyhow::Result<()> {
|
||||
let (_binding, mut incoming) = SrtListener::builder().bind(3333).await?;
|
||||
|
||||
while let Some(request) = incoming.incoming().next().await {
|
||||
let mut socket = request.accept(None).await?;
|
||||
|
||||
let stream_id = if let Some(s) = &socket.settings().stream_id {
|
||||
s.clone()
|
||||
} else {
|
||||
socket.close_and_finish().await?;
|
||||
break;
|
||||
};
|
||||
|
||||
let new_st = st.clone();
|
||||
tokio::spawn(async move {
|
||||
let (tx, rx) = mpsc::channel();
|
||||
|
||||
let cx = Arc::new(RvtContext {
|
||||
kfq: ArrayQueue::new(new_st.args.playback_size),
|
||||
});
|
||||
|
||||
// add context to map
|
||||
new_st.contexts.insert(stream_id, cx.clone());
|
||||
|
||||
let f1 = tokio::task::spawn_blocking(move || handle_input(cx, ByteReceiver::new(rx)));
|
||||
let f2 = handle_socket(socket, tx);
|
||||
let (r1, r2) = tokio::join!(f1, f2);
|
||||
|
||||
if let Ok(Err(e)) = r1 {
|
||||
println!("Error in input handler: {e}");
|
||||
}
|
||||
if let Err(e) = r2 {
|
||||
println!("Error in socket handler: {e}");
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
struct RvtContext {
|
||||
kfq: crossbeam_queue::ArrayQueue<Bytes>,
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn rgb_distance(c1: &[u8], c2: &[u8]) -> f32 {
|
||||
let d0 = c2[0] as f32 - c1[0] as f32;
|
||||
let d1 = c2[1] as f32 - c1[1] as f32;
|
||||
let d2 = c2[2] as f32 - c1[2] as f32;
|
||||
|
||||
let ss = d0 * d0 + d1 * d1 + d2 * d2;
|
||||
|
||||
ss.sqrt()
|
||||
}
|
||||
|
||||
async fn h_key_delta(
|
||||
Path((stream_id, num_delta)): Path<(String, u16)>,
|
||||
State(st): State<Arc<AppState>>,
|
||||
) -> AppResult2<Json<ZstdPayload>, &'static str> {
|
||||
let start = Instant::now();
|
||||
|
||||
let cx = if let Some(cx) = st.contexts.get(&stream_id) {
|
||||
cx
|
||||
} else {
|
||||
return Ok(Err("stream not found"));
|
||||
};
|
||||
|
||||
let cx = cx.value();
|
||||
|
||||
let kf = if let Some(kf) = cx.kfq.pop() {
|
||||
kf
|
||||
} else {
|
||||
return Ok(Err("no keyframe"));
|
||||
};
|
||||
let mut deltas = Vec::new();
|
||||
|
||||
for _ in 1..num_delta {
|
||||
let f = if let Some(f) = cx.kfq.pop() {
|
||||
f
|
||||
} else {
|
||||
break;
|
||||
};
|
||||
|
||||
let mut f = BytesMut::from(f);
|
||||
|
||||
let start2 = Instant::now();
|
||||
f.par_chunks_exact_mut(4).enumerate().for_each(|(i, rgba)| {
|
||||
// check if significant enough to care
|
||||
// println!("par proc ch {i}");
|
||||
let kfp = &kf[i..i + 4];
|
||||
if rgb_distance(rgba, kfp) < 20. {
|
||||
rgba.copy_from_slice(kfp);
|
||||
}
|
||||
});
|
||||
println!("took {}ms to calc delta", start2.elapsed().as_millis());
|
||||
|
||||
deltas.push(f.freeze());
|
||||
}
|
||||
|
||||
let mut frames = vec![kf];
|
||||
frames.append(&mut deltas);
|
||||
|
||||
let rp = RvtPacket(frames);
|
||||
let bin = rp.encode()?;
|
||||
let start3 = Instant::now();
|
||||
let zst_bin = ZstdPayload::new(bin)?;
|
||||
println!("took {}ms to compress packet", start3.elapsed().as_millis());
|
||||
|
||||
println!(
|
||||
"took {}ms to proc request. sending off now.",
|
||||
start.elapsed().as_millis()
|
||||
);
|
||||
|
||||
Ok(Ok(Json(zst_bin)))
|
||||
}
|
||||
|
||||
async fn run_http_server(st: Arc<AppState>) -> anyhow::Result<()> {
|
||||
let app = Router::new()
|
||||
.route("/", get(|| async { "hi world RVT server" }))
|
||||
.route("/:stream_id/kd/:num_delta", get(h_key_delta))
|
||||
.with_state(st);
|
||||
|
||||
let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
|
||||
axum::serve(listener, app).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// srt to rvt transcode server
|
||||
#[derive(FromArgs, Debug)]
|
||||
struct Args {
|
||||
/// how many frames to store in playback buffer
|
||||
#[argh(option, short = 'p', arg_name = "playback-size")]
|
||||
playback_size: usize,
|
||||
}
|
||||
|
||||
struct AppState {
|
||||
args: Args,
|
||||
contexts: DashMap<String, Arc<RvtContext>>,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
let args = argh::from_env();
|
||||
|
||||
let st = Arc::new(AppState {
|
||||
args,
|
||||
contexts: DashMap::new(),
|
||||
});
|
||||
|
||||
let f1 = run_srt_server(st.clone());
|
||||
let f2 = run_http_server(st.clone());
|
||||
let (r1, r2) = tokio::join!(f1, f2);
|
||||
|
||||
if let Err(e) = r1 {
|
||||
println!("Failure in HTTP runner: {e}");
|
||||
}
|
||||
if let Err(e) = r2 {
|
||||
println!("Failure in SRT runner: {e}");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
|
@ -0,0 +1,17 @@
|
|||
use bytes::{Bytes, BytesMut, BufMut};
|
||||
|
||||
pub struct RvtPacket(pub Vec<Bytes>);
|
||||
|
||||
impl RvtPacket {
|
||||
pub fn encode(self) -> anyhow::Result<Bytes> {
|
||||
let mut out = BytesMut::new();
|
||||
|
||||
for f in self.0 {
|
||||
let len = f.len().try_into()?;
|
||||
out.put_u32_le(len);
|
||||
out.put(f);
|
||||
}
|
||||
|
||||
Ok(out.freeze())
|
||||
}
|
||||
}
|
|
@ -0,0 +1,37 @@
|
|||
use base64::{prelude::BASE64_STANDARD, Engine};
|
||||
use bytes::{Buf, Bytes};
|
||||
use serde::Serialize;
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct ZstdPayload {
|
||||
m: Option<()>,
|
||||
t: &'static str,
|
||||
zbase64: String,
|
||||
}
|
||||
|
||||
pub fn encode_all(source: Bytes, level: i32) -> std::io::Result<Vec<u8>> {
|
||||
let mut result = Vec::<u8>::new();
|
||||
|
||||
let mut encoder = zstd::Encoder::new(&mut result, level)?;
|
||||
encoder.set_pledged_src_size(Some(source.len() as u64))?;
|
||||
encoder.multithread(128)?;
|
||||
std::io::copy(&mut source.reader(), &mut encoder)?;
|
||||
encoder.finish()?;
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
impl ZstdPayload {
|
||||
pub fn new(b: Bytes) -> anyhow::Result<Self> {
|
||||
let zst = encode_all(b, 1)?;
|
||||
let b64 = BASE64_STANDARD.encode(zst);
|
||||
|
||||
let p = Self {
|
||||
m: None,
|
||||
t: "buffer",
|
||||
zbase64: b64,
|
||||
};
|
||||
|
||||
Ok(p)
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue