1use std::{
4    future::Future,
5    mem,
6    pin::Pin,
7    task::{Context, Poll},
8};
9
10use bytes::Bytes;
11use futures_core::ready;
12use pin_project_lite::pin_project;
13
14use crate::{
15    body::EitherBody,
16    dev,
17    web::{Form, Json},
18    Error, FromRequest, HttpRequest, HttpResponse, Responder,
19};
20
21#[derive(Debug, PartialEq, Eq)]
77pub enum Either<L, R> {
78    Left(L),
80
81    Right(R),
83}
84
85impl<T> Either<Form<T>, Json<T>> {
86    pub fn into_inner(self) -> T {
87        match self {
88            Either::Left(form) => form.into_inner(),
89            Either::Right(form) => form.into_inner(),
90        }
91    }
92}
93
94impl<T> Either<Json<T>, Form<T>> {
95    pub fn into_inner(self) -> T {
96        match self {
97            Either::Left(form) => form.into_inner(),
98            Either::Right(form) => form.into_inner(),
99        }
100    }
101}
102
103#[cfg(test)]
104impl<L, R> Either<L, R> {
105    pub(self) fn unwrap_left(self) -> L {
106        match self {
107            Either::Left(data) => data,
108            Either::Right(_) => {
109                panic!("Cannot unwrap Left branch. Either contains an `R` type.")
110            }
111        }
112    }
113
114    pub(self) fn unwrap_right(self) -> R {
115        match self {
116            Either::Left(_) => {
117                panic!("Cannot unwrap Right branch. Either contains an `L` type.")
118            }
119            Either::Right(data) => data,
120        }
121    }
122}
123
124impl<L, R> Responder for Either<L, R>
126where
127    L: Responder,
128    R: Responder,
129{
130    type Body = EitherBody<L::Body, R::Body>;
131
132    fn respond_to(self, req: &HttpRequest) -> HttpResponse<Self::Body> {
133        match self {
134            Either::Left(a) => a.respond_to(req).map_into_left_body(),
135            Either::Right(b) => b.respond_to(req).map_into_right_body(),
136        }
137    }
138}
139
140#[derive(Debug)]
145pub enum EitherExtractError<L, R> {
146    Bytes(Error),
148
149    Extract(L, R),
151}
152
153impl<L, R> From<EitherExtractError<L, R>> for Error
154where
155    L: Into<Error>,
156    R: Into<Error>,
157{
158    fn from(err: EitherExtractError<L, R>) -> Error {
159        match err {
160            EitherExtractError::Bytes(err) => err,
161            EitherExtractError::Extract(a_err, _b_err) => a_err.into(),
162        }
163    }
164}
165
166impl<L, R> FromRequest for Either<L, R>
168where
169    L: FromRequest + 'static,
170    R: FromRequest + 'static,
171{
172    type Error = EitherExtractError<L::Error, R::Error>;
173    type Future = EitherExtractFut<L, R>;
174
175    fn from_request(req: &HttpRequest, payload: &mut dev::Payload) -> Self::Future {
176        EitherExtractFut {
177            req: req.clone(),
178            state: EitherExtractState::Bytes {
179                bytes: Bytes::from_request(req, payload),
180            },
181        }
182    }
183}
184
185pin_project! {
186    pub struct EitherExtractFut<L, R>
187    where
188        R: FromRequest,
189        L: FromRequest,
190    {
191        req: HttpRequest,
192        #[pin]
193        state: EitherExtractState<L, R>,
194    }
195}
196
197pin_project! {
198    #[project = EitherExtractProj]
199    pub enum EitherExtractState<L, R>
200    where
201        L: FromRequest,
202        R: FromRequest,
203    {
204        Bytes {
205            #[pin]
206            bytes: <Bytes as FromRequest>::Future,
207        },
208        Left {
209            #[pin]
210            left: L::Future,
211            fallback: Bytes,
212        },
213        Right {
214            #[pin]
215            right: R::Future,
216            left_err: Option<L::Error>,
217        },
218    }
219}
220
221impl<R, RF, RE, L, LF, LE> Future for EitherExtractFut<L, R>
222where
223    L: FromRequest<Future = LF, Error = LE>,
224    R: FromRequest<Future = RF, Error = RE>,
225    LF: Future<Output = Result<L, LE>> + 'static,
226    RF: Future<Output = Result<R, RE>> + 'static,
227    LE: Into<Error>,
228    RE: Into<Error>,
229{
230    type Output = Result<Either<L, R>, EitherExtractError<LE, RE>>;
231
232    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
233        let mut this = self.project();
234        let ready = loop {
235            let next = match this.state.as_mut().project() {
236                EitherExtractProj::Bytes { bytes } => {
237                    let res = ready!(bytes.poll(cx));
238                    match res {
239                        Ok(bytes) => {
240                            let fallback = bytes.clone();
241                            let left = L::from_request(this.req, &mut dev::Payload::from(bytes));
242                            EitherExtractState::Left { left, fallback }
243                        }
244                        Err(err) => break Err(EitherExtractError::Bytes(err)),
245                    }
246                }
247                EitherExtractProj::Left { left, fallback } => {
248                    let res = ready!(left.poll(cx));
249                    match res {
250                        Ok(extracted) => break Ok(Either::Left(extracted)),
251                        Err(left_err) => {
252                            let right = R::from_request(
253                                this.req,
254                                &mut dev::Payload::from(mem::take(fallback)),
255                            );
256                            EitherExtractState::Right {
257                                left_err: Some(left_err),
258                                right,
259                            }
260                        }
261                    }
262                }
263                EitherExtractProj::Right { right, left_err } => {
264                    let res = ready!(right.poll(cx));
265                    match res {
266                        Ok(data) => break Ok(Either::Right(data)),
267                        Err(err) => {
268                            break Err(EitherExtractError::Extract(left_err.take().unwrap(), err));
269                        }
270                    }
271                }
272            };
273            this.state.set(next);
274        };
275        Poll::Ready(ready)
276    }
277}
278
279#[cfg(test)]
280mod tests {
281    use serde::{Deserialize, Serialize};
282
283    use super::*;
284    use crate::test::TestRequest;
285
286    #[derive(Debug, Clone, Serialize, Deserialize)]
287    struct TestForm {
288        hello: String,
289    }
290
291    #[actix_rt::test]
292    async fn test_either_extract_first_try() {
293        let (req, mut pl) = TestRequest::default()
294            .set_form(TestForm {
295                hello: "world".to_owned(),
296            })
297            .to_http_parts();
298
299        let form = Either::<Form<TestForm>, Json<TestForm>>::from_request(&req, &mut pl)
300            .await
301            .unwrap()
302            .unwrap_left()
303            .into_inner();
304        assert_eq!(&form.hello, "world");
305    }
306
307    #[actix_rt::test]
308    async fn test_either_extract_fallback() {
309        let (req, mut pl) = TestRequest::default()
310            .set_json(TestForm {
311                hello: "world".to_owned(),
312            })
313            .to_http_parts();
314
315        let form = Either::<Form<TestForm>, Json<TestForm>>::from_request(&req, &mut pl)
316            .await
317            .unwrap()
318            .unwrap_right()
319            .into_inner();
320        assert_eq!(&form.hello, "world");
321    }
322
323    #[actix_rt::test]
324    async fn test_either_extract_recursive_fallback() {
325        let (req, mut pl) = TestRequest::default()
326            .set_payload(Bytes::from_static(b"!@$%^&*()"))
327            .to_http_parts();
328
329        let payload =
330            Either::<Either<Form<TestForm>, Json<TestForm>>, Bytes>::from_request(&req, &mut pl)
331                .await
332                .unwrap()
333                .unwrap_right();
334        assert_eq!(&payload.as_ref(), &b"!@$%^&*()");
335    }
336
337    #[actix_rt::test]
338    async fn test_either_extract_recursive_fallback_inner() {
339        let (req, mut pl) = TestRequest::default()
340            .set_json(TestForm {
341                hello: "world".to_owned(),
342            })
343            .to_http_parts();
344
345        let form =
346            Either::<Either<Form<TestForm>, Json<TestForm>>, Bytes>::from_request(&req, &mut pl)
347                .await
348                .unwrap()
349                .unwrap_left()
350                .unwrap_right()
351                .into_inner();
352        assert_eq!(&form.hello, "world");
353    }
354}