diff --git a/qnt/backtester.py b/qnt/backtester.py index f010186..b1550f9 100644 --- a/qnt/backtester.py +++ b/qnt/backtester.py @@ -285,15 +285,7 @@ def backtest( current_date = pd.to_datetime('today') data_days_difference = (current_date - checking_start_data_date).days data = load_data(data_days_difference) - try: - if data.name == 'stocks' and competition_type != 'stocks' and competition_type != 'stocks_long' \ - or data.name == 'stocks_nasdaq100' and competition_type != 'stocks_nasdaq100' \ - or data.name == 'cryptofutures' and competition_type != 'cryptofutures' and competition_type != 'crypto_futures' \ - or data.name == 'crypto' and competition_type != 'crypto' \ - or data.name == 'futures' and competition_type != 'futures': - log_err("WARNING! The data type and the competition type are mismatch.") - except: - pass + check_data_type_mismatch(data, competition_type) data, time_series = extract_time_series(data) log_info("Run strategy...") @@ -309,15 +301,7 @@ def backtest( days=60))) else: data = load_data(lookback_period) - try: - if data.name == 'stocks' and competition_type != 'stocks' and competition_type != 'stocks_long' \ - or data.name == 'stocks_nasdaq100' and competition_type != 'stocks_nasdaq100' \ - or data.name == 'cryptofutures' and competition_type != 'cryptofutures' and competition_type != 'crypto_futures' \ - or data.name == 'crypto' and competition_type != 'crypto' \ - or data.name == 'futures' and competition_type != 'futures': - log_err("WARNING! The data type and the competition type are mismatch.") - except: - pass + check_data_type_mismatch(data, competition_type) data, time_series = extract_time_series(data) log_info("Run strategy...") @@ -418,32 +402,20 @@ def copy_window(data, dt, tail): output_time_coord = ts[ts >= start_date] output_time_coord = output_time_coord[::step] - i = 0 - - sys.stdout.flush() - with progressbar.ProgressBar(max_value=len(output_time_coord), poll_interval=1) as p: state = None - for t in output_time_coord: + for i, t in enumerate(output_time_coord): tail = copy_window(data, t, lookback_period) result = strategy(tail, copy.deepcopy(state)) output, state = unpack_result(result) - if type(output) != xr.DataArray: - log_err("Output is not xarray!") - return - if set(output.dims) != {'asset'} and set(output.dims) != {'asset', 'time'}: - log_err("Wrong output dimensions. ", output.dims, "Should contain only:", {'asset', 'time'}) - return + validate_output(output) if 'time' in output.dims: - output = output.sel(time=t) + output = output.isel(time=-1) output = output.drop_vars(['field', 'time'], errors='ignore') outputs.append(output) if collect_all_states: all_states.append(state) - i += 1 - p.update(i) - - sys.stderr.flush() + p.update(i+1) log_info("Merge outputs...") output = xr.concat(outputs, pd.Index(output_time_coord, name=qndata.ds.TIME)) @@ -451,6 +423,25 @@ def copy_window(data, dt, tail): return output, all_states if collect_all_states else state +def check_data_type_mismatch(data, competition_type: str): + valid_competition_types = { + 'stocks': ['stocks', 'stocks_long'], + 'stocks_nasdaq100': ['stocks_nasdaq100'], + 'cryptofutures': ['cryptofutures', 'crypto_futures'], + 'crypto': ['crypto'], + 'futures': ['futures'] + } + if competition_type not in valid_competition_types[data.name]: + log_err("WARNING! The data type and the competition type are mismatched.") + + +def validate_output(output): + if not isinstance(output, xr.DataArray): + raise ValueError("Invalid output type. Expected xarray.DataArray.") + if set(output.dims) not in [{'asset'}, {'asset', 'time'}]: + raise ValueError(f"Wrong output dimensions. {output.dims} Should contain only: {'asset', 'time'}") + + def standard_window(data, max_date: np.datetime64, lookback_period:int): min_date = max_date - np.timedelta64(lookback_period,'D') return data.loc[dict(time=slice(min_date, max_date))]