diff --git a/dashscope/__init__.py b/dashscope/__init__.py index e439c13..744b269 100644 --- a/dashscope/__init__.py +++ b/dashscope/__init__.py @@ -25,8 +25,6 @@ base_http_api_url, base_websocket_api_url, ) -from dashscope.common.aio_session_manager import AioSessionManager -from dashscope.common.session_manager import SessionManager from dashscope.customize.deployments import Deployments from dashscope.customize.finetunes import FineTunes from dashscope.embeddings.batch_text_embedding import BatchTextEmbedding @@ -66,276 +64,6 @@ list_tokenizers, ) - -def enable_http_connection_pool( - pool_connections: int = None, - pool_maxsize: int = None, - max_retries: int = None, - pool_block: bool = None, -): - """ - 启用 HTTP 连接池复用 - - 启用后,所有同步 HTTP 请求将复用连接,显著减少延迟。 - - Args: - pool_connections: 连接池大小,默认 10 - - 低并发(< 10 req/s): 10 - - 中并发(10-50 req/s): 20-30 - - 高并发(> 50 req/s): 50-100 - - pool_maxsize: 最大连接数,默认 20 - - 应该 >= pool_connections - - 低并发: 20 - - 中并发: 50 - - 高并发: 100-200 - - max_retries: 重试次数,默认 3 - - 网络稳定: 3 - - 网络不稳定: 5-10 - - pool_block: 连接池满时是否阻塞,默认 False - - False: 连接池满时创建新连接(推荐) - - True: 连接池满时等待可用连接 - - Examples: - >>> import dashscope - >>> - >>> # 使用默认配置 - >>> dashscope.enable_http_connection_pool() - >>> - >>> # 自定义配置 - >>> dashscope.enable_http_connection_pool( - ... pool_connections=20, - ... pool_maxsize=50 - ... ) - >>> - >>> # 之后的所有请求都会复用连接 - >>> Generation.call(model='qwen-turbo', prompt='Hello') - """ - SessionManager.get_instance().enable( - pool_connections=pool_connections, - pool_maxsize=pool_maxsize, - max_retries=max_retries, - pool_block=pool_block, - ) - - -def disable_http_connection_pool(): - """ - 禁用 HTTP 连接池复用 - - 恢复到原有的每次请求创建新连接的行为。 - - Example: - >>> import dashscope - >>> dashscope.disable_http_connection_pool() - """ - SessionManager.get_instance().disable() - - -def reset_http_connection_pool(): - """ - 重置 HTTP 连接池 - - 用于处理连接问题或网络切换场景。 - - Example: - >>> import dashscope - >>> dashscope.reset_http_connection_pool() - """ - SessionManager.get_instance().reset() - - -def configure_http_connection_pool( - pool_connections: int = None, - pool_maxsize: int = None, - max_retries: int = None, - pool_block: bool = None, -): - """ - 配置 HTTP 连接池参数 - - 运行时动态调整连接池配置。 - - Args: - pool_connections: 连接池大小 - pool_maxsize: 最大连接数 - max_retries: 重试次数 - pool_block: 连接池满时是否阻塞 - - Examples: - >>> import dashscope - >>> - >>> # 调整单个参数 - >>> dashscope.configure_http_connection_pool(pool_maxsize=100) - >>> - >>> # 调整多个参数 - >>> dashscope.configure_http_connection_pool( - ... pool_connections=50, - ... pool_maxsize=100 - ... ) - """ - SessionManager.get_instance().configure( - pool_connections=pool_connections, - pool_maxsize=pool_maxsize, - max_retries=max_retries, - pool_block=pool_block, - ) - - -async def enable_aio_http_connection_pool( - limit: int = None, - limit_per_host: int = None, - ttl_dns_cache: int = None, - keepalive_timeout: int = None, - force_close: bool = None, -): - """ - 启用异步 HTTP 连接池复用 - - 启用后,所有异步 HTTP 请求将复用连接,显著减少延迟。 - - Args: - limit: 总连接数限制,默认 100 - - 低并发(< 10 req/s): 100 - - 中并发(10-50 req/s): 200 - - 高并发(> 50 req/s): 300-500 - - limit_per_host: 每个主机的连接数限制,默认 30 - - 应该 <= limit - - 低并发: 30 - - 中并发: 50 - - 高并发: 100 - - ttl_dns_cache: DNS 缓存 TTL(秒),默认 300 - - DNS 稳定: 300-600 - - DNS 变化频繁: 60-120 - - keepalive_timeout: Keep-Alive 超时(秒),默认 30 - - 短连接: 15-30 - - 长连接: 60-120 - - force_close: 是否强制关闭连接,默认 False - - False: 复用连接(推荐) - - True: 每次关闭连接 - - Examples: - >>> import asyncio - >>> import dashscope - >>> from dashscope import AioGeneration - >>> - >>> async def main(): - ... # 使用默认配置 - ... await dashscope.enable_aio_http_connection_pool() - ... - ... # 之后的所有异步请求都会复用连接 - ... response = await AioGeneration.call( - ... model='qwen-turbo', - ... prompt='Hello' - ... ) - ... - ... # 自定义配置 - ... await dashscope.enable_aio_http_connection_pool( - ... limit=200, - ... limit_per_host=50 - ... ) - >>> - >>> asyncio.run(main()) - """ - manager = await AioSessionManager.get_instance() - await manager.enable( - limit=limit, - limit_per_host=limit_per_host, - ttl_dns_cache=ttl_dns_cache, - keepalive_timeout=keepalive_timeout, - force_close=force_close, - ) - - -async def disable_aio_http_connection_pool(): - """ - 禁用异步 HTTP 连接池复用 - - 恢复到原有的每次请求创建新连接的行为。 - - Examples: - >>> import asyncio - >>> import dashscope - >>> - >>> async def main(): - ... await dashscope.disable_aio_http_connection_pool() - >>> - >>> asyncio.run(main()) - """ - manager = await AioSessionManager.get_instance() - await manager.disable() - - -async def reset_aio_http_connection_pool(): - """ - 重置异步 HTTP 连接池 - - 用于处理连接问题或网络切换场景。 - - Examples: - >>> import asyncio - >>> import dashscope - >>> - >>> async def main(): - ... await dashscope.reset_aio_http_connection_pool() - >>> - >>> asyncio.run(main()) - """ - manager = await AioSessionManager.get_instance() - await manager.reset() - - -async def configure_aio_http_connection_pool( - limit: int = None, - limit_per_host: int = None, - ttl_dns_cache: int = None, - keepalive_timeout: int = None, - force_close: bool = None, -): - """ - 配置异步 HTTP 连接池参数 - - 运行时动态调整连接池配置。 - - Args: - limit: 总连接数限制 - limit_per_host: 每个主机的连接数限制 - ttl_dns_cache: DNS 缓存 TTL(秒) - keepalive_timeout: Keep-Alive 超时(秒) - force_close: 是否强制关闭连接 - - Examples: - >>> import asyncio - >>> import dashscope - >>> - >>> async def main(): - ... # 调整单个参数 - ... await dashscope.configure_aio_http_connection_pool(limit=200) - ... - ... # 调整多个参数 - ... await dashscope.configure_aio_http_connection_pool( - ... limit=200, - ... limit_per_host=50 - ... ) - >>> - >>> asyncio.run(main()) - """ - manager = await AioSessionManager.get_instance() - await manager.configure( - limit=limit, - limit_per_host=limit_per_host, - ttl_dns_cache=ttl_dns_cache, - keepalive_timeout=keepalive_timeout, - force_close=force_close, - ) - - __all__ = [ "base_http_api_url", "base_websocket_api_url", @@ -390,14 +118,6 @@ async def configure_aio_http_connection_pool( "MessageFile", "AssistantFile", "VideoSynthesis", - "enable_http_connection_pool", - "disable_http_connection_pool", - "reset_http_connection_pool", - "configure_http_connection_pool", - "enable_aio_http_connection_pool", - "disable_aio_http_connection_pool", - "reset_aio_http_connection_pool", - "configure_aio_http_connection_pool", ] logging.getLogger(__name__).addHandler(NullHandler()) diff --git a/dashscope/api_entities/api_request_factory.py b/dashscope/api_entities/api_request_factory.py index 591966f..cc58479 100644 --- a/dashscope/api_entities/api_request_factory.py +++ b/dashscope/api_entities/api_request_factory.py @@ -37,7 +37,6 @@ def _get_protocol_params(kwargs): flattened_output = kwargs.pop("flattened_output", False) extra_url_parameters = kwargs.pop("extra_url_parameters", None) session = kwargs.pop("session", None) - aio_session = kwargs.pop("aio_session", None) # Extract user-agent from headers if present user_agent = "" @@ -61,7 +60,6 @@ def _get_protocol_params(kwargs): extra_url_parameters, user_agent, session, - aio_session, ) @@ -92,7 +90,6 @@ def _build_api_request( # pylint: disable=too-many-branches extra_url_parameters, user_agent, session, - aio_session, ) = _get_protocol_params(kwargs) task_id = kwargs.pop("task_id", None) enable_encryption = kwargs.pop("enable_encryption", False) @@ -137,7 +134,6 @@ def _build_api_request( # pylint: disable=too-many-branches encryption=encryption, user_agent=user_agent, session=session, - aio_session=aio_session, ) elif api_protocol == ApiProtocol.WEBSOCKET: if base_address is not None: diff --git a/dashscope/api_entities/http_request.py b/dashscope/api_entities/http_request.py index 42343a7..052ca83 100644 --- a/dashscope/api_entities/http_request.py +++ b/dashscope/api_entities/http_request.py @@ -4,7 +4,7 @@ import json import ssl from http import HTTPStatus -from typing import Optional +from typing import Optional, Dict, Union import aiohttp import certifi @@ -42,8 +42,9 @@ def __init__( flattened_output: bool = False, encryption: Optional[Encryption] = None, user_agent: str = "", - session: Optional[requests.Session] = None, - aio_session: Optional[aiohttp.ClientSession] = None, + session: Optional[ + Union[requests.Session, aiohttp.ClientSession] + ] = None, ) -> None: """HttpSSERequest, processing http server sent event stream. @@ -56,10 +57,11 @@ def __init__( Defaults to DEFAULT_REQUEST_TIMEOUT_SECONDS. user_agent (str, optional): Additional user agent string to append. Defaults to ''. - session (Optional[requests.Session]): External session for - connection reuse (sync). Defaults to None. - aio_session (Optional[aiohttp.ClientSession]): External session - for connection reuse (async). Defaults to None. + session (Optional[Union[requests.Session, + aiohttp.ClientSession]], optional): + Custom Session for connection reuse. Can be either + requests.Session for sync calls or aiohttp.ClientSession + for async calls. Defaults to None. """ super().__init__(user_agent=user_agent) @@ -67,13 +69,29 @@ def __init__( self.flattened_output = flattened_output self.async_request = async_request self.encryption = encryption - self._external_session = session - self._external_aio_session = aio_session - base_headers = getattr(self, "headers", {}) - self.headers = { + + # Auto-detect session type and store accordingly + if session is not None: + session_type = type(session).__name__ + session_module = type(session).__module__ + + # Check if it's an aiohttp ClientSession + if ( + session_type == "ClientSession" and "aiohttp" in session_module + ) or isinstance(session, aiohttp.ClientSession): + self._external_session = None + self._external_aio_session = session + else: + # Treat as requests Session + self._external_session = session + self._external_aio_session = None + else: + self._external_session = None + self._external_aio_session = None + self.headers: Dict = { "Accept": "application/json", "Authorization": f"Bearer {api_key}", - **base_headers, + **self.headers, } if encryption and encryption.is_valid(): @@ -111,24 +129,6 @@ def __init__( else: self.timeout = timeout # type: ignore[has-type] - def get_external_session(self) -> Optional[requests.Session]: - """ - 获取外部传入的同步 Session - - Returns: - Optional[requests.Session]: 外部 Session,如果未设置则返回 None - """ - return self._external_session - - def get_external_aio_session(self) -> Optional[aiohttp.ClientSession]: - """ - 获取外部传入的异步 Session - - Returns: - Optional[aiohttp.ClientSession]: 外部异步 Session,如果未设置则返回 None - """ - return self._external_aio_session - def add_header(self, key, value): self.headers[key] = value @@ -159,119 +159,70 @@ async def aio_call(self): pass return result - async def _get_aio_session(self): - """获取异步 Session(优先级:外部 > 全局 > 临时)""" - # 1. 检查是否有外部传入的 Session(最高优先级) - if self._external_aio_session is not None: - logger.debug( - "Using external async session for request: %s", - self.url, - ) - return self._external_aio_session, False - - # 2. 尝试获取全局异步 Session - from dashscope.common.aio_session_manager import AioSessionManager - - manager = await AioSessionManager.get_instance() - global_session = await manager.get_session() - - if global_session is not None: - logger.debug( - "Using global async session for request: %s", - self.url, - ) - return global_session, False - - # 3. 创建临时 Session(保持向后兼容) - connector = aiohttp.TCPConnector( - ssl=ssl.create_default_context( - cafile=certifi.where(), - ), - ) - session = aiohttp.ClientSession( - connector=connector, - timeout=aiohttp.ClientTimeout(total=self.timeout), - headers=self.headers, - ) - logger.debug( - "Using temporary async session for request: %s", - self.url, - ) - return session, True - - async def _execute_aio_request(self, session, timeout): - """执行异步 HTTP 请求""" - logger.debug("Starting request: %s", self.url) - - if self.method == HTTPMethod.POST: - return await self._execute_post_request(session, timeout) - if self.method == HTTPMethod.GET: - return await self._execute_get_request(session, timeout) - - raise UnsupportedHTTPMethod( - f"Unsupported http method: {self.method}", - ) - - async def _execute_post_request(self, session, timeout): - """执行 POST 请求""" - is_form, obj = False, {} - if hasattr(self, "data") and self.data is not None: - is_form, obj = self.data.get_aiohttp_payload() - - if is_form: - headers = {**self.headers, **obj.headers} - return await session.post( - url=self.url, - data=obj, - headers=headers, - timeout=timeout, - ) - - return await session.request( - "POST", - url=self.url, - json=obj, - headers=self.headers, - timeout=timeout, - ) - - async def _execute_get_request(self, session, timeout): - """执行 GET 请求""" - params = {} - if hasattr(self, "data") and self.data is not None: - params = getattr(self.data, "parameters", {}) - if params: - params = self.__handle_parameters(params) - - return await session.get( - url=self.url, - params=params, - headers=self.headers, - timeout=timeout, - ) - - async def _handle_aio_request(self): + async def _handle_aio_request(self): # pylint: disable=too-many-branches try: - # 获取 Session(优先级:外部 > 全局 > 临时) - session, should_close = await self._get_aio_session() + # Use external aio_session if provided, + # otherwise create temporary session + if self._external_aio_session is not None: + session = self._external_aio_session + should_close = False + else: + connector = aiohttp.TCPConnector( + ssl=ssl.create_default_context( + cafile=certifi.where(), + ), + ) + session = aiohttp.ClientSession( + connector=connector, + timeout=aiohttp.ClientTimeout(total=self.timeout), + headers=self.headers, + ) + should_close = True try: - # 设置超时 - timeout = aiohttp.ClientTimeout(total=self.timeout) - - # 执行请求 - response = await self._execute_aio_request(session, timeout) - + logger.debug("Starting request: %s", self.url) + if self.method == HTTPMethod.POST: + is_form, obj = False, {} + if hasattr(self, "data") and self.data is not None: + is_form, obj = self.data.get_aiohttp_payload() + if is_form: + headers = {**self.headers, **obj.headers} + response = await session.post( + url=self.url, + data=obj, + headers=headers, + ) + else: + response = await session.request( + "POST", + url=self.url, + json=obj, + headers=self.headers, + ) + elif self.method == HTTPMethod.GET: + # 添加条件判断 + params = {} + if hasattr(self, "data") and self.data is not None: + params = getattr(self.data, "parameters", {}) + if params: + params = self.__handle_parameters(params) + response = await session.get( + url=self.url, + params=params, + headers=self.headers, + ) + else: + raise UnsupportedHTTPMethod( + f"Unsupported http method: {self.method}", + ) logger.debug("Response returned: %s", self.url) async with response: async for rsp in self._handle_aio_response(response): yield rsp finally: - # 只关闭临时 Session + # Only close if we created the session if should_close: await session.close() - logger.debug("Temporary async session closed") - except aiohttp.ClientConnectorError as e: logger.error(e) raise e @@ -496,34 +447,17 @@ def _handle_response( # pylint: disable=too-many-branches yield _handle_http_failed_response(response) def _handle_request(self): - """ - 处理 HTTP 请求 - - 优先级: - 1. 外部传入的 session(用户自定义) - 2. 全局 SessionManager(如果启用) - 3. 临时 session(保持原有行为) - """ try: - from dashscope.common.session_manager import SessionManager - - # 优先使用外部传入的 session + # Use external session if provided, + # otherwise create temporary session if self._external_session is not None: session = self._external_session should_close = False else: - # 尝试使用全局 SessionManager - session_manager = SessionManager.get_instance() - session = session_manager.get_session() - should_close = False - - # 如果未启用连接复用,创建临时 session - if session is None: - session = requests.Session() - should_close = True + session = requests.Session() + should_close = True try: - # 执行请求 if self.method == HTTPMethod.POST: is_form, form, obj = self.data.get_http_payload() if is_form: @@ -556,14 +490,12 @@ def _handle_request(self): raise UnsupportedHTTPMethod( f"Unsupported http method: {self.method}", ) - for rsp in self._handle_response(response): yield rsp finally: - # 只关闭临时创建的 session + # Only close if we created the session if should_close: session.close() - except BaseException as e: logger.error(e) raise e diff --git a/dashscope/common/aio_session_manager.py b/dashscope/common/aio_session_manager.py deleted file mode 100644 index ceaa7a2..0000000 --- a/dashscope/common/aio_session_manager.py +++ /dev/null @@ -1,410 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright (c) Alibaba, Inc. and its affiliates. - -"""异步 HTTP Session 管理器,用于管理 aiohttp.ClientSession 的连接池复用""" - -import asyncio -import ssl -from typing import Optional - -import aiohttp -import certifi - -from dashscope.common.logging import logger - - -class AioConnectionPoolConfig: - """异步连接池配置类""" - - def __init__( - self, - limit: int = 100, - limit_per_host: int = 30, - ttl_dns_cache: int = 300, - keepalive_timeout: int = 30, - force_close: bool = False, - ): - """ - 初始化异步连接池配置 - - Args: - limit: 总连接数限制,默认 100 - limit_per_host: 每个主机的连接数限制,默认 30 - ttl_dns_cache: DNS 缓存 TTL(秒),默认 300 - keepalive_timeout: Keep-Alive 超时(秒),默认 30 - force_close: 是否强制关闭连接,默认 False - - Raises: - ValueError: 当参数值不合法时 - """ - if limit <= 0: - raise ValueError(f"limit ({limit}) 必须 > 0") - if limit_per_host <= 0: - raise ValueError(f"limit_per_host ({limit_per_host}) 必须 > 0") - if limit_per_host > limit: - raise ValueError( - f"limit_per_host ({limit_per_host}) 必须 <= " f"limit ({limit})", - ) - if ttl_dns_cache < 0: - raise ValueError(f"ttl_dns_cache ({ttl_dns_cache}) 必须 >= 0") - if keepalive_timeout < 0: - raise ValueError( - f"keepalive_timeout ({keepalive_timeout}) 必须 >= 0", - ) - - self.limit = limit - self.limit_per_host = limit_per_host - self.ttl_dns_cache = ttl_dns_cache - self.keepalive_timeout = keepalive_timeout - self.force_close = force_close - - def __repr__(self): - return ( - f"AioConnectionPoolConfig(limit={self.limit}, " - f"limit_per_host={self.limit_per_host}, " - f"ttl_dns_cache={self.ttl_dns_cache}, " - f"keepalive_timeout={self.keepalive_timeout}, " - f"force_close={self.force_close})" - ) - - -class AioSessionManager: - """ - 异步 HTTP Session 管理器(单例模式) - - 用于管理全局的 aiohttp.ClientSession 实例,实现异步 HTTP 连接复用。 - - 特性: - - 单例模式:全局唯一实例 - - 异步锁保护:使用 asyncio.Lock 保护并发访问 - - 连接池配置:支持自定义 TCPConnector 参数 - - 生命周期管理:支持启用、禁用、重置 - - 向后兼容:默认禁用,不影响现有代码 - - Examples: - >>> import asyncio - >>> from dashscope.common.aio_session_manager import AioSessionManager - >>> - >>> async def main(): - ... manager = await AioSessionManager.get_instance() - ... await manager.enable(limit=200, limit_per_host=50) - ... session = await manager.get_session() - ... # 使用 session 进行请求 - ... await manager.disable() - >>> - >>> asyncio.run(main()) - """ - - _instance: Optional["AioSessionManager"] = None - _lock = asyncio.Lock() - - def __init__(self): - """初始化 Session 管理器(私有,通过 get_instance 获取)""" - self._enabled = False - self._session: Optional[aiohttp.ClientSession] = None - self._session_lock = asyncio.Lock() - self._config = AioConnectionPoolConfig() - logger.debug("AioSessionManager initialized") - - @classmethod - async def get_instance(cls) -> "AioSessionManager": - """ - 获取单例实例(异步) - - Returns: - AioSessionManager: 单例实例 - """ - if cls._instance is None: - async with cls._lock: - if cls._instance is None: - cls._instance = cls() - logger.debug( - "AioSessionManager singleton instance created", - ) - return cls._instance - - @classmethod - async def reset_instance(cls): - """ - 重置单例实例(仅用于测试) - - 警告:此方法仅应在测试环境中使用 - """ - async with cls._lock: - if cls._instance is not None: - await cls._instance.disable() - await cls._instance.reset() - cls._instance = None - logger.debug("AioSessionManager singleton instance reset") - - async def enable( - self, - limit: int = None, - limit_per_host: int = None, - ttl_dns_cache: int = None, - keepalive_timeout: int = None, - force_close: bool = None, - ): - """ - 启用异步连接池复用 - - Args: - limit: 总连接数限制,默认 100 - limit_per_host: 每个主机的连接数限制,默认 30 - ttl_dns_cache: DNS 缓存 TTL(秒),默认 300 - keepalive_timeout: Keep-Alive 超时(秒),默认 30 - force_close: 是否强制关闭连接,默认 False - - Examples: - >>> manager = await AioSessionManager.get_instance() - >>> await manager.enable(limit=200, limit_per_host=50) - """ - async with self._session_lock: - # 如果提供了配置参数,先配置 - if any( - param is not None - for param in [ - limit, - limit_per_host, - ttl_dns_cache, - keepalive_timeout, - force_close, - ] - ): - await self._configure( - limit=limit, - limit_per_host=limit_per_host, - ttl_dns_cache=ttl_dns_cache, - keepalive_timeout=keepalive_timeout, - force_close=force_close, - ) - - self._enabled = True - await self._ensure_session() - logger.info( - "Async HTTP connection pool enabled with config: %s", - self._config, - ) - - async def disable(self): - """ - 禁用异步连接池复用 - - 关闭当前 Session 并禁用连接池功能 - - Examples: - >>> manager = await AioSessionManager.get_instance() - >>> await manager.disable() - """ - async with self._session_lock: - self._enabled = False - if self._session and not self._session.closed: - await self._session.close() - logger.debug("Async ClientSession closed") - self._session = None - logger.info("Async HTTP connection pool disabled") - - async def configure( - self, - limit: int = None, - limit_per_host: int = None, - ttl_dns_cache: int = None, - keepalive_timeout: int = None, - force_close: bool = None, - ): - """ - 配置连接池参数 - - Args: - limit: 总连接数限制 - limit_per_host: 每个主机的连接数限制 - ttl_dns_cache: DNS 缓存 TTL(秒) - keepalive_timeout: Keep-Alive 超时(秒) - force_close: 是否强制关闭连接 - - Examples: - >>> manager = await AioSessionManager.get_instance() - >>> await manager.configure(limit=200, limit_per_host=50) - """ - async with self._session_lock: - await self._configure( - limit=limit, - limit_per_host=limit_per_host, - ttl_dns_cache=ttl_dns_cache, - keepalive_timeout=keepalive_timeout, - force_close=force_close, - ) - - async def _configure( - self, - limit: int = None, - limit_per_host: int = None, - ttl_dns_cache: int = None, - keepalive_timeout: int = None, - force_close: bool = None, - ): - """内部配置方法(无锁)""" - config_params = {} - if limit is not None: - config_params["limit"] = limit - if limit_per_host is not None: - config_params["limit_per_host"] = limit_per_host - if ttl_dns_cache is not None: - config_params["ttl_dns_cache"] = ttl_dns_cache - if keepalive_timeout is not None: - config_params["keepalive_timeout"] = keepalive_timeout - if force_close is not None: - config_params["force_close"] = force_close - - if config_params: - # 创建新配置 - limit = config_params.get("limit", self._config.limit) - limit_per_host = config_params.get( - "limit_per_host", - self._config.limit_per_host, - ) - ttl_dns_cache = config_params.get( - "ttl_dns_cache", - self._config.ttl_dns_cache, - ) - keepalive_timeout = config_params.get( - "keepalive_timeout", - self._config.keepalive_timeout, - ) - force_close = config_params.get( - "force_close", - self._config.force_close, - ) - - new_config = AioConnectionPoolConfig( - limit=limit, - limit_per_host=limit_per_host, - ttl_dns_cache=ttl_dns_cache, - keepalive_timeout=keepalive_timeout, - force_close=bool(force_close), - ) - self._config = new_config - - # 如果已启用,重新创建 Session - if self._enabled: - if self._session and not self._session.closed: - await self._session.close() - self._session = None - await self._ensure_session() - logger.info( - "Async connection pool reconfigured: %s", - self._config, - ) - - async def _ensure_session(self): - """确保 Session 存在且有效(内部方法,无锁)""" - if self._session is None or self._session.closed: - connector = aiohttp.TCPConnector( - limit=self._config.limit, - limit_per_host=self._config.limit_per_host, - ttl_dns_cache=self._config.ttl_dns_cache, - keepalive_timeout=self._config.keepalive_timeout, - force_close=self._config.force_close, - ssl=ssl.create_default_context(cafile=certifi.where()), - ) - self._session = aiohttp.ClientSession(connector=connector) - logger.debug( - "New async ClientSession created with config: %s", - self._config, - ) - - async def get_session(self) -> Optional[aiohttp.ClientSession]: - """ - 获取 Session(如果启用) - - Returns: - Optional[aiohttp.ClientSession]: 如果启用返回 Session,否则返回 None - - Examples: - >>> manager = await AioSessionManager.get_instance() - >>> await manager.enable() - >>> session = await manager.get_session() - >>> if session: - ... # 使用 session 进行请求 - ... pass - """ - async with self._session_lock: - if self._enabled: - await self._ensure_session() - return self._session - return None - - async def get_session_direct(self) -> Optional[aiohttp.ClientSession]: - """ - 直接获取 Session(不检查启用状态) - - Returns: - Optional[aiohttp.ClientSession]: 当前 Session 或 None - - Note: - 此方法主要用于测试,一般应使用 get_session() - """ - async with self._session_lock: - return self._session - - async def reset(self): - """ - 重置 Session - - 关闭当前 Session 并根据启用状态重新创建 - - Examples: - >>> manager = await AioSessionManager.get_instance() - >>> await manager.reset() - """ - async with self._session_lock: - if self._session and not self._session.closed: - await self._session.close() - logger.debug("Async ClientSession closed during reset") - self._session = None - if self._enabled: - await self._ensure_session() - logger.info("Async HTTP connection pool reset") - - def get_config(self) -> AioConnectionPoolConfig: - """ - 获取当前连接池配置 - - Returns: - AioConnectionPoolConfig: 当前配置 - - Examples: - >>> manager = await AioSessionManager.get_instance() - >>> config = manager.get_config() - >>> print(config.limit) - """ - return self._config - - def is_enabled(self) -> bool: - """ - 检查连接池是否已启用 - - Returns: - bool: 是否已启用 - - Examples: - >>> manager = await AioSessionManager.get_instance() - >>> if manager.is_enabled(): - ... print("Connection pool is enabled") - """ - return self._enabled - - async def has_active_session(self) -> bool: - """ - 检查是否有活跃的 Session - - Returns: - bool: 是否有活跃的 Session - - Examples: - >>> manager = await AioSessionManager.get_instance() - >>> if await manager.has_active_session(): - ... print("Active session exists") - """ - async with self._session_lock: - return self._session is not None and not self._session.closed diff --git a/dashscope/common/session_manager.py b/dashscope/common/session_manager.py deleted file mode 100644 index cdf321e..0000000 --- a/dashscope/common/session_manager.py +++ /dev/null @@ -1,315 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright (c) Alibaba, Inc. and its affiliates. - -import threading -from typing import Optional - -import requests -from requests.adapters import HTTPAdapter - -from dashscope.common.logging import logger - - -class ConnectionPoolConfig: - """ - 连接池配置类 - - 提供类型安全和参数验证的配置方式 - """ - - def __init__( - self, - pool_connections: int = 10, - pool_maxsize: int = 20, - max_retries: int = 3, - pool_block: bool = False, - ): - """ - 初始化连接池配置 - - Args: - pool_connections: 连接池大小,默认 10 - - 低并发(< 10 req/s): 10 - - 中并发(10-50 req/s): 20-30 - - 高并发(> 50 req/s): 50-100 - - pool_maxsize: 最大连接数,默认 20 - - 应该 >= pool_connections - - 低并发: 20 - - 中并发: 50 - - 高并发: 100-200 - - max_retries: 重试次数,默认 3 - - 网络稳定: 3 - - 网络不稳定: 5-10 - - pool_block: 连接池满时是否阻塞,默认 False - - False: 连接池满时创建新连接(推荐) - - True: 连接池满时等待可用连接 - """ - # 参数验证 - if pool_connections < 1: - raise ValueError("pool_connections 必须 >= 1") - if pool_maxsize < pool_connections: - raise ValueError("pool_maxsize 必须 >= pool_connections") - if max_retries < 0: - raise ValueError("max_retries 必须 >= 0") - - self.pool_connections = pool_connections - self.pool_maxsize = pool_maxsize - self.max_retries = max_retries - self.pool_block = pool_block - - def to_dict(self): - """转换为字典格式""" - return { - "pool_connections": self.pool_connections, - "pool_maxsize": self.pool_maxsize, - "max_retries": self.max_retries, - "pool_block": self.pool_block, - } - - def __repr__(self): - return ( - f"ConnectionPoolConfig(" - f"pool_connections={self.pool_connections}, " - f"pool_maxsize={self.pool_maxsize}, " - f"max_retries={self.max_retries}, " - f"pool_block={self.pool_block})" - ) - - -class SessionManager: - """ - 全局 HTTP Session 管理器 - - 特性: - 1. 线程安全的 Session 池 - 2. 支持全局启用/禁用连接复用 - 3. 支持自定义 Session 配置 - 4. 自动清理和重建机制 - """ - - _instance = None - _lock = threading.Lock() - - def __init__(self): - self._enabled = False # 默认关闭,保持向后兼容 - self._session = None - self._session_lock = threading.RLock() - self._config = ConnectionPoolConfig() # 使用配置类 - - @classmethod - def get_instance(cls): - """单例模式获取实例""" - if cls._instance is None: - with cls._lock: - if cls._instance is None: - cls._instance = cls() - return cls._instance - - @classmethod - def reset_instance(cls): - """ - 重置单例实例(仅用于测试) - - 警告:此方法仅应在测试环境中使用 - """ - with cls._lock: - if cls._instance is not None: - cls._instance.disable() - cls._instance.reset() - cls._instance = None - - def enable( - self, - pool_connections: Optional[int] = None, - pool_maxsize: Optional[int] = None, - max_retries: Optional[int] = None, - pool_block: Optional[bool] = None, - ): - """ - 启用连接复用 - - Args: - pool_connections: 连接池大小,默认 10 - pool_maxsize: 最大连接数,默认 20 - max_retries: 重试次数,默认 3 - pool_block: 连接池满时是否阻塞,默认 False - - Examples: - # 使用默认配置 - enable() - - # 使用命名参数 - enable(pool_connections=50, pool_maxsize=100) - """ - with self._session_lock: - # 使用命名参数更新配置 - if pool_connections is not None: - self._config.pool_connections = pool_connections - if pool_maxsize is not None: - self._config.pool_maxsize = pool_maxsize - if max_retries is not None: - self._config.max_retries = max_retries - if pool_block is not None: - self._config.pool_block = pool_block - - # 参数验证 - if self._config.pool_maxsize < self._config.pool_connections: - raise ValueError( - f"pool_maxsize ({self._config.pool_maxsize}) 必须 >= " - f"pool_connections ({self._config.pool_connections})", - ) - - self._enabled = True - self._ensure_session() - logger.info( - "HTTP connection pool enabled with config: %s", - self._config, - ) - - def disable(self): - """禁用连接复用,关闭现有 Session""" - with self._session_lock: - self._enabled = False - if self._session: - try: - self._session.close() - except Exception as e: - logger.warning("Error closing session: %s", e) - finally: - self._session = None - logger.info("HTTP connection pool disabled") - - def is_enabled(self): - """检查是否启用连接复用""" - return self._enabled - - def get_config(self) -> ConnectionPoolConfig: - """ - 获取当前连接池配置 - - Returns: - ConnectionPoolConfig: 当前配置对象 - """ - return self._config - - def has_active_session(self) -> bool: - """ - 检查是否有活跃的 Session - - Returns: - bool: 如果存在活跃的 Session 返回 True,否则返回 False - """ - with self._session_lock: - return self._session is not None - - def _ensure_session(self): - """确保 Session 存在且有效(需要持有锁)""" - if self._session is None: - self._session = requests.Session() - - # 配置连接池 - adapter = HTTPAdapter( - pool_connections=self._config.pool_connections, - pool_maxsize=self._config.pool_maxsize, - max_retries=self._config.max_retries, - pool_block=self._config.pool_block, - ) - - self._session.mount("http://", adapter) - self._session.mount("https://", adapter) - logger.debug("Created new HTTP session with connection pool") - - def get_session(self) -> Optional[requests.Session]: - """ - 获取 Session 对象 - - Returns: - 如果启用了连接复用,返回全局 Session - 否则返回 None - - Examples: - >>> manager = SessionManager.get_instance() - >>> manager.enable() - >>> session = manager.get_session() - >>> if session: - ... response = session.get(url) - """ - if not self._enabled: - return None - - with self._session_lock: - self._ensure_session() - return self._session - - def reset(self): - """重置 Session(用于处理连接问题)""" - with self._session_lock: - if self._session: - try: - self._session.close() - except Exception as e: - logger.warning("Error closing session during reset: %s", e) - finally: - self._session = None - if self._enabled: - self._ensure_session() - logger.info("HTTP connection pool reset") - - def configure( - self, - pool_connections: Optional[int] = None, - pool_maxsize: Optional[int] = None, - max_retries: Optional[int] = None, - pool_block: Optional[bool] = None, - ): - """ - 更新配置并重建 Session - - Args: - pool_connections: 连接池大小 - pool_maxsize: 最大连接数 - max_retries: 重试次数 - pool_block: 连接池满时是否阻塞 - - Examples: - # 调整单个参数 - configure(pool_maxsize=100) - - # 调整多个参数 - configure(pool_connections=50, pool_maxsize=100) - """ - with self._session_lock: - # 使用命名参数更新配置 - if pool_connections is not None: - self._config.pool_connections = pool_connections - if pool_maxsize is not None: - self._config.pool_maxsize = pool_maxsize - if max_retries is not None: - self._config.max_retries = max_retries - if pool_block is not None: - self._config.pool_block = pool_block - - # 参数验证 - if self._config.pool_maxsize < self._config.pool_connections: - raise ValueError( - f"pool_maxsize ({self._config.pool_maxsize}) 必须 >= " - f"pool_connections ({self._config.pool_connections})", - ) - - if self._enabled: - # 重建 Session 以应用新配置 - if self._session: - try: - self._session.close() - except Exception as e: - logger.warning( - "Error closing session during configure: %s", - e, - ) - finally: - self._session = None - self._ensure_session() - logger.info("HTTP connection pool configured: %s", self._config) diff --git a/samples/test_aio_generation.py b/samples/test_aio_generation.py index 54728b1..271bf47 100644 --- a/samples/test_aio_generation.py +++ b/samples/test_aio_generation.py @@ -3,6 +3,9 @@ import asyncio import os +import aiohttp +import ssl +import certifi from dashscope.aigc.generation import AioGeneration @@ -439,6 +442,118 @@ async def process_responses(responses, step_name): await call_deep_research_model(messages, "第二步:深入研究") print("\n 研究完成!") + @staticmethod + async def test_with_custom_session(): + """示例:使用自定义 ClientSession 进行连接复用""" + print("\n=== 使用自定义 ClientSession 示例 ===") + + # 配置 TCPConnector + connector = aiohttp.TCPConnector( + limit=100, + limit_per_host=30, + ssl=ssl.create_default_context(cafile=certifi.where()), + ) + + # 创建自定义 ClientSession + async with aiohttp.ClientSession(connector=connector) as session: + # 使用同一个 session 进行多次请求 + for i in range(3): + print(f"\n--- 请求 {i+1} ---") + + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": f"你好"}, + ] + + response = await AioGeneration.call( + api_key=os.getenv("DASHSCOPE_API_KEY"), + model="qwen-turbo", + messages=messages, + result_format="message", + session=session, # ← 传入自定义 session + ) + + print(f"响应: {response.output.choices[0].message.content}") + + print("\n✅ ClientSession 已自动关闭") + + @staticmethod + async def test_with_custom_session_streaming(): + """示例:使用自定义 ClientSession 进行流式输出""" + print("\n=== 使用自定义 ClientSession 流式输出示例 ===") + + # 配置连接池 + connector = aiohttp.TCPConnector( + limit=100, + ssl=ssl.create_default_context(cafile=certifi.where()), + ) + + async with aiohttp.ClientSession(connector=connector) as session: + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "请写一首关于秋天的诗"}, + ] + + print("\n流式输出:") + response = await AioGeneration.call( + api_key=os.getenv("DASHSCOPE_API_KEY"), + model="qwen-turbo", + messages=messages, + result_format="message", + stream=True, + incremental_output=True, + session=session, # ← 传入自定义 session + ) + + async for chunk in response: + if chunk.status_code == 200: + print(chunk.output.choices[0].message.content, end='', flush=True) + + print("\n") + + print("✅ ClientSession 已自动关闭") + + @staticmethod + async def test_with_custom_session_concurrent(): + """示例:使用自定义 ClientSession 进行并发请求""" + print("\n=== 使用自定义 ClientSession 并发请求示例 ===") + + # 配置连接池 + connector = aiohttp.TCPConnector( + limit=100, + ssl=ssl.create_default_context(cafile=certifi.where()), + ) + + async with aiohttp.ClientSession(connector=connector) as session: + # 创建多个并发任务 + tasks = [] + topics = ["Python", "JavaScript", "Go"] + + for topic in topics: + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": f"请用一句话介绍:{topic}"}, + ] + + task = AioGeneration.call( + api_key=os.getenv("DASHSCOPE_API_KEY"), + model="qwen-turbo", + messages=messages, + result_format="message", + session=session, # ← 所有请求共享同一个 session + ) + tasks.append(task) + + # 并发执行所有任务 + print("\n开始并发请求...") + responses = await asyncio.gather(*tasks) + + # 处理响应 + for i, response in enumerate(responses): + print(f"\n请求 {i+1} 响应: {response.output.choices[0].message.content}") + + print("\n✅ ClientSession 已自动关闭") + async def main(): """Main function to run all async tests.""" @@ -450,6 +565,11 @@ async def main(): # await TestAioGeneration.test_response_with_search_info() # await TestAioGeneration.test_response_with_reasoning_content() + # 自定义 Session 示例 + # await TestAioGeneration.test_with_custom_session() + # await TestAioGeneration.test_with_custom_session_streaming() + # await TestAioGeneration.test_with_custom_session_concurrent() + print("\n所有异步测试用例执行完成!") diff --git a/samples/test_generation.py b/samples/test_generation.py index d7f42f4..f112c06 100644 --- a/samples/test_generation.py +++ b/samples/test_generation.py @@ -2,8 +2,8 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import os - -import dashscope +import requests +from requests.adapters import HTTPAdapter from dashscope import Generation @@ -12,9 +12,6 @@ class TestGeneration: @staticmethod def test_response_with_content(): - - dashscope.enable_http_connection_pool() - messages = [ {"role": "system", "content": "You are a helpful assistant."}, { @@ -336,9 +333,85 @@ def process_responses(responses, step_name): call_deep_research_model(messages, "第二步:深入研究") print("\n 研究完成!") + @staticmethod + def test_with_custom_session(): + """示例:使用自定义 Session 进行连接复用""" + print("\n=== 使用自定义 Session 示例 ===") + + # 创建自定义 Session 并配置连接池 + with requests.Session() as session: + # 配置连接池参数 + adapter = HTTPAdapter( + pool_connections=10, + pool_maxsize=20, + max_retries=3, + ) + session.mount('http://', adapter) + session.mount('https://', adapter) + + # 使用同一个 session 进行多次请求 + for i in range(3): + print(f"\n--- 请求 {i+1} ---") + + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": f"请用一句话介绍:主题 {i+1}"}, + ] + + response = Generation.call( + api_key=os.getenv("DASHSCOPE_API_KEY"), + model="qwen-turbo", + messages=messages, + result_format="message", + session=session, # ← 传入自定义 session + ) + + print(f"响应: {response.output.choices[0].message.content}") + + print("\n✅ Session 已自动关闭") + + @staticmethod + def test_with_custom_session_streaming(): + """示例:使用自定义 Session 进行流式输出""" + print("\n=== 使用自定义 Session 流式输出示例 ===") + + with requests.Session() as session: + # 配置连接池 + adapter = HTTPAdapter( + pool_connections=10, + pool_maxsize=20, + ) + session.mount('http://', adapter) + session.mount('https://', adapter) + + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "你好!"}, + ] + + print("\n流式输出:") + response = Generation.call( + api_key=os.getenv("DASHSCOPE_API_KEY"), + model="qwen-turbo", + messages=messages, + result_format="message", + stream=True, + incremental_output=True, + session=session, # ← 传入自定义 session + ) + + for chunk in response: + print(f"chunk: {chunk}") + + print("✅ Session 已自动关闭") + if __name__ == "__main__": TestGeneration.test_response_with_content() # TestGeneration.test_response_with_tool_calls() # TestGeneration.test_response_with_search_info() # TestGeneration.test_response_with_reasoning_content() + + # 自定义 Session 示例 + # TestGeneration.test_with_custom_session() + # TestGeneration.test_with_custom_session_streaming() diff --git a/tests/unit/test_aio_connection_pool.py b/tests/unit/test_aio_connection_pool.py deleted file mode 100644 index dd8ae51..0000000 --- a/tests/unit/test_aio_connection_pool.py +++ /dev/null @@ -1,510 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright (c) Alibaba, Inc. and its affiliates. - -"""异步连接池单元测试""" - -import asyncio - -import aiohttp -import pytest - -from dashscope.common.aio_session_manager import ( - AioConnectionPoolConfig, - AioSessionManager, -) - - -class TestAioConnectionPoolConfig: - """测试 AioConnectionPoolConfig 类""" - - def test_default_config(self): - """测试默认配置""" - config = AioConnectionPoolConfig() - assert config.limit == 100 - assert config.limit_per_host == 30 - assert config.ttl_dns_cache == 300 - assert config.keepalive_timeout == 30 - assert config.force_close is False - - def test_custom_config(self): - """测试自定义配置""" - config = AioConnectionPoolConfig( - limit=200, - limit_per_host=50, - ttl_dns_cache=600, - keepalive_timeout=60, - force_close=True, - ) - assert config.limit == 200 - assert config.limit_per_host == 50 - assert config.ttl_dns_cache == 600 - assert config.keepalive_timeout == 60 - assert config.force_close is True - - def test_config_validation(self): - """测试配置参数验证""" - # limit 必须 > 0 - with pytest.raises(ValueError, match=r"limit.*必须 > 0"): - AioConnectionPoolConfig(limit=0) - - # limit_per_host 必须 > 0 - with pytest.raises(ValueError, match=r"limit_per_host.*必须 > 0"): - AioConnectionPoolConfig(limit_per_host=0) - - # limit_per_host 必须 <= limit - with pytest.raises(ValueError, match=r"limit_per_host.*必须 <="): - AioConnectionPoolConfig(limit=50, limit_per_host=100) - - # ttl_dns_cache 必须 >= 0 - with pytest.raises(ValueError, match=r"ttl_dns_cache.*必须 >= 0"): - AioConnectionPoolConfig(ttl_dns_cache=-1) - - # keepalive_timeout 必须 >= 0 - with pytest.raises(ValueError, match=r"keepalive_timeout.*必须 >= 0"): - AioConnectionPoolConfig(keepalive_timeout=-1) - - def test_config_repr(self): - """测试配置的字符串表示""" - config = AioConnectionPoolConfig(limit=200, limit_per_host=50) - repr_str = repr(config) - assert "AioConnectionPoolConfig" in repr_str - assert "limit=200" in repr_str - assert "limit_per_host=50" in repr_str - - -class TestAioSessionManager: - """测试 AioSessionManager 类""" - - @pytest.fixture(autouse=True) - async def cleanup(self): - """每个测试后清理单例实例""" - yield - await AioSessionManager.reset_instance() - - @pytest.mark.asyncio - async def test_singleton_pattern(self): - """测试单例模式""" - manager1 = await AioSessionManager.get_instance() - manager2 = await AioSessionManager.get_instance() - assert manager1 is manager2 - - @pytest.mark.asyncio - async def test_default_state(self): - """测试默认状态""" - manager = await AioSessionManager.get_instance() - assert not manager.is_enabled() - assert not await manager.has_active_session() - config = manager.get_config() - assert config.limit == 100 - assert config.limit_per_host == 30 - - @pytest.mark.asyncio - async def test_enable(self): - """测试启用连接池""" - manager = await AioSessionManager.get_instance() - await manager.enable() - assert manager.is_enabled() - assert await manager.has_active_session() - - @pytest.mark.asyncio - async def test_enable_with_config(self): - """测试启用时配置参数""" - manager = await AioSessionManager.get_instance() - await manager.enable(limit=200, limit_per_host=50) - config = manager.get_config() - assert config.limit == 200 - assert config.limit_per_host == 50 - - @pytest.mark.asyncio - async def test_disable(self): - """测试禁用连接池""" - manager = await AioSessionManager.get_instance() - await manager.enable() - assert manager.is_enabled() - - await manager.disable() - assert not manager.is_enabled() - assert not await manager.has_active_session() - - @pytest.mark.asyncio - async def test_get_session(self): - """测试获取 Session""" - manager = await AioSessionManager.get_instance() - - # 禁用时返回 None - session = await manager.get_session() - assert session is None - - # 启用后返回 Session - await manager.enable() - session = await manager.get_session() - assert session is not None - assert isinstance(session, aiohttp.ClientSession) - assert not session.closed - - @pytest.mark.asyncio - async def test_get_session_direct(self): - """测试直接获取 Session""" - manager = await AioSessionManager.get_instance() - - # 禁用时返回 None - session = await manager.get_session_direct() - assert session is None - - # 启用后返回 Session - await manager.enable() - session = await manager.get_session_direct() - assert session is not None - - # 禁用后 Session 被关闭 - await manager.disable() - session = await manager.get_session_direct() - assert session is None - - @pytest.mark.asyncio - async def test_configure(self): - """测试配置连接池""" - manager = await AioSessionManager.get_instance() - await manager.enable() - - # 配置参数 - await manager.configure(limit=200, limit_per_host=50) - config = manager.get_config() - assert config.limit == 200 - assert config.limit_per_host == 50 - - @pytest.mark.asyncio - async def test_configure_before_enable(self): - """测试启用前配置""" - manager = await AioSessionManager.get_instance() - - # 启用前配置不会创建 Session - await manager.configure(limit=200) - assert not await manager.has_active_session() - - # 启用后使用配置的参数 - await manager.enable() - config = manager.get_config() - assert config.limit == 200 - - @pytest.mark.asyncio - async def test_reset(self): - """测试重置连接池""" - manager = await AioSessionManager.get_instance() - await manager.enable() - - old_session = await manager.get_session_direct() - assert old_session is not None - - # 重置后创建新 Session - await manager.reset() - new_session = await manager.get_session_direct() - assert new_session is not None - assert new_session is not old_session - - @pytest.mark.asyncio - async def test_reset_when_disabled(self): - """测试禁用状态下重置""" - manager = await AioSessionManager.get_instance() - await manager.enable() - await manager.disable() - - # 禁用状态下重置不会创建 Session - await manager.reset() - assert not await manager.has_active_session() - - @pytest.mark.asyncio - async def test_reset_instance(self): - """测试重置单例实例""" - manager1 = await AioSessionManager.get_instance() - await manager1.enable() - - await AioSessionManager.reset_instance() - - manager2 = await AioSessionManager.get_instance() - assert not manager2.is_enabled() - assert not await manager2.has_active_session() - - @pytest.mark.asyncio - async def test_session_reuse(self): - """测试 Session 复用""" - manager = await AioSessionManager.get_instance() - await manager.enable() - - session1 = await manager.get_session() - session2 = await manager.get_session() - assert session1 is session2 - - @pytest.mark.asyncio - async def test_session_recreation_on_configure(self): - """测试配置变更时重新创建 Session""" - manager = await AioSessionManager.get_instance() - await manager.enable() - - old_session = await manager.get_session_direct() - - # 配置变更后 Session 被重新创建 - await manager.configure(limit=200) - new_session = await manager.get_session_direct() - assert new_session is not old_session - - @pytest.mark.asyncio - async def test_concurrent_enable(self): - """测试并发启用""" - manager = await AioSessionManager.get_instance() - - # 并发启用 - await asyncio.gather( - manager.enable(), - manager.enable(), - manager.enable(), - ) - - assert manager.is_enabled() - assert await manager.has_active_session() - - @pytest.mark.asyncio - async def test_concurrent_get_session(self): - """测试并发获取 Session""" - manager = await AioSessionManager.get_instance() - await manager.enable() - - # 并发获取 Session - sessions = await asyncio.gather( - manager.get_session(), - manager.get_session(), - manager.get_session(), - ) - - # 所有 Session 应该是同一个实例 - assert all(s is sessions[0] for s in sessions) - - @pytest.mark.asyncio - async def test_concurrent_enable_disable(self): - """测试并发启用和禁用""" - manager = await AioSessionManager.get_instance() - - async def enable_disable(): - await manager.enable() - await asyncio.sleep(0.01) - await manager.disable() - - # 并发执行启用和禁用 - await asyncio.gather( - enable_disable(), - enable_disable(), - enable_disable(), - ) - - # 最终状态应该是禁用 - assert not manager.is_enabled() - - @pytest.mark.asyncio - async def test_session_closed_detection(self): - """测试 Session 关闭检测""" - manager = await AioSessionManager.get_instance() - await manager.enable() - - session = await manager.get_session_direct() - assert not session.closed - - # 手动关闭 Session - await session.close() - - # get_session 应该创建新的 Session - new_session = await manager.get_session() - assert new_session is not session - assert not new_session.closed - - -class TestAioConnectionPoolIntegration: - """测试异步连接池集成""" - - @pytest.fixture(autouse=True) - async def cleanup(self): - """每个测试后清理""" - yield - await AioSessionManager.reset_instance() - - @pytest.mark.asyncio - async def test_default_behavior_unchanged(self): - """测试默认行为不变""" - manager = await AioSessionManager.get_instance() - - # 默认禁用,不影响现有代码 - session = await manager.get_session() - assert session is None - - @pytest.mark.asyncio - async def test_enable_affects_all_requests(self): - """测试启用后影响所有请求""" - manager = await AioSessionManager.get_instance() - await manager.enable() - - # 所有请求应该使用同一个 Session - session1 = await manager.get_session() - session2 = await manager.get_session() - assert session1 is session2 - - @pytest.mark.asyncio - async def test_disable_stops_reuse(self): - """测试禁用后停止复用""" - manager = await AioSessionManager.get_instance() - await manager.enable() - - session_before = await manager.get_session() - assert session_before is not None - - await manager.disable() - - session_after = await manager.get_session() - assert session_after is None - - @pytest.mark.asyncio - async def test_multiple_enable_disable_cycles(self): - """测试多次启用/禁用循环""" - manager = await AioSessionManager.get_instance() - - for _ in range(3): - await manager.enable() - assert manager.is_enabled() - session = await manager.get_session() - assert session is not None - - await manager.disable() - assert not manager.is_enabled() - session = await manager.get_session() - assert session is None - - -class TestAioCustomSession: - """测试自定义异步 Session""" - - @pytest.fixture(autouse=True) - async def cleanup(self): - """每个测试后清理""" - yield - await AioSessionManager.reset_instance() - - @pytest.mark.asyncio - async def test_external_session_priority(self): - """测试外部 Session 优先级最高""" - from dashscope.api_entities.http_request import HttpRequest - - # 创建外部 Session - external_session = aiohttp.ClientSession() - - # 创建 HttpRequest(传入外部 Session) - http_request = HttpRequest( - url="https://example.com", - api_key="test_key", - http_method="POST", - aio_session=external_session, - ) - - # 验证外部 Session 被存储 - assert http_request.get_external_aio_session() is external_session - - await external_session.close() - - @pytest.mark.asyncio - async def test_external_session_overrides_global(self): - """测试外部 Session 覆盖全局连接池""" - from dashscope.api_entities.http_request import HttpRequest - - # 启用全局连接池 - manager = await AioSessionManager.get_instance() - await manager.enable() - - # 创建外部 Session - external_session = aiohttp.ClientSession() - - # 创建 HttpRequest(传入外部 Session) - http_request = HttpRequest( - url="https://example.com", - api_key="test_key", - http_method="POST", - aio_session=external_session, - ) - - # 验证使用外部 Session - assert http_request.get_external_aio_session() is external_session - - await external_session.close() - - -class TestAioConnectionPoolEdgeCases: - """测试异步连接池边界情况""" - - @pytest.fixture(autouse=True) - async def cleanup(self): - """每个测试后清理""" - yield - await AioSessionManager.reset_instance() - - @pytest.mark.asyncio - async def test_configure_partial_params(self): - """测试部分配置参数""" - manager = await AioSessionManager.get_instance() - await manager.enable() - - # 只配置部分参数 - await manager.configure(limit=200) - config = manager.get_config() - assert config.limit == 200 - assert config.limit_per_host == 30 # 保持默认值 - - @pytest.mark.asyncio - async def test_enable_multiple_times(self): - """测试多次启用""" - manager = await AioSessionManager.get_instance() - - await manager.enable() - session1 = await manager.get_session_direct() - - await manager.enable() - session2 = await manager.get_session_direct() - - # 多次启用不会重新创建 Session - assert session1 is session2 - - @pytest.mark.asyncio - async def test_disable_multiple_times(self): - """测试多次禁用""" - manager = await AioSessionManager.get_instance() - await manager.enable() - - await manager.disable() - await manager.disable() # 不应该报错 - - assert not manager.is_enabled() - - @pytest.mark.asyncio - async def test_reset_multiple_times(self): - """测试多次重置""" - manager = await AioSessionManager.get_instance() - await manager.enable() - - await manager.reset() - await manager.reset() # 不应该报错 - - assert manager.is_enabled() - assert await manager.has_active_session() - - @pytest.mark.asyncio - async def test_configure_with_no_params(self): - """测试无参数配置""" - manager = await AioSessionManager.get_instance() - await manager.enable() - - old_config = manager.get_config() - await manager.configure() # 不传参数 - new_config = manager.get_config() - - # 配置应该保持不变 - assert old_config.limit == new_config.limit - assert old_config.limit_per_host == new_config.limit_per_host - - -if __name__ == "__main__": - pytest.main([__file__, "-v"]) diff --git a/tests/unit/test_async_custom_session.py b/tests/unit/test_async_custom_session.py new file mode 100644 index 0000000..180b478 --- /dev/null +++ b/tests/unit/test_async_custom_session.py @@ -0,0 +1,659 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Alibaba, Inc. and its affiliates. + +""" +异步 HTTP 自定义 Session 功能单元测试 + +测试范围: +1. HttpRequest 接受自定义 aiohttp.ClientSession 参数 +2. 自定义 aio_session 的使用和资源管理 +3. 临时 aio_session 的创建和清理 +4. Session 优先级逻辑 +5. 不同场景下的异步 Session 行为 + +注意:所有测试都不依赖真实的 API Key +""" + +# pylint: disable=protected-access,unused-argument,unused-variable +# pylint: disable=broad-exception-raised + +import ssl +from unittest.mock import patch, AsyncMock + +import aiohttp +import certifi +import pytest + +from dashscope.api_entities.http_request import HttpRequest +from dashscope.api_entities.api_request_data import ApiRequestData +from dashscope.common.constants import ApiProtocol, HTTPMethod + + +class TestAsyncSessionBasics: + """测试异步 Session 基本功能""" + + @pytest.mark.asyncio + async def test_http_request_accepts_aio_session_parameter(self): + """测试 HttpRequest 接受 aio_session 参数""" + connector = aiohttp.TCPConnector() + custom_session = aiohttp.ClientSession(connector=connector) + + try: + http_request = HttpRequest( + url="http://example.com/api", + api_key="fake-api-key", + http_method=HTTPMethod.POST, + session=custom_session, + ) + + assert http_request._external_aio_session is custom_session + assert http_request._external_aio_session is not None + finally: + await custom_session.close() + + @pytest.mark.asyncio + async def test_http_request_without_aio_session_parameter(self): + """测试 HttpRequest 不传 aio_session 参数""" + http_request = HttpRequest( + url="http://example.com/api", + api_key="fake-api-key", + http_method=HTTPMethod.POST, + ) + + assert http_request._external_aio_session is None + + @pytest.mark.asyncio + async def test_aio_session_parameter_is_optional(self): + """测试 aio_session 参数是可选的""" + # 不传 aio_session 参数应该正常工作 + http_request = HttpRequest( + url="http://example.com/api", + api_key="fake-api-key", + http_method=HTTPMethod.POST, + stream=False, + ) + + assert http_request._external_aio_session is None + assert http_request.url == "http://example.com/api" + + +class TestAsyncSessionUsage: + """测试异步 Session 的实际使用""" + + @pytest.mark.asyncio + async def test_custom_aio_session_is_used_for_request(self): + """测试自定义 aio_session 被实际用于请求""" + # 创建 mock session + mock_session = AsyncMock() + mock_response = AsyncMock() + mock_response.status = 200 + mock_response.headers = {"content-type": "application/json"} + mock_response.__aenter__ = AsyncMock(return_value=mock_response) + mock_response.__aexit__ = AsyncMock(return_value=None) + + # Make request() return an awaitable + async def mock_request(*_args, **_kwargs): + return mock_response + + mock_session.request = mock_request + + # 创建 HttpRequest 并传入自定义 session + http_request = HttpRequest( + url="http://example.com/api", + api_key="fake-api-key", + http_method=HTTPMethod.POST, + stream=False, + session=mock_session, + ) + + # 添加请求数据 + request_data = ApiRequestData( + model="test-model", + task_group="test", + task="test", + function="test", + input_data={"test": "data"}, + form=None, + is_binary_input=False, + api_protocol=ApiProtocol.HTTPS, + ) + http_request.data = request_data + + # 执行请求 + async def mock_handle_response(_response): + yield mock_response + + with patch.object( + http_request, + "_handle_aio_response", + side_effect=mock_handle_response, + ): + _ = await http_request.aio_call() + + # 验证自定义 session 没有被关闭 + mock_session.close.assert_not_called() + + @pytest.mark.asyncio + async def test_temporary_aio_session_is_created_when_no_custom_session( + self, + ): + """测试没有自定义 aio_session 时会创建临时 aio_session""" + # 创建 mock session + mock_session = AsyncMock() + mock_response = AsyncMock() + mock_response.status = 200 + mock_response.headers = {"content-type": "application/json"} + mock_response.__aenter__ = AsyncMock(return_value=mock_response) + mock_response.__aexit__ = AsyncMock(return_value=None) + mock_session.request.return_value = mock_response + + # 创建 HttpRequest 不传 aio_session + http_request = HttpRequest( + url="http://example.com/api", + api_key="fake-api-key", + http_method=HTTPMethod.POST, + stream=False, + ) + + # 添加请求数据 + request_data = ApiRequestData( + model="test-model", + task_group="test", + task="test", + function="test", + input_data={"test": "data"}, + form=None, + is_binary_input=False, + api_protocol=ApiProtocol.HTTPS, + ) + http_request.data = request_data + + # 执行请求 + async def mock_handle_response(_response): + yield mock_response + + with patch("aiohttp.ClientSession", return_value=mock_session): + with patch.object( + http_request, + "_handle_aio_response", + side_effect=mock_handle_response, + ): + _ = await http_request.aio_call() + + # 验证临时 aio_session 被关闭 + mock_session.close.assert_called_once() + + +class TestAsyncSessionResourceManagement: + """测试异步 Session 资源管理""" + + @pytest.mark.asyncio + async def test_custom_aio_session_not_closed_by_http_request(self): + """测试自定义 aio_session 不会被 HttpRequest 关闭""" + custom_session = AsyncMock(spec=aiohttp.ClientSession) + mock_response = AsyncMock() + mock_response.status = 200 + mock_response.headers = {"content-type": "application/json"} + mock_response.__aenter__ = AsyncMock(return_value=mock_response) + mock_response.__aexit__ = AsyncMock(return_value=None) + + # Make request() return an awaitable + async def mock_request(*_args, **_kwargs): + return mock_response + + custom_session.request = mock_request + + http_request = HttpRequest( + url="http://example.com/api", + api_key="fake-api-key", + http_method=HTTPMethod.POST, + stream=False, + session=custom_session, + ) + + request_data = ApiRequestData( + model="test-model", + task_group="test", + task="test", + function="test", + input_data={"test": "data"}, + form=None, + is_binary_input=False, + api_protocol=ApiProtocol.HTTPS, + ) + http_request.data = request_data + + async def mock_handle_response(_response): + yield mock_response + + with patch.object( + http_request, + "_handle_aio_response", + side_effect=mock_handle_response, + ): + _ = await http_request.aio_call() + + # 验证自定义 aio_session 没有被关闭 + custom_session.close.assert_not_called() + + @pytest.mark.asyncio + async def test_temporary_aio_session_closed_on_success(self): + """测试临时 aio_session 在成功后被关闭""" + mock_session = AsyncMock() + mock_response = AsyncMock() + mock_response.status = 200 + mock_response.headers = {"content-type": "application/json"} + mock_response.__aenter__ = AsyncMock(return_value=mock_response) + mock_response.__aexit__ = AsyncMock(return_value=None) + mock_session.request.return_value = mock_response + + http_request = HttpRequest( + url="http://example.com/api", + api_key="fake-api-key", + http_method=HTTPMethod.POST, + stream=False, + ) + + request_data = ApiRequestData( + model="test-model", + task_group="test", + task="test", + function="test", + input_data={"test": "data"}, + form=None, + is_binary_input=False, + api_protocol=ApiProtocol.HTTPS, + ) + http_request.data = request_data + + async def mock_handle_response(_response): + yield mock_response + + with patch("aiohttp.ClientSession", return_value=mock_session): + with patch.object( + http_request, + "_handle_aio_response", + side_effect=mock_handle_response, + ): + _ = await http_request.aio_call() + + # 验证临时 aio_session 被关闭 + mock_session.close.assert_called_once() + + @pytest.mark.asyncio + async def test_temporary_aio_session_closed_on_exception(self): + """测试临时 aio_session 在异常时也被关闭""" + mock_session = AsyncMock() + + # Make request() raise an exception + async def mock_request(*_args, **_kwargs): + raise Exception("Network error") + + mock_session.request = mock_request + + http_request = HttpRequest( + url="http://example.com/api", + api_key="fake-api-key", + http_method=HTTPMethod.POST, + stream=False, + ) + + request_data = ApiRequestData( + model="test-model", + task_group="test", + task="test", + function="test", + input_data={"test": "data"}, + form=None, + is_binary_input=False, + api_protocol=ApiProtocol.HTTPS, + ) + http_request.data = request_data + + # 执行请求应该抛出异常 + with patch("aiohttp.ClientSession", return_value=mock_session): + with pytest.raises(Exception, match="Network error"): + _ = await http_request.aio_call() + + # 验证临时 aio_session 仍然被关闭 + mock_session.close.assert_called_once() + + +class TestAsyncSessionWithCustomConfiguration: + """测试自定义配置的异步 Session""" + + @pytest.mark.asyncio + async def test_custom_aio_session_with_connector(self): + """测试带自定义 connector 的 aio_session""" + connector = aiohttp.TCPConnector( + limit=100, + limit_per_host=30, + ssl=ssl.create_default_context(cafile=certifi.where()), + ) + custom_session = aiohttp.ClientSession(connector=connector) + + try: + http_request = HttpRequest( + url="http://example.com/api", + api_key="fake-api-key", + http_method=HTTPMethod.POST, + session=custom_session, + ) + + assert http_request._external_aio_session is custom_session + # 验证 connector 已配置 + assert custom_session.connector is not None + finally: + await custom_session.close() + + @pytest.mark.asyncio + async def test_custom_aio_session_with_headers(self): + """测试带自定义 headers 的 aio_session""" + custom_headers = { + "User-Agent": "Custom-Agent/1.0", + "X-Custom-Header": "custom-value", + } + custom_session = aiohttp.ClientSession(headers=custom_headers) + + try: + http_request = HttpRequest( + url="http://example.com/api", + api_key="fake-api-key", + http_method=HTTPMethod.POST, + session=custom_session, + ) + + assert http_request._external_aio_session is custom_session + # 验证 headers 已配置 + assert "User-Agent" in custom_session.headers + finally: + await custom_session.close() + + @pytest.mark.asyncio + async def test_custom_aio_session_with_timeout(self): + """测试带自定义 timeout 的 aio_session""" + timeout = aiohttp.ClientTimeout(total=60) + custom_session = aiohttp.ClientSession(timeout=timeout) + + try: + http_request = HttpRequest( + url="http://example.com/api", + api_key="fake-api-key", + http_method=HTTPMethod.POST, + session=custom_session, + ) + + assert http_request._external_aio_session is custom_session + # 验证 timeout 已配置 + assert custom_session.timeout is not None + finally: + await custom_session.close() + + +class TestAsyncSessionPriority: + """测试异步 Session 优先级""" + + @pytest.mark.asyncio + async def test_custom_aio_session_has_priority(self): + """测试自定义 aio_session 优先于临时 aio_session""" + connector = aiohttp.TCPConnector() + custom_session = aiohttp.ClientSession(connector=connector) + + try: + http_request = HttpRequest( + url="http://example.com/api", + api_key="fake-api-key", + http_method=HTTPMethod.POST, + session=custom_session, + ) + + # 验证存储了自定义 aio_session + assert http_request._external_aio_session is custom_session + assert http_request._external_aio_session is not None + finally: + await custom_session.close() + + +class TestAsyncSessionWithDifferentMethods: + """测试不同 HTTP 方法的异步 Session 使用""" + + @pytest.mark.asyncio + async def test_custom_aio_session_with_post_request(self): + """测试 POST 请求使用自定义 session""" + mock_session = AsyncMock() + mock_response = AsyncMock() + mock_response.status = 200 + mock_response.__aenter__ = AsyncMock(return_value=mock_response) + mock_response.__aexit__ = AsyncMock(return_value=None) + + # Make request() return an awaitable + async def mock_request(*_args, **_kwargs): + return mock_response + + mock_session.request = mock_request + + http_request = HttpRequest( + url="http://example.com/api", + api_key="fake-api-key", + http_method=HTTPMethod.POST, + stream=False, + session=mock_session, + ) + + request_data = ApiRequestData( + model="test-model", + task_group="test", + task="test", + function="test", + input_data={"test": "data"}, + form=None, + is_binary_input=False, + api_protocol=ApiProtocol.HTTPS, + ) + http_request.data = request_data + + async def mock_handle_response(_response): + yield mock_response + + with patch.object( + http_request, + "_handle_aio_response", + side_effect=mock_handle_response, + ): + _ = await http_request.aio_call() + + # Test passed if no exception + + @pytest.mark.asyncio + async def test_custom_aio_session_with_get_request(self): + """测试 GET 请求使用自定义 session""" + mock_session = AsyncMock() + mock_response = AsyncMock() + mock_response.status = 200 + mock_response.__aenter__ = AsyncMock(return_value=mock_response) + mock_response.__aexit__ = AsyncMock(return_value=None) + + # Make get() return an awaitable + async def mock_get(*args, **kwargs): + return mock_response + + mock_session.get = mock_get + + http_request = HttpRequest( + url="http://example.com/api", + api_key="fake-api-key", + http_method=HTTPMethod.GET, + stream=False, + session=mock_session, + ) + + request_data = ApiRequestData( + model="test-model", + task_group="test", + task="test", + function="test", + input_data={"test": "data"}, + form=None, + is_binary_input=False, + api_protocol=ApiProtocol.HTTPS, + ) + http_request.data = request_data + + async def mock_handle_response(_response): + yield mock_response + + with patch.object( + http_request, + "_handle_aio_response", + side_effect=mock_handle_response, + ): + _ = await http_request.aio_call() + + # Test passed if no exception + + +class TestAsyncBackwardCompatibility: + """测试异步向后兼容性""" + + @pytest.mark.asyncio + async def test_works_without_aio_session_parameter(self): + """测试不传 aio_session 参数时保持原有行为""" + http_request = HttpRequest( + url="http://example.com/api", + api_key="fake-api-key", + http_method=HTTPMethod.POST, + stream=False, + ) + + # 验证不传 aio_session 时,_external_aio_session 为 None + assert http_request._external_aio_session is None + + # 验证其他参数正常 + assert http_request.url == "http://example.com/api" + assert http_request.method == HTTPMethod.POST + + @pytest.mark.asyncio + async def test_default_behavior_unchanged(self): + """测试默认行为未改变""" + mock_session = AsyncMock() + mock_response = AsyncMock() + mock_response.status = 200 + mock_response.headers = {"content-type": "application/json"} + mock_response.__aenter__ = AsyncMock(return_value=mock_response) + mock_response.__aexit__ = AsyncMock(return_value=None) + mock_session.request.return_value = mock_response + + # 不传 aio_session 参数 + http_request = HttpRequest( + url="http://example.com/api", + api_key="fake-api-key", + http_method=HTTPMethod.POST, + stream=False, + ) + + request_data = ApiRequestData( + model="test-model", + task_group="test", + task="test", + function="test", + input_data={"test": "data"}, + form=None, + is_binary_input=False, + api_protocol=ApiProtocol.HTTPS, + ) + http_request.data = request_data + + async def mock_handle_response(_response): + yield mock_response + + with patch("aiohttp.ClientSession", return_value=mock_session): + with patch.object( + http_request, + "_handle_aio_response", + side_effect=mock_handle_response, + ): + _ = await http_request.aio_call() + + # 验证临时 aio_session 被关闭(原有行为) + mock_session.close.assert_called_once() + + +class TestAsyncSessionLifecycle: + """测试异步 Session 生命周期""" + + @pytest.mark.asyncio + async def test_multiple_requests_with_same_custom_session(self): + """测试使用同一个自定义 session 进行多次请求""" + mock_session = AsyncMock() + mock_response = AsyncMock() + mock_response.status = 200 + mock_response.__aenter__ = AsyncMock(return_value=mock_response) + mock_response.__aexit__ = AsyncMock(return_value=None) + + # Make request() return an awaitable + async def mock_request(*_args, **_kwargs): + return mock_response + + mock_session.request = mock_request + + # 第一次请求 + http_request1 = HttpRequest( + url="http://example.com/api", + api_key="fake-api-key", + http_method=HTTPMethod.POST, + stream=False, + session=mock_session, + ) + + request_data1 = ApiRequestData( + model="test-model", + task_group="test", + task="test", + function="test", + input_data={"test": "data1"}, + form=None, + is_binary_input=False, + api_protocol=ApiProtocol.HTTPS, + ) + http_request1.data = request_data1 + + async def mock_handle_response(_response): + yield mock_response + + with patch.object( + http_request1, + "_handle_aio_response", + side_effect=mock_handle_response, + ): + _ = await http_request1.aio_call() + + # 第二次请求 + http_request2 = HttpRequest( + url="http://example.com/api", + api_key="fake-api-key", + http_method=HTTPMethod.POST, + stream=False, + session=mock_session, + ) + + request_data2 = ApiRequestData( + model="test-model", + task_group="test", + task="test", + function="test", + input_data={"test": "data2"}, + form=None, + is_binary_input=False, + api_protocol=ApiProtocol.HTTPS, + ) + http_request2.data = request_data2 + + with patch.object( + http_request2, + "_handle_aio_response", + side_effect=mock_handle_response, + ): + _ = await http_request2.aio_call() + + # 验证 session 没有被关闭(因为是外部 session) + mock_session.close.assert_not_called() diff --git a/tests/unit/test_connection_pool.py b/tests/unit/test_connection_pool.py deleted file mode 100644 index 48ec833..0000000 --- a/tests/unit/test_connection_pool.py +++ /dev/null @@ -1,675 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright (c) Alibaba, Inc. and its affiliates. - -""" -HTTP 连接池功能单元测试 - -测试范围: -1. SessionManager 基本功能 -2. ConnectionPoolConfig 配置类 -3. HttpRequest 与 Session 集成 -4. 全局连接池 API -5. 自定义 Session 支持 -6. 线程安全性 -""" - -import threading -import time - -import pytest -import requests -from requests.adapters import HTTPAdapter - -import dashscope -from dashscope.common.session_manager import ( - SessionManager, - ConnectionPoolConfig, -) -from dashscope.api_entities.http_request import HttpRequest -from tests.unit.base_test import BaseTestEnvironment - - -class TestConnectionPoolConfig: - """测试 ConnectionPoolConfig 配置类""" - - def test_default_config(self): - """测试默认配置""" - config = ConnectionPoolConfig() - assert config.pool_connections == 10 - assert config.pool_maxsize == 20 - assert config.max_retries == 3 - assert config.pool_block is False - - def test_custom_config(self): - """测试自定义配置""" - config = ConnectionPoolConfig( - pool_connections=20, - pool_maxsize=50, - max_retries=5, - pool_block=True, - ) - assert config.pool_connections == 20 - assert config.pool_maxsize == 50 - assert config.max_retries == 5 - assert config.pool_block is True - - def test_config_validation(self): - """测试配置验证""" - # 测试负数验证 - with pytest.raises(ValueError, match="pool_connections 必须"): - ConnectionPoolConfig(pool_connections=0) - - with pytest.raises(ValueError, match="pool_maxsize 必须"): - ConnectionPoolConfig(pool_maxsize=0) - - with pytest.raises(ValueError, match="max_retries 必须"): - ConnectionPoolConfig(max_retries=-1) - - # 测试 pool_maxsize >= pool_connections - with pytest.raises( - ValueError, - match="pool_maxsize.*必须.*pool_connections", - ): - ConnectionPoolConfig(pool_connections=30, pool_maxsize=20) - - def test_config_to_dict(self): - """测试配置转换为字典""" - config = ConnectionPoolConfig( - pool_connections=15, - pool_maxsize=30, - max_retries=5, - pool_block=True, - ) - config_dict = config.to_dict() - assert config_dict == { - "pool_connections": 15, - "pool_maxsize": 30, - "max_retries": 5, - "pool_block": True, - } - - def test_config_str(self): - """测试配置字符串表示""" - config = ConnectionPoolConfig() - config_str = str(config) - assert "pool_connections=10" in config_str - assert "pool_maxsize=20" in config_str - assert "max_retries=3" in config_str - assert "pool_block=False" in config_str - - -class TestSessionManager: - """测试 SessionManager 单例类""" - - def setup_method(self): - """每个测试前重置 SessionManager""" - SessionManager.reset_instance() - - def teardown_method(self): - """每个测试后清理""" - manager = SessionManager.get_instance() - manager.reset() - - def test_singleton_pattern(self): - """测试单例模式""" - manager1 = SessionManager.get_instance() - manager2 = SessionManager.get_instance() - assert manager1 is manager2 - - def test_singleton_thread_safe(self): - """测试单例模式的线程安全性""" - instances = [] - - def get_instance(): - instances.append(SessionManager.get_instance()) - - threads = [threading.Thread(target=get_instance) for _ in range(10)] - for t in threads: - t.start() - for t in threads: - t.join() - - # 所有实例应该是同一个 - assert all(inst is instances[0] for inst in instances) - - def test_enable_disable(self): - """测试启用和禁用连接池""" - manager = SessionManager.get_instance() - - # 默认禁用 - assert not manager.is_enabled() - - # 启用 - manager.enable() - assert manager.is_enabled() - - # 禁用 - manager.disable() - assert not manager.is_enabled() - - def test_enable_with_config(self): - """测试使用配置启用连接池""" - manager = SessionManager.get_instance() - - manager.enable( - pool_connections=15, - pool_maxsize=30, - max_retries=5, - pool_block=True, - ) - - assert manager.is_enabled() - config = manager.get_config() - assert config.pool_connections == 15 - assert config.pool_maxsize == 30 - assert config.max_retries == 5 - assert config.pool_block is True - - def test_configure(self): - """测试配置连接池""" - manager = SessionManager.get_instance() - manager.enable() - - # 配置连接池 - manager.configure( - pool_connections=25, - pool_maxsize=50, - ) - - config = manager.get_config() - assert config.pool_connections == 25 - assert config.pool_maxsize == 50 - assert config.max_retries == 3 # 保持默认值 - - def test_get_session_when_disabled(self): - """测试禁用时获取 Session(直接方式)""" - manager = SessionManager.get_instance() - manager.disable() - - session = manager.get_session() - assert session is None - - def test_get_session_when_enabled(self): - """测试启用时获取 Session(直接方式)""" - manager = SessionManager.get_instance() - manager.enable() - - session = manager.get_session() - assert session is not None - assert isinstance(session, requests.Session) - - def test_get_session_returns_same_instance(self): - """测试获取 Session 返回同一实例""" - manager = SessionManager.get_instance() - manager.enable() - - session1 = manager.get_session() - session2 = manager.get_session() - assert session1 is session2 - - def test_get_session(self): - """测试直接获取 Session""" - manager = SessionManager.get_instance() - - # 启用时能获取 - manager.enable() - session = manager.get_session() - assert session is not None - assert isinstance(session, requests.Session) - - # 禁用时返回 None - manager.disable() - session = manager.get_session() - assert session is None - - def test_reset(self): - """测试重置连接池""" - manager = SessionManager.get_instance() - manager.enable() - - old_session = manager.get_session() - assert old_session is not None - - # 禁用后重置 - manager.disable() - manager.reset() - - # Session 应该被清理 - assert not manager.has_active_session() - assert not manager.is_enabled() - - # 重新启用后应该是新的 Session - manager.enable() - new_session = manager.get_session() - assert new_session is not old_session - - def test_session_has_adapter(self): - """测试 Session 配置了 HTTPAdapter""" - manager = SessionManager.get_instance() - manager.enable(pool_connections=15, pool_maxsize=30) - - session = manager.get_session() - assert session is not None - - # 检查是否配置了 HTTPAdapter - http_adapter = session.get_adapter("http://") - https_adapter = session.get_adapter("https://") - - assert isinstance(http_adapter, HTTPAdapter) - assert isinstance(https_adapter, HTTPAdapter) - - def test_thread_safe_session_creation(self): - """测试多线程环境下 Session 创建的线程安全性""" - manager = SessionManager.get_instance() - manager.enable() - - sessions = [] - - def get_session(): - sessions.append(manager.get_session()) - - threads = [threading.Thread(target=get_session) for _ in range(20)] - for t in threads: - t.start() - for t in threads: - t.join() - - # 所有线程应该获取到同一个 Session - assert all(s is sessions[0] for s in sessions) - - -class TestHttpRequestSessionIntegration: - """测试 HttpRequest 与 Session 的集成""" - - def test_http_request_accepts_session(self): - """测试 HttpRequest 接受 session 参数""" - custom_session = requests.Session() - - http_request = HttpRequest( - url="http://example.com/api", - api_key="test-key", - http_method="POST", - session=custom_session, - ) - - assert http_request.get_external_session() is custom_session - - def test_http_request_without_session(self): - """测试 HttpRequest 不传 session 参数""" - http_request = HttpRequest( - url="http://example.com/api", - api_key="test-key", - http_method="POST", - ) - - assert http_request.get_external_session() is None - - def test_http_request_uses_external_session_priority(self): - """测试 HttpRequest 优先使用外部传入的 Session""" - # 创建自定义 Session - custom_session = requests.Session() - custom_session.headers.update({"X-Test": "custom"}) - - # 创建 HttpRequest,传入自定义 Session - http_request = HttpRequest( - url="http://example.com/api", - api_key="test-key", - http_method="POST", - session=custom_session, - ) - - # 验证使用了自定义 Session - assert http_request.get_external_session() is custom_session - assert ( - http_request.get_external_session().headers.get("X-Test") - == "custom" - ) - - def test_http_request_session_priority(self): - """测试 Session 优先级:外部 > 全局 > 临时""" - # 1. 外部 Session 优先级最高 - custom_session = requests.Session() - http_request = HttpRequest( - url="http://example.com/api", - api_key="test-key", - http_method="POST", - session=custom_session, - ) - assert http_request.get_external_session() is custom_session - - # 2. 没有外部 Session 时,应该尝试使用全局 Session - http_request_no_session = HttpRequest( - url="http://example.com/api", - api_key="test-key", - http_method="POST", - ) - assert http_request_no_session.get_external_session() is None - - -class TestGlobalConnectionPoolAPI(BaseTestEnvironment): - """测试全局连接池 API""" - - def setup_method(self): - """每个测试前重置""" - super().setup_class() - SessionManager.reset_instance() - - def teardown_method(self): - """每个测试后清理""" - dashscope.disable_http_connection_pool() - super().teardown_class() - - def test_enable_http_connection_pool(self): - """测试启用 HTTP 连接池""" - dashscope.enable_http_connection_pool() - - manager = SessionManager.get_instance() - assert manager.is_enabled() - - def test_enable_http_connection_pool_with_params(self): - """测试使用参数启用 HTTP 连接池""" - dashscope.enable_http_connection_pool( - pool_connections=15, - pool_maxsize=30, - max_retries=5, - pool_block=True, - ) - - manager = SessionManager.get_instance() - assert manager.is_enabled() - - config = manager.get_config() - assert config.pool_connections == 15 - assert config.pool_maxsize == 30 - assert config.max_retries == 5 - assert config.pool_block is True - - def test_disable_http_connection_pool(self): - """测试禁用 HTTP 连接池""" - dashscope.enable_http_connection_pool() - assert SessionManager.get_instance().is_enabled() - - dashscope.disable_http_connection_pool() - assert not SessionManager.get_instance().is_enabled() - - def test_reset_http_connection_pool(self): - """测试重置 HTTP 连接池""" - dashscope.enable_http_connection_pool() - # 验证 session 存在 - assert SessionManager.get_instance().get_session() is not None - - # 禁用后重置 - dashscope.disable_http_connection_pool() - dashscope.reset_http_connection_pool() - - # Session 应该被清理 - manager = SessionManager.get_instance() - assert not manager.has_active_session() - assert not manager.is_enabled() - - def test_configure_http_connection_pool(self): - """测试配置 HTTP 连接池""" - dashscope.enable_http_connection_pool() - - dashscope.configure_http_connection_pool( - pool_connections=25, - pool_maxsize=50, - ) - - config = SessionManager.get_instance().get_config() - assert config.pool_connections == 25 - assert config.pool_maxsize == 50 - - def test_configure_before_enable(self): - """测试在启用前配置""" - # 先启用 - dashscope.enable_http_connection_pool() - - # 然后配置 - dashscope.configure_http_connection_pool( - pool_connections=20, - pool_maxsize=40, - ) - - manager = SessionManager.get_instance() - assert manager.is_enabled() - - config = manager.get_config() - assert config.pool_connections == 20 - assert config.pool_maxsize == 40 - - -class TestCustomSessionSupport: - """测试自定义 Session 支持""" - - def test_custom_session_with_headers(self): - """测试自定义 Session 带请求头""" - session = requests.Session() - session.headers.update({"X-Custom-Header": "TestValue"}) - - http_request = HttpRequest( - url="http://example.com/api", - api_key="test-key", - http_method="POST", - session=session, - ) - - assert http_request.get_external_session() is session - assert session.headers.get("X-Custom-Header") == "TestValue" - - def test_custom_session_with_proxies(self): - """测试自定义 Session 带代理""" - session = requests.Session() - session.proxies = {"https": "https://proxy.example.com:8080"} - - http_request = HttpRequest( - url="http://example.com/api", - api_key="test-key", - http_method="POST", - session=session, - ) - - assert http_request.get_external_session() is session - assert session.proxies.get("https") == "https://proxy.example.com:8080" - - def test_custom_session_with_adapter(self): - """测试自定义 Session 带自定义 Adapter""" - session = requests.Session() - adapter = HTTPAdapter( - pool_connections=50, - pool_maxsize=100, - ) - session.mount("https://", adapter) - session.mount("http://", adapter) - - http_request = HttpRequest( - url="http://example.com/api", - api_key="test-key", - http_method="POST", - session=session, - ) - - assert http_request.get_external_session() is session - - # 验证 Adapter 配置 - http_adapter = session.get_adapter("http://") - assert isinstance(http_adapter, HTTPAdapter) - - -class TestThreadSafety: - """测试线程安全性""" - - def setup_method(self): - """每个测试前重置""" - SessionManager.reset_instance() - - def teardown_method(self): - """每个测试后清理""" - manager = SessionManager.get_instance() - manager.reset() - - def test_concurrent_enable_disable(self): - """测试并发启用和禁用""" - manager = SessionManager.get_instance() - - def toggle_enable(): - for _ in range(10): - manager.enable() - time.sleep(0.001) - manager.disable() - time.sleep(0.001) - - threads = [threading.Thread(target=toggle_enable) for _ in range(5)] - for t in threads: - t.start() - for t in threads: - t.join() - - # 不应该抛出异常 - assert True - - def test_concurrent_get_session(self): - """测试并发获取 Session""" - manager = SessionManager.get_instance() - manager.enable() - - sessions = [] - - def get_session(): - for _ in range(10): - s = manager.get_session() - sessions.append(s) - time.sleep(0.001) - - threads = [threading.Thread(target=get_session) for _ in range(10)] - for t in threads: - t.start() - for t in threads: - t.join() - - # 所有获取的 Session 应该是同一个 - assert all(s is sessions[0] for s in sessions) - - def test_concurrent_configure(self): - """测试并发配置""" - manager = SessionManager.get_instance() - manager.enable() - - def configure(): - for i in range(5): - manager.configure( - pool_connections=10 + i, - pool_maxsize=20 + i * 2, - ) - time.sleep(0.001) - - threads = [threading.Thread(target=configure) for _ in range(5)] - for thread in threads: - thread.start() - for thread in threads: - thread.join() - - # 不应该抛出异常,最终配置应该是有效的 - config = manager.get_config() - assert config.pool_connections > 0 - assert config.pool_maxsize >= config.pool_connections - - -class TestEdgeCases: - """测试边界情况""" - - def setup_method(self): - """每个测试前重置""" - SessionManager.reset_instance() - - def teardown_method(self): - """每个测试后清理""" - manager = SessionManager.get_instance() - manager.reset() - - def test_enable_multiple_times(self): - """测试多次启用""" - manager = SessionManager.get_instance() - - manager.enable() - session1 = manager.get_session() - - manager.enable() - session2 = manager.get_session() - - # 应该返回同一个 Session - assert session1 is session2 - - def test_configure_with_partial_params(self): - """测试部分参数配置""" - manager = SessionManager.get_instance() - manager.enable() - - # 只配置部分参数 - manager.configure(pool_connections=15) - - config = manager.get_config() - assert config.pool_connections == 15 - assert config.pool_maxsize == 20 # 保持默认值 - assert config.max_retries == 3 # 保持默认值 - - def test_reset_when_disabled(self): - """测试禁用状态下重置""" - manager = SessionManager.get_instance() - manager.disable() - - # 不应该抛出异常 - manager.reset() - assert not manager.is_enabled() - - def test_get_session_after_reset(self): - """测试重置后获取 Session""" - manager = SessionManager.get_instance() - manager.enable() - - old_session = manager.get_session() - - # 禁用后重置 - manager.disable() - manager.reset() - - # 重置后应该返回 None - assert manager.get_session() is None - - # 重新启用后应该是新的 Session - manager.enable() - new_session = manager.get_session() - assert new_session is not None - assert new_session is not old_session - - -class TestBackwardCompatibility: - """测试向后兼容性""" - - def test_http_request_without_session_param(self): - """测试不传 session 参数的 HttpRequest(向后兼容)""" - # 不传 session 参数应该正常工作 - http_request = HttpRequest( - url="http://example.com/api", - api_key="test-key", - http_method="POST", - ) - - assert http_request.get_external_session() is None - - def test_default_behavior_unchanged(self): - """测试默认行为未改变(需要在干净环境中测试)""" - # 重置到初始状态 - manager = SessionManager.get_instance() - manager.disable() - manager.reset() - - # 默认应该是禁用状态 - assert not manager.is_enabled() - - # 默认获取 Session 应该返回 None - assert manager.get_session() is None - - -if __name__ == "__main__": - pytest.main([__file__, "-v", "-s"]) diff --git a/tests/unit/test_sync_custom_session.py b/tests/unit/test_sync_custom_session.py new file mode 100644 index 0000000..7b493f3 --- /dev/null +++ b/tests/unit/test_sync_custom_session.py @@ -0,0 +1,515 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Alibaba, Inc. and its affiliates. + +""" +同步 HTTP 自定义 Session 功能单元测试 + +测试范围: +1. HttpRequest 接受自定义 Session 参数 +2. 自定义 Session 的使用和资源管理 +3. 临时 Session 的创建和清理 +4. Session 优先级逻辑 +5. 不同场景下的 Session 行为 + +注意:所有测试都不依赖真实的 API Key +""" + +# pylint: disable=protected-access,unused-argument,unused-variable +# pylint: disable=broad-exception-raised + +from unittest.mock import Mock, patch + +import pytest +import requests +from requests.adapters import HTTPAdapter + +from dashscope.api_entities.http_request import HttpRequest +from dashscope.api_entities.api_request_data import ApiRequestData +from dashscope.common.constants import ApiProtocol, HTTPMethod + + +class TestSyncSessionBasics: + """测试同步 Session 基本功能""" + + def test_http_request_accepts_session_parameter(self): + """测试 HttpRequest 接受 session 参数""" + custom_session = requests.Session() + + http_request = HttpRequest( + url="http://example.com/api", + api_key="fake-api-key", + http_method=HTTPMethod.POST, + session=custom_session, + ) + + assert http_request._external_session is custom_session + assert http_request._external_session is not None + + def test_http_request_without_session_parameter(self): + """测试 HttpRequest 不传 session 参数""" + http_request = HttpRequest( + url="http://example.com/api", + api_key="fake-api-key", + http_method=HTTPMethod.POST, + ) + + assert http_request._external_session is None + + def test_session_parameter_is_optional(self): + """测试 session 参数是可选的""" + # 不传 session 参数应该正常工作 + http_request = HttpRequest( + url="http://example.com/api", + api_key="fake-api-key", + http_method=HTTPMethod.POST, + stream=False, + ) + + assert http_request._external_session is None + assert http_request.url == "http://example.com/api" + + +class TestSyncSessionUsage: + """测试同步 Session 的实际使用""" + + @patch("requests.Session") + def test_custom_session_is_used_for_request(self, _mock_session_class): + """测试自定义 session 被实际用于请求""" + # 创建 mock session + mock_session = Mock() + mock_response = Mock() + mock_response.status_code = 200 + mock_response.headers = {"content-type": "application/json"} + mock_response.text = '{"status": "success"}' + mock_session.post.return_value = mock_response + + # 创建 HttpRequest 并传入自定义 session + http_request = HttpRequest( + url="http://example.com/api", + api_key="fake-api-key", + http_method=HTTPMethod.POST, + stream=False, + session=mock_session, + ) + + # 添加请求数据 + request_data = ApiRequestData( + model="test-model", + task_group="test", + task="test", + function="test", + input_data={"test": "data"}, + form=None, + is_binary_input=False, + api_protocol=ApiProtocol.HTTPS, + ) + http_request.data = request_data + + # 执行请求 + with patch.object( + http_request, + "_handle_response", + return_value=iter([mock_response]), + ): + _ = http_request.call() + + # 验证自定义 session 被使用 + mock_session.post.assert_called_once() + + # 验证自定义 session 没有被关闭 + mock_session.close.assert_not_called() + + @patch("requests.Session") + def test_temporary_session_is_created_when_no_custom_session( + self, + mock_session_class, + ): + """测试没有自定义 session 时会创建临时 session""" + # 创建 mock session + mock_session = Mock() + mock_response = Mock() + mock_response.status_code = 200 + mock_response.headers = {"content-type": "application/json"} + mock_response.text = '{"status": "success"}' + mock_session.post.return_value = mock_response + mock_session_class.return_value = mock_session + + # 创建 HttpRequest 不传 session + http_request = HttpRequest( + url="http://example.com/api", + api_key="fake-api-key", + http_method=HTTPMethod.POST, + stream=False, + ) + + # 添加请求数据 + request_data = ApiRequestData( + model="test-model", + task_group="test", + task="test", + function="test", + input_data={"test": "data"}, + form=None, + is_binary_input=False, + api_protocol=ApiProtocol.HTTPS, + ) + http_request.data = request_data + + # 执行请求 + with patch.object( + http_request, + "_handle_response", + return_value=iter([mock_response]), + ): + _ = http_request.call() + + # 验证临时 session 被创建 + mock_session_class.assert_called_once() + + # 验证临时 session 被关闭 + mock_session.close.assert_called_once() + + +class TestSyncSessionResourceManagement: + """测试同步 Session 资源管理""" + + def test_custom_session_not_closed_by_http_request(self): + """测试自定义 session 不会被 HttpRequest 关闭""" + custom_session = Mock(spec=requests.Session) + mock_response = Mock() + mock_response.status_code = 200 + mock_response.headers = {"content-type": "application/json"} + mock_response.text = '{"status": "success"}' + custom_session.post.return_value = mock_response + + http_request = HttpRequest( + url="http://example.com/api", + api_key="fake-api-key", + http_method=HTTPMethod.POST, + stream=False, + session=custom_session, + ) + + request_data = ApiRequestData( + model="test-model", + task_group="test", + task="test", + function="test", + input_data={"test": "data"}, + form=None, + is_binary_input=False, + api_protocol=ApiProtocol.HTTPS, + ) + http_request.data = request_data + + with patch.object( + http_request, + "_handle_response", + return_value=iter([mock_response]), + ): + _ = http_request.call() + + # 验证自定义 session 没有被关闭 + custom_session.close.assert_not_called() + + @patch("requests.Session") + def test_temporary_session_closed_on_success(self, mock_session_class): + """测试临时 session 在成功后被关闭""" + mock_session = Mock() + mock_response = Mock() + mock_response.status_code = 200 + mock_response.headers = {"content-type": "application/json"} + mock_response.text = '{"status": "success"}' + mock_session.post.return_value = mock_response + mock_session_class.return_value = mock_session + + http_request = HttpRequest( + url="http://example.com/api", + api_key="fake-api-key", + http_method=HTTPMethod.POST, + stream=False, + ) + + request_data = ApiRequestData( + model="test-model", + task_group="test", + task="test", + function="test", + input_data={"test": "data"}, + form=None, + is_binary_input=False, + api_protocol=ApiProtocol.HTTPS, + ) + http_request.data = request_data + + with patch.object( + http_request, + "_handle_response", + return_value=iter([mock_response]), + ): + _ = http_request.call() + + # 验证临时 session 被关闭 + mock_session.close.assert_called_once() + + @patch("requests.Session") + def test_temporary_session_closed_on_exception(self, mock_session_class): + """测试临时 session 在异常时也被关闭""" + mock_session = Mock() + mock_session.post.side_effect = Exception("Network error") + mock_session_class.return_value = mock_session + + http_request = HttpRequest( + url="http://example.com/api", + api_key="fake-api-key", + http_method=HTTPMethod.POST, + stream=False, + ) + + request_data = ApiRequestData( + model="test-model", + task_group="test", + task="test", + function="test", + input_data={"test": "data"}, + form=None, + is_binary_input=False, + api_protocol=ApiProtocol.HTTPS, + ) + http_request.data = request_data + + # 执行请求应该抛出异常 + with pytest.raises(Exception, match="Network error"): + _ = http_request.call() + + # 验证临时 session 仍然被关闭 + mock_session.close.assert_called_once() + + +class TestSyncSessionWithCustomConfiguration: + """测试自定义配置的 Session""" + + def test_custom_session_with_connection_pool(self): + """测试带连接池配置的自定义 session""" + custom_session = requests.Session() + adapter = HTTPAdapter( + pool_connections=10, + pool_maxsize=20, + max_retries=3, + ) + custom_session.mount("http://", adapter) + custom_session.mount("https://", adapter) + + http_request = HttpRequest( + url="http://example.com/api", + api_key="fake-api-key", + http_method=HTTPMethod.POST, + session=custom_session, + ) + + assert http_request._external_session is custom_session + # 验证 adapter 已配置 + assert "http://" in custom_session.adapters + assert "https://" in custom_session.adapters + + def test_custom_session_with_headers(self): + """测试带自定义 headers 的 session""" + custom_session = requests.Session() + custom_session.headers.update( + { + "User-Agent": "Custom-Agent/1.0", + "X-Custom-Header": "custom-value", + }, + ) + + http_request = HttpRequest( + url="http://example.com/api", + api_key="fake-api-key", + http_method=HTTPMethod.POST, + session=custom_session, + ) + + assert http_request._external_session is custom_session + assert "User-Agent" in custom_session.headers + assert custom_session.headers["User-Agent"] == "Custom-Agent/1.0" + + def test_custom_session_with_proxies(self): + """测试带代理配置的 session""" + custom_session = requests.Session() + custom_session.proxies = { + "http": "http://proxy.example.com:8080", + "https": "https://proxy.example.com:8080", + } + + http_request = HttpRequest( + url="http://example.com/api", + api_key="fake-api-key", + http_method=HTTPMethod.POST, + session=custom_session, + ) + + assert http_request._external_session is custom_session + assert ( + custom_session.proxies["http"] == "http://proxy.example.com:8080" + ) + + +class TestSyncSessionPriority: + """测试 Session 优先级""" + + def test_custom_session_has_priority(self): + """测试自定义 session 优先于临时 session""" + custom_session = requests.Session() + + http_request = HttpRequest( + url="http://example.com/api", + api_key="fake-api-key", + http_method=HTTPMethod.POST, + session=custom_session, + ) + + # 验证存储了自定义 session + assert http_request._external_session is custom_session + assert http_request._external_session is not None + + +class TestSyncSessionWithDifferentMethods: + """测试不同 HTTP 方法的 Session 使用""" + + @patch("requests.Session") + def test_custom_session_with_post_request(self, _mock_session_class): + """测试 POST 请求使用自定义 session""" + mock_session = Mock() + mock_response = Mock() + mock_response.status_code = 200 + mock_session.post.return_value = mock_response + + http_request = HttpRequest( + url="http://example.com/api", + api_key="fake-api-key", + http_method=HTTPMethod.POST, + stream=False, + session=mock_session, + ) + + request_data = ApiRequestData( + model="test-model", + task_group="test", + task="test", + function="test", + input_data={"test": "data"}, + form=None, + is_binary_input=False, + api_protocol=ApiProtocol.HTTPS, + ) + http_request.data = request_data + + with patch.object( + http_request, + "_handle_response", + return_value=iter([mock_response]), + ): + _ = http_request.call() + + # 验证使用了 POST 方法 + mock_session.post.assert_called_once() + + @patch("requests.Session") + def test_custom_session_with_get_request(self, _mock_session_class): + """测试 GET 请求使用自定义 session""" + mock_session = Mock() + mock_response = Mock() + mock_response.status_code = 200 + mock_session.get.return_value = mock_response + + http_request = HttpRequest( + url="http://example.com/api", + api_key="fake-api-key", + http_method=HTTPMethod.GET, + stream=False, + session=mock_session, + ) + + request_data = ApiRequestData( + model="test-model", + task_group="test", + task="test", + function="test", + input_data={"test": "data"}, + form=None, + is_binary_input=False, + api_protocol=ApiProtocol.HTTPS, + ) + http_request.data = request_data + + with patch.object( + http_request, + "_handle_response", + return_value=iter([mock_response]), + ): + _ = http_request.call() + + # 验证使用了 GET 方法 + mock_session.get.assert_called_once() + + +class TestSyncBackwardCompatibility: + """测试向后兼容性""" + + def test_works_without_session_parameter(self): + """测试不传 session 参数时保持原有行为""" + http_request = HttpRequest( + url="http://example.com/api", + api_key="fake-api-key", + http_method=HTTPMethod.POST, + stream=False, + ) + + # 验证不传 session 时,_external_session 为 None + assert http_request._external_session is None + + # 验证其他参数正常 + assert http_request.url == "http://example.com/api" + assert http_request.method == HTTPMethod.POST + + @patch("requests.Session") + def test_default_behavior_unchanged(self, mock_session_class): + """测试默认行为未改变""" + mock_session = Mock() + mock_response = Mock() + mock_response.status_code = 200 + mock_response.headers = {"content-type": "application/json"} + mock_response.text = '{"status": "success"}' + mock_session.post.return_value = mock_response + mock_session_class.return_value = mock_session + + # 不传 session 参数 + http_request = HttpRequest( + url="http://example.com/api", + api_key="fake-api-key", + http_method=HTTPMethod.POST, + stream=False, + ) + + request_data = ApiRequestData( + model="test-model", + task_group="test", + task="test", + function="test", + input_data={"test": "data"}, + form=None, + is_binary_input=False, + api_protocol=ApiProtocol.HTTPS, + ) + http_request.data = request_data + + with patch.object( + http_request, + "_handle_response", + return_value=iter([mock_response]), + ): + _ = http_request.call() + + # 验证临时 session 被创建和关闭(原有行为) + mock_session_class.assert_called_once() + mock_session.close.assert_called_once()