rsa/algorithms.rs
1//! Useful algorithms related to RSA.
2
3use digest::{Digest, DynDigest, FixedOutputReset};
4use num_bigint::traits::ModInverse;
5use num_bigint::{BigUint, RandPrime};
6#[allow(unused_imports)]
7use num_traits::Float;
8use num_traits::{FromPrimitive, One, Zero};
9use rand_core::{CryptoRng, RngCore};
10
11use crate::errors::{Error, Result};
12use crate::key::RsaPrivateKey;
13
14/// Default exponent for RSA keys.
15const EXP: u64 = 65537;
16
17/// Generates a multi-prime RSA keypair of the given bit size,
18/// and the given random source, as suggested in [1]. Although the public
19/// keys are compatible (actually, indistinguishable) from the 2-prime case,
20/// the private keys are not. Thus it may not be possible to export multi-prime
21/// private keys in certain formats or to subsequently import them into other
22/// code.
23///
24/// Uses default public key exponent of `65537`. If you want to use a custom
25/// public key exponent value, use `algorithms::generate_multi_prime_key_with_exp`
26/// instead.
27///
28/// Table 1 in [2] suggests maximum numbers of primes for a given size.
29///
30/// [1]: https://patents.google.com/patent/US4405829A/en
31/// [2]: https://cacr.uwaterloo.ca/techreports/2006/cacr2006-16.pdf
32pub fn generate_multi_prime_key<R: RngCore + CryptoRng>(
33 rng: &mut R,
34 nprimes: usize,
35 bit_size: usize,
36) -> Result<RsaPrivateKey> {
37 let exp = BigUint::from_u64(EXP).expect("invalid static exponent");
38 generate_multi_prime_key_with_exp(rng, nprimes, bit_size, &exp)
39}
40
41/// Generates a multi-prime RSA keypair of the given bit size, public exponent,
42/// and the given random source, as suggested in [1]. Although the public
43/// keys are compatible (actually, indistinguishable) from the 2-prime case,
44/// the private keys are not. Thus it may not be possible to export multi-prime
45/// private keys in certain formats or to subsequently import them into other
46/// code.
47///
48/// Table 1 in [2] suggests maximum numbers of primes for a given size.
49///
50/// [1]: https://patents.google.com/patent/US4405829A/en
51/// [2]: http://www.cacr.math.uwaterloo.ca/techreports/2006/cacr2006-16.pdf
52pub fn generate_multi_prime_key_with_exp<R: RngCore + CryptoRng>(
53 rng: &mut R,
54 nprimes: usize,
55 bit_size: usize,
56 exp: &BigUint,
57) -> Result<RsaPrivateKey> {
58 if nprimes < 2 {
59 return Err(Error::NprimesTooSmall);
60 }
61
62 if bit_size < 64 {
63 let prime_limit = (1u64 << (bit_size / nprimes) as u64) as f64;
64
65 // pi aproximates the number of primes less than prime_limit
66 let mut pi = prime_limit / (prime_limit.ln() - 1f64);
67 // Generated primes start with 0b11, so we can only use a quarter of them.
68 pi /= 4f64;
69 // Use a factor of two to ensure that key generation terminates in a
70 // reasonable amount of time.
71 pi /= 2f64;
72
73 if pi < nprimes as f64 {
74 return Err(Error::TooFewPrimes);
75 }
76 }
77
78 let mut primes = vec![BigUint::zero(); nprimes];
79 let n_final: BigUint;
80 let d_final: BigUint;
81
82 'next: loop {
83 let mut todo = bit_size;
84 // `gen_prime` should set the top two bits in each prime.
85 // Thus each prime has the form
86 // p_i = 2^bitlen(p_i) × 0.11... (in base 2).
87 // And the product is:
88 // P = 2^todo × α
89 // where α is the product of nprimes numbers of the form 0.11...
90 //
91 // If α < 1/2 (which can happen for nprimes > 2), we need to
92 // shift todo to compensate for lost bits: the mean value of 0.11...
93 // is 7/8, so todo + shift - nprimes * log2(7/8) ~= bits - 1/2
94 // will give good results.
95 if nprimes >= 7 {
96 todo += (nprimes - 2) / 5;
97 }
98
99 for (i, prime) in primes.iter_mut().enumerate() {
100 *prime = rng.gen_prime(todo / (nprimes - i));
101 todo -= prime.bits();
102 }
103
104 // Makes sure that primes is pairwise unequal.
105 for (i, prime1) in primes.iter().enumerate() {
106 for prime2 in primes.iter().take(i) {
107 if prime1 == prime2 {
108 continue 'next;
109 }
110 }
111 }
112
113 let mut n = BigUint::one();
114 let mut totient = BigUint::one();
115
116 for prime in &primes {
117 n *= prime;
118 totient *= prime - BigUint::one();
119 }
120
121 if n.bits() != bit_size {
122 // This should never happen for nprimes == 2 because
123 // gen_prime should set the top two bits in each prime.
124 // For nprimes > 2 we hope it does not happen often.
125 continue 'next;
126 }
127
128 if let Some(d) = exp.mod_inverse(totient) {
129 n_final = n;
130 d_final = d.to_biguint().unwrap();
131 break;
132 }
133 }
134
135 RsaPrivateKey::from_components(n_final, exp.clone(), d_final, primes)
136}
137
138/// Mask generation function.
139///
140/// Panics if out is larger than 2**32. This is in accordance with RFC 8017 - PKCS #1 B.2.1
141pub fn mgf1_xor(out: &mut [u8], digest: &mut dyn DynDigest, seed: &[u8]) {
142 let mut counter = [0u8; 4];
143 let mut i = 0;
144
145 const MAX_LEN: u64 = core::u32::MAX as u64 + 1;
146 assert!(out.len() as u64 <= MAX_LEN);
147
148 while i < out.len() {
149 let mut digest_input = vec![0u8; seed.len() + 4];
150 digest_input[0..seed.len()].copy_from_slice(seed);
151 digest_input[seed.len()..].copy_from_slice(&counter);
152
153 digest.update(digest_input.as_slice());
154 let digest_output = &*digest.finalize_reset();
155 let mut j = 0;
156 loop {
157 if j >= digest_output.len() || i >= out.len() {
158 break;
159 }
160
161 out[i] ^= digest_output[j];
162 j += 1;
163 i += 1;
164 }
165 inc_counter(&mut counter);
166 }
167}
168
169/// Mask generation function.
170///
171/// Panics if out is larger than 2**32. This is in accordance with RFC 8017 - PKCS #1 B.2.1
172pub fn mgf1_xor_digest<D>(out: &mut [u8], digest: &mut D, seed: &[u8])
173where
174 D: Digest + FixedOutputReset,
175{
176 let mut counter = [0u8; 4];
177 let mut i = 0;
178
179 const MAX_LEN: u64 = core::u32::MAX as u64 + 1;
180 assert!(out.len() as u64 <= MAX_LEN);
181
182 while i < out.len() {
183 Digest::update(digest, seed);
184 Digest::update(digest, counter);
185
186 let digest_output = digest.finalize_reset();
187 let mut j = 0;
188 loop {
189 if j >= digest_output.len() || i >= out.len() {
190 break;
191 }
192
193 out[i] ^= digest_output[j];
194 j += 1;
195 i += 1;
196 }
197 inc_counter(&mut counter);
198 }
199}
200fn inc_counter(counter: &mut [u8; 4]) {
201 for i in (0..4).rev() {
202 counter[i] = counter[i].wrapping_add(1);
203 if counter[i] != 0 {
204 // No overflow
205 return;
206 }
207 }
208}