use std::{ io::{Read, Write}, net::{IpAddr, SocketAddr, TcpListener, TcpStream}, sync::{Arc, Mutex}, thread, }; use bytes::{Bytes, BytesMut}; use http::{Method, Request, Response, Version}; #[cfg(feature = "https")] use rustls::{ pki_types::{pem::PemObject, CertificateDer, PrivateKeyDer}, ServerConfig, ServerConnection, StreamOwned, }; pub use http; enum Stream { Tcp(TcpStream), #[cfg(feature = "https")] Tls(StreamOwned), } impl Read for Stream { fn read(&mut self, buf: &mut [u8]) -> std::io::Result { match self { Stream::Tcp(stream) => stream.read(buf), #[cfg(feature = "https")] Stream::Tls(stream) => stream.read(buf), } } } impl Write for Stream { fn write(&mut self, buf: &[u8]) -> std::io::Result { match self { Stream::Tcp(stream) => stream.write(buf), #[cfg(feature = "https")] Stream::Tls(stream) => stream.write(buf), } } fn flush(&mut self) -> std::io::Result<()> { match self { Stream::Tcp(stream) => stream.flush(), #[cfg(feature = "https")] Stream::Tls(stream) => stream.flush(), } } } pub struct Server { listener: TcpListener, #[cfg(feature = "https")] tls_config: Option>, request_handler: Arc) -> Response + Send + Sync>>>>, } impl Server { pub fn new(ip: IpAddr, port: u16) -> Self { let addr = SocketAddr::from((ip, port)); let listener = TcpListener::bind(addr).unwrap(); Server { listener, #[cfg(feature = "https")] tls_config: None, request_handler: Arc::new(Mutex::new(None)), } } pub fn run(&mut self) { for stream in self.listener.incoming() { let stream = stream.unwrap(); let request_handler = Arc::clone(&self.request_handler); #[cfg(feature = "https")] let tls_config = self.tls_config.clone(); thread::spawn(move || { #[cfg(feature = "https")] let mut stream = if let Some(tls_config) = tls_config { let conn = ServerConnection::new(tls_config).unwrap(); Stream::Tls(StreamOwned::new(conn, stream)) } else { Stream::Tcp(stream) }; #[cfg(not(feature = "https"))] let mut stream = Stream::Tcp(stream); let mut buffer = vec![0; 1024]; let bytes_read = stream.read(&mut buffer).unwrap(); if bytes_read > 0 { let request_bytes = BytesMut::from(&buffer[..bytes_read]); if let Some(request_handler) = &*request_handler .lock() .unwrap_or_else(|poisoned| poisoned.into_inner()) { let request = parse_request(request_bytes).unwrap(); let response = parse_response(request_handler(request)); stream.write(&response).unwrap(); stream.flush().unwrap(); } } }); } } #[cfg(feature = "https")] pub fn with_https(mut self, cert_file: &str, key_file: &str) -> Self { let certs = CertificateDer::pem_file_iter(cert_file) .unwrap() .map(|cert| cert.unwrap()) .collect::>(); let key = PrivateKeyDer::from_pem_file(key_file).unwrap(); let config = ServerConfig::builder() .with_no_client_auth() .with_single_cert(certs, key) .unwrap(); self.tls_config = Some(Arc::new(config)); self } pub fn on_request(&self, handler: F) where F: Fn(Request) -> Response + Send + Sync + 'static, { *self.request_handler.lock().unwrap() = Some(Box::new(handler)) } } fn parse_request(request: BytesMut) -> Option> { let request_str = std::str::from_utf8(&request).ok()?; let mut lines = request_str.lines(); let request_line = lines.next()?; let mut parts = request_line.split_whitespace(); let method = parts.next()?.parse::().ok()?; let uri = parts.next()?; let version = match parts.next()? { "HTTP/0.9" => Version::HTTP_09, "HTTP/1.0" => Version::HTTP_10, "HTTP/1.1" => Version::HTTP_11, "HTTP/2.0" => Version::HTTP_2, "HTTP/3.0" => Version::HTTP_3, _ => unreachable!(), }; let mut builder = Request::builder().method(method).uri(uri).version(version); for line in lines.by_ref() { if let Some((k, v)) = line.split_once(": ") { builder = builder.header(k, v); }; } let body = lines.collect::>().join("\n"); builder.body(body.into()).ok() } fn parse_response(response: Response) -> BytesMut { let mut response_bytes = BytesMut::new(); let version = match response.version() { Version::HTTP_09 => "HTTP/0.9", Version::HTTP_10 => "HTTP/1.0", Version::HTTP_11 => "HTTP/1.1", Version::HTTP_2 => "HTTP/2.0", Version::HTTP_3 => "HTTP/3.0", _ => unreachable!(), }; let status = response.status(); response_bytes.extend_from_slice(format!("{} {}\r\n", version, status).as_bytes()); for (k, v) in response.headers() { response_bytes.extend_from_slice(format!("{}: {}\r\n", k, v.to_str().unwrap()).as_bytes()) } response_bytes.extend_from_slice(b"\r\n"); response_bytes.extend_from_slice(response.body()); response_bytes }