rsa/
pss.rs

1//! Support for the [Probabilistic Signature Scheme] (PSS) a.k.a. RSASSA-PSS.
2//!
3//! Designed by Mihir Bellare and Phillip Rogaway. Specified in [RFC8017 § 8.1].
4//!
5//! # Usage
6//!
7//! See [code example in the toplevel rustdoc](../index.html#pss-signatures).
8//!
9//! [Probabilistic Signature Scheme]: https://en.wikipedia.org/wiki/Probabilistic_signature_scheme
10//! [RFC8017 § 8.1]: https://datatracker.ietf.org/doc/html/rfc8017#section-8.1
11
12use alloc::vec::Vec;
13
14use core::fmt::{Debug, Display, Formatter, LowerHex, UpperHex};
15use core::marker::PhantomData;
16use core::ops::Deref;
17use digest::{Digest, DynDigest, FixedOutputReset};
18use pkcs8::{Document, EncodePrivateKey, EncodePublicKey, SecretDocument};
19use rand_core::{CryptoRng, RngCore};
20#[cfg(feature = "hazmat")]
21use signature::hazmat::{PrehashVerifier, RandomizedPrehashSigner};
22use signature::{
23    DigestVerifier, RandomizedDigestSigner, RandomizedSigner, Signature as SignSignature, Verifier,
24};
25use subtle::ConstantTimeEq;
26
27use crate::algorithms::{mgf1_xor, mgf1_xor_digest};
28use crate::errors::{Error, Result};
29use crate::key::{PrivateKey, PublicKey};
30use crate::{RsaPrivateKey, RsaPublicKey};
31
32/// RSASSA-PSS signatures as described in [RFC8017 § 8.1].
33///
34/// [RFC8017 § 8.1]: https://datatracker.ietf.org/doc/html/rfc8017#section-8.1
35#[derive(Clone)]
36pub struct Signature {
37    bytes: Vec<u8>,
38}
39
40impl signature::Signature for Signature {
41    fn from_bytes(bytes: &[u8]) -> signature::Result<Self> {
42        Ok(Signature {
43            bytes: bytes.into(),
44        })
45    }
46
47    fn as_bytes(&self) -> &[u8] {
48        self.bytes.as_slice()
49    }
50}
51
52impl From<Vec<u8>> for Signature {
53    fn from(bytes: Vec<u8>) -> Self {
54        Self { bytes }
55    }
56}
57
58impl Deref for Signature {
59    type Target = [u8];
60
61    fn deref(&self) -> &Self::Target {
62        self.as_bytes()
63    }
64}
65
66impl PartialEq for Signature {
67    fn eq(&self, other: &Self) -> bool {
68        self.as_bytes() == other.as_bytes()
69    }
70}
71
72impl Eq for Signature {}
73
74impl Debug for Signature {
75    fn fmt(&self, fmt: &mut Formatter<'_>) -> core::result::Result<(), core::fmt::Error> {
76        fmt.debug_list().entries(self.as_bytes().iter()).finish()
77    }
78}
79
80impl AsRef<[u8]> for Signature {
81    fn as_ref(&self) -> &[u8] {
82        self.as_bytes()
83    }
84}
85
86impl LowerHex for Signature {
87    fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
88        for byte in self.as_bytes() {
89            write!(f, "{:02x}", byte)?;
90        }
91        Ok(())
92    }
93}
94
95impl UpperHex for Signature {
96    fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
97        for byte in self.as_bytes() {
98            write!(f, "{:02X}", byte)?;
99        }
100        Ok(())
101    }
102}
103
104impl Display for Signature {
105    fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
106        write!(f, "{:X}", self)
107    }
108}
109
110pub(crate) fn verify<PK: PublicKey>(
111    pub_key: &PK,
112    hashed: &[u8],
113    sig: &[u8],
114    digest: &mut dyn DynDigest,
115) -> Result<()> {
116    if sig.len() != pub_key.size() {
117        return Err(Error::Verification);
118    }
119
120    let em_bits = pub_key.n().bits() - 1;
121    let em_len = (em_bits + 7) / 8;
122    let mut em = pub_key.raw_encryption_primitive(sig, em_len)?;
123
124    emsa_pss_verify(hashed, &mut em, em_bits, None, digest)
125}
126
127pub(crate) fn verify_digest<PK, D>(pub_key: &PK, hashed: &[u8], sig: &[u8]) -> Result<()>
128where
129    PK: PublicKey,
130    D: Digest + FixedOutputReset,
131{
132    if sig.len() != pub_key.size() {
133        return Err(Error::Verification);
134    }
135
136    let em_bits = pub_key.n().bits() - 1;
137    let em_len = (em_bits + 7) / 8;
138    let mut em = pub_key.raw_encryption_primitive(sig, em_len)?;
139
140    emsa_pss_verify_digest::<D>(hashed, &mut em, em_bits, None)
141}
142
143/// SignPSS calculates the signature of hashed using RSASSA-PSS.
144///
145/// Note that hashed must be the result of hashing the input message using the
146/// given hash function. The opts argument may be nil, in which case sensible
147/// defaults are used.
148// TODO: bind T with the CryptoRng trait
149pub(crate) fn sign<T: RngCore + CryptoRng, SK: PrivateKey>(
150    rng: &mut T,
151    blind: bool,
152    priv_key: &SK,
153    hashed: &[u8],
154    salt_len: Option<usize>,
155    digest: &mut dyn DynDigest,
156) -> Result<Vec<u8>> {
157    let salt = generate_salt(rng, priv_key, salt_len, digest.output_size());
158
159    sign_pss_with_salt(blind.then(|| rng), priv_key, hashed, &salt, digest)
160}
161
162pub(crate) fn sign_digest<T: RngCore + CryptoRng, SK: PrivateKey, D: Digest + FixedOutputReset>(
163    rng: &mut T,
164    blind: bool,
165    priv_key: &SK,
166    hashed: &[u8],
167    salt_len: Option<usize>,
168) -> Result<Vec<u8>> {
169    let salt = generate_salt(rng, priv_key, salt_len, <D as Digest>::output_size());
170
171    sign_pss_with_salt_digest::<_, _, D>(blind.then(|| rng), priv_key, hashed, &salt)
172}
173
174fn generate_salt<T: RngCore + ?Sized, SK: PrivateKey>(
175    rng: &mut T,
176    priv_key: &SK,
177    salt_len: Option<usize>,
178    digest_size: usize,
179) -> Vec<u8> {
180    let salt_len = salt_len.unwrap_or_else(|| priv_key.size() - 2 - digest_size);
181
182    let mut salt = vec![0; salt_len];
183    rng.fill_bytes(&mut salt[..]);
184
185    salt
186}
187
188/// signPSSWithSalt calculates the signature of hashed using PSS with specified salt.
189///
190/// Note that hashed must be the result of hashing the input message using the
191/// given hash function. salt is a random sequence of bytes whose length will be
192/// later used to verify the signature.
193fn sign_pss_with_salt<T: CryptoRng + RngCore, SK: PrivateKey>(
194    blind_rng: Option<&mut T>,
195    priv_key: &SK,
196    hashed: &[u8],
197    salt: &[u8],
198    digest: &mut dyn DynDigest,
199) -> Result<Vec<u8>> {
200    let em_bits = priv_key.n().bits() - 1;
201    let em = emsa_pss_encode(hashed, em_bits, salt, digest)?;
202
203    priv_key.raw_decryption_primitive(blind_rng, &em, priv_key.size())
204}
205
206fn sign_pss_with_salt_digest<
207    T: CryptoRng + RngCore,
208    SK: PrivateKey,
209    D: Digest + FixedOutputReset,
210>(
211    blind_rng: Option<&mut T>,
212    priv_key: &SK,
213    hashed: &[u8],
214    salt: &[u8],
215) -> Result<Vec<u8>> {
216    let em_bits = priv_key.n().bits() - 1;
217    let em = emsa_pss_encode_digest::<D>(hashed, em_bits, salt)?;
218
219    priv_key.raw_decryption_primitive(blind_rng, &em, priv_key.size())
220}
221
222fn emsa_pss_encode(
223    m_hash: &[u8],
224    em_bits: usize,
225    salt: &[u8],
226    hash: &mut dyn DynDigest,
227) -> Result<Vec<u8>> {
228    // See [1], section 9.1.1
229    let h_len = hash.output_size();
230    let s_len = salt.len();
231    let em_len = (em_bits + 7) / 8;
232
233    // 1. If the length of M is greater than the input limitation for the
234    //     hash function (2^61 - 1 octets for SHA-1), output "message too
235    //     long" and stop.
236    //
237    // 2.  Let mHash = Hash(M), an octet string of length hLen.
238    if m_hash.len() != h_len {
239        return Err(Error::InputNotHashed);
240    }
241
242    // 3. If em_len < h_len + s_len + 2, output "encoding error" and stop.
243    if em_len < h_len + s_len + 2 {
244        // TODO: Key size too small
245        return Err(Error::Internal);
246    }
247
248    let mut em = vec![0; em_len];
249
250    let (db, h) = em.split_at_mut(em_len - h_len - 1);
251    let h = &mut h[..(em_len - 1) - db.len()];
252
253    // 4. Generate a random octet string salt of length s_len; if s_len = 0,
254    //     then salt is the empty string.
255    //
256    // 5.  Let
257    //       M' = (0x)00 00 00 00 00 00 00 00 || m_hash || salt;
258    //
259    //     M' is an octet string of length 8 + h_len + s_len with eight
260    //     initial zero octets.
261    //
262    // 6.  Let H = Hash(M'), an octet string of length h_len.
263    let prefix = [0u8; 8];
264
265    hash.update(&prefix);
266    hash.update(m_hash);
267    hash.update(salt);
268
269    let hashed = hash.finalize_reset();
270    h.copy_from_slice(&hashed);
271
272    // 7.  Generate an octet string PS consisting of em_len - s_len - h_len - 2
273    //     zero octets. The length of PS may be 0.
274    //
275    // 8.  Let DB = PS || 0x01 || salt; DB is an octet string of length
276    //     emLen - hLen - 1.
277    db[em_len - s_len - h_len - 2] = 0x01;
278    db[em_len - s_len - h_len - 1..].copy_from_slice(salt);
279
280    // 9.  Let dbMask = MGF(H, emLen - hLen - 1).
281    //
282    // 10. Let maskedDB = DB \xor dbMask.
283    mgf1_xor(db, hash, h);
284
285    // 11. Set the leftmost 8 * em_len - em_bits bits of the leftmost octet in
286    //     maskedDB to zero.
287    db[0] &= 0xFF >> (8 * em_len - em_bits);
288
289    // 12. Let EM = maskedDB || H || 0xbc.
290    em[em_len - 1] = 0xBC;
291
292    Ok(em)
293}
294
295fn emsa_pss_encode_digest<D>(m_hash: &[u8], em_bits: usize, salt: &[u8]) -> Result<Vec<u8>>
296where
297    D: Digest + FixedOutputReset,
298{
299    // See [1], section 9.1.1
300    let h_len = <D as Digest>::output_size();
301    let s_len = salt.len();
302    let em_len = (em_bits + 7) / 8;
303
304    // 1. If the length of M is greater than the input limitation for the
305    //     hash function (2^61 - 1 octets for SHA-1), output "message too
306    //     long" and stop.
307    //
308    // 2.  Let mHash = Hash(M), an octet string of length hLen.
309    if m_hash.len() != h_len {
310        return Err(Error::InputNotHashed);
311    }
312
313    // 3. If em_len < h_len + s_len + 2, output "encoding error" and stop.
314    if em_len < h_len + s_len + 2 {
315        // TODO: Key size too small
316        return Err(Error::Internal);
317    }
318
319    let mut em = vec![0; em_len];
320
321    let (db, h) = em.split_at_mut(em_len - h_len - 1);
322    let h = &mut h[..(em_len - 1) - db.len()];
323
324    // 4. Generate a random octet string salt of length s_len; if s_len = 0,
325    //     then salt is the empty string.
326    //
327    // 5.  Let
328    //       M' = (0x)00 00 00 00 00 00 00 00 || m_hash || salt;
329    //
330    //     M' is an octet string of length 8 + h_len + s_len with eight
331    //     initial zero octets.
332    //
333    // 6.  Let H = Hash(M'), an octet string of length h_len.
334    let prefix = [0u8; 8];
335
336    let mut hash = D::new();
337
338    Digest::update(&mut hash, &prefix);
339    Digest::update(&mut hash, m_hash);
340    Digest::update(&mut hash, salt);
341
342    let hashed = hash.finalize_reset();
343    h.copy_from_slice(&hashed);
344
345    // 7.  Generate an octet string PS consisting of em_len - s_len - h_len - 2
346    //     zero octets. The length of PS may be 0.
347    //
348    // 8.  Let DB = PS || 0x01 || salt; DB is an octet string of length
349    //     emLen - hLen - 1.
350    db[em_len - s_len - h_len - 2] = 0x01;
351    db[em_len - s_len - h_len - 1..].copy_from_slice(salt);
352
353    // 9.  Let dbMask = MGF(H, emLen - hLen - 1).
354    //
355    // 10. Let maskedDB = DB \xor dbMask.
356    mgf1_xor_digest(db, &mut hash, h);
357
358    // 11. Set the leftmost 8 * em_len - em_bits bits of the leftmost octet in
359    //     maskedDB to zero.
360    db[0] &= 0xFF >> (8 * em_len - em_bits);
361
362    // 12. Let EM = maskedDB || H || 0xbc.
363    em[em_len - 1] = 0xBC;
364
365    Ok(em)
366}
367
368fn emsa_pss_verify_pre<'a>(
369    m_hash: &[u8],
370    em: &'a mut [u8],
371    em_bits: usize,
372    s_len: Option<usize>,
373    h_len: usize,
374) -> Result<(&'a mut [u8], &'a mut [u8])> {
375    // 1. If the length of M is greater than the input limitation for the
376    //    hash function (2^61 - 1 octets for SHA-1), output "inconsistent"
377    //    and stop.
378    //
379    // 2. Let mHash = Hash(M), an octet string of length hLen
380    if m_hash.len() != h_len {
381        return Err(Error::Verification);
382    }
383
384    // 3. If emLen < hLen + sLen + 2, output "inconsistent" and stop.
385    let em_len = em.len(); //(em_bits + 7) / 8;
386    if em_len < h_len + s_len.unwrap_or_default() + 2 {
387        return Err(Error::Verification);
388    }
389
390    // 4. If the rightmost octet of EM does not have hexadecimal value
391    //    0xbc, output "inconsistent" and stop.
392    if em[em.len() - 1] != 0xBC {
393        return Err(Error::Verification);
394    }
395
396    // 5. Let maskedDB be the leftmost emLen - hLen - 1 octets of EM, and
397    //    let H be the next hLen octets.
398    let (db, h) = em.split_at_mut(em_len - h_len - 1);
399    let h = &mut h[..h_len];
400
401    // 6. If the leftmost 8 * em_len - em_bits bits of the leftmost octet in
402    //    maskedDB are not all equal to zero, output "inconsistent" and
403    //    stop.
404    if db[0] & (0xFF << /*uint*/(8 - (8 * em_len - em_bits))) != 0 {
405        return Err(Error::Verification);
406    }
407
408    Ok((db, h))
409}
410
411fn emsa_pss_get_salt(
412    db: &[u8],
413    em_len: usize,
414    s_len: Option<usize>,
415    h_len: usize,
416) -> Result<&[u8]> {
417    let s_len = match s_len {
418        None => (0..=em_len - (h_len + 2))
419            .rev()
420            .try_fold(None, |state, i| match (state, db[em_len - h_len - i - 2]) {
421                (Some(i), _) => Ok(Some(i)),
422                (_, 1) => Ok(Some(i)),
423                (_, 0) => Ok(None),
424                _ => Err(Error::Verification),
425            })?
426            .ok_or(Error::Verification)?,
427        Some(s_len) => {
428            // 10. If the emLen - hLen - sLen - 2 leftmost octets of DB are not zero
429            //     or if the octet at position emLen - hLen - sLen - 1 (the leftmost
430            //     position is "position 1") does not have hexadecimal value 0x01,
431            //     output "inconsistent" and stop.
432            let (zeroes, rest) = db.split_at(em_len - h_len - s_len - 2);
433            if zeroes.iter().any(|e| *e != 0x00) || rest[0] != 0x01 {
434                return Err(Error::Verification);
435            }
436
437            s_len
438        }
439    };
440
441    // 11. Let salt be the last s_len octets of DB.
442    let salt = &db[db.len() - s_len..];
443
444    Ok(salt)
445}
446
447fn emsa_pss_verify(
448    m_hash: &[u8],
449    em: &mut [u8],
450    em_bits: usize,
451    s_len: Option<usize>,
452    hash: &mut dyn DynDigest,
453) -> Result<()> {
454    let em_len = em.len(); //(em_bits + 7) / 8;
455    let h_len = hash.output_size();
456
457    let (db, h) = emsa_pss_verify_pre(m_hash, em, em_bits, s_len, h_len)?;
458
459    // 7. Let dbMask = MGF(H, em_len - h_len - 1)
460    //
461    // 8. Let DB = maskedDB \xor dbMask
462    mgf1_xor(db, hash, &*h);
463
464    // 9.  Set the leftmost 8 * emLen - emBits bits of the leftmost octet in DB
465    //     to zero.
466    db[0] &= 0xFF >> /*uint*/(8 * em_len - em_bits);
467
468    let salt = emsa_pss_get_salt(db, em_len, s_len, h_len)?;
469
470    // 12. Let
471    //          M' = (0x)00 00 00 00 00 00 00 00 || mHash || salt ;
472    //     M' is an octet string of length 8 + hLen + sLen with eight
473    //     initial zero octets.
474    //
475    // 13. Let H' = Hash(M'), an octet string of length hLen.
476    let prefix = [0u8; 8];
477
478    hash.update(&prefix[..]);
479    hash.update(m_hash);
480    hash.update(salt);
481    let h0 = hash.finalize_reset();
482
483    // 14. If H = H', output "consistent." Otherwise, output "inconsistent."
484    if h0.ct_eq(h).into() {
485        Ok(())
486    } else {
487        Err(Error::Verification)
488    }
489}
490
491fn emsa_pss_verify_digest<D>(
492    m_hash: &[u8],
493    em: &mut [u8],
494    em_bits: usize,
495    s_len: Option<usize>,
496) -> Result<()>
497where
498    D: Digest + FixedOutputReset,
499{
500    let em_len = em.len(); //(em_bits + 7) / 8;
501    let h_len = <D as Digest>::output_size();
502
503    let (db, h) = emsa_pss_verify_pre(m_hash, em, em_bits, s_len, h_len)?;
504
505    let mut hash = D::new();
506
507    // 7. Let dbMask = MGF(H, em_len - h_len - 1)
508    //
509    // 8. Let DB = maskedDB \xor dbMask
510    mgf1_xor_digest::<D>(db, &mut hash, &*h);
511
512    // 9.  Set the leftmost 8 * emLen - emBits bits of the leftmost octet in DB
513    //     to zero.
514    db[0] &= 0xFF >> /*uint*/(8 * em_len - em_bits);
515
516    let salt = emsa_pss_get_salt(db, em_len, s_len, h_len)?;
517
518    // 12. Let
519    //          M' = (0x)00 00 00 00 00 00 00 00 || mHash || salt ;
520    //     M' is an octet string of length 8 + hLen + sLen with eight
521    //     initial zero octets.
522    //
523    // 13. Let H' = Hash(M'), an octet string of length hLen.
524    let prefix = [0u8; 8];
525
526    Digest::update(&mut hash, &prefix[..]);
527    Digest::update(&mut hash, m_hash);
528    Digest::update(&mut hash, salt);
529    let h0 = hash.finalize_reset();
530
531    // 14. If H = H', output "consistent." Otherwise, output "inconsistent."
532    if h0.ct_eq(h).into() {
533        Ok(())
534    } else {
535        Err(Error::Verification)
536    }
537}
538
539/// Signing key for producing RSASSA-PSS signatures as described in
540/// [RFC8017 § 8.1].
541///
542/// [RFC8017 § 8.1]: https://datatracker.ietf.org/doc/html/rfc8017#section-8.1
543#[derive(Debug, Clone)]
544pub struct SigningKey<D>
545where
546    D: Digest,
547{
548    inner: RsaPrivateKey,
549    salt_len: Option<usize>,
550    phantom: PhantomData<D>,
551}
552
553impl<D> SigningKey<D>
554where
555    D: Digest,
556{
557    /// Create a new RSASSA-PSS signing key.
558    pub fn new(key: RsaPrivateKey) -> Self {
559        Self {
560            inner: key,
561            salt_len: None,
562            phantom: Default::default(),
563        }
564    }
565
566    /// Create a new RSASSA-PSS signing key with a salt of the given length.
567    pub fn new_with_salt_len(key: RsaPrivateKey, salt_len: usize) -> Self {
568        Self {
569            inner: key,
570            salt_len: Some(salt_len),
571            phantom: Default::default(),
572        }
573    }
574
575    pub(crate) fn key(&self) -> &RsaPrivateKey {
576        &self.inner
577    }
578}
579
580impl<D> From<RsaPrivateKey> for SigningKey<D>
581where
582    D: Digest,
583{
584    fn from(key: RsaPrivateKey) -> Self {
585        Self::new(key)
586    }
587}
588
589impl<D> From<SigningKey<D>> for RsaPrivateKey
590where
591    D: Digest,
592{
593    fn from(key: SigningKey<D>) -> Self {
594        key.inner
595    }
596}
597
598impl<D> EncodePrivateKey for SigningKey<D>
599where
600    D: Digest,
601{
602    fn to_pkcs8_der(&self) -> pkcs8::Result<SecretDocument> {
603        self.inner.to_pkcs8_der()
604    }
605}
606
607impl<D> RandomizedSigner<Signature> for SigningKey<D>
608where
609    D: Digest + FixedOutputReset,
610{
611    fn try_sign_with_rng(
612        &self,
613        mut rng: impl CryptoRng + RngCore,
614        msg: &[u8],
615    ) -> signature::Result<Signature> {
616        sign_digest::<_, _, D>(&mut rng, false, &self.inner, &D::digest(msg), self.salt_len)
617            .map(|v| v.into())
618            .map_err(|e| e.into())
619    }
620}
621
622impl<D> RandomizedDigestSigner<D, Signature> for SigningKey<D>
623where
624    D: Digest + FixedOutputReset,
625{
626    fn try_sign_digest_with_rng(
627        &self,
628        mut rng: impl CryptoRng + RngCore,
629        digest: D,
630    ) -> signature::Result<Signature> {
631        sign_digest::<_, _, D>(
632            &mut rng,
633            false,
634            &self.inner,
635            &digest.finalize(),
636            self.salt_len,
637        )
638        .map(|v| v.into())
639        .map_err(|e| e.into())
640    }
641}
642
643#[cfg(feature = "hazmat")]
644impl<D> RandomizedPrehashSigner<Signature> for SigningKey<D>
645where
646    D: Digest + FixedOutputReset,
647{
648    fn sign_prehash_with_rng(
649        &self,
650        mut rng: impl CryptoRng + RngCore,
651        prehash: &[u8],
652    ) -> signature::Result<Signature> {
653        sign_digest::<_, _, D>(&mut rng, false, &self.inner, prehash, self.salt_len)
654            .map(|v| v.into())
655            .map_err(|e| e.into())
656    }
657}
658
659impl<D> AsRef<RsaPrivateKey> for SigningKey<D>
660where
661    D: Digest,
662{
663    fn as_ref(&self) -> &RsaPrivateKey {
664        &self.inner
665    }
666}
667
668/// Signing key for producing "blinded" RSASSA-PSS signatures as described in
669/// [draft-irtf-cfrg-rsa-blind-signatures](https://datatracker.ietf.org/doc/draft-irtf-cfrg-rsa-blind-signatures/).
670#[derive(Debug, Clone)]
671pub struct BlindedSigningKey<D>
672where
673    D: Digest,
674{
675    inner: RsaPrivateKey,
676    salt_len: Option<usize>,
677    phantom: PhantomData<D>,
678}
679
680impl<D> BlindedSigningKey<D>
681where
682    D: Digest,
683{
684    /// Create a new RSASSA-PSS signing key which produces "blinded"
685    /// signatures.
686    pub fn new(key: RsaPrivateKey) -> Self {
687        Self {
688            inner: key,
689            salt_len: None,
690            phantom: Default::default(),
691        }
692    }
693
694    /// Create a new RSASSA-PSS signing key which produces "blinded"
695    /// signatures with a salt of the given length.
696    pub fn new_with_salt_len(key: RsaPrivateKey, salt_len: usize) -> Self {
697        Self {
698            inner: key,
699            salt_len: Some(salt_len),
700            phantom: Default::default(),
701        }
702    }
703
704    pub(crate) fn key(&self) -> &RsaPrivateKey {
705        &self.inner
706    }
707}
708
709impl<D> From<RsaPrivateKey> for BlindedSigningKey<D>
710where
711    D: Digest,
712{
713    fn from(key: RsaPrivateKey) -> Self {
714        Self::new(key)
715    }
716}
717
718impl<D> From<BlindedSigningKey<D>> for RsaPrivateKey
719where
720    D: Digest,
721{
722    fn from(key: BlindedSigningKey<D>) -> Self {
723        key.inner
724    }
725}
726
727impl<D> EncodePrivateKey for BlindedSigningKey<D>
728where
729    D: Digest,
730{
731    fn to_pkcs8_der(&self) -> pkcs8::Result<SecretDocument> {
732        self.inner.to_pkcs8_der()
733    }
734}
735
736impl<D> RandomizedSigner<Signature> for BlindedSigningKey<D>
737where
738    D: Digest + FixedOutputReset,
739{
740    fn try_sign_with_rng(
741        &self,
742        mut rng: impl CryptoRng + RngCore,
743        msg: &[u8],
744    ) -> signature::Result<Signature> {
745        sign_digest::<_, _, D>(&mut rng, true, &self.inner, &D::digest(msg), self.salt_len)
746            .map(|v| v.into())
747            .map_err(|e| e.into())
748    }
749}
750
751impl<D> RandomizedDigestSigner<D, Signature> for BlindedSigningKey<D>
752where
753    D: Digest + FixedOutputReset,
754{
755    fn try_sign_digest_with_rng(
756        &self,
757        mut rng: impl CryptoRng + RngCore,
758        digest: D,
759    ) -> signature::Result<Signature> {
760        sign_digest::<_, _, D>(
761            &mut rng,
762            true,
763            &self.inner,
764            &digest.finalize(),
765            self.salt_len,
766        )
767        .map(|v| v.into())
768        .map_err(|e| e.into())
769    }
770}
771
772#[cfg(feature = "hazmat")]
773impl<D> RandomizedPrehashSigner<Signature> for BlindedSigningKey<D>
774where
775    D: Digest + FixedOutputReset,
776{
777    fn sign_prehash_with_rng(
778        &self,
779        mut rng: impl CryptoRng + RngCore,
780        prehash: &[u8],
781    ) -> signature::Result<Signature> {
782        sign_digest::<_, _, D>(&mut rng, true, &self.inner, prehash, self.salt_len)
783            .map(|v| v.into())
784            .map_err(|e| e.into())
785    }
786}
787
788impl<D> AsRef<RsaPrivateKey> for BlindedSigningKey<D>
789where
790    D: Digest,
791{
792    fn as_ref(&self) -> &RsaPrivateKey {
793        &self.inner
794    }
795}
796
797/// Verifying key for checking the validity of RSASSA-PSS signatures as
798/// described in [RFC8017 § 8.1].
799///
800/// [RFC8017 § 8.1]: https://datatracker.ietf.org/doc/html/rfc8017#section-8.1
801#[derive(Debug, Clone)]
802pub struct VerifyingKey<D>
803where
804    D: Digest,
805{
806    inner: RsaPublicKey,
807    phantom: PhantomData<D>,
808}
809
810impl<D> VerifyingKey<D>
811where
812    D: Digest,
813{
814    /// Create a new RSASSA-PSS verifying key.
815    pub fn new(key: RsaPublicKey) -> Self {
816        Self {
817            inner: key,
818            phantom: Default::default(),
819        }
820    }
821}
822
823impl<D> From<RsaPublicKey> for VerifyingKey<D>
824where
825    D: Digest,
826{
827    fn from(key: RsaPublicKey) -> Self {
828        Self::new(key)
829    }
830}
831
832impl<D> From<VerifyingKey<D>> for RsaPublicKey
833where
834    D: Digest,
835{
836    fn from(key: VerifyingKey<D>) -> Self {
837        key.inner
838    }
839}
840
841impl<D> From<SigningKey<D>> for VerifyingKey<D>
842where
843    D: Digest,
844{
845    fn from(key: SigningKey<D>) -> Self {
846        Self {
847            inner: key.key().into(),
848            phantom: Default::default(),
849        }
850    }
851}
852
853impl<D> From<&SigningKey<D>> for VerifyingKey<D>
854where
855    D: Digest,
856{
857    fn from(key: &SigningKey<D>) -> Self {
858        Self {
859            inner: key.key().into(),
860            phantom: Default::default(),
861        }
862    }
863}
864
865impl<D> From<BlindedSigningKey<D>> for VerifyingKey<D>
866where
867    D: Digest,
868{
869    fn from(key: BlindedSigningKey<D>) -> Self {
870        Self {
871            inner: key.key().into(),
872            phantom: Default::default(),
873        }
874    }
875}
876
877impl<D> From<&BlindedSigningKey<D>> for VerifyingKey<D>
878where
879    D: Digest,
880{
881    fn from(key: &BlindedSigningKey<D>) -> Self {
882        Self {
883            inner: key.key().into(),
884            phantom: Default::default(),
885        }
886    }
887}
888
889impl<D> Verifier<Signature> for VerifyingKey<D>
890where
891    D: Digest + FixedOutputReset,
892{
893    fn verify(&self, msg: &[u8], signature: &Signature) -> signature::Result<()> {
894        verify_digest::<_, D>(&self.inner, &D::digest(msg), signature.as_ref())
895            .map_err(|e| e.into())
896    }
897}
898
899impl<D> DigestVerifier<D, Signature> for VerifyingKey<D>
900where
901    D: Digest + FixedOutputReset,
902{
903    fn verify_digest(&self, digest: D, signature: &Signature) -> signature::Result<()> {
904        verify_digest::<_, D>(&self.inner, &digest.finalize(), signature.as_ref())
905            .map_err(|e| e.into())
906    }
907}
908
909#[cfg(feature = "hazmat")]
910impl<D> PrehashVerifier<Signature> for VerifyingKey<D>
911where
912    D: Digest + FixedOutputReset,
913{
914    fn verify_prehash(&self, prehash: &[u8], signature: &Signature) -> signature::Result<()> {
915        verify_digest::<_, D>(&self.inner, prehash, signature.as_ref()).map_err(|e| e.into())
916    }
917}
918
919impl<D> AsRef<RsaPublicKey> for VerifyingKey<D>
920where
921    D: Digest,
922{
923    fn as_ref(&self) -> &RsaPublicKey {
924        &self.inner
925    }
926}
927
928impl<D> EncodePublicKey for VerifyingKey<D>
929where
930    D: Digest,
931{
932    fn to_public_key_der(&self) -> pkcs8::spki::Result<Document> {
933        self.inner.to_public_key_der()
934    }
935}
936
937#[cfg(test)]
938mod test {
939    use crate::pss::{BlindedSigningKey, SigningKey, VerifyingKey};
940    use crate::{PaddingScheme, PublicKey, RsaPrivateKey, RsaPublicKey};
941
942    use hex_literal::hex;
943    use num_bigint::BigUint;
944    use num_traits::{FromPrimitive, Num};
945    use rand_chacha::{rand_core::SeedableRng, ChaCha8Rng};
946    use sha1::{Digest, Sha1};
947    #[cfg(feature = "hazmat")]
948    use signature::hazmat::{PrehashVerifier, RandomizedPrehashSigner};
949    use signature::{
950        DigestVerifier, RandomizedDigestSigner, RandomizedSigner, Signature, Verifier,
951    };
952
953    fn get_private_key() -> RsaPrivateKey {
954        // In order to generate new test vectors you'll need the PEM form of this key:
955        // -----BEGIN RSA PRIVATE KEY-----
956        // MIIBOgIBAAJBALKZD0nEffqM1ACuak0bijtqE2QrI/KLADv7l3kK3ppMyCuLKoF0
957        // fd7Ai2KW5ToIwzFofvJcS/STa6HA5gQenRUCAwEAAQJBAIq9amn00aS0h/CrjXqu
958        // /ThglAXJmZhOMPVn4eiu7/ROixi9sex436MaVeMqSNf7Ex9a8fRNfWss7Sqd9eWu
959        // RTUCIQDasvGASLqmjeffBNLTXV2A5g4t+kLVCpsEIZAycV5GswIhANEPLmax0ME/
960        // EO+ZJ79TJKN5yiGBRsv5yvx5UiHxajEXAiAhAol5N4EUyq6I9w1rYdhPMGpLfk7A
961        // IU2snfRJ6Nq2CQIgFrPsWRCkV+gOYcajD17rEqmuLrdIRexpg8N1DOSXoJ8CIGlS
962        // tAboUGBxTDq3ZroNism3DaMIbKPyYrAqhKov1h5V
963        // -----END RSA PRIVATE KEY-----
964
965        RsaPrivateKey::from_components(
966            BigUint::from_str_radix("9353930466774385905609975137998169297361893554149986716853295022578535724979677252958524466350471210367835187480748268864277464700638583474144061408845077", 10).unwrap(),
967            BigUint::from_u64(65537).unwrap(),
968            BigUint::from_str_radix("7266398431328116344057699379749222532279343923819063639497049039389899328538543087657733766554155839834519529439851673014800261285757759040931985506583861", 10).unwrap(),
969            vec![
970                BigUint::from_str_radix("98920366548084643601728869055592650835572950932266967461790948584315647051443",10).unwrap(),
971                BigUint::from_str_radix("94560208308847015747498523884063394671606671904944666360068158221458669711639", 10).unwrap()
972            ],
973        ).unwrap()
974    }
975
976    #[test]
977    fn test_verify_pss() {
978        let priv_key = get_private_key();
979
980        let tests = [
981            (
982                "test\n",
983                hex!(
984                    "6f86f26b14372b2279f79fb6807c49889835c204f71e38249b4c5601462da8ae"
985                    "30f26ffdd9c13f1c75eee172bebe7b7c89f2f1526c722833b9737d6c172a962f"
986                ),
987                true,
988            ),
989            (
990                "test\n",
991                hex!(
992                    "6f86f26b14372b2279f79fb6807c49889835c204f71e38249b4c5601462da8ae"
993                    "30f26ffdd9c13f1c75eee172bebe7b7c89f2f1526c722833b9737d6c172a962e"
994                ),
995                false,
996            ),
997        ];
998        let pub_key: RsaPublicKey = priv_key.into();
999
1000        for (text, sig, expected) in &tests {
1001            let digest = Sha1::digest(text.as_bytes()).to_vec();
1002            let result = pub_key.verify(PaddingScheme::new_pss::<Sha1>(), &digest, sig);
1003            match expected {
1004                true => result.expect("failed to verify"),
1005                false => {
1006                    result.expect_err("expected verifying error");
1007                }
1008            }
1009        }
1010    }
1011
1012    #[test]
1013    fn test_verify_pss_signer() {
1014        let priv_key = get_private_key();
1015
1016        let tests = [
1017            (
1018                "test\n",
1019                hex!(
1020                    "6f86f26b14372b2279f79fb6807c49889835c204f71e38249b4c5601462da8ae"
1021                    "30f26ffdd9c13f1c75eee172bebe7b7c89f2f1526c722833b9737d6c172a962f"
1022                ),
1023                true,
1024            ),
1025            (
1026                "test\n",
1027                hex!(
1028                    "6f86f26b14372b2279f79fb6807c49889835c204f71e38249b4c5601462da8ae"
1029                    "30f26ffdd9c13f1c75eee172bebe7b7c89f2f1526c722833b9737d6c172a962e"
1030                ),
1031                false,
1032            ),
1033        ];
1034        let pub_key: RsaPublicKey = priv_key.into();
1035        let verifying_key: VerifyingKey<Sha1> = VerifyingKey::new(pub_key);
1036
1037        for (text, sig, expected) in &tests {
1038            let result =
1039                verifying_key.verify(text.as_bytes(), &Signature::from_bytes(sig).unwrap());
1040            match expected {
1041                true => result.expect("failed to verify"),
1042                false => {
1043                    result.expect_err("expected verifying error");
1044                }
1045            }
1046        }
1047    }
1048
1049    #[test]
1050    fn test_verify_pss_digest_signer() {
1051        let priv_key = get_private_key();
1052
1053        let tests = [
1054            (
1055                "test\n",
1056                hex!(
1057                    "6f86f26b14372b2279f79fb6807c49889835c204f71e38249b4c5601462da8ae"
1058                    "30f26ffdd9c13f1c75eee172bebe7b7c89f2f1526c722833b9737d6c172a962f"
1059                ),
1060                true,
1061            ),
1062            (
1063                "test\n",
1064                hex!(
1065                    "6f86f26b14372b2279f79fb6807c49889835c204f71e38249b4c5601462da8ae"
1066                    "30f26ffdd9c13f1c75eee172bebe7b7c89f2f1526c722833b9737d6c172a962e"
1067                ),
1068                false,
1069            ),
1070        ];
1071        let pub_key: RsaPublicKey = priv_key.into();
1072        let verifying_key = VerifyingKey::new(pub_key);
1073
1074        for (text, sig, expected) in &tests {
1075            let mut digest = Sha1::new();
1076            digest.update(text.as_bytes());
1077            let result = verifying_key.verify_digest(digest, &Signature::from_bytes(sig).unwrap());
1078            match expected {
1079                true => result.expect("failed to verify"),
1080                false => {
1081                    result.expect_err("expected verifying error");
1082                }
1083            }
1084        }
1085    }
1086
1087    #[test]
1088    fn test_sign_and_verify_roundtrip() {
1089        let priv_key = get_private_key();
1090
1091        let tests = ["test\n"];
1092        let rng = ChaCha8Rng::from_seed([42; 32]);
1093
1094        for test in &tests {
1095            let digest = Sha1::digest(test.as_bytes()).to_vec();
1096            let sig = priv_key
1097                .sign_with_rng(&mut rng.clone(), PaddingScheme::new_pss::<Sha1>(), &digest)
1098                .expect("failed to sign");
1099
1100            priv_key
1101                .verify(PaddingScheme::new_pss::<Sha1>(), &digest, &sig)
1102                .expect("failed to verify");
1103        }
1104    }
1105
1106    #[test]
1107    fn test_sign_blinded_and_verify_roundtrip() {
1108        let priv_key = get_private_key();
1109
1110        let tests = ["test\n"];
1111        let rng = ChaCha8Rng::from_seed([42; 32]);
1112
1113        for test in &tests {
1114            let digest = Sha1::digest(test.as_bytes()).to_vec();
1115            let sig = priv_key
1116                .sign_blinded(&mut rng.clone(), PaddingScheme::new_pss::<Sha1>(), &digest)
1117                .expect("failed to sign");
1118
1119            priv_key
1120                .verify(PaddingScheme::new_pss::<Sha1>(), &digest, &sig)
1121                .expect("failed to verify");
1122        }
1123    }
1124
1125    #[test]
1126    fn test_sign_and_verify_roundtrip_signer() {
1127        let priv_key = get_private_key();
1128
1129        let tests = ["test\n"];
1130        let mut rng = ChaCha8Rng::from_seed([42; 32]);
1131        let signing_key = SigningKey::<Sha1>::new(priv_key);
1132        let verifying_key = VerifyingKey::from(&signing_key);
1133
1134        for test in &tests {
1135            let sig = signing_key.sign_with_rng(&mut rng, test.as_bytes());
1136            verifying_key
1137                .verify(test.as_bytes(), &sig)
1138                .expect("failed to verify");
1139        }
1140    }
1141
1142    #[test]
1143    fn test_sign_and_verify_roundtrip_blinded_signer() {
1144        let priv_key = get_private_key();
1145
1146        let tests = ["test\n"];
1147        let mut rng = ChaCha8Rng::from_seed([42; 32]);
1148        let signing_key = BlindedSigningKey::<Sha1>::new(priv_key);
1149        let verifying_key = VerifyingKey::from(&signing_key);
1150
1151        for test in &tests {
1152            let sig = signing_key.sign_with_rng(&mut rng, test.as_bytes());
1153            verifying_key
1154                .verify(test.as_bytes(), &sig)
1155                .expect("failed to verify");
1156        }
1157    }
1158
1159    #[test]
1160    fn test_sign_and_verify_roundtrip_digest_signer() {
1161        let priv_key = get_private_key();
1162
1163        let tests = ["test\n"];
1164        let mut rng = ChaCha8Rng::from_seed([42; 32]);
1165        let signing_key = SigningKey::new(priv_key);
1166        let verifying_key = VerifyingKey::from(&signing_key);
1167
1168        for test in &tests {
1169            let mut digest = Sha1::new();
1170            digest.update(test.as_bytes());
1171            let sig = signing_key.sign_digest_with_rng(&mut rng, digest);
1172
1173            let mut digest = Sha1::new();
1174            digest.update(test.as_bytes());
1175            verifying_key
1176                .verify_digest(digest, &sig)
1177                .expect("failed to verify");
1178        }
1179    }
1180
1181    #[test]
1182    fn test_sign_and_verify_roundtrip_blinded_digest_signer() {
1183        let priv_key = get_private_key();
1184
1185        let tests = ["test\n"];
1186        let mut rng = ChaCha8Rng::from_seed([42; 32]);
1187        let signing_key = BlindedSigningKey::<Sha1>::new(priv_key);
1188        let verifying_key = VerifyingKey::from(&signing_key);
1189
1190        for test in &tests {
1191            let mut digest = Sha1::new();
1192            digest.update(test.as_bytes());
1193            let sig = signing_key.sign_digest_with_rng(&mut rng, digest);
1194
1195            let mut digest = Sha1::new();
1196            digest.update(test.as_bytes());
1197            verifying_key
1198                .verify_digest(digest, &sig)
1199                .expect("failed to verify");
1200        }
1201    }
1202
1203    #[cfg(feature = "hazmat")]
1204    #[test]
1205    fn test_verify_pss_hazmat() {
1206        let priv_key = get_private_key();
1207
1208        let tests = [
1209            (
1210                Sha1::digest("test\n"),
1211                hex!(
1212                    "6f86f26b14372b2279f79fb6807c49889835c204f71e38249b4c5601462da8ae"
1213                    "30f26ffdd9c13f1c75eee172bebe7b7c89f2f1526c722833b9737d6c172a962f"
1214                ),
1215                true,
1216            ),
1217            (
1218                Sha1::digest("test\n"),
1219                hex!(
1220                    "6f86f26b14372b2279f79fb6807c49889835c204f71e38249b4c5601462da8ae"
1221                    "30f26ffdd9c13f1c75eee172bebe7b7c89f2f1526c722833b9737d6c172a962e"
1222                ),
1223                false,
1224            ),
1225        ];
1226        let pub_key: RsaPublicKey = priv_key.into();
1227        let verifying_key = VerifyingKey::<Sha1>::new(pub_key);
1228
1229        for (text, sig, expected) in &tests {
1230            let result =
1231                verifying_key.verify_prehash(text.as_ref(), &Signature::from_bytes(sig).unwrap());
1232            match expected {
1233                true => result.expect("failed to verify"),
1234                false => {
1235                    result.expect_err("expected verifying error");
1236                }
1237            }
1238        }
1239    }
1240
1241    #[cfg(feature = "hazmat")]
1242    #[test]
1243    fn test_sign_and_verify_pss_hazmat() {
1244        let priv_key = get_private_key();
1245
1246        let tests = [Sha1::digest("test\n")];
1247        let mut rng = ChaCha8Rng::from_seed([42; 32]);
1248        let signing_key = SigningKey::<Sha1>::new(priv_key);
1249        let verifying_key = VerifyingKey::from(&signing_key);
1250
1251        for test in &tests {
1252            let sig = signing_key
1253                .sign_prehash_with_rng(&mut rng, &test)
1254                .expect("failed to sign");
1255            verifying_key
1256                .verify_prehash(&test, &sig)
1257                .expect("failed to verify");
1258        }
1259    }
1260
1261    #[cfg(feature = "hazmat")]
1262    #[test]
1263    fn test_sign_and_verify_pss_blinded_hazmat() {
1264        let priv_key = get_private_key();
1265
1266        let tests = [Sha1::digest("test\n")];
1267        let mut rng = ChaCha8Rng::from_seed([42; 32]);
1268        let signing_key = BlindedSigningKey::<Sha1>::new(priv_key);
1269        let verifying_key = VerifyingKey::from(&signing_key);
1270
1271        for test in &tests {
1272            let sig = signing_key
1273                .sign_prehash_with_rng(&mut rng, &test)
1274                .expect("failed to sign");
1275            verifying_key
1276                .verify_prehash(&test, &sig)
1277                .expect("failed to verify");
1278        }
1279    }
1280}