From 88f67c1de091cdd674458818c4be5be6dbedbf00 Mon Sep 17 00:00:00 2001 From: Huang Jianwei Date: Fri, 27 Apr 2018 13:53:44 +0800 Subject: [PATCH] fixed error when calculate adjust_factor for other asset except for stock --- jaqs/data/dataview.py | 1335 +++++++++++++++++++++++++---------------- 1 file changed, 820 insertions(+), 515 deletions(-) diff --git a/jaqs/data/dataview.py b/jaqs/data/dataview.py index 18e380f..99e3437 100644 --- a/jaqs/data/dataview.py +++ b/jaqs/data/dataview.py @@ -18,8 +18,6 @@ from jaqs.data.py_expression_eval import Parser - - class DataView(object): """ Prepare data before research / trade. Support file I/O. @@ -42,9 +40,10 @@ class DataView(object): index is date, columns is symbol-field MultiIndex """ + def __init__(self): self.data_api = None - + self.universe = "" self.symbol = [] self.benchmark = "" @@ -57,26 +56,31 @@ def __init__(self): self.all_price = True self._snapshot = None - self.meta_data_list = ['start_date', 'end_date', - 'extended_start_date_d', 'extended_start_date_q', - 'freq', 'fields', 'symbol', 'universe', 'benchmark', - 'custom_daily_fields', 'custom_quarterly_fields'] + self.meta_data_list = [ + 'start_date', 'end_date', 'extended_start_date_d', + 'extended_start_date_q', 'freq', 'fields', 'symbol', 'universe', + 'benchmark', 'custom_daily_fields', 'custom_quarterly_fields' + ] self.adjust_mode = 'post' - + self.data_d = None self.data_q = None self._data_benchmark = None self._data_inst = None # self._data_group = None - + common_list = {'symbol', 'start_date', 'end_date'} - market_bar_list = {'open', 'high', 'low', 'close', 'volume', 'turnover', 'vwap', 'oi'} - market_tick_list = {'volume', 'oi', - 'askprice1', 'askprice2', 'askprice3', 'askprice4', 'askprice5', - 'bidprice1', 'bidprice1', 'bidprice1', 'bidprice1', 'bidprice1', - 'askvolume1', 'askvolume2', 'askvolume3', 'askvolume4', 'askvolume5', - 'bidvolume1', 'bidvolume2', 'bidvolume3', 'bidvolume4', 'bidvolume5'} + market_bar_list = { + 'open', 'high', 'low', 'close', 'volume', 'turnover', 'vwap', 'oi' + } + market_tick_list = { + 'volume', 'oi', 'askprice1', 'askprice2', 'askprice3', 'askprice4', + 'askprice5', 'bidprice1', 'bidprice1', 'bidprice1', 'bidprice1', + 'bidprice1', 'askvolume1', 'askvolume2', 'askvolume3', + 'askvolume4', 'askvolume5', 'bidvolume1', 'bidvolume2', + 'bidvolume3', 'bidvolume4', 'bidvolume5' + } # fields map # TODO: 'freq' is not in market_daily_fields yet. self.market_daily_fields = \ @@ -213,21 +217,21 @@ def __init__(self): "yoyocfps","yoyop","yoyebt","yoynetprofit","yoynetprofit_deducted","yoyocf","yoyroe","yoybps","yoyassets", "yoyequity","yoy_tr","yoy_or","qfa_yoygr","qfa_cgrgr","qfa_yoysales","qfa_cgrsales","qfa_yoyop","qfa_cgrop", "qfa_yoyprofit","qfa_cgrprofit","qfa_yoynetprofit","qfa_cgrnetprofit","yoy_equity","rd_expense","waa_roe"} - self .custom_daily_fields = [] - self .custom_quarterly_fields = [] - + self.custom_daily_fields = [] + self.custom_quarterly_fields = [] + # co nst - self .ANN_DATE_FIELD_NAME = 'ann_date' - self .REPORT_DATE_FIELD_NAME = 'report_date' + self.ANN_DATE_FIELD_NAME = 'ann_date' + self.REPORT_DATE_FIELD_NAME = 'report_date' self.TRADE_STATUS_FIELD_NAME = 'trade_status' self.TRADE_DATE_FIELD_NAME = 'trade_date' - + # -------------------------------------------------------------------------------------------------------- # Properties @property def data_benchmark(self): return self._data_benchmark - + @property def data_inst(self): """ @@ -238,11 +242,13 @@ def data_inst(self): """ return self._data_inst - + @data_benchmark.setter def data_benchmark(self, df_new): if self.data_d is not None and df_new.shape[0] != self.data_d.shape[0]: - raise ValueError("You must provide a DataFrame with the same shape of data_benchmark.") + raise ValueError( + "You must provide a DataFrame with the same shape of data_benchmark." + ) self._data_benchmark = df_new @property @@ -259,10 +265,13 @@ def dates(self): if self.data_d is not None: res = self.data_d.index.values elif self.data_api is not None: - res = self.data_api.query_trade_dates(self.extended_start_date_d, self.end_date) + res = self.data_api.query_trade_dates(self.extended_start_date_d, + self.end_date) else: - raise ValueError("Cannot get dates array when neither of data and data_api exists.") - + raise ValueError( + "Cannot get dates array when neither of data and data_api exists." + ) + return res # -------------------------------------------------------------------------------------------------------- @@ -322,7 +331,8 @@ def _is_predefined_field(self, field_name): bool """ - return self._is_quarter_field(field_name) or self._is_daily_field(field_name) + return self._is_quarter_field(field_name) or self._is_daily_field( + field_name) def _get_fields(self, field_type, fields, complement=False, append=False): """ @@ -340,51 +350,50 @@ def _get_fields(self, field_type, fields, complement=False, append=False): list """ - pool_map = {'market_daily': self.market_daily_fields, - 'ref_daily': self.reference_daily_fields, - 'income': self.fin_stat_income, - 'balance_sheet': self.fin_stat_balance_sheet, - 'cash_flow': self.fin_stat_cash_flow, - 'fin_indicator': self.fin_indicator, - 'group': self.group_fields} + pool_map = { + 'market_daily': self.market_daily_fields, + 'ref_daily': self.reference_daily_fields, + 'income': self.fin_stat_income, + 'balance_sheet': self.fin_stat_balance_sheet, + 'cash_flow': self.fin_stat_cash_flow, + 'fin_indicator': self.fin_indicator, + 'group': self.group_fields + } pool_map['daily'] = set.union(pool_map['market_daily'], - pool_map['ref_daily'], - pool_map['group'], + pool_map['ref_daily'], pool_map['group'], self.custom_daily_fields) - pool_map['quarterly'] = set.union(pool_map['income'], - pool_map['balance_sheet'], - pool_map['cash_flow'], - pool_map['fin_indicator'], - self.custom_quarterly_fields) - + pool_map['quarterly'] = set.union( + pool_map['income'], pool_map['balance_sheet'], + pool_map['cash_flow'], pool_map['fin_indicator'], + self.custom_quarterly_fields) + pool = pool_map.get(field_type, None) if pool is None: raise NotImplementedError("field_type = {:s}".format(field_type)) - + s = set.intersection(set(pool), set(fields)) if not s: return [] - + if complement: s = set(fields) - s - + if field_type == 'market_daily' and self.all_price: # turnover will not be adjusted s.update({'open', 'high', 'close', 'low', 'vwap'}) - + if append: s.add('symbol') if field_type == 'market_daily' or field_type == 'ref_daily': s.add('trade_date') if field_type == 'market_daily': s.add(self.TRADE_STATUS_FIELD_NAME) - elif (field_type == 'income' - or field_type == 'balance_sheet' + elif (field_type == 'income' or field_type == 'balance_sheet' or field_type == 'cash_flow' or field_type == 'fin_indicator'): s.add(self.ANN_DATE_FIELD_NAME) s.add(self.REPORT_DATE_FIELD_NAME) - + l = list(s) return l @@ -404,31 +413,36 @@ def init_from_config(self, props, data_api): """ # data_api.init_from_config(props) self.data_api = data_api - + sep = ',' - + # initialize parameters self.start_date = props['start_date'] - self.extended_start_date_d = jutil.shift(self.start_date, n_weeks=-8) # query more data + self.extended_start_date_d = jutil.shift( + self.start_date, n_weeks=-8) # query more data self.extended_start_date_q = jutil.shift(self.start_date, n_weeks=-80) self.end_date = props['end_date'] self.all_price = props.get('all_price', True) self.freq = props.get('freq', 1) - + # get and filter fields fields = props.get('fields', []) if fields: fields = props['fields'].split(sep) - self.fields = [field for field in fields if self._is_predefined_field(field)] + self.fields = [ + field for field in fields if self._is_predefined_field(field) + ] if len(self.fields) < len(fields): - print("Field name [{}] not valid, ignore.".format(set.difference(set(fields), set(self.fields)))) - + print("Field name [{}] not valid, ignore.".format( + set.difference(set(fields), set(self.fields)))) + # append additional fields if self.all_price: - self.fields.extend(['open_adj', 'high_adj', 'low_adj', 'close_adj', - 'open', 'high', 'low', 'close', - 'vwap', 'vwap_adj']) - + self.fields.extend([ + 'open_adj', 'high_adj', 'low_adj', 'close_adj', 'open', 'high', + 'low', 'close', 'vwap', 'vwap_adj' + ]) + # initialize universe/symbol universe = props.get('universe', "") symbol = props.get('symbol', "") @@ -442,7 +456,9 @@ def init_from_config(self, props, data_api): self.universe = univ_list symbols_list = [] for univ in univ_list: - symbols_list.extend(data_api.query_index_member(univ, self.extended_start_date_d, self.end_date)) + symbols_list.extend( + data_api.query_index_member( + univ, self.extended_start_date_d, self.end_date)) self.symbol = sorted(list(set(symbols_list))) #else: @@ -459,42 +475,52 @@ def init_from_config(self, props, data_api): if self.universe: if len(self.universe) > 1: print("More than one universe are used: {}, " - "use the first one ({}) as index by default. " - "If you want to use other benchmark, " - "please specify benchmark in configs.".format(repr(self.universe), self.universe[0])) + "use the first one ({}) as index by default. " + "If you want to use other benchmark, " + "please specify benchmark in configs.".format( + repr(self.universe), self.universe[0])) self.benchmark = self.universe[0] - + print("Initialize config success.") - def distributed_query(self, query_func_name, symbol, start_date, end_date, limit=100000, **kwargs): + def distributed_query(self, + query_func_name, + symbol, + start_date, + end_date, + limit=100000, + **kwargs): n_symbols = len(symbol.split(',')) dates = self.data_api.query_trade_dates(start_date, end_date) n_days = len(dates) - + if n_symbols * n_days > limit: n = limit // n_symbols - + df_list = [] i = 0 pos1, pos2 = n * i, n * (i + 1) - 1 while pos2 < n_days: print(pos2) - df, msg = getattr(self.data_api, query_func_name)(symbol=symbol, - start_date=dates[pos1], end_date=dates[pos2], - **kwargs) + df, msg = getattr(self.data_api, query_func_name)( + symbol=symbol, + start_date=dates[pos1], + end_date=dates[pos2], + **kwargs) df_list.append(df) i += 1 pos1, pos2 = n * i, n * (i + 1) - 1 if pos1 < n_days: - df, msg = getattr(self.data_api, query_func_name)(symbol=symbol, - start_date=dates[pos1], end_date=dates[-1], - **kwargs) + df, msg = getattr(self.data_api, query_func_name)( + symbol=symbol, + start_date=dates[pos1], + end_date=dates[-1], + **kwargs) df_list.append(df) df = pd.concat(df_list, axis=0) else: - df, msg = getattr(self.data_api, query_func_name)(symbol, - start_date=start_date, end_date=end_date, - **kwargs) + df, msg = getattr(self.data_api, query_func_name)( + symbol, start_date=start_date, end_date=end_date, **kwargs) return df, msg def prepare_data(self): @@ -506,20 +532,20 @@ def prepare_data(self): if self.data_q is not None: self._prepare_report_date() self._align_and_merge_q_into_d() - + print("Query instrument info...") self._prepare_inst_info() - + print("Query adj_factor...") self._prepare_adj_factor() - + if self.benchmark: print("Query benchmark...") self._data_benchmark = self._prepare_benchmark() if self.universe: print("Query benchmar member info...") self._prepare_comp_info() - + group_fields = self._get_fields('group', self.fields) if group_fields: print("Query groups (industry)...") @@ -551,11 +577,11 @@ def _prepare_daily_quarterly(self, fields): """ if not fields: return None, None - + # query data print("Query data - query...") daily_list, quarterly_list = self._query_data(self.symbol, fields) - + def pivot_and_sort(df, index_name): df = self._process_index_co(df, index_name) df = df.pivot(index=index_name, columns='symbol') @@ -565,21 +591,31 @@ def pivot_and_sort(df, index_name): df = df.sort_index(axis=1, level=col_names) df.index.name = index_name return df - + multi_daily = None multi_quarterly = None if daily_list: - daily_list_pivot = [pivot_and_sort(df, self.TRADE_DATE_FIELD_NAME) for df in daily_list] - multi_daily = self._merge_data(daily_list_pivot, self.TRADE_DATE_FIELD_NAME) + daily_list_pivot = [ + pivot_and_sort(df, self.TRADE_DATE_FIELD_NAME) + for df in daily_list + ] + multi_daily = self._merge_data(daily_list_pivot, + self.TRADE_DATE_FIELD_NAME) # use self.dates as index because original data have weekends - multi_daily = self._fill_missing_idx_col(multi_daily, index=self.dates, symbols=self.symbol) + multi_daily = self._fill_missing_idx_col( + multi_daily, index=self.dates, symbols=self.symbol) print("Query data - daily fields prepared.") if quarterly_list: - quarterly_list_pivot = [pivot_and_sort(df, self.REPORT_DATE_FIELD_NAME) for df in quarterly_list] - multi_quarterly = self._merge_data(quarterly_list_pivot, self.REPORT_DATE_FIELD_NAME) - multi_quarterly = self._fill_missing_idx_col(multi_quarterly, index=None, symbols=self.symbol) + quarterly_list_pivot = [ + pivot_and_sort(df, self.REPORT_DATE_FIELD_NAME) + for df in quarterly_list + ] + multi_quarterly = self._merge_data(quarterly_list_pivot, + self.REPORT_DATE_FIELD_NAME) + multi_quarterly = self._fill_missing_idx_col( + multi_quarterly, index=None, symbols=self.symbol) print("Query data - quarterly fields prepared.") - + return multi_daily, multi_quarterly def _query_data(self, symbol, fields): @@ -601,70 +637,112 @@ def _query_data(self, symbol, fields): """ sep = ',' symbol_str = sep.join(symbol) - + if self.freq == 1: daily_list = [] quarterly_list = [] - + # TODO : use fields = {field: kwargs} to enable params - fields_market_daily = self._get_fields('market_daily', fields, append=True) + fields_market_daily = self._get_fields( + 'market_daily', fields, append=True) if fields_market_daily: - print("NOTE: price adjust method is [{:s} adjust]".format(self.adjust_mode)) + print("NOTE: price adjust method is [{:s} adjust]".format( + self.adjust_mode)) # no adjust prices and other market daily fields - df_daily, msg1 = self.distributed_query('daily', symbol_str, - start_date=self.extended_start_date_d, end_date=self.end_date, - adjust_mode=None, fields=sep.join(fields_market_daily), limit=100000) + df_daily, msg1 = self.distributed_query( + 'daily', + symbol_str, + start_date=self.extended_start_date_d, + end_date=self.end_date, + adjust_mode=None, + fields=sep.join(fields_market_daily), + limit=100000) #df_daily, msg1 = self.data_api.daily(symbol_str, start_date=self.extended_start_date_d, end_date=self.end_date, # adjust_mode=None, fields=sep.join(fields_market_daily)) - + if self.all_price: adj_cols = ['open', 'high', 'low', 'close', 'vwap'] # adjusted prices #df_daily_adjust, msg11 = self.data_api.daily(symbol_str, start_date=self.extended_start_date_d, end_date=self.end_date, # adjust_mode=self.adjust_mode, fields=','.join(adj_cols)) - df_daily_adjust, msg1 = self.distributed_query('daily', symbol_str, - start_date=self.extended_start_date_d, end_date=self.end_date, - adjust_mode=self.adjust_mode, fields=sep.join(fields_market_daily), limit=100000) - - df_daily = pd.merge(df_daily, df_daily_adjust, how='outer', - on=['symbol', 'trade_date'], suffixes=('', '_adj')) + df_daily_adjust, msg1 = self.distributed_query( + 'daily', + symbol_str, + start_date=self.extended_start_date_d, + end_date=self.end_date, + adjust_mode=self.adjust_mode, + fields=sep.join(fields_market_daily), + limit=100000) + + df_daily = pd.merge( + df_daily, + df_daily_adjust, + how='outer', + on=['symbol', 'trade_date'], + suffixes=('', '_adj')) daily_list.append(df_daily.loc[:, fields_market_daily]) - - fields_ref_daily = self._get_fields('ref_daily', fields, append=True) + + fields_ref_daily = self._get_fields( + 'ref_daily', fields, append=True) if fields_ref_daily: - df_ref_daily, msg2 = self.distributed_query('query_lb_dailyindicator', symbol_str, - start_date=self.extended_start_date_d, end_date=self.end_date, - fields=sep.join(fields_ref_daily), limit=20000) + df_ref_daily, msg2 = self.distributed_query( + 'query_lb_dailyindicator', + symbol_str, + start_date=self.extended_start_date_d, + end_date=self.end_date, + fields=sep.join(fields_ref_daily), + limit=20000) daily_list.append(df_ref_daily.loc[:, fields_ref_daily]) - + fields_income = self._get_fields('income', fields, append=True) if fields_income: - df_income, msg3 = self.data_api.query_lb_fin_stat('income', symbol_str, self.extended_start_date_q, self.end_date, - sep.join(fields_income), drop_dup_cols=['symbol', self.REPORT_DATE_FIELD_NAME]) + df_income, msg3 = self.data_api.query_lb_fin_stat( + 'income', + symbol_str, + self.extended_start_date_q, + self.end_date, + sep.join(fields_income), + drop_dup_cols=['symbol', self.REPORT_DATE_FIELD_NAME]) quarterly_list.append(df_income.loc[:, fields_income]) - - fields_balance = self._get_fields('balance_sheet', fields, append=True) + + fields_balance = self._get_fields( + 'balance_sheet', fields, append=True) if fields_balance: - df_balance, msg3 = self.data_api.query_lb_fin_stat('balance_sheet', symbol_str, self.extended_start_date_q, self.end_date, - sep.join(fields_balance), drop_dup_cols=['symbol', self.REPORT_DATE_FIELD_NAME]) + df_balance, msg3 = self.data_api.query_lb_fin_stat( + 'balance_sheet', + symbol_str, + self.extended_start_date_q, + self.end_date, + sep.join(fields_balance), + drop_dup_cols=['symbol', self.REPORT_DATE_FIELD_NAME]) quarterly_list.append(df_balance.loc[:, fields_balance]) - + fields_cf = self._get_fields('cash_flow', fields, append=True) if fields_cf: - df_cf, msg3 = self.data_api.query_lb_fin_stat('cash_flow', symbol_str, self.extended_start_date_q, self.end_date, - sep.join(fields_cf), drop_dup_cols=['symbol', self.REPORT_DATE_FIELD_NAME]) + df_cf, msg3 = self.data_api.query_lb_fin_stat( + 'cash_flow', + symbol_str, + self.extended_start_date_q, + self.end_date, + sep.join(fields_cf), + drop_dup_cols=['symbol', self.REPORT_DATE_FIELD_NAME]) quarterly_list.append(df_cf.loc[:, fields_cf]) - - fields_fin_ind = self._get_fields('fin_indicator', fields, append=True) + + fields_fin_ind = self._get_fields( + 'fin_indicator', fields, append=True) if fields_fin_ind: - df_fin_ind, msg4 = self.data_api.query_lb_fin_stat('fin_indicator', symbol_str, - self.extended_start_date_q, self.end_date, - sep.join(fields_fin_ind), drop_dup_cols=['symbol', self.REPORT_DATE_FIELD_NAME]) + df_fin_ind, msg4 = self.data_api.query_lb_fin_stat( + 'fin_indicator', + symbol_str, + self.extended_start_date_q, + self.end_date, + sep.join(fields_fin_ind), + drop_dup_cols=['symbol', self.REPORT_DATE_FIELD_NAME]) quarterly_list.append(df_fin_ind.loc[:, fields_fin_ind]) - + else: raise NotImplementedError("freq = {}".format(self.freq)) - + return daily_list, quarterly_list @staticmethod @@ -687,22 +765,22 @@ def _merge_data(dfs, index_name='trade_date'): """ # dfs = [df for df in dfs if df is not None] - + merge = pd.concat(dfs, axis=1, join='outer') - + # drop duplicated columns. ONE LINE EFFICIENT version mask_duplicated = merge.columns.duplicated() if np.any(mask_duplicated): # print("Duplicated columns found. Dropped.") merge = merge.loc[:, ~mask_duplicated] - + # if merge.isnull().sum().sum() > 0: # print "WARNING: nan in final merged data. NO fill" # merge.fillna(method='ffill', inplace=True) - + merge = merge.sort_index(axis=1, level=['symbol', 'field']) merge.index.name = index_name - + return merge def _fill_missing_idx_col(self, df, index=None, symbols=None): @@ -711,20 +789,25 @@ def _fill_missing_idx_col(self, df, index=None, symbols=None): if symbols is None: symbols = self.symbol fields = df.columns.levels[1] - - if len(fields) * len(self.symbol) != len(df.columns) or len(index) != len(df.index): - cols_multi = pd.MultiIndex.from_product([symbols, fields], names=['symbol', 'field']) + + if len(fields) * len(self.symbol) != len( + df.columns) or len(index) != len(df.index): + cols_multi = pd.MultiIndex.from_product( + [symbols, fields], names=['symbol', 'field']) cols_multi = cols_multi.sort_values() - df_final = pd.DataFrame(index=index, columns=cols_multi, data=np.nan) + df_final = pd.DataFrame( + index=index, columns=cols_multi, data=np.nan) df_final.index.name = df.index.name - + df_final.update(df) - + # idx_diff = sorted(set(df_final.index) - set(df.index)) - col_diff = sorted(set(df_final.columns.levels[0].values) - set(df.columns.levels[0].values)) - print ("WARNING: some data is unavailable: " - # + "\n At index " + ', '.join(idx_diff) - + "\n At fields " + ', '.join(col_diff)) + col_diff = sorted( + set(df_final.columns.levels[0].values) - + set(df.columns.levels[0].values)) + print("WARNING: some data is unavailable: " + # + "\n At index " + ', '.join(idx_diff) + + "\n At fields " + ', '.join(col_diff)) return df_final else: return df @@ -732,39 +815,55 @@ def _fill_missing_idx_col(self, df, index=None, symbols=None): def _align_and_merge_q_into_d(self): data_d, data_q = self.data_d, self.data_q if data_d is not None and data_q is not None: - df_ref_ann = data_q.loc[:, pd.IndexSlice[:, self.ANN_DATE_FIELD_NAME]].copy() + df_ref_ann = data_q.loc[:, + pd.IndexSlice[:, self. + ANN_DATE_FIELD_NAME]].copy() df_ref_ann.columns = df_ref_ann.columns.droplevel(level='field') - + dic_expanded = dict() - for field_name, df in data_q.groupby(level=1, axis=1): # by column multiindex fields + for field_name, df in data_q.groupby( + level=1, axis=1): # by column multiindex fields df_expanded = align(df, df_ref_ann, self.dates) dic_expanded[field_name] = df_expanded df_quarterly_expanded = pd.concat(dic_expanded.values(), axis=1) df_quarterly_expanded.index.name = self.TRADE_DATE_FIELD_NAME - - data_d_merge = self._merge_data([data_d, df_quarterly_expanded], index_name=self.TRADE_DATE_FIELD_NAME) + + data_d_merge = self._merge_data( + [data_d, df_quarterly_expanded], + index_name=self.TRADE_DATE_FIELD_NAME) data_d = data_d_merge.loc[data_d.index, :] self.data_d = data_d def _prepare_adj_factor(self): """Query and append daily adjust factor for prices.""" mask_stocks = self.data_inst['inst_type'] == 1 - if mask_stocks.sum() == 0: - return - symbol_stocks = self.data_inst.loc[mask_stocks].index.values - symbol_str = ','.join(symbol_stocks) - df_adj = self.data_api.query_adj_factor_daily(symbol_str, - start_date=self.extended_start_date_d, end_date=self.end_date, div=False) + if mask_stocks.sum() > 0: + symbol_stocks = self.data_inst.loc[mask_stocks].index.values + symbol_str = ','.join(symbol_stocks) + df_adj = self.data_api.query_adj_factor_daily( + symbol_str, + start_date=self.extended_start_date_d, + end_date=self.end_date, + div=False) + else: + dt_idx = self.data_api.query_trade_dates( + start_date=self.extended_start_date_d, + end_date=self.end_date, + ) + df_adj = pd.DataFrame( + index=dt_idx, columns=self.data_inst.index).fillna(1) self.append_df(df_adj, 'adjust_factor', is_quarterly=False) def _prepare_comp_info(self): # if a symbol is index member of any one universe, its value of index_member will be 1.0 res = dict() for univ in self.universe: - df = self.data_api.query_index_member_daily(univ, self.extended_start_date_d, self.end_date) + df = self.data_api.query_index_member_daily( + univ, self.extended_start_date_d, self.end_date) res[univ] = df df_res = pd.concat(res, axis=0) - df = df_res.groupby(by='trade_date').apply(lambda df: df.any(axis=0)).astype(float) + df = df_res.groupby( + by='trade_date').apply(lambda df: df.any(axis=0)).astype(float) # Always include additional symbols for code in self.symbol: @@ -772,9 +871,10 @@ def _prepare_comp_info(self): df[code] = 1.0 self.append_df(df, 'index_member', is_quarterly=False) - + # use weights of the first universe - df_weights = self.data_api.query_index_weights_daily(self.universe[0], self.extended_start_date_d, self.end_date) + df_weights = self.data_api.query_index_weights_daily( + self.universe[0], self.extended_start_date_d, self.end_date) self.append_df(df_weights, 'index_weight', is_quarterly=False) def _prepare_report_date(self): @@ -783,36 +883,44 @@ def _prepare_report_date(self): n = len(idx) quarter = idx.values // 100 % 100 df_report_date.loc[:, :] = quarter.reshape(n, -1) - + self.append_df(df_report_date, 'quarter', is_quarterly=True) - + def _prepare_inst_info(self): - res = self.data_api.query_inst_info(symbol=','.join(self.symbol), - fields='symbol,inst_type,name,list_date,' - 'delist_date,product,pricetick,multiplier,' - 'buylot,setlot', - inst_type="") + res = self.data_api.query_inst_info( + symbol=','.join(self.symbol), + fields='symbol,inst_type,name,list_date,' + 'delist_date,product,pricetick,multiplier,' + 'buylot,setlot', + inst_type="") self._data_inst = res def _prepare_group(self, group_fields): - data_map = {'sw1': ('SW', 1), - 'sw2': ('SW', 2), - 'sw3': ('SW', 3), - 'sw4': ('SW', 4), - 'zz1': ('ZZ', 1), - 'zz2': ('ZZ', 2)} + data_map = { + 'sw1': ('SW', 1), + 'sw2': ('SW', 2), + 'sw3': ('SW', 3), + 'sw4': ('SW', 4), + 'zz1': ('ZZ', 1), + 'zz2': ('ZZ', 2) + } for field in group_fields: type_, level = data_map[field] - df = self.data_api.query_industry_daily(symbol=','.join(self.symbol), - start_date=self.extended_start_date_q, end_date=self.end_date, - type_=type_, level=level) + df = self.data_api.query_industry_daily( + symbol=','.join(self.symbol), + start_date=self.extended_start_date_q, + end_date=self.end_date, + type_=type_, + level=level) self.append_df(df, field, is_quarterly=False) def _prepare_benchmark(self): - df_bench, msg = self.data_api.daily(self.benchmark, - start_date=self.extended_start_date_d, end_date=self.end_date, - adjust_mode=self.adjust_mode, - fields='trade_date,symbol,close,vwap,volume,turnover') + df_bench, msg = self.data_api.daily( + self.benchmark, + start_date=self.extended_start_date_d, + end_date=self.end_date, + adjust_mode=self.adjust_mode, + fields='trade_date,symbol,close,vwap,volume,turnover') # TODO: we want more than just close price of benchmark df_bench = df_bench.set_index('trade_date').loc[:, ['close']] return df_bench @@ -823,7 +931,9 @@ def _add_field(self, field_name, is_quarterly=None): self.fields.append(field_name) if not self._is_predefined_field(field_name): if is_quarterly is None: - raise ValueError("Field [{:s}] is not a predefined field, but no frequency information is provided.") + raise ValueError( + "Field [{:s}] is not a predefined field, but no frequency information is provided." + ) if is_quarterly: self.custom_quarterly_fields.append(field_name) else: @@ -847,11 +957,13 @@ def add_field(self, field_name, data_api=None): """ if data_api is None: if self.data_api is None: - print("Add field failed. No data_api available. Please specify one in parameter.") + print( + "Add field failed. No data_api available. Please specify one in parameter." + ) return False else: self.data_api = data_api - + if field_name in self.fields: print("Field name [{:s}] already exists.".format(field_name)) return False @@ -861,31 +973,40 @@ def add_field(self, field_name, data_api=None): return False merge_d, merge_q = self._prepare_daily_quarterly([field_name]) - + if self._is_daily_field(field_name): if self.data_d is None: - raise ValueError("Please prepare [{:s}] first.".format(field_name)) + raise ValueError( + "Please prepare [{:s}] first.".format(field_name)) merge, _ = self._prepare_daily_quarterly([field_name]) is_quarterly = False else: if self.data_q is None: - raise ValueError("Please prepare [{:s}] first.".format(field_name)) + raise ValueError( + "Please prepare [{:s}] first.".format(field_name)) _, merge = self._prepare_daily_quarterly([field_name]) is_quarterly = True - + merge = merge.loc[:, pd.IndexSlice[:, field_name]] merge.columns = merge.columns.droplevel(level=1) - self.append_df(merge, field_name, is_quarterly=is_quarterly) # whether contain only trade days is decided by existing data. - + self.append_df( + merge, field_name, is_quarterly=is_quarterly + ) # whether contain only trade days is decided by existing data. + if is_quarterly: df_ann = merge_q.loc[:, pd.IndexSlice[:, self.ANN_DATE_FIELD_NAME]] df_ann.columns = df_ann.columns.droplevel(level='field') df_expanded = align(merge, df_ann, self.dates) self.append_df(df_expanded, field_name, is_quarterly=False) return True - - def add_formula(self, field_name, formula, is_quarterly, overwrite=True, - formula_func_name_style='camel', data_api=None, + + def add_formula(self, + field_name, + formula, + is_quarterly, + overwrite=True, + formula_func_name_style='camel', + data_api=None, within_index=True): """ Add a new field, which is calculated using existing fields. @@ -913,24 +1034,28 @@ def add_formula(self, field_name, formula, is_quarterly, overwrite=True, """ if data_api is not None: self.data_api = data_api - + if field_name in self.fields: if overwrite: self.remove_field(field_name) print("Field [{:s}] is overwritten.".format(field_name)) else: - raise ValueError("Add formula failed: name [{:s}] exist. Try another name.".format(field_name)) + raise ValueError( + "Add formula failed: name [{:s}] exist. Try another name.". + format(field_name)) elif self._is_predefined_field(field_name): - raise ValueError("[{:s}] is alread a pre-defined field. Please use another name.".format(field_name)) - + raise ValueError( + "[{:s}] is alread a pre-defined field. Please use another name.". + format(field_name)) + parser = Parser() parser.set_capital(formula_func_name_style) - + expr = parser.parse(formula) - + var_df_dic = dict() var_list = expr.variables() - + # TODO: users do not need to prepare data before add_formula if not self.fields: self.fields.extend(var_list) @@ -943,23 +1068,35 @@ def add_formula(self, field_name, formula, is_quarterly, overwrite=True, success = self.add_field(var) if not success: return - + for var in var_list: if self._is_quarter_field(var): - df_var = self.get_ts_quarter(var, start_date=self.extended_start_date_q) + df_var = self.get_ts_quarter( + var, start_date=self.extended_start_date_q) else: # must use extended date. Default is start_date - df_var = self.get_ts(var, start_date=self.extended_start_date_d, end_date=self.end_date) - + df_var = self.get_ts( + var, + start_date=self.extended_start_date_d, + end_date=self.end_date) + var_df_dic[var] = df_var - + # TODO: send ann_date into expr.evaluate. We assume that ann_date of all fields of a symbol is the same df_ann = self._get_ann_df() if within_index: - df_index_member = self.get_ts('index_member', start_date=self.extended_start_date_d, end_date=self.end_date) - df_eval = parser.evaluate(var_df_dic, ann_dts=df_ann, trade_dts=self.dates, index_member=df_index_member) + df_index_member = self.get_ts( + 'index_member', + start_date=self.extended_start_date_d, + end_date=self.end_date) + df_eval = parser.evaluate( + var_df_dic, + ann_dts=df_ann, + trade_dts=self.dates, + index_member=df_index_member) else: - df_eval = parser.evaluate(var_df_dic, ann_dts=df_ann, trade_dts=self.dates) + df_eval = parser.evaluate( + var_df_dic, ann_dts=df_ann, trade_dts=self.dates) self.append_df(df_eval, field_name, is_quarterly=is_quarterly) @@ -967,7 +1104,7 @@ def add_formula(self, field_name, formula, is_quarterly, overwrite=True, df_ann = self._get_ann_df() df_expanded = align(df_eval, df_ann, self.dates) self.append_df(df_expanded, field_name, is_quarterly=False) - + def append_df(self, df, field_name, is_quarterly=False): """ Append DataFrame to existing multi-index DataFrame and add corresponding field name. @@ -991,16 +1128,19 @@ def append_df(self, df, field_name, is_quarterly=False): elif isinstance(df, pd.Series): df = pd.DataFrame(df) else: - raise ValueError("Data to be appended must be pandas format. But we have {}".format(type(df))) - + raise ValueError( + "Data to be appended must be pandas format. But we have {}". + format(type(df))) + if is_quarterly: the_data = self.data_q else: the_data = self.data_d - + exist_symbols = the_data.columns.levels[0] if len(df.columns) < len(exist_symbols): - df2 = pd.DataFrame(index=df.index, columns=exist_symbols, data=np.nan) + df2 = pd.DataFrame( + index=df.index, columns=exist_symbols, data=np.nan) df2.update(df) df = df2 elif len(df.columns) > len(exist_symbols): @@ -1010,11 +1150,12 @@ def append_df(self, df, field_name, is_quarterly=False): #the_data = apply_in_subprocess(pd.merge, args=(the_data, df), # kwargs={'left_index': True, 'right_index': True, 'how': 'left'}) # runs in *only* one process - the_data = pd.merge(the_data, df, left_index=True, right_index=True, how='left') + the_data = pd.merge( + the_data, df, left_index=True, right_index=True, how='left') the_data = the_data.sort_index(axis=1) #merge = the_data.join(df, how='left') # left: keep index of existing data unchanged #sort_columns(the_data) - + if is_quarterly: self.data_q = the_data else: @@ -1041,27 +1182,29 @@ def remove_field(self, field_names): field_names = field_names.split(',') else: raise ValueError("field_names must be str separated by comma.") - + for field_name in field_names: # parameter validation if field_name not in self.fields: - print("Field name [{:s}] does not exist. Stop remove_field.".format(field_name)) + print("Field name [{:s}] does not exist. Stop remove_field.". + format(field_name)) return - + if self._is_daily_field(field_name): is_quarterly = False elif self._is_quarter_field(field_name): is_quarterly = True else: - print("Field name [{}] is a pre-defined field, ignore.".format(field_name)) + print("Field name [{}] is a pre-defined field, ignore.".format( + field_name)) return - + # remove field data - + self.data_d = self.data_d.drop(field_name, axis=1, level=1) if is_quarterly: self.data_q = self.data_q.drop(field_name, axis=1, level=1) - + # remove fields name from list self.fields.remove(field_name) if is_quarterly: @@ -1095,25 +1238,26 @@ def get(self, symbol="", start_date=0, end_date=0, fields=""): """ sep = ',' - + if not fields: fields = slice(None) # self.fields else: fields = fields.split(sep) - + if not symbol: symbol = slice(None) # this is 3X faster than symbol = self.symbol else: symbol = symbol.split(sep) - + if not start_date: start_date = self.start_date if not end_date: end_date = self.end_date - - res = self.data_d.loc[pd.IndexSlice[start_date: end_date], pd.IndexSlice[symbol, fields]] + + res = self.data_d.loc[pd.IndexSlice[start_date:end_date], + pd.IndexSlice[symbol, fields]] return res - + def get_snapshot(self, snapshot_date, symbol="", fields=""): """ Get snapshot of given fields and symbol at snapshot_date. @@ -1144,17 +1288,21 @@ def get_snapshot(self, snapshot_date, symbol="", fields=""): else: return df - - res = self.get(symbol=symbol, start_date=snapshot_date, end_date=snapshot_date, fields=fields) + res = self.get( + symbol=symbol, + start_date=snapshot_date, + end_date=snapshot_date, + fields=fields) if res is None: - print("No data. for date={}, fields={}, symbol={}".format(snapshot_date, fields, symbol)) + print("No data. for date={}, fields={}, symbol={}".format( + snapshot_date, fields, symbol)) return - + res = res.stack(level='symbol', dropna=False) res.index = res.index.droplevel(level=self.TRADE_DATE_FIELD_NAME) - + return res - + def _get_ann_df(self): """ Query announcement date of financial statements of all securities. @@ -1170,15 +1318,17 @@ def _get_ann_df(self): return None df_ann = self.data_q.loc[:, pd.IndexSlice[:, self.ANN_DATE_FIELD_NAME]] df_ann.columns = df_ann.columns.droplevel(level='field') - + return df_ann def get_symbol(self, symbol, start_date=0, end_date=0, fields=""): - res = self.get(symbol, start_date=start_date, end_date=end_date, fields=fields) + res = self.get( + symbol, start_date=start_date, end_date=end_date, fields=fields) if res is None: - raise ValueError("No data. for " - "start_date={}, end_date={}, field={}, symbol={}".format(start_date, end_date, - fields, symbol)) + raise ValueError( + "No data. for " + "start_date={}, end_date={}, field={}, symbol={}".format( + start_date, end_date, fields, symbol)) res.columns = res.columns.droplevel(level='symbol') return res @@ -1190,18 +1340,24 @@ def get_ts_quarter(self, field, symbol="", start_date=0, end_date=0): symbol = self.symbol else: symbol = symbol.split(sep) - + if not start_date: start_date = self.start_date if not end_date: end_date = self.end_date - + df_ref_quarterly = self.data_q.loc[:, pd.IndexSlice[symbol, field]] - df_ref_quarterly.columns = df_ref_quarterly.columns.droplevel(level='field') - + df_ref_quarterly.columns = df_ref_quarterly.columns.droplevel( + level='field') + return df_ref_quarterly - - def get_ts(self, field, symbol="", start_date=0, end_date=0, keep_level=False): + + def get_ts(self, + field, + symbol="", + start_date=0, + end_date=0, + keep_level=False): """ Get time series data of single field. @@ -1222,18 +1378,20 @@ def get_ts(self, field, symbol="", start_date=0, end_date=0, keep_level=False): Index is int date, column is symbol. """ - res = self.get(symbol, start_date=start_date, end_date=end_date, fields=field) + res = self.get( + symbol, start_date=start_date, end_date=end_date, fields=field) if res is None: - print("No data. for start_date={}, end_date={}, field={}, symbol={}".format(start_date, - end_date, field, symbol)) + print( + "No data. for start_date={}, end_date={}, field={}, symbol={}". + format(start_date, end_date, field, symbol)) raise ValueError return - if not keep_level and len(res.columns) and len(field.split(','))==1: + if not keep_level and len(res.columns) and len(field.split(',')) == 1: res.columns = res.columns.droplevel(level='field') return res - + # -------------------------------------------------------------------------------------------------------- # DataView I/O @staticmethod @@ -1247,13 +1405,13 @@ def _load_h5(fp): """ h5 = pd.HDFStore(fp) - + res = dict() for key in h5.keys(): res[key] = h5.get(key) - + h5.close() - + return res def _process_data(self, large_memory=False): @@ -1266,7 +1424,6 @@ def _process_data(self, large_memory=False): b = (a / a.shift(1)).fillna(1.0) self.append_df(b, '_daily_adjust_factor', is_quarterly=False) - t = self.get_ts("_limit") if t is None or len(t.columns) == 0: dates = self.dates @@ -1274,8 +1431,9 @@ def _process_data(self, large_memory=False): before_first_day = dates[mask][-1] open = self.get_ts('open') - preclose = self.get_ts('close', start_date=before_first_day).shift(1) - limit = np.abs((open - preclose)/preclose) + preclose = self.get_ts( + 'close', start_date=before_first_day).shift(1) + limit = np.abs((open - preclose) / preclose) self.append_df(limit, "_limit", is_quarterly=False) # Snapshot dict may use large memory. @@ -1283,7 +1441,6 @@ def _process_data(self, large_memory=False): if large_memory: self.update_snapshot() - def update_snapshot(self): dates = self.data_d.index.values df = self.data_d.T.unstack() @@ -1307,8 +1464,9 @@ def load_dataview(self, folder_path='.', large_memory=True): path_meta_data = os.path.join(folder_path, 'meta_data.json') path_data = os.path.join(folder_path, 'data.hd5') if not (os.path.exists(path_meta_data) and os.path.exists(path_data)): - raise IOError("There is no data file under directory {}".format(folder_path)) - + raise IOError( + "There is no data file under directory {}".format(folder_path)) + meta_data = jutil.read_json(path_meta_data) dic = self._load_h5(path_data) self.data_d = dic.get('/data_d', None) @@ -1335,19 +1493,29 @@ def save_dataview(self, folder_path): abs_folder = os.path.abspath(folder_path) meta_path = os.path.join(folder_path, 'meta_data.json') data_path = os.path.join(folder_path, 'data.hd5') - - data_to_store = {'data_d': self.data_d, 'data_q': self.data_q, - 'data_benchmark': self.data_benchmark, 'data_inst': self.data_inst} - data_to_store = {k: v for k, v in data_to_store.items() if v is not None} - meta_data_to_store = {key: self.__dict__[key] for key in self.meta_data_list} + + data_to_store = { + 'data_d': self.data_d, + 'data_q': self.data_q, + 'data_benchmark': self.data_benchmark, + 'data_inst': self.data_inst + } + data_to_store = { + k: v + for k, v in data_to_store.items() if v is not None + } + meta_data_to_store = { + key: self.__dict__[key] + for key in self.meta_data_list + } print("\nStore data...") jutil.save_json(meta_data_to_store, meta_path) self._save_h5(data_path, data_to_store) - - print ("Dataview has been successfully saved to:\n" - + abs_folder + "\n\n" - + "You can load it with load_dataview('{:s}')".format(abs_folder)) + + print( + "Dataview has been successfully saved to:\n" + abs_folder + "\n\n" + + "You can load it with load_dataview('{:s}')".format(abs_folder)) @staticmethod def _save_h5(fp, dic): @@ -1362,8 +1530,9 @@ def _save_h5(fp, dic): """ import warnings - warnings.filterwarnings('ignore', category=pd.io.pytables.PerformanceWarning) - + warnings.filterwarnings( + 'ignore', category=pd.io.pytables.PerformanceWarning) + jutil.create_dir(fp) h5 = pd.HDFStore(fp, complevel=9, complib='blosc') for key, value in dic.items(): @@ -1393,9 +1562,10 @@ class EventDataView(object): index is date, columns is symbol-field MultiIndex """ + def __init__(self): self.data_api = None - + self.universe = "" self.symbol = [] self.benchmark = "" @@ -1405,45 +1575,49 @@ def __init__(self): self.fields = [] self.freq = 1 self.all_price = True - - self.meta_data_list = ['start_date', 'end_date', - 'extended_start_date_d', - 'freq', 'fields', 'symbol', 'universe', 'benchmark', - 'custom_daily_fields'] + + self.meta_data_list = [ + 'start_date', 'end_date', 'extended_start_date_d', 'freq', + 'fields', 'symbol', 'universe', 'benchmark', 'custom_daily_fields' + ] self.adjust_mode = 'post' - + self.data_d = None self.data_q = None self._data_benchmark = None self._data_inst = None self.data_custom = None # self._data_group = None - + common_list = {'symbol', 'start_date', 'end_date'} - market_bar_list = {'open', 'high', 'low', 'close', 'volume', 'turnover', 'vwap', 'oi'} - market_tick_list = {'volume', 'oi', - 'askprice1', 'askprice2', 'askprice3', 'askprice4', 'askprice5', - 'bidprice1', 'bidprice1', 'bidprice1', 'bidprice1', 'bidprice1', - 'askvolume1', 'askvolume2', 'askvolume3', 'askvolume4', 'askvolume5', - 'bidvolume1', 'bidvolume2', 'bidvolume3', 'bidvolume4', 'bidvolume5'} + market_bar_list = { + 'open', 'high', 'low', 'close', 'volume', 'turnover', 'vwap', 'oi' + } + market_tick_list = { + 'volume', 'oi', 'askprice1', 'askprice2', 'askprice3', 'askprice4', + 'askprice5', 'bidprice1', 'bidprice1', 'bidprice1', 'bidprice1', + 'bidprice1', 'askvolume1', 'askvolume2', 'askvolume3', + 'askvolume4', 'askvolume5', 'bidvolume1', 'bidvolume2', + 'bidvolume3', 'bidvolume4', 'bidvolume5' + } # fields map # TODO: 'freq' is not in market_daily_fields yet. self.market_daily_fields = \ {'open', 'high', 'low', 'close', 'volume', 'turnover', 'vwap', 'oi', 'trade_status', 'open_adj', 'high_adj', 'low_adj', 'close_adj', 'vwap_adj', 'index_member', 'index_weight'} self.group_fields = {'sw1', 'sw2', 'sw3', 'sw4', 'zz1', 'zz2'} - self .custom_daily_fields = [] - + self.custom_daily_fields = [] + # const self.TRADE_STATUS_FIELD_NAME = 'trade_status' self.TRADE_DATE_FIELD_NAME = 'trade_date' - + # -------------------------------------------------------------------------------------------------------- # Properties @property def data_benchmark(self): return self._data_benchmark - + @property def data_inst(self): """ @@ -1454,13 +1628,15 @@ def data_inst(self): """ return self._data_inst - + @data_benchmark.setter def data_benchmark(self, df_new): if self.data_d is not None and df_new.shape[0] != self.data_d.shape[0]: - raise ValueError("You must provide a DataFrame with the same shape of data_benchmark.") + raise ValueError( + "You must provide a DataFrame with the same shape of data_benchmark." + ) self._data_benchmark = df_new - + @property def dates(self): """ @@ -1475,12 +1651,15 @@ def dates(self): if self.data_d is not None: res = self.data_d.index.values elif self.data_api is not None: - res = self.data_api.query_trade_dates(self.extended_start_date_d, self.end_date) + res = self.data_api.query_trade_dates(self.extended_start_date_d, + self.end_date) else: - raise ValueError("Cannot get dates array when neither of data and data_api exists.") - + raise ValueError( + "Cannot get dates array when neither of data and data_api exists." + ) + return res - + # -------------------------------------------------------------------------------------------------------- # Fields def _is_quarter_field(self, field_name): @@ -1503,7 +1682,7 @@ def _is_quarter_field(self, field_name): or field_name in self.fin_indicator or field_name in self.custom_quarterly_fields) return res - + def _is_daily_field(self, field_name): """ Check whether a field name is daily frequency. @@ -1523,7 +1702,7 @@ def _is_daily_field(self, field_name): or field_name in self.custom_daily_fields or field_name in self.group_fields) return flag - + def _is_predefined_field(self, field_name): """ Check whether a field name can be recognized. @@ -1538,8 +1717,9 @@ def _is_predefined_field(self, field_name): bool """ - return self._is_quarter_field(field_name) or self._is_daily_field(field_name) - + return self._is_quarter_field(field_name) or self._is_daily_field( + field_name) + def _get_fields(self, field_type, fields, complement=False, append=False): """ Get list of fields that are in ref_quarterly_fields. @@ -1556,43 +1736,44 @@ def _get_fields(self, field_type, fields, complement=False, append=False): list """ - pool_map = {'market_daily': self.market_daily_fields, - 'group': self.group_fields} + pool_map = { + 'market_daily': self.market_daily_fields, + 'group': self.group_fields + } pool_map['daily'] = set.union(pool_map['market_daily'], pool_map['group'], self.custom_daily_fields) - + pool = pool_map.get(field_type, None) if pool is None: raise NotImplementedError("field_type = {:s}".format(field_type)) - + s = set.intersection(set(pool), set(fields)) if not s: return [] - + if complement: s = set(fields) - s - + if field_type == 'market_daily' and self.all_price: # turnover will not be adjusted s.update({'open', 'high', 'close', 'low', 'vwap'}) - + if append: s.add('symbol') if field_type == 'market_daily' or field_type == 'ref_daily': s.add('trade_date') if field_type == 'market_daily': s.add(self.TRADE_STATUS_FIELD_NAME) - elif (field_type == 'income' - or field_type == 'balance_sheet' + elif (field_type == 'income' or field_type == 'balance_sheet' or field_type == 'cash_flow' or field_type == 'fin_indicator'): s.add(self.ANN_DATE_FIELD_NAME) s.add(self.REPORT_DATE_FIELD_NAME) - + l = list(s) return l - + # -------------------------------------------------------------------------------------------------------- # Prepare data def init_from_config(self, props, data_api): @@ -1609,31 +1790,36 @@ def init_from_config(self, props, data_api): """ # data_api.init_from_config(props) self.data_api = data_api - + sep = ',' - + # initialize parameters self.start_date = props['start_date'] - self.extended_start_date_d = jutil.shift(self.start_date, n_weeks=-8) # query more data + self.extended_start_date_d = jutil.shift( + self.start_date, n_weeks=-8) # query more data self.extended_start_date_q = jutil.shift(self.start_date, n_weeks=-130) self.end_date = props['end_date'] self.all_price = props.get('all_price', True) self.freq = props.get('freq', 1) - + # get and filter fields fields = props.get('fields', []) if fields: fields = props['fields'].split(sep) - self.fields = [field for field in fields if self._is_predefined_field(field)] + self.fields = [ + field for field in fields if self._is_predefined_field(field) + ] if len(self.fields) < len(fields): - print("Field name [{}] not valid, ignore.".format(set.difference(set(fields), set(self.fields)))) - + print("Field name [{}] not valid, ignore.".format( + set.difference(set(fields), set(self.fields)))) + # append additional fields if self.all_price: - self.fields.extend(['open_adj', 'high_adj', 'low_adj', 'close_adj', - 'open', 'high', 'low', 'close', - 'vwap', 'vwap_adj']) - + self.fields.extend([ + 'open_adj', 'high_adj', 'low_adj', 'close_adj', 'open', 'high', + 'low', 'close', 'vwap', 'vwap_adj' + ]) + # initialize universe/symbol universe = props.get('universe', "") symbol = props.get('symbol', "") @@ -1647,7 +1833,9 @@ def init_from_config(self, props, data_api): self.universe = univ_list symbols_list = [] for univ in univ_list: - symbols_list.extend(data_api.query_index_member(univ, self.extended_start_date_d, self.end_date)) + symbols_list.extend( + data_api.query_index_member( + univ, self.extended_start_date_d, self.end_date)) self.symbol = sorted(list(set(symbols_list))) else: self.symbol = sorted(symbol.split(sep)) @@ -1659,42 +1847,52 @@ def init_from_config(self, props, data_api): print("More than one universe are used: {}, " "use the first one ({}) as index by default. " "If you want to use other benchmark, " - "please specify benchmark in configs.".format(repr(self.universe), self.universe[0])) + "please specify benchmark in configs.".format( + repr(self.universe), self.universe[0])) self.benchmark = self.universe[0] - + print("Initialize config success.") - def distributed_query(self, query_func_name, symbol, start_date, end_date, limit=100000, **kwargs): + def distributed_query(self, + query_func_name, + symbol, + start_date, + end_date, + limit=100000, + **kwargs): n_symbols = len(symbol.split(',')) dates = self.data_api.query_trade_dates(start_date, end_date) n_days = len(dates) - + if n_symbols * n_days > limit: n = limit // n_symbols - + df_list = [] i = 0 pos1, pos2 = n * i, n * (i + 1) - 1 while pos2 < n_days: print(pos2) - df, msg = getattr(self.data_api, query_func_name)(symbol=symbol, - start_date=dates[pos1], end_date=dates[pos2], - **kwargs) + df, msg = getattr(self.data_api, query_func_name)( + symbol=symbol, + start_date=dates[pos1], + end_date=dates[pos2], + **kwargs) df_list.append(df) i += 1 pos1, pos2 = n * i, n * (i + 1) - 1 if pos1 < n_days: - df, msg = getattr(self.data_api, query_func_name)(symbol=symbol, - start_date=dates[pos1], end_date=dates[-1], - **kwargs) + df, msg = getattr(self.data_api, query_func_name)( + symbol=symbol, + start_date=dates[pos1], + end_date=dates[-1], + **kwargs) df_list.append(df) df = pd.concat(df_list, axis=0) else: - df, msg = getattr(self.data_api, query_func_name)(symbol, - start_date=start_date, end_date=end_date, - **kwargs) + df, msg = getattr(self.data_api, query_func_name)( + symbol, start_date=start_date, end_date=end_date, **kwargs) return df, msg - + def prepare_data(self): """Prepare data for the FIRST time.""" # prepare benchmark and group @@ -1702,22 +1900,22 @@ def prepare_data(self): data_d = self._prepare_daily_quarterly(self.fields) self.data_d = data_d #self._align_and_merge_q_into_d() - + print("Query instrument info...") self._prepare_inst_info() - + if self.benchmark: print("Query benchmark...") self._data_benchmark = self._prepare_benchmark() - + print("Data has been successfully prepared.") - + @staticmethod def _process_index_co(df, index_name): df = df.astype(dtype={index_name: int}) df = df.drop_duplicates(subset=['symbol', index_name]) return df - + def _prepare_daily_quarterly(self, fields): """ Query and process data from data_api. @@ -1734,11 +1932,11 @@ def _prepare_daily_quarterly(self, fields): """ if not fields: return None, None - + # query data print("Query data - query...") daily_list = self._query_data(self.symbol, fields) - + def pivot_and_sort(df, index_name): df = self._process_index_co(df, index_name) df = df.pivot(index=index_name, columns='symbol') @@ -1748,17 +1946,22 @@ def pivot_and_sort(df, index_name): df = df.sort_index(axis=1, level=col_names) df.index.name = index_name return df - + multi_daily = None if daily_list: - daily_list_pivot = [pivot_and_sort(df, self.TRADE_DATE_FIELD_NAME) for df in daily_list] - multi_daily = self._merge_data(daily_list_pivot, self.TRADE_DATE_FIELD_NAME) + daily_list_pivot = [ + pivot_and_sort(df, self.TRADE_DATE_FIELD_NAME) + for df in daily_list + ] + multi_daily = self._merge_data(daily_list_pivot, + self.TRADE_DATE_FIELD_NAME) # use self.dates as index because original data have weekends - multi_daily = self._fill_missing_idx_col(multi_daily, index=self.dates, symbols=self.symbol) + multi_daily = self._fill_missing_idx_col( + multi_daily, index=self.dates, symbols=self.symbol) print("Query data - daily fields prepared.") - + return multi_daily - + def _query_data(self, symbol, fields): """ Query data using different APIs, then store them in dict. @@ -1778,39 +1981,55 @@ def _query_data(self, symbol, fields): """ sep = ',' symbol_str = sep.join(symbol) - + if self.freq == 1: daily_list = [] - + # TODO : use fields = {field: kwargs} to enable params - fields_market_daily = self._get_fields('market_daily', fields, append=True) + fields_market_daily = self._get_fields( + 'market_daily', fields, append=True) if fields_market_daily: - print("NOTE: price adjust method is [{:s} adjust]".format(self.adjust_mode)) + print("NOTE: price adjust method is [{:s} adjust]".format( + self.adjust_mode)) # no adjust prices and other market daily fields - df_daily, msg1 = self.distributed_query('daily', symbol_str, - start_date=self.extended_start_date_d, end_date=self.end_date, - adjust_mode=None, fields=sep.join(fields_market_daily), limit=100000) + df_daily, msg1 = self.distributed_query( + 'daily', + symbol_str, + start_date=self.extended_start_date_d, + end_date=self.end_date, + adjust_mode=None, + fields=sep.join(fields_market_daily), + limit=100000) #df_daily, msg1 = self.data_api.daily(symbol_str, start_date=self.extended_start_date_d, end_date=self.end_date, # adjust_mode=None, fields=sep.join(fields_market_daily)) - + if self.all_price: adj_cols = ['open', 'high', 'low', 'close', 'vwap'] # adjusted prices #df_daily_adjust, msg11 = self.data_api.daily(symbol_str, start_date=self.extended_start_date_d, end_date=self.end_date, # adjust_mode=self.adjust_mode, fields=','.join(adj_cols)) - df_daily_adjust, msg1 = self.distributed_query('daily', symbol_str, - start_date=self.extended_start_date_d, end_date=self.end_date, - adjust_mode=self.adjust_mode, fields=sep.join(fields_market_daily), limit=100000) - - df_daily = pd.merge(df_daily, df_daily_adjust, how='outer', - on=['symbol', 'trade_date'], suffixes=('', '_adj')) + df_daily_adjust, msg1 = self.distributed_query( + 'daily', + symbol_str, + start_date=self.extended_start_date_d, + end_date=self.end_date, + adjust_mode=self.adjust_mode, + fields=sep.join(fields_market_daily), + limit=100000) + + df_daily = pd.merge( + df_daily, + df_daily_adjust, + how='outer', + on=['symbol', 'trade_date'], + suffixes=('', '_adj')) daily_list.append(df_daily.loc[:, fields_market_daily]) - + else: raise NotImplementedError("freq = {}".format(self.freq)) - + return daily_list - + @staticmethod def _merge_data(dfs, index_name='trade_date'): """ @@ -1831,104 +2050,129 @@ def _merge_data(dfs, index_name='trade_date'): """ # dfs = [df for df in dfs if df is not None] - + merge = pd.concat(dfs, axis=1, join='outer') - + # drop duplicated columns. ONE LINE EFFICIENT version mask_duplicated = merge.columns.duplicated() if np.any(mask_duplicated): # print("Duplicated columns found. Dropped.") merge = merge.loc[:, ~mask_duplicated] - + # if merge.isnull().sum().sum() > 0: # print "WARNING: nan in final merged data. NO fill" # merge.fillna(method='ffill', inplace=True) - + merge = merge.sort_index(axis=1, level=['symbol', 'field']) merge.index.name = index_name - + return merge - + def _fill_missing_idx_col(self, df, index=None, symbols=None): if index is None: index = df.index if symbols is None: symbols = self.symbol fields = df.columns.levels[1] - - if len(fields) * len(self.symbol) != len(df.columns) or len(index) != len(df.index): - cols_multi = pd.MultiIndex.from_product([symbols, fields], names=['symbol', 'field']) + + if len(fields) * len(self.symbol) != len( + df.columns) or len(index) != len(df.index): + cols_multi = pd.MultiIndex.from_product( + [symbols, fields], names=['symbol', 'field']) cols_multi = cols_multi.sort_values() - df_final = pd.DataFrame(index=index, columns=cols_multi, data=np.nan) + df_final = pd.DataFrame( + index=index, columns=cols_multi, data=np.nan) df_final.index.name = df.index.name - + df_final.update(df) - + # idx_diff = sorted(set(df_final.index) - set(df.index)) - col_diff = sorted(set(df_final.columns.levels[0].values) - set(df.columns.levels[0].values)) - print ("WARNING: some data is unavailable: " - # + "\n At index " + ', '.join(idx_diff) - + "\n At fields " + ', '.join(col_diff)) + col_diff = sorted( + set(df_final.columns.levels[0].values) - + set(df.columns.levels[0].values)) + print("WARNING: some data is unavailable: " + # + "\n At index " + ', '.join(idx_diff) + + "\n At fields " + ', '.join(col_diff)) return df_final else: return df - + def _prepare_adj_factor(self): """Query and append daily adjust factor for prices.""" mask_stocks = self.data_inst['inst_type'] == 1 - if mask_stocks.sum() == 0: - return - symbol_stocks = self.data_inst.loc[mask_stocks].index.values - symbol_str = ','.join(symbol_stocks) - df_adj = self.data_api.query_adj_factor_daily(symbol_str, - start_date=self.extended_start_date_d, end_date=self.end_date, div=False) + if mask_stocks.sum() > 0: + symbol_stocks = self.data_inst.loc[mask_stocks].index.values + symbol_str = ','.join(symbol_stocks) + df_adj = self.data_api.query_adj_factor_daily( + symbol_str, + start_date=self.extended_start_date_d, + end_date=self.end_date, + div=False) + else: + dt_idx = self.data_api.query_trade_dates( + start_date=self.extended_start_date_d, + end_date=self.end_date, + ) + df_adj = pd.DataFrame( + index=dt_idx, columns=self.data_inst.index).fillna(1) self.append_df(df_adj, 'adjust_factor', is_quarterly=False) - + def _prepare_comp_info(self): # if a symbol is index member of any one universe, its value of index_member will be 1.0 res = dict() for univ in self.universe: - df = self.data_api.query_index_member_daily(univ, self.extended_start_date_d, self.end_date) + df = self.data_api.query_index_member_daily( + univ, self.extended_start_date_d, self.end_date) res[univ] = df df_res = pd.concat(res, axis=0) - df = df_res.groupby(by='trade_date').apply(lambda df: df.any(axis=0)).astype(float) + df = df_res.groupby( + by='trade_date').apply(lambda df: df.any(axis=0)).astype(float) self.append_df(df, 'index_member', is_quarterly=False) - + # use weights of the first universe - df_weights = self.data_api.query_index_weights_daily(self.universe[0], self.extended_start_date_d, self.end_date) + df_weights = self.data_api.query_index_weights_daily( + self.universe[0], self.extended_start_date_d, self.end_date) self.append_df(df_weights, 'index_weight', is_quarterly=False) - + def _prepare_inst_info(self): - res = self.data_api.query_inst_info(symbol=','.join(self.symbol), - fields='symbol,inst_type,name,list_date,' - 'delist_date,product,pricetick,multiplier,' - 'buylot,setlot', - inst_type="") + res = self.data_api.query_inst_info( + symbol=','.join(self.symbol), + fields='symbol,inst_type,name,list_date,' + 'delist_date,product,pricetick,multiplier,' + 'buylot,setlot', + inst_type="") self._data_inst = res - + def _prepare_group(self, group_fields): - data_map = {'sw1': ('SW', 1), - 'sw2': ('SW', 2), - 'sw3': ('SW', 3), - 'sw4': ('SW', 4), - 'zz1': ('ZZ', 1), - 'zz2': ('ZZ', 2)} + data_map = { + 'sw1': ('SW', 1), + 'sw2': ('SW', 2), + 'sw3': ('SW', 3), + 'sw4': ('SW', 4), + 'zz1': ('ZZ', 1), + 'zz2': ('ZZ', 2) + } for field in group_fields: type_, level = data_map[field] - df = self.data_api.query_industry_daily(symbol=','.join(self.symbol), - start_date=self.extended_start_date_q, end_date=self.end_date, - type_=type_, level=level) + df = self.data_api.query_industry_daily( + symbol=','.join(self.symbol), + start_date=self.extended_start_date_q, + end_date=self.end_date, + type_=type_, + level=level) self.append_df(df, field, is_quarterly=False) - + def _prepare_benchmark(self): - df_bench, msg = self.data_api.daily(self.benchmark, - start_date=self.extended_start_date_d, end_date=self.end_date, - adjust_mode=self.adjust_mode, - fields='trade_date,symbol,close,vwap,volume,turnover') + df_bench, msg = self.data_api.daily( + self.benchmark, + start_date=self.extended_start_date_d, + end_date=self.end_date, + adjust_mode=self.adjust_mode, + fields='trade_date,symbol,close,vwap,volume,turnover') # TODO: we want more than just close price of benchmark df_bench = df_bench.set_index('trade_date').loc[:, ['close']] return df_bench - + # -------------------------------------------------------------------------------------------------------- # Add/Remove Fields&Formulas def _add_field(self, field_name, is_quarterly=None): @@ -1938,7 +2182,8 @@ def _add_field(self, field_name, is_quarterly=None): def _add_symbol(self, symbol_name): if symbol_name in self.symbol: - print("symbol [{:s}] already exists, add_symbol failed.".format(symbol_name)) + print("symbol [{:s}] already exists, add_symbol failed.".format( + symbol_name)) return self.symbol.append(symbol_name) @@ -1960,45 +2205,56 @@ def add_field(self, field_name, data_api=None): """ if data_api is None: if self.data_api is None: - print("Add field failed. No data_api available. Please specify one in parameter.") + print( + "Add field failed. No data_api available. Please specify one in parameter." + ) return False else: self.data_api = data_api - + if field_name in self.fields: print("Field name [{:s}] already exists.".format(field_name)) return False - + if not self._is_predefined_field(field_name): print("Field name [{}] not valid, ignore.".format(field_name)) return False - + merge_d, merge_q = self._prepare_daily_quarterly([field_name]) - + if self._is_daily_field(field_name): if self.data_d is None: - raise ValueError("Please prepare [{:s}] first.".format(field_name)) + raise ValueError( + "Please prepare [{:s}] first.".format(field_name)) merge, _ = self._prepare_daily_quarterly([field_name]) is_quarterly = False else: if self.data_q is None: - raise ValueError("Please prepare [{:s}] first.".format(field_name)) + raise ValueError( + "Please prepare [{:s}] first.".format(field_name)) _, merge = self._prepare_daily_quarterly([field_name]) is_quarterly = True - + merge = merge.loc[:, pd.IndexSlice[:, field_name]] merge.columns = merge.columns.droplevel(level=1) - self.append_df(merge, field_name, is_quarterly=is_quarterly) # whether contain only trade days is decided by existing data. - + self.append_df( + merge, field_name, is_quarterly=is_quarterly + ) # whether contain only trade days is decided by existing data. + if is_quarterly: df_ann = merge_q.loc[:, pd.IndexSlice[:, self.ANN_DATE_FIELD_NAME]] df_ann.columns = df_ann.columns.droplevel(level='field') df_expanded = align(merge, df_ann, self.dates) self.append_df(df_expanded, field_name, is_quarterly=False) return True - - def add_formula(self, field_name, formula, is_quarterly, overwrite=True, - formula_func_name_style='camel', data_api=None, + + def add_formula(self, + field_name, + formula, + is_quarterly, + overwrite=True, + formula_func_name_style='camel', + data_api=None, within_index=True): """ Add a new field, which is calculated using existing fields. @@ -2026,23 +2282,25 @@ def add_formula(self, field_name, formula, is_quarterly, overwrite=True, """ if data_api is not None: self.data_api = data_api - + if field_name in self.fields: if overwrite: self.remove_field(field_name) print("Field [{:s}] is overwritten.".format(field_name)) else: - print("Add formula failed: name [{:s}] exist. Try another name.".format(field_name)) + print( + "Add formula failed: name [{:s}] exist. Try another name.". + format(field_name)) return - + parser = Parser() parser.set_capital(formula_func_name_style) - + expr = parser.parse(formula) - + var_df_dic = dict() var_list = expr.variables() - + # TODO: users do not need to prepare data before add_formula if not self.fields: self.fields.extend(var_list) @@ -2055,31 +2313,43 @@ def add_formula(self, field_name, formula, is_quarterly, overwrite=True, success = self.add_field(var) if not success: return - + for var in var_list: if self._is_quarter_field(var): - df_var = self.get_ts_quarter(var, start_date=self.extended_start_date_q) + df_var = self.get_ts_quarter( + var, start_date=self.extended_start_date_q) else: # must use extended date. Default is start_date - df_var = self.get_ts(var, start_date=self.extended_start_date_d, end_date=self.end_date) - + df_var = self.get_ts( + var, + start_date=self.extended_start_date_d, + end_date=self.end_date) + var_df_dic[var] = df_var - + # TODO: send ann_date into expr.evaluate. We assume that ann_date of all fields of a symbol is the same df_ann = self._get_ann_df() if within_index: - df_index_member = self.get_ts('index_member', start_date=self.extended_start_date_d, end_date=self.end_date) - df_eval = parser.evaluate(var_df_dic, ann_dts=df_ann, trade_dts=self.dates, index_member=df_index_member) + df_index_member = self.get_ts( + 'index_member', + start_date=self.extended_start_date_d, + end_date=self.end_date) + df_eval = parser.evaluate( + var_df_dic, + ann_dts=df_ann, + trade_dts=self.dates, + index_member=df_index_member) else: - df_eval = parser.evaluate(var_df_dic, ann_dts=df_ann, trade_dts=self.dates) - + df_eval = parser.evaluate( + var_df_dic, ann_dts=df_ann, trade_dts=self.dates) + self.append_df(df_eval, field_name, is_quarterly=is_quarterly) - + if is_quarterly: df_ann = self._get_ann_df() df_expanded = align(df_eval, df_ann, self.dates) self.append_df(df_expanded, field_name, is_quarterly=False) - + def append_df(self, df, field_name, is_quarterly=False): """ Append DataFrame to existing multi-index DataFrame and add corresponding field name. @@ -2103,30 +2373,34 @@ def append_df(self, df, field_name, is_quarterly=False): elif isinstance(df, pd.Series): df = pd.DataFrame(df) else: - raise ValueError("Data to be appended must be pandas format. But we have {}".format(type(df))) - + raise ValueError( + "Data to be appended must be pandas format. But we have {}". + format(type(df))) + if is_quarterly: the_data = self.data_q else: the_data = self.data_d - + exist_symbols = the_data.columns.levels[0] if len(df.columns) < len(exist_symbols): - df2 = pd.DataFrame(index=df.index, columns=exist_symbols, data=np.nan) + df2 = pd.DataFrame( + index=df.index, columns=exist_symbols, data=np.nan) df2.update(df) df = df2 elif len(df.columns) > len(exist_symbols): df = df.loc[:, exist_symbols] multi_idx = pd.MultiIndex.from_product([exist_symbols, [field_name]]) df.columns = multi_idx - + #the_data = apply_in_subprocess(pd.merge, args=(the_data, df), # kwargs={'left_index': True, 'right_index': True, 'how': 'left'}) # runs in *only* one process - the_data = pd.merge(the_data, df, left_index=True, right_index=True, how='left') + the_data = pd.merge( + the_data, df, left_index=True, right_index=True, how='left') the_data = the_data.sort_index(axis=1) #merge = the_data.join(df, how='left') # left: keep index of existing data unchanged #sort_columns(the_data) - + if is_quarterly: self.data_q = the_data else: @@ -2156,26 +2430,30 @@ def append_df_symbol(self, df, symbol_name): elif isinstance(df, pd.Series): df = pd.DataFrame(df) else: - raise ValueError("Data to be appended must be pandas format. But we have {}".format(type(df))) - + raise ValueError( + "Data to be appended must be pandas format. But we have {}". + format(type(df))) + the_data = self.data_d - + exist_fields = the_data.columns.levels[1] if len(set(exist_fields) - set(df.columns)): - #if set(df.columns) < set(exist_fields): - df2 = pd.DataFrame(index=df.index, columns=exist_fields, data=np.nan) + #if set(df.columns) < set(exist_fields): + df2 = pd.DataFrame( + index=df.index, columns=exist_fields, data=np.nan) df2.update(df) df = df2 multi_idx = pd.MultiIndex.from_product([[symbol_name], exist_fields]) df.columns = multi_idx - + #the_data = apply_in_subprocess(pd.merge, args=(the_data, df), # kwargs={'left_index': True, 'right_index': True, 'how': 'left'}) # runs in *only* one process - the_data = pd.merge(the_data, df, left_index=True, right_index=True, how='left') + the_data = pd.merge( + the_data, df, left_index=True, right_index=True, how='left') the_data = the_data.sort_index(axis=1) #merge = the_data.join(df, how='left') # left: keep index of existing data unchanged #sort_columns(the_data) - + self.data_d = the_data self._add_symbol(symbol_name) @@ -2200,27 +2478,28 @@ def remove_field(self, field_names): pass else: raise ValueError("field_names must be str or list of str.") - + for field_name in field_names: # parameter validation if field_name not in self.fields: print("Field name [{:s}] does not exist.".format(field_name)) return - + if self._is_daily_field(field_name): is_quarterly = False elif self._is_quarter_field(field_name): is_quarterly = True else: - print("Field name [{}] is a pre-defined field, ignore.".format(field_name)) + print("Field name [{}] is a pre-defined field, ignore.".format( + field_name)) return - + # remove field data - + self.data_d = self.data_d.drop(field_name, axis=1, level=1) if is_quarterly: self.data_q = self.data_q.drop(field_name, axis=1, level=1) - + # remove fields name from list self.fields.remove(field_name) if is_quarterly: @@ -2229,10 +2508,15 @@ def remove_field(self, field_names): else: if field_name in self.custom_daily_fields: self.custom_daily_fields.remove(field_name) - + # -------------------------------------------------------------------------------------------------------- # Get Data API - def get(self, symbol="", start_date=0, end_date=0, fields="", data_format='wide'): + def get(self, + symbol="", + start_date=0, + end_date=0, + fields="", + data_format='wide'): """ Basic API to get arbitrary data. If nothing fetched, return None. @@ -2256,30 +2540,31 @@ def get(self, symbol="", start_date=0, end_date=0, fields="", data_format='wide' """ sep = ',' - + if not fields: fields = slice(None) # self.fields else: fields = fields.split(sep) - + if not symbol: symbol = slice(None) # this is 3X faster than symbol = self.symbol else: symbol = symbol.split(sep) - + if not start_date: start_date = self.start_date if not end_date: end_date = self.end_date - - res = self.data_d.loc[pd.IndexSlice[start_date: end_date], pd.IndexSlice[symbol, fields]] - + + res = self.data_d.loc[pd.IndexSlice[start_date:end_date], + pd.IndexSlice[symbol, fields]] + if data_format == 'wide': pass else: res = res.stack(level='symbol').reset_index() return res - + def get_snapshot(self, snapshot_date, symbol="", fields=""): """ Get snapshot of given fields and symbol at snapshot_date. @@ -2299,16 +2584,21 @@ def get_snapshot(self, snapshot_date, symbol="", fields=""): symbol as index, field as columns """ - res = self.get(symbol=symbol, start_date=snapshot_date, end_date=snapshot_date, fields=fields) + res = self.get( + symbol=symbol, + start_date=snapshot_date, + end_date=snapshot_date, + fields=fields) if res is None: - print("No data. for date={}, fields={}, symbol={}".format(snapshot_date, fields, symbol)) + print("No data. for date={}, fields={}, symbol={}".format( + snapshot_date, fields, symbol)) return - + res = res.stack(level='symbol', dropna=False) res.index = res.index.droplevel(level=self.TRADE_DATE_FIELD_NAME) - + return res - + def _get_ann_df(self): """ Query announcement date of financial statements of all securities. @@ -2324,15 +2614,17 @@ def _get_ann_df(self): return None df_ann = self.data_q.loc[:, pd.IndexSlice[:, self.ANN_DATE_FIELD_NAME]] df_ann.columns = df_ann.columns.droplevel(level='field') - + return df_ann - + def get_symbol(self, symbol, start_date=0, end_date=0, fields=""): - res = self.get(symbol, start_date=start_date, end_date=end_date, fields=fields) + res = self.get( + symbol, start_date=start_date, end_date=end_date, fields=fields) if res is None: - raise ValueError("No data. for " - "start_date={}, end_date={}, field={}, symbol={}".format(start_date, end_date, - fields, symbol)) + raise ValueError( + "No data. for " + "start_date={}, end_date={}, field={}, symbol={}".format( + start_date, end_date, fields, symbol)) res.columns = res.columns.droplevel(level='symbol') return res @@ -2358,15 +2650,17 @@ def get_ts(self, field, symbol="", start_date=0, end_date=0): Index is int date, column is symbol. """ - res = self.get(symbol, start_date=start_date, end_date=end_date, fields=field) + res = self.get( + symbol, start_date=start_date, end_date=end_date, fields=field) if res is None: - raise ValueError("No data. for " - "start_date={}, end_date={}, field={}, symbol={}".format(start_date, end_date, - field, symbol)) - + raise ValueError( + "No data. for " + "start_date={}, end_date={}, field={}, symbol={}".format( + start_date, end_date, field, symbol)) + res.columns = res.columns.droplevel(level='field') return res - + # -------------------------------------------------------------------------------------------------------- # DataView I/O @staticmethod @@ -2380,15 +2674,15 @@ def _load_h5(fp): """ h5 = pd.HDFStore(fp) - + res = dict() for key in h5.keys(): res[key] = h5.get(key) - + h5.close() - + return res - + def load_dataview(self, folder_path='.'): """ Load data from local file. @@ -2402,8 +2696,9 @@ def load_dataview(self, folder_path='.'): path_meta_data = os.path.join(folder_path, 'meta_data.json') path_data = os.path.join(folder_path, 'data.hd5') if not (os.path.exists(path_meta_data) and os.path.exists(path_data)): - raise IOError("There is no data file under directory {}".format(folder_path)) - + raise IOError( + "There is no data file under directory {}".format(folder_path)) + meta_data = jutil.read_json(path_meta_data) dic = self._load_h5(path_data) self.data_d = dic.get('/data_d', None) @@ -2412,9 +2707,9 @@ def load_dataview(self, folder_path='.'): self._data_inst = dic.get('/data_inst', None) self.data_custom = dic.get('/data_custom', None) self.__dict__.update(meta_data) - + print("Dataview loaded successfully.") - + def save_dataview(self, folder_path): """ Save data and meta_data_to_store to a single hd5 file. @@ -2429,21 +2724,31 @@ def save_dataview(self, folder_path): abs_folder = os.path.abspath(folder_path) meta_path = os.path.join(folder_path, 'meta_data.json') data_path = os.path.join(folder_path, 'data.hd5') - - data_to_store = {'data_d': self.data_d, 'data_q': self.data_q, - 'data_benchmark': self.data_benchmark, 'data_inst': self.data_inst, - 'data_custom': self.data_custom} - data_to_store = {k: v for k, v in data_to_store.items() if v is not None} - meta_data_to_store = {key: self.__dict__[key] for key in self.meta_data_list} - + + data_to_store = { + 'data_d': self.data_d, + 'data_q': self.data_q, + 'data_benchmark': self.data_benchmark, + 'data_inst': self.data_inst, + 'data_custom': self.data_custom + } + data_to_store = { + k: v + for k, v in data_to_store.items() if v is not None + } + meta_data_to_store = { + key: self.__dict__[key] + for key in self.meta_data_list + } + print("\nStore data...") jutil.save_json(meta_data_to_store, meta_path) self._save_h5(data_path, data_to_store) - - print ("Dataview has been successfully saved to:\n" - + abs_folder + "\n\n" - + "You can load it with load_dataview('{:s}')".format(abs_folder)) - + + print( + "Dataview has been successfully saved to:\n" + abs_folder + "\n\n" + + "You can load it with load_dataview('{:s}')".format(abs_folder)) + @staticmethod def _save_h5(fp, dic): """ @@ -2457,11 +2762,11 @@ def _save_h5(fp, dic): """ import warnings - warnings.filterwarnings('ignore', category=pd.io.pytables.PerformanceWarning) - + warnings.filterwarnings( + 'ignore', category=pd.io.pytables.PerformanceWarning) + jutil.create_dir(fp) h5 = pd.HDFStore(fp, complevel=9, complib='blosc') for key, value in dic.items(): h5[key] = value h5.close() -