diff --git a/download_test.py b/download_test.py new file mode 100644 index 000000000..fc76544e5 --- /dev/null +++ b/download_test.py @@ -0,0 +1,23 @@ +import os + +import pooch +from pooch import Unzip + + +FILES = [] +FILES.append("https://dataverse.harvard.edu/api/access/datafile/2499178") + +base_path = os.path.join(os.path.expanduser("~"), "mne_data", "WEIBO", "MNE-weibo-2014") +if not os.path.isdir(base_path): + os.makedirs(base_path) + +print(f"Downloading to {base_path}") +pooch.retrieve( + FILES[0], + None, + "data0.zip", + base_path, + processor=Unzip(), + progressbar=True, +) +print("Download finished") diff --git a/moabb/datasets/braininvaders.py b/moabb/datasets/braininvaders.py index dc4a2204f..e0d283fc6 100644 --- a/moabb/datasets/braininvaders.py +++ b/moabb/datasets/braininvaders.py @@ -284,7 +284,7 @@ def _bi_data_path( # noqa: C901 zip_ref = z.ZipFile(path_zip, "r") zip_ref.extractall(path_folder) os.makedirs(osp.join(directory, f"Session{i + 1}")) - shutil.copy_tree(path_zip.strip(".zip"), directory) + shutil.copytree(path_zip.strip(".zip"), directory) shutil.rmtree(path_zip.strip(".zip")) # filter the data regarding the experimental conditions diff --git a/moabb/tests/test_datasets.py b/moabb/tests/test_datasets.py index 084c11c61..b43ebd72b 100644 --- a/moabb/tests/test_datasets.py +++ b/moabb/tests/test_datasets.py @@ -574,6 +574,33 @@ def test_epochs(self, data, dataset): assert all([a == b for a, b in zip(raw.annotations.description[:3], description)]) +class TestWeibo2014: + @pytest.mark.parametrize("subject", [1]) + def test_get_data(self, subject, dl_data): + if not dl_data: + pytest.skip("Skipping download test") + ds = db.Weibo2014() + data = ds.get_data(subjects=[subject]) + + # we should get a dict + assert isinstance(data, dict) + + # we get the right number of subject + assert len(data) == 1 + assert subject in data + + # right number of session + assert len(data[subject]) == 1 + assert "0" in data[subject] + + # right number of run + assert len(data[subject]["0"]) == 1 + assert "0" in data[subject]["0"] + + # We should get a raw array at the end + assert isinstance(data[subject]["0"]["0"], mne.io.BaseRaw) + + class TestBIDSDataset: @pytest.fixture(scope="class") def cached_dataset_root(self, tmpdir_factory):