rsa/
algorithms.rs

1//! Useful algorithms related to RSA.
2
3use digest::{Digest, DynDigest, FixedOutputReset};
4use num_bigint::traits::ModInverse;
5use num_bigint::{BigUint, RandPrime};
6#[allow(unused_imports)]
7use num_traits::Float;
8use num_traits::{FromPrimitive, One, Zero};
9use rand_core::{CryptoRng, RngCore};
10
11use crate::errors::{Error, Result};
12use crate::key::RsaPrivateKey;
13
14/// Default exponent for RSA keys.
15const EXP: u64 = 65537;
16
17/// Generates a multi-prime RSA keypair of the given bit size,
18/// and the given random source, as suggested in [1]. Although the public
19/// keys are compatible (actually, indistinguishable) from the 2-prime case,
20/// the private keys are not. Thus it may not be possible to export multi-prime
21/// private keys in certain formats or to subsequently import them into other
22/// code.
23///
24/// Uses default public key exponent of `65537`. If you want to use a custom
25/// public key exponent value, use `algorithms::generate_multi_prime_key_with_exp`
26/// instead.
27///
28/// Table 1 in [2] suggests maximum numbers of primes for a given size.
29///
30/// [1]: https://patents.google.com/patent/US4405829A/en
31/// [2]: https://cacr.uwaterloo.ca/techreports/2006/cacr2006-16.pdf
32pub fn generate_multi_prime_key<R: RngCore + CryptoRng>(
33    rng: &mut R,
34    nprimes: usize,
35    bit_size: usize,
36) -> Result<RsaPrivateKey> {
37    let exp = BigUint::from_u64(EXP).expect("invalid static exponent");
38    generate_multi_prime_key_with_exp(rng, nprimes, bit_size, &exp)
39}
40
41/// Generates a multi-prime RSA keypair of the given bit size, public exponent,
42/// and the given random source, as suggested in [1]. Although the public
43/// keys are compatible (actually, indistinguishable) from the 2-prime case,
44/// the private keys are not. Thus it may not be possible to export multi-prime
45/// private keys in certain formats or to subsequently import them into other
46/// code.
47///
48/// Table 1 in [2] suggests maximum numbers of primes for a given size.
49///
50/// [1]: https://patents.google.com/patent/US4405829A/en
51/// [2]: http://www.cacr.math.uwaterloo.ca/techreports/2006/cacr2006-16.pdf
52pub fn generate_multi_prime_key_with_exp<R: RngCore + CryptoRng>(
53    rng: &mut R,
54    nprimes: usize,
55    bit_size: usize,
56    exp: &BigUint,
57) -> Result<RsaPrivateKey> {
58    if nprimes < 2 {
59        return Err(Error::NprimesTooSmall);
60    }
61
62    if bit_size < 64 {
63        let prime_limit = (1u64 << (bit_size / nprimes) as u64) as f64;
64
65        // pi aproximates the number of primes less than prime_limit
66        let mut pi = prime_limit / (prime_limit.ln() - 1f64);
67        // Generated primes start with 0b11, so we can only use a quarter of them.
68        pi /= 4f64;
69        // Use a factor of two to ensure that key generation terminates in a
70        // reasonable amount of time.
71        pi /= 2f64;
72
73        if pi < nprimes as f64 {
74            return Err(Error::TooFewPrimes);
75        }
76    }
77
78    let mut primes = vec![BigUint::zero(); nprimes];
79    let n_final: BigUint;
80    let d_final: BigUint;
81
82    'next: loop {
83        let mut todo = bit_size;
84        // `gen_prime` should set the top two bits in each prime.
85        // Thus each prime has the form
86        //   p_i = 2^bitlen(p_i) × 0.11... (in base 2).
87        // And the product is:
88        //   P = 2^todo × α
89        // where α is the product of nprimes numbers of the form 0.11...
90        //
91        // If α < 1/2 (which can happen for nprimes > 2), we need to
92        // shift todo to compensate for lost bits: the mean value of 0.11...
93        // is 7/8, so todo + shift - nprimes * log2(7/8) ~= bits - 1/2
94        // will give good results.
95        if nprimes >= 7 {
96            todo += (nprimes - 2) / 5;
97        }
98
99        for (i, prime) in primes.iter_mut().enumerate() {
100            *prime = rng.gen_prime(todo / (nprimes - i));
101            todo -= prime.bits();
102        }
103
104        // Makes sure that primes is pairwise unequal.
105        for (i, prime1) in primes.iter().enumerate() {
106            for prime2 in primes.iter().take(i) {
107                if prime1 == prime2 {
108                    continue 'next;
109                }
110            }
111        }
112
113        let mut n = BigUint::one();
114        let mut totient = BigUint::one();
115
116        for prime in &primes {
117            n *= prime;
118            totient *= prime - BigUint::one();
119        }
120
121        if n.bits() != bit_size {
122            // This should never happen for nprimes == 2 because
123            // gen_prime should set the top two bits in each prime.
124            // For nprimes > 2 we hope it does not happen often.
125            continue 'next;
126        }
127
128        if let Some(d) = exp.mod_inverse(totient) {
129            n_final = n;
130            d_final = d.to_biguint().unwrap();
131            break;
132        }
133    }
134
135    RsaPrivateKey::from_components(n_final, exp.clone(), d_final, primes)
136}
137
138/// Mask generation function.
139///
140/// Panics if out is larger than 2**32. This is in accordance with RFC 8017 - PKCS #1 B.2.1
141pub fn mgf1_xor(out: &mut [u8], digest: &mut dyn DynDigest, seed: &[u8]) {
142    let mut counter = [0u8; 4];
143    let mut i = 0;
144
145    const MAX_LEN: u64 = core::u32::MAX as u64 + 1;
146    assert!(out.len() as u64 <= MAX_LEN);
147
148    while i < out.len() {
149        let mut digest_input = vec![0u8; seed.len() + 4];
150        digest_input[0..seed.len()].copy_from_slice(seed);
151        digest_input[seed.len()..].copy_from_slice(&counter);
152
153        digest.update(digest_input.as_slice());
154        let digest_output = &*digest.finalize_reset();
155        let mut j = 0;
156        loop {
157            if j >= digest_output.len() || i >= out.len() {
158                break;
159            }
160
161            out[i] ^= digest_output[j];
162            j += 1;
163            i += 1;
164        }
165        inc_counter(&mut counter);
166    }
167}
168
169/// Mask generation function.
170///
171/// Panics if out is larger than 2**32. This is in accordance with RFC 8017 - PKCS #1 B.2.1
172pub fn mgf1_xor_digest<D>(out: &mut [u8], digest: &mut D, seed: &[u8])
173where
174    D: Digest + FixedOutputReset,
175{
176    let mut counter = [0u8; 4];
177    let mut i = 0;
178
179    const MAX_LEN: u64 = core::u32::MAX as u64 + 1;
180    assert!(out.len() as u64 <= MAX_LEN);
181
182    while i < out.len() {
183        Digest::update(digest, seed);
184        Digest::update(digest, counter);
185
186        let digest_output = digest.finalize_reset();
187        let mut j = 0;
188        loop {
189            if j >= digest_output.len() || i >= out.len() {
190                break;
191            }
192
193            out[i] ^= digest_output[j];
194            j += 1;
195            i += 1;
196        }
197        inc_counter(&mut counter);
198    }
199}
200fn inc_counter(counter: &mut [u8; 4]) {
201    for i in (0..4).rev() {
202        counter[i] = counter[i].wrapping_add(1);
203        if counter[i] != 0 {
204            // No overflow
205            return;
206        }
207    }
208}