1use std::{
4 io::{self, Read, Write},
5 marker::PhantomData,
6 result::Result as StdResult,
7};
8
9use http::{
10 response::Builder, HeaderMap, Request as HttpRequest, Response as HttpResponse, StatusCode,
11};
12use httparse::Status;
13use log::*;
14
15use super::{
16 derive_accept_key,
17 headers::{FromHttparse, MAX_HEADERS},
18 machine::{HandshakeMachine, StageResult, TryParse},
19 HandshakeRole, MidHandshake, ProcessingResult,
20};
21use crate::{
22 error::{Error, ProtocolError, Result},
23 handshake::version_as_str,
24 protocol::{Role, WebSocket, WebSocketConfig},
25};
26
27pub type Request = HttpRequest<()>;
29
30pub type Response = HttpResponse<()>;
32
33pub type ErrorResponse = HttpResponse<Option<String>>;
35
36fn create_parts<T>(request: &HttpRequest<T>) -> Result<Builder> {
37 if request.method() != http::Method::GET {
38 return Err(Error::Protocol(ProtocolError::WrongHttpMethod));
39 }
40
41 if request.version() < http::Version::HTTP_11 {
42 return Err(Error::Protocol(ProtocolError::WrongHttpVersion));
43 }
44
45 if !request
46 .headers()
47 .get("Connection")
48 .and_then(|h| h.to_str().ok())
49 .map(|h| h.split([' ', ',']).any(|p| p.eq_ignore_ascii_case("Upgrade")))
50 .unwrap_or(false)
51 {
52 return Err(Error::Protocol(ProtocolError::MissingConnectionUpgradeHeader));
53 }
54
55 if !request
56 .headers()
57 .get("Upgrade")
58 .and_then(|h| h.to_str().ok())
59 .map(|h| h.eq_ignore_ascii_case("websocket"))
60 .unwrap_or(false)
61 {
62 return Err(Error::Protocol(ProtocolError::MissingUpgradeWebSocketHeader));
63 }
64
65 if !request.headers().get("Sec-WebSocket-Version").map(|h| h == "13").unwrap_or(false) {
66 return Err(Error::Protocol(ProtocolError::MissingSecWebSocketVersionHeader));
67 }
68
69 let key = request
70 .headers()
71 .get("Sec-WebSocket-Key")
72 .ok_or(Error::Protocol(ProtocolError::MissingSecWebSocketKey))?;
73
74 let builder = Response::builder()
75 .status(StatusCode::SWITCHING_PROTOCOLS)
76 .version(request.version())
77 .header("Connection", "Upgrade")
78 .header("Upgrade", "websocket")
79 .header("Sec-WebSocket-Accept", derive_accept_key(key.as_bytes()));
80
81 Ok(builder)
82}
83
84pub fn create_response(request: &Request) -> Result<Response> {
86 Ok(create_parts(request)?.body(())?)
87}
88
89pub fn create_response_with_body<T1, T2>(
91 request: &HttpRequest<T1>,
92 generate_body: impl FnOnce() -> T2,
93) -> Result<HttpResponse<T2>> {
94 Ok(create_parts(request)?.body(generate_body())?)
95}
96
97pub fn write_response<T>(mut w: impl io::Write, response: &HttpResponse<T>) -> Result<()> {
99 writeln!(
100 w,
101 "{version} {status}\r",
102 version = version_as_str(response.version())?,
103 status = response.status()
104 )?;
105
106 for (k, v) in response.headers() {
107 writeln!(w, "{}: {}\r", k, v.to_str()?)?;
108 }
109
110 writeln!(w, "\r")?;
111
112 Ok(())
113}
114
115impl TryParse for Request {
116 fn try_parse(buf: &[u8]) -> Result<Option<(usize, Self)>> {
117 let mut hbuffer = [httparse::EMPTY_HEADER; MAX_HEADERS];
118 let mut req = httparse::Request::new(&mut hbuffer);
119 Ok(match req.parse(buf)? {
120 Status::Partial => None,
121 Status::Complete(size) => Some((size, Request::from_httparse(req)?)),
122 })
123 }
124}
125
126impl<'h, 'b: 'h> FromHttparse<httparse::Request<'h, 'b>> for Request {
127 fn from_httparse(raw: httparse::Request<'h, 'b>) -> Result<Self> {
128 if raw.method.expect("Bug: no method in header") != "GET" {
129 return Err(Error::Protocol(ProtocolError::WrongHttpMethod));
130 }
131
132 if raw.version.expect("Bug: no HTTP version") < 1 {
133 return Err(Error::Protocol(ProtocolError::WrongHttpVersion));
134 }
135
136 let headers = HeaderMap::from_httparse(raw.headers)?;
137
138 let mut request = Request::new(());
139 *request.method_mut() = http::Method::GET;
140 *request.headers_mut() = headers;
141 *request.uri_mut() = raw.path.expect("Bug: no path in header").parse()?;
142 *request.version_mut() = http::Version::HTTP_11;
145
146 Ok(request)
147 }
148}
149
150pub trait Callback: Sized {
157 fn on_request(
161 self,
162 request: &Request,
163 response: Response,
164 ) -> StdResult<Response, ErrorResponse>;
165}
166
167impl<F> Callback for F
168where
169 F: FnOnce(&Request, Response) -> StdResult<Response, ErrorResponse>,
170{
171 fn on_request(
172 self,
173 request: &Request,
174 response: Response,
175 ) -> StdResult<Response, ErrorResponse> {
176 self(request, response)
177 }
178}
179
180#[derive(Clone, Copy, Debug)]
182pub struct NoCallback;
183
184impl Callback for NoCallback {
185 fn on_request(
186 self,
187 _request: &Request,
188 response: Response,
189 ) -> StdResult<Response, ErrorResponse> {
190 Ok(response)
191 }
192}
193
194#[allow(missing_copy_implementations)]
196#[derive(Debug)]
197pub struct ServerHandshake<S, C> {
198 callback: Option<C>,
202 config: Option<WebSocketConfig>,
204 error_response: Option<ErrorResponse>,
206 _marker: PhantomData<S>,
208}
209
210impl<S: Read + Write, C: Callback> ServerHandshake<S, C> {
211 pub fn start(stream: S, callback: C, config: Option<WebSocketConfig>) -> MidHandshake<Self> {
216 trace!("Server handshake initiated.");
217 MidHandshake {
218 machine: HandshakeMachine::start_read(stream),
219 role: ServerHandshake {
220 callback: Some(callback),
221 config,
222 error_response: None,
223 _marker: PhantomData,
224 },
225 }
226 }
227}
228
229impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> {
230 type IncomingData = Request;
231 type InternalStream = S;
232 type FinalResult = WebSocket<S>;
233
234 fn stage_finished(
235 &mut self,
236 finish: StageResult<Self::IncomingData, Self::InternalStream>,
237 ) -> Result<ProcessingResult<Self::InternalStream, Self::FinalResult>> {
238 Ok(match finish {
239 StageResult::DoneReading { stream, result, tail } => {
240 if !tail.is_empty() {
241 return Err(Error::Protocol(ProtocolError::JunkAfterRequest));
242 }
243
244 let response = create_response(&result)?;
245 let callback_result = if let Some(callback) = self.callback.take() {
246 callback.on_request(&result, response)
247 } else {
248 Ok(response)
249 };
250
251 match callback_result {
252 Ok(response) => {
253 let mut output = vec![];
254 write_response(&mut output, &response)?;
255 ProcessingResult::Continue(HandshakeMachine::start_write(stream, output))
256 }
257
258 Err(resp) => {
259 if resp.status().is_success() {
260 return Err(Error::Protocol(ProtocolError::CustomResponseSuccessful));
261 }
262
263 self.error_response = Some(resp);
264 let resp = self.error_response.as_ref().unwrap();
265
266 let mut output = vec![];
267 write_response(&mut output, resp)?;
268
269 if let Some(body) = resp.body() {
270 output.extend_from_slice(body.as_bytes());
271 }
272
273 ProcessingResult::Continue(HandshakeMachine::start_write(stream, output))
274 }
275 }
276 }
277
278 StageResult::DoneWriting(stream) => {
279 if let Some(err) = self.error_response.take() {
280 debug!("Server handshake failed.");
281
282 let (parts, body) = err.into_parts();
283 let body = body.map(|b| b.as_bytes().to_vec());
284 return Err(Error::Http(http::Response::from_parts(parts, body).into()));
285 } else {
286 debug!("Server handshake done.");
287 let websocket = WebSocket::from_raw_socket(stream, Role::Server, self.config);
288 ProcessingResult::Done(websocket)
289 }
290 }
291 })
292 }
293}
294
295#[cfg(test)]
296mod tests {
297 use super::{super::machine::TryParse, create_response, Request};
298
299 #[test]
300 fn request_parsing() {
301 const DATA: &[u8] = b"GET /script.ws HTTP/1.1\r\nHost: foo.com\r\n\r\n";
302 let (_, req) = Request::try_parse(DATA).unwrap().unwrap();
303 assert_eq!(req.uri().path(), "/script.ws");
304 assert_eq!(req.headers().get("Host").unwrap(), &b"foo.com"[..]);
305 }
306
307 #[test]
308 fn request_replying() {
309 const DATA: &[u8] = b"\
310 GET /script.ws HTTP/1.1\r\n\
311 Host: foo.com\r\n\
312 Connection: upgrade\r\n\
313 Upgrade: websocket\r\n\
314 Sec-WebSocket-Version: 13\r\n\
315 Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n\
316 \r\n";
317 let (_, req) = Request::try_parse(DATA).unwrap().unwrap();
318 let response = create_response(&req).unwrap();
319
320 assert_eq!(
321 response.headers().get("Sec-WebSocket-Accept").unwrap(),
322 b"s3pPLMBiTxaQ9kYGzzhZRbK+xOo=".as_ref()
323 );
324 }
325}