mirror of
https://codeberg.org/bunbun/fluffle/
synced 2025-04-21 22:27:56 -07:00
195 lines
5.7 KiB
Rust
195 lines
5.7 KiB
Rust
use std::{
|
|
io::{Read, Write},
|
|
net::{IpAddr, SocketAddr, TcpListener, TcpStream},
|
|
sync::{Arc, Mutex},
|
|
thread,
|
|
};
|
|
|
|
use bytes::{Bytes, BytesMut};
|
|
use http::{Method, Request, Response, Version};
|
|
#[cfg(feature = "https")]
|
|
use rustls::{
|
|
pki_types::{pem::PemObject, CertificateDer, PrivateKeyDer},
|
|
ServerConfig, ServerConnection, StreamOwned,
|
|
};
|
|
|
|
pub use http;
|
|
|
|
enum Stream {
|
|
Tcp(TcpStream),
|
|
#[cfg(feature = "https")]
|
|
Tls(StreamOwned<ServerConnection, TcpStream>),
|
|
}
|
|
|
|
impl Read for Stream {
|
|
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
|
|
match self {
|
|
Stream::Tcp(stream) => stream.read(buf),
|
|
#[cfg(feature = "https")]
|
|
Stream::Tls(stream) => stream.read(buf),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl Write for Stream {
|
|
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
|
|
match self {
|
|
Stream::Tcp(stream) => stream.write(buf),
|
|
#[cfg(feature = "https")]
|
|
Stream::Tls(stream) => stream.write(buf),
|
|
}
|
|
}
|
|
|
|
fn flush(&mut self) -> std::io::Result<()> {
|
|
match self {
|
|
Stream::Tcp(stream) => stream.flush(),
|
|
#[cfg(feature = "https")]
|
|
Stream::Tls(stream) => stream.flush(),
|
|
}
|
|
}
|
|
}
|
|
|
|
pub struct Server {
|
|
listener: TcpListener,
|
|
#[cfg(feature = "https")]
|
|
tls_config: Option<Arc<ServerConfig>>,
|
|
request_handler:
|
|
Arc<Mutex<Option<Box<dyn Fn(Request<Bytes>) -> Response<Bytes> + Send + Sync>>>>,
|
|
}
|
|
|
|
impl Server {
|
|
pub fn new(ip: IpAddr, port: u16) -> Self {
|
|
let addr = SocketAddr::from((ip, port));
|
|
let listener = TcpListener::bind(addr).unwrap();
|
|
|
|
Server {
|
|
listener,
|
|
#[cfg(feature = "https")]
|
|
tls_config: None,
|
|
request_handler: Arc::new(Mutex::new(None)),
|
|
}
|
|
}
|
|
|
|
pub fn run(&mut self) {
|
|
for stream in self.listener.incoming() {
|
|
let stream = stream.unwrap();
|
|
|
|
let request_handler = Arc::clone(&self.request_handler);
|
|
|
|
#[cfg(feature = "https")]
|
|
let tls_config = self.tls_config.clone();
|
|
|
|
thread::spawn(move || {
|
|
#[cfg(feature = "https")]
|
|
let mut stream = if let Some(tls_config) = tls_config {
|
|
let conn = ServerConnection::new(tls_config).unwrap();
|
|
Stream::Tls(StreamOwned::new(conn, stream))
|
|
} else {
|
|
Stream::Tcp(stream)
|
|
};
|
|
|
|
#[cfg(not(feature = "https"))]
|
|
let mut stream = Stream::Tcp(stream);
|
|
|
|
let mut buffer = vec![0; 1024];
|
|
let bytes_read = stream.read(&mut buffer).unwrap();
|
|
|
|
if bytes_read > 0 {
|
|
let request_bytes = BytesMut::from(&buffer[..bytes_read]);
|
|
if let Some(request_handler) = &*request_handler
|
|
.lock()
|
|
.unwrap_or_else(|poisoned| poisoned.into_inner())
|
|
{
|
|
let request = parse_request(request_bytes).unwrap();
|
|
let response = parse_response(request_handler(request));
|
|
|
|
stream.write(&response).unwrap();
|
|
stream.flush().unwrap();
|
|
}
|
|
}
|
|
});
|
|
}
|
|
}
|
|
|
|
#[cfg(feature = "https")]
|
|
pub fn with_https(mut self, cert_file: &str, key_file: &str) -> Self {
|
|
let certs = CertificateDer::pem_file_iter(cert_file)
|
|
.unwrap()
|
|
.map(|cert| cert.unwrap())
|
|
.collect::<Vec<_>>();
|
|
|
|
let key = PrivateKeyDer::from_pem_file(key_file).unwrap();
|
|
|
|
let config = ServerConfig::builder()
|
|
.with_no_client_auth()
|
|
.with_single_cert(certs, key)
|
|
.unwrap();
|
|
|
|
self.tls_config = Some(Arc::new(config));
|
|
|
|
self
|
|
}
|
|
|
|
pub fn on_request<F>(&self, handler: F)
|
|
where
|
|
F: Fn(Request<Bytes>) -> Response<Bytes> + Send + Sync + 'static,
|
|
{
|
|
*self.request_handler.lock().unwrap() = Some(Box::new(handler))
|
|
}
|
|
}
|
|
|
|
fn parse_request(request: BytesMut) -> Option<Request<Bytes>> {
|
|
let request_str = std::str::from_utf8(&request).ok()?;
|
|
let mut lines = request_str.lines();
|
|
let request_line = lines.next()?;
|
|
let mut parts = request_line.split_whitespace();
|
|
|
|
let method = parts.next()?.parse::<Method>().ok()?;
|
|
let uri = parts.next()?;
|
|
let version = match parts.next()? {
|
|
"HTTP/0.9" => Version::HTTP_09,
|
|
"HTTP/1.0" => Version::HTTP_10,
|
|
"HTTP/1.1" => Version::HTTP_11,
|
|
"HTTP/2.0" => Version::HTTP_2,
|
|
"HTTP/3.0" => Version::HTTP_3,
|
|
_ => unreachable!(),
|
|
};
|
|
|
|
let mut builder = Request::builder().method(method).uri(uri).version(version);
|
|
|
|
for line in lines.by_ref() {
|
|
if let Some((k, v)) = line.split_once(": ") {
|
|
builder = builder.header(k, v);
|
|
};
|
|
}
|
|
|
|
let body = lines.collect::<Vec<&str>>().join("\n");
|
|
|
|
builder.body(body.into()).ok()
|
|
}
|
|
|
|
fn parse_response(response: Response<Bytes>) -> BytesMut {
|
|
let mut response_bytes = BytesMut::new();
|
|
|
|
let version = match response.version() {
|
|
Version::HTTP_09 => "HTTP/0.9",
|
|
Version::HTTP_10 => "HTTP/1.0",
|
|
Version::HTTP_11 => "HTTP/1.1",
|
|
Version::HTTP_2 => "HTTP/2.0",
|
|
Version::HTTP_3 => "HTTP/3.0",
|
|
_ => unreachable!(),
|
|
};
|
|
let status = response.status();
|
|
|
|
response_bytes.extend_from_slice(format!("{} {}\r\n", version, status).as_bytes());
|
|
|
|
for (k, v) in response.headers() {
|
|
response_bytes.extend_from_slice(format!("{}: {}\r\n", k, v.to_str().unwrap()).as_bytes())
|
|
}
|
|
|
|
response_bytes.extend_from_slice(b"\r\n");
|
|
response_bytes.extend_from_slice(response.body());
|
|
|
|
response_bytes
|
|
}
|