diff --git a/.gitignore b/.gitignore index e168bc2..296a331 100644 --- a/.gitignore +++ b/.gitignore @@ -2,7 +2,7 @@ __pycache__/ *.py[cod] *$py.class - +example-files # C extensions *.so diff --git a/pyxdf/__init__.py b/pyxdf/__init__.py index 50ecda7..81b6354 100644 --- a/pyxdf/__init__.py +++ b/pyxdf/__init__.py @@ -3,11 +3,15 @@ # # License: BSD (2-clause) -from pkg_resources import get_distribution, DistributionNotFound try: - __version__ = get_distribution(__name__).version -except DistributionNotFound: # package is not installed + from pkg_resources import get_distribution, DistributionNotFound + + try: + __version__ = get_distribution(__name__).version + except DistributionNotFound: # package is not installed + __version__ = None +except ImportError: # pkg_resources is not available __version__ = None -from .pyxdf import load_xdf, resolve_streams, match_streaminfos +from .pyxdf import load_xdf, resolve_streams, match_streaminfos, align_streams -__all__ = [load_xdf, resolve_streams, match_streaminfos] +__all__ = [load_xdf, resolve_streams, match_streaminfos, align_streams] diff --git a/pyxdf/align.py b/pyxdf/align.py new file mode 100644 index 0000000..a9929ce --- /dev/null +++ b/pyxdf/align.py @@ -0,0 +1,213 @@ +import numpy as np +import warnings +from collections import defaultdict, Counter + + +def _interpolate( + x: np.ndarray, y: np.ndarray, new_x: np.ndarray, kind="linear" +) -> np.ndarray: + """Perform interpolation for _align_timestamps + + If scipy is not installed, the method falls back to numpy, and then only + supports linear interpolation. Otherwise, it supports ‘linear’, ‘nearest’, ‘nearest-up’, ‘zero’, ‘slinear’, ‘quadratic’, ‘cubic’, ‘previous’, or ‘next’. + """ + try: + from scipy.interpolate import interp1d + + f = interp1d( + x, + y, + kind=kind, + axis=0, + assume_sorted=True, # speed up + bounds_error=False, + ) + return f(new_x) + except ImportError as e: + if kind != "linear": + raise e + else: + return np.interp(new_x, xp=x, fp=y, left=np.NaN, right=np.NaN) + + +def _shift_align(old_timestamps, old_timeseries, new_timestamps): + # Convert inputs to numpy arrays + old_timestamps = np.array(old_timestamps) + old_timeseries = np.array(old_timeseries) + new_timestamps = np.array(new_timestamps) + + ts_last = old_timestamps[-1] + ts_first = old_timestamps[0] + + # Initialize variables + source = [] + target = [] + + new_timeseries = np.full((new_timestamps.shape[0], old_timeseries.shape[1]), np.nan) + + too_old = [] + too_young = [] + + # Loop through new timestamps to find the closest old timestamp + # Handle timestamps outside of the segment (too young or too old) different from stamnps from within the segment + for nix, nts in enumerate(new_timestamps): + if nts > ts_last: + too_young.append((nix, nts)) + elif nts < ts_first: + too_old.append((nix, nts)) + else: + closest = np.abs(old_timestamps - nts).argmin() + if closest not in source: # Ensure unique mapping + source.append(closest) + target.append(nix) + else: + raise RuntimeError( + f"Non-unique mapping. Closest old timestamp for {new_timestamps[nix]} is {old_timestamps[closest]} but that one was already assigned to {new_timestamps[source.index(closest)]}" + ) + + # Handle too old timestamps (those before the first old timestamp) + for nix, nts in too_old: + closest = 0 # Assign to the first timestamp + if closest not in source: # Ensure unique mapping + source.append(closest) + target.append(nix) + break # only one, because we only need the edge + + # Handle too young timestamps (those after the last old timestamp) + for nix, nts in too_young: + closest = len(old_timestamps) - 1 # Assign to the last timestamp + if closest not in source: # Ensure unique mapping + source.append(closest) + target.append(nix) + break # only one, because we only need the edge + + # Sanity check: all old timestamps should be assigned to at least one new timestamp + missed = len(old_timestamps) - len(set(source)) + if missed > 0: + unassigned_old = [i for i in range(len(old_timestamps)) if i not in source] + raise RuntimeError( + f"Too few new timestamps. {missed} old timestamps ({unassigned_old}:{old_timestamps[unassigned_old]}) found no corresponding new timestamp because all were already taken by other old timestamps. If your stream has multiple segments, this might be caused by small differences in effective srate between segments. Try different dejittering thresholds or support your own aligned_timestamps." + ) + + # Populate new timeseries with aligned values from old_timeseries + for chan in range(old_timeseries.shape[1]): + new_timeseries[target, chan] = old_timeseries[source, chan] + + return new_timeseries + + +def align_streams( + streams, # List[defaultdict] + align_foo=dict(), # defaultdict[int, Callable] + aligned_timestamps=None, # Optional[List[float]] + sampling_rate=None, # Optional[float|int] +): # -> Tuple[np.ndarray, List[float]] + """ + A function to + + + Args: + + streams: a list of defaultdicts (i.e. streams) as returned by + load_xdf + align_foo: a dictionary mapping streamIDs (i.e. int) to interpolation + callables. These callables must have the signature + `interpolate(old_timestamps, old_timeseries, new_timestamps)` and return a np.ndarray. See `_shift_align` and `_interpolate` for examples. + aligned_timestamps (optional): a list of floats with the new + timestamps to be used for alignment/interpolation. This list of timestamps can be irregular and have gaps. + sampling_rate (optional): a float defining the sampling rate which + will be used to calculate aligned_timestamps. + + Return: + (aligned_timeseries, aligned_timestamps): tuple + + + THe user can define either aligned_timestamps or sampling_rate or neither. If neither is defined, the algorithm will take the sampling_rate of the fastest stream and create aligned_timestamps from the oldest sample of all streams to the youngest. + + """ + + if sampling_rate is not None and aligned_timestamps is not None: + raise ValueError( + "You can not specify aligned_timestamps and sampling_rate at the same time" + ) + + if sampling_rate is None: + # we pick the effective sampling rate from the fastest stream + srates = [stream["info"]["effective_srate"] for stream in streams] + sampling_rate = max(srates, default=0) + if sampling_rate <= 0: # either no valid stream or all streams are async + warnings.warn( + "Can not align streams: Fastest effective sampling rate was 0 step = 1 / sampling_rateor smaller." + ) + return streams + + if aligned_timestamps is None: + # we pick the oldest and youngest timestamp of all streams + stamps = [stream["time_stamps"] for stream in streams] + ts_first = min((min(s) for s in stamps)) + ts_last = max((max(s) for s in stamps)) + full_dur = ( + ts_last - ts_first + (1 / sampling_rate) + ) # add one sample to include the last sample (see _jitter_removal) + # Use np.linspace for precise control over the number of points and guaranteed inclusion of the stop value. + # np.arange is better when you need direct control over step size but may exclude the stop value and accumulate floating-point errors. + # Choose np.linspace for better precision and np.arange for efficiency with fixed steps. + # we create new regularized timestamps + # arange implementation: + # step = 1 / sampling_rate + # aligned_timestamps = np.arange(ts_first, ts_last + step / 2, step) + # linspace implementation: + # add 1 to the number of samples to include the last sample + n_samples = int(np.round((full_dur * sampling_rate), 0)) + 1 + aligned_timestamps = np.linspace(ts_first, ts_last, n_samples) + + channels = 0 + for stream in streams: + # print(stream) + channels += int(stream["info"]["channel_count"][0]) + # https://stackoverflow.com/questions/1704823/create-numpy-matrix-filled-with-nans The timings show a preference for ndarray.fill(..) as the faster alternative. + aligned_timeseries = np.empty( + ( + len(aligned_timestamps), + channels, + ), + dtype=object, + ) + aligned_timeseries.fill(np.nan) + + chan_start = 0 + chan_end = 0 + for stream in streams: + sid = stream["info"]["stream_id"] + align = align_foo.get(sid, _shift_align) + chan_cnt = int(stream["info"]["channel_count"][0]) + new_timeseries = np.empty((len(aligned_timestamps), chan_cnt), dtype=object) + new_timeseries.fill(np.nan) + print("Stream #", sid, " has ", len(stream["info"]["segments"]), "segments") + for seg_idx, (seg_start, seg_stop) in enumerate(stream["info"]["segments"]): + print(seg_idx, ": from index ", seg_start, "to ", seg_stop + 1) + # segments have been created including the stop index, so we need to add 1 to include the last sample + segment_old_timestamps = stream["time_stamps"][seg_start : seg_stop + 1] + segment_old_timeseries = stream["time_series"][seg_start : seg_stop + 1] + # Sanity check for duplicate timestamps + if len(np.unique(segment_old_timestamps)) != len(segment_old_timestamps): + raise RuntimeError("Duplicate timestamps found in old_timestamps") + # apply align function as defined by the user (or default) + segment_new_timeseries = align( + segment_old_timestamps, + segment_old_timeseries, + aligned_timestamps, + ) + # pick indices of the NEW timestamps closest to when segments start and stop + a = stream["time_stamps"][seg_start] + b = stream["time_stamps"][seg_stop] + aix = np.argmin(np.abs(aligned_timestamps - a)) + bix = np.argmin(np.abs(aligned_timestamps - b)) + # and store only this aligned segment, leaving the rest as nans (or aligned as other segments) + new_timeseries[aix : bix + 1] = segment_new_timeseries[aix : bix + 1] + + # store the new timeseries at the respective channel indices in the 2D array + chan_start = chan_end + chan_end += chan_cnt + aligned_timeseries[:, chan_start:chan_end] = new_timeseries + return aligned_timeseries, aligned_timestamps diff --git a/pyxdf/pyxdf.py b/pyxdf/pyxdf.py index e1f16d5..76e923d 100644 --- a/pyxdf/pyxdf.py +++ b/pyxdf/pyxdf.py @@ -17,15 +17,12 @@ from collections import OrderedDict, defaultdict import logging from pathlib import Path - import numpy as np - - -__all__ = ["load_xdf"] +from pyxdf.align import align_streams +__all__ = ["load_xdf", "align_streams"] logger = logging.getLogger(__name__) - class StreamData: """Temporary per-stream data.""" @@ -74,15 +71,15 @@ def load_xdf( synchronize_clocks=True, handle_clock_resets=True, dejitter_timestamps=True, - jitter_break_threshold_seconds=1, + jitter_break_threshold_seconds=1.0, jitter_break_threshold_samples=500, clock_reset_threshold_seconds=5, clock_reset_threshold_stds=5, clock_reset_threshold_offset_seconds=1, clock_reset_threshold_offset_stds=10, winsor_threshold=0.0001, - verbose=None -): + verbose=None, + ): """Import an XDF file. This is an importer for multi-stream XDF (Extensible Data Format) @@ -120,8 +117,8 @@ def load_xdf( ClockOffset chunks. (default: true) dejitter_timestamps : Whether to perform jitter removal for regularly - sampled streams. (default: true) - + sampled streams. (default: true) + on_chunk : Function that is called for each chunk of data as it is being retrieved from the file; the function is allowed to modify the data (for example, sub-sample it). The four input arguments @@ -372,6 +369,10 @@ def load_xdf( ) # perform jitter removal if requested + for stream in temp.values(): + #initialize segment list in case jitter_removal was not selected + stream.segments = [(0, len(stream.time_series)-1)] #inclusive + if dejitter_timestamps: logger.info(" performing jitter removal...") temp = _jitter_removal( @@ -386,6 +387,8 @@ def load_xdf( stream.effective_srate = len(stream.time_stamps) / duration else: stream.effective_srate = 0.0 + + for k in streams.keys(): stream = streams[k] @@ -399,11 +402,11 @@ def load_xdf( ) stream["info"]["stream_id"] = k stream["info"]["effective_srate"] = tmp.effective_srate + stream["info"]["segments"] = tmp.segments stream["time_series"] = tmp.time_series stream["time_stamps"] = tmp.time_stamps - streams = [s for s in streams.values()] - return streams, fileheader + return list(streams), fileheader def open_xdf(file): @@ -501,7 +504,7 @@ def _xml2dict(t): def _scan_forward(f): """Scan forward through file object until after the next boundary chunk.""" - blocklen = 2 ** 20 + blocklen = 2**20 signature = bytes( [ 0x43, @@ -645,21 +648,25 @@ def _clock_sync( def _jitter_removal(streams, threshold_seconds=1, threshold_samples=500): for stream_id, stream in streams.items(): stream.effective_srate = 0 # will be recalculated if possible - nsamples = len(stream.time_stamps) + stream.segments = [] + nsamples = len(stream.time_stamps) if nsamples > 0 and stream.srate > 0: # Identify breaks in the time_stamps diffs = np.diff(stream.time_stamps) - b_breaks = diffs > np.max( + threshold = np.max( (threshold_seconds, threshold_samples * stream.tdiff) ) + b_breaks = diffs > threshold # find indices (+ 1 to compensate for lost sample in np.diff) - break_inds = np.where(b_breaks)[0] + 1 - + break_inds = np.where(b_breaks)[0] + 1 # Get indices delimiting segments without breaks # 0th sample is a segment start and last sample is a segment stop seg_starts = np.hstack(([0], break_inds)) seg_stops = np.hstack((break_inds - 1, nsamples - 1)) # inclusive + for a,b in zip(seg_starts, seg_stops): + stream.segments.append((a,b)) + stream.seg_stops = seg_stops.tolist() # Process each segment separately for start_ix, stop_ix in zip(seg_starts, seg_stops): # Calculate time stamps assuming constant intervals within each @@ -679,15 +686,15 @@ def _jitter_removal(streams, threshold_seconds=1, threshold_samples=500): stream.time_stamps[seg_stops] + stream.tdiff ) - stream.time_stamps[seg_starts] stream.effective_srate = np.sum(counts) / np.sum(durations) - + else: + stream.segments = [0, nsamples-1] srate, effective_srate = stream.srate, stream.effective_srate if srate != 0 and np.abs(srate - effective_srate) / srate > 0.1: msg = ( "Stream %d: Calculated effective sampling rate %.4f Hz is" " different from specified rate %.4f Hz." ) - logger.warning(msg, stream_id, effective_srate, srate) - + logger.warning(msg, stream_id, effective_srate, srate) return streams @@ -859,3 +866,4 @@ def _read_chunks(f): def _parse_streamheader(xml): """Parse stream header XML.""" return {el.tag: el.text for el in xml if el.tag != "desc"} + diff --git a/pyxdf/test/conftest.py b/pyxdf/test/conftest.py new file mode 100644 index 0000000..f5f7d7b --- /dev/null +++ b/pyxdf/test/conftest.py @@ -0,0 +1,16 @@ +from pytest import fixture + + +@fixture(scope="session") +def example_files(): + from pathlib import Path + + # requires git clone https://github.com/xdf-modules/example-files.git + # into the root xdf-python folder + path = Path("example-files") + extensions = ["*.xdf", "*.xdfz", "*.xdf.gz"] + files = [] + for ext in extensions: + files.extend(path.glob(ext)) + files = [str(file) for file in files] + yield files diff --git a/pyxdf/test/test_data.py b/pyxdf/test/test_data.py index b8005d0..efd939e 100644 --- a/pyxdf/test/test_data.py +++ b/pyxdf/test/test_data.py @@ -1,5 +1,5 @@ from pathlib import Path -from pyxdf import load_xdf +from pyxdf import load_xdf, align_streams import pytest import numpy as np @@ -29,16 +29,21 @@ def test_load_file(file): assert streams[0]["info"]["channel_format"][0] == "int16" assert streams[0]["info"]["stream_id"] == 0 - s = np.array([[192, 255, 238], - [12, 22, 32], - [13, 23, 33], - [14, 24, 34], - [15, 25, 35], - [12, 22, 32], - [13, 23, 33], - [14, 24, 34], - [15, 25, 35]], dtype=np.int16) - t = np.array([5., 5.1, 5.2, 5.3, 5.4, 5.5, 5.6, 5.7, 5.8]) + s = np.array( + [ + [192, 255, 238], + [12, 22, 32], + [13, 23, 33], + [14, 24, 34], + [15, 25, 35], + [12, 22, 32], + [13, 23, 33], + [14, 24, 34], + [15, 25, 35], + ], + dtype=np.int16, + ) + t = np.array([5.0, 5.1, 5.2, 5.3, 5.4, 5.5, 5.6, 5.7, 5.8]) np.testing.assert_array_equal(streams[0]["time_series"], s) np.testing.assert_array_almost_equal(streams[0]["time_stamps"], t) @@ -49,20 +54,73 @@ def test_load_file(file): assert streams[1]["info"]["channel_format"][0] == "string" assert streams[1]["info"]["stream_id"] == 0x02C0FFEE - s = [['LabRecorder xdfwriter' - '5.1' - '5.99' - '-.01' - '-.02' - ''], - ['Hello'], - ['World'], - ['from'], - ['LSL'], - ['Hello'], - ['World'], - ['from'], - ['LSL']] + s = [ + [ + 'LabRecorder xdfwriter' + "5.1" + "5.99" + "-.01" + "-.02" + "" + ], + ["Hello"], + ["World"], + ["from"], + ["LSL"], + ["Hello"], + ["World"], + ["from"], + ["LSL"], + ] t = np.array([5.1, 5.2, 5.3, 5.4, 5.5, 5.6, 5.7, 5.8, 5.9]) assert streams[1]["time_series"] == s np.testing.assert_array_almost_equal(streams[1]["time_stamps"], t) + + +def test_smoketest_sync_unsegmented(example_files): + for file in example_files: + streams, header = load_xdf(file) + if file.endswith("minimal.xdf"): + print("unsegmented") + a_series, a_stamps = align_streams(streams) + for d, s in zip(a_series, a_stamps): + print(f"{s:5.3f} : {d}") + + + + +def test_smoketest_sync_forced(example_files): + for file in example_files: + streams, header = load_xdf(file) + if file.endswith("minimal.xdf"): + print("forced_stamps") + a_series, a_stamps = align_streams(streams, aligned_timestamps=np.arange(5.001, 5.92, 0.1) ) + for d, s in zip(a_series, a_stamps): + print(f"{s:5.3f} : {d}") + + print("forced_rate") + a_series, a_stamps = align_streams(streams, + sampling_rate=10.15) + for d, s in zip(a_series, a_stamps): + print(f"{s:5.3f} : {d}") + + from pyxdf.align import _interpolate + print("forced cubic") + a_series, a_stamps = align_streams(streams, align_foo={0:lambda x,y,xh: _interpolate(x,y,xh, "cubic")}) + for d, s in zip(a_series, a_stamps): + print(f"{s:5.3f} : {d}") + +def test_smoketest_sync_with_gaps(example_files): + for file in example_files: + streams, header = load_xdf( + file, + jitter_break_threshold_seconds=0.001, jitter_break_threshold_samples=1 + ) + if file.endswith("minimal.xdf"): + a_series, a_stamps = align_streams(streams) + print("segmented") + for d, s in zip(a_series, a_stamps): + print(f"{s:5.3f} : {d}") + + + diff --git a/pyxdf/test/test_shift_align.py b/pyxdf/test/test_shift_align.py new file mode 100644 index 0000000..ac7007c --- /dev/null +++ b/pyxdf/test/test_shift_align.py @@ -0,0 +1,58 @@ +from pyxdf.align import _shift_align +import numpy as np +import pytest +old_timestamps = np.linspace(1.0, 1.5, 51) +old_timeseries = np.empty((51,1)) +old_timeseries[:,0] = np.linspace(0, 50, 51) + +def test_shift_align_too_few_new_stamps(): + # not all old samples were assigned + new_timestamps = np.linspace(1.001, 1.5001, 50) + with pytest.raises(RuntimeError): + new_timeseries = _shift_align(old_timestamps, old_timeseries, new_timestamps) + +def test_shift_align_slightly_later(): + print("\n==================") + new_timestamps = np.arange(1.001, 1.5011, 0.01) + new_timeseries = _shift_align(old_timestamps, old_timeseries, new_timestamps) + for x, y, xhat, yhat in zip(old_timestamps, old_timeseries, new_timestamps, new_timeseries): + print(f"{x:3.4f} -> {xhat:3.4f} = {y[0]:3.0f} / {yhat[0]:3.0f} ") + assert y == yhat + +def test_shift_align_slightly_earlier(): + print("\n==================") + new_timestamps = np.arange(0.999, 1.499, 0.01) + new_timeseries= _shift_align(old_timestamps, old_timeseries, new_timestamps) + + for x, y, xhat, yhat in zip(old_timestamps, old_timeseries, new_timestamps, new_timeseries): + print(f"{x:3.4f} -> {xhat:3.4f} = {y[0]:3.0f} / {yhat[0]:3.0f} ") + assert y == yhat + +def test_shift_align_jittered(): + print("\n==================") + jittered_timestamps = np.random.randn(*old_timestamps.shape)*0.0005 + jittered_timestamps += old_timestamps + new_timeseries = _shift_align(jittered_timestamps, old_timeseries, old_timestamps) + for x, y, xhat, yhat in zip(jittered_timestamps, old_timeseries, old_timestamps, new_timeseries): + print(f"{x:3.4f} -> {xhat:3.4f} = {y[0]:3.0f} / {yhat[0]:3.0f} ") + assert y == yhat + +def test_shift_align_edges(): + print("\n==================") + new_timestamps = np.arange(0.5, 1.7, 0.01) + # idx = np.argwhere(old_timestamps<1.2)[-1][0]+1 + idx = len(old_timestamps) + new_timeseries = _shift_align(old_timestamps[:idx], old_timeseries, new_timestamps) + for xhat, yhat in zip(new_timestamps, new_timeseries): + delta = xhat-old_timestamps + idx_old = np.argmin(np.abs(delta)) + x = old_timestamps[idx_old] + if xhat < 1.00 or xhat > 1.5001: + y = np.nan + else: + y = old_timeseries[idx_old,0] + print(f"{x:3.4f} -> {xhat:3.4f} = {y:3.0f} / {yhat[0]:3.0f} ") + if (np.isnan(y)): + assert np.isnan(yhat[0]) + else: + assert y == yhat[0] \ No newline at end of file diff --git a/test.py b/test.py new file mode 100644 index 0000000..b0897b4 --- /dev/null +++ b/test.py @@ -0,0 +1,15 @@ +import matplotlib.pyplot as plt +import pyxdf + +if __name__ == "__main__": + fname = "/home/rtgugg/Downloads/sub-13_ses-S001_task-HCT_run-001_eeg.xdf" + # streams, header = pyxdf.load_xdf( + # fname, select_streams=[2, 5] + # ) # EEG and ACC streams + + # pyxdf.align_streams(streams) + + streams, header = pyxdf.load_xdf(fname, select_streams=[2]) # EEG stream + plt.plot(streams[0]["time_stamps"]) + plt.show() + pyxdf.align_streams(streams)