tungstenite/handshake/
mod.rs1pub mod client;
4pub mod headers;
5pub mod machine;
6pub mod server;
7
8use std::{
9    error::Error as ErrorTrait,
10    fmt,
11    io::{Read, Write},
12};
13
14use http::Version;
15use sha1::{Digest, Sha1};
16
17use self::machine::{HandshakeMachine, RoundResult, StageResult, TryParse};
18use crate::error::{Error, ProtocolError};
19
20#[derive(Debug)]
22pub struct MidHandshake<Role: HandshakeRole> {
23    role: Role,
24    machine: HandshakeMachine<Role::InternalStream>,
25}
26
27impl<Role: HandshakeRole> MidHandshake<Role> {
28    pub fn get_ref(&self) -> &HandshakeMachine<Role::InternalStream> {
30        &self.machine
31    }
32
33    pub fn get_mut(&mut self) -> &mut HandshakeMachine<Role::InternalStream> {
35        &mut self.machine
36    }
37
38    pub fn handshake(mut self) -> Result<Role::FinalResult, HandshakeError<Role>> {
40        let mut mach = self.machine;
41        loop {
42            mach = match mach.single_round()? {
43                RoundResult::WouldBlock(m) => {
44                    return Err(HandshakeError::Interrupted(MidHandshake { machine: m, ..self }))
45                }
46                RoundResult::Incomplete(m) => m,
47                RoundResult::StageFinished(s) => match self.role.stage_finished(s)? {
48                    ProcessingResult::Continue(m) => m,
49                    ProcessingResult::Done(result) => return Ok(result),
50                },
51            }
52        }
53    }
54}
55
56pub enum HandshakeError<Role: HandshakeRole> {
58    Interrupted(MidHandshake<Role>),
60    Failure(Error),
62}
63
64impl<Role: HandshakeRole> fmt::Debug for HandshakeError<Role> {
65    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
66        match *self {
67            HandshakeError::Interrupted(_) => write!(f, "HandshakeError::Interrupted(...)"),
68            HandshakeError::Failure(ref e) => write!(f, "HandshakeError::Failure({e:?})"),
69        }
70    }
71}
72
73impl<Role: HandshakeRole> fmt::Display for HandshakeError<Role> {
74    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
75        match *self {
76            HandshakeError::Interrupted(_) => write!(f, "Interrupted handshake (WouldBlock)"),
77            HandshakeError::Failure(ref e) => write!(f, "{e}"),
78        }
79    }
80}
81
82impl<Role: HandshakeRole> ErrorTrait for HandshakeError<Role> {}
83
84impl<Role: HandshakeRole> From<Error> for HandshakeError<Role> {
85    fn from(err: Error) -> Self {
86        HandshakeError::Failure(err)
87    }
88}
89
90pub trait HandshakeRole {
92    #[doc(hidden)]
93    type IncomingData: TryParse;
94    #[doc(hidden)]
95    type InternalStream: Read + Write;
96    #[doc(hidden)]
97    type FinalResult;
98    #[doc(hidden)]
99    fn stage_finished(
100        &mut self,
101        finish: StageResult<Self::IncomingData, Self::InternalStream>,
102    ) -> Result<ProcessingResult<Self::InternalStream, Self::FinalResult>, Error>;
103}
104
105#[doc(hidden)]
107#[derive(Debug)]
108pub enum ProcessingResult<Stream, FinalResult> {
109    Continue(HandshakeMachine<Stream>),
110    Done(FinalResult),
111}
112
113pub fn derive_accept_key(request_key: &[u8]) -> String {
118    const WS_GUID: &[u8] = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
121    let mut sha1 = Sha1::default();
122    sha1.update(request_key);
123    sha1.update(WS_GUID);
124    data_encoding::BASE64.encode(&sha1.finalize())
125}
126
127fn version_as_str(ver: Version) -> crate::Result<&'static str> {
128    match ver {
129        ver if ver == Version::HTTP_09 => Ok("HTTP/0.9"),
130        ver if ver == Version::HTTP_10 => Ok("HTTP/1.0"),
131        ver if ver == Version::HTTP_11 => Ok("HTTP/1.1"),
132        _ => Err(Error::Protocol(ProtocolError::WrongHttpVersion)),
133    }
134}
135
136#[cfg(test)]
137mod tests {
138    use super::derive_accept_key;
139
140    #[test]
141    fn key_conversion() {
142        assert_eq!(derive_accept_key(b"dGhlIHNhbXBsZSBub25jZQ=="), "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=");
144    }
145}