This package contains code related to the Active Skill-level Data Aggration (ASkDAgger) paper.
In particular, the code is this package can be used to recreate results related to the S-Aware Gating (SAG) algorithm on the MNIST dataset.
The SAG algorithm allows one in an interactive learning setting to dynamically adjust the gating threshold
It is adviced to use uv to install the dependencies of askdagger_mnist package.
Please make sure uv is installed according to the installation instructions.
First clone and go to the askdagger_mnist folder:
git clone https://github.com/askdagger/askdagger_mnist.git
cd askdagger_mnistCreate a virtual environment:
uv venv --python 3.10Source the virtual environment:
source .venv/bin/activateInstall the askdagger_mnist package`:
uv pip install -e .
The main training script can be run as follows. In case you have a CUDA-enabled GPU you can run:
python ./scripts/main.py --reps 2 --s_des 0.9Otherwise, for CPU training run:
python ./scripts/main.py --reps 2 --s_des 0.9 --accelerator cpuTo reproduce the experiments from the paper run:
python ./scripts/main.pyTo also reproduce the ablations from the paper run:
python ./scripts/ablations.pyThis will train LeNet model(s) interactively with SAG on the MNIST dataset.
The training procedure goes as follows.
Every time step, [batch_size] novel images of digits are sampled from the MNIST dataset.
Then we perform inference with the LeNet model (the novice) on these images and quantify the model's uncertainty for each sample.
Using SAG, the theshold is determined for gating.
For every sample with an uncertainty level that exceeds this threshold, a ground truth label is queried.
Also, for the samples with an uncertainty lower than the threshold, a ground truth is label is queried with a probability of [p_rand].
All samples for which a ground truth label is queried are added to the training dataset.
Finally, the model is updated with the training dataset every [update_every] steps.
Uncertainty quantification is performed through Monte-Carlo dropout with a dropout rate of 40% and 16 dropout evaluations.
This means there is an ensemble
where
Instead of training the models yourself, it is also possible to download the results data from the experiments in the paper.
python scripts/download_results.pyAfter training or downloading the results, you can plot the results as in the paper by doing:
python ./scripts/plot.pyThe resulting figure is save at figures/mnist.pdf.
After downloading the results or performing the ablation experiments, you can plot the ablations plots by:
python ./scripts/plot_reg_albation.pyand
python ./scripts/plot_prand_albation.pyThis work uses code from the TorchUncertainty open-source project.
Original: https://github.com/torch-uncertainty/torch-uncertainty
License: Apache 2.0
Changes: Our main training script is adapted from this tutorial.
The data modules are modified to allow for interactive training with a subset of the MNIST dataset.