1use std::fmt;
2use std::iter::FromIterator;
3use std::str::FromStr;
4use std::time::Duration;
5
6use http::{HeaderName, HeaderValue};
7
8use crate::util::{self, csv, Seconds};
9use crate::{Error, Header};
10
11#[derive(PartialEq, Clone, Debug)]
40pub struct CacheControl {
41 flags: Flags,
42 max_age: Option<Seconds>,
43 max_stale: Option<Seconds>,
44 min_fresh: Option<Seconds>,
45 s_max_age: Option<Seconds>,
46}
47
48#[derive(Debug, Clone, PartialEq)]
49struct Flags {
50 bits: u64,
51}
52
53impl Flags {
54 const NO_CACHE: Self = Self { bits: 0b000000001 };
55 const NO_STORE: Self = Self { bits: 0b000000010 };
56 const NO_TRANSFORM: Self = Self { bits: 0b000000100 };
57 const ONLY_IF_CACHED: Self = Self { bits: 0b000001000 };
58 const MUST_REVALIDATE: Self = Self { bits: 0b000010000 };
59 const PUBLIC: Self = Self { bits: 0b000100000 };
60 const PRIVATE: Self = Self { bits: 0b001000000 };
61 const PROXY_REVALIDATE: Self = Self { bits: 0b010000000 };
62 const IMMUTABLE: Self = Self { bits: 0b100000000 };
63 const MUST_UNDERSTAND: Self = Self { bits: 0b1000000000 };
64
65 fn empty() -> Self {
66 Self { bits: 0 }
67 }
68
69 fn contains(&self, flag: Self) -> bool {
70 (self.bits & flag.bits) != 0
71 }
72
73 fn insert(&mut self, flag: Self) {
74 self.bits |= flag.bits;
75 }
76}
77
78impl CacheControl {
79 pub fn new() -> Self {
81 CacheControl {
82 flags: Flags::empty(),
83 max_age: None,
84 max_stale: None,
85 min_fresh: None,
86 s_max_age: None,
87 }
88 }
89
90 pub fn no_cache(&self) -> bool {
94 self.flags.contains(Flags::NO_CACHE)
95 }
96
97 pub fn no_store(&self) -> bool {
99 self.flags.contains(Flags::NO_STORE)
100 }
101
102 pub fn no_transform(&self) -> bool {
104 self.flags.contains(Flags::NO_TRANSFORM)
105 }
106
107 pub fn only_if_cached(&self) -> bool {
109 self.flags.contains(Flags::ONLY_IF_CACHED)
110 }
111
112 pub fn public(&self) -> bool {
114 self.flags.contains(Flags::PUBLIC)
115 }
116
117 pub fn private(&self) -> bool {
119 self.flags.contains(Flags::PRIVATE)
120 }
121
122 pub fn immutable(&self) -> bool {
124 self.flags.contains(Flags::IMMUTABLE)
125 }
126
127 pub fn must_revalidate(&self) -> bool {
129 self.flags.contains(Flags::MUST_REVALIDATE)
130 }
131
132 pub fn must_understand(&self) -> bool {
134 self.flags.contains(Flags::MUST_UNDERSTAND)
135 }
136
137 pub fn max_age(&self) -> Option<Duration> {
139 self.max_age.map(Into::into)
140 }
141
142 pub fn max_stale(&self) -> Option<Duration> {
144 self.max_stale.map(Into::into)
145 }
146
147 pub fn min_fresh(&self) -> Option<Duration> {
149 self.min_fresh.map(Into::into)
150 }
151
152 pub fn s_max_age(&self) -> Option<Duration> {
154 self.s_max_age.map(Into::into)
155 }
156
157 pub fn with_no_cache(mut self) -> Self {
161 self.flags.insert(Flags::NO_CACHE);
162 self
163 }
164
165 pub fn with_no_store(mut self) -> Self {
167 self.flags.insert(Flags::NO_STORE);
168 self
169 }
170
171 pub fn with_no_transform(mut self) -> Self {
173 self.flags.insert(Flags::NO_TRANSFORM);
174 self
175 }
176
177 pub fn with_only_if_cached(mut self) -> Self {
179 self.flags.insert(Flags::ONLY_IF_CACHED);
180 self
181 }
182
183 pub fn with_private(mut self) -> Self {
185 self.flags.insert(Flags::PRIVATE);
186 self
187 }
188
189 pub fn with_public(mut self) -> Self {
191 self.flags.insert(Flags::PUBLIC);
192 self
193 }
194
195 pub fn with_immutable(mut self) -> Self {
197 self.flags.insert(Flags::IMMUTABLE);
198 self
199 }
200
201 pub fn with_must_revalidate(mut self) -> Self {
203 self.flags.insert(Flags::MUST_REVALIDATE);
204 self
205 }
206
207 pub fn with_must_understand(mut self) -> Self {
209 self.flags.insert(Flags::MUST_UNDERSTAND);
210 self
211 }
212
213 pub fn with_max_age(mut self, duration: Duration) -> Self {
215 self.max_age = Some(duration.into());
216 self
217 }
218
219 pub fn with_max_stale(mut self, duration: Duration) -> Self {
221 self.max_stale = Some(duration.into());
222 self
223 }
224
225 pub fn with_min_fresh(mut self, duration: Duration) -> Self {
227 self.min_fresh = Some(duration.into());
228 self
229 }
230
231 pub fn with_s_max_age(mut self, duration: Duration) -> Self {
233 self.s_max_age = Some(duration.into());
234 self
235 }
236}
237
238impl Header for CacheControl {
239 fn name() -> &'static HeaderName {
240 &::http::header::CACHE_CONTROL
241 }
242
243 fn decode<'i, I: Iterator<Item = &'i HeaderValue>>(values: &mut I) -> Result<Self, Error> {
244 csv::from_comma_delimited(values).map(|FromIter(cc)| cc)
245 }
246
247 fn encode<E: Extend<HeaderValue>>(&self, values: &mut E) {
248 values.extend(::std::iter::once(util::fmt(Fmt(self))));
249 }
250}
251
252struct FromIter(CacheControl);
254
255impl FromIterator<KnownDirective> for FromIter {
256 fn from_iter<I>(iter: I) -> Self
257 where
258 I: IntoIterator<Item = KnownDirective>,
259 {
260 let mut cc = CacheControl::new();
261
262 let iter = iter.into_iter().filter_map(|dir| match dir {
264 KnownDirective::Known(dir) => Some(dir),
265 KnownDirective::Unknown => None,
266 });
267
268 for directive in iter {
269 match directive {
270 Directive::NoCache => {
271 cc.flags.insert(Flags::NO_CACHE);
272 }
273 Directive::NoStore => {
274 cc.flags.insert(Flags::NO_STORE);
275 }
276 Directive::NoTransform => {
277 cc.flags.insert(Flags::NO_TRANSFORM);
278 }
279 Directive::OnlyIfCached => {
280 cc.flags.insert(Flags::ONLY_IF_CACHED);
281 }
282 Directive::MustRevalidate => {
283 cc.flags.insert(Flags::MUST_REVALIDATE);
284 }
285 Directive::MustUnderstand => {
286 cc.flags.insert(Flags::MUST_UNDERSTAND);
287 }
288 Directive::Public => {
289 cc.flags.insert(Flags::PUBLIC);
290 }
291 Directive::Private => {
292 cc.flags.insert(Flags::PRIVATE);
293 }
294 Directive::Immutable => {
295 cc.flags.insert(Flags::IMMUTABLE);
296 }
297 Directive::ProxyRevalidate => {
298 cc.flags.insert(Flags::PROXY_REVALIDATE);
299 }
300 Directive::MaxAge(secs) => {
301 cc.max_age = Some(Duration::from_secs(secs).into());
302 }
303 Directive::MaxStale(secs) => {
304 cc.max_stale = Some(Duration::from_secs(secs).into());
305 }
306 Directive::MinFresh(secs) => {
307 cc.min_fresh = Some(Duration::from_secs(secs).into());
308 }
309 Directive::SMaxAge(secs) => {
310 cc.s_max_age = Some(Duration::from_secs(secs).into());
311 }
312 }
313 }
314
315 FromIter(cc)
316 }
317}
318
319struct Fmt<'a>(&'a CacheControl);
320
321impl fmt::Display for Fmt<'_> {
322 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
323 let if_flag = |f: Flags, dir: Directive| {
324 if self.0.flags.contains(f) {
325 Some(dir)
326 } else {
327 None
328 }
329 };
330
331 let slice = &[
332 if_flag(Flags::NO_CACHE, Directive::NoCache),
333 if_flag(Flags::NO_STORE, Directive::NoStore),
334 if_flag(Flags::NO_TRANSFORM, Directive::NoTransform),
335 if_flag(Flags::ONLY_IF_CACHED, Directive::OnlyIfCached),
336 if_flag(Flags::MUST_REVALIDATE, Directive::MustRevalidate),
337 if_flag(Flags::PUBLIC, Directive::Public),
338 if_flag(Flags::PRIVATE, Directive::Private),
339 if_flag(Flags::IMMUTABLE, Directive::Immutable),
340 if_flag(Flags::MUST_UNDERSTAND, Directive::MustUnderstand),
341 if_flag(Flags::PROXY_REVALIDATE, Directive::ProxyRevalidate),
342 self.0
343 .max_age
344 .as_ref()
345 .map(|s| Directive::MaxAge(s.as_u64())),
346 self.0
347 .max_stale
348 .as_ref()
349 .map(|s| Directive::MaxStale(s.as_u64())),
350 self.0
351 .min_fresh
352 .as_ref()
353 .map(|s| Directive::MinFresh(s.as_u64())),
354 self.0
355 .s_max_age
356 .as_ref()
357 .map(|s| Directive::SMaxAge(s.as_u64())),
358 ];
359
360 let iter = slice.iter().filter_map(|o| *o);
361
362 csv::fmt_comma_delimited(f, iter)
363 }
364}
365
366#[derive(Clone, Copy)]
367enum KnownDirective {
368 Known(Directive),
369 Unknown,
370}
371
372#[derive(Clone, Copy)]
373enum Directive {
374 NoCache,
375 NoStore,
376 NoTransform,
377 OnlyIfCached,
378
379 MaxAge(u64),
381 MaxStale(u64),
382 MinFresh(u64),
383
384 MustRevalidate,
386 MustUnderstand,
387 Public,
388 Private,
389 Immutable,
390 ProxyRevalidate,
391 SMaxAge(u64),
392}
393
394impl fmt::Display for Directive {
395 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
396 fmt::Display::fmt(
397 match *self {
398 Directive::NoCache => "no-cache",
399 Directive::NoStore => "no-store",
400 Directive::NoTransform => "no-transform",
401 Directive::OnlyIfCached => "only-if-cached",
402
403 Directive::MaxAge(secs) => return write!(f, "max-age={}", secs),
404 Directive::MaxStale(secs) => return write!(f, "max-stale={}", secs),
405 Directive::MinFresh(secs) => return write!(f, "min-fresh={}", secs),
406
407 Directive::MustRevalidate => "must-revalidate",
408 Directive::MustUnderstand => "must-understand",
409 Directive::Public => "public",
410 Directive::Private => "private",
411 Directive::Immutable => "immutable",
412 Directive::ProxyRevalidate => "proxy-revalidate",
413 Directive::SMaxAge(secs) => return write!(f, "s-maxage={}", secs),
414 },
415 f,
416 )
417 }
418}
419
420impl FromStr for KnownDirective {
421 type Err = ();
422 fn from_str(s: &str) -> Result<Self, Self::Err> {
423 Ok(KnownDirective::Known(match s {
424 "no-cache" => Directive::NoCache,
425 "no-store" => Directive::NoStore,
426 "no-transform" => Directive::NoTransform,
427 "only-if-cached" => Directive::OnlyIfCached,
428 "must-revalidate" => Directive::MustRevalidate,
429 "public" => Directive::Public,
430 "private" => Directive::Private,
431 "immutable" => Directive::Immutable,
432 "must-understand" => Directive::MustUnderstand,
433 "proxy-revalidate" => Directive::ProxyRevalidate,
434 "" => return Err(()),
435 _ => match s.find('=') {
436 Some(idx) if idx + 1 < s.len() => {
437 match (&s[..idx], (s[idx + 1..]).trim_matches('"')) {
438 ("max-age", secs) => secs.parse().map(Directive::MaxAge).map_err(|_| ())?,
439 ("max-stale", secs) => {
440 secs.parse().map(Directive::MaxStale).map_err(|_| ())?
441 }
442 ("min-fresh", secs) => {
443 secs.parse().map(Directive::MinFresh).map_err(|_| ())?
444 }
445 ("s-maxage", secs) => {
446 secs.parse().map(Directive::SMaxAge).map_err(|_| ())?
447 }
448 _unknown => return Ok(KnownDirective::Unknown),
449 }
450 }
451 Some(_) | None => return Ok(KnownDirective::Unknown),
452 },
453 }))
454 }
455}
456
457#[cfg(test)]
458mod tests {
459 use super::super::{test_decode, test_encode};
460 use super::*;
461
462 #[test]
463 fn test_parse_multiple_headers() {
464 assert_eq!(
465 test_decode::<CacheControl>(&["no-cache", "private"]).unwrap(),
466 CacheControl::new().with_no_cache().with_private(),
467 );
468 }
469
470 #[test]
471 fn test_parse_argument() {
472 assert_eq!(
473 test_decode::<CacheControl>(&["max-age=100, private"]).unwrap(),
474 CacheControl::new()
475 .with_max_age(Duration::from_secs(100))
476 .with_private(),
477 );
478 }
479
480 #[test]
481 fn test_parse_quote_form() {
482 assert_eq!(
483 test_decode::<CacheControl>(&["max-age=\"200\""]).unwrap(),
484 CacheControl::new().with_max_age(Duration::from_secs(200)),
485 );
486 }
487
488 #[test]
489 fn test_parse_extension() {
490 assert_eq!(
491 test_decode::<CacheControl>(&["foo, no-cache, bar=baz"]).unwrap(),
492 CacheControl::new().with_no_cache(),
493 "unknown extensions are ignored but shouldn't fail parsing",
494 );
495 }
496
497 #[test]
498 fn test_immutable() {
499 let cc = CacheControl::new().with_immutable();
500 let headers = test_encode(cc.clone());
501 assert_eq!(headers["cache-control"], "immutable");
502 assert_eq!(test_decode::<CacheControl>(&["immutable"]).unwrap(), cc);
503 assert!(cc.immutable());
504 }
505
506 #[test]
507 fn test_must_revalidate() {
508 let cc = CacheControl::new().with_must_revalidate();
509 let headers = test_encode(cc.clone());
510 assert_eq!(headers["cache-control"], "must-revalidate");
511 assert_eq!(
512 test_decode::<CacheControl>(&["must-revalidate"]).unwrap(),
513 cc
514 );
515 assert!(cc.must_revalidate());
516 }
517
518 #[test]
519 fn test_must_understand() {
520 let cc = CacheControl::new().with_must_understand();
521 let headers = test_encode(cc.clone());
522 assert_eq!(headers["cache-control"], "must-understand");
523 assert_eq!(
524 test_decode::<CacheControl>(&["must-understand"]).unwrap(),
525 cc
526 );
527 assert!(cc.must_understand());
528 }
529
530 #[test]
531 fn test_parse_bad_syntax() {
532 assert_eq!(test_decode::<CacheControl>(&["max-age=lolz"]), None);
533 }
534
535 #[test]
536 fn encode_one_flag_directive() {
537 let cc = CacheControl::new().with_no_cache();
538
539 let headers = test_encode(cc);
540 assert_eq!(headers["cache-control"], "no-cache");
541 }
542
543 #[test]
544 fn encode_one_param_directive() {
545 let cc = CacheControl::new().with_max_age(Duration::from_secs(300));
546
547 let headers = test_encode(cc);
548 assert_eq!(headers["cache-control"], "max-age=300");
549 }
550
551 #[test]
552 fn encode_two_directive() {
553 let headers = test_encode(CacheControl::new().with_no_cache().with_private());
554 assert_eq!(headers["cache-control"], "no-cache, private");
555
556 let headers = test_encode(
557 CacheControl::new()
558 .with_no_cache()
559 .with_max_age(Duration::from_secs(100)),
560 );
561 assert_eq!(headers["cache-control"], "no-cache, max-age=100");
562 }
563}