fluffle/src/lib.rs
2025-03-15 23:09:34 -07:00

195 lines
5.7 KiB
Rust

use std::{
io::{Read, Write},
net::{IpAddr, SocketAddr, TcpListener, TcpStream},
sync::{Arc, Mutex},
thread,
};
use bytes::{Bytes, BytesMut};
use http::{Method, Request, Response, Version};
#[cfg(feature = "https")]
use rustls::{
pki_types::{pem::PemObject, CertificateDer, PrivateKeyDer},
ServerConfig, ServerConnection, StreamOwned,
};
pub use http;
enum Stream {
Tcp(TcpStream),
#[cfg(feature = "https")]
Tls(StreamOwned<ServerConnection, TcpStream>),
}
impl Read for Stream {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
match self {
Stream::Tcp(stream) => stream.read(buf),
#[cfg(feature = "https")]
Stream::Tls(stream) => stream.read(buf),
}
}
}
impl Write for Stream {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
match self {
Stream::Tcp(stream) => stream.write(buf),
#[cfg(feature = "https")]
Stream::Tls(stream) => stream.write(buf),
}
}
fn flush(&mut self) -> std::io::Result<()> {
match self {
Stream::Tcp(stream) => stream.flush(),
#[cfg(feature = "https")]
Stream::Tls(stream) => stream.flush(),
}
}
}
pub struct Server {
listener: TcpListener,
#[cfg(feature = "https")]
tls_config: Option<Arc<ServerConfig>>,
request_handler:
Arc<Mutex<Option<Box<dyn Fn(Request<Bytes>) -> Response<Bytes> + Send + Sync>>>>,
}
impl Server {
pub fn new(ip: IpAddr, port: u16) -> Self {
let addr = SocketAddr::from((ip, port));
let listener = TcpListener::bind(addr).unwrap();
Server {
listener,
#[cfg(feature = "https")]
tls_config: None,
request_handler: Arc::new(Mutex::new(None)),
}
}
pub fn run(&mut self) {
for stream in self.listener.incoming() {
let stream = stream.unwrap();
let request_handler = Arc::clone(&self.request_handler);
#[cfg(feature = "https")]
let tls_config = self.tls_config.clone();
thread::spawn(move || {
#[cfg(feature = "https")]
let mut stream = if let Some(tls_config) = tls_config {
let conn = ServerConnection::new(tls_config).unwrap();
Stream::Tls(StreamOwned::new(conn, stream))
} else {
Stream::Tcp(stream)
};
#[cfg(not(feature = "https"))]
let mut stream = Stream::Tcp(stream);
let mut buffer = vec![0; 1024];
let bytes_read = stream.read(&mut buffer).unwrap();
if bytes_read > 0 {
let request_bytes = BytesMut::from(&buffer[..bytes_read]);
if let Some(request_handler) = &*request_handler
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner())
{
let request = parse_request(request_bytes).unwrap();
let response = parse_response(request_handler(request));
stream.write(&response).unwrap();
stream.flush().unwrap();
}
}
});
}
}
#[cfg(feature = "https")]
pub fn with_https(mut self, cert_file: &str, key_file: &str) -> Self {
let certs = CertificateDer::pem_file_iter(cert_file)
.unwrap()
.map(|cert| cert.unwrap())
.collect::<Vec<_>>();
let key = PrivateKeyDer::from_pem_file(key_file).unwrap();
let config = ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(certs, key)
.unwrap();
self.tls_config = Some(Arc::new(config));
self
}
pub fn on_request<F>(&self, handler: F)
where
F: Fn(Request<Bytes>) -> Response<Bytes> + Send + Sync + 'static,
{
*self.request_handler.lock().unwrap() = Some(Box::new(handler))
}
}
fn parse_request(request: BytesMut) -> Option<Request<Bytes>> {
let request_str = std::str::from_utf8(&request).ok()?;
let mut lines = request_str.lines();
let request_line = lines.next()?;
let mut parts = request_line.split_whitespace();
let method = parts.next()?.parse::<Method>().ok()?;
let uri = parts.next()?;
let version = match parts.next()? {
"HTTP/0.9" => Version::HTTP_09,
"HTTP/1.0" => Version::HTTP_10,
"HTTP/1.1" => Version::HTTP_11,
"HTTP/2.0" => Version::HTTP_2,
"HTTP/3.0" => Version::HTTP_3,
_ => unreachable!(),
};
let mut builder = Request::builder().method(method).uri(uri).version(version);
for line in lines.by_ref() {
if let Some((k, v)) = line.split_once(": ") {
builder = builder.header(k, v);
};
}
let body = lines.collect::<Vec<&str>>().join("\n");
builder.body(body.into()).ok()
}
fn parse_response(response: Response<Bytes>) -> BytesMut {
let mut response_bytes = BytesMut::new();
let version = match response.version() {
Version::HTTP_09 => "HTTP/0.9",
Version::HTTP_10 => "HTTP/1.0",
Version::HTTP_11 => "HTTP/1.1",
Version::HTTP_2 => "HTTP/2.0",
Version::HTTP_3 => "HTTP/3.0",
_ => unreachable!(),
};
let status = response.status();
response_bytes.extend_from_slice(format!("{} {}\r\n", version, status).as_bytes());
for (k, v) in response.headers() {
response_bytes.extend_from_slice(format!("{}: {}\r\n", k, v.to_str().unwrap()).as_bytes())
}
response_bytes.extend_from_slice(b"\r\n");
response_bytes.extend_from_slice(response.body());
response_bytes
}