Skip to content

Commit dd55eee

Browse files
authored
feat(ci): add unit tests (#100)
1 parent 27d2d97 commit dd55eee

29 files changed

+539
-284
lines changed

tests/legacy/test_websocket_async_api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from dashscope.protocol.websocket import WebsocketStreamingMode
88
from tests.unit.base_test import BaseTestEnvironment
99
from tests.unit.constants import TestTasks
10-
from tests.legacy.websocket_task_request import WebSocketRequest
10+
from tests.unit.websocket_task_request import WebSocketRequest
1111

1212
# set mock server url.
1313
base_websocket_api_url = "ws://localhost:8080/ws/aigc/v1"

tests/legacy/test_websocket_parameters.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
TEST_ENABLE_DATA_INSPECTION_REQUEST_ID,
1010
TestTasks,
1111
)
12-
from tests.legacy.websocket_task_request import WebSocketRequest
12+
from tests.unit.websocket_task_request import WebSocketRequest
1313

1414

1515
def pytest_generate_tests(metafunc):

tests/legacy/test_websocket_sync_api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from dashscope.protocol.websocket import WebsocketStreamingMode
66
from tests.unit.base_test import BaseTestEnvironment
77
from tests.unit.constants import TestTasks
8-
from tests.legacy.websocket_task_request import WebSocketRequest
8+
from tests.unit.websocket_task_request import WebSocketRequest
99

1010

1111
def pytest_generate_tests(metafunc):

tests/unit/mock_server.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
list_fine_tune_handler,
3434
)
3535
from tests.unit.mock_sse import sse_response
36-
from tests.legacy.websocket_mock_server_task_handler import (
36+
from tests.unit.websocket_mock_server_task_handler import (
3737
WebSocketTaskProcessor,
3838
)
3939

Lines changed: 54 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -45,25 +45,29 @@ def test_rag_call(self, mock_server: MockServer):
4545
"action_type": "api",
4646
"action_name": "文档检索",
4747
"action": "searchDocument",
48-
"action_input_stream": '{"query":"API接口说明中, TopP参数改如何传递?"}',
48+
"action_input_stream": (
49+
'{"query":"API接口说明中, ' 'TopP参数改如何传递?"}'
50+
),
4951
"action_input": {
50-
"query": "API接口说明中, TopP参数改如何传递?",
52+
"query": ("API接口说明中, TopP参数改如何传递?"),
5153
},
52-
"observation": """{"data": [
53-
{
54-
"docId": "1234",
55-
"docName": "API接口说明",
56-
"docUrl": "https://127.0.0.1/dl/API接口说明.pdf",
57-
"indexId": "1",
58-
"score": 0.11992252,
59-
"text": "填(0,1.0),取值越大,生成的随机性越高;启用文档检索后,文档引用类型,取值包括:simple|indexed。",
60-
"title": "API接口说明",
61-
"titlePath": "API接口说明>>>接口说明>>>是否必 说明>>>填"
62-
}
63-
],
64-
"status": "SUCCESS"
65-
}""",
66-
"response": "API接口说明中, TopP参数是一个float类型的参数,取值范围为0到1.0,默认为1.0。取值越大,生成的随机性越高。[5]",
54+
"observation": (
55+
'{"data": [{"docId": "1234", '
56+
'"docName": "API接口说明", '
57+
'"docUrl": "https://127.0.0.1/dl/'
58+
'API接口说明.pdf", "indexId": "1", '
59+
'"score": 0.11992252, "text": "填(0,1.0),'
60+
"取值越大,生成的随机性越高;启用文档检索后,"
61+
'文档引用类型,取值包括:simple|indexed。", '
62+
'"title": "API接口说明", "titlePath": '
63+
'"API接口说明>>>接口说明>>>是否必 说明>>>填"}], '
64+
'"status": "SUCCESS"}'
65+
),
66+
"response": (
67+
"API接口说明中, TopP参数是一个float类型的"
68+
"参数,取值范围为0到1.0,默认为1.0。取值越大,"
69+
"生成的随机性越高。[5]"
70+
),
6771
},
6872
],
6973
},
@@ -86,11 +90,12 @@ def test_rag_call(self, mock_server: MockServer):
8690
top_p=0.2,
8791
temperature=1.0,
8892
doc_tag_codes=["t1234", "t2345"],
89-
doc_reference_type=Application.DocReferenceType.simple,
93+
doc_reference_type=(Application.DocReferenceType.simple),
9094
has_thoughts=True,
9195
)
9296

93-
self.check_result(resp, test_response)
97+
# Test mock response type
98+
self.check_result(resp, test_response) # type: ignore[arg-type]
9499

95100
def test_flow_call(self, mock_server: MockServer):
96101
test_response = {
@@ -105,13 +110,19 @@ def test_flow_call(self, mock_server: MockServer):
105110
"action_type": "api",
106111
"action_name": "plugin",
107112
"action": "api",
108-
"action_input_stream": '{"userId": "123", "date": "202402", "city": "hangzhou"}',
113+
"action_input_stream": (
114+
'{"userId": "123", "date": "202402", '
115+
'"city": "hangzhou"}'
116+
),
109117
"action_input": {
110118
"userId": "123",
111119
"date": "202402",
112120
"city": "hangzhou",
113121
},
114-
"observation": """{"quantity": 102, "type": "resident", "date": "202402", "unit": "千瓦"}""",
122+
"observation": (
123+
'{"quantity": 102, "type": "resident", '
124+
'"date": "202402", "unit": "千瓦"}'
125+
),
115126
"response": "当月的居民用电量为102千瓦。",
116127
},
117128
],
@@ -140,7 +151,8 @@ def test_flow_call(self, mock_server: MockServer):
140151
has_thoughts=True,
141152
)
142153

143-
self.check_result(resp, test_response)
154+
# Test mock response type
155+
self.check_result(resp, test_response) # type: ignore[arg-type]
144156

145157
def test_call_with_error(self, mock_server: MockServer):
146158
test_response = {
@@ -212,21 +224,21 @@ def check_result(resp: ApplicationResponse, test_response: Dict):
212224
expected_doc_refs,
213225
)
214226

215-
for i in range(len(doc_refs)):
216-
assert doc_refs[i].index_id == expected_doc_refs[i].get(
227+
for i, doc_ref in enumerate(doc_refs):
228+
assert doc_ref.index_id == expected_doc_refs[i].get(
217229
"index_id",
218230
)
219-
assert doc_refs[i].doc_id == expected_doc_refs[i].get("doc_id")
220-
assert doc_refs[i].doc_name == expected_doc_refs[i].get(
231+
assert doc_ref.doc_id == expected_doc_refs[i].get("doc_id")
232+
assert doc_ref.doc_name == expected_doc_refs[i].get(
221233
"doc_name",
222234
)
223-
assert doc_refs[i].doc_url == expected_doc_refs[i].get(
235+
assert doc_ref.doc_url == expected_doc_refs[i].get(
224236
"doc_url",
225237
)
226-
assert doc_refs[i].title == expected_doc_refs[i].get("title")
227-
assert doc_refs[i].text == expected_doc_refs[i].get("text")
228-
assert doc_refs[i].biz_id == expected_doc_refs[i].get("biz_id")
229-
assert json.dumps(doc_refs[i].images) == json.dumps(
238+
assert doc_ref.title == expected_doc_refs[i].get("title")
239+
assert doc_ref.text == expected_doc_refs[i].get("text")
240+
assert doc_ref.biz_id == expected_doc_refs[i].get("biz_id")
241+
assert json.dumps(doc_ref.images) == json.dumps(
230242
expected_doc_refs[i].get("images"),
231243
)
232244

@@ -238,26 +250,26 @@ def check_result(resp: ApplicationResponse, test_response: Dict):
238250
expected_thoughts,
239251
)
240252

241-
for i in range(len(thoughts)):
242-
assert thoughts[i].thought == expected_thoughts[i].get(
253+
for i, thought in enumerate(thoughts):
254+
assert thought.thought == expected_thoughts[i].get(
243255
"thought",
244256
)
245-
assert thoughts[i].action == expected_thoughts[i].get("action")
246-
assert thoughts[i].action_name == expected_thoughts[i].get(
257+
assert thought.action == expected_thoughts[i].get("action")
258+
assert thought.action_name == expected_thoughts[i].get(
247259
"action_name",
248260
)
249-
assert thoughts[i].action_type == expected_thoughts[i].get(
261+
assert thought.action_type == expected_thoughts[i].get(
250262
"action_type",
251263
)
252-
assert json.dumps(thoughts[i].action_input) == json.dumps(
264+
assert json.dumps(thought.action_input) == json.dumps(
253265
expected_thoughts[i].get("action_input"),
254266
)
255-
assert thoughts[i].action_input_stream == expected_thoughts[
256-
i
257-
].get("action_input_stream")
258-
assert thoughts[i].observation == expected_thoughts[i].get(
267+
assert thought.action_input_stream == (
268+
expected_thoughts[i].get("action_input_stream")
269+
)
270+
assert thought.observation == expected_thoughts[i].get(
259271
"observation",
260272
)
261-
assert thoughts[i].response == expected_thoughts[i].get(
273+
assert thought.response == expected_thoughts[i].get(
262274
"response",
263275
)

0 commit comments

Comments
 (0)