Skip to content

Commit 2acc3be

Browse files
authored
More SSL impl (RustPython#6224)
* fix ipv6 formattig * consts * fspath * fix set_ecdh_curve * minimum/maximum version * Add SSL_CTX_security_level
1 parent b6e8a87 commit 2acc3be

File tree

2 files changed

+88
-33
lines changed

2 files changed

+88
-33
lines changed

stdlib/src/ssl.rs

Lines changed: 81 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -94,9 +94,6 @@ mod _ssl {
9494
SSL_ERROR_WANT_CONNECT,
9595
SSL_ERROR_WANT_READ,
9696
SSL_ERROR_WANT_WRITE,
97-
// X509_V_FLAG_CRL_CHECK as VERIFY_CRL_CHECK_LEAF,
98-
// sys::X509_V_FLAG_CRL_CHECK|sys::X509_V_FLAG_CRL_CHECK_ALL as VERIFY_CRL_CHECK_CHAIN
99-
// X509_V_FLAG_X509_STRICT as VERIFY_X509_STRICT,
10097
SSL_ERROR_ZERO_RETURN,
10198
SSL_OP_CIPHER_SERVER_PREFERENCE as OP_CIPHER_SERVER_PREFERENCE,
10299
SSL_OP_ENABLE_MIDDLEBOX_COMPAT as OP_ENABLE_MIDDLEBOX_COMPAT,
@@ -114,6 +111,11 @@ mod _ssl {
114111
X509_V_FLAG_X509_STRICT as VERIFY_X509_STRICT,
115112
};
116113

114+
// CRL verification constants
115+
#[pyattr]
116+
const VERIFY_CRL_CHECK_CHAIN: libc::c_ulong =
117+
sys::X509_V_FLAG_CRL_CHECK | sys::X509_V_FLAG_CRL_CHECK_ALL;
118+
117119
// taken from CPython, should probably be kept up to date with their version if it ever changes
118120
#[pyattr]
119121
const _DEFAULT_CIPHERS: &str =
@@ -631,6 +633,12 @@ mod _ssl {
631633
Ok(())
632634
}
633635

636+
#[cfg(ossl110)]
637+
#[pygetset]
638+
fn security_level(&self) -> i32 {
639+
unsafe { SSL_CTX_get_security_level(self.ctx().as_ptr()) }
640+
}
641+
634642
#[pymethod]
635643
fn set_ciphers(&self, cipherlist: PyStrRef, vm: &VirtualMachine) -> PyResult<()> {
636644
let ciphers = cipherlist.as_str();
@@ -677,19 +685,29 @@ mod _ssl {
677685
}
678686

679687
#[pymethod]
680-
fn set_ecdh_curve(&self, name: PyStrRef, vm: &VirtualMachine) -> PyResult<()> {
688+
fn set_ecdh_curve(
689+
&self,
690+
name: Either<PyStrRef, ArgBytesLike>,
691+
vm: &VirtualMachine,
692+
) -> PyResult<()> {
681693
use openssl::ec::{EcGroup, EcKey};
682694

683-
let curve_name = name.as_str();
684-
if curve_name.contains('\0') {
685-
return Err(exceptions::cstring_error(vm));
686-
}
695+
// Convert name to CString, supporting both str and bytes
696+
let name_cstr = match name {
697+
Either::A(s) => {
698+
if s.as_str().contains('\0') {
699+
return Err(exceptions::cstring_error(vm));
700+
}
701+
s.to_cstring(vm)?
702+
}
703+
Either::B(b) => std::ffi::CString::new(b.borrow_buf().to_vec())
704+
.map_err(|_| exceptions::cstring_error(vm))?,
705+
};
687706

688707
// Find the NID for the curve name using OBJ_sn2nid
689-
let name_cstr = name.to_cstring(vm)?;
690708
let nid_raw = unsafe { sys::OBJ_sn2nid(name_cstr.as_ptr()) };
691709
if nid_raw == 0 {
692-
return Err(vm.new_value_error(format!("unknown curve name: {}", curve_name)));
710+
return Err(vm.new_value_error("unknown curve name"));
693711
}
694712
let nid = Nid::from_raw(nid_raw);
695713

@@ -794,6 +812,47 @@ mod _ssl {
794812
self.check_hostname.store(ch);
795813
}
796814

815+
// PY_PROTO_MINIMUM_SUPPORTED = -2, PY_PROTO_MAXIMUM_SUPPORTED = -1
816+
#[pygetset]
817+
fn minimum_version(&self) -> i32 {
818+
let ctx = self.ctx();
819+
let version = unsafe { sys::SSL_CTX_get_min_proto_version(ctx.as_ptr()) };
820+
if version == 0 {
821+
-2 // PY_PROTO_MINIMUM_SUPPORTED
822+
} else {
823+
version
824+
}
825+
}
826+
#[pygetset(setter)]
827+
fn set_minimum_version(&self, value: i32, vm: &VirtualMachine) -> PyResult<()> {
828+
let ctx = self.builder();
829+
let result = unsafe { sys::SSL_CTX_set_min_proto_version(ctx.as_ptr(), value) };
830+
if result == 0 {
831+
return Err(vm.new_value_error("invalid protocol version"));
832+
}
833+
Ok(())
834+
}
835+
836+
#[pygetset]
837+
fn maximum_version(&self) -> i32 {
838+
let ctx = self.ctx();
839+
let version = unsafe { sys::SSL_CTX_get_max_proto_version(ctx.as_ptr()) };
840+
if version == 0 {
841+
-1 // PY_PROTO_MAXIMUM_SUPPORTED
842+
} else {
843+
version
844+
}
845+
}
846+
#[pygetset(setter)]
847+
fn set_maximum_version(&self, value: i32, vm: &VirtualMachine) -> PyResult<()> {
848+
let ctx = self.builder();
849+
let result = unsafe { sys::SSL_CTX_set_max_proto_version(ctx.as_ptr(), value) };
850+
if result == 0 {
851+
return Err(vm.new_value_error("invalid protocol version"));
852+
}
853+
Ok(())
854+
}
855+
797856
#[pymethod]
798857
fn set_default_verify_paths(&self, vm: &VirtualMachine) -> PyResult<()> {
799858
cfg_if::cfg_if! {
@@ -852,12 +911,6 @@ mod _ssl {
852911
if let (None, None, None) = (&args.cafile, &args.capath, &args.cadata) {
853912
return Err(vm.new_type_error("cafile, capath and cadata cannot be all omitted"));
854913
}
855-
if let Some(cafile) = &args.cafile {
856-
cafile.ensure_no_nul(vm)?
857-
}
858-
if let Some(capath) = &args.capath {
859-
capath.ensure_no_nul(vm)?
860-
}
861914

862915
#[cold]
863916
fn invalid_cadata(vm: &VirtualMachine) -> PyBaseExceptionRef {
@@ -887,11 +940,10 @@ mod _ssl {
887940
}
888941

889942
if args.cafile.is_some() || args.capath.is_some() {
890-
ctx.load_verify_locations(
891-
args.cafile.as_ref().map(|s| s.as_str().as_ref()),
892-
args.capath.as_ref().map(|s| s.as_str().as_ref()),
893-
)
894-
.map_err(|e| convert_openssl_error(vm, e))?;
943+
let cafile_path = args.cafile.map(|p| p.to_path_buf(vm)).transpose()?;
944+
let capath_path = args.capath.map(|p| p.to_path_buf(vm)).transpose()?;
945+
ctx.load_verify_locations(cafile_path.as_deref(), capath_path.as_deref())
946+
.map_err(|e| convert_openssl_error(vm, e))?;
895947
}
896948

897949
Ok(())
@@ -1064,9 +1116,9 @@ mod _ssl {
10641116
#[derive(FromArgs)]
10651117
struct LoadVerifyLocationsArgs {
10661118
#[pyarg(any, default)]
1067-
cafile: Option<PyStrRef>,
1119+
cafile: Option<FsPath>,
10681120
#[pyarg(any, default)]
1069-
capath: Option<PyStrRef>,
1121+
capath: Option<FsPath>,
10701122
#[pyarg(any, default)]
10711123
cadata: Option<Either<PyStrRef, ArgBytesLike>>,
10721124
}
@@ -1794,6 +1846,11 @@ mod _ssl {
17941846
fn SSL_verify_client_post_handshake(ssl: *const sys::SSL) -> libc::c_int;
17951847
}
17961848

1849+
#[cfg(ossl110)]
1850+
unsafe extern "C" {
1851+
fn SSL_CTX_get_security_level(ctx: *const sys::SSL_CTX) -> libc::c_int;
1852+
}
1853+
17971854
// OpenSSL BIO helper functions
17981855
// These are typically macros in OpenSSL, implemented via BIO_ctrl
17991856
const BIO_CTRL_PENDING: libc::c_int = 10;
@@ -2082,7 +2139,7 @@ mod _ssl {
20822139
let lib = sys::ERR_GET_LIB(err_code);
20832140
if lib == ERR_LIB_SSL && reason == SSL_R_UNEXPECTED_EOF_WHILE_READING {
20842141
return vm.new_exception(
2085-
vm.class("_ssl", "SSLEOFError"),
2142+
PySslEOFError::class(&vm.ctx).to_owned(),
20862143
vec![
20872144
vm.ctx.new_int(SSL_ERROR_EOF).into(),
20882145
vm.ctx

stdlib/src/ssl/cert.rs

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -164,17 +164,15 @@ pub(crate) mod ssl_cert {
164164
// IPv4
165165
format!("{}.{}.{}.{}", ip[0], ip[1], ip[2], ip[3])
166166
} else if ip.len() == 16 {
167-
// IPv6 - format like: "X:X:X:X:X:X:X:X" (not compressed)
167+
// IPv6 - format with all zeros visible (not compressed)
168+
let ip_addr = std::net::Ipv6Addr::from([
169+
ip[0], ip[1], ip[2], ip[3], ip[4], ip[5], ip[6], ip[7], ip[8],
170+
ip[9], ip[10], ip[11], ip[12], ip[13], ip[14], ip[15],
171+
]);
172+
let s = ip_addr.segments();
168173
format!(
169174
"{:X}:{:X}:{:X}:{:X}:{:X}:{:X}:{:X}:{:X}",
170-
(ip[0] as u16) << 8 | ip[1] as u16,
171-
(ip[2] as u16) << 8 | ip[3] as u16,
172-
(ip[4] as u16) << 8 | ip[5] as u16,
173-
(ip[6] as u16) << 8 | ip[7] as u16,
174-
(ip[8] as u16) << 8 | ip[9] as u16,
175-
(ip[10] as u16) << 8 | ip[11] as u16,
176-
(ip[12] as u16) << 8 | ip[13] as u16,
177-
(ip[14] as u16) << 8 | ip[15] as u16
175+
s[0], s[1], s[2], s[3], s[4], s[5], s[6], s[7]
178176
)
179177
} else {
180178
// Fallback for unexpected length

0 commit comments

Comments
 (0)