From 1b830a63fe8e1c8488b70fd4495e2b83a211bf3e Mon Sep 17 00:00:00 2001 From: Winter Hille Date: Sun, 2 Feb 2025 19:24:33 -0800 Subject: [PATCH] reworked https support --- src/client.rs | 43 ++++++++++++++++++++++++++----------------- src/server.rs | 17 ++++++++++++----- 2 files changed, 38 insertions(+), 22 deletions(-) diff --git a/src/client.rs b/src/client.rs index 5778de2..8008ed0 100644 --- a/src/client.rs +++ b/src/client.rs @@ -3,51 +3,60 @@ use std::{ net::TcpStream, }; -#[cfg(not(feature = "https"))] -pub struct Client { - stream: TcpStream, +#[cfg(feature = "https")] +use rustls::{ServerConnection, StreamOwned}; + +pub(crate) enum Stream { + Tcp(TcpStream), + #[cfg(feature = "https")] + Tls(StreamOwned), } -#[cfg(feature = "https")] pub struct Client { - stream: rustls::StreamOwned, + stream: Stream, } impl Client { - #[cfg(not(feature = "https"))] - pub(crate) fn new(stream: TcpStream) -> Self { - Client { stream } - } - - #[cfg(feature = "https")] - pub(crate) fn new(stream: rustls::StreamOwned) -> Self { + pub(crate) fn new(stream: Stream) -> Self { Client { stream } } // basically an alias for write (might remove) pub(crate) fn write_bytes(&mut self, mut bytes: Vec) -> io::Result { bytes.extend(b"\r\n"); - self.stream.write(&bytes) + self.write(&bytes) } pub(crate) fn write_string(&mut self, mut string: String) -> io::Result { string.push_str("\r\n"); - self.stream.write(string.as_bytes()) + self.write(string.as_bytes()) } } impl Read for Client { fn read(&mut self, buf: &mut [u8]) -> io::Result { - self.stream.read(buf) + match &mut self.stream { + Stream::Tcp(stream) => stream.read(buf), + #[cfg(feature = "https")] + Stream::Tls(stream) => stream.read(buf), + } } } impl Write for Client { fn write(&mut self, buf: &[u8]) -> io::Result { - self.stream.write(buf) + match &mut self.stream { + Stream::Tcp(stream) => stream.write(buf), + #[cfg(feature = "https")] + Stream::Tls(stream) => stream.write(buf), + } } fn flush(&mut self) -> io::Result<()> { - self.stream.flush() + match &mut self.stream { + Stream::Tcp(stream) => stream.flush(), + #[cfg(feature = "https")] + Stream::Tls(stream) => stream.flush(), + } } } diff --git a/src/server.rs b/src/server.rs index 74e5d0e..e2cbdb6 100644 --- a/src/server.rs +++ b/src/server.rs @@ -10,10 +10,11 @@ use std::{ #[cfg(feature = "https")] use rustls::{ pki_types::{pem::PemObject, CertificateDer, PrivateKeyDer}, - ServerConfig, ServerConnection, + ServerConfig, ServerConnection, StreamOwned, }; use crate::{ + client, http::{request::RequestLine, Request, Response}, Client, }; @@ -49,13 +50,19 @@ impl Server { let on_request = Arc::clone(&self.on_request); #[cfg(feature = "https")] - let tls_config = self.tls_config.clone().unwrap(); + let tls_config = self.tls_config.clone(); thread::spawn(move || { #[cfg(feature = "https")] - let connection = ServerConnection::new(tls_config).unwrap(); - #[cfg(feature = "https")] - let stream = rustls::StreamOwned::new(connection, stream); + let stream = if let Some(tls_config) = tls_config { + let connection = ServerConnection::new(tls_config).unwrap(); + client::Stream::Tls(StreamOwned::new(connection, stream)) + } else { + client::Stream::Tcp(stream) + }; + + #[cfg(not(feature = "https"))] + let stream = client::Stream::Tcp(stream); let mut client = Client::new(stream);