diff --git a/README.md b/README.md index bb3af606..aadef6c1 100644 --- a/README.md +++ b/README.md @@ -194,7 +194,7 @@ Please cite with: title = {sequifier - causal transformer models for multivariate sequence modelling}, year = {2025}, publisher = {GitHub}, - version = {v1.0.0.6}, + version = {v1.1.0.0}, url = {https://github.com/0xideas/sequifier} } ``` diff --git a/docs/source/conf.py b/docs/source/conf.py index 5c60768c..76e0441d 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -15,7 +15,7 @@ project = 'sequifier' copyright = '2025, Leon Luithlen' author = 'Leon Luithlen' -release = 'v1.0.0.6' +release = 'v1.1.0.0' html_baseurl = 'https://www.sequifier.com/' # -- General configuration --------------------------------------------------- diff --git a/documentation/configs/preprocess.md b/documentation/configs/preprocess.md index 29c7e7c7..eb0e0b9b 100644 --- a/documentation/configs/preprocess.md +++ b/documentation/configs/preprocess.md @@ -86,3 +86,26 @@ After running `preprocess`, the following are generated: 2. **Metadata Config:** Located in `configs/metadata_configs/[NAME].json`. * **Crucial:** This file contains the integer mappings for categorical variables (`id_maps`) and normalization stats for real variables (`selected_columns_statistics`). * **Next Step:** You must link this file path in your `train.yaml` and `infer.yaml` under `metadata_config_path`. + + +## 5\. Advanced: Custom ID Mapping + +By default, Sequifier automatically generates integer IDs for categorical columns starting from index 2 (indices 0 and 1 are reserved for system use, such as "unknown" values). + +If you need to enforce specific integer mappings (e.g., to maintain consistency across different training runs or datasets), you can provide **precomputed ID maps**. + +1. Create a folder named `id_maps` inside your configs directory: `configs/id_maps/`. +2. Create a JSON file named exactly after the column you want to map (e.g., `my_column_name.json`). +3. The JSON file must contain a key-value dictionary where keys are the raw values and values are the integer IDs. + +**Constraints:** +* Integer IDs must start at **2** or higher. +* IDs **0** and **1** are reserved. + +**Example `configs/id_maps/category_col.json`:** +```json +{ + "cat": 2, + "dog": 3, + "mouse": 4 +} diff --git a/pyproject.toml b/pyproject.toml index 53735579..1ca0e5a9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "sequifier" -version = "v1.0.0.6" +version = "v1.1.0.0" authors = [ { name = "Leon Luithlen", email = "leontimnaluithlen@gmail.com" }, ] diff --git a/src/sequifier/config/train_config.py b/src/sequifier/config/train_config.py index 55618467..00c42bef 100644 --- a/src/sequifier/config/train_config.py +++ b/src/sequifier/config/train_config.py @@ -100,6 +100,12 @@ class DotDict(dict): def __deepcopy__(self, memo=None): return DotDict(copy.deepcopy(dict(self), memo=memo)) + def __getstate__(self): + return dict(self) + + def __setstate__(self, state): + self.update(state) + class TrainingSpecModel(BaseModel): """Pydantic model for training specifications. diff --git a/src/sequifier/helpers.py b/src/sequifier/helpers.py index e374ba26..7a51d48d 100644 --- a/src/sequifier/helpers.py +++ b/src/sequifier/helpers.py @@ -73,7 +73,9 @@ def construct_index_maps( A special mapping for index 0 is added: - If original IDs are strings, 0 maps to "unknown". - - If original IDs are integers, 0 maps to (minimum original ID) - 1. + - If original IDs are strings, 1 maps to "other". + - If original IDs are integers, 0 maps to (minimum original ID) - 2. + - If original IDs are integers, 1 maps to (minimum original ID) - 1. Args: id_maps: A nested dictionary mapping column names to their @@ -105,10 +107,13 @@ def construct_index_maps( val = next(iter(map_.values())) if isinstance(val, str): map_[0] = "unknown" + map_[1] = "other" else: if not isinstance(val, int): raise TypeError(f"Expected integer ID in map, got {type(val)}") - map_[0] = min(map_.values()) - 1 # type: ignore + min_id = int(min(map_.values())) + map_[0] = min_id - 2 # type: ignore + map_[1] = min_id - 1 index_map[target_column] = map_ return index_map diff --git a/src/sequifier/preprocess.py b/src/sequifier/preprocess.py index da2ed70e..426856ab 100644 --- a/src/sequifier/preprocess.py +++ b/src/sequifier/preprocess.py @@ -130,10 +130,22 @@ def __init__( col for col in data.columns if col not in ["sequenceId", "itemPosition"] ] id_maps, selected_columns_statistics = {}, {} + + precomputed_id_maps = load_precomputed_id_maps( + self.project_root, data_columns + ) + id_maps, selected_columns_statistics = _get_column_statistics( - data, data_columns, id_maps, selected_columns_statistics, 0 + data, + data_columns, + id_maps, + selected_columns_statistics, + 0, + precomputed_id_maps, ) + id_maps = id_maps | precomputed_id_maps + data, n_classes, col_types = _apply_column_statistics( data, data_columns, id_maps, selected_columns_statistics ) @@ -319,9 +331,14 @@ def _get_column_metadata_across_files( - data_columns (list[str]): List of all processed data column names. """ + n_rows_running_count = 0 id_maps, selected_columns_statistics = {}, {} + col_types, data_columns = None, None + + precomputed_id_maps = load_precomputed_id_maps(self.project_root, data_columns) + files_to_process = [] logger.info(f"Data path: {data_path}") for root, dirs, files in os.walk(data_path): @@ -354,6 +371,12 @@ def _get_column_metadata_across_files( if col_types is None: data_columns = current_file_cols col_types = {col: str(data.schema[col]) for col in data_columns} + + for col in precomputed_id_maps.keys(): + if col not in data_columns: + raise ValueError( + f"Precomputed column {col} not found in {file}" + ) else: if set(current_file_cols) != set(col_types.keys()): missing = set(col_types.keys()) - set(current_file_cols) @@ -382,12 +405,15 @@ def _get_column_metadata_across_files( id_maps, selected_columns_statistics, n_rows_running_count, + precomputed_id_maps, ) n_rows_running_count += data.shape[0] + id_maps = id_maps | precomputed_id_maps + if data_columns is None: raise RuntimeError("data_columns was not initialized correctly.") - n_classes = {col: len(id_maps[col]) + 1 for col in id_maps} + n_classes = {col: max(id_maps[col].values()) + 1 for col in id_maps} if col_types is None: raise RuntimeError("col_types was not initialized correctly.") @@ -785,14 +811,14 @@ def _apply_column_statistics( - `col_types`: The (potentially computed) column type dictionary. """ if n_classes is None: - n_classes = {col: len(id_maps[col]) + 1 for col in id_maps} + n_classes = {col: max(id_maps[col].values()) + 1 for col in id_maps} if col_types is None: col_types = {col: str(data.schema[col]) for col in data_columns} for col in data_columns: if col in id_maps: - data = data.with_columns(pl.col(col).replace(id_maps[col])) + data = data.with_columns(pl.col(col).replace(id_maps[col], default=1)) col_types[col] = "Int64" elif col in selected_columns_statistics: data = data.with_columns( @@ -805,6 +831,47 @@ def _apply_column_statistics( return (data, n_classes, col_types) +@beartype +def load_precomputed_id_maps( + project_root: str, data_columns: Optional[list[str]] +) -> dict[str, dict[Union[str, int], int]]: + """Loads custom ID maps from configs/id_maps if the folder exists. + + Args: + project_root: The path to the project root directory. + data_columns: Optional list of columns present in the data to validate + against the found map files. + + Returns: + A dictionary mapping column names to their ID maps. + """ + custom_maps = {} + path = os.path.join(project_root, "configs", "id_maps") + + if os.path.exists(path): + for file in os.listdir(path): + if file.endswith(".json"): + col_name = os.path.splitext(file)[0] + if data_columns is not None and col_name not in data_columns: + raise ValueError( + f"{file} does not correspond to any column in the data" + ) + + with open(os.path.join(path, file), "r") as f: + # Load and ensure values are integers + m = {k: int(v) for k, v in json.load(f).items()} + + if not len(m) > 0: + raise ValueError(f"map in {file} does not contain any values") + min_val = min(m.values()) + if min_val != 2: + raise ValueError( + f"minimum value in map {file} is {min_val}, must be 2." + ) + custom_maps[col_name] = m + return custom_maps + + @beartype def _get_column_statistics( data: pl.DataFrame, @@ -812,6 +879,7 @@ def _get_column_statistics( id_maps: dict[str, dict[Union[str, int], int]], selected_columns_statistics: dict[str, dict[str, float]], n_rows_running_count: int, + precomputed_id_maps: dict[str, dict[Union[str, int], int]], ) -> tuple[ dict[str, dict[Union[str, int], int]], dict[str, dict[str, float]], @@ -837,6 +905,8 @@ def _get_column_statistics( statistics to be updated. n_rows_running_count: The total number of rows processed *before* this chunk, used for weighting statistics. + precomputed_id_maps: A dictionary of pre-loaded ID maps that should + be applied and not re-computed. Returns: A tuple `(id_maps, selected_columns_statistics)` containing the @@ -863,9 +933,17 @@ def _get_column_statistics( pl.UInt64, ), ): - new_id_map = create_id_map(data, column=data_col) - id_maps[data_col] = combine_maps(new_id_map, id_maps.get(data_col, {})) + if data_col not in precomputed_id_maps: + new_id_map = create_id_map(data, column=data_col) + id_maps[data_col] = combine_maps(new_id_map, id_maps.get(data_col, {})) + else: + logger.info(f"Applying precomputed map for {data_col}") elif isinstance(dtype, (pl.Float32, pl.Float64)): + if data_col in precomputed_id_maps: + raise ValueError( + f"Column {data_col} is not categorical, precomputed map is invalid." + ) + combined_mean, combined_std = get_combined_statistics( data.shape[0], data.get_column(data_col).mean(), @@ -1262,7 +1340,7 @@ def create_id_map(data: pl.DataFrame, column: str) -> dict[Union[str, int], int] ids = sorted( [int(x) if not isinstance(x, str) else x for x in np.unique(data[column])] ) # type: ignore - id_map = {id_: i + 1 for i, id_ in enumerate(ids)} + id_map = {id_: i + 2 for i, id_ in enumerate(ids)} return dict(id_map) @@ -1330,7 +1408,7 @@ def combine_maps( A new, combined, and re-indexed ID map. """ combined_keys = sorted(list(set(list(map1.keys())).union(list(set(map2.keys()))))) - id_map = {id_: i + 1 for i, id_ in enumerate(combined_keys)} + id_map = {id_: i + 2 for i, id_ in enumerate(combined_keys)} return id_map diff --git a/tests/unit/test_preprocess.py b/tests/unit/test_preprocess.py index d14bbaad..98191bad 100644 --- a/tests/unit/test_preprocess.py +++ b/tests/unit/test_preprocess.py @@ -170,13 +170,13 @@ def test_get_column_statistics_state_accumulation(): # Pass 1 id_maps, stats = _get_column_statistics( - chunk1, ["cat_col", "num_col"], id_maps, stats, running_count + chunk1, ["cat_col", "num_col"], id_maps, stats, running_count, {} ) running_count += len(chunk1) # Pass 2 id_maps, stats = _get_column_statistics( - chunk2, ["cat_col", "num_col"], id_maps, stats, running_count + chunk2, ["cat_col", "num_col"], id_maps, stats, running_count, {} ) # Validations @@ -197,7 +197,7 @@ def test_create_id_map(): df = pl.DataFrame({"A": ["z", "x", "y", "x"]}) mapping = create_id_map(df, "A") - # Sorted unique values: x, y, z -> 1, 2, 3 - assert mapping["x"] == 1 - assert mapping["y"] == 2 - assert mapping["z"] == 3 + # Sorted unique values: x, y, z -> 2, 3, 4 + assert mapping["x"] == 2 + assert mapping["y"] == 3 + assert mapping["z"] == 4 diff --git a/tools/resize_pt_files.py b/tools/resize_pt_files.py new file mode 100644 index 00000000..1bfe7312 --- /dev/null +++ b/tools/resize_pt_files.py @@ -0,0 +1,284 @@ +import argparse +import json +import os +import sys +from typing import Any, Dict, Tuple + +import torch + + +def unpack_dataset_tuple(data_tuple: Tuple) -> Dict[str, Any]: + """ + Unpacks the standard Sequifier 5-element tuple into a structured dictionary + that is easier to handle programmatically. + """ + return { + "sequences": data_tuple[0], + "targets": data_tuple[1], + "seq_ids": data_tuple[2], + "sub_ids": data_tuple[3], + "start_pos": data_tuple[4], + } + + +def pack_dataset_tuple(data_dict: Dict[str, Any]) -> Tuple: + """ + Repacks the dictionary back into the 5-element tuple format expected by Sequifier. + """ + return ( + data_dict["sequences"], + data_dict["targets"], + data_dict["seq_ids"], + data_dict["sub_ids"], + data_dict["start_pos"], + ) + + +def concat_datasets(left: Dict[str, Any], right: Dict[str, Any]) -> Dict[str, Any]: + """ + Concatenates two dataset dictionaries along the batch dimension (dim 0). + """ + combined = {} + + # Concatenate sequence dicts + combined["sequences"] = { + k: torch.cat([left["sequences"][k], right["sequences"][k]], dim=0) + for k in left["sequences"] + } + + # Concatenate target dicts + combined["targets"] = { + k: torch.cat([left["targets"][k], right["targets"][k]], dim=0) + for k in left["targets"] + } + + # Concatenate metadata tensors + combined["seq_ids"] = torch.cat([left["seq_ids"], right["seq_ids"]], dim=0) + combined["sub_ids"] = torch.cat([left["sub_ids"], right["sub_ids"]], dim=0) + combined["start_pos"] = torch.cat([left["start_pos"], right["start_pos"]], dim=0) + + return combined + + +def slice_dataset(data: Dict[str, Any], start: int, end: int) -> Dict[str, Any]: + """ + Slices a dataset dictionary from start to end index. + Returns a view (fast), not a copy. + """ + sliced = {} + sliced["sequences"] = {k: v[start:end] for k, v in data["sequences"].items()} + sliced["targets"] = {k: v[start:end] for k, v in data["targets"].items()} + sliced["seq_ids"] = data["seq_ids"][start:end] + sliced["sub_ids"] = data["sub_ids"][start:end] + sliced["start_pos"] = data["start_pos"][start:end] + return sliced + + +def clone_dataset(data: Dict[str, Any]) -> Dict[str, Any]: + """ + Deep clones a dataset. Used for the remainder to allow the + large original tensor to be freed from memory. + """ + cloned = {} + cloned["sequences"] = {k: v.clone() for k, v in data["sequences"].items()} + cloned["targets"] = {k: v.clone() for k, v in data["targets"].items()} + cloned["seq_ids"] = data["seq_ids"].clone() + cloned["sub_ids"] = data["sub_ids"].clone() + cloned["start_pos"] = data["start_pos"].clone() + return cloned + + +def get_row_size_bytes(data: Dict[str, Any]) -> float: + """ + Calculates the exact size in bytes of a single row in the dataset. + """ + total_bytes = 0 + # Add size of one row for every tensor in sequences + for t in data["sequences"].values(): + total_bytes += t.element_size() * t.shape[1] + # Add size of one row for every tensor in targets + for t in data["targets"].values(): + total_bytes += t.element_size() * t.shape[1] + + # Add metadata sizes (1 element each) + total_bytes += data["seq_ids"].element_size() + total_bytes += data["sub_ids"].element_size() + total_bytes += data["start_pos"].element_size() + + return total_bytes + + +def process_split( + input_dir: str, + output_dir: str, + dataset_name: str, + target_size_mb: float, + split_suffix: str, +): + os.makedirs(output_dir, exist_ok=True) + + # 1. Determine File Order via metadata + meta_path = os.path.join(input_dir, "metadata.json") + if os.path.exists(meta_path): + with open(meta_path, "r") as f: + old_metadata = json.load(f) + input_files = [entry["path"] for entry in old_metadata.get("batch_files", [])] + expected_total_samples = old_metadata.get("total_samples", 0) + else: + input_files = sorted([f for f in os.listdir(input_dir) if f.endswith(".pt")]) + expected_total_samples = None + print( + f"Warning: No metadata.json found in {input_dir}. Using alphabetical sort." + ) + + # State variables + remainder_data = None + output_batch_idx = 0 + new_batch_files_metadata = [] + total_samples_processed = 0 + + target_bytes = target_size_mb * 1024 * 1024 + + print(f"Processing {input_dir} -> {output_dir}") + + # 2. Iterate Files (Vectorized) + for file_name in input_files: + file_path = os.path.join(input_dir, file_name) + if not os.path.exists(file_path): + continue + + try: + # Load full file into RAM + loaded_tuple = torch.load(file_path, map_location="cpu", weights_only=False) + current_data = unpack_dataset_tuple(loaded_tuple) + + # Concatenate with remainder from previous file (if any) + if remainder_data is not None: # pyright: ignore + full_data = concat_datasets(remainder_data, current_data) # pyright: ignore + # Free memory + del remainder_data # pyright: ignore + del current_data + else: + full_data = current_data + + num_rows = full_data["seq_ids"].shape[0] + if num_rows == 0: + continue + + # Calculate slice size (Rows per output file) + bytes_per_row = get_row_size_bytes(full_data) + target_rows = max(1, int(target_bytes / bytes_per_row)) + + # Slice and Save Loop + start_idx = 0 + while start_idx + target_rows <= num_rows: + end_idx = start_idx + target_rows + + # Create slice view + chunk_data = slice_dataset(full_data, start_idx, end_idx) + + # Save + fname = f"{dataset_name}-{split_suffix}-{output_batch_idx}.pt" + out_path = os.path.join(output_dir, fname) + torch.save(pack_dataset_tuple(chunk_data), out_path) + + # Update Metadata + chunk_len = end_idx - start_idx + new_batch_files_metadata.append({"path": fname, "samples": chunk_len}) + total_samples_processed += chunk_len + output_batch_idx += 1 + + start_idx = end_idx + + # Handle Remainder + if start_idx < num_rows: + # We have leftovers. We must CLONE them so we can drop the reference + # to the massive `full_data` tensor, freeing RAM for the next file load. + remainder_data = clone_dataset( + slice_dataset(full_data, start_idx, num_rows) + ) + else: + remainder_data = None + + # Explicitly free full_data to be safe + del full_data + + except Exception as e: + print(f"Error processing {file_path}: {e}") + sys.exit(1) + + # 3. Flush final remainder + if remainder_data is not None: # pyright: ignore + fname = f"{dataset_name}-{split_suffix}-{output_batch_idx}.pt" + out_path = os.path.join(output_dir, fname) + torch.save(pack_dataset_tuple(remainder_data), out_path) # pyright: ignore + + chunk_len = remainder_data["seq_ids"].shape[0] # pyright: ignore + new_batch_files_metadata.append({"path": fname, "samples": chunk_len}) + total_samples_processed += chunk_len + + # 4. Write New Metadata + new_metadata = { + "total_samples": total_samples_processed, + "batch_files": new_batch_files_metadata, + } + with open(os.path.join(output_dir, "metadata.json"), "w") as f: + json.dump(new_metadata, f, indent=4) + + # 5. Validation + if ( + expected_total_samples is not None + and total_samples_processed != expected_total_samples + ): + print( + f"WARNING: Sample count mismatch! Input: {expected_total_samples}, Output: {total_samples_processed}" + ) + else: + print( + f"Success: {split_suffix} processed. Total samples: {total_samples_processed}" + ) + + +def main(): + parser = argparse.ArgumentParser( + description="Fast Rechunker for Sequifier Datasets" + ) + parser.add_argument("data_folder", type=str, help="Path containing split folders") + parser.add_argument("dataset_name", type=str, help="Root name of dataset") + parser.add_argument("target_size_mb", type=float, help="Target file size in MB") + + args = parser.parse_args() + + if not os.path.exists(args.data_folder): + print("Data folder not found.") + sys.exit(1) + + contents = os.listdir(args.data_folder) + split_folders = [ + f + for f in contents + if f.startswith(f"{args.dataset_name}-split") + and os.path.isdir(os.path.join(args.data_folder, f)) + ] + + if not split_folders: + print("No matching split folders found.") + sys.exit(1) + + for folder in split_folders: + # Extract "split0" from "mydata-split0" + suffix = folder.split("-")[-1] + + input_path = os.path.join(args.data_folder, folder) + output_folder_name = ( + f"{args.dataset_name}-{int(args.target_size_mb)}MB-{suffix}" + ) + output_path = os.path.join(args.data_folder, output_folder_name) + + process_split( + input_path, output_path, args.dataset_name, args.target_size_mb, suffix + ) + + +if __name__ == "__main__": + main()