This repository contains the implementation for the in-training compression of SSMs by Makram Chahine
This repository is an extension of https://github.com/tk-rusch/linoss which is itself an extension of https://github.com/Benjamin-Walker/log-neural-cdes.
This repository is implemented in python 3.10 and uses Jax as their machine learning framework.
The code for preprocessing the datasets, training LinOSS, S5, LRU, NCDE, NRDE, and Log-NCDE uses the following packages:
jaxandjaxlibfor automatic differentiation.equinoxfor constructing neural networks.optaxfor neural network optimisers.diffraxfor differential equation solvers.signaxfor calculating the signature.sktimefor handling time series data in ARFF format.tqdmfor progress bars.matplotlibfor plotting.pre-commitfor code formatting.
conda create -n compressm python=3.10
conda activate compressm
conda install pre-commit=3.7.1 sktime=0.30.1 tqdm=4.66.4 matplotlib=3.8.4 -c conda-forge
# Substitue for correct Jax pip install: https://jax.readthedocs.io/en/latest/installation.html
pip install -U "jax[cuda12]" "jaxlib[cuda12]" equinox==0.11.4 optax==0.2.2 diffrax==0.5.1 signax==0.1.1
After installing the requirements, run pre-commit install to install the pre-commit hooks.
The folder data_dir contains the scripts for downloading data, preprocessing the data, and creating dataloaders and
datasets. Raw data should be downloaded into the data_dir/raw folder. Processed data should be saved into the data_dir/processed
folder in the following format:
processed/{collection}/{dataset_name}/data.pkl,
processed/{collection}/{dataset_name}/labels.pkl,
processed/{collection}/{dataset_name}/original_idxs.pkl (if the dataset has original data splits)
where data.pkl and labels.pkl are jnp.arrays with shape (n_samples, n_timesteps, n_features) and (n_samples, n_classes) respectively. If the dataset had original_idxs then those should be saved as a list of jnp.arrays with shape [(n_train,), (n_val,), (n_test,)].
The code for training and evaluating the models is contained in train.py. Experiments can be run using the run_experiment.py script.
This script requires you to specify the names of the models you want to train,
the names of the datasets you want to train on, and a directory which contains configuration files. By default,
it will run the LinOSS experiments. The configuration files should be organised as config_dir/{model_name}/{dataset_name}.json and contain the
following fields:
seeds: A list of seeds to use for training.data_dir: The directory containing the data.output_parent_dir: The directory to save the output.lr_scheduler: A function which takes the learning rate and returns the new learning rate.num_steps: The number of steps to train for.print_steps: The number of steps between printing the loss.batch_size: The batch size.metric: The metric to use for evaluation.classification: Whether the task is a classification task.linoss_discretization: ONLY for LinoSS -- which discretization to use. Choices are ['IM','IMEX']lr: The initial learning rate.time: Whether to include time as a channel.- Any further specific model parameters.
See experiment_configs/repeats for examples.
The configuration files for all the experiments with fixed hyperparameters can be found in the experiment_configs folder and
run_experiment.py is currently configured to run the repeat experiments on the UEA datasets.
The outputs folder contains a zip file of the output files from the UEA, and PPG experiments.