tungstenite/handshake/
server.rs

1//! Server handshake machine.
2
3use std::{
4    io::{self, Read, Write},
5    marker::PhantomData,
6    result::Result as StdResult,
7};
8
9use http::{
10    response::Builder, HeaderMap, Request as HttpRequest, Response as HttpResponse, StatusCode,
11};
12use httparse::Status;
13use log::*;
14
15use super::{
16    derive_accept_key,
17    headers::{FromHttparse, MAX_HEADERS},
18    machine::{HandshakeMachine, StageResult, TryParse},
19    HandshakeRole, MidHandshake, ProcessingResult,
20};
21use crate::{
22    error::{Error, ProtocolError, Result},
23    handshake::version_as_str,
24    protocol::{Role, WebSocket, WebSocketConfig},
25};
26
27/// Server request type.
28pub type Request = HttpRequest<()>;
29
30/// Server response type.
31pub type Response = HttpResponse<()>;
32
33/// Server error response type.
34pub type ErrorResponse = HttpResponse<Option<String>>;
35
36fn create_parts<T>(request: &HttpRequest<T>) -> Result<Builder> {
37    if request.method() != http::Method::GET {
38        return Err(Error::Protocol(ProtocolError::WrongHttpMethod));
39    }
40
41    if request.version() < http::Version::HTTP_11 {
42        return Err(Error::Protocol(ProtocolError::WrongHttpVersion));
43    }
44
45    if !request
46        .headers()
47        .get("Connection")
48        .and_then(|h| h.to_str().ok())
49        .map(|h| h.split([' ', ',']).any(|p| p.eq_ignore_ascii_case("Upgrade")))
50        .unwrap_or(false)
51    {
52        return Err(Error::Protocol(ProtocolError::MissingConnectionUpgradeHeader));
53    }
54
55    if !request
56        .headers()
57        .get("Upgrade")
58        .and_then(|h| h.to_str().ok())
59        .map(|h| h.eq_ignore_ascii_case("websocket"))
60        .unwrap_or(false)
61    {
62        return Err(Error::Protocol(ProtocolError::MissingUpgradeWebSocketHeader));
63    }
64
65    if !request.headers().get("Sec-WebSocket-Version").map(|h| h == "13").unwrap_or(false) {
66        return Err(Error::Protocol(ProtocolError::MissingSecWebSocketVersionHeader));
67    }
68
69    let key = request
70        .headers()
71        .get("Sec-WebSocket-Key")
72        .ok_or(Error::Protocol(ProtocolError::MissingSecWebSocketKey))?;
73
74    let builder = Response::builder()
75        .status(StatusCode::SWITCHING_PROTOCOLS)
76        .version(request.version())
77        .header("Connection", "Upgrade")
78        .header("Upgrade", "websocket")
79        .header("Sec-WebSocket-Accept", derive_accept_key(key.as_bytes()));
80
81    Ok(builder)
82}
83
84/// Create a response for the request.
85pub fn create_response(request: &Request) -> Result<Response> {
86    Ok(create_parts(request)?.body(())?)
87}
88
89/// Create a response for the request with a custom body.
90pub fn create_response_with_body<T1, T2>(
91    request: &HttpRequest<T1>,
92    generate_body: impl FnOnce() -> T2,
93) -> Result<HttpResponse<T2>> {
94    Ok(create_parts(request)?.body(generate_body())?)
95}
96
97/// Write `response` to the stream `w`.
98pub fn write_response<T>(mut w: impl io::Write, response: &HttpResponse<T>) -> Result<()> {
99    writeln!(
100        w,
101        "{version} {status}\r",
102        version = version_as_str(response.version())?,
103        status = response.status()
104    )?;
105
106    for (k, v) in response.headers() {
107        writeln!(w, "{}: {}\r", k, v.to_str()?)?;
108    }
109
110    writeln!(w, "\r")?;
111
112    Ok(())
113}
114
115impl TryParse for Request {
116    fn try_parse(buf: &[u8]) -> Result<Option<(usize, Self)>> {
117        let mut hbuffer = [httparse::EMPTY_HEADER; MAX_HEADERS];
118        let mut req = httparse::Request::new(&mut hbuffer);
119        Ok(match req.parse(buf)? {
120            Status::Partial => None,
121            Status::Complete(size) => Some((size, Request::from_httparse(req)?)),
122        })
123    }
124}
125
126impl<'h, 'b: 'h> FromHttparse<httparse::Request<'h, 'b>> for Request {
127    fn from_httparse(raw: httparse::Request<'h, 'b>) -> Result<Self> {
128        if raw.method.expect("Bug: no method in header") != "GET" {
129            return Err(Error::Protocol(ProtocolError::WrongHttpMethod));
130        }
131
132        if raw.version.expect("Bug: no HTTP version") < /*1.*/1 {
133            return Err(Error::Protocol(ProtocolError::WrongHttpVersion));
134        }
135
136        let headers = HeaderMap::from_httparse(raw.headers)?;
137
138        let mut request = Request::new(());
139        *request.method_mut() = http::Method::GET;
140        *request.headers_mut() = headers;
141        *request.uri_mut() = raw.path.expect("Bug: no path in header").parse()?;
142        // TODO: httparse only supports HTTP 0.9/1.0/1.1 but not HTTP 2.0
143        // so the only valid value we could get in the response would be 1.1.
144        *request.version_mut() = http::Version::HTTP_11;
145
146        Ok(request)
147    }
148}
149
150/// The callback trait.
151///
152/// The callback is called when the server receives an incoming WebSocket
153/// handshake request from the client. Specifying a callback allows you to analyze incoming headers
154/// and add additional headers to the response that server sends to the client and/or reject the
155/// connection based on the incoming headers.
156pub trait Callback: Sized {
157    /// Called whenever the server read the request from the client and is ready to reply to it.
158    /// May return additional reply headers.
159    /// Returning an error resulting in rejecting the incoming connection.
160    fn on_request(
161        self,
162        request: &Request,
163        response: Response,
164    ) -> StdResult<Response, ErrorResponse>;
165}
166
167impl<F> Callback for F
168where
169    F: FnOnce(&Request, Response) -> StdResult<Response, ErrorResponse>,
170{
171    fn on_request(
172        self,
173        request: &Request,
174        response: Response,
175    ) -> StdResult<Response, ErrorResponse> {
176        self(request, response)
177    }
178}
179
180/// Stub for callback that does nothing.
181#[derive(Clone, Copy, Debug)]
182pub struct NoCallback;
183
184impl Callback for NoCallback {
185    fn on_request(
186        self,
187        _request: &Request,
188        response: Response,
189    ) -> StdResult<Response, ErrorResponse> {
190        Ok(response)
191    }
192}
193
194/// Server handshake role.
195#[allow(missing_copy_implementations)]
196#[derive(Debug)]
197pub struct ServerHandshake<S, C> {
198    /// Callback which is called whenever the server read the request from the client and is ready
199    /// to reply to it. The callback returns an optional headers which will be added to the reply
200    /// which the server sends to the user.
201    callback: Option<C>,
202    /// WebSocket configuration.
203    config: Option<WebSocketConfig>,
204    /// Error code/flag. If set, an error will be returned after sending response to the client.
205    error_response: Option<ErrorResponse>,
206    /// Internal stream type.
207    _marker: PhantomData<S>,
208}
209
210impl<S: Read + Write, C: Callback> ServerHandshake<S, C> {
211    /// Start server handshake. `callback` specifies a custom callback which the user can pass to
212    /// the handshake, this callback will be called when the a websocket client connects to the
213    /// server, you can specify the callback if you want to add additional header to the client
214    /// upon join based on the incoming headers.
215    pub fn start(stream: S, callback: C, config: Option<WebSocketConfig>) -> MidHandshake<Self> {
216        trace!("Server handshake initiated.");
217        MidHandshake {
218            machine: HandshakeMachine::start_read(stream),
219            role: ServerHandshake {
220                callback: Some(callback),
221                config,
222                error_response: None,
223                _marker: PhantomData,
224            },
225        }
226    }
227}
228
229impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> {
230    type IncomingData = Request;
231    type InternalStream = S;
232    type FinalResult = WebSocket<S>;
233
234    fn stage_finished(
235        &mut self,
236        finish: StageResult<Self::IncomingData, Self::InternalStream>,
237    ) -> Result<ProcessingResult<Self::InternalStream, Self::FinalResult>> {
238        Ok(match finish {
239            StageResult::DoneReading { stream, result, tail } => {
240                if !tail.is_empty() {
241                    return Err(Error::Protocol(ProtocolError::JunkAfterRequest));
242                }
243
244                let response = create_response(&result)?;
245                let callback_result = if let Some(callback) = self.callback.take() {
246                    callback.on_request(&result, response)
247                } else {
248                    Ok(response)
249                };
250
251                match callback_result {
252                    Ok(response) => {
253                        let mut output = vec![];
254                        write_response(&mut output, &response)?;
255                        ProcessingResult::Continue(HandshakeMachine::start_write(stream, output))
256                    }
257
258                    Err(resp) => {
259                        if resp.status().is_success() {
260                            return Err(Error::Protocol(ProtocolError::CustomResponseSuccessful));
261                        }
262
263                        self.error_response = Some(resp);
264                        let resp = self.error_response.as_ref().unwrap();
265
266                        let mut output = vec![];
267                        write_response(&mut output, resp)?;
268
269                        if let Some(body) = resp.body() {
270                            output.extend_from_slice(body.as_bytes());
271                        }
272
273                        ProcessingResult::Continue(HandshakeMachine::start_write(stream, output))
274                    }
275                }
276            }
277
278            StageResult::DoneWriting(stream) => {
279                if let Some(err) = self.error_response.take() {
280                    debug!("Server handshake failed.");
281
282                    let (parts, body) = err.into_parts();
283                    let body = body.map(|b| b.as_bytes().to_vec());
284                    return Err(Error::Http(http::Response::from_parts(parts, body).into()));
285                } else {
286                    debug!("Server handshake done.");
287                    let websocket = WebSocket::from_raw_socket(stream, Role::Server, self.config);
288                    ProcessingResult::Done(websocket)
289                }
290            }
291        })
292    }
293}
294
295#[cfg(test)]
296mod tests {
297    use super::{super::machine::TryParse, create_response, Request};
298
299    #[test]
300    fn request_parsing() {
301        const DATA: &[u8] = b"GET /script.ws HTTP/1.1\r\nHost: foo.com\r\n\r\n";
302        let (_, req) = Request::try_parse(DATA).unwrap().unwrap();
303        assert_eq!(req.uri().path(), "/script.ws");
304        assert_eq!(req.headers().get("Host").unwrap(), &b"foo.com"[..]);
305    }
306
307    #[test]
308    fn request_replying() {
309        const DATA: &[u8] = b"\
310            GET /script.ws HTTP/1.1\r\n\
311            Host: foo.com\r\n\
312            Connection: upgrade\r\n\
313            Upgrade: websocket\r\n\
314            Sec-WebSocket-Version: 13\r\n\
315            Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n\
316            \r\n";
317        let (_, req) = Request::try_parse(DATA).unwrap().unwrap();
318        let response = create_response(&req).unwrap();
319
320        assert_eq!(
321            response.headers().get("Sec-WebSocket-Accept").unwrap(),
322            b"s3pPLMBiTxaQ9kYGzzhZRbK+xOo=".as_ref()
323        );
324    }
325}