jwt_simple/
claims.rs

1use std::collections::HashSet;
2use std::convert::TryInto;
3
4use coarsetime::{Clock, Duration, UnixTimeStamp};
5use ct_codecs::{Base64UrlSafeNoPadding, Encoder};
6use rand::RngCore;
7use serde::{de::DeserializeOwned, Deserialize, Serialize};
8
9use crate::common::VerificationOptions;
10use crate::error::*;
11use crate::serde_additions;
12
13pub const DEFAULT_TIME_TOLERANCE_SECS: u64 = 900;
14
15/// Type representing the fact that no application-defined claims is necessary.
16#[derive(Copy, Clone, Default, Debug, Serialize, Deserialize)]
17pub struct NoCustomClaims {}
18
19/// Depending on applications, the `audiences` property may be either a set or a
20/// string. We support both.
21#[derive(Debug, Clone, Eq, PartialEq)]
22pub enum Audiences {
23    AsSet(HashSet<String>),
24    AsString(String),
25}
26
27impl Audiences {
28    /// Return `true` if the audiences are represented as a set.
29    pub fn is_set(&self) -> bool {
30        matches!(self, Audiences::AsSet(_))
31    }
32
33    /// Return `true` if the audiences are represented as a string.
34    pub fn is_string(&self) -> bool {
35        matches!(self, Audiences::AsString(_))
36    }
37
38    /// Return `true` if the audiences include any of the `allowed_audiences`
39    /// entries
40    pub fn contains(&self, allowed_audiences: &HashSet<String>) -> bool {
41        match self {
42            Audiences::AsString(audience) => allowed_audiences.contains(audience),
43            Audiences::AsSet(audiences) => {
44                audiences.intersection(allowed_audiences).next().is_some()
45            }
46        }
47    }
48
49    /// Get the audiences as a set
50    pub fn into_set(self) -> HashSet<String> {
51        match self {
52            Audiences::AsSet(audiences_set) => audiences_set,
53            Audiences::AsString(audiences) => {
54                let mut audiences_set = HashSet::new();
55                if !audiences.is_empty() {
56                    audiences_set.insert(audiences);
57                }
58                audiences_set
59            }
60        }
61    }
62
63    /// Get the audiences as a string.
64    /// If it was originally serialized as a set, it can be only converted to a
65    /// string if it contains at most one element.
66    pub fn into_string(self) -> Result<String, Error> {
67        match self {
68            Audiences::AsString(audiences_str) => Ok(audiences_str),
69            Audiences::AsSet(audiences) => {
70                if audiences.len() > 1 {
71                    bail!(JWTError::TooManyAudiences);
72                }
73                Ok(audiences
74                    .iter()
75                    .next()
76                    .map(|x| x.to_string())
77                    .unwrap_or_default())
78            }
79        }
80    }
81}
82
83impl TryInto<String> for Audiences {
84    type Error = Error;
85
86    fn try_into(self) -> Result<String, Error> {
87        self.into_string()
88    }
89}
90
91impl From<Audiences> for HashSet<String> {
92    fn from(audiences: Audiences) -> HashSet<String> {
93        audiences.into_set()
94    }
95}
96
97impl<T: ToString> From<T> for Audiences {
98    fn from(audience: T) -> Self {
99        Audiences::AsString(audience.to_string())
100    }
101}
102
103/// A set of JWT claims.
104///
105/// The `CustomClaims` parameter can be set to `NoCustomClaims` if only standard
106/// claims are used, or to a user-defined type that must be `serde`-serializable
107/// if custom claims are required.
108#[derive(Debug, Clone, Serialize, Deserialize)]
109pub struct JWTClaims<CustomClaims> {
110    /// Time the claims were created at
111    #[serde(
112        rename = "iat",
113        default,
114        skip_serializing_if = "Option::is_none",
115        with = "self::serde_additions::unix_timestamp"
116    )]
117    pub issued_at: Option<UnixTimeStamp>,
118
119    /// Time the claims expire at
120    #[serde(
121        rename = "exp",
122        default,
123        skip_serializing_if = "Option::is_none",
124        with = "self::serde_additions::unix_timestamp"
125    )]
126    pub expires_at: Option<UnixTimeStamp>,
127
128    /// Time the claims will be invalid until
129    #[serde(
130        rename = "nbf",
131        default,
132        skip_serializing_if = "Option::is_none",
133        with = "self::serde_additions::unix_timestamp"
134    )]
135    pub invalid_before: Option<UnixTimeStamp>,
136
137    /// Issuer - This can be set to anything application-specific
138    #[serde(rename = "iss", default, skip_serializing_if = "Option::is_none")]
139    pub issuer: Option<String>,
140
141    /// Subject - This can be set to anything application-specific
142    #[serde(rename = "sub", default, skip_serializing_if = "Option::is_none")]
143    pub subject: Option<String>,
144
145    /// Audience
146    #[serde(
147        rename = "aud",
148        default,
149        skip_serializing_if = "Option::is_none",
150        with = "self::serde_additions::audiences"
151    )]
152    pub audiences: Option<Audiences>,
153
154    /// JWT identifier
155    ///
156    /// That property was originally designed to avoid replay attacks, but
157    /// keeping all previously sent JWT token IDs is unrealistic.
158    ///
159    /// Replay attacks are better addressed by keeping only the timestamp of the
160    /// last valid token for a user, and rejecting anything older in future
161    /// tokens.
162    #[serde(rename = "jti", default, skip_serializing_if = "Option::is_none")]
163    pub jwt_id: Option<String>,
164
165    /// Nonce
166    #[serde(rename = "nonce", default, skip_serializing_if = "Option::is_none")]
167    pub nonce: Option<String>,
168
169    /// Custom (application-defined) claims
170    #[serde(flatten)]
171    pub custom: CustomClaims,
172}
173
174impl<CustomClaims> JWTClaims<CustomClaims> {
175    pub(crate) fn validate(&self, options: &VerificationOptions) -> Result<(), Error> {
176        let now = options
177            .artificial_time
178            .unwrap_or_else(Clock::now_since_epoch);
179        let time_tolerance = options.time_tolerance.unwrap_or_default();
180
181        if let Some(reject_before) = options.reject_before {
182            ensure!(now <= reject_before, JWTError::OldTokenReused);
183        }
184        if let Some(time_issued) = self.issued_at {
185            ensure!(time_issued <= now + time_tolerance, JWTError::ClockDrift);
186            if let Some(max_validity) = options.max_validity {
187                ensure!(
188                    now <= time_issued || now - time_issued <= max_validity,
189                    JWTError::TokenIsTooOld
190                );
191            }
192        }
193        if !options.accept_future {
194            if let Some(invalid_before) = self.invalid_before {
195                ensure!(
196                    now + time_tolerance >= invalid_before,
197                    JWTError::TokenNotValidYet
198                );
199            }
200        }
201        if let Some(expires_at) = self.expires_at {
202            ensure!(
203                now - time_tolerance <= expires_at,
204                JWTError::TokenHasExpired
205            );
206        }
207        if let Some(allowed_issuers) = &options.allowed_issuers {
208            if let Some(issuer) = &self.issuer {
209                ensure!(
210                    allowed_issuers.contains(issuer),
211                    JWTError::RequiredIssuerMismatch
212                );
213            } else {
214                bail!(JWTError::RequiredIssuerMissing);
215            }
216        }
217        if let Some(required_subject) = &options.required_subject {
218            if let Some(subject) = &self.subject {
219                ensure!(
220                    subject == required_subject,
221                    JWTError::RequiredSubjectMismatch
222                );
223            } else {
224                bail!(JWTError::RequiredSubjectMissing);
225            }
226        }
227        if let Some(required_nonce) = &options.required_nonce {
228            if let Some(nonce) = &self.nonce {
229                ensure!(nonce == required_nonce, JWTError::RequiredNonceMismatch);
230            } else {
231                bail!(JWTError::RequiredNonceMissing);
232            }
233        }
234        if let Some(allowed_audiences) = &options.allowed_audiences {
235            if let Some(audiences) = &self.audiences {
236                ensure!(
237                    audiences.contains(allowed_audiences),
238                    JWTError::RequiredAudienceMismatch
239                );
240            } else {
241                bail!(JWTError::RequiredAudienceMissing);
242            }
243        }
244        Ok(())
245    }
246
247    /// Set the token as not being valid until `unix_timestamp`
248    pub fn invalid_before(mut self, unix_timestamp: UnixTimeStamp) -> Self {
249        self.invalid_before = Some(unix_timestamp);
250        self
251    }
252
253    /// Set the issuer
254    pub fn with_issuer(mut self, issuer: impl ToString) -> Self {
255        self.issuer = Some(issuer.to_string());
256        self
257    }
258
259    /// Set the subject
260    pub fn with_subject(mut self, subject: impl ToString) -> Self {
261        self.subject = Some(subject.to_string());
262        self
263    }
264
265    /// Register one or more audiences (optional recipient identifiers), as a
266    /// set
267    pub fn with_audiences(mut self, audiences: HashSet<impl ToString>) -> Self {
268        self.audiences = Some(Audiences::AsSet(
269            audiences.iter().map(|x| x.to_string()).collect(),
270        ));
271        self
272    }
273
274    /// Set a unique audience (an optional recipient identifier), as a string
275    pub fn with_audience(mut self, audience: impl ToString) -> Self {
276        self.audiences = Some(Audiences::AsString(audience.to_string()));
277        self
278    }
279
280    /// Set the JWT identifier
281    pub fn with_jwt_id(mut self, jwt_id: impl ToString) -> Self {
282        self.jwt_id = Some(jwt_id.to_string());
283        self
284    }
285
286    /// Set the nonce
287    pub fn with_nonce(mut self, nonce: impl ToString) -> Self {
288        self.nonce = Some(nonce.to_string());
289        self
290    }
291
292    /// Create a nonce, attach it and return it
293    pub fn create_nonce(&mut self) -> String {
294        let mut raw_nonce = [0u8; 24];
295        let mut rng = rand::thread_rng();
296        rng.fill_bytes(&mut raw_nonce);
297        let nonce = Base64UrlSafeNoPadding::encode_to_string(raw_nonce).unwrap();
298        self.nonce = Some(nonce);
299        self.nonce.as_deref().unwrap().to_string()
300    }
301}
302
303pub struct Claims;
304
305impl Claims {
306    /// Create a new set of claims, without custom data, expiring in
307    /// `valid_for`.
308    pub fn create(valid_for: Duration) -> JWTClaims<NoCustomClaims> {
309        let now = Some(Clock::now_since_epoch());
310        JWTClaims {
311            issued_at: now,
312            expires_at: Some(now.unwrap() + valid_for),
313            invalid_before: now,
314            audiences: None,
315            issuer: None,
316            jwt_id: None,
317            subject: None,
318            nonce: None,
319            custom: NoCustomClaims {},
320        }
321    }
322
323    /// Create a new set of claims, with custom data, expiring in `valid_for`.
324    pub fn with_custom_claims<CustomClaims: Serialize + DeserializeOwned>(
325        custom_claims: CustomClaims,
326        valid_for: Duration,
327    ) -> JWTClaims<CustomClaims> {
328        let now = Some(Clock::now_since_epoch());
329        JWTClaims {
330            issued_at: now,
331            expires_at: Some(now.unwrap() + valid_for),
332            invalid_before: now,
333            audiences: None,
334            issuer: None,
335            jwt_id: None,
336            subject: None,
337            nonce: None,
338            custom: custom_claims,
339        }
340    }
341}
342
343#[cfg(test)]
344mod tests {
345    use super::*;
346
347    #[test]
348    fn should_set_standard_claims() {
349        let exp = Duration::from_mins(10);
350        let mut audiences = HashSet::new();
351        audiences.insert("audience1".to_string());
352        audiences.insert("audience2".to_string());
353        let claims = Claims::create(exp)
354            .with_audiences(audiences.clone())
355            .with_issuer("issuer")
356            .with_jwt_id("jwt_id")
357            .with_nonce("nonce")
358            .with_subject("subject");
359
360        assert_eq!(claims.audiences, Some(Audiences::AsSet(audiences)));
361        assert_eq!(claims.issuer, Some("issuer".to_owned()));
362        assert_eq!(claims.jwt_id, Some("jwt_id".to_owned()));
363        assert_eq!(claims.nonce, Some("nonce".to_owned()));
364        assert_eq!(claims.subject, Some("subject".to_owned()));
365    }
366
367    #[test]
368    fn parse_floating_point_unix_time() {
369        let claims: JWTClaims<()> = serde_json::from_str(r#"{"exp":1617757825.8}"#).unwrap();
370        assert_eq!(
371            claims.expires_at,
372            Some(UnixTimeStamp::from_secs(1617757825))
373        );
374    }
375
376    #[test]
377    fn should_tolerate_clock_drift() {
378        let exp = Duration::from_mins(1);
379        let claims = Claims::create(exp);
380        let mut options = VerificationOptions::default();
381
382        // Verifier clock is 2 minutes ahead of the token clock.
383        // The token is valid for 1 minute, with an extra tolerance of 1 minute.
384        // Verification should pass.
385        let drift = Duration::from_mins(2);
386        options.artificial_time = Some(claims.issued_at.unwrap() + drift);
387        options.time_tolerance = Some(Duration::from_mins(1));
388        claims.validate(&options).unwrap();
389
390        // Verifier clock is 2 minutes ahead of the token clock.
391        // The token is valid for 1 minute, with an extra tolerance of 1 minute.
392        // Verification must not pass.
393        let drift = Duration::from_mins(3);
394        options.artificial_time = Some(claims.issued_at.unwrap() + drift);
395        options.time_tolerance = Some(Duration::from_mins(1));
396        assert!(claims.validate(&options).is_err());
397
398        // Verifier clock is 2 minutes ahead of the token clock.
399        // The token is valid for 30 seconds, with an extra tolerance of 1 minute.
400        // Verification must not pass.
401        let drift = Duration::from_secs(30);
402        options.artificial_time = Some(claims.issued_at.unwrap() + drift);
403        options.time_tolerance = Some(Duration::from_mins(1));
404        claims.validate(&options).unwrap();
405
406        // Verifier clock is 2 minutes behind the token clock.
407        // The token is valid for 1 minute, so it is already expired.
408        // We have a tolerance of 1 minute.
409        // Verification must not pass.
410        let drift = Duration::from_mins(2);
411        options.artificial_time = Some(claims.issued_at.unwrap() - drift);
412        options.time_tolerance = Some(Duration::from_mins(1));
413        assert!(claims.validate(&options).is_err());
414
415        // Verifier clock is 2 minutes behind the token clock.
416        // The token is valid for 1 minute, so it is already expired.
417        // We have a tolerance of 2 minute.
418        // Verification should pass.
419        let drift = Duration::from_mins(2);
420        options.artificial_time = Some(claims.issued_at.unwrap() - drift);
421        options.time_tolerance = Some(Duration::from_mins(2));
422        claims.validate(&options).unwrap();
423    }
424}