Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 25 additions & 34 deletions qnt/backtester.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,15 +285,7 @@ def backtest(
current_date = pd.to_datetime('today')
data_days_difference = (current_date - checking_start_data_date).days
data = load_data(data_days_difference)
try:
if data.name == 'stocks' and competition_type != 'stocks' and competition_type != 'stocks_long' \
or data.name == 'stocks_nasdaq100' and competition_type != 'stocks_nasdaq100' \
or data.name == 'cryptofutures' and competition_type != 'cryptofutures' and competition_type != 'crypto_futures' \
or data.name == 'crypto' and competition_type != 'crypto' \
or data.name == 'futures' and competition_type != 'futures':
log_err("WARNING! The data type and the competition type are mismatch.")
except:
pass
check_data_type_mismatch(data, competition_type)
data, time_series = extract_time_series(data)

log_info("Run strategy...")
Expand All @@ -309,15 +301,7 @@ def backtest(
days=60)))
else:
data = load_data(lookback_period)
try:
if data.name == 'stocks' and competition_type != 'stocks' and competition_type != 'stocks_long' \
or data.name == 'stocks_nasdaq100' and competition_type != 'stocks_nasdaq100' \
or data.name == 'cryptofutures' and competition_type != 'cryptofutures' and competition_type != 'crypto_futures' \
or data.name == 'crypto' and competition_type != 'crypto' \
or data.name == 'futures' and competition_type != 'futures':
log_err("WARNING! The data type and the competition type are mismatch.")
except:
pass
check_data_type_mismatch(data, competition_type)
data, time_series = extract_time_series(data)

log_info("Run strategy...")
Expand Down Expand Up @@ -418,39 +402,46 @@ def copy_window(data, dt, tail):
output_time_coord = ts[ts >= start_date]
output_time_coord = output_time_coord[::step]

i = 0

sys.stdout.flush()

with progressbar.ProgressBar(max_value=len(output_time_coord), poll_interval=1) as p:
state = None
for t in output_time_coord:
for i, t in enumerate(output_time_coord):
tail = copy_window(data, t, lookback_period)
result = strategy(tail, copy.deepcopy(state))
output, state = unpack_result(result)
if type(output) != xr.DataArray:
log_err("Output is not xarray!")
return
if set(output.dims) != {'asset'} and set(output.dims) != {'asset', 'time'}:
log_err("Wrong output dimensions. ", output.dims, "Should contain only:", {'asset', 'time'})
return
validate_output(output)
if 'time' in output.dims:
output = output.sel(time=t)
output = output.isel(time=-1)
output = output.drop_vars(['field', 'time'], errors='ignore')
outputs.append(output)
if collect_all_states:
all_states.append(state)
i += 1
p.update(i)

sys.stderr.flush()
p.update(i+1)

log_info("Merge outputs...")
output = xr.concat(outputs, pd.Index(output_time_coord, name=qndata.ds.TIME))

return output, all_states if collect_all_states else state


def check_data_type_mismatch(data, competition_type: str):
valid_competition_types = {
'stocks': ['stocks', 'stocks_long'],
'stocks_nasdaq100': ['stocks_nasdaq100'],
'cryptofutures': ['cryptofutures', 'crypto_futures'],
'crypto': ['crypto'],
'futures': ['futures']
}
if competition_type not in valid_competition_types[data.name]:
log_err("WARNING! The data type and the competition type are mismatched.")


def validate_output(output):
if not isinstance(output, xr.DataArray):
raise ValueError("Invalid output type. Expected xarray.DataArray.")
if set(output.dims) not in [{'asset'}, {'asset', 'time'}]:
raise ValueError(f"Wrong output dimensions. {output.dims} Should contain only: {'asset', 'time'}")


def standard_window(data, max_date: np.datetime64, lookback_period:int):
min_date = max_date - np.timedelta64(lookback_period,'D')
return data.loc[dict(time=slice(min_date, max_date))]
Expand Down