1use std::collections::HashMap;
6use std::net::{Ipv4Addr, Ipv6Addr};
7use std::num::NonZeroU64;
8use std::sync::LazyLock;
9use std::time::Duration;
10
11use embedder_traits::resources::{self, Resource};
12use fst::{Map, MapBuilder};
13use headers::{HeaderMapExt, StrictTransportSecurity};
14use http::HeaderMap;
15use log::{debug, error, info};
16use malloc_size_of::{MallocSizeOf, MallocSizeOfOps};
17use malloc_size_of_derive::MallocSizeOf;
18use net_traits::IncludeSubdomains;
19use net_traits::pub_domains::reg_suffix;
20use serde::{Deserialize, Serialize};
21use servo_config::pref;
22use servo_url::{Host, ServoUrl};
23use time::UtcDateTime;
24
25#[derive(Clone, Debug, Deserialize, MallocSizeOf, Serialize)]
26pub struct HstsEntry {
27 pub host: String,
28 pub include_subdomains: bool,
29 pub expires_at: Option<NonZeroU64>,
31}
32
33fn unix_timestamp_to_nonzerou64(timestamp: i64) -> NonZeroU64 {
35 if timestamp <= 0 {
36 NonZeroU64::new(1).unwrap()
37 } else {
38 NonZeroU64::new(timestamp.try_into().unwrap()).unwrap()
39 }
40}
41
42impl HstsEntry {
43 pub fn new(
44 host: String,
45 subdomains: IncludeSubdomains,
46 max_age: Option<Duration>,
47 ) -> Option<HstsEntry> {
48 let expires_at = max_age.map(|duration| {
49 unix_timestamp_to_nonzerou64((UtcDateTime::now() + duration).unix_timestamp())
50 });
51 if host.parse::<Ipv4Addr>().is_ok() || host.parse::<Ipv6Addr>().is_ok() {
52 None
53 } else {
54 Some(HstsEntry {
55 host,
56 include_subdomains: (subdomains == IncludeSubdomains::Included),
57 expires_at,
58 })
59 }
60 }
61
62 pub fn is_expired(&self) -> bool {
63 match self.expires_at {
64 Some(timestamp) => {
65 unix_timestamp_to_nonzerou64(UtcDateTime::now().unix_timestamp()) >= timestamp
66 },
67 _ => false,
68 }
69 }
70
71 fn matches_domain(&self, host: &str) -> bool {
72 self.host == host
73 }
74
75 fn matches_subdomain(&self, host: &str) -> bool {
76 host.ends_with(&format!(".{}", self.host))
77 }
78}
79
80#[derive(Clone, Debug, Default, Deserialize, MallocSizeOf, Serialize)]
81pub struct HstsList {
82 pub entries_map: HashMap<String, Vec<HstsEntry>>,
84}
85
86#[derive(Clone, Debug)]
92pub struct HstsPreloadList(pub fst::Map<Vec<u8>>);
93
94impl MallocSizeOf for HstsPreloadList {
95 #[allow(unsafe_code)]
96 fn size_of(&self, ops: &mut malloc_size_of::MallocSizeOfOps) -> usize {
97 unsafe { ops.malloc_size_of(self.0.as_fst().as_inner().as_ptr()) }
98 }
99}
100
101static PRELOAD_LIST_ENTRIES: LazyLock<HstsPreloadList> =
102 LazyLock::new(HstsPreloadList::from_servo_preload);
103
104pub fn hsts_preload_size_of(ops: &mut MallocSizeOfOps) -> usize {
105 PRELOAD_LIST_ENTRIES.size_of(ops)
106}
107
108impl HstsPreloadList {
109 pub fn from_preload(preload_content: Vec<u8>) -> Option<HstsPreloadList> {
111 Map::new(preload_content).map(HstsPreloadList).ok()
112 }
113
114 pub fn from_servo_preload() -> HstsPreloadList {
115 debug!("Intializing HSTS Preload list");
116 let map_bytes = resources::read_bytes(Resource::HstsPreloadList);
117 HstsPreloadList::from_preload(map_bytes).unwrap_or_else(|| {
118 error!("HSTS preload file is invalid. Setting HSTS list to default values");
119 HstsPreloadList(MapBuilder::memory().into_map())
120 })
121 }
122
123 pub fn is_host_secure(&self, host: &str) -> bool {
124 let base_domain = reg_suffix(host);
125 let parts = host[..host.len() - base_domain.len()].rsplit_terminator('.');
126 let mut domain_to_test = base_domain.to_owned();
127
128 if self.0.get(&domain_to_test).is_some_and(|id| {
129 id % 2 == 1 || domain_to_test == host
131 }) {
132 return true;
133 }
134
135 for part in parts {
137 domain_to_test = format!("{}.{}", part, domain_to_test);
138 if self.0.get(&domain_to_test).is_some_and(|id| {
139 id % 2 == 1 || domain_to_test == host
141 }) {
142 return true;
143 }
144 }
145 false
146 }
147}
148
149impl HstsList {
150 pub fn is_host_secure(&self, host: &str) -> bool {
151 if PRELOAD_LIST_ENTRIES.is_host_secure(host) {
152 info!("{host} is in the preload list");
153 return true;
154 }
155
156 let base_domain = reg_suffix(host);
157 self.entries_map.get(base_domain).is_some_and(|entries| {
158 entries.iter().filter(|e| !e.is_expired()).any(|e| {
159 if e.include_subdomains {
160 e.matches_subdomain(host) || e.matches_domain(host)
161 } else {
162 e.matches_domain(host)
163 }
164 })
165 })
166 }
167
168 fn has_domain(&self, host: &str, base_domain: &str) -> bool {
169 self.entries_map
170 .get(base_domain)
171 .is_some_and(|entries| entries.iter().any(|e| e.matches_domain(host)))
172 }
173
174 fn has_subdomain(&self, host: &str, base_domain: &str) -> bool {
175 self.entries_map.get(base_domain).is_some_and(|entries| {
176 entries
177 .iter()
178 .any(|e| e.include_subdomains && e.matches_subdomain(host))
179 })
180 }
181
182 pub fn push(&mut self, entry: HstsEntry) {
183 let host = entry.host.clone();
184 let base_domain = reg_suffix(&host);
185 let have_domain = self.has_domain(&entry.host, base_domain);
186 let have_subdomain = self.has_subdomain(&entry.host, base_domain);
187
188 let entries = self.entries_map.entry(base_domain.to_owned()).or_default();
189 if !have_domain && !have_subdomain {
190 entries.push(entry);
191 } else if !have_subdomain {
192 for e in entries.iter_mut() {
193 if e.matches_domain(&entry.host) {
194 e.include_subdomains = entry.include_subdomains;
195 e.expires_at = entry.expires_at;
196 }
197 }
198 }
199 entries.retain(|e| !e.is_expired());
200 }
201
202 pub fn apply_hsts_rules(&self, url: &mut ServoUrl) {
204 if url.scheme() != "http" && url.scheme() != "ws" {
205 return;
206 }
207
208 let upgrade_scheme = if pref!(network_enforce_tls_enabled) {
209 if (!pref!(network_enforce_tls_localhost) &&
210 match url.host() {
211 Some(Host::Domain(domain)) => {
212 domain.ends_with(".localhost") || domain == "localhost"
213 },
214 Some(Host::Ipv4(ipv4)) => ipv4.is_loopback(),
215 Some(Host::Ipv6(ipv6)) => ipv6.is_loopback(),
216 _ => false,
217 }) ||
218 (!pref!(network_enforce_tls_onion) &&
219 url.domain()
220 .is_some_and(|domain| domain.ends_with(".onion")))
221 {
222 url.domain()
223 .is_some_and(|domain| self.is_host_secure(domain))
224 } else {
225 true
226 }
227 } else {
228 url.domain()
229 .is_some_and(|domain| self.is_host_secure(domain))
230 };
231
232 if upgrade_scheme {
233 let upgraded_scheme = match url.scheme() {
234 "ws" => "wss",
235 _ => "https",
236 };
237 url.as_mut_url().set_scheme(upgraded_scheme).unwrap();
238 }
239 }
240
241 pub fn update_hsts_list_from_response(&mut self, url: &ServoUrl, headers: &HeaderMap) {
242 if url.scheme() != "https" && url.scheme() != "wss" {
243 return;
244 }
245
246 if let Some(header) = headers.typed_get::<StrictTransportSecurity>() {
247 if let Some(host) = url.domain() {
248 let include_subdomains = if header.include_subdomains() {
249 IncludeSubdomains::Included
250 } else {
251 IncludeSubdomains::NotIncluded
252 };
253
254 if let Some(entry) =
255 HstsEntry::new(host.to_owned(), include_subdomains, Some(header.max_age()))
256 {
257 info!("adding host {} to the strict transport security list", host);
258 info!("- max-age {}", header.max_age().as_secs());
259 if header.include_subdomains() {
260 info!("- includeSubdomains");
261 }
262
263 self.push(entry);
264 }
265 }
266 }
267 }
268}