# Copyright 2019-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); you # may not use this file except in compliance with the License. You # may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. """A CPython compatible SSLContext implementation wrapping PyOpenSSL's context. Due to limitations of the CPython asyncio.Protocol implementation for SSL, the async API does not support PyOpenSSL. """ from __future__ import annotations import socket as _socket import ssl as _stdlibssl import sys as _sys import time as _time from errno import EINTR as _EINTR from ipaddress import ip_address as _ip_address from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, Union import cryptography.x509 as x509 import service_identity from OpenSSL import SSL as _SSL from OpenSSL import crypto as _crypto from pymongo.errors import ConfigurationError as _ConfigurationError from pymongo.errors import _CertificateError # type:ignore[attr-defined] from pymongo.ocsp_cache import _OCSPCache from pymongo.ocsp_support import _load_trusted_ca_certs, _ocsp_callback from pymongo.socket_checker import SocketChecker as _SocketChecker from pymongo.socket_checker import _errno_from_exception from pymongo.write_concern import validate_boolean if TYPE_CHECKING: from ssl import VerifyMode _T = TypeVar("_T") try: import certifi _HAVE_CERTIFI = True except ImportError: _HAVE_CERTIFI = False PROTOCOL_SSLv23 = _SSL.SSLv23_METHOD # Always available OP_NO_SSLv2 = _SSL.OP_NO_SSLv2 OP_NO_SSLv3 = _SSL.OP_NO_SSLv3 OP_NO_COMPRESSION = _SSL.OP_NO_COMPRESSION # This isn't currently documented for PyOpenSSL OP_NO_RENEGOTIATION = getattr(_SSL, "OP_NO_RENEGOTIATION", 0) # Always available HAS_SNI = True IS_PYOPENSSL = True # Base Exception class SSLError = _SSL.Error # https://github.com/python/cpython/blob/v3.8.0/Modules/_ssl.c#L2995-L3002 _VERIFY_MAP = { _stdlibssl.CERT_NONE: _SSL.VERIFY_NONE, _stdlibssl.CERT_OPTIONAL: _SSL.VERIFY_PEER, _stdlibssl.CERT_REQUIRED: _SSL.VERIFY_PEER | _SSL.VERIFY_FAIL_IF_NO_PEER_CERT, } _REVERSE_VERIFY_MAP = {value: key for key, value in _VERIFY_MAP.items()} # For SNI support. According to RFC6066, section 3, IPv4 and IPv6 literals are # not permitted for SNI hostname. def _is_ip_address(address: Any) -> bool: try: _ip_address(address) return True except (ValueError, UnicodeError): return False # According to the docs for socket.send it can raise # WantX509LookupError and should be retried. BLOCKING_IO_ERRORS = (_SSL.WantReadError, _SSL.WantWriteError, _SSL.WantX509LookupError) BLOCKING_IO_READ_ERROR = _SSL.WantReadError BLOCKING_IO_WRITE_ERROR = _SSL.WantWriteError BLOCKING_IO_LOOKUP_ERROR = _SSL.WantX509LookupError def _ragged_eof(exc: BaseException) -> bool: """Return True if the OpenSSL.SSL.SysCallError is a ragged EOF.""" return exc.args == (-1, "Unexpected EOF") # https://github.com/pyca/pyopenssl/issues/168 # https://github.com/pyca/pyopenssl/issues/176 # https://docs.python.org/3/library/ssl.html#notes-on-non-blocking-sockets class _sslConn(_SSL.Connection): def __init__( self, ctx: _SSL.Context, sock: Optional[_socket.socket], suppress_ragged_eofs: bool, ): self.socket_checker = _SocketChecker() self.suppress_ragged_eofs = suppress_ragged_eofs super().__init__(ctx, sock) def _call(self, call: Callable[..., _T], *args: Any, **kwargs: Any) -> _T: timeout = self.gettimeout() if timeout: start = _time.monotonic() while True: try: return call(*args, **kwargs) except BLOCKING_IO_ERRORS as exc: # Do not retry if the connection is in non-blocking mode. if timeout == 0: raise exc # Check for closed socket. if self.fileno() == -1: if timeout and _time.monotonic() - start > timeout: raise _socket.timeout("timed out") from None raise SSLError("Underlying socket has been closed") from None if isinstance(exc, _SSL.WantReadError): want_read = True want_write = False elif isinstance(exc, _SSL.WantWriteError): want_read = False want_write = True else: want_read = True want_write = True self.socket_checker.select(self, want_read, want_write, timeout) if timeout and _time.monotonic() - start > timeout: raise _socket.timeout("timed out") from None continue def do_handshake(self, *args: Any, **kwargs: Any) -> None: return self._call(super().do_handshake, *args, **kwargs) def recv(self, *args: Any, **kwargs: Any) -> bytes: try: return self._call(super().recv, *args, **kwargs) except _SSL.SysCallError as exc: # Suppress ragged EOFs to match the stdlib. if self.suppress_ragged_eofs and _ragged_eof(exc): return b"" raise def recv_into(self, *args: Any, **kwargs: Any) -> int: try: return self._call(super().recv_into, *args, **kwargs) except _SSL.SysCallError as exc: # Suppress ragged EOFs to match the stdlib. if self.suppress_ragged_eofs and _ragged_eof(exc): return 0 raise def sendall(self, buf: bytes, flags: int = 0) -> None: # type: ignore[override] view = memoryview(buf) total_length = len(buf) total_sent = 0 while total_sent < total_length: try: sent = self._call(super().send, view[total_sent:], flags) # XXX: It's not clear if this can actually happen. PyOpenSSL # doesn't appear to have any interrupt handling, nor any interrupt # errors for OpenSSL connections. except OSError as exc: if _errno_from_exception(exc) == _EINTR: continue raise # https://github.com/pyca/pyopenssl/blob/19.1.0/src/OpenSSL/SSL.py#L1756 # https://www.openssl.org/docs/man1.0.2/man3/SSL_write.html if sent <= 0: raise OSError("connection closed") total_sent += sent class _CallbackData: """Data class which is passed to the OCSP callback.""" def __init__(self) -> None: self.trusted_ca_certs: Optional[list[x509.Certificate]] = None self.check_ocsp_endpoint: Optional[bool] = None self.ocsp_response_cache = _OCSPCache() class SSLContext: """A CPython compatible SSLContext implementation wrapping PyOpenSSL's context. """ __slots__ = ("_protocol", "_ctx", "_callback_data", "_check_hostname") def __init__(self, protocol: int): self._protocol = protocol self._ctx = _SSL.Context(self._protocol) self._callback_data = _CallbackData() self._check_hostname = True # OCSP # XXX: Find a better place to do this someday, since this is client # side configuration and wrap_socket tries to support both client and # server side sockets. self._callback_data.check_ocsp_endpoint = True self._ctx.set_ocsp_client_callback(callback=_ocsp_callback, data=self._callback_data) @property def protocol(self) -> int: """The protocol version chosen when constructing the context. This attribute is read-only. """ return self._protocol def __get_verify_mode(self) -> VerifyMode: """Whether to try to verify other peers' certificates and how to behave if verification fails. This attribute must be one of ssl.CERT_NONE, ssl.CERT_OPTIONAL or ssl.CERT_REQUIRED. """ return _REVERSE_VERIFY_MAP[self._ctx.get_verify_mode()] def __set_verify_mode(self, value: VerifyMode) -> None: """Setter for verify_mode.""" def _cb( _connobj: _SSL.Connection, _x509obj: _crypto.X509, _errnum: int, _errdepth: int, retcode: int, ) -> bool: # It seems we don't need to do anything here. Twisted doesn't, # and OpenSSL's SSL_CTX_set_verify let's you pass NULL # for the callback option. It's weird that PyOpenSSL requires # this. # This is optional in pyopenssl >= 20 and can be removed once minimum # supported version is bumped # See: pyopenssl.org/en/latest/changelog.html#id47 return bool(retcode) self._ctx.set_verify(_VERIFY_MAP[value], _cb) verify_mode = property(__get_verify_mode, __set_verify_mode) def __get_check_hostname(self) -> bool: return self._check_hostname def __set_check_hostname(self, value: Any) -> None: validate_boolean("check_hostname", value) self._check_hostname = value check_hostname = property(__get_check_hostname, __set_check_hostname) def __get_check_ocsp_endpoint(self) -> Optional[bool]: return self._callback_data.check_ocsp_endpoint def __set_check_ocsp_endpoint(self, value: bool) -> None: validate_boolean("check_ocsp", value) self._callback_data.check_ocsp_endpoint = value check_ocsp_endpoint = property(__get_check_ocsp_endpoint, __set_check_ocsp_endpoint) def __get_options(self) -> int: # Calling set_options adds the option to the existing bitmask and # returns the new bitmask. # https://www.pyopenssl.org/en/stable/api/ssl.html#OpenSSL.SSL.Context.set_options return self._ctx.set_options(0) def __set_options(self, value: int) -> None: # Explicitly convert to int, since newer CPython versions # use enum.IntFlag for options. The values are the same # regardless of implementation. self._ctx.set_options(int(value)) options = property(__get_options, __set_options) def load_cert_chain( self, certfile: Union[str, bytes], keyfile: Union[str, bytes, None] = None, password: Optional[str] = None, ) -> None: """Load a private key and the corresponding certificate. The certfile string must be the path to a single file in PEM format containing the certificate as well as any number of CA certificates needed to establish the certificate's authenticity. The keyfile string, if present, must point to a file containing the private key. Otherwise the private key will be taken from certfile as well. """ # Match CPython behavior # https://github.com/python/cpython/blob/v3.8.0/Modules/_ssl.c#L3930-L3971 # Password callback MUST be set first or it will be ignored. if password: def _pwcb(_max_length: int, _prompt_twice: bool, _user_data: Optional[bytes]) -> bytes: # XXX:We could check the password length against what OpenSSL # tells us is the max, but we can't raise an exception, so... # warn? assert password is not None return password.encode("utf-8") self._ctx.set_passwd_cb(_pwcb) self._ctx.use_certificate_chain_file(certfile) self._ctx.use_privatekey_file(keyfile or certfile) self._ctx.check_privatekey() def load_verify_locations( self, cafile: Optional[str] = None, capath: Optional[str] = None ) -> None: """Load a set of "certification authority"(CA) certificates used to validate other peers' certificates when `~verify_mode` is other than ssl.CERT_NONE. """ self._ctx.load_verify_locations(cafile, capath) # Manually load the CA certs when get_verified_chain is not available (pyopenssl<20). if not hasattr(_SSL.Connection, "get_verified_chain"): assert cafile is not None self._callback_data.trusted_ca_certs = _load_trusted_ca_certs(cafile) def _load_certifi(self) -> None: """Attempt to load CA certs from certifi.""" if _HAVE_CERTIFI: self.load_verify_locations(certifi.where()) else: raise _ConfigurationError( "tlsAllowInvalidCertificates is False but no system " "CA certificates could be loaded. Please install the " "certifi package, or provide a path to a CA file using " "the tlsCAFile option" ) def _load_wincerts(self, store: str) -> None: """Attempt to load CA certs from Windows trust store.""" cert_store = self._ctx.get_cert_store() assert cert_store is not None oid = _stdlibssl.Purpose.SERVER_AUTH.oid for cert, encoding, trust in _stdlibssl.enum_certificates(store): # type: ignore if encoding == "x509_asn": if trust is True or oid in trust: cert_store.add_cert( _crypto.X509.from_cryptography(x509.load_der_x509_certificate(cert)) ) def load_default_certs(self) -> None: """A PyOpenSSL version of load_default_certs from CPython.""" # PyOpenSSL is incapable of loading CA certs from Windows, and mostly # incapable on macOS. # https://www.pyopenssl.org/en/stable/api/ssl.html#OpenSSL.SSL.Context.set_default_verify_paths if _sys.platform == "win32": try: for storename in ("CA", "ROOT"): self._load_wincerts(storename) except PermissionError: # Fall back to certifi self._load_certifi() elif _sys.platform == "darwin": self._load_certifi() self._ctx.set_default_verify_paths() def set_default_verify_paths(self) -> None: """Specify that the platform provided CA certificates are to be used for verification purposes. """ # Note: See PyOpenSSL's docs for limitations, which are similar # but not that same as CPython's. self._ctx.set_default_verify_paths() def wrap_socket( self, sock: _socket.socket, server_side: bool = False, do_handshake_on_connect: bool = True, suppress_ragged_eofs: bool = True, server_hostname: Optional[str] = None, session: Optional[_SSL.Session] = None, ) -> _sslConn: """Wrap an existing Python socket connection and return a TLS socket object. """ ssl_conn = _sslConn(self._ctx, sock, suppress_ragged_eofs) if session: ssl_conn.set_session(session) if server_side is True: ssl_conn.set_accept_state() else: # SNI if server_hostname and not _is_ip_address(server_hostname): # XXX: Do this in a callback registered with # SSLContext.set_info_callback? See Twisted for an example. ssl_conn.set_tlsext_host_name(server_hostname.encode("idna")) if self.verify_mode != _stdlibssl.CERT_NONE: # Request a stapled OCSP response. ssl_conn.request_ocsp() ssl_conn.set_connect_state() # If this wasn't true the caller of wrap_socket would call # do_handshake() if do_handshake_on_connect: # XXX: If we do hostname checking in a callback we can get rid # of this call to do_handshake() since the handshake # will happen automatically later. ssl_conn.do_handshake() # XXX: Do this in a callback registered with # SSLContext.set_info_callback? See Twisted for an example. if self.check_hostname and server_hostname is not None: from service_identity import pyopenssl try: if _is_ip_address(server_hostname): pyopenssl.verify_ip_address(ssl_conn, server_hostname) else: pyopenssl.verify_hostname(ssl_conn, server_hostname) except ( service_identity.CertificateError, service_identity.VerificationError, ) as exc: raise _CertificateError(str(exc)) from None return ssl_conn