Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 69 additions & 41 deletions aster/lib/utils.py
Original file line number Diff line number Diff line change
@@ -1,64 +1,92 @@
import json
import time
from urllib.parse import urlencode, quote_plus # quote_plus for cleaner URL encoding
from typing import Any, Dict, List, Type # For improved type hinting
from enum import Enum # Type added for check_enum_parameter

from urllib.parse import urlencode
# Assuming these custom errors are properly defined in the 'aster.error' module
from aster.error import (
ParameterRequiredError,
ParameterValueError,
ParameterTypeError,
)


def cleanNoneValue(d) -> dict:
out = {}
for k in d.keys():
if d[k] is not None:
out[k] = d[k]
return out


def check_required_parameter(value, name):
if not value and value != 0:
raise ParameterRequiredError([name])
# --- DATA CLEANING AND TRANSFORMATION UTILITIES ---

def clean_none_value(d: Dict[str, Any]) -> Dict[str, Any]:
"""
Removes entries with None values from a dictionary.

Args:
d: The dictionary to clean.

Returns:
A new dictionary with None values removed (Pythonic approach).
"""
if not isinstance(d, dict):
return {}
# Optimization: Use dictionary comprehension for efficiency and clean style.
return {k: v for k, v in d.items() if v is not None}

def check_required_parameters(params):
"""validate multiple parameters
params = [
['btcusdt', 'symbol'],
[10, 'price']
]

def convert_list_to_json_array(data: List[Any]) -> str | List[Any]:
"""
for p in params:
check_required_parameter(p[0], p[1])

Converts a Python list to a JSON string, ensuring no internal spaces.

Args:
data: The list to convert.

Returns:
The compact JSON string representation or the original data if None.
"""
if data is None:
return None
# Use json.dumps with separators to guarantee no spaces, which is cleaner
# and more reliable than .replace(" ", "").
return json.dumps(data, separators=(',', ':'))

def check_enum_parameter(value, enum_class):
if value not in set(item.value for item in enum_class):
raise ParameterValueError([value])

# --- VALIDATION AND ERROR CHECKING ---

def check_type_parameter(value, name, data_type):
if value is not None and type(value) != data_type:
raise ParameterTypeError([name, data_type])
def check_required_parameter(value: Any, name: str):
"""
Validates that a parameter is not None or an empty string.
Allows boolean False and integer 0.
"""
# Optimized check: Allows 0 and False, but rejects None and empty string/collection.
if value is None or (isinstance(value, str) and not value) or (hasattr(value, '__len__') and len(value) == 0):
raise ParameterRequiredError([name])


def get_timestamp():
return int(time.time() * 1000)
def check_required_parameters(params: List[List[Any]]):
"""
Validates multiple parameters based on a list of [value, name] pairs.

Args:
params: List of lists, e.g., [['btcusdt', 'symbol'], [10, 'price']].
"""
for value, name in params:
check_required_parameter(value, name)

def encoded_string(query, special = False):
if(special):
return urlencode(query).replace("%40", "@").replace('%27', '%22')
else:
return urlencode(query, True).replace("%40", "@")

def convert_list_to_json_array(symbols):
if symbols is None:
return symbols
res = json.dumps(symbols)
return res.replace(" ", "")
def check_enum_parameter(value: Any, enum_class: Type[Enum]):
"""
Checks if a given value is one of the defined values in the Enum class.

Args:
value: The value to check.
enum_class: The Enum class to check against.
"""
# Create the set of valid values only once for the enum_class type.
# Note: For production use, this set should be cached outside the function.
valid_values = {item.value for item in enum_class}

if value not in valid_values:
# Use the name of the Enum class in the error message for clarity
raise ParameterValueError([f"'{value}' not in {enum_class.__name__} valid values: {valid_values}"])


def config_logging(logging, logging_devel, log_file=None):
logging.basicConfig(level=logging_devel, filename=log_file)
def check_type_parameter(value: Any, name: str, data_type: Type):
"""
Validates that