Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/workflows/python-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ on:
pull_request:
branches: [ "master" ]

env:
BUILD_ENV: ci

permissions:
contents: read

Expand Down
14 changes: 13 additions & 1 deletion src/meegsim/simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,8 @@ def add_patch_sources(
waveform_params=dict(),
snr_params=dict(),
extents=None,
subject=None,
subjects_dir=None,
names=None,
):
"""
Expand Down Expand Up @@ -227,6 +229,14 @@ def add_patch_sources(
grown (using :func:`mne.grow_labels`) from vertices specified in
location according to the provided values of extent. If a single number
is provided, all patch sources have the same extent.
subject : str, optional
Subject name, only used when growing patch sources from the central vertex.
If None (default), it is derived from the ``src`` object provided when
initializing the simulator.
subject_dir : str, optional
Path to the directory with FreeSurfer output, only used when growing patch
sources from the central vertex. If None (default), it is resolved automatically
by MNE-Python.
names : list, optional
A list of names for each source. If not specified, the names will be
autogenerated using the format 'auto-sgN-sM', where N is the index
Expand All @@ -247,8 +257,10 @@ def add_patch_sources(
std=std,
location_params=location_params,
waveform_params=waveform_params,
extents=extents,
snr_params=snr_params,
extents=extents,
subject=subject,
subjects_dir=subjects_dir,
names=names,
group=f"sg{next_group_idx}",
existing=self._sources,
Expand Down
35 changes: 33 additions & 2 deletions src/meegsim/source_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,17 @@ def create(

class PatchSourceGroup(_BaseSourceGroup):
def __init__(
self, n_sources, location, waveform, snr, snr_params, std, extents, names
self,
n_sources,
location,
waveform,
snr,
snr_params,
std,
extents,
subject,
subjects_dir,
names,
):
super().__init__()

Expand All @@ -178,6 +188,8 @@ def __init__(
self.std = std
self.names = names
self.extents = extents
self.subject = subject
self.subjects_dir = subjects_dir

def __repr__(self):
location_desc = "list"
Expand All @@ -204,6 +216,8 @@ def simulate(self, src, times, random_state=None):
self.std,
self.names,
self.extents,
self.subject,
self.subjects_dir,
random_state=random_state,
)

Expand All @@ -219,6 +233,8 @@ def create(
waveform_params,
snr_params,
extents,
subject,
subjects_dir,
names,
group,
existing,
Expand Down Expand Up @@ -247,6 +263,10 @@ def create(
Additional parameters for the adjustment of SNR.
extents: list
Extents (radius in mm) of each patch provided by the user.
subject: str, optional
Subject name.
subject_dir: str, optional
Path to the directory with FreeSurfer output.
names:
The names of sources provided by the user.
group:
Expand Down Expand Up @@ -280,4 +300,15 @@ def create(
else:
check_names(names, n_sources, existing)

return cls(n_sources, location, waveform, snr, snr_params, std, extents, names)
return cls(
n_sources,
location,
waveform,
snr,
snr_params,
std,
extents,
subject,
subjects_dir,
names,
)
15 changes: 12 additions & 3 deletions src/meegsim/sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,8 @@ def _create(
stds,
names,
extents,
subject,
subjects_dir,
random_state=None,
):
"""
Expand All @@ -323,8 +325,11 @@ def _create(
if data.shape[1] != len(times):
raise ValueError("The number of samples in waveform does not match")

# find patch vertices
subject = src[0].get("subject_his_id", None)
# Pick subject name from src if not provided explicitly
if subject is None:
subject = src[0].get("subject_his_id", None)

# Find patch vertices
patch_vertices = []
patch_stds = [] if isinstance(stds, mne.SourceEstimate) else stds
for isource, extent in enumerate(extents):
Expand All @@ -347,7 +352,11 @@ def _create(

# Grow the patch from center otherwise
patch = mne.grow_labels(
subject, vertno, extent, src_idx, subjects_dir=None
subject=subject,
seeds=vertno,
extents=extent,
hemis=src_idx,
subjects_dir=subjects_dir,
)[0]

# Prune vertices
Expand Down
2 changes: 2 additions & 0 deletions tests/test_simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,8 @@ def test_simulate_std_adjustment():
snr_params=dict(),
std=[3],
extents=[None],
subject=None,
subjects_dir=None,
names=["patch"],
),
]
Expand Down
4 changes: 4 additions & 0 deletions tests/test_snr.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,8 @@ def test_adjust_snr_local_patch(adjust_snr_mock):
snr_params=dict(fmin=8, fmax=12),
std=1,
extents=None,
subject=None,
subjects_dir=None,
names=["s1"],
),
PatchSourceGroup(
Expand All @@ -253,6 +255,8 @@ def test_adjust_snr_local_patch(adjust_snr_mock):
snr_params=dict(),
std=1,
extents=None,
subject=None,
subjects_dir=None,
names=["s2"],
),
]
Expand Down
8 changes: 5 additions & 3 deletions tests/test_sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,8 @@ def test_patchsource_create_with_extent():
stds=stds,
names=names,
extents=extents,
subject=None,
subjects_dir=None,
random_state=None,
)

Expand All @@ -369,7 +371,7 @@ def test_patchsource_create_with_extent():

# Verify that grow_labels was called once for the source with extent
mock_grow_labels.assert_called_once_with(
"meegsim", [2], 3, 0, subjects_dir=None
subject="meegsim", seeds=[2], extents=3, hemis=0, subjects_dir=None
)


Expand All @@ -391,13 +393,13 @@ def test_patchsource_create_std_sourceestimate(get_param_mock):

# Values are passed directly - the mock should not be used
sources = PatchSource._create(
src, times, n_sources, location, waveform, stds, names, extents
src, times, n_sources, location, waveform, stds, names, extents, None, None
)
get_param_mock.assert_not_called()

# Values are passed in stc - the mock should be called once per patch
sources = PatchSource._create(
src, times, n_sources, location, waveform, std_stc, names, extents
src, times, n_sources, location, waveform, std_stc, names, extents, None, None
)
assert get_param_mock.call_count == n_sources

Expand Down
77 changes: 77 additions & 0 deletions tests/test_subject.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
"""
Tests for ensuring that subject-specific info is processed and set correctly.
"""

import mne
import pytest

from mne.datasets import sample

from meegsim.location import select_random
from meegsim.simulate import SourceSimulator
from meegsim.waveform import narrowband_oscillation

from utils.misc import running_on_ci


def prepare_real_data():
data_path = sample.data_path() / "MEG" / "sample"
fwd_path = data_path / "sample_audvis-meg-eeg-oct-6-fwd.fif"
raw_path = data_path / "sample_audvis_raw.fif"
subjects_dir = sample.data_path() / "subjects"

# Load the prerequisites: fwd, src, and info
fwd = mne.read_forward_solution(fwd_path)
fwd = mne.convert_forward_solution(fwd, force_fixed=True)
raw = mne.io.read_raw(raw_path)
info = raw.info

# Pick EEG channels only
eeg_idx = mne.pick_types(info, eeg=True)
info_eeg = mne.pick_info(info, eeg_idx)
fwd_eeg = fwd.pick_channels(info_eeg.ch_names)

return fwd_eeg, info_eeg, subjects_dir


@pytest.mark.skipif(running_on_ci(), reason="Skip tests with real data on CI")
def test_grow_patch_source():
fwd, info, subjects_dir = prepare_real_data()
src = fwd["src"]

sfreq = 250
duration = 60
seed = 123

sim = SourceSimulator(src, snr_mode="local")

# Select some vertices randomly
sim.add_patch_sources(
location=select_random,
waveform=narrowband_oscillation,
location_params=dict(n=3),
waveform_params=dict(fmin=8, fmax=12),
extents=5,
subjects_dir=subjects_dir,
names=["s1", "s2", "s3"],
)

sc = sim.simulate(sfreq, duration, fwd=fwd, random_state=seed)
sc.to_raw(fwd, info)


@pytest.mark.skipif(running_on_ci(), reason="Skip tests with real data on CI")
def test_sourceconfiguration_plot():
fwd, _, subjects_dir = prepare_real_data()
src = fwd["src"]

sfreq = 250
duration = 60
seed = 123

sim = SourceSimulator(src, snr_mode="local")
sim.add_noise_sources(location=select_random, location_params=dict(n=10))

sc = sim.simulate(sfreq, duration, fwd=fwd, random_state=seed)
brain = sc.plot(subject="sample", subjects_dir=subjects_dir)
brain.close()
5 changes: 5 additions & 0 deletions tests/utils/misc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
import os


def running_on_ci():
return os.environ.get("BUILD_ENV", "local") == "ci"