Skip to content

Commit 538813d

Browse files
committed
feat(ci): add unit tests
1 parent 27d2d97 commit 538813d

File tree

7 files changed

+305
-148
lines changed

7 files changed

+305
-148
lines changed
Lines changed: 54 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -45,25 +45,29 @@ def test_rag_call(self, mock_server: MockServer):
4545
"action_type": "api",
4646
"action_name": "文档检索",
4747
"action": "searchDocument",
48-
"action_input_stream": '{"query":"API接口说明中, TopP参数改如何传递?"}',
48+
"action_input_stream": (
49+
'{"query":"API接口说明中, ' 'TopP参数改如何传递?"}'
50+
),
4951
"action_input": {
50-
"query": "API接口说明中, TopP参数改如何传递?",
52+
"query": ("API接口说明中, TopP参数改如何传递?"),
5153
},
52-
"observation": """{"data": [
53-
{
54-
"docId": "1234",
55-
"docName": "API接口说明",
56-
"docUrl": "https://127.0.0.1/dl/API接口说明.pdf",
57-
"indexId": "1",
58-
"score": 0.11992252,
59-
"text": "填(0,1.0),取值越大,生成的随机性越高;启用文档检索后,文档引用类型,取值包括:simple|indexed。",
60-
"title": "API接口说明",
61-
"titlePath": "API接口说明>>>接口说明>>>是否必 说明>>>填"
62-
}
63-
],
64-
"status": "SUCCESS"
65-
}""",
66-
"response": "API接口说明中, TopP参数是一个float类型的参数,取值范围为0到1.0,默认为1.0。取值越大,生成的随机性越高。[5]",
54+
"observation": (
55+
'{"data": [{"docId": "1234", '
56+
'"docName": "API接口说明", '
57+
'"docUrl": "https://127.0.0.1/dl/'
58+
'API接口说明.pdf", "indexId": "1", '
59+
'"score": 0.11992252, "text": "填(0,1.0),'
60+
"取值越大,生成的随机性越高;启用文档检索后,"
61+
'文档引用类型,取值包括:simple|indexed。", '
62+
'"title": "API接口说明", "titlePath": '
63+
'"API接口说明>>>接口说明>>>是否必 说明>>>填"}], '
64+
'"status": "SUCCESS"}'
65+
),
66+
"response": (
67+
"API接口说明中, TopP参数是一个float类型的"
68+
"参数,取值范围为0到1.0,默认为1.0。取值越大,"
69+
"生成的随机性越高。[5]"
70+
),
6771
},
6872
],
6973
},
@@ -86,11 +90,12 @@ def test_rag_call(self, mock_server: MockServer):
8690
top_p=0.2,
8791
temperature=1.0,
8892
doc_tag_codes=["t1234", "t2345"],
89-
doc_reference_type=Application.DocReferenceType.simple,
93+
doc_reference_type=(Application.DocReferenceType.simple),
9094
has_thoughts=True,
9195
)
9296

93-
self.check_result(resp, test_response)
97+
# Test mock response type
98+
self.check_result(resp, test_response) # type: ignore[arg-type]
9499

95100
def test_flow_call(self, mock_server: MockServer):
96101
test_response = {
@@ -105,13 +110,19 @@ def test_flow_call(self, mock_server: MockServer):
105110
"action_type": "api",
106111
"action_name": "plugin",
107112
"action": "api",
108-
"action_input_stream": '{"userId": "123", "date": "202402", "city": "hangzhou"}',
113+
"action_input_stream": (
114+
'{"userId": "123", "date": "202402", '
115+
'"city": "hangzhou"}'
116+
),
109117
"action_input": {
110118
"userId": "123",
111119
"date": "202402",
112120
"city": "hangzhou",
113121
},
114-
"observation": """{"quantity": 102, "type": "resident", "date": "202402", "unit": "千瓦"}""",
122+
"observation": (
123+
'{"quantity": 102, "type": "resident", '
124+
'"date": "202402", "unit": "千瓦"}'
125+
),
115126
"response": "当月的居民用电量为102千瓦。",
116127
},
117128
],
@@ -140,7 +151,8 @@ def test_flow_call(self, mock_server: MockServer):
140151
has_thoughts=True,
141152
)
142153

143-
self.check_result(resp, test_response)
154+
# Test mock response type
155+
self.check_result(resp, test_response) # type: ignore[arg-type]
144156

145157
def test_call_with_error(self, mock_server: MockServer):
146158
test_response = {
@@ -212,21 +224,21 @@ def check_result(resp: ApplicationResponse, test_response: Dict):
212224
expected_doc_refs,
213225
)
214226

215-
for i in range(len(doc_refs)):
216-
assert doc_refs[i].index_id == expected_doc_refs[i].get(
227+
for i, doc_ref in enumerate(doc_refs):
228+
assert doc_ref.index_id == expected_doc_refs[i].get(
217229
"index_id",
218230
)
219-
assert doc_refs[i].doc_id == expected_doc_refs[i].get("doc_id")
220-
assert doc_refs[i].doc_name == expected_doc_refs[i].get(
231+
assert doc_ref.doc_id == expected_doc_refs[i].get("doc_id")
232+
assert doc_ref.doc_name == expected_doc_refs[i].get(
221233
"doc_name",
222234
)
223-
assert doc_refs[i].doc_url == expected_doc_refs[i].get(
235+
assert doc_ref.doc_url == expected_doc_refs[i].get(
224236
"doc_url",
225237
)
226-
assert doc_refs[i].title == expected_doc_refs[i].get("title")
227-
assert doc_refs[i].text == expected_doc_refs[i].get("text")
228-
assert doc_refs[i].biz_id == expected_doc_refs[i].get("biz_id")
229-
assert json.dumps(doc_refs[i].images) == json.dumps(
238+
assert doc_ref.title == expected_doc_refs[i].get("title")
239+
assert doc_ref.text == expected_doc_refs[i].get("text")
240+
assert doc_ref.biz_id == expected_doc_refs[i].get("biz_id")
241+
assert json.dumps(doc_ref.images) == json.dumps(
230242
expected_doc_refs[i].get("images"),
231243
)
232244

@@ -238,26 +250,26 @@ def check_result(resp: ApplicationResponse, test_response: Dict):
238250
expected_thoughts,
239251
)
240252

241-
for i in range(len(thoughts)):
242-
assert thoughts[i].thought == expected_thoughts[i].get(
253+
for i, thought in enumerate(thoughts):
254+
assert thought.thought == expected_thoughts[i].get(
243255
"thought",
244256
)
245-
assert thoughts[i].action == expected_thoughts[i].get("action")
246-
assert thoughts[i].action_name == expected_thoughts[i].get(
257+
assert thought.action == expected_thoughts[i].get("action")
258+
assert thought.action_name == expected_thoughts[i].get(
247259
"action_name",
248260
)
249-
assert thoughts[i].action_type == expected_thoughts[i].get(
261+
assert thought.action_type == expected_thoughts[i].get(
250262
"action_type",
251263
)
252-
assert json.dumps(thoughts[i].action_input) == json.dumps(
264+
assert json.dumps(thought.action_input) == json.dumps(
253265
expected_thoughts[i].get("action_input"),
254266
)
255-
assert thoughts[i].action_input_stream == expected_thoughts[
256-
i
257-
].get("action_input_stream")
258-
assert thoughts[i].observation == expected_thoughts[i].get(
267+
assert thought.action_input_stream == (
268+
expected_thoughts[i].get("action_input_stream")
269+
)
270+
assert thought.observation == expected_thoughts[i].get(
259271
"observation",
260272
)
261-
assert thoughts[i].response == expected_thoughts[i].get(
273+
assert thought.response == expected_thoughts[i].get(
262274
"response",
263275
)
Lines changed: 70 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,10 @@ def setup_class(cls):
4040
cls.update_phrase = {"黄鸡": 2, "红鸡": 1}
4141
cls.phrase_id = TEST_JOB_ID
4242

43-
def test_create_phrases(self, http_server):
43+
def test_create_phrases(
44+
self,
45+
http_server,
46+
): # pylint: disable=unused-argument
4447
result = AsrPhraseManager.create_phrases(
4548
model=self.model,
4649
phrases=self.phrase,
@@ -54,7 +57,10 @@ def test_create_phrases(self, http_server):
5457
assert len(result.output["finetuned_output"]) > 0
5558
self.phrase_id = result.output["finetuned_output"]
5659

57-
def test_update_phrases(self, http_server):
60+
def test_update_phrases(
61+
self,
62+
http_server,
63+
): # pylint: disable=unused-argument
5864
result = AsrPhraseManager.update_phrases(
5965
model=self.model,
6066
phrase_id=self.phrase_id,
@@ -66,7 +72,10 @@ def test_update_phrases(self, http_server):
6672
assert result.output["finetuned_output"] is not None
6773
assert len(result.output["finetuned_output"]) > 0
6874

69-
def test_query_phrases(self, http_server):
75+
def test_query_phrases(
76+
self,
77+
http_server,
78+
): # pylint: disable=unused-argument
7079
result = AsrPhraseManager.query_phrases(phrase_id=self.phrase_id)
7180
assert result is not None
7281
assert result.status_code == HTTPStatus.OK
@@ -75,46 +84,55 @@ def test_query_phrases(self, http_server):
7584
assert result.output["model"] is not None
7685
assert len(result.output["model"]) > 0
7786

78-
def test_list_phrases(self, http_server):
87+
def test_list_phrases(
88+
self,
89+
http_server,
90+
): # pylint: disable=unused-argument
7991
result = AsrPhraseManager.list_phrases(page=1, page_size=10)
8092
assert result is not None
8193
assert result.status_code == HTTPStatus.OK
8294
assert result.output["finetuned_outputs"] is not None
8395
assert len(result.output["finetuned_outputs"]) > 0
8496

85-
def test_delete_phrases(self, http_server):
97+
def test_delete_phrases(
98+
self,
99+
http_server,
100+
): # pylint: disable=unused-argument
86101
result = AsrPhraseManager.delete_phrases(phrase_id=self.phrase_id)
87102
assert result is not None
88103
assert result.status_code == HTTPStatus.OK
89104
assert result.output["finetuned_output"] is not None
90105
assert len(result.output["finetuned_output"]) > 0
91106

92107

93-
def str2bool(str):
94-
return True if str.lower() == "true" else False
108+
def str2bool(test): # pylint: disable=redefined-builtin
109+
# Return True if test string is "true", False otherwise
110+
return test.lower() == "true"
95111

96112

97-
def complete_url(url: str) -> str:
113+
def complete_url(url: str) -> None:
114+
# Set base URLs for dashscope API
98115
parsed = urlparse(url)
99116
base_url = "".join([parsed.scheme, "://", parsed.netloc])
100117
dashscope.base_websocket_api_url = "/".join(
101118
[base_url, "api-ws", dashscope.common.env.api_version, "inference"],
102119
)
103-
dashscope.base_http_api_url = url = "/".join(
120+
dashscope.base_http_api_url = "/".join(
104121
[base_url, "api", dashscope.common.env.api_version],
105122
)
106123
print("Set base_websocket_api_url: ", dashscope.base_websocket_api_url)
107124
print("Set base_http_api_url: ", dashscope.base_http_api_url)
108125

109126

110-
def phrases(
127+
def phrases( # pylint: disable=redefined-outer-name,too-many-branches
111128
model,
112129
phrase_id: str,
113130
phrases: dict,
114131
page: int,
115132
page_size: int,
116133
delete: bool,
117134
):
135+
# Manage ASR phrases based on provided parameters
118136
print("phrase_id: ", phrase_id)
119137
print("phrase: ", phrases)
120138
print("delete flag: ", delete)
@@ -126,33 +144,30 @@ def phrases(
126144
phrase_id=phrase_id,
127145
phrases=phrases,
128146
)
129-
else:
130-
print("Create phrases -->")
131-
return AsrPhraseManager.create_phrases(
132-
model=model,
133-
phrases=phrases,
134-
)
135-
else:
136-
if delete:
137-
print("Delete phrases -->")
138-
return AsrPhraseManager.delete_phrases(phrase_id=phrase_id)
139-
else:
140-
if phrase_id is not None:
141-
print("Query phrases -->")
142-
return AsrPhraseManager.query_phrases(phrase_id=phrase_id)
143-
if page is not None and page_size is not None:
144-
print(
145-
"List phrases page %d page_size %d -->"
146-
% (page, page_size),
147-
)
148-
return AsrPhraseManager.list_phrases(
149-
page=page,
150-
page_size=page_size,
151-
)
147+
print("Create phrases -->")
148+
return AsrPhraseManager.create_phrases(
149+
model=model,
150+
phrases=phrases,
151+
)
152+
if delete:
153+
print("Delete phrases -->")
154+
return AsrPhraseManager.delete_phrases(phrase_id=phrase_id)
155+
if phrase_id is not None:
156+
print("Query phrases -->")
157+
return AsrPhraseManager.query_phrases(phrase_id=phrase_id)
158+
if page is not None and page_size is not None:
159+
print(
160+
f"List phrases page {page} page_size {page_size} -->",
161+
)
162+
return AsrPhraseManager.list_phrases(
163+
page=page,
164+
page_size=page_size,
165+
)
166+
return None
152167

153168

154169
@pytest.mark.skip
155-
def test_by_user():
170+
def test_by_user(): # pylint: disable=too-many-branches
156171
parser = argparse.ArgumentParser()
157172
parser.add_argument("--model", type=str, default="paraformer-realtime-v1")
158173
parser.add_argument("--phrase", type=str, default="")
@@ -184,46 +199,47 @@ def test_by_user():
184199
print("Response of phrases: ", resp)
185200
if resp is not None and resp.output is not None:
186201
output = resp.output
187-
print("\nGet output: %s\n" % (str(output)))
202+
print(f"\nGet output: {str(output)}\n")
188203

189204
if (
190205
"finetuned_output" in output
191206
and output["finetuned_output"] is not None
192207
):
193-
print("Get phrase_id: %s" % (output["finetuned_output"]))
208+
print(
209+
f"Get phrase_id: {output['finetuned_output']}",
210+
)
194211
if "job_id" in output and output["job_id"] is not None:
195-
print("Get job_id: %s" % (output["job_id"]))
212+
print(f"Get job_id: {output['job_id']}")
196213
if "create_time" in output and output["create_time"] is not None:
197-
print("Get create_time: %s" % (output["create_time"]))
214+
print(f"Get create_time: {output['create_time']}")
198215
if "model" in output and output["model"] is not None:
199-
print("Get model_id: %s" % (output["model"]))
216+
print(f"Get model_id: {output['model']}")
200217
if "output_type" in output and output["output_type"] is not None:
201-
print("Get output_type: %s" % (output["output_type"]))
218+
print(f"Get output_type: {output['output_type']}")
202219

203220
if (
204221
"finetuned_outputs" in output
205222
and output["finetuned_outputs"] is not None
206223
):
207224
outputs = output["finetuned_outputs"]
208225
print(
209-
"Get %d info from page_no:%d page_size:%d total:%d ->"
210-
% (
211-
len(outputs),
212-
output["page_no"],
213-
output["page_size"],
214-
output["total"],
215-
),
226+
f"Get {len(outputs)} info from "
227+
f"page_no:{output['page_no']} "
228+
f"page_size:{output['page_size']} "
229+
f"total:{output['total']} ->",
216230
)
217231
for item in outputs:
218-
print(" get phrase_id: %s" % (item["finetuned_output"]))
219-
print(" get job_id: %s" % (item["job_id"]))
220-
print(" get create_time: %s" % (item["create_time"]))
221-
print(" get model_id: %s" % (item["model"]))
222-
print(" get output_type: %s\n" % (item["output_type"]))
232+
print(
233+
f" get phrase_id: {item['finetuned_output']}",
234+
)
235+
print(f" get job_id: {item['job_id']}")
236+
print(f" get create_time: {item['create_time']}")
237+
print(f" get model_id: {item['model']}")
238+
print(f" get output_type: {item['output_type']}\n")
223239
else:
224240
print(
225-
"ERROR, status_code:%d, code_message:%s, error_message:%s"
226-
% (resp.status_code, resp.code, resp.message),
241+
f"ERROR, status_code:{resp.status_code}, "
242+
f"code_message:{resp.code}, error_message:{resp.message}",
227243
)
228244

229245

0 commit comments

Comments
 (0)