88import threading
99import time
1010import uuid
11+ from dataclasses import dataclass
1112from enum import Enum , unique
13+ from typing import Dict , List , Optional
1214
1315import websocket
1416
2628)
2729
2830
31+ @dataclass
32+ class HotFix :
33+ """
34+ Hot fix parameters for pronunciation and text replacement.
35+
36+ Attributes:
37+ pronunciation: List of pronunciation, e.g., [{"草地": "cao3 di4"}]
38+ replace: List of text replacement, e.g., [{"草地": "草弟"}]
39+
40+ Example:
41+ hot_fix = HotFix(
42+ pronunciation=[{"草地": "cao3 di4"}],
43+ replace=[{"草地": "草弟"}]
44+ )
45+ hot_fix_dict = hot_fix.to_dict()
46+ """
47+
48+ pronunciation : Optional [List [Dict [str , str ]]] = None
49+ replace : Optional [List [Dict [str , str ]]] = None
50+
51+ def to_dict (self ) -> Dict [str , List [Dict [str , str ]]]:
52+ result = {}
53+ if self .pronunciation is not None :
54+ result ["pronunciation" ] = self .pronunciation
55+ if self .replace is not None :
56+ result ["replace" ] = self .replace
57+ return result
58+
59+
2960class ResultCallback :
3061 """
3162 An interface that defines callback methods for getting speech synthesis results. # noqa E501
@@ -246,6 +277,7 @@ def __init__( # pylint: disable=redefined-builtin
246277 callback : ResultCallback = None ,
247278 workspace = None ,
248279 url = None ,
280+ hot_fix = None ,
249281 additional_params = None ,
250282 ):
251283 """
@@ -282,6 +314,14 @@ def __init__( # pylint: disable=redefined-builtin
282314 The language hints of the synthesizer. supported language: zh, en.
283315 additional_params: Dict
284316 Additional parameters for the Dashscope API.
317+ hot_fix: Dict or HotFix
318+ Hot fix parameters for pronunciation and text replacement.
319+ Example: {
320+ "pronunciation": [{"草地": "cao3 di4"}],
321+ "replace": [{"草地": "草弟"}]
322+ }
323+ enable_markdown_filter: bool
324+ Whether to enable markdown filter. should be set into additional_params.
285325 """
286326 self .ws = None
287327 self .start_event = threading .Event ()
@@ -316,6 +356,7 @@ def __init__( # pylint: disable=redefined-builtin
316356 workspace ,
317357 url ,
318358 additional_params ,
359+ hot_fix ,
319360 )
320361
321362 def __send_str (self , data : str ):
@@ -404,6 +445,7 @@ def __update_params( # pylint: disable=redefined-builtin
404445 url = None ,
405446 additional_params = None ,
406447 close_ws_after_use = True ,
448+ hot_fix = None ,
407449 ):
408450 if model is None :
409451 raise ModelRequired ("Model is required!" )
@@ -417,6 +459,17 @@ def __update_params( # pylint: disable=redefined-builtin
417459 raise InputRequired ("apikey is required!" )
418460 self .headers = headers
419461 self .workspace = workspace
462+
463+ # Merge hot_fix into additional_params
464+ if hot_fix is not None :
465+ if additional_params is None :
466+ additional_params = {}
467+ # Support both HotFix instance and dict
468+ if isinstance (hot_fix , HotFix ):
469+ additional_params ["hot_fix" ] = hot_fix .to_dict ()
470+ else :
471+ additional_params ["hot_fix" ] = hot_fix
472+
420473 self .additional_params = additional_params
421474 self .model = model
422475 self .voice = voice
0 commit comments