1use std::collections::HashSet;
2use std::convert::TryInto;
3
4use coarsetime::{Clock, Duration, UnixTimeStamp};
5use ct_codecs::{Base64UrlSafeNoPadding, Encoder};
6use rand::RngCore;
7use serde::{de::DeserializeOwned, Deserialize, Serialize};
8
9use crate::common::VerificationOptions;
10use crate::error::*;
11use crate::serde_additions;
12
13pub const DEFAULT_TIME_TOLERANCE_SECS: u64 = 900;
14
15#[derive(Copy, Clone, Default, Debug, Serialize, Deserialize)]
17pub struct NoCustomClaims {}
18
19#[derive(Debug, Clone, Eq, PartialEq)]
22pub enum Audiences {
23 AsSet(HashSet<String>),
24 AsString(String),
25}
26
27impl Audiences {
28 pub fn is_set(&self) -> bool {
30 matches!(self, Audiences::AsSet(_))
31 }
32
33 pub fn is_string(&self) -> bool {
35 matches!(self, Audiences::AsString(_))
36 }
37
38 pub fn contains(&self, allowed_audiences: &HashSet<String>) -> bool {
41 match self {
42 Audiences::AsString(audience) => allowed_audiences.contains(audience),
43 Audiences::AsSet(audiences) => {
44 audiences.intersection(allowed_audiences).next().is_some()
45 }
46 }
47 }
48
49 pub fn into_set(self) -> HashSet<String> {
51 match self {
52 Audiences::AsSet(audiences_set) => audiences_set,
53 Audiences::AsString(audiences) => {
54 let mut audiences_set = HashSet::new();
55 if !audiences.is_empty() {
56 audiences_set.insert(audiences);
57 }
58 audiences_set
59 }
60 }
61 }
62
63 pub fn into_string(self) -> Result<String, Error> {
67 match self {
68 Audiences::AsString(audiences_str) => Ok(audiences_str),
69 Audiences::AsSet(audiences) => {
70 if audiences.len() > 1 {
71 bail!(JWTError::TooManyAudiences);
72 }
73 Ok(audiences
74 .iter()
75 .next()
76 .map(|x| x.to_string())
77 .unwrap_or_default())
78 }
79 }
80 }
81}
82
83impl TryInto<String> for Audiences {
84 type Error = Error;
85
86 fn try_into(self) -> Result<String, Error> {
87 self.into_string()
88 }
89}
90
91impl From<Audiences> for HashSet<String> {
92 fn from(audiences: Audiences) -> HashSet<String> {
93 audiences.into_set()
94 }
95}
96
97impl<T: ToString> From<T> for Audiences {
98 fn from(audience: T) -> Self {
99 Audiences::AsString(audience.to_string())
100 }
101}
102
103#[derive(Debug, Clone, Serialize, Deserialize)]
109pub struct JWTClaims<CustomClaims> {
110 #[serde(
112 rename = "iat",
113 default,
114 skip_serializing_if = "Option::is_none",
115 with = "self::serde_additions::unix_timestamp"
116 )]
117 pub issued_at: Option<UnixTimeStamp>,
118
119 #[serde(
121 rename = "exp",
122 default,
123 skip_serializing_if = "Option::is_none",
124 with = "self::serde_additions::unix_timestamp"
125 )]
126 pub expires_at: Option<UnixTimeStamp>,
127
128 #[serde(
130 rename = "nbf",
131 default,
132 skip_serializing_if = "Option::is_none",
133 with = "self::serde_additions::unix_timestamp"
134 )]
135 pub invalid_before: Option<UnixTimeStamp>,
136
137 #[serde(rename = "iss", default, skip_serializing_if = "Option::is_none")]
139 pub issuer: Option<String>,
140
141 #[serde(rename = "sub", default, skip_serializing_if = "Option::is_none")]
143 pub subject: Option<String>,
144
145 #[serde(
147 rename = "aud",
148 default,
149 skip_serializing_if = "Option::is_none",
150 with = "self::serde_additions::audiences"
151 )]
152 pub audiences: Option<Audiences>,
153
154 #[serde(rename = "jti", default, skip_serializing_if = "Option::is_none")]
163 pub jwt_id: Option<String>,
164
165 #[serde(rename = "nonce", default, skip_serializing_if = "Option::is_none")]
167 pub nonce: Option<String>,
168
169 #[serde(flatten)]
171 pub custom: CustomClaims,
172}
173
174impl<CustomClaims> JWTClaims<CustomClaims> {
175 pub(crate) fn validate(&self, options: &VerificationOptions) -> Result<(), Error> {
176 let now = options
177 .artificial_time
178 .unwrap_or_else(Clock::now_since_epoch);
179 let time_tolerance = options.time_tolerance.unwrap_or_default();
180
181 if let Some(reject_before) = options.reject_before {
182 ensure!(now <= reject_before, JWTError::OldTokenReused);
183 }
184 if let Some(time_issued) = self.issued_at {
185 ensure!(time_issued <= now + time_tolerance, JWTError::ClockDrift);
186 if let Some(max_validity) = options.max_validity {
187 ensure!(
188 now <= time_issued || now - time_issued <= max_validity,
189 JWTError::TokenIsTooOld
190 );
191 }
192 }
193 if !options.accept_future {
194 if let Some(invalid_before) = self.invalid_before {
195 ensure!(
196 now + time_tolerance >= invalid_before,
197 JWTError::TokenNotValidYet
198 );
199 }
200 }
201 if let Some(expires_at) = self.expires_at {
202 ensure!(
203 now - time_tolerance <= expires_at,
204 JWTError::TokenHasExpired
205 );
206 }
207 if let Some(allowed_issuers) = &options.allowed_issuers {
208 if let Some(issuer) = &self.issuer {
209 ensure!(
210 allowed_issuers.contains(issuer),
211 JWTError::RequiredIssuerMismatch
212 );
213 } else {
214 bail!(JWTError::RequiredIssuerMissing);
215 }
216 }
217 if let Some(required_subject) = &options.required_subject {
218 if let Some(subject) = &self.subject {
219 ensure!(
220 subject == required_subject,
221 JWTError::RequiredSubjectMismatch
222 );
223 } else {
224 bail!(JWTError::RequiredSubjectMissing);
225 }
226 }
227 if let Some(required_nonce) = &options.required_nonce {
228 if let Some(nonce) = &self.nonce {
229 ensure!(nonce == required_nonce, JWTError::RequiredNonceMismatch);
230 } else {
231 bail!(JWTError::RequiredNonceMissing);
232 }
233 }
234 if let Some(allowed_audiences) = &options.allowed_audiences {
235 if let Some(audiences) = &self.audiences {
236 ensure!(
237 audiences.contains(allowed_audiences),
238 JWTError::RequiredAudienceMismatch
239 );
240 } else {
241 bail!(JWTError::RequiredAudienceMissing);
242 }
243 }
244 Ok(())
245 }
246
247 pub fn invalid_before(mut self, unix_timestamp: UnixTimeStamp) -> Self {
249 self.invalid_before = Some(unix_timestamp);
250 self
251 }
252
253 pub fn with_issuer(mut self, issuer: impl ToString) -> Self {
255 self.issuer = Some(issuer.to_string());
256 self
257 }
258
259 pub fn with_subject(mut self, subject: impl ToString) -> Self {
261 self.subject = Some(subject.to_string());
262 self
263 }
264
265 pub fn with_audiences(mut self, audiences: HashSet<impl ToString>) -> Self {
268 self.audiences = Some(Audiences::AsSet(
269 audiences.iter().map(|x| x.to_string()).collect(),
270 ));
271 self
272 }
273
274 pub fn with_audience(mut self, audience: impl ToString) -> Self {
276 self.audiences = Some(Audiences::AsString(audience.to_string()));
277 self
278 }
279
280 pub fn with_jwt_id(mut self, jwt_id: impl ToString) -> Self {
282 self.jwt_id = Some(jwt_id.to_string());
283 self
284 }
285
286 pub fn with_nonce(mut self, nonce: impl ToString) -> Self {
288 self.nonce = Some(nonce.to_string());
289 self
290 }
291
292 pub fn create_nonce(&mut self) -> String {
294 let mut raw_nonce = [0u8; 24];
295 let mut rng = rand::thread_rng();
296 rng.fill_bytes(&mut raw_nonce);
297 let nonce = Base64UrlSafeNoPadding::encode_to_string(raw_nonce).unwrap();
298 self.nonce = Some(nonce);
299 self.nonce.as_deref().unwrap().to_string()
300 }
301}
302
303pub struct Claims;
304
305impl Claims {
306 pub fn create(valid_for: Duration) -> JWTClaims<NoCustomClaims> {
309 let now = Some(Clock::now_since_epoch());
310 JWTClaims {
311 issued_at: now,
312 expires_at: Some(now.unwrap() + valid_for),
313 invalid_before: now,
314 audiences: None,
315 issuer: None,
316 jwt_id: None,
317 subject: None,
318 nonce: None,
319 custom: NoCustomClaims {},
320 }
321 }
322
323 pub fn with_custom_claims<CustomClaims: Serialize + DeserializeOwned>(
325 custom_claims: CustomClaims,
326 valid_for: Duration,
327 ) -> JWTClaims<CustomClaims> {
328 let now = Some(Clock::now_since_epoch());
329 JWTClaims {
330 issued_at: now,
331 expires_at: Some(now.unwrap() + valid_for),
332 invalid_before: now,
333 audiences: None,
334 issuer: None,
335 jwt_id: None,
336 subject: None,
337 nonce: None,
338 custom: custom_claims,
339 }
340 }
341}
342
343#[cfg(test)]
344mod tests {
345 use super::*;
346
347 #[test]
348 fn should_set_standard_claims() {
349 let exp = Duration::from_mins(10);
350 let mut audiences = HashSet::new();
351 audiences.insert("audience1".to_string());
352 audiences.insert("audience2".to_string());
353 let claims = Claims::create(exp)
354 .with_audiences(audiences.clone())
355 .with_issuer("issuer")
356 .with_jwt_id("jwt_id")
357 .with_nonce("nonce")
358 .with_subject("subject");
359
360 assert_eq!(claims.audiences, Some(Audiences::AsSet(audiences)));
361 assert_eq!(claims.issuer, Some("issuer".to_owned()));
362 assert_eq!(claims.jwt_id, Some("jwt_id".to_owned()));
363 assert_eq!(claims.nonce, Some("nonce".to_owned()));
364 assert_eq!(claims.subject, Some("subject".to_owned()));
365 }
366
367 #[test]
368 fn parse_floating_point_unix_time() {
369 let claims: JWTClaims<()> = serde_json::from_str(r#"{"exp":1617757825.8}"#).unwrap();
370 assert_eq!(
371 claims.expires_at,
372 Some(UnixTimeStamp::from_secs(1617757825))
373 );
374 }
375
376 #[test]
377 fn should_tolerate_clock_drift() {
378 let exp = Duration::from_mins(1);
379 let claims = Claims::create(exp);
380 let mut options = VerificationOptions::default();
381
382 let drift = Duration::from_mins(2);
386 options.artificial_time = Some(claims.issued_at.unwrap() + drift);
387 options.time_tolerance = Some(Duration::from_mins(1));
388 claims.validate(&options).unwrap();
389
390 let drift = Duration::from_mins(3);
394 options.artificial_time = Some(claims.issued_at.unwrap() + drift);
395 options.time_tolerance = Some(Duration::from_mins(1));
396 assert!(claims.validate(&options).is_err());
397
398 let drift = Duration::from_secs(30);
402 options.artificial_time = Some(claims.issued_at.unwrap() + drift);
403 options.time_tolerance = Some(Duration::from_mins(1));
404 claims.validate(&options).unwrap();
405
406 let drift = Duration::from_mins(2);
411 options.artificial_time = Some(claims.issued_at.unwrap() - drift);
412 options.time_tolerance = Some(Duration::from_mins(1));
413 assert!(claims.validate(&options).is_err());
414
415 let drift = Duration::from_mins(2);
420 options.artificial_time = Some(claims.issued_at.unwrap() - drift);
421 options.time_tolerance = Some(Duration::from_mins(2));
422 claims.validate(&options).unwrap();
423 }
424}