Skip to content

Commit e6624cc

Browse files
committed
feat(http): reuse http session
1 parent dbca9d3 commit e6624cc

File tree

6 files changed

+1443
-14
lines changed

6 files changed

+1443
-14
lines changed

dashscope/api_entities/api_request_factory.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def _get_protocol_params(kwargs):
3636
base_address = kwargs.pop("base_address", None)
3737
flattened_output = kwargs.pop("flattened_output", False)
3838
extra_url_parameters = kwargs.pop("extra_url_parameters", None)
39+
session = kwargs.pop("session", None)
3940

4041
# Extract user-agent from headers if present
4142
user_agent = ""
@@ -58,6 +59,7 @@ def _get_protocol_params(kwargs):
5859
flattened_output,
5960
extra_url_parameters,
6061
user_agent,
62+
session,
6163
)
6264

6365

@@ -87,6 +89,7 @@ def _build_api_request( # pylint: disable=too-many-branches
8789
flattened_output,
8890
extra_url_parameters,
8991
user_agent,
92+
session,
9093
) = _get_protocol_params(kwargs)
9194
task_id = kwargs.pop("task_id", None)
9295
enable_encryption = kwargs.pop("enable_encryption", False)
@@ -130,6 +133,7 @@ def _build_api_request( # pylint: disable=too-many-branches
130133
flattened_output=flattened_output,
131134
encryption=encryption,
132135
user_agent=user_agent,
136+
session=session,
133137
)
134138
elif api_protocol == ApiProtocol.WEBSOCKET:
135139
if base_address is not None:

dashscope/api_entities/http_request.py

Lines changed: 67 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import json
55
import ssl
66
from http import HTTPStatus
7-
from typing import Optional
7+
from typing import Optional, Dict, Union
88

99
import aiohttp
1010
import certifi
@@ -42,6 +42,9 @@ def __init__(
4242
flattened_output: bool = False,
4343
encryption: Optional[Encryption] = None,
4444
user_agent: str = "",
45+
session: Optional[
46+
Union[requests.Session, aiohttp.ClientSession]
47+
] = None,
4548
) -> None:
4649
"""HttpSSERequest, processing http server sent event stream.
4750
@@ -54,14 +57,38 @@ def __init__(
5457
Defaults to DEFAULT_REQUEST_TIMEOUT_SECONDS.
5558
user_agent (str, optional): Additional user agent string to
5659
append. Defaults to ''.
60+
session (Optional[Union[requests.Session,
61+
aiohttp.ClientSession]], optional):
62+
Custom Session for connection reuse. Can be either
63+
requests.Session for sync calls or aiohttp.ClientSession
64+
for async calls. Defaults to None.
5765
"""
5866

5967
super().__init__(user_agent=user_agent)
6068
self.url = url
6169
self.flattened_output = flattened_output
6270
self.async_request = async_request
6371
self.encryption = encryption
64-
self.headers = {
72+
73+
# Auto-detect session type and store accordingly
74+
if session is not None:
75+
session_type = type(session).__name__
76+
session_module = type(session).__module__
77+
78+
# Check if it's an aiohttp ClientSession
79+
if (
80+
session_type == "ClientSession" and "aiohttp" in session_module
81+
) or isinstance(session, aiohttp.ClientSession):
82+
self._external_session = None
83+
self._external_aio_session = session
84+
else:
85+
# Treat as requests Session
86+
self._external_session = session
87+
self._external_aio_session = None
88+
else:
89+
self._external_session = None
90+
self._external_aio_session = None
91+
self.headers: Dict = {
6592
"Accept": "application/json",
6693
"Authorization": f"Bearer {api_key}",
6794
**self.headers,
@@ -132,18 +159,27 @@ async def aio_call(self):
132159
pass
133160
return result
134161

135-
async def _handle_aio_request(self):
162+
async def _handle_aio_request(self): # pylint: disable=too-many-branches
136163
try:
137-
connector = aiohttp.TCPConnector(
138-
ssl=ssl.create_default_context(
139-
cafile=certifi.where(),
140-
),
141-
)
142-
async with aiohttp.ClientSession(
143-
connector=connector,
144-
timeout=aiohttp.ClientTimeout(total=self.timeout),
145-
headers=self.headers,
146-
) as session:
164+
# Use external aio_session if provided,
165+
# otherwise create temporary session
166+
if self._external_aio_session is not None:
167+
session = self._external_aio_session
168+
should_close = False
169+
else:
170+
connector = aiohttp.TCPConnector(
171+
ssl=ssl.create_default_context(
172+
cafile=certifi.where(),
173+
),
174+
)
175+
session = aiohttp.ClientSession(
176+
connector=connector,
177+
timeout=aiohttp.ClientTimeout(total=self.timeout),
178+
headers=self.headers,
179+
)
180+
should_close = True
181+
182+
try:
147183
logger.debug("Starting request: %s", self.url)
148184
if self.method == HTTPMethod.POST:
149185
is_form, obj = False, {}
@@ -183,6 +219,10 @@ async def _handle_aio_request(self):
183219
async with response:
184220
async for rsp in self._handle_aio_response(response):
185221
yield rsp
222+
finally:
223+
# Only close if we created the session
224+
if should_close:
225+
await session.close()
186226
except aiohttp.ClientConnectorError as e:
187227
logger.error(e)
188228
raise e
@@ -408,7 +448,16 @@ def _handle_response( # pylint: disable=too-many-branches
408448

409449
def _handle_request(self):
410450
try:
411-
with requests.Session() as session:
451+
# Use external session if provided,
452+
# otherwise create temporary session
453+
if self._external_session is not None:
454+
session = self._external_session
455+
should_close = False
456+
else:
457+
session = requests.Session()
458+
should_close = True
459+
460+
try:
412461
if self.method == HTTPMethod.POST:
413462
is_form, form, obj = self.data.get_http_payload()
414463
if is_form:
@@ -443,6 +492,10 @@ def _handle_request(self):
443492
)
444493
for rsp in self._handle_response(response):
445494
yield rsp
495+
finally:
496+
# Only close if we created the session
497+
if should_close:
498+
session.close()
446499
except BaseException as e:
447500
logger.error(e)
448501
raise e

samples/test_aio_generation.py

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33

44
import asyncio
55
import os
6+
import aiohttp
7+
import ssl
8+
import certifi
69
from dashscope.aigc.generation import AioGeneration
710

811

@@ -439,6 +442,118 @@ async def process_responses(responses, step_name):
439442
await call_deep_research_model(messages, "第二步:深入研究")
440443
print("\n 研究完成!")
441444

445+
@staticmethod
446+
async def test_with_custom_session():
447+
"""示例:使用自定义 ClientSession 进行连接复用"""
448+
print("\n=== 使用自定义 ClientSession 示例 ===")
449+
450+
# 配置 TCPConnector
451+
connector = aiohttp.TCPConnector(
452+
limit=100,
453+
limit_per_host=30,
454+
ssl=ssl.create_default_context(cafile=certifi.where()),
455+
)
456+
457+
# 创建自定义 ClientSession
458+
async with aiohttp.ClientSession(connector=connector) as session:
459+
# 使用同一个 session 进行多次请求
460+
for i in range(3):
461+
print(f"\n--- 请求 {i+1} ---")
462+
463+
messages = [
464+
{"role": "system", "content": "You are a helpful assistant."},
465+
{"role": "user", "content": f"你好"},
466+
]
467+
468+
response = await AioGeneration.call(
469+
api_key=os.getenv("DASHSCOPE_API_KEY"),
470+
model="qwen-turbo",
471+
messages=messages,
472+
result_format="message",
473+
session=session, # ← 传入自定义 session
474+
)
475+
476+
print(f"响应: {response.output.choices[0].message.content}")
477+
478+
print("\n✅ ClientSession 已自动关闭")
479+
480+
@staticmethod
481+
async def test_with_custom_session_streaming():
482+
"""示例:使用自定义 ClientSession 进行流式输出"""
483+
print("\n=== 使用自定义 ClientSession 流式输出示例 ===")
484+
485+
# 配置连接池
486+
connector = aiohttp.TCPConnector(
487+
limit=100,
488+
ssl=ssl.create_default_context(cafile=certifi.where()),
489+
)
490+
491+
async with aiohttp.ClientSession(connector=connector) as session:
492+
messages = [
493+
{"role": "system", "content": "You are a helpful assistant."},
494+
{"role": "user", "content": "请写一首关于秋天的诗"},
495+
]
496+
497+
print("\n流式输出:")
498+
response = await AioGeneration.call(
499+
api_key=os.getenv("DASHSCOPE_API_KEY"),
500+
model="qwen-turbo",
501+
messages=messages,
502+
result_format="message",
503+
stream=True,
504+
incremental_output=True,
505+
session=session, # ← 传入自定义 session
506+
)
507+
508+
async for chunk in response:
509+
if chunk.status_code == 200:
510+
print(chunk.output.choices[0].message.content, end='', flush=True)
511+
512+
print("\n")
513+
514+
print("✅ ClientSession 已自动关闭")
515+
516+
@staticmethod
517+
async def test_with_custom_session_concurrent():
518+
"""示例:使用自定义 ClientSession 进行并发请求"""
519+
print("\n=== 使用自定义 ClientSession 并发请求示例 ===")
520+
521+
# 配置连接池
522+
connector = aiohttp.TCPConnector(
523+
limit=100,
524+
ssl=ssl.create_default_context(cafile=certifi.where()),
525+
)
526+
527+
async with aiohttp.ClientSession(connector=connector) as session:
528+
# 创建多个并发任务
529+
tasks = []
530+
topics = ["Python", "JavaScript", "Go"]
531+
532+
for topic in topics:
533+
messages = [
534+
{"role": "system", "content": "You are a helpful assistant."},
535+
{"role": "user", "content": f"请用一句话介绍:{topic}"},
536+
]
537+
538+
task = AioGeneration.call(
539+
api_key=os.getenv("DASHSCOPE_API_KEY"),
540+
model="qwen-turbo",
541+
messages=messages,
542+
result_format="message",
543+
session=session, # ← 所有请求共享同一个 session
544+
)
545+
tasks.append(task)
546+
547+
# 并发执行所有任务
548+
print("\n开始并发请求...")
549+
responses = await asyncio.gather(*tasks)
550+
551+
# 处理响应
552+
for i, response in enumerate(responses):
553+
print(f"\n请求 {i+1} 响应: {response.output.choices[0].message.content}")
554+
555+
print("\n✅ ClientSession 已自动关闭")
556+
442557

443558
async def main():
444559
"""Main function to run all async tests."""
@@ -450,6 +565,11 @@ async def main():
450565
# await TestAioGeneration.test_response_with_search_info()
451566
# await TestAioGeneration.test_response_with_reasoning_content()
452567

568+
# 自定义 Session 示例
569+
# await TestAioGeneration.test_with_custom_session()
570+
# await TestAioGeneration.test_with_custom_session_streaming()
571+
# await TestAioGeneration.test_with_custom_session_concurrent()
572+
453573
print("\n所有异步测试用例执行完成!")
454574

455575

0 commit comments

Comments
 (0)