1use std::convert::TryFrom;
2use std::fmt;
3
4use bytes::Bytes;
5use http::uri::{self, Authority, Scheme, Uri};
6use http::HeaderValue;
7
8use crate::util::{IterExt, TryFromValues};
9use crate::Error;
10
11#[derive(Clone, Debug, PartialEq, Eq, Hash)]
29pub struct Origin(OriginOrNull);
30
31derive_header! {
32 Origin(_),
33 name: ORIGIN
34}
35
36#[derive(Clone, Debug, PartialEq, Eq, Hash)]
37enum OriginOrNull {
38 Origin(Scheme, Authority),
39 Null,
40}
41
42impl Origin {
43 pub const NULL: Origin = Origin(OriginOrNull::Null);
45
46 #[inline]
48 pub fn is_null(&self) -> bool {
49 matches!(self.0, OriginOrNull::Null)
50 }
51
52 #[inline]
54 pub fn scheme(&self) -> &str {
55 match self.0 {
56 OriginOrNull::Origin(ref scheme, _) => scheme.as_str(),
57 OriginOrNull::Null => "",
58 }
59 }
60
61 #[inline]
63 pub fn hostname(&self) -> &str {
64 match self.0 {
65 OriginOrNull::Origin(_, ref auth) => auth.host(),
66 OriginOrNull::Null => "",
67 }
68 }
69
70 #[inline]
72 pub fn port(&self) -> Option<u16> {
73 match self.0 {
74 OriginOrNull::Origin(_, ref auth) => auth.port_u16(),
75 OriginOrNull::Null => None,
76 }
77 }
78
79 pub fn try_from_parts(
81 scheme: &str,
82 host: &str,
83 port: impl Into<Option<u16>>,
84 ) -> Result<Self, InvalidOrigin> {
85 struct MaybePort(Option<u16>);
86
87 impl fmt::Display for MaybePort {
88 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
89 if let Some(port) = self.0 {
90 write!(f, ":{}", port)
91 } else {
92 Ok(())
93 }
94 }
95 }
96
97 let bytes = Bytes::from(format!("{}://{}{}", scheme, host, MaybePort(port.into())));
98 HeaderValue::from_maybe_shared(bytes)
99 .ok()
100 .and_then(|val| Self::try_from_value(&val))
101 .ok_or(InvalidOrigin { _inner: () })
102 }
103
104 pub(super) fn try_from_value(value: &HeaderValue) -> Option<Self> {
106 OriginOrNull::try_from_value(value).map(Origin)
107 }
108
109 pub(super) fn to_value(&self) -> HeaderValue {
110 (&self.0).into()
111 }
112}
113
114impl fmt::Display for Origin {
115 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
116 match self.0 {
117 OriginOrNull::Origin(ref scheme, ref auth) => write!(f, "{}://{}", scheme, auth),
118 OriginOrNull::Null => f.write_str("null"),
119 }
120 }
121}
122
123error_type!(InvalidOrigin);
124
125impl OriginOrNull {
126 fn try_from_value(value: &HeaderValue) -> Option<Self> {
127 if value == "null" {
128 return Some(OriginOrNull::Null);
129 }
130
131 let uri = Uri::try_from(value.as_bytes()).ok()?;
132
133 let (scheme, auth) = match uri.into_parts() {
134 uri::Parts {
135 scheme: Some(scheme),
136 authority: Some(auth),
137 path_and_query: None,
138 ..
139 } => (scheme, auth),
140 uri::Parts {
141 scheme: Some(ref scheme),
142 authority: Some(ref auth),
143 path_and_query: Some(ref p),
144 ..
145 } if p == "/" => (scheme.clone(), auth.clone()),
146 _ => {
147 return None;
148 }
149 };
150
151 Some(OriginOrNull::Origin(scheme, auth))
152 }
153}
154
155impl TryFromValues for OriginOrNull {
156 fn try_from_values<'i, I>(values: &mut I) -> Result<Self, Error>
157 where
158 I: Iterator<Item = &'i HeaderValue>,
159 {
160 values
161 .just_one()
162 .and_then(OriginOrNull::try_from_value)
163 .ok_or_else(Error::invalid)
164 }
165}
166
167impl<'a> From<&'a OriginOrNull> for HeaderValue {
168 fn from(origin: &'a OriginOrNull) -> HeaderValue {
169 match origin {
170 OriginOrNull::Origin(ref scheme, ref auth) => {
171 let s = format!("{}://{}", scheme, auth);
172 let bytes = Bytes::from(s);
173 HeaderValue::from_maybe_shared(bytes)
174 .expect("Scheme and Authority are valid header values")
175 }
176 OriginOrNull::Null => HeaderValue::from_static("null"),
179 }
180 }
181}
182
183#[cfg(test)]
184mod tests {
185 use super::super::{test_decode, test_encode};
186 use super::*;
187
188 #[test]
189 fn origin() {
190 let s = "http://web-platform.test:8000";
191 let origin = test_decode::<Origin>(&[s]).unwrap();
192 assert_eq!(origin.scheme(), "http");
193 assert_eq!(origin.hostname(), "web-platform.test");
194 assert_eq!(origin.port(), Some(8000));
195
196 let headers = test_encode(origin);
197 assert_eq!(headers["origin"], s);
198 }
199
200 #[test]
201 fn null() {
202 assert_eq!(test_decode::<Origin>(&["null"]), Some(Origin::NULL),);
203
204 let headers = test_encode(Origin::NULL);
205 assert_eq!(headers["origin"], "null");
206 }
207}