Skip to content

camail-official/compressm

Repository files navigation

Compre-SSM

This repository contains the implementation for the in-training compression of SSMs by Makram Chahine

This repository is an extension of https://github.com/tk-rusch/linoss which is itself an extension of https://github.com/Benjamin-Walker/log-neural-cdes.


Requirements

This repository is implemented in python 3.10 and uses Jax as their machine learning framework.

Environment

The code for preprocessing the datasets, training LinOSS, S5, LRU, NCDE, NRDE, and Log-NCDE uses the following packages:

  • jax and jaxlib for automatic differentiation.
  • equinox for constructing neural networks.
  • optax for neural network optimisers.
  • diffrax for differential equation solvers.
  • signax for calculating the signature.
  • sktime for handling time series data in ARFF format.
  • tqdm for progress bars.
  • matplotlib for plotting.
  • pre-commit for code formatting.
conda create -n compressm python=3.10
conda activate compressm
conda install pre-commit=3.7.1 sktime=0.30.1 tqdm=4.66.4 matplotlib=3.8.4 -c conda-forge
# Substitue for correct Jax pip install: https://jax.readthedocs.io/en/latest/installation.html
pip install -U "jax[cuda12]" "jaxlib[cuda12]" equinox==0.11.4 optax==0.2.2 diffrax==0.5.1 signax==0.1.1

After installing the requirements, run pre-commit install to install the pre-commit hooks.


Data

The folder data_dir contains the scripts for downloading data, preprocessing the data, and creating dataloaders and datasets. Raw data should be downloaded into the data_dir/raw folder. Processed data should be saved into the data_dir/processed folder in the following format:

processed/{collection}/{dataset_name}/data.pkl, 
processed/{collection}/{dataset_name}/labels.pkl,
processed/{collection}/{dataset_name}/original_idxs.pkl (if the dataset has original data splits)

where data.pkl and labels.pkl are jnp.arrays with shape (n_samples, n_timesteps, n_features) and (n_samples, n_classes) respectively. If the dataset had original_idxs then those should be saved as a list of jnp.arrays with shape [(n_train,), (n_val,), (n_test,)].

Experiments

The code for training and evaluating the models is contained in train.py. Experiments can be run using the run_experiment.py script. This script requires you to specify the names of the models you want to train, the names of the datasets you want to train on, and a directory which contains configuration files. By default, it will run the LinOSS experiments. The configuration files should be organised as config_dir/{model_name}/{dataset_name}.json and contain the following fields:

  • seeds: A list of seeds to use for training.
  • data_dir: The directory containing the data.
  • output_parent_dir: The directory to save the output.
  • lr_scheduler: A function which takes the learning rate and returns the new learning rate.
  • num_steps: The number of steps to train for.
  • print_steps: The number of steps between printing the loss.
  • batch_size: The batch size.
  • metric: The metric to use for evaluation.
  • classification: Whether the task is a classification task.
  • linoss_discretization: ONLY for LinoSS -- which discretization to use. Choices are ['IM','IMEX']
  • lr: The initial learning rate.
  • time: Whether to include time as a channel.
  • Any further specific model parameters.

See experiment_configs/repeats for examples.


Reproducing the Results

The configuration files for all the experiments with fixed hyperparameters can be found in the experiment_configs folder and run_experiment.py is currently configured to run the repeat experiments on the UEA datasets. The outputs folder contains a zip file of the output files from the UEA, and PPG experiments.


About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published