549 lines
20 KiB
Python
549 lines
20 KiB
Python
# Copyright 2025-present MongoDB, Inc.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""Pool utilities and shared helper methods."""
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import functools
|
|
import socket
|
|
import ssl
|
|
import sys
|
|
from typing import (
|
|
TYPE_CHECKING,
|
|
Any,
|
|
NoReturn,
|
|
Optional,
|
|
Union,
|
|
)
|
|
|
|
from pymongo import _csot
|
|
from pymongo.asynchronous.helpers import _getaddrinfo
|
|
from pymongo.errors import ( # type:ignore[attr-defined]
|
|
AutoReconnect,
|
|
ConnectionFailure,
|
|
NetworkTimeout,
|
|
_CertificateError,
|
|
)
|
|
from pymongo.network_layer import AsyncNetworkingInterface, NetworkingInterface, PyMongoProtocol
|
|
from pymongo.pool_options import PoolOptions
|
|
from pymongo.ssl_support import PYSSLError, SSLError, _has_sni
|
|
|
|
SSLErrors = (PYSSLError, SSLError)
|
|
if TYPE_CHECKING:
|
|
from pymongo.pyopenssl_context import _sslConn
|
|
from pymongo.typings import _Address
|
|
|
|
try:
|
|
from fcntl import F_GETFD, F_SETFD, FD_CLOEXEC, fcntl
|
|
|
|
def _set_non_inheritable_non_atomic(fd: int) -> None:
|
|
"""Set the close-on-exec flag on the given file descriptor."""
|
|
flags = fcntl(fd, F_GETFD)
|
|
fcntl(fd, F_SETFD, flags | FD_CLOEXEC)
|
|
|
|
except ImportError:
|
|
# Windows, various platforms we don't claim to support
|
|
# (Jython, IronPython, ..), systems that don't provide
|
|
# everything we need from fcntl, etc.
|
|
def _set_non_inheritable_non_atomic(fd: int) -> None: # noqa: ARG001
|
|
"""Dummy function for platforms that don't provide fcntl."""
|
|
|
|
|
|
_MAX_TCP_KEEPIDLE = 120
|
|
_MAX_TCP_KEEPINTVL = 10
|
|
_MAX_TCP_KEEPCNT = 9
|
|
|
|
if sys.platform == "win32":
|
|
try:
|
|
import _winreg as winreg
|
|
except ImportError:
|
|
import winreg
|
|
|
|
def _query(key, name, default):
|
|
try:
|
|
value, _ = winreg.QueryValueEx(key, name)
|
|
# Ensure the value is a number or raise ValueError.
|
|
return int(value)
|
|
except (OSError, ValueError):
|
|
# QueryValueEx raises OSError when the key does not exist (i.e.
|
|
# the system is using the Windows default value).
|
|
return default
|
|
|
|
try:
|
|
with winreg.OpenKey(
|
|
winreg.HKEY_LOCAL_MACHINE, r"SYSTEM\CurrentControlSet\Services\Tcpip\Parameters"
|
|
) as key:
|
|
_WINDOWS_TCP_IDLE_MS = _query(key, "KeepAliveTime", 7200000)
|
|
_WINDOWS_TCP_INTERVAL_MS = _query(key, "KeepAliveInterval", 1000)
|
|
except OSError:
|
|
# We could not check the default values because winreg.OpenKey failed.
|
|
# Assume the system is using the default values.
|
|
_WINDOWS_TCP_IDLE_MS = 7200000
|
|
_WINDOWS_TCP_INTERVAL_MS = 1000
|
|
|
|
def _set_keepalive_times(sock):
|
|
idle_ms = min(_WINDOWS_TCP_IDLE_MS, _MAX_TCP_KEEPIDLE * 1000)
|
|
interval_ms = min(_WINDOWS_TCP_INTERVAL_MS, _MAX_TCP_KEEPINTVL * 1000)
|
|
if idle_ms < _WINDOWS_TCP_IDLE_MS or interval_ms < _WINDOWS_TCP_INTERVAL_MS:
|
|
sock.ioctl(socket.SIO_KEEPALIVE_VALS, (1, idle_ms, interval_ms))
|
|
|
|
else:
|
|
|
|
def _set_tcp_option(sock: socket.socket, tcp_option: str, max_value: int) -> None:
|
|
if hasattr(socket, tcp_option):
|
|
sockopt = getattr(socket, tcp_option)
|
|
try:
|
|
# PYTHON-1350 - NetBSD doesn't implement getsockopt for
|
|
# TCP_KEEPIDLE and friends. Don't attempt to set the
|
|
# values there.
|
|
default = sock.getsockopt(socket.IPPROTO_TCP, sockopt)
|
|
if default > max_value:
|
|
sock.setsockopt(socket.IPPROTO_TCP, sockopt, max_value)
|
|
except OSError:
|
|
pass
|
|
|
|
def _set_keepalive_times(sock: socket.socket) -> None:
|
|
_set_tcp_option(sock, "TCP_KEEPIDLE", _MAX_TCP_KEEPIDLE)
|
|
_set_tcp_option(sock, "TCP_KEEPINTVL", _MAX_TCP_KEEPINTVL)
|
|
_set_tcp_option(sock, "TCP_KEEPCNT", _MAX_TCP_KEEPCNT)
|
|
|
|
|
|
def _raise_connection_failure(
|
|
address: Any,
|
|
error: Exception,
|
|
msg_prefix: Optional[str] = None,
|
|
timeout_details: Optional[dict[str, float]] = None,
|
|
) -> NoReturn:
|
|
"""Convert a socket.error to ConnectionFailure and raise it."""
|
|
host, port = address
|
|
# If connecting to a Unix socket, port will be None.
|
|
if port is not None:
|
|
msg = "%s:%d: %s" % (host, port, error)
|
|
else:
|
|
msg = f"{host}: {error}"
|
|
if msg_prefix:
|
|
msg = msg_prefix + msg
|
|
if "configured timeouts" not in msg:
|
|
msg += format_timeout_details(timeout_details)
|
|
if isinstance(error, socket.timeout):
|
|
raise NetworkTimeout(msg) from error
|
|
elif isinstance(error, SSLErrors) and "timed out" in str(error):
|
|
# Eventlet does not distinguish TLS network timeouts from other
|
|
# SSLErrors (https://github.com/eventlet/eventlet/issues/692).
|
|
# Luckily, we can work around this limitation because the phrase
|
|
# 'timed out' appears in all the timeout related SSLErrors raised.
|
|
raise NetworkTimeout(msg) from error
|
|
else:
|
|
raise AutoReconnect(msg) from error
|
|
|
|
|
|
def _get_timeout_details(options: PoolOptions) -> dict[str, float]:
|
|
details = {}
|
|
timeout = _csot.get_timeout()
|
|
socket_timeout = options.socket_timeout
|
|
connect_timeout = options.connect_timeout
|
|
if timeout:
|
|
details["timeoutMS"] = timeout * 1000
|
|
if socket_timeout and not timeout:
|
|
details["socketTimeoutMS"] = socket_timeout * 1000
|
|
if connect_timeout:
|
|
details["connectTimeoutMS"] = connect_timeout * 1000
|
|
return details
|
|
|
|
|
|
def format_timeout_details(details: Optional[dict[str, float]]) -> str:
|
|
result = ""
|
|
if details:
|
|
result += " (configured timeouts:"
|
|
for timeout in ["socketTimeoutMS", "timeoutMS", "connectTimeoutMS"]:
|
|
if timeout in details:
|
|
result += f" {timeout}: {details[timeout]}ms,"
|
|
result = result[:-1]
|
|
result += ")"
|
|
return result
|
|
|
|
|
|
class _CancellationContext:
|
|
def __init__(self) -> None:
|
|
self._cancelled = False
|
|
|
|
def cancel(self) -> None:
|
|
"""Cancel this context."""
|
|
self._cancelled = True
|
|
|
|
@property
|
|
def cancelled(self) -> bool:
|
|
"""Was cancel called?"""
|
|
return self._cancelled
|
|
|
|
|
|
async def _async_create_connection(address: _Address, options: PoolOptions) -> socket.socket:
|
|
"""Given (host, port) and PoolOptions, connect and return a raw socket object.
|
|
|
|
Can raise socket.error.
|
|
|
|
This is a modified version of create_connection from CPython >= 2.7.
|
|
"""
|
|
host, port = address
|
|
|
|
# Check if dealing with a unix domain socket
|
|
if host.endswith(".sock"):
|
|
if not hasattr(socket, "AF_UNIX"):
|
|
raise ConnectionFailure("UNIX-sockets are not supported on this system")
|
|
sock = socket.socket(socket.AF_UNIX)
|
|
# SOCK_CLOEXEC not supported for Unix sockets.
|
|
_set_non_inheritable_non_atomic(sock.fileno())
|
|
try:
|
|
sock.setblocking(False)
|
|
await asyncio.get_running_loop().sock_connect(sock, host)
|
|
return sock
|
|
except OSError:
|
|
sock.close()
|
|
raise
|
|
|
|
# Don't try IPv6 if we don't support it. Also skip it if host
|
|
# is 'localhost' (::1 is fine). Avoids slow connect issues
|
|
# like PYTHON-356.
|
|
family = socket.AF_INET
|
|
if socket.has_ipv6 and host != "localhost":
|
|
family = socket.AF_UNSPEC
|
|
|
|
err = None
|
|
for res in await _getaddrinfo(host, port, family=family, type=socket.SOCK_STREAM):
|
|
af, socktype, proto, dummy, sa = res
|
|
# SOCK_CLOEXEC was new in CPython 3.2, and only available on a limited
|
|
# number of platforms (newer Linux and *BSD). Starting with CPython 3.4
|
|
# all file descriptors are created non-inheritable. See PEP 446.
|
|
try:
|
|
sock = socket.socket(af, socktype | getattr(socket, "SOCK_CLOEXEC", 0), proto)
|
|
except OSError:
|
|
# Can SOCK_CLOEXEC be defined even if the kernel doesn't support
|
|
# it?
|
|
sock = socket.socket(af, socktype, proto)
|
|
# Fallback when SOCK_CLOEXEC isn't available.
|
|
_set_non_inheritable_non_atomic(sock.fileno())
|
|
try:
|
|
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
|
|
# CSOT: apply timeout to socket connect.
|
|
timeout = _csot.remaining()
|
|
if timeout is None:
|
|
timeout = options.connect_timeout
|
|
elif timeout <= 0:
|
|
raise socket.timeout("timed out")
|
|
sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, True)
|
|
_set_keepalive_times(sock)
|
|
# Socket needs to be non-blocking during connection to not block the event loop
|
|
sock.setblocking(False)
|
|
await asyncio.wait_for(
|
|
asyncio.get_running_loop().sock_connect(sock, sa), timeout=timeout
|
|
)
|
|
sock.settimeout(timeout)
|
|
return sock
|
|
except asyncio.TimeoutError as e:
|
|
sock.close()
|
|
err = socket.timeout("timed out")
|
|
err.__cause__ = e
|
|
except OSError as e:
|
|
sock.close()
|
|
err = e # type: ignore[assignment]
|
|
|
|
if err is not None:
|
|
raise err
|
|
else:
|
|
# This likely means we tried to connect to an IPv6 only
|
|
# host with an OS/kernel or Python interpreter that doesn't
|
|
# support IPv6. The test case is Jython2.5.1 which doesn't
|
|
# support IPv6 at all.
|
|
raise OSError("getaddrinfo failed")
|
|
|
|
|
|
async def _async_configured_socket(
|
|
address: _Address, options: PoolOptions
|
|
) -> Union[socket.socket, _sslConn]:
|
|
"""Given (host, port) and PoolOptions, return a raw configured socket.
|
|
|
|
Can raise socket.error, ConnectionFailure, or _CertificateError.
|
|
|
|
Sets socket's SSL and timeout options.
|
|
"""
|
|
sock = await _async_create_connection(address, options)
|
|
ssl_context = options._ssl_context
|
|
|
|
if ssl_context is None:
|
|
sock.settimeout(options.socket_timeout)
|
|
return sock
|
|
|
|
host = address[0]
|
|
try:
|
|
# We have to pass hostname / ip address to wrap_socket
|
|
# to use SSLContext.check_hostname.
|
|
if _has_sni(False):
|
|
loop = asyncio.get_running_loop()
|
|
ssl_sock = await loop.run_in_executor(
|
|
None,
|
|
functools.partial(ssl_context.wrap_socket, sock, server_hostname=host), # type: ignore[assignment, misc, unused-ignore]
|
|
)
|
|
else:
|
|
loop = asyncio.get_running_loop()
|
|
ssl_sock = await loop.run_in_executor(None, ssl_context.wrap_socket, sock) # type: ignore[assignment, misc, unused-ignore]
|
|
except _CertificateError:
|
|
sock.close()
|
|
# Raise _CertificateError directly like we do after match_hostname
|
|
# below.
|
|
raise
|
|
except (OSError, *SSLErrors) as exc:
|
|
sock.close()
|
|
# We raise AutoReconnect for transient and permanent SSL handshake
|
|
# failures alike. Permanent handshake failures, like protocol
|
|
# mismatch, will be turned into ServerSelectionTimeoutErrors later.
|
|
details = _get_timeout_details(options)
|
|
_raise_connection_failure(address, exc, "SSL handshake failed: ", timeout_details=details)
|
|
if (
|
|
ssl_context.verify_mode
|
|
and not ssl_context.check_hostname
|
|
and not options.tls_allow_invalid_hostnames
|
|
):
|
|
try:
|
|
ssl.match_hostname(ssl_sock.getpeercert(), hostname=host) # type:ignore[attr-defined, unused-ignore]
|
|
except _CertificateError:
|
|
ssl_sock.close()
|
|
raise
|
|
|
|
ssl_sock.settimeout(options.socket_timeout)
|
|
return ssl_sock
|
|
|
|
|
|
async def _configured_protocol_interface(
|
|
address: _Address, options: PoolOptions
|
|
) -> AsyncNetworkingInterface:
|
|
"""Given (host, port) and PoolOptions, return a configured AsyncNetworkingInterface.
|
|
|
|
Can raise socket.error, ConnectionFailure, or _CertificateError.
|
|
|
|
Sets protocol's SSL and timeout options.
|
|
"""
|
|
sock = await _async_create_connection(address, options)
|
|
ssl_context = options._ssl_context
|
|
timeout = options.socket_timeout
|
|
|
|
if ssl_context is None:
|
|
return AsyncNetworkingInterface(
|
|
await asyncio.get_running_loop().create_connection(
|
|
lambda: PyMongoProtocol(timeout=timeout), sock=sock
|
|
)
|
|
)
|
|
|
|
host = address[0]
|
|
try:
|
|
# We have to pass hostname / ip address to wrap_socket
|
|
# to use SSLContext.check_hostname.
|
|
transport, protocol = await asyncio.get_running_loop().create_connection( # type: ignore[call-overload]
|
|
lambda: PyMongoProtocol(timeout=timeout),
|
|
sock=sock,
|
|
server_hostname=host,
|
|
ssl=ssl_context,
|
|
)
|
|
except _CertificateError:
|
|
# Raise _CertificateError directly like we do after match_hostname
|
|
# below.
|
|
raise
|
|
except (OSError, *SSLErrors) as exc:
|
|
# We raise AutoReconnect for transient and permanent SSL handshake
|
|
# failures alike. Permanent handshake failures, like protocol
|
|
# mismatch, will be turned into ServerSelectionTimeoutErrors later.
|
|
details = _get_timeout_details(options)
|
|
_raise_connection_failure(address, exc, "SSL handshake failed: ", timeout_details=details)
|
|
if (
|
|
ssl_context.verify_mode
|
|
and not ssl_context.check_hostname
|
|
and not options.tls_allow_invalid_hostnames
|
|
):
|
|
try:
|
|
ssl.match_hostname(transport.get_extra_info("peercert"), hostname=host) # type:ignore[attr-defined,unused-ignore]
|
|
except _CertificateError:
|
|
transport.abort()
|
|
raise
|
|
|
|
return AsyncNetworkingInterface((transport, protocol))
|
|
|
|
|
|
def _create_connection(address: _Address, options: PoolOptions) -> socket.socket:
|
|
"""Given (host, port) and PoolOptions, connect and return a raw socket object.
|
|
|
|
Can raise socket.error.
|
|
|
|
This is a modified version of create_connection from CPython >= 2.7.
|
|
"""
|
|
host, port = address
|
|
|
|
# Check if dealing with a unix domain socket
|
|
if host.endswith(".sock"):
|
|
if not hasattr(socket, "AF_UNIX"):
|
|
raise ConnectionFailure("UNIX-sockets are not supported on this system")
|
|
sock = socket.socket(socket.AF_UNIX)
|
|
# SOCK_CLOEXEC not supported for Unix sockets.
|
|
_set_non_inheritable_non_atomic(sock.fileno())
|
|
try:
|
|
sock.connect(host)
|
|
return sock
|
|
except OSError:
|
|
sock.close()
|
|
raise
|
|
|
|
# Don't try IPv6 if we don't support it. Also skip it if host
|
|
# is 'localhost' (::1 is fine). Avoids slow connect issues
|
|
# like PYTHON-356.
|
|
family = socket.AF_INET
|
|
if socket.has_ipv6 and host != "localhost":
|
|
family = socket.AF_UNSPEC
|
|
|
|
err = None
|
|
for res in socket.getaddrinfo(host, port, family=family, type=socket.SOCK_STREAM): # type: ignore[attr-defined, unused-ignore]
|
|
af, socktype, proto, dummy, sa = res
|
|
# SOCK_CLOEXEC was new in CPython 3.2, and only available on a limited
|
|
# number of platforms (newer Linux and *BSD). Starting with CPython 3.4
|
|
# all file descriptors are created non-inheritable. See PEP 446.
|
|
try:
|
|
sock = socket.socket(af, socktype | getattr(socket, "SOCK_CLOEXEC", 0), proto)
|
|
except OSError:
|
|
# Can SOCK_CLOEXEC be defined even if the kernel doesn't support
|
|
# it?
|
|
sock = socket.socket(af, socktype, proto)
|
|
# Fallback when SOCK_CLOEXEC isn't available.
|
|
_set_non_inheritable_non_atomic(sock.fileno())
|
|
try:
|
|
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
|
|
# CSOT: apply timeout to socket connect.
|
|
timeout = _csot.remaining()
|
|
if timeout is None:
|
|
timeout = options.connect_timeout
|
|
elif timeout <= 0:
|
|
raise socket.timeout("timed out")
|
|
sock.settimeout(timeout)
|
|
sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, True)
|
|
_set_keepalive_times(sock)
|
|
sock.connect(sa)
|
|
return sock
|
|
except OSError as e:
|
|
err = e
|
|
sock.close()
|
|
|
|
if err is not None:
|
|
raise err
|
|
else:
|
|
# This likely means we tried to connect to an IPv6 only
|
|
# host with an OS/kernel or Python interpreter that doesn't
|
|
# support IPv6. The test case is Jython2.5.1 which doesn't
|
|
# support IPv6 at all.
|
|
raise OSError("getaddrinfo failed")
|
|
|
|
|
|
def _configured_socket(address: _Address, options: PoolOptions) -> Union[socket.socket, _sslConn]:
|
|
"""Given (host, port) and PoolOptions, return a raw configured socket.
|
|
|
|
Can raise socket.error, ConnectionFailure, or _CertificateError.
|
|
|
|
Sets socket's SSL and timeout options.
|
|
"""
|
|
sock = _create_connection(address, options)
|
|
ssl_context = options._ssl_context
|
|
|
|
if ssl_context is None:
|
|
sock.settimeout(options.socket_timeout)
|
|
return sock
|
|
|
|
host = address[0]
|
|
try:
|
|
# We have to pass hostname / ip address to wrap_socket
|
|
# to use SSLContext.check_hostname.
|
|
if _has_sni(True):
|
|
ssl_sock = ssl_context.wrap_socket(sock, server_hostname=host) # type: ignore[assignment, misc, unused-ignore]
|
|
else:
|
|
ssl_sock = ssl_context.wrap_socket(sock) # type: ignore[assignment, misc, unused-ignore]
|
|
except _CertificateError:
|
|
sock.close()
|
|
# Raise _CertificateError directly like we do after match_hostname
|
|
# below.
|
|
raise
|
|
except (OSError, *SSLErrors) as exc:
|
|
sock.close()
|
|
# We raise AutoReconnect for transient and permanent SSL handshake
|
|
# failures alike. Permanent handshake failures, like protocol
|
|
# mismatch, will be turned into ServerSelectionTimeoutErrors later.
|
|
details = _get_timeout_details(options)
|
|
_raise_connection_failure(address, exc, "SSL handshake failed: ", timeout_details=details)
|
|
if (
|
|
ssl_context.verify_mode
|
|
and not ssl_context.check_hostname
|
|
and not options.tls_allow_invalid_hostnames
|
|
):
|
|
try:
|
|
ssl.match_hostname(ssl_sock.getpeercert(), hostname=host) # type:ignore[attr-defined, unused-ignore]
|
|
except _CertificateError:
|
|
ssl_sock.close()
|
|
raise
|
|
|
|
ssl_sock.settimeout(options.socket_timeout)
|
|
return ssl_sock
|
|
|
|
|
|
def _configured_socket_interface(address: _Address, options: PoolOptions) -> NetworkingInterface:
|
|
"""Given (host, port) and PoolOptions, return a NetworkingInterface wrapping a configured socket.
|
|
|
|
Can raise socket.error, ConnectionFailure, or _CertificateError.
|
|
|
|
Sets socket's SSL and timeout options.
|
|
"""
|
|
sock = _create_connection(address, options)
|
|
ssl_context = options._ssl_context
|
|
|
|
if ssl_context is None:
|
|
sock.settimeout(options.socket_timeout)
|
|
return NetworkingInterface(sock)
|
|
|
|
host = address[0]
|
|
try:
|
|
# We have to pass hostname / ip address to wrap_socket
|
|
# to use SSLContext.check_hostname.
|
|
if _has_sni(True):
|
|
ssl_sock = ssl_context.wrap_socket(sock, server_hostname=host)
|
|
else:
|
|
ssl_sock = ssl_context.wrap_socket(sock)
|
|
except _CertificateError:
|
|
sock.close()
|
|
# Raise _CertificateError directly like we do after match_hostname
|
|
# below.
|
|
raise
|
|
except (OSError, *SSLErrors) as exc:
|
|
sock.close()
|
|
# We raise AutoReconnect for transient and permanent SSL handshake
|
|
# failures alike. Permanent handshake failures, like protocol
|
|
# mismatch, will be turned into ServerSelectionTimeoutErrors later.
|
|
details = _get_timeout_details(options)
|
|
_raise_connection_failure(address, exc, "SSL handshake failed: ", timeout_details=details)
|
|
if (
|
|
ssl_context.verify_mode
|
|
and not ssl_context.check_hostname
|
|
and not options.tls_allow_invalid_hostnames
|
|
):
|
|
try:
|
|
ssl.match_hostname(ssl_sock.getpeercert(), hostname=host) # type:ignore[attr-defined,unused-ignore]
|
|
except _CertificateError:
|
|
ssl_sock.close()
|
|
raise
|
|
|
|
ssl_sock.settimeout(options.socket_timeout)
|
|
return NetworkingInterface(ssl_sock)
|