curl/easy/
windows.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
#![allow(non_camel_case_types, non_snake_case)]

use libc::c_void;

#[cfg(target_env = "msvc")]
mod win {
    use schannel::cert_context::ValidUses;
    use schannel::cert_store::CertStore;
    use std::ffi::*;
    use std::mem;
    use std::ptr;
    use windows_sys::Win32::Security::Cryptography::*;
    use windows_sys::Win32::System::LibraryLoader::*;

    fn lookup(module: &str, symbol: &str) -> Option<*const c_void> {
        unsafe {
            let mut mod_buf: Vec<u16> = module.encode_utf16().collect();
            mod_buf.push(0);
            let handle = GetModuleHandleW(mod_buf.as_mut_ptr());
            GetProcAddress(handle, symbol.as_ptr()).map(|n| n as *const c_void)
        }
    }

    pub enum X509_STORE {}
    pub enum X509 {}
    pub enum SSL_CTX {}

    type d2i_X509_fn = unsafe extern "C" fn(
        a: *mut *mut X509,
        pp: *mut *const c_uchar,
        length: c_long,
    ) -> *mut X509;
    type X509_free_fn = unsafe extern "C" fn(x: *mut X509);
    type X509_STORE_add_cert_fn =
        unsafe extern "C" fn(store: *mut X509_STORE, x: *mut X509) -> c_int;
    type SSL_CTX_get_cert_store_fn = unsafe extern "C" fn(ctx: *const SSL_CTX) -> *mut X509_STORE;

    struct OpenSSL {
        d2i_X509: d2i_X509_fn,
        X509_free: X509_free_fn,
        X509_STORE_add_cert: X509_STORE_add_cert_fn,
        SSL_CTX_get_cert_store: SSL_CTX_get_cert_store_fn,
    }

    unsafe fn lookup_functions(crypto_module: &str, ssl_module: &str) -> Option<OpenSSL> {
        macro_rules! get {
            ($(let $sym:ident in $module:expr;)*) => ($(
                let $sym = match lookup($module, stringify!($sym)) {
                    Some(p) => p,
                    None => return None,
                };
            )*)
        }
        get! {
            let d2i_X509 in crypto_module;
            let X509_free in crypto_module;
            let X509_STORE_add_cert in crypto_module;
            let SSL_CTX_get_cert_store in ssl_module;
        }
        Some(OpenSSL {
            d2i_X509: mem::transmute(d2i_X509),
            X509_free: mem::transmute(X509_free),
            X509_STORE_add_cert: mem::transmute(X509_STORE_add_cert),
            SSL_CTX_get_cert_store: mem::transmute(SSL_CTX_get_cert_store),
        })
    }

    pub unsafe fn add_certs_to_context(ssl_ctx: *mut c_void) {
        // check the runtime version of OpenSSL
        let openssl = match crate::version::Version::get().ssl_version() {
            Some(ssl_ver) if ssl_ver.starts_with("OpenSSL/1.1.0") => {
                lookup_functions("libcrypto", "libssl")
            }
            Some(ssl_ver) if ssl_ver.starts_with("OpenSSL/1.0.2") => {
                lookup_functions("libeay32", "ssleay32")
            }
            _ => return,
        };
        let openssl = match openssl {
            Some(s) => s,
            None => return,
        };

        let openssl_store = (openssl.SSL_CTX_get_cert_store)(ssl_ctx as *const SSL_CTX);
        let store = match CertStore::open_current_user("ROOT") {
            Ok(s) => s,
            Err(_) => return,
        };

        for cert in store.certs() {
            let valid_uses = match cert.valid_uses() {
                Ok(v) => v,
                Err(_) => continue,
            };

            // check the extended key usage for the "Server Authentication" OID
            match valid_uses {
                ValidUses::All => {}
                ValidUses::Oids(ref oids) => {
                    let oid = CStr::from_ptr(szOID_PKIX_KP_SERVER_AUTH as *const _)
                        .to_string_lossy()
                        .into_owned();
                    if !oids.contains(&oid) {
                        continue;
                    }
                }
            }

            let der = cert.to_der();
            let x509 = (openssl.d2i_X509)(ptr::null_mut(), &mut der.as_ptr(), der.len() as c_long);
            if !x509.is_null() {
                (openssl.X509_STORE_add_cert)(openssl_store, x509);
                (openssl.X509_free)(x509);
            }
        }
    }
}

#[cfg(target_env = "msvc")]
pub fn add_certs_to_context(ssl_ctx: *mut c_void) {
    unsafe {
        win::add_certs_to_context(ssl_ctx as *mut _);
    }
}

#[cfg(not(target_env = "msvc"))]
pub fn add_certs_to_context(_: *mut c_void) {}