1use std::{
4    future::Future,
5    pin::Pin,
6    rc::Rc,
7    task::{Context, Poll},
8};
9
10use actix_service::{Service, Transform};
11use foldhash::HashMap as FoldHashMap;
12use futures_core::{future::LocalBoxFuture, ready};
13use pin_project_lite::pin_project;
14
15use crate::{
16    body::EitherBody,
17    dev::{ServiceRequest, ServiceResponse},
18    http::StatusCode,
19    Error, Result,
20};
21
22pub enum ErrorHandlerResponse<B> {
24    Response(ServiceResponse<EitherBody<B>>),
26
27    Future(LocalBoxFuture<'static, Result<ServiceResponse<EitherBody<B>>, Error>>),
29}
30
31type ErrorHandler<B> = dyn Fn(ServiceResponse<B>) -> Result<ErrorHandlerResponse<B>>;
32
33type DefaultHandler<B> = Option<Rc<ErrorHandler<B>>>;
34
35pub struct ErrorHandlers<B> {
183    default_client: DefaultHandler<B>,
184    default_server: DefaultHandler<B>,
185    handlers: Handlers<B>,
186}
187
188type Handlers<B> = Rc<FoldHashMap<StatusCode, Box<ErrorHandler<B>>>>;
189
190impl<B> Default for ErrorHandlers<B> {
191    fn default() -> Self {
192        ErrorHandlers {
193            default_client: Default::default(),
194            default_server: Default::default(),
195            handlers: Default::default(),
196        }
197    }
198}
199
200impl<B> ErrorHandlers<B> {
201    pub fn new() -> Self {
203        ErrorHandlers::default()
204    }
205
206    pub fn handler<F>(mut self, status: StatusCode, handler: F) -> Self
208    where
209        F: Fn(ServiceResponse<B>) -> Result<ErrorHandlerResponse<B>> + 'static,
210    {
211        Rc::get_mut(&mut self.handlers)
212            .unwrap()
213            .insert(status, Box::new(handler));
214        self
215    }
216
217    pub fn default_handler<F>(self, handler: F) -> Self
230    where
231        F: Fn(ServiceResponse<B>) -> Result<ErrorHandlerResponse<B>> + 'static,
232    {
233        let handler = Rc::new(handler);
234        let handler2 = Rc::clone(&handler);
235        Self {
236            default_server: Some(handler2),
237            default_client: Some(handler),
238            ..self
239        }
240    }
241
242    pub fn default_handler_client<F>(self, handler: F) -> Self
244    where
245        F: Fn(ServiceResponse<B>) -> Result<ErrorHandlerResponse<B>> + 'static,
246    {
247        Self {
248            default_client: Some(Rc::new(handler)),
249            ..self
250        }
251    }
252
253    pub fn default_handler_server<F>(self, handler: F) -> Self
255    where
256        F: Fn(ServiceResponse<B>) -> Result<ErrorHandlerResponse<B>> + 'static,
257    {
258        Self {
259            default_server: Some(Rc::new(handler)),
260            ..self
261        }
262    }
263
264    fn get_handler<'a>(
269        status: &StatusCode,
270        default_client: Option<&'a ErrorHandler<B>>,
271        default_server: Option<&'a ErrorHandler<B>>,
272        handlers: &'a Handlers<B>,
273    ) -> Option<&'a ErrorHandler<B>> {
274        handlers
275            .get(status)
276            .map(|h| h.as_ref())
277            .or_else(|| status.is_client_error().then_some(default_client).flatten())
278            .or_else(|| status.is_server_error().then_some(default_server).flatten())
279    }
280}
281
282impl<S, B> Transform<S, ServiceRequest> for ErrorHandlers<B>
283where
284    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
285    S::Future: 'static,
286    B: 'static,
287{
288    type Response = ServiceResponse<EitherBody<B>>;
289    type Error = Error;
290    type Transform = ErrorHandlersMiddleware<S, B>;
291    type InitError = ();
292    type Future = LocalBoxFuture<'static, Result<Self::Transform, Self::InitError>>;
293
294    fn new_transform(&self, service: S) -> Self::Future {
295        let handlers = Rc::clone(&self.handlers);
296        let default_client = self.default_client.clone();
297        let default_server = self.default_server.clone();
298        Box::pin(async move {
299            Ok(ErrorHandlersMiddleware {
300                service,
301                default_client,
302                default_server,
303                handlers,
304            })
305        })
306    }
307}
308
309#[doc(hidden)]
310pub struct ErrorHandlersMiddleware<S, B> {
311    service: S,
312    default_client: DefaultHandler<B>,
313    default_server: DefaultHandler<B>,
314    handlers: Handlers<B>,
315}
316
317impl<S, B> Service<ServiceRequest> for ErrorHandlersMiddleware<S, B>
318where
319    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
320    S::Future: 'static,
321    B: 'static,
322{
323    type Response = ServiceResponse<EitherBody<B>>;
324    type Error = Error;
325    type Future = ErrorHandlersFuture<S::Future, B>;
326
327    actix_service::forward_ready!(service);
328
329    fn call(&self, req: ServiceRequest) -> Self::Future {
330        let handlers = Rc::clone(&self.handlers);
331        let default_client = self.default_client.clone();
332        let default_server = self.default_server.clone();
333        let fut = self.service.call(req);
334        ErrorHandlersFuture::ServiceFuture {
335            fut,
336            default_client,
337            default_server,
338            handlers,
339        }
340    }
341}
342
343pin_project! {
344    #[project = ErrorHandlersProj]
345    pub enum ErrorHandlersFuture<Fut, B>
346    where
347        Fut: Future,
348    {
349        ServiceFuture {
350            #[pin]
351            fut: Fut,
352            default_client: DefaultHandler<B>,
353            default_server: DefaultHandler<B>,
354            handlers: Handlers<B>,
355        },
356        ErrorHandlerFuture {
357            fut: LocalBoxFuture<'static, Result<ServiceResponse<EitherBody<B>>, Error>>,
358        },
359    }
360}
361
362impl<Fut, B> Future for ErrorHandlersFuture<Fut, B>
363where
364    Fut: Future<Output = Result<ServiceResponse<B>, Error>>,
365{
366    type Output = Result<ServiceResponse<EitherBody<B>>, Error>;
367
368    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
369        match self.as_mut().project() {
370            ErrorHandlersProj::ServiceFuture {
371                fut,
372                default_client,
373                default_server,
374                handlers,
375            } => {
376                let res = ready!(fut.poll(cx))?;
377                let status = res.status();
378
379                let handler = ErrorHandlers::get_handler(
380                    &status,
381                    default_client.as_mut().map(|f| Rc::as_ref(f)),
382                    default_server.as_mut().map(|f| Rc::as_ref(f)),
383                    handlers,
384                );
385                match handler {
386                    Some(handler) => match handler(res)? {
387                        ErrorHandlerResponse::Response(res) => Poll::Ready(Ok(res)),
388                        ErrorHandlerResponse::Future(fut) => {
389                            self.as_mut()
390                                .set(ErrorHandlersFuture::ErrorHandlerFuture { fut });
391
392                            self.poll(cx)
393                        }
394                    },
395                    None => Poll::Ready(Ok(res.map_into_left_body())),
396                }
397            }
398
399            ErrorHandlersProj::ErrorHandlerFuture { fut } => fut.as_mut().poll(cx),
400        }
401    }
402}
403
404#[cfg(test)]
405mod tests {
406    use actix_service::IntoService;
407    use actix_utils::future::ok;
408    use bytes::Bytes;
409    use futures_util::FutureExt as _;
410
411    use super::*;
412    use crate::{
413        body,
414        http::header::{HeaderValue, CONTENT_TYPE},
415        test::{self, TestRequest},
416    };
417
418    #[actix_rt::test]
419    async fn add_header_error_handler() {
420        #[allow(clippy::unnecessary_wraps)]
421        fn error_handler<B>(mut res: ServiceResponse<B>) -> Result<ErrorHandlerResponse<B>> {
422            res.response_mut()
423                .headers_mut()
424                .insert(CONTENT_TYPE, HeaderValue::from_static("0001"));
425
426            Ok(ErrorHandlerResponse::Response(res.map_into_left_body()))
427        }
428
429        let srv = test::status_service(StatusCode::INTERNAL_SERVER_ERROR);
430
431        let mw = ErrorHandlers::new()
432            .handler(StatusCode::INTERNAL_SERVER_ERROR, error_handler)
433            .new_transform(srv.into_service())
434            .await
435            .unwrap();
436
437        let resp = test::call_service(&mw, TestRequest::default().to_srv_request()).await;
438        assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001");
439    }
440
441    #[actix_rt::test]
442    async fn add_header_error_handler_async() {
443        #[allow(clippy::unnecessary_wraps)]
444        fn error_handler<B: 'static>(
445            mut res: ServiceResponse<B>,
446        ) -> Result<ErrorHandlerResponse<B>> {
447            res.response_mut()
448                .headers_mut()
449                .insert(CONTENT_TYPE, HeaderValue::from_static("0001"));
450
451            Ok(ErrorHandlerResponse::Future(
452                ok(res.map_into_left_body()).boxed_local(),
453            ))
454        }
455
456        let srv = test::status_service(StatusCode::INTERNAL_SERVER_ERROR);
457
458        let mw = ErrorHandlers::new()
459            .handler(StatusCode::INTERNAL_SERVER_ERROR, error_handler)
460            .new_transform(srv.into_service())
461            .await
462            .unwrap();
463
464        let resp = test::call_service(&mw, TestRequest::default().to_srv_request()).await;
465        assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001");
466    }
467
468    #[actix_rt::test]
469    async fn changes_body_type() {
470        #[allow(clippy::unnecessary_wraps)]
471        fn error_handler<B>(res: ServiceResponse<B>) -> Result<ErrorHandlerResponse<B>> {
472            let (req, res) = res.into_parts();
473            let res = res.set_body(Bytes::from("sorry, that's no bueno"));
474
475            let res = ServiceResponse::new(req, res)
476                .map_into_boxed_body()
477                .map_into_right_body();
478
479            Ok(ErrorHandlerResponse::Response(res))
480        }
481
482        let srv = test::status_service(StatusCode::INTERNAL_SERVER_ERROR);
483
484        let mw = ErrorHandlers::new()
485            .handler(StatusCode::INTERNAL_SERVER_ERROR, error_handler)
486            .new_transform(srv.into_service())
487            .await
488            .unwrap();
489
490        let res = test::call_service(&mw, TestRequest::default().to_srv_request()).await;
491        assert_eq!(test::read_body(res).await, "sorry, that's no bueno");
492    }
493
494    #[actix_rt::test]
495    async fn error_thrown() {
496        #[allow(clippy::unnecessary_wraps)]
497        fn error_handler<B>(_res: ServiceResponse<B>) -> Result<ErrorHandlerResponse<B>> {
498            Err(crate::error::ErrorInternalServerError(
499                "error in error handler",
500            ))
501        }
502
503        let srv = test::status_service(StatusCode::BAD_REQUEST);
504
505        let mw = ErrorHandlers::new()
506            .handler(StatusCode::BAD_REQUEST, error_handler)
507            .new_transform(srv.into_service())
508            .await
509            .unwrap();
510
511        let err = mw
512            .call(TestRequest::default().to_srv_request())
513            .await
514            .unwrap_err();
515        let res = err.error_response();
516
517        assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR);
518        assert_eq!(
519            body::to_bytes(res.into_body()).await.unwrap(),
520            "error in error handler"
521        );
522    }
523
524    #[actix_rt::test]
525    async fn default_error_handler() {
526        #[allow(clippy::unnecessary_wraps)]
527        fn error_handler<B>(mut res: ServiceResponse<B>) -> Result<ErrorHandlerResponse<B>> {
528            res.response_mut()
529                .headers_mut()
530                .insert(CONTENT_TYPE, HeaderValue::from_static("0001"));
531            Ok(ErrorHandlerResponse::Response(res.map_into_left_body()))
532        }
533
534        let make_mw = |status| async move {
535            ErrorHandlers::new()
536                .default_handler(error_handler)
537                .new_transform(test::status_service(status).into_service())
538                .await
539                .unwrap()
540        };
541        let mw_server = make_mw(StatusCode::INTERNAL_SERVER_ERROR).await;
542        let mw_client = make_mw(StatusCode::BAD_REQUEST).await;
543
544        let resp = test::call_service(&mw_client, TestRequest::default().to_srv_request()).await;
545        assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001");
546
547        let resp = test::call_service(&mw_server, TestRequest::default().to_srv_request()).await;
548        assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001");
549    }
550
551    #[actix_rt::test]
552    async fn default_handlers_separate_client_server() {
553        #[allow(clippy::unnecessary_wraps)]
554        fn error_handler_client<B>(mut res: ServiceResponse<B>) -> Result<ErrorHandlerResponse<B>> {
555            res.response_mut()
556                .headers_mut()
557                .insert(CONTENT_TYPE, HeaderValue::from_static("0001"));
558            Ok(ErrorHandlerResponse::Response(res.map_into_left_body()))
559        }
560
561        #[allow(clippy::unnecessary_wraps)]
562        fn error_handler_server<B>(mut res: ServiceResponse<B>) -> Result<ErrorHandlerResponse<B>> {
563            res.response_mut()
564                .headers_mut()
565                .insert(CONTENT_TYPE, HeaderValue::from_static("0002"));
566            Ok(ErrorHandlerResponse::Response(res.map_into_left_body()))
567        }
568
569        let make_mw = |status| async move {
570            ErrorHandlers::new()
571                .default_handler_server(error_handler_server)
572                .default_handler_client(error_handler_client)
573                .new_transform(test::status_service(status).into_service())
574                .await
575                .unwrap()
576        };
577        let mw_server = make_mw(StatusCode::INTERNAL_SERVER_ERROR).await;
578        let mw_client = make_mw(StatusCode::BAD_REQUEST).await;
579
580        let resp = test::call_service(&mw_client, TestRequest::default().to_srv_request()).await;
581        assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001");
582
583        let resp = test::call_service(&mw_server, TestRequest::default().to_srv_request()).await;
584        assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0002");
585    }
586
587    #[actix_rt::test]
588    async fn default_handlers_specialization() {
589        #[allow(clippy::unnecessary_wraps)]
590        fn error_handler_client<B>(mut res: ServiceResponse<B>) -> Result<ErrorHandlerResponse<B>> {
591            res.response_mut()
592                .headers_mut()
593                .insert(CONTENT_TYPE, HeaderValue::from_static("0001"));
594            Ok(ErrorHandlerResponse::Response(res.map_into_left_body()))
595        }
596
597        #[allow(clippy::unnecessary_wraps)]
598        fn error_handler_specific<B>(
599            mut res: ServiceResponse<B>,
600        ) -> Result<ErrorHandlerResponse<B>> {
601            res.response_mut()
602                .headers_mut()
603                .insert(CONTENT_TYPE, HeaderValue::from_static("0003"));
604            Ok(ErrorHandlerResponse::Response(res.map_into_left_body()))
605        }
606
607        let make_mw = |status| async move {
608            ErrorHandlers::new()
609                .default_handler_client(error_handler_client)
610                .handler(StatusCode::UNPROCESSABLE_ENTITY, error_handler_specific)
611                .new_transform(test::status_service(status).into_service())
612                .await
613                .unwrap()
614        };
615        let mw_client = make_mw(StatusCode::BAD_REQUEST).await;
616        let mw_specific = make_mw(StatusCode::UNPROCESSABLE_ENTITY).await;
617
618        let resp = test::call_service(&mw_client, TestRequest::default().to_srv_request()).await;
619        assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0001");
620
621        let resp = test::call_service(&mw_specific, TestRequest::default().to_srv_request()).await;
622        assert_eq!(resp.headers().get(CONTENT_TYPE).unwrap(), "0003");
623    }
624}