use std::{ collections::HashMap, io::{Read, Write}, net::TcpListener, str::from_utf8, sync::{Arc, Mutex}, thread, }; #[cfg(feature = "https")] use rustls::{ pki_types::{pem::PemObject, CertificateDer, PrivateKeyDer}, ServerConfig, ServerConnection, StreamOwned, }; use crate::{ client, http::{request::RequestLine, Request, Response}, Client, }; pub struct Server { listener: TcpListener, on_request: Arc Response + Send + 'static>>>>, #[cfg(feature = "https")] tls_config: Option>, } impl Server { pub fn new(ip: &str, port: i32) -> Self { let addr = format!("{}:{}", ip, port); let listener = TcpListener::bind(addr).unwrap(); Server { listener, on_request: Arc::new(Mutex::new(None)), #[cfg(feature = "https")] tls_config: None, } } pub fn run(&mut self) { for stream in self.listener.incoming() { let stream = stream.unwrap(); let on_request = Arc::clone(&self.on_request); #[cfg(feature = "https")] let tls_config = self.tls_config.clone(); thread::spawn(move || { #[cfg(feature = "https")] let stream = if let Some(tls_config) = tls_config { let connection = ServerConnection::new(tls_config).unwrap(); client::Stream::Tls(StreamOwned::new(connection, stream)) } else { client::Stream::Tcp(stream) }; #[cfg(not(feature = "https"))] let stream = client::Stream::Tcp(stream); let mut client = Client::new(stream); let mut buffer = vec![0; 1024]; let bytes_read = client.read(&mut buffer).unwrap(); if bytes_read > 0 { if let Ok(request_str) = from_utf8(&buffer[..bytes_read]) { // TODO: support proper error handling if let Some(on_request) = &*on_request.lock().unwrap_or_else(|e| e.into_inner()) { let request; { let mut lines = request_str.lines(); let request_line = RequestLine::from_str(lines.next().unwrap()); let mut headers = HashMap::new(); for line in lines { if let Some((k, v)) = line.split_once(": ") { headers.insert(k, v); } } request = Request { request_line, headers, }; } let response = on_request(request); { let response_line = response.response_line.to_string(); client.write_string(response_line).unwrap(); if let Some(headers) = response.headers { for (k, v) in headers { client.write_string(format!("{}: {}", k, v)).unwrap(); } } else { client.write(b"\r\n").unwrap(); } if let Some(content) = response.content { client.write_all(&content).unwrap(); } else { client.write(b"\r\n").unwrap(); } } client.flush().unwrap(); } } } }); } } #[cfg(feature = "https")] pub fn with_https(mut self, cert_file: &str, key_file: &str) -> Self { let certs = CertificateDer::pem_file_iter(cert_file) .unwrap() .map(|cert| cert.unwrap()) .collect::>(); let key = PrivateKeyDer::from_pem_file(key_file).unwrap(); let config = ServerConfig::builder() .with_no_client_auth() .with_single_cert(certs, key) .unwrap(); self.tls_config = Some(Arc::new(config)); self } pub fn on_request(&mut self, f: F) where F: Fn(Request) -> Response + Send + 'static, { *self.on_request.lock().unwrap() = Some(Box::new(f)) } }