1use std::{
53    cell::{Ref, RefMut},
54    rc::Rc,
55};
56
57use actix_http::{header, Extensions, Method as HttpMethod, RequestHead};
58
59use crate::{http::header::Header, service::ServiceRequest, HttpMessage as _};
60
61mod acceptable;
62mod host;
63
64pub use self::{
65    acceptable::Acceptable,
66    host::{Host, HostGuard},
67};
68
69#[derive(Debug)]
71pub struct GuardContext<'a> {
72    pub(crate) req: &'a ServiceRequest,
73}
74
75impl<'a> GuardContext<'a> {
76    #[inline]
78    pub fn head(&self) -> &RequestHead {
79        self.req.head()
80    }
81
82    #[inline]
84    pub fn req_data(&self) -> Ref<'a, Extensions> {
85        self.req.extensions()
86    }
87
88    #[inline]
90    pub fn req_data_mut(&self) -> RefMut<'a, Extensions> {
91        self.req.extensions_mut()
92    }
93
94    #[inline]
110    pub fn header<H: Header>(&self) -> Option<H> {
111        H::parse(self.req).ok()
112    }
113
114    #[inline]
116    pub fn app_data<T: 'static>(&self) -> Option<&T> {
117        self.req.app_data()
118    }
119}
120
121pub trait Guard {
125    fn check(&self, ctx: &GuardContext<'_>) -> bool;
127}
128
129impl Guard for Rc<dyn Guard> {
130    fn check(&self, ctx: &GuardContext<'_>) -> bool {
131        (**self).check(ctx)
132    }
133}
134
135pub fn fn_guard<F>(f: F) -> impl Guard
148where
149    F: Fn(&GuardContext<'_>) -> bool,
150{
151    FnGuard(f)
152}
153
154struct FnGuard<F: Fn(&GuardContext<'_>) -> bool>(F);
155
156impl<F> Guard for FnGuard<F>
157where
158    F: Fn(&GuardContext<'_>) -> bool,
159{
160    fn check(&self, ctx: &GuardContext<'_>) -> bool {
161        (self.0)(ctx)
162    }
163}
164
165impl<F> Guard for F
166where
167    F: Fn(&GuardContext<'_>) -> bool,
168{
169    fn check(&self, ctx: &GuardContext<'_>) -> bool {
170        (self)(ctx)
171    }
172}
173
174#[allow(non_snake_case)]
188pub fn Any<F: Guard + 'static>(guard: F) -> AnyGuard {
189    AnyGuard {
190        guards: vec![Box::new(guard)],
191    }
192}
193
194pub struct AnyGuard {
200    guards: Vec<Box<dyn Guard>>,
201}
202
203impl AnyGuard {
204    pub fn or<F: Guard + 'static>(mut self, guard: F) -> Self {
206        self.guards.push(Box::new(guard));
207        self
208    }
209}
210
211impl Guard for AnyGuard {
212    #[inline]
213    fn check(&self, ctx: &GuardContext<'_>) -> bool {
214        for guard in &self.guards {
215            if guard.check(ctx) {
216                return true;
217            }
218        }
219
220        false
221    }
222}
223
224#[allow(non_snake_case)]
240pub fn All<F: Guard + 'static>(guard: F) -> AllGuard {
241    AllGuard {
242        guards: vec![Box::new(guard)],
243    }
244}
245
246pub struct AllGuard {
252    guards: Vec<Box<dyn Guard>>,
253}
254
255impl AllGuard {
256    pub fn and<F: Guard + 'static>(mut self, guard: F) -> Self {
258        self.guards.push(Box::new(guard));
259        self
260    }
261}
262
263impl Guard for AllGuard {
264    #[inline]
265    fn check(&self, ctx: &GuardContext<'_>) -> bool {
266        for guard in &self.guards {
267            if !guard.check(ctx) {
268                return false;
269            }
270        }
271
272        true
273    }
274}
275
276pub struct Not<G>(pub G);
288
289impl<G: Guard> Guard for Not<G> {
290    #[inline]
291    fn check(&self, ctx: &GuardContext<'_>) -> bool {
292        !self.0.check(ctx)
293    }
294}
295
296#[allow(non_snake_case)]
298pub fn Method(method: HttpMethod) -> impl Guard {
299    MethodGuard(method)
300}
301
302#[derive(Debug, Clone)]
303pub(crate) struct RegisteredMethods(pub(crate) Vec<HttpMethod>);
304
305#[derive(Debug)]
307pub(crate) struct MethodGuard(HttpMethod);
308
309impl Guard for MethodGuard {
310    fn check(&self, ctx: &GuardContext<'_>) -> bool {
311        let registered = ctx.req_data_mut().remove::<RegisteredMethods>();
312
313        if let Some(mut methods) = registered {
314            methods.0.push(self.0.clone());
315            ctx.req_data_mut().insert(methods);
316        } else {
317            ctx.req_data_mut()
318                .insert(RegisteredMethods(vec![self.0.clone()]));
319        }
320
321        ctx.head().method == self.0
322    }
323}
324
325macro_rules! method_guard {
326    ($method_fn:ident, $method_const:ident) => {
327        #[doc = concat!("Creates a guard that matches the `", stringify!($method_const), "` request method.")]
328        #[doc = concat!("The route in this example will only respond to `", stringify!($method_const), "` requests.")]
331        #[doc = concat!("    .guard(guard::", stringify!($method_fn), "())")]
336        #[allow(non_snake_case)]
339        pub fn $method_fn() -> impl Guard {
340            MethodGuard(HttpMethod::$method_const)
341        }
342    };
343}
344
345method_guard!(Get, GET);
346method_guard!(Post, POST);
347method_guard!(Put, PUT);
348method_guard!(Delete, DELETE);
349method_guard!(Head, HEAD);
350method_guard!(Options, OPTIONS);
351method_guard!(Connect, CONNECT);
352method_guard!(Patch, PATCH);
353method_guard!(Trace, TRACE);
354
355#[allow(non_snake_case)]
368pub fn Header(name: &'static str, value: &'static str) -> impl Guard {
369    HeaderGuard(
370        header::HeaderName::try_from(name).unwrap(),
371        header::HeaderValue::from_static(value),
372    )
373}
374
375struct HeaderGuard(header::HeaderName, header::HeaderValue);
376
377impl Guard for HeaderGuard {
378    fn check(&self, ctx: &GuardContext<'_>) -> bool {
379        if let Some(val) = ctx.head().headers.get(&self.0) {
380            return val == self.1;
381        }
382
383        false
384    }
385}
386
387#[cfg(test)]
388mod tests {
389    use actix_http::Method;
390
391    use super::*;
392    use crate::test::TestRequest;
393
394    #[test]
395    fn header_match() {
396        let req = TestRequest::default()
397            .insert_header((header::TRANSFER_ENCODING, "chunked"))
398            .to_srv_request();
399
400        let hdr = Header("transfer-encoding", "chunked");
401        assert!(hdr.check(&req.guard_ctx()));
402
403        let hdr = Header("transfer-encoding", "other");
404        assert!(!hdr.check(&req.guard_ctx()));
405
406        let hdr = Header("content-type", "chunked");
407        assert!(!hdr.check(&req.guard_ctx()));
408
409        let hdr = Header("content-type", "other");
410        assert!(!hdr.check(&req.guard_ctx()));
411    }
412
413    #[test]
414    fn method_guards() {
415        let get_req = TestRequest::get().to_srv_request();
416        let post_req = TestRequest::post().to_srv_request();
417
418        assert!(Get().check(&get_req.guard_ctx()));
419        assert!(!Get().check(&post_req.guard_ctx()));
420
421        assert!(Post().check(&post_req.guard_ctx()));
422        assert!(!Post().check(&get_req.guard_ctx()));
423
424        let req = TestRequest::put().to_srv_request();
425        assert!(Put().check(&req.guard_ctx()));
426        assert!(!Put().check(&get_req.guard_ctx()));
427
428        let req = TestRequest::patch().to_srv_request();
429        assert!(Patch().check(&req.guard_ctx()));
430        assert!(!Patch().check(&get_req.guard_ctx()));
431
432        let r = TestRequest::delete().to_srv_request();
433        assert!(Delete().check(&r.guard_ctx()));
434        assert!(!Delete().check(&get_req.guard_ctx()));
435
436        let req = TestRequest::default().method(Method::HEAD).to_srv_request();
437        assert!(Head().check(&req.guard_ctx()));
438        assert!(!Head().check(&get_req.guard_ctx()));
439
440        let req = TestRequest::default()
441            .method(Method::OPTIONS)
442            .to_srv_request();
443        assert!(Options().check(&req.guard_ctx()));
444        assert!(!Options().check(&get_req.guard_ctx()));
445
446        let req = TestRequest::default()
447            .method(Method::CONNECT)
448            .to_srv_request();
449        assert!(Connect().check(&req.guard_ctx()));
450        assert!(!Connect().check(&get_req.guard_ctx()));
451
452        let req = TestRequest::default()
453            .method(Method::TRACE)
454            .to_srv_request();
455        assert!(Trace().check(&req.guard_ctx()));
456        assert!(!Trace().check(&get_req.guard_ctx()));
457    }
458
459    #[test]
460    fn aggregate_any() {
461        let req = TestRequest::default()
462            .method(Method::TRACE)
463            .to_srv_request();
464
465        assert!(Any(Trace()).check(&req.guard_ctx()));
466        assert!(Any(Trace()).or(Get()).check(&req.guard_ctx()));
467        assert!(!Any(Get()).or(Get()).check(&req.guard_ctx()));
468    }
469
470    #[test]
471    fn aggregate_all() {
472        let req = TestRequest::default()
473            .method(Method::TRACE)
474            .to_srv_request();
475
476        assert!(All(Trace()).check(&req.guard_ctx()));
477        assert!(All(Trace()).and(Trace()).check(&req.guard_ctx()));
478        assert!(!All(Trace()).and(Get()).check(&req.guard_ctx()));
479    }
480
481    #[test]
482    fn nested_not() {
483        let req = TestRequest::default().to_srv_request();
484
485        let get = Get();
486        assert!(get.check(&req.guard_ctx()));
487
488        let not_get = Not(get);
489        assert!(!not_get.check(&req.guard_ctx()));
490
491        let not_not_get = Not(not_get);
492        assert!(not_not_get.check(&req.guard_ctx()));
493    }
494
495    #[test]
496    fn function_guard() {
497        let domain = "rust-lang.org".to_owned();
498        let guard = fn_guard(|ctx| ctx.head().uri.host().unwrap().ends_with(&domain));
499
500        let req = TestRequest::default()
501            .uri("blog.rust-lang.org")
502            .to_srv_request();
503        assert!(guard.check(&req.guard_ctx()));
504
505        let req = TestRequest::default().uri("crates.io").to_srv_request();
506        assert!(!guard.check(&req.guard_ctx()));
507    }
508
509    #[test]
510    fn mega_nesting() {
511        let guard = fn_guard(|ctx| All(Not(Any(Not(Trace())))).check(ctx));
512
513        let req = TestRequest::default().to_srv_request();
514        assert!(!guard.check(&req.guard_ctx()));
515
516        let req = TestRequest::default()
517            .method(Method::TRACE)
518            .to_srv_request();
519        assert!(guard.check(&req.guard_ctx()));
520    }
521
522    #[test]
523    fn app_data() {
524        const TEST_VALUE: u32 = 42;
525        let guard = fn_guard(|ctx| dbg!(ctx.app_data::<u32>()) == Some(&TEST_VALUE));
526
527        let req = TestRequest::default().app_data(TEST_VALUE).to_srv_request();
528        assert!(guard.check(&req.guard_ctx()));
529
530        let req = TestRequest::default()
531            .app_data(TEST_VALUE * 2)
532            .to_srv_request();
533        assert!(!guard.check(&req.guard_ctx()));
534    }
535}