1use std::{fmt, str};
2
3use super::common_header;
4use crate::http::header;
5
6common_header! {
7    (CacheControl, header::CACHE_CONTROL) => (CacheDirective)+
48
49    test_parse_and_format {
50        common_header_test!(no_headers, [b""; 0], None);
51        common_header_test!(empty_header, [b""; 1], None);
52        common_header_test!(bad_syntax, [b"foo="], None);
53
54        common_header_test!(
55            multiple_headers,
56            [&b"no-cache"[..], &b"private"[..]],
57            Some(CacheControl(vec![
58                CacheDirective::NoCache,
59                CacheDirective::Private,
60            ]))
61        );
62
63        common_header_test!(
64            argument,
65            [b"max-age=100, private"],
66            Some(CacheControl(vec![
67                CacheDirective::MaxAge(100),
68                CacheDirective::Private,
69            ]))
70        );
71
72        common_header_test!(
73            extension,
74            [b"foo, bar=baz"],
75            Some(CacheControl(vec![
76                CacheDirective::Extension("foo".to_owned(), None),
77                CacheDirective::Extension("bar".to_owned(), Some("baz".to_owned())),
78            ]))
79        );
80
81        #[test]
82        fn parse_quote_form() {
83            let req = test::TestRequest::default()
84                .insert_header((header::CACHE_CONTROL, "max-age=\"200\""))
85                .finish();
86
87            assert_eq!(
88                Header::parse(&req).ok(),
89                Some(CacheControl(vec![CacheDirective::MaxAge(200)]))
90            )
91        }
92    }
93}
94
95#[derive(Debug, Clone, PartialEq, Eq)]
97pub enum CacheDirective {
98    NoCache,
100    NoStore,
102    NoTransform,
104    OnlyIfCached,
106
107    MaxAge(u32),
110    MaxStale(u32),
112    MinFresh(u32),
114
115    MustRevalidate,
118    Public,
120    Private,
122    ProxyRevalidate,
124    SMaxAge(u32),
126
127    Extension(String, Option<String>),
129}
130
131impl fmt::Display for CacheDirective {
132    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
133        use self::CacheDirective::*;
134
135        let dir_str = match self {
136            NoCache => "no-cache",
137            NoStore => "no-store",
138            NoTransform => "no-transform",
139            OnlyIfCached => "only-if-cached",
140
141            MaxAge(secs) => return write!(f, "max-age={}", secs),
142            MaxStale(secs) => return write!(f, "max-stale={}", secs),
143            MinFresh(secs) => return write!(f, "min-fresh={}", secs),
144
145            MustRevalidate => "must-revalidate",
146            Public => "public",
147            Private => "private",
148            ProxyRevalidate => "proxy-revalidate",
149            SMaxAge(secs) => return write!(f, "s-maxage={}", secs),
150
151            Extension(name, None) => name.as_str(),
152            Extension(name, Some(arg)) => return write!(f, "{}={}", name, arg),
153        };
154
155        f.write_str(dir_str)
156    }
157}
158
159impl str::FromStr for CacheDirective {
160    type Err = Option<<u32 as str::FromStr>::Err>;
161
162    fn from_str(s: &str) -> Result<Self, Self::Err> {
163        use self::CacheDirective::*;
164
165        match s {
166            "" => Err(None),
167
168            "no-cache" => Ok(NoCache),
169            "no-store" => Ok(NoStore),
170            "no-transform" => Ok(NoTransform),
171            "only-if-cached" => Ok(OnlyIfCached),
172            "must-revalidate" => Ok(MustRevalidate),
173            "public" => Ok(Public),
174            "private" => Ok(Private),
175            "proxy-revalidate" => Ok(ProxyRevalidate),
176
177            _ => match s.find('=') {
178                Some(idx) if idx + 1 < s.len() => {
179                    match (&s[..idx], s[idx + 1..].trim_matches('"')) {
180                        ("max-age", secs) => secs.parse().map(MaxAge).map_err(Some),
181                        ("max-stale", secs) => secs.parse().map(MaxStale).map_err(Some),
182                        ("min-fresh", secs) => secs.parse().map(MinFresh).map_err(Some),
183                        ("s-maxage", secs) => secs.parse().map(SMaxAge).map_err(Some),
184                        (left, right) => Ok(Extension(left.to_owned(), Some(right.to_owned()))),
185                    }
186                }
187                Some(_) => Err(None),
188                None => Ok(Extension(s.to_owned(), None)),
189            },
190        }
191    }
192}